From acdf191a5a45305a70b57d036b93caf87e4e393c Mon Sep 17 00:00:00 2001 From: JSC Date: Sun, 21 Sep 2025 19:07:32 +0200 Subject: [PATCH] refactor: Improve code readability and structure across TTS modules --- app/api/v1/tts.py | 4 +- app/models/tts.py | 4 +- app/repositories/tts.py | 13 +++- app/services/playlist.py | 11 ++- app/services/tts/base.py | 8 ++- app/services/tts/providers/gtts.py | 9 ++- app/services/tts/service.py | 111 ++++++++++++++++++----------- app/services/tts_processor.py | 2 +- 8 files changed, 106 insertions(+), 56 deletions(-) diff --git a/app/api/v1/tts.py b/app/api/v1/tts.py index 9e951bf..b927753 100644 --- a/app/api/v1/tts.py +++ b/app/api/v1/tts.py @@ -206,8 +206,6 @@ async def delete_tts( await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id) - return {"message": "TTS generation deleted successfully"} - except ValueError as e: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -223,3 +221,5 @@ async def delete_tts( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to delete TTS: {e!s}", ) from e + else: + return {"message": "TTS generation deleted successfully"} diff --git a/app/models/tts.py b/app/models/tts.py index 005b229..32044a7 100644 --- a/app/models/tts.py +++ b/app/models/tts.py @@ -22,7 +22,9 @@ class TTS(SQLModel, table=True): ) status: str = Field(default="pending", description="Processing status") error: str | None = Field(default=None, description="Error message if failed") - sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID") + sound_id: int | None = Field( + foreign_key="sound.id", description="Associated sound ID", + ) user_id: int = Field(foreign_key="user.id", description="User who created the TTS") created_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow) diff --git a/app/repositories/tts.py b/app/repositories/tts.py index d17430e..4c65ad3 100644 --- a/app/repositories/tts.py +++ b/app/repositories/tts.py @@ -1,7 +1,10 @@ """TTS repository for database operations.""" from collections.abc import Sequence -from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import select @@ -12,7 +15,13 @@ from app.repositories.base import BaseRepository class TTSRepository(BaseRepository[TTS]): """Repository for TTS operations.""" - def __init__(self, session: Any) -> None: + def __init__(self, session: "AsyncSession") -> None: + """Initialize TTS repository. + + Args: + session: Database session for operations + + """ super().__init__(TTS, session) async def get_by_user_id( diff --git a/app/services/playlist.py b/app/services/playlist.py index ba29e68..1222227 100644 --- a/app/services/playlist.py +++ b/app/services/playlist.py @@ -244,7 +244,10 @@ class PlaylistService: if was_current: main_playlist = await self.get_main_playlist() await self.playlist_repo.update(main_playlist, {"is_current": True}) - logger.info("Set main playlist as current after deleting current playlist %s", playlist_id) + logger.info( + "Set main playlist as current after deleting current playlist %s", + playlist_id, + ) # Reload player to reflect the change await _reload_player_playlist() @@ -562,7 +565,11 @@ class PlaylistService: await self.session.delete(playlist_sound) await self.session.commit() - logger.info("Deleted %d playlist_sound records for playlist %s", len(playlist_sounds), playlist_id) + logger.info( + "Deleted %d playlist_sound records for playlist %s", + len(playlist_sounds), + playlist_id, + ) async def _unset_current_playlist(self) -> None: """Unset any current playlist globally.""" diff --git a/app/services/tts/base.py b/app/services/tts/base.py index 5f3cd11..b3b1e81 100644 --- a/app/services/tts/base.py +++ b/app/services/tts/base.py @@ -1,14 +1,16 @@ """Base TTS provider interface.""" from abc import ABC, abstractmethod -from typing import Any + +# Type alias for TTS options +TTSOptions = dict[str, str | bool | int | float] class TTSProvider(ABC): """Abstract base class for TTS providers.""" @abstractmethod - async def generate_speech(self, text: str, **options: Any) -> bytes: + async def generate_speech(self, text: str, **options: str | bool | float) -> bytes: """Generate speech from text with provider-specific options. Args: @@ -25,7 +27,7 @@ class TTSProvider(ABC): """Return list of supported language codes.""" @abstractmethod - def get_option_schema(self) -> dict[str, Any]: + def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]: """Return schema for provider-specific options.""" @property diff --git a/app/services/tts/providers/gtts.py b/app/services/tts/providers/gtts.py index 7f31600..6894c27 100644 --- a/app/services/tts/providers/gtts.py +++ b/app/services/tts/providers/gtts.py @@ -2,11 +2,10 @@ import asyncio import io -from typing import Any from gtts import gTTS -from ..base import TTSProvider +from app.services.tts.base import TTSProvider class GTTSProvider(TTSProvider): @@ -22,7 +21,7 @@ class GTTSProvider(TTSProvider): """Return the default file extension for this provider.""" return "mp3" - async def generate_speech(self, text: str, **options: Any) -> bytes: + async def generate_speech(self, text: str, **options: str | bool | float) -> bytes: """Generate speech from text using Google TTS. Args: @@ -38,7 +37,7 @@ class GTTSProvider(TTSProvider): slow = options.get("slow", False) # Run TTS generation in thread pool since gTTS is synchronous - def _generate(): + def _generate() -> bytes: tts = gTTS(text=text, lang=lang, tld=tld, slow=slow) fp = io.BytesIO() tts.write_to_fp(fp) @@ -64,7 +63,7 @@ class GTTSProvider(TTSProvider): "vi", "yo", "zh", "zh-cn", "zh-tw", "zu", ] - def get_option_schema(self) -> dict[str, Any]: + def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]: """Return schema for GTTS-specific options.""" return { "lang": { diff --git a/app/services/tts/service.py b/app/services/tts/service.py index 11afa94..064f15e 100644 --- a/app/services/tts/service.py +++ b/app/services/tts/service.py @@ -10,10 +10,13 @@ from gtts import gTTS from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +from app.core.database import get_session_factory +from app.core.logging import get_logger from app.models.sound import Sound from app.models.tts import TTS from app.repositories.sound import SoundRepository from app.repositories.tts import TTSRepository +from app.services.socket import socket_manager from app.services.sound_normalizer import SoundNormalizerService from app.utils.audio import get_audio_duration, get_file_hash, get_file_size @@ -25,6 +28,12 @@ MAX_TEXT_LENGTH = 1000 MAX_NAME_LENGTH = 50 +async def _get_tts_processor() -> object: + """Get TTS processor instance, avoiding circular import.""" + from app.services.tts_processor import tts_processor # noqa: PLC0415 + return tts_processor + + class TTSService: """Text-to-Speech service with provider management.""" @@ -69,7 +78,7 @@ class TTSService: text: str, user_id: int, provider: str = "gtts", - **options: Any, + **options: str | bool | float, ) -> dict[str, Any]: """Create a TTS request that will be processed in the background. @@ -114,8 +123,7 @@ class TTSService: # Queue for background processing using the TTS processor if tts.id is not None: - from app.services.tts_processor import tts_processor - + tts_processor = await _get_tts_processor() await tts_processor.queue_tts(tts.id) return {"tts": tts, "message": "TTS generation queued successfully"} @@ -132,8 +140,6 @@ class TTSService: async def _process_tts_in_background(self, tts_id: int) -> None: """Process TTS generation in background.""" - from app.core.database import get_session_factory - try: # Create a new session for background processing session_factory = get_session_factory() @@ -164,13 +170,19 @@ class TTSService: except Exception: # Log error but don't fail - avoiding print for production - pass + logger = get_logger(__name__) + logger.exception("Error processing TTS generation %s", tts_id) async def _generate_tts_sync( - self, text: str, provider: str, user_id: int, options: dict[str, Any], + self, + text: str, + provider: str, + user_id: int, + options: dict[str, Any], ) -> Sound: """Generate TTS using a synchronous approach.""" - # Generate the audio using the provider (avoid async issues by doing it directly) + # Generate the audio using the provider + # (avoid async issues by doing it directly) tts_provider = self.providers[provider] # Create directories if they don't exist @@ -199,20 +211,28 @@ class TTSService: original_path.write_bytes(audio_bytes) except Exception: + logger = get_logger(__name__) + logger.exception("Error generating TTS audio") raise # Create Sound record with proper metadata sound = await self._create_sound_record_complete( - original_path, text, provider, user_id, + original_path, + text, + user_id, ) # Normalize the sound - await self._normalize_sound_safe(sound.id) + if sound.id is not None: + await self._normalize_sound_safe(sound.id) return sound async def get_user_tts_history( - self, user_id: int, limit: int = 50, offset: int = 0, + self, + user_id: int, + limit: int = 50, + offset: int = 0, ) -> list[TTS]: """Get TTS history for a user. @@ -229,7 +249,11 @@ class TTSService: return list(result) async def _create_sound_record( - self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str, + self, + audio_path: Path, + text: str, + user_id: int, + file_hash: str, ) -> Sound: """Create a Sound record for the TTS audio.""" # Get audio metadata @@ -253,11 +277,13 @@ class TTSService: "play_count": 0, } - sound = await self.sound_repo.create(sound_data) - return sound + return await self.sound_repo.create(sound_data) async def _create_sound_record_simple( - self, audio_path: Path, text: str, provider: str, user_id: int, + self, + audio_path: Path, + text: str, + user_id: int, ) -> Sound: """Create a Sound record for the TTS audio with minimal processing.""" # Create sound data with basic info @@ -278,11 +304,13 @@ class TTSService: "play_count": 0, } - sound = await self.sound_repo.create(sound_data) - return sound + return await self.sound_repo.create(sound_data) async def _create_sound_record_complete( - self, audio_path: Path, text: str, provider: str, user_id: int, + self, + audio_path: Path, + text: str, + user_id: int, ) -> Sound: """Create a Sound record for the TTS audio with complete metadata.""" # Get audio metadata @@ -316,8 +344,7 @@ class TTSService: "play_count": 0, } - sound = await self.sound_repo.create(sound_data) - return sound + return await self.sound_repo.create(sound_data) async def _normalize_sound_safe(self, sound_id: int) -> None: """Normalize the TTS sound with error handling.""" @@ -331,12 +358,16 @@ class TTSService: result = await normalizer_service.normalize_sound(sound) if result["status"] == "error": - print( - f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}", + logger = get_logger(__name__) + logger.warning( + "Warning: Failed to normalize TTS sound %s: %s", + sound_id, + result.get("error"), ) - except Exception as e: - print(f"Exception during TTS sound normalization {sound_id}: {e}") + except Exception: + logger = get_logger(__name__) + logger.exception("Exception during TTS sound normalization %s", sound_id) # Don't fail the TTS generation if normalization fails async def _normalize_sound(self, sound_id: int) -> None: @@ -356,20 +387,23 @@ class TTSService: except Exception: # Don't fail the TTS generation if normalization fails - pass + logger = get_logger(__name__) + logger.exception("Error normalizing sound %s", sound_id) async def delete_tts(self, tts_id: int, user_id: int) -> None: """Delete a TTS generation and its associated sound and files.""" # Get the TTS record tts = await self.tts_repo.get_by_id(tts_id) if not tts: - raise ValueError(f"TTS with ID {tts_id} not found") + tts_not_found_msg = f"TTS with ID {tts_id} not found" + raise ValueError(tts_not_found_msg) # Check ownership if tts.user_id != user_id: - raise PermissionError( - "You don't have permission to delete this TTS generation", + permission_error_msg = ( + "You don't have permission to delete this TTS generation" ) + raise PermissionError(permission_error_msg) # If there's an associated sound, delete it and its files if tts.sound_id: @@ -385,8 +419,6 @@ class TTSService: async def _delete_sound_files(self, sound: Sound) -> None: """Delete all files associated with a sound.""" - from pathlib import Path - # Delete original file original_path = Path("sounds/originals/text_to_speech") / sound.filename if original_path.exists(): @@ -465,7 +497,8 @@ class TTSService: tts = result.first() if not tts: - raise ValueError(f"TTS with ID {tts_id} not found") + tts_not_found_msg = f"TTS with ID {tts_id} not found" + raise ValueError(tts_not_found_msg) # Generate the TTS sound = await self._generate_tts_sync( @@ -477,6 +510,9 @@ class TTSService: # Capture sound ID before session issues sound_id = sound.id + if sound_id is None: + sound_creation_error = "Sound creation failed - no ID assigned" + raise ValueError(sound_creation_error) # Mark as completed await self.mark_tts_completed(tts_id, sound_id) @@ -501,9 +537,6 @@ class TTSService: ) -> None: """Emit a socket event for TTS status change.""" try: - from app.core.logging import get_logger - from app.services.socket import socket_manager - logger = get_logger(__name__) data = { @@ -513,12 +546,10 @@ class TTSService: if error: data["error"] = error - logger.info(f"Emitting TTS socket event: {event} with data: {data}") + logger.info("Emitting TTS socket event: %s with data: %s", event, data) await socket_manager.broadcast_to_all(event, data) - logger.info(f"Successfully emitted TTS socket event: {event}") - except Exception as e: + logger.info("Successfully emitted TTS socket event: %s", event) + except Exception: # Don't fail TTS processing if socket emission fails - from app.core.logging import get_logger - logger = get_logger(__name__) - logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True) + logger.exception("Failed to emit TTS socket event %s", event) diff --git a/app/services/tts_processor.py b/app/services/tts_processor.py index 80c7bf3..bf10387 100644 --- a/app/services/tts_processor.py +++ b/app/services/tts_processor.py @@ -14,7 +14,7 @@ logger = get_logger(__name__) class TTSProcessor: - """Background processor for handling TTS generation queue with concurrency control.""" + """Background processor for handling TTS generation queue.""" def __init__(self) -> None: """Initialize the TTS processor."""