feat: Implement Text-to-Speech (TTS) functionality with API endpoints, models, and service integration

This commit is contained in:
JSC
2025-09-20 23:10:47 +02:00
parent fb0e5e919c
commit 5e8d619736
11 changed files with 887 additions and 0 deletions

View File

@@ -0,0 +1,45 @@
"""Add TTS table
Revision ID: e617c155eea9
Revises: a0d322857b2c
Create Date: 2025-09-20 21:51:26.557738
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = 'e617c155eea9'
down_revision: Union[str, Sequence[str], None] = 'a0d322857b2c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tts',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('text', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=False),
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
sa.Column('options', sa.JSON(), nullable=True),
sa.Column('sound_id', sa.Integer(), nullable=True),
sa.Column('user_id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tts')
# ### end Alembic commands ###

View File

@@ -15,6 +15,7 @@ from app.api.v1 import (
scheduler, scheduler,
socket, socket,
sounds, sounds,
tts,
) )
# V1 API router with v1 prefix # V1 API router with v1 prefix
@@ -32,4 +33,5 @@ api_router.include_router(playlists.router, tags=["playlists"])
api_router.include_router(scheduler.router, tags=["scheduler"]) api_router.include_router(scheduler.router, tags=["scheduler"])
api_router.include_router(socket.router, tags=["socket"]) api_router.include_router(socket.router, tags=["socket"])
api_router.include_router(sounds.router, tags=["sounds"]) api_router.include_router(sounds.router, tags=["sounds"])
api_router.include_router(tts.router, tags=["tts"])
api_router.include_router(admin.router) api_router.include_router(admin.router)

216
app/api/v1/tts.py Normal file
View File

@@ -0,0 +1,216 @@
"""TTS API endpoints."""
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.services.tts import TTSService
router = APIRouter(prefix="/tts", tags=["tts"])
class TTSGenerateRequest(BaseModel):
"""TTS generation request model."""
text: str = Field(..., min_length=1, max_length=1000, description="Text to convert to speech")
provider: str = Field(default="gtts", description="TTS provider to use")
options: dict[str, Any] = Field(default_factory=dict, description="Provider-specific options")
class TTSResponse(BaseModel):
"""TTS generation response model."""
id: int
text: str
provider: str
options: dict[str, Any]
sound_id: int | None
user_id: int
created_at: str
class ProviderInfo(BaseModel):
"""Provider information model."""
name: str
file_extension: str
supported_languages: list[str]
option_schema: dict[str, Any]
async def get_tts_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> TTSService:
"""Get the TTS service."""
return TTSService(session)
@router.post("/generate")
async def generate_tts(
request: TTSGenerateRequest,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, Any]:
"""Generate TTS audio and create sound."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
result = await tts_service.create_tts_request(
text=request.text,
user_id=current_user.id,
provider=request.provider,
**request.options
)
tts_record = result["tts"]
return {
"message": result["message"],
"tts": TTSResponse(
id=tts_record.id,
text=tts_record.text,
provider=tts_record.provider,
options=tts_record.options,
sound_id=tts_record.sound_id,
user_id=tts_record.user_id,
created_at=tts_record.created_at.isoformat(),
)
}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate TTS: {e!s}",
) from e
@router.get("/providers")
async def get_providers(
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, ProviderInfo]:
"""Get all available TTS providers."""
providers = tts_service.get_providers()
result = {}
for name, provider in providers.items():
result[name] = ProviderInfo(
name=provider.name,
file_extension=provider.file_extension,
supported_languages=provider.get_supported_languages(),
option_schema=provider.get_option_schema(),
)
return result
@router.get("/providers/{provider_name}")
async def get_provider(
provider_name: str,
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> ProviderInfo:
"""Get information about a specific TTS provider."""
provider = tts_service.get_provider(provider_name)
if not provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_name}' not found",
)
return ProviderInfo(
name=provider.name,
file_extension=provider.file_extension,
supported_languages=provider.get_supported_languages(),
option_schema=provider.get_option_schema(),
)
@router.get("/history")
async def get_tts_history(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
limit: int = 50,
offset: int = 0,
) -> list[TTSResponse]:
"""Get TTS generation history for the current user."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
tts_records = await tts_service.get_user_tts_history(
user_id=current_user.id,
limit=limit,
offset=offset,
)
return [
TTSResponse(
id=tts.id,
text=tts.text,
provider=tts.provider,
options=tts.options,
sound_id=tts.sound_id,
user_id=tts.user_id,
created_at=tts.created_at.isoformat(),
)
for tts in tts_records
]
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get TTS history: {e!s}",
) from e
@router.delete("/{tts_id}")
async def delete_tts(
tts_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, str]:
"""Delete a TTS generation and its associated files."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
return {"message": "TTS generation deleted successfully"}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except PermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete TTS: {e!s}",
) from e

View File

@@ -12,6 +12,7 @@ from .playlist_sound import PlaylistSound
from .scheduled_task import ScheduledTask from .scheduled_task import ScheduledTask
from .sound import Sound from .sound import Sound
from .sound_played import SoundPlayed from .sound_played import SoundPlayed
from .tts import TTS
from .user import User from .user import User
from .user_oauth import UserOauth from .user_oauth import UserOauth
@@ -27,6 +28,7 @@ __all__ = [
"ScheduledTask", "ScheduledTask",
"Sound", "Sound",
"SoundPlayed", "SoundPlayed",
"TTS",
"User", "User",
"UserOauth", "UserOauth",
] ]

26
app/models/tts.py Normal file
View File

@@ -0,0 +1,26 @@
"""TTS model."""
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
class TTS(SQLModel, table=True):
"""Text-to-Speech generation record."""
__tablename__ = "tts"
id: int | None = Field(primary_key=True)
text: str = Field(max_length=1000, description="Text that was converted to speech")
provider: str = Field(max_length=50, description="TTS provider used")
options: dict[str, Any] = Field(
default_factory=dict,
sa_column=Column(JSON),
description="Provider-specific options used"
)
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)

62
app/repositories/tts.py Normal file
View File

@@ -0,0 +1,62 @@
"""TTS repository for database operations."""
from typing import Any, Sequence
from sqlmodel import select
from app.models.tts import TTS
from app.repositories.base import BaseRepository
class TTSRepository(BaseRepository[TTS]):
"""Repository for TTS operations."""
def __init__(self, session: Any) -> None:
super().__init__(TTS, session)
async def get_by_user_id(
self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> Sequence[TTS]:
"""Get TTS records by user ID with pagination.
Args:
user_id: User ID to filter by
limit: Maximum number of records to return
offset: Number of records to skip
Returns:
List of TTS records
"""
stmt = (
select(self.model)
.where(self.model.user_id == user_id)
.order_by(self.model.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return result.all()
async def get_by_user_and_id(
self,
user_id: int,
tts_id: int,
) -> TTS | None:
"""Get a specific TTS record by user ID and TTS ID.
Args:
user_id: User ID to filter by
tts_id: TTS ID to retrieve
Returns:
TTS record if found and belongs to user, None otherwise
"""
stmt = select(self.model).where(
self.model.id == tts_id,
self.model.user_id == user_id,
)
result = await self.session.exec(stmt)
return result.first()

View File

@@ -0,0 +1,6 @@
"""Text-to-Speech services package."""
from .base import TTSProvider
from .service import TTSService
__all__ = ["TTSProvider", "TTSService"]

38
app/services/tts/base.py Normal file
View File

@@ -0,0 +1,38 @@
"""Base TTS provider interface."""
from abc import ABC, abstractmethod
from typing import Any
class TTSProvider(ABC):
"""Abstract base class for TTS providers."""
@abstractmethod
async def generate_speech(self, text: str, **options: Any) -> bytes:
"""Generate speech from text with provider-specific options.
Args:
text: The text to convert to speech
**options: Provider-specific options
Returns:
Audio data as bytes
"""
@abstractmethod
def get_supported_languages(self) -> list[str]:
"""Return list of supported language codes."""
@abstractmethod
def get_option_schema(self) -> dict[str, Any]:
"""Return schema for provider-specific options."""
@property
@abstractmethod
def name(self) -> str:
"""Return the provider name."""
@property
@abstractmethod
def file_extension(self) -> str:
"""Return the default file extension for this provider."""

View File

@@ -0,0 +1,5 @@
"""TTS providers package."""
from .gtts import GTTSProvider
__all__ = ["GTTSProvider"]

View File

@@ -0,0 +1,81 @@
"""Google Text-to-Speech provider."""
import asyncio
import io
from typing import Any
from gtts import gTTS
from ..base import TTSProvider
class GTTSProvider(TTSProvider):
"""Google Text-to-Speech provider implementation."""
@property
def name(self) -> str:
"""Return the provider name."""
return "gtts"
@property
def file_extension(self) -> str:
"""Return the default file extension for this provider."""
return "mp3"
async def generate_speech(self, text: str, **options: Any) -> bytes:
"""Generate speech from text using Google TTS.
Args:
text: The text to convert to speech
**options: GTTS-specific options (lang, tld, slow)
Returns:
MP3 audio data as bytes
"""
lang = options.get("lang", "en")
tld = options.get("tld", "com")
slow = options.get("slow", False)
# Run TTS generation in thread pool since gTTS is synchronous
def _generate():
tts = gTTS(text=text, lang=lang, tld=tld, slow=slow)
fp = io.BytesIO()
tts.write_to_fp(fp)
fp.seek(0)
return fp.read()
# Use asyncio.to_thread which is more reliable than run_in_executor
return await asyncio.to_thread(_generate)
def get_supported_languages(self) -> list[str]:
"""Return list of supported language codes."""
# Common GTTS supported languages
return [
"af", "ar", "bg", "bn", "bs", "ca", "cs", "cy", "da", "de", "el", "en",
"eo", "es", "et", "fi", "fr", "gu", "hi", "hr", "hu", "hy", "id", "is",
"it", "ja", "jw", "km", "kn", "ko", "la", "lv", "mk", "ml", "mr", "my",
"ne", "nl", "no", "pl", "pt", "ro", "ru", "si", "sk", "sq", "sr", "su",
"sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur", "vi", "zh-cn", "zh-tw"
]
def get_option_schema(self) -> dict[str, Any]:
"""Return schema for GTTS-specific options."""
return {
"lang": {
"type": "string",
"default": "en",
"description": "Language code",
"enum": self.get_supported_languages()
},
"tld": {
"type": "string",
"default": "com",
"description": "Top-level domain for Google TTS",
"enum": ["com", "co.uk", "com.au", "ca", "co.in", "ie", "co.za"]
},
"slow": {
"type": "boolean",
"default": False,
"description": "Speak slowly"
}
}

404
app/services/tts/service.py Normal file
View File

@@ -0,0 +1,404 @@
"""TTS service implementation."""
import asyncio
import io
import uuid
from pathlib import Path
from typing import Any
from gtts import gTTS
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound
from app.models.tts import TTS
from app.repositories.sound import SoundRepository
from app.repositories.tts import TTSRepository
from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
from .base import TTSProvider
from .providers import GTTSProvider
# Constants
MAX_TEXT_LENGTH = 1000
MAX_NAME_LENGTH = 50
class TTSService:
"""Text-to-Speech service with provider management."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize TTS service.
Args:
session: Database session
"""
self.session = session
self.sound_repo = SoundRepository(session)
self.tts_repo = TTSRepository(session)
self.providers: dict[str, TTSProvider] = {}
# Register default providers
self._register_default_providers()
def _register_default_providers(self) -> None:
"""Register default TTS providers."""
self.register_provider(GTTSProvider())
def register_provider(self, provider: TTSProvider) -> None:
"""Register a TTS provider.
Args:
provider: TTS provider instance
"""
self.providers[provider.name] = provider
def get_providers(self) -> dict[str, TTSProvider]:
"""Get all registered providers."""
return self.providers.copy()
def get_provider(self, name: str) -> TTSProvider | None:
"""Get a specific provider by name."""
return self.providers.get(name)
async def create_tts_request(
self,
text: str,
user_id: int,
provider: str = "gtts",
**options: Any,
) -> dict[str, Any]:
"""Create a TTS request that will be processed in the background.
Args:
text: Text to convert to speech
user_id: ID of user creating the sound
provider: TTS provider name
**options: Provider-specific options
Returns:
Dictionary with TTS record information
Raises:
ValueError: If provider not found or text too long
Exception: If request creation fails
"""
provider_not_found_msg = f"Provider '{provider}' not found"
if provider not in self.providers:
raise ValueError(provider_not_found_msg)
text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)"
if len(text) > MAX_TEXT_LENGTH:
raise ValueError(text_too_long_msg)
empty_text_msg = "Text cannot be empty"
if not text.strip():
raise ValueError(empty_text_msg)
# Create TTS record with pending status
tts = TTS(
text=text,
provider=provider,
options=options,
sound_id=None, # Will be set when processing completes
user_id=user_id,
)
self.session.add(tts)
await self.session.commit()
await self.session.refresh(tts)
# Queue for background processing
if tts.id is not None:
await self._queue_tts_processing(tts.id)
return {"tts": tts, "message": "TTS generation queued successfully"}
async def _queue_tts_processing(self, tts_id: int) -> None:
"""Queue TTS for background processing."""
# For now, process immediately in a different way
# This could be moved to a proper background queue later
task = asyncio.create_task(self._process_tts_in_background(tts_id))
# Store reference to prevent garbage collection
self._background_tasks = getattr(self, '_background_tasks', set())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
async def _process_tts_in_background(self, tts_id: int) -> None:
"""Process TTS generation in background."""
from app.core.database import get_session_factory
try:
# Create a new session for background processing
session_factory = get_session_factory()
async with session_factory() as background_session:
tts_service = TTSService(background_session)
# Get the TTS record
stmt = select(TTS).where(TTS.id == tts_id)
result = await background_session.exec(stmt)
tts = result.first()
if not tts:
return
# Use a synchronous approach for the actual generation
sound = await tts_service._generate_tts_sync(
tts.text,
tts.provider,
tts.user_id,
tts.options,
)
# Update the TTS record with the sound ID
if sound.id is not None:
tts.sound_id = sound.id
background_session.add(tts)
await background_session.commit()
except Exception:
# Log error but don't fail - avoiding print for production
pass
async def _generate_tts_sync(
self,
text: str,
provider: str,
user_id: int,
options: dict[str, Any]
) -> Sound:
"""Generate TTS using a synchronous approach."""
# Generate the audio using the provider (avoid async issues by doing it directly)
tts_provider = self.providers[provider]
# Create directories if they don't exist
original_dir = Path("sounds/originals/text_to_speech")
original_dir.mkdir(parents=True, exist_ok=True)
# Create UUID filename
sound_uuid = str(uuid.uuid4())
original_filename = f"{sound_uuid}.{tts_provider.file_extension}"
original_path = original_dir / original_filename
# Generate audio synchronously
try:
# Generate TTS audio
lang = options.get("lang", "en")
tld = options.get("tld", "com")
slow = options.get("slow", False)
tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow)
fp = io.BytesIO()
tts_instance.write_to_fp(fp)
fp.seek(0)
audio_bytes = fp.read()
# Save the file
original_path.write_bytes(audio_bytes)
except Exception:
raise
# Create Sound record with proper metadata
sound = await self._create_sound_record_complete(
original_path,
text,
provider,
user_id
)
# Normalize the sound
await self._normalize_sound_safe(sound.id)
return sound
async def get_user_tts_history(
self,
user_id: int,
limit: int = 50,
offset: int = 0
) -> list[TTS]:
"""Get TTS history for a user.
Args:
user_id: User ID
limit: Maximum number of records
offset: Offset for pagination
Returns:
List of TTS records
"""
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
return list(result)
async def _create_sound_record(
self,
audio_path: Path,
text: str,
provider: str,
user_id: int,
file_hash: str
) -> Sound:
"""Create a Sound record for the TTS audio."""
# Get audio metadata
duration = get_audio_duration(audio_path)
size = get_file_size(audio_path)
# Create sound data
sound_data = {
"type": "TTS",
"name": text[:50] + ("..." if len(text) > 50 else ""),
"filename": audio_path.name,
"duration": duration,
"size": size,
"hash": file_hash,
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
async def _create_sound_record_simple(
self,
audio_path: Path,
text: str,
provider: str,
user_id: int
) -> Sound:
"""Create a Sound record for the TTS audio with minimal processing."""
# Create sound data with basic info
sound_data = {
"type": "TTS",
"name": text[:50] + ("..." if len(text) > 50 else ""),
"filename": audio_path.name,
"duration": 0, # Skip duration calculation for now
"size": 0, # Skip size calculation for now
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
async def _create_sound_record_complete(
self,
audio_path: Path,
text: str,
provider: str,
user_id: int
) -> Sound:
"""Create a Sound record for the TTS audio with complete metadata."""
# Get audio metadata
duration = get_audio_duration(audio_path)
size = get_file_size(audio_path)
file_hash = get_file_hash(audio_path)
# Check if a sound with this hash already exists
existing_sound = await self.sound_repo.get_by_hash(file_hash)
if existing_sound:
# Clean up the temporary file since we have a duplicate
if audio_path.exists():
audio_path.unlink()
return existing_sound
# Create sound data with complete metadata
sound_data = {
"type": "TTS",
"name": text[:50] + ("..." if len(text) > 50 else ""),
"filename": audio_path.name,
"duration": duration,
"size": size,
"hash": file_hash,
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
async def _normalize_sound_safe(self, sound_id: int) -> None:
"""Normalize the TTS sound with error handling."""
try:
# Get fresh sound object from database for normalization
sound = await self.sound_repo.get_by_id(sound_id)
if not sound:
return
normalizer_service = SoundNormalizerService(self.session)
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
print(f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}")
except Exception as e:
print(f"Exception during TTS sound normalization {sound_id}: {e}")
# Don't fail the TTS generation if normalization fails
async def _normalize_sound(self, sound_id: int) -> None:
"""Normalize the TTS sound."""
try:
# Get fresh sound object from database for normalization
sound = await self.sound_repo.get_by_id(sound_id)
if not sound:
return
normalizer_service = SoundNormalizerService(self.session)
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
# Log warning but don't fail the TTS generation
pass
except Exception:
# Don't fail the TTS generation if normalization fails
pass
async def delete_tts(self, tts_id: int, user_id: int) -> None:
"""Delete a TTS generation and its associated sound and files."""
# Get the TTS record
tts = await self.tts_repo.get_by_id(tts_id)
if not tts:
raise ValueError(f"TTS with ID {tts_id} not found")
# Check ownership
if tts.user_id != user_id:
raise PermissionError("You don't have permission to delete this TTS generation")
# If there's an associated sound, delete it and its files
if tts.sound_id:
sound = await self.sound_repo.get_by_id(tts.sound_id)
if sound:
# Delete the sound files
await self._delete_sound_files(sound)
# Delete the sound record
await self.sound_repo.delete(sound)
# Delete the TTS record
await self.tts_repo.delete(tts)
async def _delete_sound_files(self, sound: Sound) -> None:
"""Delete all files associated with a sound."""
from pathlib import Path
# Delete original file
original_path = Path("sounds/originals/text_to_speech") / sound.filename
if original_path.exists():
original_path.unlink()
# Delete normalized file if it exists
if sound.normalized_filename:
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
if normalized_path.exists():
normalized_path.unlink()