style: Format code for consistency and readability across TTS modules

This commit is contained in:
JSC
2025-09-21 18:05:20 +02:00
parent 50eeae4c62
commit d3b6e90262
11 changed files with 36 additions and 27 deletions

View File

@@ -18,11 +18,11 @@ class TTSGenerateRequest(BaseModel):
"""TTS generation request model.""" """TTS generation request model."""
text: str = Field( text: str = Field(
..., min_length=1, max_length=1000, description="Text to convert to speech" ..., min_length=1, max_length=1000, description="Text to convert to speech",
) )
provider: str = Field(default="gtts", description="TTS provider to use") provider: str = Field(default="gtts", description="TTS provider to use")
options: dict[str, Any] = Field( options: dict[str, Any] = Field(
default_factory=dict, description="Provider-specific options" default_factory=dict, description="Provider-specific options",
) )

View File

@@ -7,12 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware
from app.api import api_router from app.api import api_router
from app.core.config import settings from app.core.config import settings
from app.core.database import get_session_factory, init_db from app.core.database import get_session_factory
from app.core.logging import get_logger, setup_logging from app.core.logging import get_logger, setup_logging
from app.core.services import app_services from app.core.services import app_services
from app.middleware.logging import LoggingMiddleware from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor from app.services.extraction_processor import extraction_processor
from app.services.tts_processor import tts_processor
from app.services.player import ( from app.services.player import (
get_player_service, get_player_service,
initialize_player_service, initialize_player_service,
@@ -20,6 +19,7 @@ from app.services.player import (
) )
from app.services.scheduler import SchedulerService from app.services.scheduler import SchedulerService
from app.services.socket import socket_manager from app.services.socket import socket_manager
from app.services.tts_processor import tts_processor
@asynccontextmanager @asynccontextmanager

View File

@@ -17,6 +17,7 @@ from .user import User
from .user_oauth import UserOauth from .user_oauth import UserOauth
__all__ = [ __all__ = [
"TTS",
"BaseModel", "BaseModel",
"CreditAction", "CreditAction",
"CreditTransaction", "CreditTransaction",
@@ -28,7 +29,6 @@ __all__ = [
"ScheduledTask", "ScheduledTask",
"Sound", "Sound",
"SoundPlayed", "SoundPlayed",
"TTS",
"User", "User",
"UserOauth", "UserOauth",
] ]

View File

@@ -18,11 +18,11 @@ class TTS(SQLModel, table=True):
options: dict[str, Any] = Field( options: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
sa_column=Column(JSON), sa_column=Column(JSON),
description="Provider-specific options used" description="Provider-specific options used",
) )
status: str = Field(default="pending", description="Processing status") status: str = Field(default="pending", description="Processing status")
error: str | None = Field(default=None, description="Error message if failed") error: str | None = Field(default=None, description="Error message if failed")
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID") sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
user_id: int = Field(foreign_key="user.id", description="User who created the TTS") user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
created_at: datetime = Field(default_factory=datetime.utcnow) created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow) updated_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -1,6 +1,7 @@
"""TTS repository for database operations.""" """TTS repository for database operations."""
from typing import Any, Sequence from collections.abc import Sequence
from typing import Any
from sqlmodel import select from sqlmodel import select
@@ -29,6 +30,7 @@ class TTSRepository(BaseRepository[TTS]):
Returns: Returns:
List of TTS records List of TTS records
""" """
stmt = ( stmt = (
select(self.model) select(self.model)
@@ -53,10 +55,11 @@ class TTSRepository(BaseRepository[TTS]):
Returns: Returns:
TTS record if found and belongs to user, None otherwise TTS record if found and belongs to user, None otherwise
""" """
stmt = select(self.model).where( stmt = select(self.model).where(
self.model.id == tts_id, self.model.id == tts_id,
self.model.user_id == user_id, self.model.user_id == user_id,
) )
result = await self.session.exec(stmt) result = await self.session.exec(stmt)
return result.first() return result.first()

View File

@@ -3,4 +3,4 @@
from .base import TTSProvider from .base import TTSProvider
from .service import TTSService from .service import TTSService
__all__ = ["TTSProvider", "TTSService"] __all__ = ["TTSProvider", "TTSService"]

View File

@@ -17,6 +17,7 @@ class TTSProvider(ABC):
Returns: Returns:
Audio data as bytes Audio data as bytes
""" """
@abstractmethod @abstractmethod
@@ -35,4 +36,4 @@ class TTSProvider(ABC):
@property @property
@abstractmethod @abstractmethod
def file_extension(self) -> str: def file_extension(self) -> str:
"""Return the default file extension for this provider.""" """Return the default file extension for this provider."""

View File

@@ -2,4 +2,4 @@
from .gtts import GTTSProvider from .gtts import GTTSProvider
__all__ = ["GTTSProvider"] __all__ = ["GTTSProvider"]

View File

@@ -31,6 +31,7 @@ class GTTSProvider(TTSProvider):
Returns: Returns:
MP3 audio data as bytes MP3 audio data as bytes
""" """
lang = options.get("lang", "en") lang = options.get("lang", "en")
tld = options.get("tld", "com") tld = options.get("tld", "com")
@@ -60,7 +61,7 @@ class GTTSProvider(TTSProvider):
"lv", "mk", "ml", "mr", "ms", "mt", "my", "ne", "nl", "no", "pa", "lv", "mk", "ml", "mr", "ms", "mt", "my", "ne", "nl", "no", "pa",
"pl", "pt", "pt-br", "pt-pt", "ro", "ru", "si", "sk", "sl", "sq", "pl", "pt", "pt-br", "pt-pt", "ro", "ru", "si", "sk", "sl", "sq",
"sr", "su", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur", "sr", "su", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur",
"vi", "yo", "zh", "zh-cn", "zh-tw", "zu" "vi", "yo", "zh", "zh-cn", "zh-tw", "zu",
] ]
def get_option_schema(self) -> dict[str, Any]: def get_option_schema(self) -> dict[str, Any]:
@@ -70,11 +71,11 @@ class GTTSProvider(TTSProvider):
"type": "string", "type": "string",
"default": "en", "default": "en",
"description": "Language code", "description": "Language code",
"enum": self.get_supported_languages() "enum": self.get_supported_languages(),
}, },
"slow": { "slow": {
"type": "boolean", "type": "boolean",
"default": False, "default": False,
"description": "Speak slowly" "description": "Speak slowly",
} },
} }

View File

@@ -33,6 +33,7 @@ class TTSService:
Args: Args:
session: Database session session: Database session
""" """
self.session = session self.session = session
self.sound_repo = SoundRepository(session) self.sound_repo = SoundRepository(session)
@@ -51,6 +52,7 @@ class TTSService:
Args: Args:
provider: TTS provider instance provider: TTS provider instance
""" """
self.providers[provider.name] = provider self.providers[provider.name] = provider
@@ -83,6 +85,7 @@ class TTSService:
Raises: Raises:
ValueError: If provider not found or text too long ValueError: If provider not found or text too long
Exception: If request creation fails Exception: If request creation fails
""" """
provider_not_found_msg = f"Provider '{provider}' not found" provider_not_found_msg = f"Provider '{provider}' not found"
if provider not in self.providers: if provider not in self.providers:
@@ -164,7 +167,7 @@ class TTSService:
pass pass
async def _generate_tts_sync( async def _generate_tts_sync(
self, text: str, provider: str, user_id: int, options: dict[str, Any] self, text: str, provider: str, user_id: int, options: dict[str, Any],
) -> Sound: ) -> Sound:
"""Generate TTS using a synchronous approach.""" """Generate TTS using a synchronous approach."""
# Generate the audio using the provider (avoid async issues by doing it directly) # Generate the audio using the provider (avoid async issues by doing it directly)
@@ -200,7 +203,7 @@ class TTSService:
# Create Sound record with proper metadata # Create Sound record with proper metadata
sound = await self._create_sound_record_complete( sound = await self._create_sound_record_complete(
original_path, text, provider, user_id original_path, text, provider, user_id,
) )
# Normalize the sound # Normalize the sound
@@ -209,7 +212,7 @@ class TTSService:
return sound return sound
async def get_user_tts_history( async def get_user_tts_history(
self, user_id: int, limit: int = 50, offset: int = 0 self, user_id: int, limit: int = 50, offset: int = 0,
) -> list[TTS]: ) -> list[TTS]:
"""Get TTS history for a user. """Get TTS history for a user.
@@ -220,12 +223,13 @@ class TTSService:
Returns: Returns:
List of TTS records List of TTS records
""" """
result = await self.tts_repo.get_by_user_id(user_id, limit, offset) result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
return list(result) return list(result)
async def _create_sound_record( async def _create_sound_record(
self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str,
) -> Sound: ) -> Sound:
"""Create a Sound record for the TTS audio.""" """Create a Sound record for the TTS audio."""
# Get audio metadata # Get audio metadata
@@ -253,7 +257,7 @@ class TTSService:
return sound return sound
async def _create_sound_record_simple( async def _create_sound_record_simple(
self, audio_path: Path, text: str, provider: str, user_id: int self, audio_path: Path, text: str, provider: str, user_id: int,
) -> Sound: ) -> Sound:
"""Create a Sound record for the TTS audio with minimal processing.""" """Create a Sound record for the TTS audio with minimal processing."""
# Create sound data with basic info # Create sound data with basic info
@@ -278,7 +282,7 @@ class TTSService:
return sound return sound
async def _create_sound_record_complete( async def _create_sound_record_complete(
self, audio_path: Path, text: str, provider: str, user_id: int self, audio_path: Path, text: str, provider: str, user_id: int,
) -> Sound: ) -> Sound:
"""Create a Sound record for the TTS audio with complete metadata.""" """Create a Sound record for the TTS audio with complete metadata."""
# Get audio metadata # Get audio metadata
@@ -328,7 +332,7 @@ class TTSService:
if result["status"] == "error": if result["status"] == "error":
print( print(
f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}" f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}",
) )
except Exception as e: except Exception as e:
@@ -364,7 +368,7 @@ class TTSService:
# Check ownership # Check ownership
if tts.user_id != user_id: if tts.user_id != user_id:
raise PermissionError( raise PermissionError(
"You don't have permission to delete this TTS generation" "You don't have permission to delete this TTS generation",
) )
# If there's an associated sound, delete it and its files # If there's an associated sound, delete it and its files

View File

@@ -18,7 +18,7 @@ class TTSProcessor:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the TTS processor.""" """Initialize the TTS processor."""
self.max_concurrent = getattr(settings, 'TTS_MAX_CONCURRENT', 3) self.max_concurrent = getattr(settings, "TTS_MAX_CONCURRENT", 3)
self.running_tts: set[int] = set() self.running_tts: set[int] = set()
self.processing_lock = asyncio.Lock() self.processing_lock = asyncio.Lock()
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
@@ -190,4 +190,4 @@ class TTSProcessor:
# Global TTS processor instance # Global TTS processor instance
tts_processor = TTSProcessor() tts_processor = TTSProcessor()