From 5e8d619736e9b534a88e7d0d3d06738820f209c0 Mon Sep 17 00:00:00 2001 From: JSC Date: Sat, 20 Sep 2025 23:10:47 +0200 Subject: [PATCH] feat: Implement Text-to-Speech (TTS) functionality with API endpoints, models, and service integration --- .../versions/e617c155eea9_add_tts_table.py | 45 ++ app/api/v1/__init__.py | 2 + app/api/v1/tts.py | 216 ++++++++++ app/models/__init__.py | 2 + app/models/tts.py | 26 ++ app/repositories/tts.py | 62 +++ app/services/tts/__init__.py | 6 + app/services/tts/base.py | 38 ++ app/services/tts/providers/__init__.py | 5 + app/services/tts/providers/gtts.py | 81 ++++ app/services/tts/service.py | 404 ++++++++++++++++++ 11 files changed, 887 insertions(+) create mode 100644 alembic/versions/e617c155eea9_add_tts_table.py create mode 100644 app/api/v1/tts.py create mode 100644 app/models/tts.py create mode 100644 app/repositories/tts.py create mode 100644 app/services/tts/__init__.py create mode 100644 app/services/tts/base.py create mode 100644 app/services/tts/providers/__init__.py create mode 100644 app/services/tts/providers/gtts.py create mode 100644 app/services/tts/service.py diff --git a/alembic/versions/e617c155eea9_add_tts_table.py b/alembic/versions/e617c155eea9_add_tts_table.py new file mode 100644 index 0000000..44e40fd --- /dev/null +++ b/alembic/versions/e617c155eea9_add_tts_table.py @@ -0,0 +1,45 @@ +"""Add TTS table + +Revision ID: e617c155eea9 +Revises: a0d322857b2c +Create Date: 2025-09-20 21:51:26.557738 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = 'e617c155eea9' +down_revision: Union[str, Sequence[str], None] = 'a0d322857b2c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tts', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('text', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=False), + sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False), + sa.Column('options', sa.JSON(), nullable=True), + sa.Column('sound_id', sa.Integer(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tts') + # ### end Alembic commands ### diff --git a/app/api/v1/__init__.py b/app/api/v1/__init__.py index 638a549..2bf37bf 100644 --- a/app/api/v1/__init__.py +++ b/app/api/v1/__init__.py @@ -15,6 +15,7 @@ from app.api.v1 import ( scheduler, socket, sounds, + tts, ) # V1 API router with v1 prefix @@ -32,4 +33,5 @@ api_router.include_router(playlists.router, tags=["playlists"]) api_router.include_router(scheduler.router, tags=["scheduler"]) api_router.include_router(socket.router, tags=["socket"]) api_router.include_router(sounds.router, tags=["sounds"]) +api_router.include_router(tts.router, tags=["tts"]) api_router.include_router(admin.router) diff --git a/app/api/v1/tts.py b/app/api/v1/tts.py new file mode 100644 index 0000000..825bef1 --- /dev/null +++ b/app/api/v1/tts.py @@ -0,0 +1,216 @@ +"""TTS API endpoints.""" + +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, Field +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.database import get_db +from app.core.dependencies import get_current_active_user_flexible +from app.models.user import User +from app.services.tts import TTSService + + +router = APIRouter(prefix="/tts", tags=["tts"]) + + +class TTSGenerateRequest(BaseModel): + """TTS generation request model.""" + + text: str = Field(..., min_length=1, max_length=1000, description="Text to convert to speech") + provider: str = Field(default="gtts", description="TTS provider to use") + options: dict[str, Any] = Field(default_factory=dict, description="Provider-specific options") + + +class TTSResponse(BaseModel): + """TTS generation response model.""" + + id: int + text: str + provider: str + options: dict[str, Any] + sound_id: int | None + user_id: int + created_at: str + + +class ProviderInfo(BaseModel): + """Provider information model.""" + + name: str + file_extension: str + supported_languages: list[str] + option_schema: dict[str, Any] + + +async def get_tts_service( + session: Annotated[AsyncSession, Depends(get_db)], +) -> TTSService: + """Get the TTS service.""" + return TTSService(session) + + +@router.post("/generate") +async def generate_tts( + request: TTSGenerateRequest, + current_user: Annotated[User, Depends(get_current_active_user_flexible)], + tts_service: Annotated[TTSService, Depends(get_tts_service)], +) -> dict[str, Any]: + """Generate TTS audio and create sound.""" + try: + if current_user.id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User ID not available", + ) + + result = await tts_service.create_tts_request( + text=request.text, + user_id=current_user.id, + provider=request.provider, + **request.options + ) + + tts_record = result["tts"] + + return { + "message": result["message"], + "tts": TTSResponse( + id=tts_record.id, + text=tts_record.text, + provider=tts_record.provider, + options=tts_record.options, + sound_id=tts_record.sound_id, + user_id=tts_record.user_id, + created_at=tts_record.created_at.isoformat(), + ) + } + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to generate TTS: {e!s}", + ) from e + + +@router.get("/providers") +async def get_providers( + tts_service: Annotated[TTSService, Depends(get_tts_service)], +) -> dict[str, ProviderInfo]: + """Get all available TTS providers.""" + providers = tts_service.get_providers() + result = {} + + for name, provider in providers.items(): + result[name] = ProviderInfo( + name=provider.name, + file_extension=provider.file_extension, + supported_languages=provider.get_supported_languages(), + option_schema=provider.get_option_schema(), + ) + + return result + + +@router.get("/providers/{provider_name}") +async def get_provider( + provider_name: str, + tts_service: Annotated[TTSService, Depends(get_tts_service)], +) -> ProviderInfo: + """Get information about a specific TTS provider.""" + provider = tts_service.get_provider(provider_name) + + if not provider: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Provider '{provider_name}' not found", + ) + + return ProviderInfo( + name=provider.name, + file_extension=provider.file_extension, + supported_languages=provider.get_supported_languages(), + option_schema=provider.get_option_schema(), + ) + + +@router.get("/history") +async def get_tts_history( + current_user: Annotated[User, Depends(get_current_active_user_flexible)], + tts_service: Annotated[TTSService, Depends(get_tts_service)], + limit: int = 50, + offset: int = 0, +) -> list[TTSResponse]: + """Get TTS generation history for the current user.""" + try: + if current_user.id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User ID not available", + ) + + tts_records = await tts_service.get_user_tts_history( + user_id=current_user.id, + limit=limit, + offset=offset, + ) + + return [ + TTSResponse( + id=tts.id, + text=tts.text, + provider=tts.provider, + options=tts.options, + sound_id=tts.sound_id, + user_id=tts.user_id, + created_at=tts.created_at.isoformat(), + ) + for tts in tts_records + ] + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get TTS history: {e!s}", + ) from e + + +@router.delete("/{tts_id}") +async def delete_tts( + tts_id: int, + current_user: Annotated[User, Depends(get_current_active_user_flexible)], + tts_service: Annotated[TTSService, Depends(get_tts_service)], +) -> dict[str, str]: + """Delete a TTS generation and its associated files.""" + try: + if current_user.id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User ID not available", + ) + + await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id) + + return {"message": "TTS generation deleted successfully"} + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except PermissionError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e), + ) from e + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete TTS: {e!s}", + ) from e \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py index 1e3c2c7..e951768 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -12,6 +12,7 @@ from .playlist_sound import PlaylistSound from .scheduled_task import ScheduledTask from .sound import Sound from .sound_played import SoundPlayed +from .tts import TTS from .user import User from .user_oauth import UserOauth @@ -27,6 +28,7 @@ __all__ = [ "ScheduledTask", "Sound", "SoundPlayed", + "TTS", "User", "UserOauth", ] diff --git a/app/models/tts.py b/app/models/tts.py new file mode 100644 index 0000000..3dc1a66 --- /dev/null +++ b/app/models/tts.py @@ -0,0 +1,26 @@ +"""TTS model.""" + +from datetime import datetime +from typing import Any + +from sqlalchemy import JSON, Column +from sqlmodel import Field, SQLModel + + +class TTS(SQLModel, table=True): + """Text-to-Speech generation record.""" + + __tablename__ = "tts" + + id: int | None = Field(primary_key=True) + text: str = Field(max_length=1000, description="Text that was converted to speech") + provider: str = Field(max_length=50, description="TTS provider used") + options: dict[str, Any] = Field( + default_factory=dict, + sa_column=Column(JSON), + description="Provider-specific options used" + ) + sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID") + user_id: int = Field(foreign_key="user.id", description="User who created the TTS") + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) \ No newline at end of file diff --git a/app/repositories/tts.py b/app/repositories/tts.py new file mode 100644 index 0000000..497ffcb --- /dev/null +++ b/app/repositories/tts.py @@ -0,0 +1,62 @@ +"""TTS repository for database operations.""" + +from typing import Any, Sequence + +from sqlmodel import select + +from app.models.tts import TTS +from app.repositories.base import BaseRepository + + +class TTSRepository(BaseRepository[TTS]): + """Repository for TTS operations.""" + + def __init__(self, session: Any) -> None: + super().__init__(TTS, session) + + async def get_by_user_id( + self, + user_id: int, + limit: int = 50, + offset: int = 0, + ) -> Sequence[TTS]: + """Get TTS records by user ID with pagination. + + Args: + user_id: User ID to filter by + limit: Maximum number of records to return + offset: Number of records to skip + + Returns: + List of TTS records + """ + stmt = ( + select(self.model) + .where(self.model.user_id == user_id) + .order_by(self.model.created_at.desc()) + .limit(limit) + .offset(offset) + ) + result = await self.session.exec(stmt) + return result.all() + + async def get_by_user_and_id( + self, + user_id: int, + tts_id: int, + ) -> TTS | None: + """Get a specific TTS record by user ID and TTS ID. + + Args: + user_id: User ID to filter by + tts_id: TTS ID to retrieve + + Returns: + TTS record if found and belongs to user, None otherwise + """ + stmt = select(self.model).where( + self.model.id == tts_id, + self.model.user_id == user_id, + ) + result = await self.session.exec(stmt) + return result.first() \ No newline at end of file diff --git a/app/services/tts/__init__.py b/app/services/tts/__init__.py new file mode 100644 index 0000000..6e18257 --- /dev/null +++ b/app/services/tts/__init__.py @@ -0,0 +1,6 @@ +"""Text-to-Speech services package.""" + +from .base import TTSProvider +from .service import TTSService + +__all__ = ["TTSProvider", "TTSService"] \ No newline at end of file diff --git a/app/services/tts/base.py b/app/services/tts/base.py new file mode 100644 index 0000000..ad3e27d --- /dev/null +++ b/app/services/tts/base.py @@ -0,0 +1,38 @@ +"""Base TTS provider interface.""" + +from abc import ABC, abstractmethod +from typing import Any + + +class TTSProvider(ABC): + """Abstract base class for TTS providers.""" + + @abstractmethod + async def generate_speech(self, text: str, **options: Any) -> bytes: + """Generate speech from text with provider-specific options. + + Args: + text: The text to convert to speech + **options: Provider-specific options + + Returns: + Audio data as bytes + """ + + @abstractmethod + def get_supported_languages(self) -> list[str]: + """Return list of supported language codes.""" + + @abstractmethod + def get_option_schema(self) -> dict[str, Any]: + """Return schema for provider-specific options.""" + + @property + @abstractmethod + def name(self) -> str: + """Return the provider name.""" + + @property + @abstractmethod + def file_extension(self) -> str: + """Return the default file extension for this provider.""" \ No newline at end of file diff --git a/app/services/tts/providers/__init__.py b/app/services/tts/providers/__init__.py new file mode 100644 index 0000000..cb2334d --- /dev/null +++ b/app/services/tts/providers/__init__.py @@ -0,0 +1,5 @@ +"""TTS providers package.""" + +from .gtts import GTTSProvider + +__all__ = ["GTTSProvider"] \ No newline at end of file diff --git a/app/services/tts/providers/gtts.py b/app/services/tts/providers/gtts.py new file mode 100644 index 0000000..fb75253 --- /dev/null +++ b/app/services/tts/providers/gtts.py @@ -0,0 +1,81 @@ +"""Google Text-to-Speech provider.""" + +import asyncio +import io +from typing import Any + +from gtts import gTTS + +from ..base import TTSProvider + + +class GTTSProvider(TTSProvider): + """Google Text-to-Speech provider implementation.""" + + @property + def name(self) -> str: + """Return the provider name.""" + return "gtts" + + @property + def file_extension(self) -> str: + """Return the default file extension for this provider.""" + return "mp3" + + async def generate_speech(self, text: str, **options: Any) -> bytes: + """Generate speech from text using Google TTS. + + Args: + text: The text to convert to speech + **options: GTTS-specific options (lang, tld, slow) + + Returns: + MP3 audio data as bytes + """ + lang = options.get("lang", "en") + tld = options.get("tld", "com") + slow = options.get("slow", False) + + # Run TTS generation in thread pool since gTTS is synchronous + def _generate(): + tts = gTTS(text=text, lang=lang, tld=tld, slow=slow) + fp = io.BytesIO() + tts.write_to_fp(fp) + fp.seek(0) + return fp.read() + + # Use asyncio.to_thread which is more reliable than run_in_executor + return await asyncio.to_thread(_generate) + + def get_supported_languages(self) -> list[str]: + """Return list of supported language codes.""" + # Common GTTS supported languages + return [ + "af", "ar", "bg", "bn", "bs", "ca", "cs", "cy", "da", "de", "el", "en", + "eo", "es", "et", "fi", "fr", "gu", "hi", "hr", "hu", "hy", "id", "is", + "it", "ja", "jw", "km", "kn", "ko", "la", "lv", "mk", "ml", "mr", "my", + "ne", "nl", "no", "pl", "pt", "ro", "ru", "si", "sk", "sq", "sr", "su", + "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur", "vi", "zh-cn", "zh-tw" + ] + + def get_option_schema(self) -> dict[str, Any]: + """Return schema for GTTS-specific options.""" + return { + "lang": { + "type": "string", + "default": "en", + "description": "Language code", + "enum": self.get_supported_languages() + }, + "tld": { + "type": "string", + "default": "com", + "description": "Top-level domain for Google TTS", + "enum": ["com", "co.uk", "com.au", "ca", "co.in", "ie", "co.za"] + }, + "slow": { + "type": "boolean", + "default": False, + "description": "Speak slowly" + } + } \ No newline at end of file diff --git a/app/services/tts/service.py b/app/services/tts/service.py new file mode 100644 index 0000000..843ef29 --- /dev/null +++ b/app/services/tts/service.py @@ -0,0 +1,404 @@ +"""TTS service implementation.""" + +import asyncio +import io +import uuid +from pathlib import Path +from typing import Any + +from gtts import gTTS +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.sound import Sound +from app.models.tts import TTS +from app.repositories.sound import SoundRepository +from app.repositories.tts import TTSRepository +from app.services.sound_normalizer import SoundNormalizerService +from app.utils.audio import get_audio_duration, get_file_hash, get_file_size + +from .base import TTSProvider +from .providers import GTTSProvider + +# Constants +MAX_TEXT_LENGTH = 1000 +MAX_NAME_LENGTH = 50 + + +class TTSService: + """Text-to-Speech service with provider management.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize TTS service. + + Args: + session: Database session + """ + self.session = session + self.sound_repo = SoundRepository(session) + self.tts_repo = TTSRepository(session) + self.providers: dict[str, TTSProvider] = {} + + # Register default providers + self._register_default_providers() + + def _register_default_providers(self) -> None: + """Register default TTS providers.""" + self.register_provider(GTTSProvider()) + + def register_provider(self, provider: TTSProvider) -> None: + """Register a TTS provider. + + Args: + provider: TTS provider instance + """ + self.providers[provider.name] = provider + + def get_providers(self) -> dict[str, TTSProvider]: + """Get all registered providers.""" + return self.providers.copy() + + def get_provider(self, name: str) -> TTSProvider | None: + """Get a specific provider by name.""" + return self.providers.get(name) + + async def create_tts_request( + self, + text: str, + user_id: int, + provider: str = "gtts", + **options: Any, + ) -> dict[str, Any]: + """Create a TTS request that will be processed in the background. + + Args: + text: Text to convert to speech + user_id: ID of user creating the sound + provider: TTS provider name + **options: Provider-specific options + + Returns: + Dictionary with TTS record information + + Raises: + ValueError: If provider not found or text too long + Exception: If request creation fails + """ + provider_not_found_msg = f"Provider '{provider}' not found" + if provider not in self.providers: + raise ValueError(provider_not_found_msg) + + text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)" + if len(text) > MAX_TEXT_LENGTH: + raise ValueError(text_too_long_msg) + + empty_text_msg = "Text cannot be empty" + if not text.strip(): + raise ValueError(empty_text_msg) + + # Create TTS record with pending status + tts = TTS( + text=text, + provider=provider, + options=options, + sound_id=None, # Will be set when processing completes + user_id=user_id, + ) + self.session.add(tts) + await self.session.commit() + await self.session.refresh(tts) + + # Queue for background processing + if tts.id is not None: + await self._queue_tts_processing(tts.id) + + return {"tts": tts, "message": "TTS generation queued successfully"} + + async def _queue_tts_processing(self, tts_id: int) -> None: + """Queue TTS for background processing.""" + # For now, process immediately in a different way + # This could be moved to a proper background queue later + task = asyncio.create_task(self._process_tts_in_background(tts_id)) + # Store reference to prevent garbage collection + self._background_tasks = getattr(self, '_background_tasks', set()) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + async def _process_tts_in_background(self, tts_id: int) -> None: + """Process TTS generation in background.""" + from app.core.database import get_session_factory + + try: + # Create a new session for background processing + session_factory = get_session_factory() + async with session_factory() as background_session: + tts_service = TTSService(background_session) + + # Get the TTS record + stmt = select(TTS).where(TTS.id == tts_id) + result = await background_session.exec(stmt) + tts = result.first() + + if not tts: + return + + # Use a synchronous approach for the actual generation + sound = await tts_service._generate_tts_sync( + tts.text, + tts.provider, + tts.user_id, + tts.options, + ) + + # Update the TTS record with the sound ID + if sound.id is not None: + tts.sound_id = sound.id + background_session.add(tts) + await background_session.commit() + + except Exception: + # Log error but don't fail - avoiding print for production + pass + + async def _generate_tts_sync( + self, + text: str, + provider: str, + user_id: int, + options: dict[str, Any] + ) -> Sound: + """Generate TTS using a synchronous approach.""" + # Generate the audio using the provider (avoid async issues by doing it directly) + tts_provider = self.providers[provider] + + # Create directories if they don't exist + original_dir = Path("sounds/originals/text_to_speech") + original_dir.mkdir(parents=True, exist_ok=True) + + # Create UUID filename + sound_uuid = str(uuid.uuid4()) + original_filename = f"{sound_uuid}.{tts_provider.file_extension}" + original_path = original_dir / original_filename + + # Generate audio synchronously + try: + # Generate TTS audio + lang = options.get("lang", "en") + tld = options.get("tld", "com") + slow = options.get("slow", False) + + tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow) + fp = io.BytesIO() + tts_instance.write_to_fp(fp) + fp.seek(0) + audio_bytes = fp.read() + + # Save the file + original_path.write_bytes(audio_bytes) + + except Exception: + raise + + # Create Sound record with proper metadata + sound = await self._create_sound_record_complete( + original_path, + text, + provider, + user_id + ) + + # Normalize the sound + await self._normalize_sound_safe(sound.id) + + return sound + + async def get_user_tts_history( + self, + user_id: int, + limit: int = 50, + offset: int = 0 + ) -> list[TTS]: + """Get TTS history for a user. + + Args: + user_id: User ID + limit: Maximum number of records + offset: Offset for pagination + + Returns: + List of TTS records + """ + result = await self.tts_repo.get_by_user_id(user_id, limit, offset) + return list(result) + + async def _create_sound_record( + self, + audio_path: Path, + text: str, + provider: str, + user_id: int, + file_hash: str + ) -> Sound: + """Create a Sound record for the TTS audio.""" + # Get audio metadata + duration = get_audio_duration(audio_path) + size = get_file_size(audio_path) + + # Create sound data + sound_data = { + "type": "TTS", + "name": text[:50] + ("..." if len(text) > 50 else ""), + "filename": audio_path.name, + "duration": duration, + "size": size, + "hash": file_hash, + "user_id": user_id, + "is_deletable": True, + "is_music": False, # TTS is speech, not music + "is_normalized": False, + "play_count": 0, + } + + sound = await self.sound_repo.create(sound_data) + return sound + + async def _create_sound_record_simple( + self, + audio_path: Path, + text: str, + provider: str, + user_id: int + ) -> Sound: + """Create a Sound record for the TTS audio with minimal processing.""" + # Create sound data with basic info + sound_data = { + "type": "TTS", + "name": text[:50] + ("..." if len(text) > 50 else ""), + "filename": audio_path.name, + "duration": 0, # Skip duration calculation for now + "size": 0, # Skip size calculation for now + "hash": str(uuid.uuid4()), # Use UUID as temporary hash + "user_id": user_id, + "is_deletable": True, + "is_music": False, # TTS is speech, not music + "is_normalized": False, + "play_count": 0, + } + + sound = await self.sound_repo.create(sound_data) + return sound + + async def _create_sound_record_complete( + self, + audio_path: Path, + text: str, + provider: str, + user_id: int + ) -> Sound: + """Create a Sound record for the TTS audio with complete metadata.""" + # Get audio metadata + duration = get_audio_duration(audio_path) + size = get_file_size(audio_path) + file_hash = get_file_hash(audio_path) + + # Check if a sound with this hash already exists + existing_sound = await self.sound_repo.get_by_hash(file_hash) + + if existing_sound: + # Clean up the temporary file since we have a duplicate + if audio_path.exists(): + audio_path.unlink() + return existing_sound + + # Create sound data with complete metadata + sound_data = { + "type": "TTS", + "name": text[:50] + ("..." if len(text) > 50 else ""), + "filename": audio_path.name, + "duration": duration, + "size": size, + "hash": file_hash, + "user_id": user_id, + "is_deletable": True, + "is_music": False, # TTS is speech, not music + "is_normalized": False, + "play_count": 0, + } + + sound = await self.sound_repo.create(sound_data) + return sound + + async def _normalize_sound_safe(self, sound_id: int) -> None: + """Normalize the TTS sound with error handling.""" + try: + # Get fresh sound object from database for normalization + sound = await self.sound_repo.get_by_id(sound_id) + if not sound: + return + + normalizer_service = SoundNormalizerService(self.session) + result = await normalizer_service.normalize_sound(sound) + + if result["status"] == "error": + print(f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}") + + except Exception as e: + print(f"Exception during TTS sound normalization {sound_id}: {e}") + # Don't fail the TTS generation if normalization fails + + async def _normalize_sound(self, sound_id: int) -> None: + """Normalize the TTS sound.""" + try: + # Get fresh sound object from database for normalization + sound = await self.sound_repo.get_by_id(sound_id) + if not sound: + return + + normalizer_service = SoundNormalizerService(self.session) + result = await normalizer_service.normalize_sound(sound) + + if result["status"] == "error": + # Log warning but don't fail the TTS generation + pass + + except Exception: + # Don't fail the TTS generation if normalization fails + pass + + async def delete_tts(self, tts_id: int, user_id: int) -> None: + """Delete a TTS generation and its associated sound and files.""" + # Get the TTS record + tts = await self.tts_repo.get_by_id(tts_id) + if not tts: + raise ValueError(f"TTS with ID {tts_id} not found") + + # Check ownership + if tts.user_id != user_id: + raise PermissionError("You don't have permission to delete this TTS generation") + + # If there's an associated sound, delete it and its files + if tts.sound_id: + sound = await self.sound_repo.get_by_id(tts.sound_id) + if sound: + # Delete the sound files + await self._delete_sound_files(sound) + # Delete the sound record + await self.sound_repo.delete(sound) + + # Delete the TTS record + await self.tts_repo.delete(tts) + + async def _delete_sound_files(self, sound: Sound) -> None: + """Delete all files associated with a sound.""" + from pathlib import Path + + # Delete original file + original_path = Path("sounds/originals/text_to_speech") / sound.filename + if original_path.exists(): + original_path.unlink() + + # Delete normalized file if it exists + if sound.normalized_filename: + normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename + if normalized_path.exists(): + normalized_path.unlink() \ No newline at end of file