"""TTS service implementation.""" import asyncio import io import uuid from pathlib import Path from typing import Any from gtts import gTTS from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.sound import Sound from app.models.tts import TTS from app.repositories.sound import SoundRepository from app.repositories.tts import TTSRepository from app.services.sound_normalizer import SoundNormalizerService from app.utils.audio import get_audio_duration, get_file_hash, get_file_size from .base import TTSProvider from .providers import GTTSProvider # Constants MAX_TEXT_LENGTH = 1000 MAX_NAME_LENGTH = 50 class TTSService: """Text-to-Speech service with provider management.""" def __init__(self, session: AsyncSession) -> None: """Initialize TTS service. Args: session: Database session """ self.session = session self.sound_repo = SoundRepository(session) self.tts_repo = TTSRepository(session) self.providers: dict[str, TTSProvider] = {} # Register default providers self._register_default_providers() def _register_default_providers(self) -> None: """Register default TTS providers.""" self.register_provider(GTTSProvider()) def register_provider(self, provider: TTSProvider) -> None: """Register a TTS provider. Args: provider: TTS provider instance """ self.providers[provider.name] = provider def get_providers(self) -> dict[str, TTSProvider]: """Get all registered providers.""" return self.providers.copy() def get_provider(self, name: str) -> TTSProvider | None: """Get a specific provider by name.""" return self.providers.get(name) async def create_tts_request( self, text: str, user_id: int, provider: str = "gtts", **options: Any, ) -> dict[str, Any]: """Create a TTS request that will be processed in the background. Args: text: Text to convert to speech user_id: ID of user creating the sound provider: TTS provider name **options: Provider-specific options Returns: Dictionary with TTS record information Raises: ValueError: If provider not found or text too long Exception: If request creation fails """ provider_not_found_msg = f"Provider '{provider}' not found" if provider not in self.providers: raise ValueError(provider_not_found_msg) text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)" if len(text) > MAX_TEXT_LENGTH: raise ValueError(text_too_long_msg) empty_text_msg = "Text cannot be empty" if not text.strip(): raise ValueError(empty_text_msg) # Create TTS record with pending status tts = TTS( text=text, provider=provider, options=options, status="pending", sound_id=None, # Will be set when processing completes user_id=user_id, ) self.session.add(tts) await self.session.commit() await self.session.refresh(tts) # Queue for background processing using the TTS processor if tts.id is not None: from app.services.tts_processor import tts_processor await tts_processor.queue_tts(tts.id) return {"tts": tts, "message": "TTS generation queued successfully"} async def _queue_tts_processing(self, tts_id: int) -> None: """Queue TTS for background processing.""" # For now, process immediately in a different way # This could be moved to a proper background queue later task = asyncio.create_task(self._process_tts_in_background(tts_id)) # Store reference to prevent garbage collection self._background_tasks = getattr(self, "_background_tasks", set()) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) async def _process_tts_in_background(self, tts_id: int) -> None: """Process TTS generation in background.""" from app.core.database import get_session_factory try: # Create a new session for background processing session_factory = get_session_factory() async with session_factory() as background_session: tts_service = TTSService(background_session) # Get the TTS record stmt = select(TTS).where(TTS.id == tts_id) result = await background_session.exec(stmt) tts = result.first() if not tts: return # Use a synchronous approach for the actual generation sound = await tts_service._generate_tts_sync( tts.text, tts.provider, tts.user_id, tts.options, ) # Update the TTS record with the sound ID if sound.id is not None: tts.sound_id = sound.id background_session.add(tts) await background_session.commit() except Exception: # Log error but don't fail - avoiding print for production pass async def _generate_tts_sync( self, text: str, provider: str, user_id: int, options: dict[str, Any], ) -> Sound: """Generate TTS using a synchronous approach.""" # Generate the audio using the provider (avoid async issues by doing it directly) tts_provider = self.providers[provider] # Create directories if they don't exist original_dir = Path("sounds/originals/text_to_speech") original_dir.mkdir(parents=True, exist_ok=True) # Create UUID filename sound_uuid = str(uuid.uuid4()) original_filename = f"{sound_uuid}.{tts_provider.file_extension}" original_path = original_dir / original_filename # Generate audio synchronously try: # Generate TTS audio lang = options.get("lang", "en") tld = options.get("tld", "com") slow = options.get("slow", False) tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow) fp = io.BytesIO() tts_instance.write_to_fp(fp) fp.seek(0) audio_bytes = fp.read() # Save the file original_path.write_bytes(audio_bytes) except Exception: raise # Create Sound record with proper metadata sound = await self._create_sound_record_complete( original_path, text, provider, user_id, ) # Normalize the sound await self._normalize_sound_safe(sound.id) return sound async def get_user_tts_history( self, user_id: int, limit: int = 50, offset: int = 0, ) -> list[TTS]: """Get TTS history for a user. Args: user_id: User ID limit: Maximum number of records offset: Offset for pagination Returns: List of TTS records """ result = await self.tts_repo.get_by_user_id(user_id, limit, offset) return list(result) async def _create_sound_record( self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str, ) -> Sound: """Create a Sound record for the TTS audio.""" # Get audio metadata duration = get_audio_duration(audio_path) size = get_file_size(audio_path) name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "") name = " ".join(word.capitalize() for word in name.split()) # Create sound data sound_data = { "type": "TTS", "name": name, "filename": audio_path.name, "duration": duration, "size": size, "hash": file_hash, "user_id": user_id, "is_deletable": True, "is_music": False, # TTS is speech, not music "is_normalized": False, "play_count": 0, } sound = await self.sound_repo.create(sound_data) return sound async def _create_sound_record_simple( self, audio_path: Path, text: str, provider: str, user_id: int, ) -> Sound: """Create a Sound record for the TTS audio with minimal processing.""" # Create sound data with basic info name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "") name = " ".join(word.capitalize() for word in name.split()) sound_data = { "type": "TTS", "name": name, "filename": audio_path.name, "duration": 0, # Skip duration calculation for now "size": 0, # Skip size calculation for now "hash": str(uuid.uuid4()), # Use UUID as temporary hash "user_id": user_id, "is_deletable": True, "is_music": False, # TTS is speech, not music "is_normalized": False, "play_count": 0, } sound = await self.sound_repo.create(sound_data) return sound async def _create_sound_record_complete( self, audio_path: Path, text: str, provider: str, user_id: int, ) -> Sound: """Create a Sound record for the TTS audio with complete metadata.""" # Get audio metadata duration = get_audio_duration(audio_path) size = get_file_size(audio_path) file_hash = get_file_hash(audio_path) name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "") name = " ".join(word.capitalize() for word in name.split()) # Check if a sound with this hash already exists existing_sound = await self.sound_repo.get_by_hash(file_hash) if existing_sound: # Clean up the temporary file since we have a duplicate if audio_path.exists(): audio_path.unlink() return existing_sound # Create sound data with complete metadata sound_data = { "type": "TTS", "name": name, "filename": audio_path.name, "duration": duration, "size": size, "hash": file_hash, "user_id": user_id, "is_deletable": True, "is_music": False, # TTS is speech, not music "is_normalized": False, "play_count": 0, } sound = await self.sound_repo.create(sound_data) return sound async def _normalize_sound_safe(self, sound_id: int) -> None: """Normalize the TTS sound with error handling.""" try: # Get fresh sound object from database for normalization sound = await self.sound_repo.get_by_id(sound_id) if not sound: return normalizer_service = SoundNormalizerService(self.session) result = await normalizer_service.normalize_sound(sound) if result["status"] == "error": print( f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}", ) except Exception as e: print(f"Exception during TTS sound normalization {sound_id}: {e}") # Don't fail the TTS generation if normalization fails async def _normalize_sound(self, sound_id: int) -> None: """Normalize the TTS sound.""" try: # Get fresh sound object from database for normalization sound = await self.sound_repo.get_by_id(sound_id) if not sound: return normalizer_service = SoundNormalizerService(self.session) result = await normalizer_service.normalize_sound(sound) if result["status"] == "error": # Log warning but don't fail the TTS generation pass except Exception: # Don't fail the TTS generation if normalization fails pass async def delete_tts(self, tts_id: int, user_id: int) -> None: """Delete a TTS generation and its associated sound and files.""" # Get the TTS record tts = await self.tts_repo.get_by_id(tts_id) if not tts: raise ValueError(f"TTS with ID {tts_id} not found") # Check ownership if tts.user_id != user_id: raise PermissionError( "You don't have permission to delete this TTS generation", ) # If there's an associated sound, delete it and its files if tts.sound_id: sound = await self.sound_repo.get_by_id(tts.sound_id) if sound: # Delete the sound files await self._delete_sound_files(sound) # Delete the sound record await self.sound_repo.delete(sound) # Delete the TTS record await self.tts_repo.delete(tts) async def _delete_sound_files(self, sound: Sound) -> None: """Delete all files associated with a sound.""" from pathlib import Path # Delete original file original_path = Path("sounds/originals/text_to_speech") / sound.filename if original_path.exists(): original_path.unlink() # Delete normalized file if it exists if sound.normalized_filename: normalized_path = ( Path("sounds/normalized/text_to_speech") / sound.normalized_filename ) if normalized_path.exists(): normalized_path.unlink() async def get_pending_tts(self) -> list[TTS]: """Get all pending TTS generations.""" stmt = select(TTS).where(TTS.status == "pending").order_by(TTS.created_at) result = await self.session.exec(stmt) return list(result.all()) async def mark_tts_processing(self, tts_id: int) -> None: """Mark a TTS generation as processing.""" stmt = select(TTS).where(TTS.id == tts_id) result = await self.session.exec(stmt) tts = result.first() if tts: tts.status = "processing" self.session.add(tts) await self.session.commit() async def mark_tts_completed(self, tts_id: int, sound_id: int) -> None: """Mark a TTS generation as completed.""" stmt = select(TTS).where(TTS.id == tts_id) result = await self.session.exec(stmt) tts = result.first() if tts: tts.status = "completed" tts.sound_id = sound_id tts.error = None self.session.add(tts) await self.session.commit() async def mark_tts_failed(self, tts_id: int, error_message: str) -> None: """Mark a TTS generation as failed.""" stmt = select(TTS).where(TTS.id == tts_id) result = await self.session.exec(stmt) tts = result.first() if tts: tts.status = "failed" tts.error = error_message self.session.add(tts) await self.session.commit() async def reset_stuck_tts(self) -> int: """Reset stuck TTS generations from processing back to pending.""" stmt = select(TTS).where(TTS.status == "processing") result = await self.session.exec(stmt) stuck_tts = list(result.all()) for tts in stuck_tts: tts.status = "pending" tts.error = None self.session.add(tts) await self.session.commit() return len(stuck_tts) async def process_tts_generation(self, tts_id: int) -> None: """Process a TTS generation (used by the processor).""" # Mark as processing await self.mark_tts_processing(tts_id) try: # Get the TTS record stmt = select(TTS).where(TTS.id == tts_id) result = await self.session.exec(stmt) tts = result.first() if not tts: raise ValueError(f"TTS with ID {tts_id} not found") # Generate the TTS sound = await self._generate_tts_sync( tts.text, tts.provider, tts.user_id, tts.options, ) # Capture sound ID before session issues sound_id = sound.id # Mark as completed await self.mark_tts_completed(tts_id, sound_id) # Emit socket event for completion await self._emit_tts_event("tts_completed", tts_id, sound_id) except Exception as e: # Mark as failed await self.mark_tts_failed(tts_id, str(e)) # Emit socket event for failure await self._emit_tts_event("tts_failed", tts_id, None, str(e)) raise async def _emit_tts_event( self, event: str, tts_id: int, sound_id: int | None = None, error: str | None = None, ) -> None: """Emit a socket event for TTS status change.""" try: from app.core.logging import get_logger from app.services.socket import socket_manager logger = get_logger(__name__) data = { "tts_id": tts_id, "sound_id": sound_id, } if error: data["error"] = error logger.info(f"Emitting TTS socket event: {event} with data: {data}") await socket_manager.broadcast_to_all(event, data) logger.info(f"Successfully emitted TTS socket event: {event}") except Exception as e: # Don't fail TTS processing if socket emission fails from app.core.logging import get_logger logger = get_logger(__name__) logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True)