refactor: Improve code readability and structure across TTS modules
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
"""Base TTS provider interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
# Type alias for TTS options
|
||||
TTSOptions = dict[str, str | bool | int | float]
|
||||
|
||||
|
||||
class TTSProvider(ABC):
|
||||
"""Abstract base class for TTS providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_speech(self, text: str, **options: Any) -> bytes:
|
||||
async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
|
||||
"""Generate speech from text with provider-specific options.
|
||||
|
||||
Args:
|
||||
@@ -25,7 +27,7 @@ class TTSProvider(ABC):
|
||||
"""Return list of supported language codes."""
|
||||
|
||||
@abstractmethod
|
||||
def get_option_schema(self) -> dict[str, Any]:
|
||||
def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
|
||||
"""Return schema for provider-specific options."""
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
from gtts import gTTS
|
||||
|
||||
from ..base import TTSProvider
|
||||
from app.services.tts.base import TTSProvider
|
||||
|
||||
|
||||
class GTTSProvider(TTSProvider):
|
||||
@@ -22,7 +21,7 @@ class GTTSProvider(TTSProvider):
|
||||
"""Return the default file extension for this provider."""
|
||||
return "mp3"
|
||||
|
||||
async def generate_speech(self, text: str, **options: Any) -> bytes:
|
||||
async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
|
||||
"""Generate speech from text using Google TTS.
|
||||
|
||||
Args:
|
||||
@@ -38,7 +37,7 @@ class GTTSProvider(TTSProvider):
|
||||
slow = options.get("slow", False)
|
||||
|
||||
# Run TTS generation in thread pool since gTTS is synchronous
|
||||
def _generate():
|
||||
def _generate() -> bytes:
|
||||
tts = gTTS(text=text, lang=lang, tld=tld, slow=slow)
|
||||
fp = io.BytesIO()
|
||||
tts.write_to_fp(fp)
|
||||
@@ -64,7 +63,7 @@ class GTTSProvider(TTSProvider):
|
||||
"vi", "yo", "zh", "zh-cn", "zh-tw", "zu",
|
||||
]
|
||||
|
||||
def get_option_schema(self) -> dict[str, Any]:
|
||||
def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
|
||||
"""Return schema for GTTS-specific options."""
|
||||
return {
|
||||
"lang": {
|
||||
|
||||
@@ -10,10 +10,13 @@ from gtts import gTTS
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_session_factory
|
||||
from app.core.logging import get_logger
|
||||
from app.models.sound import Sound
|
||||
from app.models.tts import TTS
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.tts import TTSRepository
|
||||
from app.services.socket import socket_manager
|
||||
from app.services.sound_normalizer import SoundNormalizerService
|
||||
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
|
||||
@@ -25,6 +28,12 @@ MAX_TEXT_LENGTH = 1000
|
||||
MAX_NAME_LENGTH = 50
|
||||
|
||||
|
||||
async def _get_tts_processor() -> object:
|
||||
"""Get TTS processor instance, avoiding circular import."""
|
||||
from app.services.tts_processor import tts_processor # noqa: PLC0415
|
||||
return tts_processor
|
||||
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-Speech service with provider management."""
|
||||
|
||||
@@ -69,7 +78,7 @@ class TTSService:
|
||||
text: str,
|
||||
user_id: int,
|
||||
provider: str = "gtts",
|
||||
**options: Any,
|
||||
**options: str | bool | float,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a TTS request that will be processed in the background.
|
||||
|
||||
@@ -114,8 +123,7 @@ class TTSService:
|
||||
|
||||
# Queue for background processing using the TTS processor
|
||||
if tts.id is not None:
|
||||
from app.services.tts_processor import tts_processor
|
||||
|
||||
tts_processor = await _get_tts_processor()
|
||||
await tts_processor.queue_tts(tts.id)
|
||||
|
||||
return {"tts": tts, "message": "TTS generation queued successfully"}
|
||||
@@ -132,8 +140,6 @@ class TTSService:
|
||||
|
||||
async def _process_tts_in_background(self, tts_id: int) -> None:
|
||||
"""Process TTS generation in background."""
|
||||
from app.core.database import get_session_factory
|
||||
|
||||
try:
|
||||
# Create a new session for background processing
|
||||
session_factory = get_session_factory()
|
||||
@@ -164,13 +170,19 @@ class TTSService:
|
||||
|
||||
except Exception:
|
||||
# Log error but don't fail - avoiding print for production
|
||||
pass
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Error processing TTS generation %s", tts_id)
|
||||
|
||||
async def _generate_tts_sync(
|
||||
self, text: str, provider: str, user_id: int, options: dict[str, Any],
|
||||
self,
|
||||
text: str,
|
||||
provider: str,
|
||||
user_id: int,
|
||||
options: dict[str, Any],
|
||||
) -> Sound:
|
||||
"""Generate TTS using a synchronous approach."""
|
||||
# Generate the audio using the provider (avoid async issues by doing it directly)
|
||||
# Generate the audio using the provider
|
||||
# (avoid async issues by doing it directly)
|
||||
tts_provider = self.providers[provider]
|
||||
|
||||
# Create directories if they don't exist
|
||||
@@ -199,20 +211,28 @@ class TTSService:
|
||||
original_path.write_bytes(audio_bytes)
|
||||
|
||||
except Exception:
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Error generating TTS audio")
|
||||
raise
|
||||
|
||||
# Create Sound record with proper metadata
|
||||
sound = await self._create_sound_record_complete(
|
||||
original_path, text, provider, user_id,
|
||||
original_path,
|
||||
text,
|
||||
user_id,
|
||||
)
|
||||
|
||||
# Normalize the sound
|
||||
await self._normalize_sound_safe(sound.id)
|
||||
if sound.id is not None:
|
||||
await self._normalize_sound_safe(sound.id)
|
||||
|
||||
return sound
|
||||
|
||||
async def get_user_tts_history(
|
||||
self, user_id: int, limit: int = 50, offset: int = 0,
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[TTS]:
|
||||
"""Get TTS history for a user.
|
||||
|
||||
@@ -229,7 +249,11 @@ class TTSService:
|
||||
return list(result)
|
||||
|
||||
async def _create_sound_record(
|
||||
self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str,
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
user_id: int,
|
||||
file_hash: str,
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio."""
|
||||
# Get audio metadata
|
||||
@@ -253,11 +277,13 @@ class TTSService:
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
return sound
|
||||
return await self.sound_repo.create(sound_data)
|
||||
|
||||
async def _create_sound_record_simple(
|
||||
self, audio_path: Path, text: str, provider: str, user_id: int,
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
user_id: int,
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio with minimal processing."""
|
||||
# Create sound data with basic info
|
||||
@@ -278,11 +304,13 @@ class TTSService:
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
return sound
|
||||
return await self.sound_repo.create(sound_data)
|
||||
|
||||
async def _create_sound_record_complete(
|
||||
self, audio_path: Path, text: str, provider: str, user_id: int,
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
user_id: int,
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio with complete metadata."""
|
||||
# Get audio metadata
|
||||
@@ -316,8 +344,7 @@ class TTSService:
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
return sound
|
||||
return await self.sound_repo.create(sound_data)
|
||||
|
||||
async def _normalize_sound_safe(self, sound_id: int) -> None:
|
||||
"""Normalize the TTS sound with error handling."""
|
||||
@@ -331,12 +358,16 @@ class TTSService:
|
||||
result = await normalizer_service.normalize_sound(sound)
|
||||
|
||||
if result["status"] == "error":
|
||||
print(
|
||||
f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}",
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(
|
||||
"Warning: Failed to normalize TTS sound %s: %s",
|
||||
sound_id,
|
||||
result.get("error"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception during TTS sound normalization {sound_id}: {e}")
|
||||
except Exception:
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Exception during TTS sound normalization %s", sound_id)
|
||||
# Don't fail the TTS generation if normalization fails
|
||||
|
||||
async def _normalize_sound(self, sound_id: int) -> None:
|
||||
@@ -356,20 +387,23 @@ class TTSService:
|
||||
|
||||
except Exception:
|
||||
# Don't fail the TTS generation if normalization fails
|
||||
pass
|
||||
logger = get_logger(__name__)
|
||||
logger.exception("Error normalizing sound %s", sound_id)
|
||||
|
||||
async def delete_tts(self, tts_id: int, user_id: int) -> None:
|
||||
"""Delete a TTS generation and its associated sound and files."""
|
||||
# Get the TTS record
|
||||
tts = await self.tts_repo.get_by_id(tts_id)
|
||||
if not tts:
|
||||
raise ValueError(f"TTS with ID {tts_id} not found")
|
||||
tts_not_found_msg = f"TTS with ID {tts_id} not found"
|
||||
raise ValueError(tts_not_found_msg)
|
||||
|
||||
# Check ownership
|
||||
if tts.user_id != user_id:
|
||||
raise PermissionError(
|
||||
"You don't have permission to delete this TTS generation",
|
||||
permission_error_msg = (
|
||||
"You don't have permission to delete this TTS generation"
|
||||
)
|
||||
raise PermissionError(permission_error_msg)
|
||||
|
||||
# If there's an associated sound, delete it and its files
|
||||
if tts.sound_id:
|
||||
@@ -385,8 +419,6 @@ class TTSService:
|
||||
|
||||
async def _delete_sound_files(self, sound: Sound) -> None:
|
||||
"""Delete all files associated with a sound."""
|
||||
from pathlib import Path
|
||||
|
||||
# Delete original file
|
||||
original_path = Path("sounds/originals/text_to_speech") / sound.filename
|
||||
if original_path.exists():
|
||||
@@ -465,7 +497,8 @@ class TTSService:
|
||||
tts = result.first()
|
||||
|
||||
if not tts:
|
||||
raise ValueError(f"TTS with ID {tts_id} not found")
|
||||
tts_not_found_msg = f"TTS with ID {tts_id} not found"
|
||||
raise ValueError(tts_not_found_msg)
|
||||
|
||||
# Generate the TTS
|
||||
sound = await self._generate_tts_sync(
|
||||
@@ -477,6 +510,9 @@ class TTSService:
|
||||
|
||||
# Capture sound ID before session issues
|
||||
sound_id = sound.id
|
||||
if sound_id is None:
|
||||
sound_creation_error = "Sound creation failed - no ID assigned"
|
||||
raise ValueError(sound_creation_error)
|
||||
|
||||
# Mark as completed
|
||||
await self.mark_tts_completed(tts_id, sound_id)
|
||||
@@ -501,9 +537,6 @@ class TTSService:
|
||||
) -> None:
|
||||
"""Emit a socket event for TTS status change."""
|
||||
try:
|
||||
from app.core.logging import get_logger
|
||||
from app.services.socket import socket_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
data = {
|
||||
@@ -513,12 +546,10 @@ class TTSService:
|
||||
if error:
|
||||
data["error"] = error
|
||||
|
||||
logger.info(f"Emitting TTS socket event: {event} with data: {data}")
|
||||
logger.info("Emitting TTS socket event: %s with data: %s", event, data)
|
||||
await socket_manager.broadcast_to_all(event, data)
|
||||
logger.info(f"Successfully emitted TTS socket event: {event}")
|
||||
except Exception as e:
|
||||
logger.info("Successfully emitted TTS socket event: %s", event)
|
||||
except Exception:
|
||||
# Don't fail TTS processing if socket emission fails
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True)
|
||||
logger.exception("Failed to emit TTS socket event %s", event)
|
||||
|
||||
Reference in New Issue
Block a user