Compare commits

..

2 Commits

Author SHA1 Message Date
JSC
acdf191a5a refactor: Improve code readability and structure across TTS modules
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m36s
2025-09-21 19:07:32 +02:00
JSC
35b857fd0d feat: Add GitHub as an available OAuth provider and remove database initialization logs 2025-09-21 18:58:20 +02:00
10 changed files with 107 additions and 60 deletions

View File

@@ -331,7 +331,7 @@ async def oauth_callback(
async def get_oauth_providers() -> dict[str, list[str]]: async def get_oauth_providers() -> dict[str, list[str]]:
"""Get list of available OAuth providers.""" """Get list of available OAuth providers."""
return { return {
"providers": ["google"], "providers": ["google", "github"],
} }

View File

@@ -206,8 +206,6 @@ async def delete_tts(
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id) await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
return {"message": "TTS generation deleted successfully"}
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@@ -223,3 +221,5 @@ async def delete_tts(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete TTS: {e!s}", detail=f"Failed to delete TTS: {e!s}",
) from e ) from e
else:
return {"message": "TTS generation deleted successfully"}

View File

@@ -29,9 +29,6 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
logger = get_logger(__name__) logger = get_logger(__name__)
logger.info("Starting application") logger.info("Starting application")
# await init_db()
# logger.info("Database initialized")
# Start the extraction processor # Start the extraction processor
await extraction_processor.start() await extraction_processor.start()
logger.info("Extraction processor started") logger.info("Extraction processor started")

View File

@@ -22,7 +22,9 @@ class TTS(SQLModel, table=True):
) )
status: str = Field(default="pending", description="Processing status") status: str = Field(default="pending", description="Processing status")
error: str | None = Field(default=None, description="Error message if failed") error: str | None = Field(default=None, description="Error message if failed")
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID") sound_id: int | None = Field(
foreign_key="sound.id", description="Associated sound ID",
)
user_id: int = Field(foreign_key="user.id", description="User who created the TTS") user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
created_at: datetime = Field(default_factory=datetime.utcnow) created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -1,7 +1,10 @@
"""TTS repository for database operations.""" """TTS repository for database operations."""
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select from sqlmodel import select
@@ -12,7 +15,13 @@ from app.repositories.base import BaseRepository
class TTSRepository(BaseRepository[TTS]): class TTSRepository(BaseRepository[TTS]):
"""Repository for TTS operations.""" """Repository for TTS operations."""
def __init__(self, session: Any) -> None: def __init__(self, session: "AsyncSession") -> None:
"""Initialize TTS repository.
Args:
session: Database session for operations
"""
super().__init__(TTS, session) super().__init__(TTS, session)
async def get_by_user_id( async def get_by_user_id(

View File

@@ -244,7 +244,10 @@ class PlaylistService:
if was_current: if was_current:
main_playlist = await self.get_main_playlist() main_playlist = await self.get_main_playlist()
await self.playlist_repo.update(main_playlist, {"is_current": True}) await self.playlist_repo.update(main_playlist, {"is_current": True})
logger.info("Set main playlist as current after deleting current playlist %s", playlist_id) logger.info(
"Set main playlist as current after deleting current playlist %s",
playlist_id,
)
# Reload player to reflect the change # Reload player to reflect the change
await _reload_player_playlist() await _reload_player_playlist()
@@ -562,7 +565,11 @@ class PlaylistService:
await self.session.delete(playlist_sound) await self.session.delete(playlist_sound)
await self.session.commit() await self.session.commit()
logger.info("Deleted %d playlist_sound records for playlist %s", len(playlist_sounds), playlist_id) logger.info(
"Deleted %d playlist_sound records for playlist %s",
len(playlist_sounds),
playlist_id,
)
async def _unset_current_playlist(self) -> None: async def _unset_current_playlist(self) -> None:
"""Unset any current playlist globally.""" """Unset any current playlist globally."""

View File

@@ -1,14 +1,16 @@
"""Base TTS provider interface.""" """Base TTS provider interface."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
# Type alias for TTS options
TTSOptions = dict[str, str | bool | int | float]
class TTSProvider(ABC): class TTSProvider(ABC):
"""Abstract base class for TTS providers.""" """Abstract base class for TTS providers."""
@abstractmethod @abstractmethod
async def generate_speech(self, text: str, **options: Any) -> bytes: async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
"""Generate speech from text with provider-specific options. """Generate speech from text with provider-specific options.
Args: Args:
@@ -25,7 +27,7 @@ class TTSProvider(ABC):
"""Return list of supported language codes.""" """Return list of supported language codes."""
@abstractmethod @abstractmethod
def get_option_schema(self) -> dict[str, Any]: def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
"""Return schema for provider-specific options.""" """Return schema for provider-specific options."""
@property @property

View File

@@ -2,11 +2,10 @@
import asyncio import asyncio
import io import io
from typing import Any
from gtts import gTTS from gtts import gTTS
from ..base import TTSProvider from app.services.tts.base import TTSProvider
class GTTSProvider(TTSProvider): class GTTSProvider(TTSProvider):
@@ -22,7 +21,7 @@ class GTTSProvider(TTSProvider):
"""Return the default file extension for this provider.""" """Return the default file extension for this provider."""
return "mp3" return "mp3"
async def generate_speech(self, text: str, **options: Any) -> bytes: async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
"""Generate speech from text using Google TTS. """Generate speech from text using Google TTS.
Args: Args:
@@ -38,7 +37,7 @@ class GTTSProvider(TTSProvider):
slow = options.get("slow", False) slow = options.get("slow", False)
# Run TTS generation in thread pool since gTTS is synchronous # Run TTS generation in thread pool since gTTS is synchronous
def _generate(): def _generate() -> bytes:
tts = gTTS(text=text, lang=lang, tld=tld, slow=slow) tts = gTTS(text=text, lang=lang, tld=tld, slow=slow)
fp = io.BytesIO() fp = io.BytesIO()
tts.write_to_fp(fp) tts.write_to_fp(fp)
@@ -64,7 +63,7 @@ class GTTSProvider(TTSProvider):
"vi", "yo", "zh", "zh-cn", "zh-tw", "zu", "vi", "yo", "zh", "zh-cn", "zh-tw", "zu",
] ]
def get_option_schema(self) -> dict[str, Any]: def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
"""Return schema for GTTS-specific options.""" """Return schema for GTTS-specific options."""
return { return {
"lang": { "lang": {

View File

@@ -10,10 +10,13 @@ from gtts import gTTS
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_session_factory
from app.core.logging import get_logger
from app.models.sound import Sound from app.models.sound import Sound
from app.models.tts import TTS from app.models.tts import TTS
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
from app.repositories.tts import TTSRepository from app.repositories.tts import TTSRepository
from app.services.socket import socket_manager
from app.services.sound_normalizer import SoundNormalizerService from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
@@ -25,6 +28,12 @@ MAX_TEXT_LENGTH = 1000
MAX_NAME_LENGTH = 50 MAX_NAME_LENGTH = 50
async def _get_tts_processor() -> object:
"""Get TTS processor instance, avoiding circular import."""
from app.services.tts_processor import tts_processor # noqa: PLC0415
return tts_processor
class TTSService: class TTSService:
"""Text-to-Speech service with provider management.""" """Text-to-Speech service with provider management."""
@@ -69,7 +78,7 @@ class TTSService:
text: str, text: str,
user_id: int, user_id: int,
provider: str = "gtts", provider: str = "gtts",
**options: Any, **options: str | bool | float,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create a TTS request that will be processed in the background. """Create a TTS request that will be processed in the background.
@@ -114,8 +123,7 @@ class TTSService:
# Queue for background processing using the TTS processor # Queue for background processing using the TTS processor
if tts.id is not None: if tts.id is not None:
from app.services.tts_processor import tts_processor tts_processor = await _get_tts_processor()
await tts_processor.queue_tts(tts.id) await tts_processor.queue_tts(tts.id)
return {"tts": tts, "message": "TTS generation queued successfully"} return {"tts": tts, "message": "TTS generation queued successfully"}
@@ -132,8 +140,6 @@ class TTSService:
async def _process_tts_in_background(self, tts_id: int) -> None: async def _process_tts_in_background(self, tts_id: int) -> None:
"""Process TTS generation in background.""" """Process TTS generation in background."""
from app.core.database import get_session_factory
try: try:
# Create a new session for background processing # Create a new session for background processing
session_factory = get_session_factory() session_factory = get_session_factory()
@@ -164,13 +170,19 @@ class TTSService:
except Exception: except Exception:
# Log error but don't fail - avoiding print for production # Log error but don't fail - avoiding print for production
pass logger = get_logger(__name__)
logger.exception("Error processing TTS generation %s", tts_id)
async def _generate_tts_sync( async def _generate_tts_sync(
self, text: str, provider: str, user_id: int, options: dict[str, Any], self,
text: str,
provider: str,
user_id: int,
options: dict[str, Any],
) -> Sound: ) -> Sound:
"""Generate TTS using a synchronous approach.""" """Generate TTS using a synchronous approach."""
# Generate the audio using the provider (avoid async issues by doing it directly) # Generate the audio using the provider
# (avoid async issues by doing it directly)
tts_provider = self.providers[provider] tts_provider = self.providers[provider]
# Create directories if they don't exist # Create directories if they don't exist
@@ -199,20 +211,28 @@ class TTSService:
original_path.write_bytes(audio_bytes) original_path.write_bytes(audio_bytes)
except Exception: except Exception:
logger = get_logger(__name__)
logger.exception("Error generating TTS audio")
raise raise
# Create Sound record with proper metadata # Create Sound record with proper metadata
sound = await self._create_sound_record_complete( sound = await self._create_sound_record_complete(
original_path, text, provider, user_id, original_path,
text,
user_id,
) )
# Normalize the sound # Normalize the sound
await self._normalize_sound_safe(sound.id) if sound.id is not None:
await self._normalize_sound_safe(sound.id)
return sound return sound
async def get_user_tts_history( async def get_user_tts_history(
self, user_id: int, limit: int = 50, offset: int = 0, self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> list[TTS]: ) -> list[TTS]:
"""Get TTS history for a user. """Get TTS history for a user.
@@ -229,7 +249,11 @@ class TTSService:
return list(result) return list(result)
async def _create_sound_record( async def _create_sound_record(
self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str, self,
audio_path: Path,
text: str,
user_id: int,
file_hash: str,
) -> Sound: ) -> Sound:
"""Create a Sound record for the TTS audio.""" """Create a Sound record for the TTS audio."""
# Get audio metadata # Get audio metadata
@@ -253,11 +277,13 @@ class TTSService:
"play_count": 0, "play_count": 0,
} }
sound = await self.sound_repo.create(sound_data) return await self.sound_repo.create(sound_data)
return sound
async def _create_sound_record_simple( async def _create_sound_record_simple(
self, audio_path: Path, text: str, provider: str, user_id: int, self,
audio_path: Path,
text: str,
user_id: int,
) -> Sound: ) -> Sound:
"""Create a Sound record for the TTS audio with minimal processing.""" """Create a Sound record for the TTS audio with minimal processing."""
# Create sound data with basic info # Create sound data with basic info
@@ -278,11 +304,13 @@ class TTSService:
"play_count": 0, "play_count": 0,
} }
sound = await self.sound_repo.create(sound_data) return await self.sound_repo.create(sound_data)
return sound
async def _create_sound_record_complete( async def _create_sound_record_complete(
self, audio_path: Path, text: str, provider: str, user_id: int, self,
audio_path: Path,
text: str,
user_id: int,
) -> Sound: ) -> Sound:
"""Create a Sound record for the TTS audio with complete metadata.""" """Create a Sound record for the TTS audio with complete metadata."""
# Get audio metadata # Get audio metadata
@@ -316,8 +344,7 @@ class TTSService:
"play_count": 0, "play_count": 0,
} }
sound = await self.sound_repo.create(sound_data) return await self.sound_repo.create(sound_data)
return sound
async def _normalize_sound_safe(self, sound_id: int) -> None: async def _normalize_sound_safe(self, sound_id: int) -> None:
"""Normalize the TTS sound with error handling.""" """Normalize the TTS sound with error handling."""
@@ -331,12 +358,16 @@ class TTSService:
result = await normalizer_service.normalize_sound(sound) result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error": if result["status"] == "error":
print( logger = get_logger(__name__)
f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}", logger.warning(
"Warning: Failed to normalize TTS sound %s: %s",
sound_id,
result.get("error"),
) )
except Exception as e: except Exception:
print(f"Exception during TTS sound normalization {sound_id}: {e}") logger = get_logger(__name__)
logger.exception("Exception during TTS sound normalization %s", sound_id)
# Don't fail the TTS generation if normalization fails # Don't fail the TTS generation if normalization fails
async def _normalize_sound(self, sound_id: int) -> None: async def _normalize_sound(self, sound_id: int) -> None:
@@ -356,20 +387,23 @@ class TTSService:
except Exception: except Exception:
# Don't fail the TTS generation if normalization fails # Don't fail the TTS generation if normalization fails
pass logger = get_logger(__name__)
logger.exception("Error normalizing sound %s", sound_id)
async def delete_tts(self, tts_id: int, user_id: int) -> None: async def delete_tts(self, tts_id: int, user_id: int) -> None:
"""Delete a TTS generation and its associated sound and files.""" """Delete a TTS generation and its associated sound and files."""
# Get the TTS record # Get the TTS record
tts = await self.tts_repo.get_by_id(tts_id) tts = await self.tts_repo.get_by_id(tts_id)
if not tts: if not tts:
raise ValueError(f"TTS with ID {tts_id} not found") tts_not_found_msg = f"TTS with ID {tts_id} not found"
raise ValueError(tts_not_found_msg)
# Check ownership # Check ownership
if tts.user_id != user_id: if tts.user_id != user_id:
raise PermissionError( permission_error_msg = (
"You don't have permission to delete this TTS generation", "You don't have permission to delete this TTS generation"
) )
raise PermissionError(permission_error_msg)
# If there's an associated sound, delete it and its files # If there's an associated sound, delete it and its files
if tts.sound_id: if tts.sound_id:
@@ -385,8 +419,6 @@ class TTSService:
async def _delete_sound_files(self, sound: Sound) -> None: async def _delete_sound_files(self, sound: Sound) -> None:
"""Delete all files associated with a sound.""" """Delete all files associated with a sound."""
from pathlib import Path
# Delete original file # Delete original file
original_path = Path("sounds/originals/text_to_speech") / sound.filename original_path = Path("sounds/originals/text_to_speech") / sound.filename
if original_path.exists(): if original_path.exists():
@@ -465,7 +497,8 @@ class TTSService:
tts = result.first() tts = result.first()
if not tts: if not tts:
raise ValueError(f"TTS with ID {tts_id} not found") tts_not_found_msg = f"TTS with ID {tts_id} not found"
raise ValueError(tts_not_found_msg)
# Generate the TTS # Generate the TTS
sound = await self._generate_tts_sync( sound = await self._generate_tts_sync(
@@ -477,6 +510,9 @@ class TTSService:
# Capture sound ID before session issues # Capture sound ID before session issues
sound_id = sound.id sound_id = sound.id
if sound_id is None:
sound_creation_error = "Sound creation failed - no ID assigned"
raise ValueError(sound_creation_error)
# Mark as completed # Mark as completed
await self.mark_tts_completed(tts_id, sound_id) await self.mark_tts_completed(tts_id, sound_id)
@@ -501,9 +537,6 @@ class TTSService:
) -> None: ) -> None:
"""Emit a socket event for TTS status change.""" """Emit a socket event for TTS status change."""
try: try:
from app.core.logging import get_logger
from app.services.socket import socket_manager
logger = get_logger(__name__) logger = get_logger(__name__)
data = { data = {
@@ -513,12 +546,10 @@ class TTSService:
if error: if error:
data["error"] = error data["error"] = error
logger.info(f"Emitting TTS socket event: {event} with data: {data}") logger.info("Emitting TTS socket event: %s with data: %s", event, data)
await socket_manager.broadcast_to_all(event, data) await socket_manager.broadcast_to_all(event, data)
logger.info(f"Successfully emitted TTS socket event: {event}") logger.info("Successfully emitted TTS socket event: %s", event)
except Exception as e: except Exception:
# Don't fail TTS processing if socket emission fails # Don't fail TTS processing if socket emission fails
from app.core.logging import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True) logger.exception("Failed to emit TTS socket event %s", event)

View File

@@ -14,7 +14,7 @@ logger = get_logger(__name__)
class TTSProcessor: class TTSProcessor:
"""Background processor for handling TTS generation queue with concurrency control.""" """Background processor for handling TTS generation queue."""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the TTS processor.""" """Initialize the TTS processor."""