style: Format code for consistency and readability across TTS modules
This commit is contained in:
@@ -18,11 +18,11 @@ class TTSGenerateRequest(BaseModel):
|
|||||||
"""TTS generation request model."""
|
"""TTS generation request model."""
|
||||||
|
|
||||||
text: str = Field(
|
text: str = Field(
|
||||||
..., min_length=1, max_length=1000, description="Text to convert to speech"
|
..., min_length=1, max_length=1000, description="Text to convert to speech",
|
||||||
)
|
)
|
||||||
provider: str = Field(default="gtts", description="TTS provider to use")
|
provider: str = Field(default="gtts", description="TTS provider to use")
|
||||||
options: dict[str, Any] = Field(
|
options: dict[str, Any] = Field(
|
||||||
default_factory=dict, description="Provider-specific options"
|
default_factory=dict, description="Provider-specific options",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,11 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from app.api import api_router
|
from app.api import api_router
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import get_session_factory, init_db
|
from app.core.database import get_session_factory
|
||||||
from app.core.logging import get_logger, setup_logging
|
from app.core.logging import get_logger, setup_logging
|
||||||
from app.core.services import app_services
|
from app.core.services import app_services
|
||||||
from app.middleware.logging import LoggingMiddleware
|
from app.middleware.logging import LoggingMiddleware
|
||||||
from app.services.extraction_processor import extraction_processor
|
from app.services.extraction_processor import extraction_processor
|
||||||
from app.services.tts_processor import tts_processor
|
|
||||||
from app.services.player import (
|
from app.services.player import (
|
||||||
get_player_service,
|
get_player_service,
|
||||||
initialize_player_service,
|
initialize_player_service,
|
||||||
@@ -20,6 +19,7 @@ from app.services.player import (
|
|||||||
)
|
)
|
||||||
from app.services.scheduler import SchedulerService
|
from app.services.scheduler import SchedulerService
|
||||||
from app.services.socket import socket_manager
|
from app.services.socket import socket_manager
|
||||||
|
from app.services.tts_processor import tts_processor
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from .user import User
|
|||||||
from .user_oauth import UserOauth
|
from .user_oauth import UserOauth
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"TTS",
|
||||||
"BaseModel",
|
"BaseModel",
|
||||||
"CreditAction",
|
"CreditAction",
|
||||||
"CreditTransaction",
|
"CreditTransaction",
|
||||||
@@ -28,7 +29,6 @@ __all__ = [
|
|||||||
"ScheduledTask",
|
"ScheduledTask",
|
||||||
"Sound",
|
"Sound",
|
||||||
"SoundPlayed",
|
"SoundPlayed",
|
||||||
"TTS",
|
|
||||||
"User",
|
"User",
|
||||||
"UserOauth",
|
"UserOauth",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ class TTS(SQLModel, table=True):
|
|||||||
options: dict[str, Any] = Field(
|
options: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
sa_column=Column(JSON),
|
sa_column=Column(JSON),
|
||||||
description="Provider-specific options used"
|
description="Provider-specific options used",
|
||||||
)
|
)
|
||||||
status: str = Field(default="pending", description="Processing status")
|
status: str = Field(default="pending", description="Processing status")
|
||||||
error: str | None = Field(default=None, description="Error message if failed")
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
|
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
|
||||||
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
|
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""TTS repository for database operations."""
|
"""TTS repository for database operations."""
|
||||||
|
|
||||||
from typing import Any, Sequence
|
from collections.abc import Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
@@ -29,6 +30,7 @@ class TTSRepository(BaseRepository[TTS]):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of TTS records
|
List of TTS records
|
||||||
|
|
||||||
"""
|
"""
|
||||||
stmt = (
|
stmt = (
|
||||||
select(self.model)
|
select(self.model)
|
||||||
@@ -53,10 +55,11 @@ class TTSRepository(BaseRepository[TTS]):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TTS record if found and belongs to user, None otherwise
|
TTS record if found and belongs to user, None otherwise
|
||||||
|
|
||||||
"""
|
"""
|
||||||
stmt = select(self.model).where(
|
stmt = select(self.model).where(
|
||||||
self.model.id == tts_id,
|
self.model.id == tts_id,
|
||||||
self.model.user_id == user_id,
|
self.model.user_id == user_id,
|
||||||
)
|
)
|
||||||
result = await self.session.exec(stmt)
|
result = await self.session.exec(stmt)
|
||||||
return result.first()
|
return result.first()
|
||||||
|
|||||||
@@ -3,4 +3,4 @@
|
|||||||
from .base import TTSProvider
|
from .base import TTSProvider
|
||||||
from .service import TTSService
|
from .service import TTSService
|
||||||
|
|
||||||
__all__ = ["TTSProvider", "TTSService"]
|
__all__ = ["TTSProvider", "TTSService"]
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class TTSProvider(ABC):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Audio data as bytes
|
Audio data as bytes
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -35,4 +36,4 @@ class TTSProvider(ABC):
|
|||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def file_extension(self) -> str:
|
def file_extension(self) -> str:
|
||||||
"""Return the default file extension for this provider."""
|
"""Return the default file extension for this provider."""
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
|
|
||||||
from .gtts import GTTSProvider
|
from .gtts import GTTSProvider
|
||||||
|
|
||||||
__all__ = ["GTTSProvider"]
|
__all__ = ["GTTSProvider"]
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class GTTSProvider(TTSProvider):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MP3 audio data as bytes
|
MP3 audio data as bytes
|
||||||
|
|
||||||
"""
|
"""
|
||||||
lang = options.get("lang", "en")
|
lang = options.get("lang", "en")
|
||||||
tld = options.get("tld", "com")
|
tld = options.get("tld", "com")
|
||||||
@@ -60,7 +61,7 @@ class GTTSProvider(TTSProvider):
|
|||||||
"lv", "mk", "ml", "mr", "ms", "mt", "my", "ne", "nl", "no", "pa",
|
"lv", "mk", "ml", "mr", "ms", "mt", "my", "ne", "nl", "no", "pa",
|
||||||
"pl", "pt", "pt-br", "pt-pt", "ro", "ru", "si", "sk", "sl", "sq",
|
"pl", "pt", "pt-br", "pt-pt", "ro", "ru", "si", "sk", "sl", "sq",
|
||||||
"sr", "su", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur",
|
"sr", "su", "sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur",
|
||||||
"vi", "yo", "zh", "zh-cn", "zh-tw", "zu"
|
"vi", "yo", "zh", "zh-cn", "zh-tw", "zu",
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_option_schema(self) -> dict[str, Any]:
|
def get_option_schema(self) -> dict[str, Any]:
|
||||||
@@ -70,11 +71,11 @@ class GTTSProvider(TTSProvider):
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"default": "en",
|
"default": "en",
|
||||||
"description": "Language code",
|
"description": "Language code",
|
||||||
"enum": self.get_supported_languages()
|
"enum": self.get_supported_languages(),
|
||||||
},
|
},
|
||||||
"slow": {
|
"slow": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": False,
|
"default": False,
|
||||||
"description": "Speak slowly"
|
"description": "Speak slowly",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class TTSService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
session: Database session
|
session: Database session
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.session = session
|
self.session = session
|
||||||
self.sound_repo = SoundRepository(session)
|
self.sound_repo = SoundRepository(session)
|
||||||
@@ -51,6 +52,7 @@ class TTSService:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider: TTS provider instance
|
provider: TTS provider instance
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.providers[provider.name] = provider
|
self.providers[provider.name] = provider
|
||||||
|
|
||||||
@@ -83,6 +85,7 @@ class TTSService:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If provider not found or text too long
|
ValueError: If provider not found or text too long
|
||||||
Exception: If request creation fails
|
Exception: If request creation fails
|
||||||
|
|
||||||
"""
|
"""
|
||||||
provider_not_found_msg = f"Provider '{provider}' not found"
|
provider_not_found_msg = f"Provider '{provider}' not found"
|
||||||
if provider not in self.providers:
|
if provider not in self.providers:
|
||||||
@@ -164,7 +167,7 @@ class TTSService:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
async def _generate_tts_sync(
|
async def _generate_tts_sync(
|
||||||
self, text: str, provider: str, user_id: int, options: dict[str, Any]
|
self, text: str, provider: str, user_id: int, options: dict[str, Any],
|
||||||
) -> Sound:
|
) -> Sound:
|
||||||
"""Generate TTS using a synchronous approach."""
|
"""Generate TTS using a synchronous approach."""
|
||||||
# Generate the audio using the provider (avoid async issues by doing it directly)
|
# Generate the audio using the provider (avoid async issues by doing it directly)
|
||||||
@@ -200,7 +203,7 @@ class TTSService:
|
|||||||
|
|
||||||
# Create Sound record with proper metadata
|
# Create Sound record with proper metadata
|
||||||
sound = await self._create_sound_record_complete(
|
sound = await self._create_sound_record_complete(
|
||||||
original_path, text, provider, user_id
|
original_path, text, provider, user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalize the sound
|
# Normalize the sound
|
||||||
@@ -209,7 +212,7 @@ class TTSService:
|
|||||||
return sound
|
return sound
|
||||||
|
|
||||||
async def get_user_tts_history(
|
async def get_user_tts_history(
|
||||||
self, user_id: int, limit: int = 50, offset: int = 0
|
self, user_id: int, limit: int = 50, offset: int = 0,
|
||||||
) -> list[TTS]:
|
) -> list[TTS]:
|
||||||
"""Get TTS history for a user.
|
"""Get TTS history for a user.
|
||||||
|
|
||||||
@@ -220,12 +223,13 @@ class TTSService:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of TTS records
|
List of TTS records
|
||||||
|
|
||||||
"""
|
"""
|
||||||
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
|
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
|
||||||
return list(result)
|
return list(result)
|
||||||
|
|
||||||
async def _create_sound_record(
|
async def _create_sound_record(
|
||||||
self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str
|
self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str,
|
||||||
) -> Sound:
|
) -> Sound:
|
||||||
"""Create a Sound record for the TTS audio."""
|
"""Create a Sound record for the TTS audio."""
|
||||||
# Get audio metadata
|
# Get audio metadata
|
||||||
@@ -253,7 +257,7 @@ class TTSService:
|
|||||||
return sound
|
return sound
|
||||||
|
|
||||||
async def _create_sound_record_simple(
|
async def _create_sound_record_simple(
|
||||||
self, audio_path: Path, text: str, provider: str, user_id: int
|
self, audio_path: Path, text: str, provider: str, user_id: int,
|
||||||
) -> Sound:
|
) -> Sound:
|
||||||
"""Create a Sound record for the TTS audio with minimal processing."""
|
"""Create a Sound record for the TTS audio with minimal processing."""
|
||||||
# Create sound data with basic info
|
# Create sound data with basic info
|
||||||
@@ -278,7 +282,7 @@ class TTSService:
|
|||||||
return sound
|
return sound
|
||||||
|
|
||||||
async def _create_sound_record_complete(
|
async def _create_sound_record_complete(
|
||||||
self, audio_path: Path, text: str, provider: str, user_id: int
|
self, audio_path: Path, text: str, provider: str, user_id: int,
|
||||||
) -> Sound:
|
) -> Sound:
|
||||||
"""Create a Sound record for the TTS audio with complete metadata."""
|
"""Create a Sound record for the TTS audio with complete metadata."""
|
||||||
# Get audio metadata
|
# Get audio metadata
|
||||||
@@ -328,7 +332,7 @@ class TTSService:
|
|||||||
|
|
||||||
if result["status"] == "error":
|
if result["status"] == "error":
|
||||||
print(
|
print(
|
||||||
f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}"
|
f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -364,7 +368,7 @@ class TTSService:
|
|||||||
# Check ownership
|
# Check ownership
|
||||||
if tts.user_id != user_id:
|
if tts.user_id != user_id:
|
||||||
raise PermissionError(
|
raise PermissionError(
|
||||||
"You don't have permission to delete this TTS generation"
|
"You don't have permission to delete this TTS generation",
|
||||||
)
|
)
|
||||||
|
|
||||||
# If there's an associated sound, delete it and its files
|
# If there's an associated sound, delete it and its files
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class TTSProcessor:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize the TTS processor."""
|
"""Initialize the TTS processor."""
|
||||||
self.max_concurrent = getattr(settings, 'TTS_MAX_CONCURRENT', 3)
|
self.max_concurrent = getattr(settings, "TTS_MAX_CONCURRENT", 3)
|
||||||
self.running_tts: set[int] = set()
|
self.running_tts: set[int] = set()
|
||||||
self.processing_lock = asyncio.Lock()
|
self.processing_lock = asyncio.Lock()
|
||||||
self.shutdown_event = asyncio.Event()
|
self.shutdown_event = asyncio.Event()
|
||||||
@@ -190,4 +190,4 @@ class TTSProcessor:
|
|||||||
|
|
||||||
|
|
||||||
# Global TTS processor instance
|
# Global TTS processor instance
|
||||||
tts_processor = TTSProcessor()
|
tts_processor = TTSProcessor()
|
||||||
|
|||||||
Reference in New Issue
Block a user