Files
sdb2-backend/app/services/tts/service.py

522 lines
18 KiB
Python

"""TTS service implementation."""
import asyncio
import io
import uuid
from pathlib import Path
from typing import Any
from gtts import gTTS
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound
from app.models.tts import TTS
from app.repositories.sound import SoundRepository
from app.repositories.tts import TTSRepository
from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
from .base import TTSProvider
from .providers import GTTSProvider
# Constants
MAX_TEXT_LENGTH = 1000
MAX_NAME_LENGTH = 50
class TTSService:
"""Text-to-Speech service with provider management."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize TTS service.
Args:
session: Database session
"""
self.session = session
self.sound_repo = SoundRepository(session)
self.tts_repo = TTSRepository(session)
self.providers: dict[str, TTSProvider] = {}
# Register default providers
self._register_default_providers()
def _register_default_providers(self) -> None:
"""Register default TTS providers."""
self.register_provider(GTTSProvider())
def register_provider(self, provider: TTSProvider) -> None:
"""Register a TTS provider.
Args:
provider: TTS provider instance
"""
self.providers[provider.name] = provider
def get_providers(self) -> dict[str, TTSProvider]:
"""Get all registered providers."""
return self.providers.copy()
def get_provider(self, name: str) -> TTSProvider | None:
"""Get a specific provider by name."""
return self.providers.get(name)
async def create_tts_request(
self,
text: str,
user_id: int,
provider: str = "gtts",
**options: Any,
) -> dict[str, Any]:
"""Create a TTS request that will be processed in the background.
Args:
text: Text to convert to speech
user_id: ID of user creating the sound
provider: TTS provider name
**options: Provider-specific options
Returns:
Dictionary with TTS record information
Raises:
ValueError: If provider not found or text too long
Exception: If request creation fails
"""
provider_not_found_msg = f"Provider '{provider}' not found"
if provider not in self.providers:
raise ValueError(provider_not_found_msg)
text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)"
if len(text) > MAX_TEXT_LENGTH:
raise ValueError(text_too_long_msg)
empty_text_msg = "Text cannot be empty"
if not text.strip():
raise ValueError(empty_text_msg)
# Create TTS record with pending status
tts = TTS(
text=text,
provider=provider,
options=options,
status="pending",
sound_id=None, # Will be set when processing completes
user_id=user_id,
)
self.session.add(tts)
await self.session.commit()
await self.session.refresh(tts)
# Queue for background processing using the TTS processor
if tts.id is not None:
from app.services.tts_processor import tts_processor
await tts_processor.queue_tts(tts.id)
return {"tts": tts, "message": "TTS generation queued successfully"}
async def _queue_tts_processing(self, tts_id: int) -> None:
"""Queue TTS for background processing."""
# For now, process immediately in a different way
# This could be moved to a proper background queue later
task = asyncio.create_task(self._process_tts_in_background(tts_id))
# Store reference to prevent garbage collection
self._background_tasks = getattr(self, '_background_tasks', set())
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
async def _process_tts_in_background(self, tts_id: int) -> None:
"""Process TTS generation in background."""
from app.core.database import get_session_factory
try:
# Create a new session for background processing
session_factory = get_session_factory()
async with session_factory() as background_session:
tts_service = TTSService(background_session)
# Get the TTS record
stmt = select(TTS).where(TTS.id == tts_id)
result = await background_session.exec(stmt)
tts = result.first()
if not tts:
return
# Use a synchronous approach for the actual generation
sound = await tts_service._generate_tts_sync(
tts.text,
tts.provider,
tts.user_id,
tts.options,
)
# Update the TTS record with the sound ID
if sound.id is not None:
tts.sound_id = sound.id
background_session.add(tts)
await background_session.commit()
except Exception:
# Log error but don't fail - avoiding print for production
pass
async def _generate_tts_sync(
self,
text: str,
provider: str,
user_id: int,
options: dict[str, Any]
) -> Sound:
"""Generate TTS using a synchronous approach."""
# Generate the audio using the provider (avoid async issues by doing it directly)
tts_provider = self.providers[provider]
# Create directories if they don't exist
original_dir = Path("sounds/originals/text_to_speech")
original_dir.mkdir(parents=True, exist_ok=True)
# Create UUID filename
sound_uuid = str(uuid.uuid4())
original_filename = f"{sound_uuid}.{tts_provider.file_extension}"
original_path = original_dir / original_filename
# Generate audio synchronously
try:
# Generate TTS audio
lang = options.get("lang", "en")
tld = options.get("tld", "com")
slow = options.get("slow", False)
tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow)
fp = io.BytesIO()
tts_instance.write_to_fp(fp)
fp.seek(0)
audio_bytes = fp.read()
# Save the file
original_path.write_bytes(audio_bytes)
except Exception:
raise
# Create Sound record with proper metadata
sound = await self._create_sound_record_complete(
original_path,
text,
provider,
user_id
)
# Normalize the sound
await self._normalize_sound_safe(sound.id)
return sound
async def get_user_tts_history(
self,
user_id: int,
limit: int = 50,
offset: int = 0
) -> list[TTS]:
"""Get TTS history for a user.
Args:
user_id: User ID
limit: Maximum number of records
offset: Offset for pagination
Returns:
List of TTS records
"""
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
return list(result)
async def _create_sound_record(
self,
audio_path: Path,
text: str,
provider: str,
user_id: int,
file_hash: str
) -> Sound:
"""Create a Sound record for the TTS audio."""
# Get audio metadata
duration = get_audio_duration(audio_path)
size = get_file_size(audio_path)
# Create sound data
sound_data = {
"type": "TTS",
"name": text[:50] + ("..." if len(text) > 50 else ""),
"filename": audio_path.name,
"duration": duration,
"size": size,
"hash": file_hash,
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
async def _create_sound_record_simple(
self,
audio_path: Path,
text: str,
provider: str,
user_id: int
) -> Sound:
"""Create a Sound record for the TTS audio with minimal processing."""
# Create sound data with basic info
sound_data = {
"type": "TTS",
"name": text[:50] + ("..." if len(text) > 50 else ""),
"filename": audio_path.name,
"duration": 0, # Skip duration calculation for now
"size": 0, # Skip size calculation for now
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
async def _create_sound_record_complete(
self,
audio_path: Path,
text: str,
provider: str,
user_id: int
) -> Sound:
"""Create a Sound record for the TTS audio with complete metadata."""
# Get audio metadata
duration = get_audio_duration(audio_path)
size = get_file_size(audio_path)
file_hash = get_file_hash(audio_path)
# Check if a sound with this hash already exists
existing_sound = await self.sound_repo.get_by_hash(file_hash)
if existing_sound:
# Clean up the temporary file since we have a duplicate
if audio_path.exists():
audio_path.unlink()
return existing_sound
# Create sound data with complete metadata
sound_data = {
"type": "TTS",
"name": text[:50] + ("..." if len(text) > 50 else ""),
"filename": audio_path.name,
"duration": duration,
"size": size,
"hash": file_hash,
"user_id": user_id,
"is_deletable": True,
"is_music": False, # TTS is speech, not music
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
async def _normalize_sound_safe(self, sound_id: int) -> None:
"""Normalize the TTS sound with error handling."""
try:
# Get fresh sound object from database for normalization
sound = await self.sound_repo.get_by_id(sound_id)
if not sound:
return
normalizer_service = SoundNormalizerService(self.session)
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
print(f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}")
except Exception as e:
print(f"Exception during TTS sound normalization {sound_id}: {e}")
# Don't fail the TTS generation if normalization fails
async def _normalize_sound(self, sound_id: int) -> None:
"""Normalize the TTS sound."""
try:
# Get fresh sound object from database for normalization
sound = await self.sound_repo.get_by_id(sound_id)
if not sound:
return
normalizer_service = SoundNormalizerService(self.session)
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
# Log warning but don't fail the TTS generation
pass
except Exception:
# Don't fail the TTS generation if normalization fails
pass
async def delete_tts(self, tts_id: int, user_id: int) -> None:
"""Delete a TTS generation and its associated sound and files."""
# Get the TTS record
tts = await self.tts_repo.get_by_id(tts_id)
if not tts:
raise ValueError(f"TTS with ID {tts_id} not found")
# Check ownership
if tts.user_id != user_id:
raise PermissionError("You don't have permission to delete this TTS generation")
# If there's an associated sound, delete it and its files
if tts.sound_id:
sound = await self.sound_repo.get_by_id(tts.sound_id)
if sound:
# Delete the sound files
await self._delete_sound_files(sound)
# Delete the sound record
await self.sound_repo.delete(sound)
# Delete the TTS record
await self.tts_repo.delete(tts)
async def _delete_sound_files(self, sound: Sound) -> None:
"""Delete all files associated with a sound."""
from pathlib import Path
# Delete original file
original_path = Path("sounds/originals/text_to_speech") / sound.filename
if original_path.exists():
original_path.unlink()
# Delete normalized file if it exists
if sound.normalized_filename:
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
if normalized_path.exists():
normalized_path.unlink()
async def get_pending_tts(self) -> list[TTS]:
"""Get all pending TTS generations."""
stmt = select(TTS).where(TTS.status == "pending").order_by(TTS.created_at)
result = await self.session.exec(stmt)
return list(result.all())
async def mark_tts_processing(self, tts_id: int) -> None:
"""Mark a TTS generation as processing."""
stmt = select(TTS).where(TTS.id == tts_id)
result = await self.session.exec(stmt)
tts = result.first()
if tts:
tts.status = "processing"
self.session.add(tts)
await self.session.commit()
async def mark_tts_completed(self, tts_id: int, sound_id: int) -> None:
"""Mark a TTS generation as completed."""
stmt = select(TTS).where(TTS.id == tts_id)
result = await self.session.exec(stmt)
tts = result.first()
if tts:
tts.status = "completed"
tts.sound_id = sound_id
tts.error = None
self.session.add(tts)
await self.session.commit()
async def mark_tts_failed(self, tts_id: int, error_message: str) -> None:
"""Mark a TTS generation as failed."""
stmt = select(TTS).where(TTS.id == tts_id)
result = await self.session.exec(stmt)
tts = result.first()
if tts:
tts.status = "failed"
tts.error = error_message
self.session.add(tts)
await self.session.commit()
async def reset_stuck_tts(self) -> int:
"""Reset stuck TTS generations from processing back to pending."""
stmt = select(TTS).where(TTS.status == "processing")
result = await self.session.exec(stmt)
stuck_tts = list(result.all())
for tts in stuck_tts:
tts.status = "pending"
tts.error = None
self.session.add(tts)
await self.session.commit()
return len(stuck_tts)
async def process_tts_generation(self, tts_id: int) -> None:
"""Process a TTS generation (used by the processor)."""
# Mark as processing
await self.mark_tts_processing(tts_id)
try:
# Get the TTS record
stmt = select(TTS).where(TTS.id == tts_id)
result = await self.session.exec(stmt)
tts = result.first()
if not tts:
raise ValueError(f"TTS with ID {tts_id} not found")
# Generate the TTS
sound = await self._generate_tts_sync(
tts.text,
tts.provider,
tts.user_id,
tts.options,
)
# Capture sound ID before session issues
sound_id = sound.id
# Mark as completed
await self.mark_tts_completed(tts_id, sound_id)
# Emit socket event for completion
await self._emit_tts_event("tts_completed", tts_id, sound_id)
except Exception as e:
# Mark as failed
await self.mark_tts_failed(tts_id, str(e))
# Emit socket event for failure
await self._emit_tts_event("tts_failed", tts_id, None, str(e))
raise
async def _emit_tts_event(self, event: str, tts_id: int, sound_id: int | None = None, error: str | None = None) -> None:
"""Emit a socket event for TTS status change."""
try:
from app.services.socket import socket_manager
from app.core.logging import get_logger
logger = get_logger(__name__)
data = {
"tts_id": tts_id,
"sound_id": sound_id,
}
if error:
data["error"] = error
logger.info(f"Emitting TTS socket event: {event} with data: {data}")
await socket_manager.broadcast_to_all(event, data)
logger.info(f"Successfully emitted TTS socket event: {event}")
except Exception as e:
# Don't fail TTS processing if socket emission fails
from app.core.logging import get_logger
logger = get_logger(__name__)
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True)