522 lines
18 KiB
Python
522 lines
18 KiB
Python
"""TTS service implementation."""
|
|
|
|
import asyncio
|
|
import io
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from gtts import gTTS
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.sound import Sound
|
|
from app.models.tts import TTS
|
|
from app.repositories.sound import SoundRepository
|
|
from app.repositories.tts import TTSRepository
|
|
from app.services.sound_normalizer import SoundNormalizerService
|
|
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
|
|
|
from .base import TTSProvider
|
|
from .providers import GTTSProvider
|
|
|
|
# Constants
|
|
MAX_TEXT_LENGTH = 1000
|
|
MAX_NAME_LENGTH = 50
|
|
|
|
|
|
class TTSService:
|
|
"""Text-to-Speech service with provider management."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize TTS service.
|
|
|
|
Args:
|
|
session: Database session
|
|
"""
|
|
self.session = session
|
|
self.sound_repo = SoundRepository(session)
|
|
self.tts_repo = TTSRepository(session)
|
|
self.providers: dict[str, TTSProvider] = {}
|
|
|
|
# Register default providers
|
|
self._register_default_providers()
|
|
|
|
def _register_default_providers(self) -> None:
|
|
"""Register default TTS providers."""
|
|
self.register_provider(GTTSProvider())
|
|
|
|
def register_provider(self, provider: TTSProvider) -> None:
|
|
"""Register a TTS provider.
|
|
|
|
Args:
|
|
provider: TTS provider instance
|
|
"""
|
|
self.providers[provider.name] = provider
|
|
|
|
def get_providers(self) -> dict[str, TTSProvider]:
|
|
"""Get all registered providers."""
|
|
return self.providers.copy()
|
|
|
|
def get_provider(self, name: str) -> TTSProvider | None:
|
|
"""Get a specific provider by name."""
|
|
return self.providers.get(name)
|
|
|
|
async def create_tts_request(
|
|
self,
|
|
text: str,
|
|
user_id: int,
|
|
provider: str = "gtts",
|
|
**options: Any,
|
|
) -> dict[str, Any]:
|
|
"""Create a TTS request that will be processed in the background.
|
|
|
|
Args:
|
|
text: Text to convert to speech
|
|
user_id: ID of user creating the sound
|
|
provider: TTS provider name
|
|
**options: Provider-specific options
|
|
|
|
Returns:
|
|
Dictionary with TTS record information
|
|
|
|
Raises:
|
|
ValueError: If provider not found or text too long
|
|
Exception: If request creation fails
|
|
"""
|
|
provider_not_found_msg = f"Provider '{provider}' not found"
|
|
if provider not in self.providers:
|
|
raise ValueError(provider_not_found_msg)
|
|
|
|
text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)"
|
|
if len(text) > MAX_TEXT_LENGTH:
|
|
raise ValueError(text_too_long_msg)
|
|
|
|
empty_text_msg = "Text cannot be empty"
|
|
if not text.strip():
|
|
raise ValueError(empty_text_msg)
|
|
|
|
# Create TTS record with pending status
|
|
tts = TTS(
|
|
text=text,
|
|
provider=provider,
|
|
options=options,
|
|
status="pending",
|
|
sound_id=None, # Will be set when processing completes
|
|
user_id=user_id,
|
|
)
|
|
self.session.add(tts)
|
|
await self.session.commit()
|
|
await self.session.refresh(tts)
|
|
|
|
# Queue for background processing using the TTS processor
|
|
if tts.id is not None:
|
|
from app.services.tts_processor import tts_processor
|
|
await tts_processor.queue_tts(tts.id)
|
|
|
|
return {"tts": tts, "message": "TTS generation queued successfully"}
|
|
|
|
async def _queue_tts_processing(self, tts_id: int) -> None:
|
|
"""Queue TTS for background processing."""
|
|
# For now, process immediately in a different way
|
|
# This could be moved to a proper background queue later
|
|
task = asyncio.create_task(self._process_tts_in_background(tts_id))
|
|
# Store reference to prevent garbage collection
|
|
self._background_tasks = getattr(self, '_background_tasks', set())
|
|
self._background_tasks.add(task)
|
|
task.add_done_callback(self._background_tasks.discard)
|
|
|
|
async def _process_tts_in_background(self, tts_id: int) -> None:
|
|
"""Process TTS generation in background."""
|
|
from app.core.database import get_session_factory
|
|
|
|
try:
|
|
# Create a new session for background processing
|
|
session_factory = get_session_factory()
|
|
async with session_factory() as background_session:
|
|
tts_service = TTSService(background_session)
|
|
|
|
# Get the TTS record
|
|
stmt = select(TTS).where(TTS.id == tts_id)
|
|
result = await background_session.exec(stmt)
|
|
tts = result.first()
|
|
|
|
if not tts:
|
|
return
|
|
|
|
# Use a synchronous approach for the actual generation
|
|
sound = await tts_service._generate_tts_sync(
|
|
tts.text,
|
|
tts.provider,
|
|
tts.user_id,
|
|
tts.options,
|
|
)
|
|
|
|
# Update the TTS record with the sound ID
|
|
if sound.id is not None:
|
|
tts.sound_id = sound.id
|
|
background_session.add(tts)
|
|
await background_session.commit()
|
|
|
|
except Exception:
|
|
# Log error but don't fail - avoiding print for production
|
|
pass
|
|
|
|
async def _generate_tts_sync(
|
|
self,
|
|
text: str,
|
|
provider: str,
|
|
user_id: int,
|
|
options: dict[str, Any]
|
|
) -> Sound:
|
|
"""Generate TTS using a synchronous approach."""
|
|
# Generate the audio using the provider (avoid async issues by doing it directly)
|
|
tts_provider = self.providers[provider]
|
|
|
|
# Create directories if they don't exist
|
|
original_dir = Path("sounds/originals/text_to_speech")
|
|
original_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create UUID filename
|
|
sound_uuid = str(uuid.uuid4())
|
|
original_filename = f"{sound_uuid}.{tts_provider.file_extension}"
|
|
original_path = original_dir / original_filename
|
|
|
|
# Generate audio synchronously
|
|
try:
|
|
# Generate TTS audio
|
|
lang = options.get("lang", "en")
|
|
tld = options.get("tld", "com")
|
|
slow = options.get("slow", False)
|
|
|
|
tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow)
|
|
fp = io.BytesIO()
|
|
tts_instance.write_to_fp(fp)
|
|
fp.seek(0)
|
|
audio_bytes = fp.read()
|
|
|
|
# Save the file
|
|
original_path.write_bytes(audio_bytes)
|
|
|
|
except Exception:
|
|
raise
|
|
|
|
# Create Sound record with proper metadata
|
|
sound = await self._create_sound_record_complete(
|
|
original_path,
|
|
text,
|
|
provider,
|
|
user_id
|
|
)
|
|
|
|
# Normalize the sound
|
|
await self._normalize_sound_safe(sound.id)
|
|
|
|
return sound
|
|
|
|
async def get_user_tts_history(
|
|
self,
|
|
user_id: int,
|
|
limit: int = 50,
|
|
offset: int = 0
|
|
) -> list[TTS]:
|
|
"""Get TTS history for a user.
|
|
|
|
Args:
|
|
user_id: User ID
|
|
limit: Maximum number of records
|
|
offset: Offset for pagination
|
|
|
|
Returns:
|
|
List of TTS records
|
|
"""
|
|
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
|
|
return list(result)
|
|
|
|
async def _create_sound_record(
|
|
self,
|
|
audio_path: Path,
|
|
text: str,
|
|
provider: str,
|
|
user_id: int,
|
|
file_hash: str
|
|
) -> Sound:
|
|
"""Create a Sound record for the TTS audio."""
|
|
# Get audio metadata
|
|
duration = get_audio_duration(audio_path)
|
|
size = get_file_size(audio_path)
|
|
|
|
# Create sound data
|
|
sound_data = {
|
|
"type": "TTS",
|
|
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
|
"filename": audio_path.name,
|
|
"duration": duration,
|
|
"size": size,
|
|
"hash": file_hash,
|
|
"user_id": user_id,
|
|
"is_deletable": True,
|
|
"is_music": False, # TTS is speech, not music
|
|
"is_normalized": False,
|
|
"play_count": 0,
|
|
}
|
|
|
|
sound = await self.sound_repo.create(sound_data)
|
|
return sound
|
|
|
|
async def _create_sound_record_simple(
|
|
self,
|
|
audio_path: Path,
|
|
text: str,
|
|
provider: str,
|
|
user_id: int
|
|
) -> Sound:
|
|
"""Create a Sound record for the TTS audio with minimal processing."""
|
|
# Create sound data with basic info
|
|
sound_data = {
|
|
"type": "TTS",
|
|
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
|
"filename": audio_path.name,
|
|
"duration": 0, # Skip duration calculation for now
|
|
"size": 0, # Skip size calculation for now
|
|
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
|
|
"user_id": user_id,
|
|
"is_deletable": True,
|
|
"is_music": False, # TTS is speech, not music
|
|
"is_normalized": False,
|
|
"play_count": 0,
|
|
}
|
|
|
|
sound = await self.sound_repo.create(sound_data)
|
|
return sound
|
|
|
|
async def _create_sound_record_complete(
|
|
self,
|
|
audio_path: Path,
|
|
text: str,
|
|
provider: str,
|
|
user_id: int
|
|
) -> Sound:
|
|
"""Create a Sound record for the TTS audio with complete metadata."""
|
|
# Get audio metadata
|
|
duration = get_audio_duration(audio_path)
|
|
size = get_file_size(audio_path)
|
|
file_hash = get_file_hash(audio_path)
|
|
|
|
# Check if a sound with this hash already exists
|
|
existing_sound = await self.sound_repo.get_by_hash(file_hash)
|
|
|
|
if existing_sound:
|
|
# Clean up the temporary file since we have a duplicate
|
|
if audio_path.exists():
|
|
audio_path.unlink()
|
|
return existing_sound
|
|
|
|
# Create sound data with complete metadata
|
|
sound_data = {
|
|
"type": "TTS",
|
|
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
|
"filename": audio_path.name,
|
|
"duration": duration,
|
|
"size": size,
|
|
"hash": file_hash,
|
|
"user_id": user_id,
|
|
"is_deletable": True,
|
|
"is_music": False, # TTS is speech, not music
|
|
"is_normalized": False,
|
|
"play_count": 0,
|
|
}
|
|
|
|
sound = await self.sound_repo.create(sound_data)
|
|
return sound
|
|
|
|
async def _normalize_sound_safe(self, sound_id: int) -> None:
|
|
"""Normalize the TTS sound with error handling."""
|
|
try:
|
|
# Get fresh sound object from database for normalization
|
|
sound = await self.sound_repo.get_by_id(sound_id)
|
|
if not sound:
|
|
return
|
|
|
|
normalizer_service = SoundNormalizerService(self.session)
|
|
result = await normalizer_service.normalize_sound(sound)
|
|
|
|
if result["status"] == "error":
|
|
print(f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}")
|
|
|
|
except Exception as e:
|
|
print(f"Exception during TTS sound normalization {sound_id}: {e}")
|
|
# Don't fail the TTS generation if normalization fails
|
|
|
|
async def _normalize_sound(self, sound_id: int) -> None:
|
|
"""Normalize the TTS sound."""
|
|
try:
|
|
# Get fresh sound object from database for normalization
|
|
sound = await self.sound_repo.get_by_id(sound_id)
|
|
if not sound:
|
|
return
|
|
|
|
normalizer_service = SoundNormalizerService(self.session)
|
|
result = await normalizer_service.normalize_sound(sound)
|
|
|
|
if result["status"] == "error":
|
|
# Log warning but don't fail the TTS generation
|
|
pass
|
|
|
|
except Exception:
|
|
# Don't fail the TTS generation if normalization fails
|
|
pass
|
|
|
|
async def delete_tts(self, tts_id: int, user_id: int) -> None:
|
|
"""Delete a TTS generation and its associated sound and files."""
|
|
# Get the TTS record
|
|
tts = await self.tts_repo.get_by_id(tts_id)
|
|
if not tts:
|
|
raise ValueError(f"TTS with ID {tts_id} not found")
|
|
|
|
# Check ownership
|
|
if tts.user_id != user_id:
|
|
raise PermissionError("You don't have permission to delete this TTS generation")
|
|
|
|
# If there's an associated sound, delete it and its files
|
|
if tts.sound_id:
|
|
sound = await self.sound_repo.get_by_id(tts.sound_id)
|
|
if sound:
|
|
# Delete the sound files
|
|
await self._delete_sound_files(sound)
|
|
# Delete the sound record
|
|
await self.sound_repo.delete(sound)
|
|
|
|
# Delete the TTS record
|
|
await self.tts_repo.delete(tts)
|
|
|
|
async def _delete_sound_files(self, sound: Sound) -> None:
|
|
"""Delete all files associated with a sound."""
|
|
from pathlib import Path
|
|
|
|
# Delete original file
|
|
original_path = Path("sounds/originals/text_to_speech") / sound.filename
|
|
if original_path.exists():
|
|
original_path.unlink()
|
|
|
|
# Delete normalized file if it exists
|
|
if sound.normalized_filename:
|
|
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
|
if normalized_path.exists():
|
|
normalized_path.unlink()
|
|
|
|
async def get_pending_tts(self) -> list[TTS]:
|
|
"""Get all pending TTS generations."""
|
|
stmt = select(TTS).where(TTS.status == "pending").order_by(TTS.created_at)
|
|
result = await self.session.exec(stmt)
|
|
return list(result.all())
|
|
|
|
async def mark_tts_processing(self, tts_id: int) -> None:
|
|
"""Mark a TTS generation as processing."""
|
|
stmt = select(TTS).where(TTS.id == tts_id)
|
|
result = await self.session.exec(stmt)
|
|
tts = result.first()
|
|
if tts:
|
|
tts.status = "processing"
|
|
self.session.add(tts)
|
|
await self.session.commit()
|
|
|
|
async def mark_tts_completed(self, tts_id: int, sound_id: int) -> None:
|
|
"""Mark a TTS generation as completed."""
|
|
stmt = select(TTS).where(TTS.id == tts_id)
|
|
result = await self.session.exec(stmt)
|
|
tts = result.first()
|
|
if tts:
|
|
tts.status = "completed"
|
|
tts.sound_id = sound_id
|
|
tts.error = None
|
|
self.session.add(tts)
|
|
await self.session.commit()
|
|
|
|
async def mark_tts_failed(self, tts_id: int, error_message: str) -> None:
|
|
"""Mark a TTS generation as failed."""
|
|
stmt = select(TTS).where(TTS.id == tts_id)
|
|
result = await self.session.exec(stmt)
|
|
tts = result.first()
|
|
if tts:
|
|
tts.status = "failed"
|
|
tts.error = error_message
|
|
self.session.add(tts)
|
|
await self.session.commit()
|
|
|
|
async def reset_stuck_tts(self) -> int:
|
|
"""Reset stuck TTS generations from processing back to pending."""
|
|
stmt = select(TTS).where(TTS.status == "processing")
|
|
result = await self.session.exec(stmt)
|
|
stuck_tts = list(result.all())
|
|
|
|
for tts in stuck_tts:
|
|
tts.status = "pending"
|
|
tts.error = None
|
|
self.session.add(tts)
|
|
|
|
await self.session.commit()
|
|
return len(stuck_tts)
|
|
|
|
async def process_tts_generation(self, tts_id: int) -> None:
|
|
"""Process a TTS generation (used by the processor)."""
|
|
# Mark as processing
|
|
await self.mark_tts_processing(tts_id)
|
|
|
|
try:
|
|
# Get the TTS record
|
|
stmt = select(TTS).where(TTS.id == tts_id)
|
|
result = await self.session.exec(stmt)
|
|
tts = result.first()
|
|
|
|
if not tts:
|
|
raise ValueError(f"TTS with ID {tts_id} not found")
|
|
|
|
# Generate the TTS
|
|
sound = await self._generate_tts_sync(
|
|
tts.text,
|
|
tts.provider,
|
|
tts.user_id,
|
|
tts.options,
|
|
)
|
|
|
|
# Capture sound ID before session issues
|
|
sound_id = sound.id
|
|
|
|
# Mark as completed
|
|
await self.mark_tts_completed(tts_id, sound_id)
|
|
|
|
# Emit socket event for completion
|
|
await self._emit_tts_event("tts_completed", tts_id, sound_id)
|
|
|
|
except Exception as e:
|
|
# Mark as failed
|
|
await self.mark_tts_failed(tts_id, str(e))
|
|
|
|
# Emit socket event for failure
|
|
await self._emit_tts_event("tts_failed", tts_id, None, str(e))
|
|
raise
|
|
|
|
async def _emit_tts_event(self, event: str, tts_id: int, sound_id: int | None = None, error: str | None = None) -> None:
|
|
"""Emit a socket event for TTS status change."""
|
|
try:
|
|
from app.services.socket import socket_manager
|
|
from app.core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
data = {
|
|
"tts_id": tts_id,
|
|
"sound_id": sound_id,
|
|
}
|
|
if error:
|
|
data["error"] = error
|
|
|
|
logger.info(f"Emitting TTS socket event: {event} with data: {data}")
|
|
await socket_manager.broadcast_to_all(event, data)
|
|
logger.info(f"Successfully emitted TTS socket event: {event}")
|
|
except Exception as e:
|
|
# Don't fail TTS processing if socket emission fails
|
|
from app.core.logging import get_logger
|
|
logger = get_logger(__name__)
|
|
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True) |