refactor: Improve code readability and structure across TTS modules
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m36s

This commit is contained in:
JSC
2025-09-21 19:07:32 +02:00
parent 35b857fd0d
commit acdf191a5a
8 changed files with 106 additions and 56 deletions

View File

@@ -206,8 +206,6 @@ async def delete_tts(
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
return {"message": "TTS generation deleted successfully"}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -223,3 +221,5 @@ async def delete_tts(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete TTS: {e!s}",
) from e
else:
return {"message": "TTS generation deleted successfully"}

View File

@@ -22,7 +22,9 @@ class TTS(SQLModel, table=True):
)
status: str = Field(default="pending", description="Processing status")
error: str | None = Field(default=None, description="Error message if failed")
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
sound_id: int | None = Field(
foreign_key="sound.id", description="Associated sound ID",
)
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -1,7 +1,10 @@
"""TTS repository for database operations."""
from collections.abc import Sequence
from typing import Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
@@ -12,7 +15,13 @@ from app.repositories.base import BaseRepository
class TTSRepository(BaseRepository[TTS]):
"""Repository for TTS operations."""
def __init__(self, session: Any) -> None:
def __init__(self, session: "AsyncSession") -> None:
"""Initialize TTS repository.
Args:
session: Database session for operations
"""
super().__init__(TTS, session)
async def get_by_user_id(

View File

@@ -244,7 +244,10 @@ class PlaylistService:
if was_current:
main_playlist = await self.get_main_playlist()
await self.playlist_repo.update(main_playlist, {"is_current": True})
logger.info("Set main playlist as current after deleting current playlist %s", playlist_id)
logger.info(
"Set main playlist as current after deleting current playlist %s",
playlist_id,
)
# Reload player to reflect the change
await _reload_player_playlist()
@@ -562,7 +565,11 @@ class PlaylistService:
await self.session.delete(playlist_sound)
await self.session.commit()
logger.info("Deleted %d playlist_sound records for playlist %s", len(playlist_sounds), playlist_id)
logger.info(
"Deleted %d playlist_sound records for playlist %s",
len(playlist_sounds),
playlist_id,
)
async def _unset_current_playlist(self) -> None:
"""Unset any current playlist globally."""

View File

@@ -1,14 +1,16 @@
"""Base TTS provider interface."""
from abc import ABC, abstractmethod
from typing import Any
# Type alias for TTS options
TTSOptions = dict[str, str | bool | int | float]
class TTSProvider(ABC):
"""Abstract base class for TTS providers."""
@abstractmethod
async def generate_speech(self, text: str, **options: Any) -> bytes:
async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
"""Generate speech from text with provider-specific options.
Args:
@@ -25,7 +27,7 @@ class TTSProvider(ABC):
"""Return list of supported language codes."""
@abstractmethod
def get_option_schema(self) -> dict[str, Any]:
def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
"""Return schema for provider-specific options."""
@property

View File

@@ -2,11 +2,10 @@
import asyncio
import io
from typing import Any
from gtts import gTTS
from ..base import TTSProvider
from app.services.tts.base import TTSProvider
class GTTSProvider(TTSProvider):
@@ -22,7 +21,7 @@ class GTTSProvider(TTSProvider):
"""Return the default file extension for this provider."""
return "mp3"
async def generate_speech(self, text: str, **options: Any) -> bytes:
async def generate_speech(self, text: str, **options: str | bool | float) -> bytes:
"""Generate speech from text using Google TTS.
Args:
@@ -38,7 +37,7 @@ class GTTSProvider(TTSProvider):
slow = options.get("slow", False)
# Run TTS generation in thread pool since gTTS is synchronous
def _generate():
def _generate() -> bytes:
tts = gTTS(text=text, lang=lang, tld=tld, slow=slow)
fp = io.BytesIO()
tts.write_to_fp(fp)
@@ -64,7 +63,7 @@ class GTTSProvider(TTSProvider):
"vi", "yo", "zh", "zh-cn", "zh-tw", "zu",
]
def get_option_schema(self) -> dict[str, Any]:
def get_option_schema(self) -> dict[str, dict[str, str | list[str] | bool]]:
"""Return schema for GTTS-specific options."""
return {
"lang": {

View File

@@ -10,10 +10,13 @@ from gtts import gTTS
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_session_factory
from app.core.logging import get_logger
from app.models.sound import Sound
from app.models.tts import TTS
from app.repositories.sound import SoundRepository
from app.repositories.tts import TTSRepository
from app.services.socket import socket_manager
from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
@@ -25,6 +28,12 @@ MAX_TEXT_LENGTH = 1000
MAX_NAME_LENGTH = 50
async def _get_tts_processor() -> object:
"""Get TTS processor instance, avoiding circular import."""
from app.services.tts_processor import tts_processor # noqa: PLC0415
return tts_processor
class TTSService:
"""Text-to-Speech service with provider management."""
@@ -69,7 +78,7 @@ class TTSService:
text: str,
user_id: int,
provider: str = "gtts",
**options: Any,
**options: str | bool | float,
) -> dict[str, Any]:
"""Create a TTS request that will be processed in the background.
@@ -114,8 +123,7 @@ class TTSService:
# Queue for background processing using the TTS processor
if tts.id is not None:
from app.services.tts_processor import tts_processor
tts_processor = await _get_tts_processor()
await tts_processor.queue_tts(tts.id)
return {"tts": tts, "message": "TTS generation queued successfully"}
@@ -132,8 +140,6 @@ class TTSService:
async def _process_tts_in_background(self, tts_id: int) -> None:
"""Process TTS generation in background."""
from app.core.database import get_session_factory
try:
# Create a new session for background processing
session_factory = get_session_factory()
@@ -164,13 +170,19 @@ class TTSService:
except Exception:
# Log error but don't fail - avoiding print for production
pass
logger = get_logger(__name__)
logger.exception("Error processing TTS generation %s", tts_id)
async def _generate_tts_sync(
self, text: str, provider: str, user_id: int, options: dict[str, Any],
self,
text: str,
provider: str,
user_id: int,
options: dict[str, Any],
) -> Sound:
"""Generate TTS using a synchronous approach."""
# Generate the audio using the provider (avoid async issues by doing it directly)
# Generate the audio using the provider
# (avoid async issues by doing it directly)
tts_provider = self.providers[provider]
# Create directories if they don't exist
@@ -199,20 +211,28 @@ class TTSService:
original_path.write_bytes(audio_bytes)
except Exception:
logger = get_logger(__name__)
logger.exception("Error generating TTS audio")
raise
# Create Sound record with proper metadata
sound = await self._create_sound_record_complete(
original_path, text, provider, user_id,
original_path,
text,
user_id,
)
# Normalize the sound
if sound.id is not None:
await self._normalize_sound_safe(sound.id)
return sound
async def get_user_tts_history(
self, user_id: int, limit: int = 50, offset: int = 0,
self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> list[TTS]:
"""Get TTS history for a user.
@@ -229,7 +249,11 @@ class TTSService:
return list(result)
async def _create_sound_record(
self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str,
self,
audio_path: Path,
text: str,
user_id: int,
file_hash: str,
) -> Sound:
"""Create a Sound record for the TTS audio."""
# Get audio metadata
@@ -253,11 +277,13 @@ class TTSService:
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
return await self.sound_repo.create(sound_data)
async def _create_sound_record_simple(
self, audio_path: Path, text: str, provider: str, user_id: int,
self,
audio_path: Path,
text: str,
user_id: int,
) -> Sound:
"""Create a Sound record for the TTS audio with minimal processing."""
# Create sound data with basic info
@@ -278,11 +304,13 @@ class TTSService:
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
return await self.sound_repo.create(sound_data)
async def _create_sound_record_complete(
self, audio_path: Path, text: str, provider: str, user_id: int,
self,
audio_path: Path,
text: str,
user_id: int,
) -> Sound:
"""Create a Sound record for the TTS audio with complete metadata."""
# Get audio metadata
@@ -316,8 +344,7 @@ class TTSService:
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
return sound
return await self.sound_repo.create(sound_data)
async def _normalize_sound_safe(self, sound_id: int) -> None:
"""Normalize the TTS sound with error handling."""
@@ -331,12 +358,16 @@ class TTSService:
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
print(
f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}",
logger = get_logger(__name__)
logger.warning(
"Warning: Failed to normalize TTS sound %s: %s",
sound_id,
result.get("error"),
)
except Exception as e:
print(f"Exception during TTS sound normalization {sound_id}: {e}")
except Exception:
logger = get_logger(__name__)
logger.exception("Exception during TTS sound normalization %s", sound_id)
# Don't fail the TTS generation if normalization fails
async def _normalize_sound(self, sound_id: int) -> None:
@@ -356,20 +387,23 @@ class TTSService:
except Exception:
# Don't fail the TTS generation if normalization fails
pass
logger = get_logger(__name__)
logger.exception("Error normalizing sound %s", sound_id)
async def delete_tts(self, tts_id: int, user_id: int) -> None:
"""Delete a TTS generation and its associated sound and files."""
# Get the TTS record
tts = await self.tts_repo.get_by_id(tts_id)
if not tts:
raise ValueError(f"TTS with ID {tts_id} not found")
tts_not_found_msg = f"TTS with ID {tts_id} not found"
raise ValueError(tts_not_found_msg)
# Check ownership
if tts.user_id != user_id:
raise PermissionError(
"You don't have permission to delete this TTS generation",
permission_error_msg = (
"You don't have permission to delete this TTS generation"
)
raise PermissionError(permission_error_msg)
# If there's an associated sound, delete it and its files
if tts.sound_id:
@@ -385,8 +419,6 @@ class TTSService:
async def _delete_sound_files(self, sound: Sound) -> None:
"""Delete all files associated with a sound."""
from pathlib import Path
# Delete original file
original_path = Path("sounds/originals/text_to_speech") / sound.filename
if original_path.exists():
@@ -465,7 +497,8 @@ class TTSService:
tts = result.first()
if not tts:
raise ValueError(f"TTS with ID {tts_id} not found")
tts_not_found_msg = f"TTS with ID {tts_id} not found"
raise ValueError(tts_not_found_msg)
# Generate the TTS
sound = await self._generate_tts_sync(
@@ -477,6 +510,9 @@ class TTSService:
# Capture sound ID before session issues
sound_id = sound.id
if sound_id is None:
sound_creation_error = "Sound creation failed - no ID assigned"
raise ValueError(sound_creation_error)
# Mark as completed
await self.mark_tts_completed(tts_id, sound_id)
@@ -501,9 +537,6 @@ class TTSService:
) -> None:
"""Emit a socket event for TTS status change."""
try:
from app.core.logging import get_logger
from app.services.socket import socket_manager
logger = get_logger(__name__)
data = {
@@ -513,12 +546,10 @@ class TTSService:
if error:
data["error"] = error
logger.info(f"Emitting TTS socket event: {event} with data: {data}")
logger.info("Emitting TTS socket event: %s with data: %s", event, data)
await socket_manager.broadcast_to_all(event, data)
logger.info(f"Successfully emitted TTS socket event: {event}")
except Exception as e:
logger.info("Successfully emitted TTS socket event: %s", event)
except Exception:
# Don't fail TTS processing if socket emission fails
from app.core.logging import get_logger
logger = get_logger(__name__)
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True)
logger.exception("Failed to emit TTS socket event %s", event)

View File

@@ -14,7 +14,7 @@ logger = get_logger(__name__)
class TTSProcessor:
"""Background processor for handling TTS generation queue with concurrency control."""
"""Background processor for handling TTS generation queue."""
def __init__(self) -> None:
"""Initialize the TTS processor."""