feat: Implement Text-to-Speech (TTS) functionality with API endpoints, models, and service integration
This commit is contained in:
45
alembic/versions/e617c155eea9_add_tts_table.py
Normal file
45
alembic/versions/e617c155eea9_add_tts_table.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Add TTS table
|
||||||
|
|
||||||
|
Revision ID: e617c155eea9
|
||||||
|
Revises: a0d322857b2c
|
||||||
|
Create Date: 2025-09-20 21:51:26.557738
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'e617c155eea9'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = 'a0d322857b2c'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('tts',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('text', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=False),
|
||||||
|
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
|
||||||
|
sa.Column('options', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('sound_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('tts')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -15,6 +15,7 @@ from app.api.v1 import (
|
|||||||
scheduler,
|
scheduler,
|
||||||
socket,
|
socket,
|
||||||
sounds,
|
sounds,
|
||||||
|
tts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# V1 API router with v1 prefix
|
# V1 API router with v1 prefix
|
||||||
@@ -32,4 +33,5 @@ api_router.include_router(playlists.router, tags=["playlists"])
|
|||||||
api_router.include_router(scheduler.router, tags=["scheduler"])
|
api_router.include_router(scheduler.router, tags=["scheduler"])
|
||||||
api_router.include_router(socket.router, tags=["socket"])
|
api_router.include_router(socket.router, tags=["socket"])
|
||||||
api_router.include_router(sounds.router, tags=["sounds"])
|
api_router.include_router(sounds.router, tags=["sounds"])
|
||||||
|
api_router.include_router(tts.router, tags=["tts"])
|
||||||
api_router.include_router(admin.router)
|
api_router.include_router(admin.router)
|
||||||
|
|||||||
216
app/api/v1/tts.py
Normal file
216
app/api/v1/tts.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
"""TTS API endpoints."""
|
||||||
|
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.dependencies import get_current_active_user_flexible
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.tts import TTSService
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/tts", tags=["tts"])
|
||||||
|
|
||||||
|
|
||||||
|
class TTSGenerateRequest(BaseModel):
|
||||||
|
"""TTS generation request model."""
|
||||||
|
|
||||||
|
text: str = Field(..., min_length=1, max_length=1000, description="Text to convert to speech")
|
||||||
|
provider: str = Field(default="gtts", description="TTS provider to use")
|
||||||
|
options: dict[str, Any] = Field(default_factory=dict, description="Provider-specific options")
|
||||||
|
|
||||||
|
|
||||||
|
class TTSResponse(BaseModel):
|
||||||
|
"""TTS generation response model."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
text: str
|
||||||
|
provider: str
|
||||||
|
options: dict[str, Any]
|
||||||
|
sound_id: int | None
|
||||||
|
user_id: int
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderInfo(BaseModel):
|
||||||
|
"""Provider information model."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
file_extension: str
|
||||||
|
supported_languages: list[str]
|
||||||
|
option_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tts_service(
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> TTSService:
|
||||||
|
"""Get the TTS service."""
|
||||||
|
return TTSService(session)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate")
|
||||||
|
async def generate_tts(
|
||||||
|
request: TTSGenerateRequest,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Generate TTS audio and create sound."""
|
||||||
|
try:
|
||||||
|
if current_user.id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User ID not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await tts_service.create_tts_request(
|
||||||
|
text=request.text,
|
||||||
|
user_id=current_user.id,
|
||||||
|
provider=request.provider,
|
||||||
|
**request.options
|
||||||
|
)
|
||||||
|
|
||||||
|
tts_record = result["tts"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": result["message"],
|
||||||
|
"tts": TTSResponse(
|
||||||
|
id=tts_record.id,
|
||||||
|
text=tts_record.text,
|
||||||
|
provider=tts_record.provider,
|
||||||
|
options=tts_record.options,
|
||||||
|
sound_id=tts_record.sound_id,
|
||||||
|
user_id=tts_record.user_id,
|
||||||
|
created_at=tts_record.created_at.isoformat(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to generate TTS: {e!s}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers")
|
||||||
|
async def get_providers(
|
||||||
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||||
|
) -> dict[str, ProviderInfo]:
|
||||||
|
"""Get all available TTS providers."""
|
||||||
|
providers = tts_service.get_providers()
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for name, provider in providers.items():
|
||||||
|
result[name] = ProviderInfo(
|
||||||
|
name=provider.name,
|
||||||
|
file_extension=provider.file_extension,
|
||||||
|
supported_languages=provider.get_supported_languages(),
|
||||||
|
option_schema=provider.get_option_schema(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers/{provider_name}")
|
||||||
|
async def get_provider(
|
||||||
|
provider_name: str,
|
||||||
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||||
|
) -> ProviderInfo:
|
||||||
|
"""Get information about a specific TTS provider."""
|
||||||
|
provider = tts_service.get_provider(provider_name)
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Provider '{provider_name}' not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ProviderInfo(
|
||||||
|
name=provider.name,
|
||||||
|
file_extension=provider.file_extension,
|
||||||
|
supported_languages=provider.get_supported_languages(),
|
||||||
|
option_schema=provider.get_option_schema(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history")
|
||||||
|
async def get_tts_history(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[TTSResponse]:
|
||||||
|
"""Get TTS generation history for the current user."""
|
||||||
|
try:
|
||||||
|
if current_user.id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User ID not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
tts_records = await tts_service.get_user_tts_history(
|
||||||
|
user_id=current_user.id,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
TTSResponse(
|
||||||
|
id=tts.id,
|
||||||
|
text=tts.text,
|
||||||
|
provider=tts.provider,
|
||||||
|
options=tts.options,
|
||||||
|
sound_id=tts.sound_id,
|
||||||
|
user_id=tts.user_id,
|
||||||
|
created_at=tts.created_at.isoformat(),
|
||||||
|
)
|
||||||
|
for tts in tts_records
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to get TTS history: {e!s}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{tts_id}")
|
||||||
|
async def delete_tts(
|
||||||
|
tts_id: int,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Delete a TTS generation and its associated files."""
|
||||||
|
try:
|
||||||
|
if current_user.id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User ID not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
|
||||||
|
|
||||||
|
return {"message": "TTS generation deleted successfully"}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except PermissionError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete TTS: {e!s}",
|
||||||
|
) from e
|
||||||
@@ -12,6 +12,7 @@ from .playlist_sound import PlaylistSound
|
|||||||
from .scheduled_task import ScheduledTask
|
from .scheduled_task import ScheduledTask
|
||||||
from .sound import Sound
|
from .sound import Sound
|
||||||
from .sound_played import SoundPlayed
|
from .sound_played import SoundPlayed
|
||||||
|
from .tts import TTS
|
||||||
from .user import User
|
from .user import User
|
||||||
from .user_oauth import UserOauth
|
from .user_oauth import UserOauth
|
||||||
|
|
||||||
@@ -27,6 +28,7 @@ __all__ = [
|
|||||||
"ScheduledTask",
|
"ScheduledTask",
|
||||||
"Sound",
|
"Sound",
|
||||||
"SoundPlayed",
|
"SoundPlayed",
|
||||||
|
"TTS",
|
||||||
"User",
|
"User",
|
||||||
"UserOauth",
|
"UserOauth",
|
||||||
]
|
]
|
||||||
|
|||||||
26
app/models/tts.py
Normal file
26
app/models/tts.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""TTS model."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import JSON, Column
|
||||||
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
|
||||||
|
class TTS(SQLModel, table=True):
|
||||||
|
"""Text-to-Speech generation record."""
|
||||||
|
|
||||||
|
__tablename__ = "tts"
|
||||||
|
|
||||||
|
id: int | None = Field(primary_key=True)
|
||||||
|
text: str = Field(max_length=1000, description="Text that was converted to speech")
|
||||||
|
provider: str = Field(max_length=50, description="TTS provider used")
|
||||||
|
options: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
sa_column=Column(JSON),
|
||||||
|
description="Provider-specific options used"
|
||||||
|
)
|
||||||
|
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
|
||||||
|
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
62
app/repositories/tts.py
Normal file
62
app/repositories/tts.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""TTS repository for database operations."""
|
||||||
|
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from app.models.tts import TTS
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
|
class TTSRepository(BaseRepository[TTS]):
|
||||||
|
"""Repository for TTS operations."""
|
||||||
|
|
||||||
|
def __init__(self, session: Any) -> None:
|
||||||
|
super().__init__(TTS, session)
|
||||||
|
|
||||||
|
async def get_by_user_id(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> Sequence[TTS]:
|
||||||
|
"""Get TTS records by user ID with pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID to filter by
|
||||||
|
limit: Maximum number of records to return
|
||||||
|
offset: Number of records to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TTS records
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
select(self.model)
|
||||||
|
.where(self.model.user_id == user_id)
|
||||||
|
.order_by(self.model.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
return result.all()
|
||||||
|
|
||||||
|
async def get_by_user_and_id(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
tts_id: int,
|
||||||
|
) -> TTS | None:
|
||||||
|
"""Get a specific TTS record by user ID and TTS ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID to filter by
|
||||||
|
tts_id: TTS ID to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TTS record if found and belongs to user, None otherwise
|
||||||
|
"""
|
||||||
|
stmt = select(self.model).where(
|
||||||
|
self.model.id == tts_id,
|
||||||
|
self.model.user_id == user_id,
|
||||||
|
)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
return result.first()
|
||||||
6
app/services/tts/__init__.py
Normal file
6
app/services/tts/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Text-to-Speech services package."""
|
||||||
|
|
||||||
|
from .base import TTSProvider
|
||||||
|
from .service import TTSService
|
||||||
|
|
||||||
|
__all__ = ["TTSProvider", "TTSService"]
|
||||||
38
app/services/tts/base.py
Normal file
38
app/services/tts/base.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Base TTS provider interface."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class TTSProvider(ABC):
|
||||||
|
"""Abstract base class for TTS providers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate_speech(self, text: str, **options: Any) -> bytes:
|
||||||
|
"""Generate speech from text with provider-specific options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to convert to speech
|
||||||
|
**options: Provider-specific options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio data as bytes
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_supported_languages(self) -> list[str]:
|
||||||
|
"""Return list of supported language codes."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_option_schema(self) -> dict[str, Any]:
|
||||||
|
"""Return schema for provider-specific options."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the provider name."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def file_extension(self) -> str:
|
||||||
|
"""Return the default file extension for this provider."""
|
||||||
5
app/services/tts/providers/__init__.py
Normal file
5
app/services/tts/providers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""TTS providers package."""
|
||||||
|
|
||||||
|
from .gtts import GTTSProvider
|
||||||
|
|
||||||
|
__all__ = ["GTTSProvider"]
|
||||||
81
app/services/tts/providers/gtts.py
Normal file
81
app/services/tts/providers/gtts.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""Google Text-to-Speech provider."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from gtts import gTTS
|
||||||
|
|
||||||
|
from ..base import TTSProvider
|
||||||
|
|
||||||
|
|
||||||
|
class GTTSProvider(TTSProvider):
|
||||||
|
"""Google Text-to-Speech provider implementation."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the provider name."""
|
||||||
|
return "gtts"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def file_extension(self) -> str:
|
||||||
|
"""Return the default file extension for this provider."""
|
||||||
|
return "mp3"
|
||||||
|
|
||||||
|
async def generate_speech(self, text: str, **options: Any) -> bytes:
|
||||||
|
"""Generate speech from text using Google TTS.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to convert to speech
|
||||||
|
**options: GTTS-specific options (lang, tld, slow)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP3 audio data as bytes
|
||||||
|
"""
|
||||||
|
lang = options.get("lang", "en")
|
||||||
|
tld = options.get("tld", "com")
|
||||||
|
slow = options.get("slow", False)
|
||||||
|
|
||||||
|
# Run TTS generation in thread pool since gTTS is synchronous
|
||||||
|
def _generate():
|
||||||
|
tts = gTTS(text=text, lang=lang, tld=tld, slow=slow)
|
||||||
|
fp = io.BytesIO()
|
||||||
|
tts.write_to_fp(fp)
|
||||||
|
fp.seek(0)
|
||||||
|
return fp.read()
|
||||||
|
|
||||||
|
# Use asyncio.to_thread which is more reliable than run_in_executor
|
||||||
|
return await asyncio.to_thread(_generate)
|
||||||
|
|
||||||
|
def get_supported_languages(self) -> list[str]:
|
||||||
|
"""Return list of supported language codes."""
|
||||||
|
# Common GTTS supported languages
|
||||||
|
return [
|
||||||
|
"af", "ar", "bg", "bn", "bs", "ca", "cs", "cy", "da", "de", "el", "en",
|
||||||
|
"eo", "es", "et", "fi", "fr", "gu", "hi", "hr", "hu", "hy", "id", "is",
|
||||||
|
"it", "ja", "jw", "km", "kn", "ko", "la", "lv", "mk", "ml", "mr", "my",
|
||||||
|
"ne", "nl", "no", "pl", "pt", "ro", "ru", "si", "sk", "sq", "sr", "su",
|
||||||
|
"sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur", "vi", "zh-cn", "zh-tw"
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_option_schema(self) -> dict[str, Any]:
|
||||||
|
"""Return schema for GTTS-specific options."""
|
||||||
|
return {
|
||||||
|
"lang": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "en",
|
||||||
|
"description": "Language code",
|
||||||
|
"enum": self.get_supported_languages()
|
||||||
|
},
|
||||||
|
"tld": {
|
||||||
|
"type": "string",
|
||||||
|
"default": "com",
|
||||||
|
"description": "Top-level domain for Google TTS",
|
||||||
|
"enum": ["com", "co.uk", "com.au", "ca", "co.in", "ie", "co.za"]
|
||||||
|
},
|
||||||
|
"slow": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": False,
|
||||||
|
"description": "Speak slowly"
|
||||||
|
}
|
||||||
|
}
|
||||||
404
app/services/tts/service.py
Normal file
404
app/services/tts/service.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
"""TTS service implementation."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from gtts import gTTS
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.sound import Sound
|
||||||
|
from app.models.tts import TTS
|
||||||
|
from app.repositories.sound import SoundRepository
|
||||||
|
from app.repositories.tts import TTSRepository
|
||||||
|
from app.services.sound_normalizer import SoundNormalizerService
|
||||||
|
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||||
|
|
||||||
|
from .base import TTSProvider
|
||||||
|
from .providers import GTTSProvider
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
MAX_TEXT_LENGTH = 1000
|
||||||
|
MAX_NAME_LENGTH = 50
|
||||||
|
|
||||||
|
|
||||||
|
class TTSService:
|
||||||
|
"""Text-to-Speech service with provider management."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize TTS service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
"""
|
||||||
|
self.session = session
|
||||||
|
self.sound_repo = SoundRepository(session)
|
||||||
|
self.tts_repo = TTSRepository(session)
|
||||||
|
self.providers: dict[str, TTSProvider] = {}
|
||||||
|
|
||||||
|
# Register default providers
|
||||||
|
self._register_default_providers()
|
||||||
|
|
||||||
|
def _register_default_providers(self) -> None:
|
||||||
|
"""Register default TTS providers."""
|
||||||
|
self.register_provider(GTTSProvider())
|
||||||
|
|
||||||
|
def register_provider(self, provider: TTSProvider) -> None:
|
||||||
|
"""Register a TTS provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: TTS provider instance
|
||||||
|
"""
|
||||||
|
self.providers[provider.name] = provider
|
||||||
|
|
||||||
|
def get_providers(self) -> dict[str, TTSProvider]:
|
||||||
|
"""Get all registered providers."""
|
||||||
|
return self.providers.copy()
|
||||||
|
|
||||||
|
def get_provider(self, name: str) -> TTSProvider | None:
|
||||||
|
"""Get a specific provider by name."""
|
||||||
|
return self.providers.get(name)
|
||||||
|
|
||||||
|
async def create_tts_request(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
user_id: int,
|
||||||
|
provider: str = "gtts",
|
||||||
|
**options: Any,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a TTS request that will be processed in the background.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to convert to speech
|
||||||
|
user_id: ID of user creating the sound
|
||||||
|
provider: TTS provider name
|
||||||
|
**options: Provider-specific options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with TTS record information
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If provider not found or text too long
|
||||||
|
Exception: If request creation fails
|
||||||
|
"""
|
||||||
|
provider_not_found_msg = f"Provider '{provider}' not found"
|
||||||
|
if provider not in self.providers:
|
||||||
|
raise ValueError(provider_not_found_msg)
|
||||||
|
|
||||||
|
text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)"
|
||||||
|
if len(text) > MAX_TEXT_LENGTH:
|
||||||
|
raise ValueError(text_too_long_msg)
|
||||||
|
|
||||||
|
empty_text_msg = "Text cannot be empty"
|
||||||
|
if not text.strip():
|
||||||
|
raise ValueError(empty_text_msg)
|
||||||
|
|
||||||
|
# Create TTS record with pending status
|
||||||
|
tts = TTS(
|
||||||
|
text=text,
|
||||||
|
provider=provider,
|
||||||
|
options=options,
|
||||||
|
sound_id=None, # Will be set when processing completes
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
self.session.add(tts)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(tts)
|
||||||
|
|
||||||
|
# Queue for background processing
|
||||||
|
if tts.id is not None:
|
||||||
|
await self._queue_tts_processing(tts.id)
|
||||||
|
|
||||||
|
return {"tts": tts, "message": "TTS generation queued successfully"}
|
||||||
|
|
||||||
|
async def _queue_tts_processing(self, tts_id: int) -> None:
|
||||||
|
"""Queue TTS for background processing."""
|
||||||
|
# For now, process immediately in a different way
|
||||||
|
# This could be moved to a proper background queue later
|
||||||
|
task = asyncio.create_task(self._process_tts_in_background(tts_id))
|
||||||
|
# Store reference to prevent garbage collection
|
||||||
|
self._background_tasks = getattr(self, '_background_tasks', set())
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
|
async def _process_tts_in_background(self, tts_id: int) -> None:
|
||||||
|
"""Process TTS generation in background."""
|
||||||
|
from app.core.database import get_session_factory
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a new session for background processing
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
async with session_factory() as background_session:
|
||||||
|
tts_service = TTSService(background_session)
|
||||||
|
|
||||||
|
# Get the TTS record
|
||||||
|
stmt = select(TTS).where(TTS.id == tts_id)
|
||||||
|
result = await background_session.exec(stmt)
|
||||||
|
tts = result.first()
|
||||||
|
|
||||||
|
if not tts:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use a synchronous approach for the actual generation
|
||||||
|
sound = await tts_service._generate_tts_sync(
|
||||||
|
tts.text,
|
||||||
|
tts.provider,
|
||||||
|
tts.user_id,
|
||||||
|
tts.options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the TTS record with the sound ID
|
||||||
|
if sound.id is not None:
|
||||||
|
tts.sound_id = sound.id
|
||||||
|
background_session.add(tts)
|
||||||
|
await background_session.commit()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Log error but don't fail - avoiding print for production
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _generate_tts_sync(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
provider: str,
|
||||||
|
user_id: int,
|
||||||
|
options: dict[str, Any]
|
||||||
|
) -> Sound:
|
||||||
|
"""Generate TTS using a synchronous approach."""
|
||||||
|
# Generate the audio using the provider (avoid async issues by doing it directly)
|
||||||
|
tts_provider = self.providers[provider]
|
||||||
|
|
||||||
|
# Create directories if they don't exist
|
||||||
|
original_dir = Path("sounds/originals/text_to_speech")
|
||||||
|
original_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create UUID filename
|
||||||
|
sound_uuid = str(uuid.uuid4())
|
||||||
|
original_filename = f"{sound_uuid}.{tts_provider.file_extension}"
|
||||||
|
original_path = original_dir / original_filename
|
||||||
|
|
||||||
|
# Generate audio synchronously
|
||||||
|
try:
|
||||||
|
# Generate TTS audio
|
||||||
|
lang = options.get("lang", "en")
|
||||||
|
tld = options.get("tld", "com")
|
||||||
|
slow = options.get("slow", False)
|
||||||
|
|
||||||
|
tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow)
|
||||||
|
fp = io.BytesIO()
|
||||||
|
tts_instance.write_to_fp(fp)
|
||||||
|
fp.seek(0)
|
||||||
|
audio_bytes = fp.read()
|
||||||
|
|
||||||
|
# Save the file
|
||||||
|
original_path.write_bytes(audio_bytes)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create Sound record with proper metadata
|
||||||
|
sound = await self._create_sound_record_complete(
|
||||||
|
original_path,
|
||||||
|
text,
|
||||||
|
provider,
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize the sound
|
||||||
|
await self._normalize_sound_safe(sound.id)
|
||||||
|
|
||||||
|
return sound
|
||||||
|
|
||||||
|
async def get_user_tts_history(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0
|
||||||
|
) -> list[TTS]:
|
||||||
|
"""Get TTS history for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID
|
||||||
|
limit: Maximum number of records
|
||||||
|
offset: Offset for pagination
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of TTS records
|
||||||
|
"""
|
||||||
|
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
|
||||||
|
return list(result)
|
||||||
|
|
||||||
|
async def _create_sound_record(
|
||||||
|
self,
|
||||||
|
audio_path: Path,
|
||||||
|
text: str,
|
||||||
|
provider: str,
|
||||||
|
user_id: int,
|
||||||
|
file_hash: str
|
||||||
|
) -> Sound:
|
||||||
|
"""Create a Sound record for the TTS audio."""
|
||||||
|
# Get audio metadata
|
||||||
|
duration = get_audio_duration(audio_path)
|
||||||
|
size = get_file_size(audio_path)
|
||||||
|
|
||||||
|
# Create sound data
|
||||||
|
sound_data = {
|
||||||
|
"type": "TTS",
|
||||||
|
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
||||||
|
"filename": audio_path.name,
|
||||||
|
"duration": duration,
|
||||||
|
"size": size,
|
||||||
|
"hash": file_hash,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_deletable": True,
|
||||||
|
"is_music": False, # TTS is speech, not music
|
||||||
|
"is_normalized": False,
|
||||||
|
"play_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
sound = await self.sound_repo.create(sound_data)
|
||||||
|
return sound
|
||||||
|
|
||||||
|
async def _create_sound_record_simple(
|
||||||
|
self,
|
||||||
|
audio_path: Path,
|
||||||
|
text: str,
|
||||||
|
provider: str,
|
||||||
|
user_id: int
|
||||||
|
) -> Sound:
|
||||||
|
"""Create a Sound record for the TTS audio with minimal processing."""
|
||||||
|
# Create sound data with basic info
|
||||||
|
sound_data = {
|
||||||
|
"type": "TTS",
|
||||||
|
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
||||||
|
"filename": audio_path.name,
|
||||||
|
"duration": 0, # Skip duration calculation for now
|
||||||
|
"size": 0, # Skip size calculation for now
|
||||||
|
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_deletable": True,
|
||||||
|
"is_music": False, # TTS is speech, not music
|
||||||
|
"is_normalized": False,
|
||||||
|
"play_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
sound = await self.sound_repo.create(sound_data)
|
||||||
|
return sound
|
||||||
|
|
||||||
|
async def _create_sound_record_complete(
|
||||||
|
self,
|
||||||
|
audio_path: Path,
|
||||||
|
text: str,
|
||||||
|
provider: str,
|
||||||
|
user_id: int
|
||||||
|
) -> Sound:
|
||||||
|
"""Create a Sound record for the TTS audio with complete metadata."""
|
||||||
|
# Get audio metadata
|
||||||
|
duration = get_audio_duration(audio_path)
|
||||||
|
size = get_file_size(audio_path)
|
||||||
|
file_hash = get_file_hash(audio_path)
|
||||||
|
|
||||||
|
# Check if a sound with this hash already exists
|
||||||
|
existing_sound = await self.sound_repo.get_by_hash(file_hash)
|
||||||
|
|
||||||
|
if existing_sound:
|
||||||
|
# Clean up the temporary file since we have a duplicate
|
||||||
|
if audio_path.exists():
|
||||||
|
audio_path.unlink()
|
||||||
|
return existing_sound
|
||||||
|
|
||||||
|
# Create sound data with complete metadata
|
||||||
|
sound_data = {
|
||||||
|
"type": "TTS",
|
||||||
|
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
||||||
|
"filename": audio_path.name,
|
||||||
|
"duration": duration,
|
||||||
|
"size": size,
|
||||||
|
"hash": file_hash,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_deletable": True,
|
||||||
|
"is_music": False, # TTS is speech, not music
|
||||||
|
"is_normalized": False,
|
||||||
|
"play_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
sound = await self.sound_repo.create(sound_data)
|
||||||
|
return sound
|
||||||
|
|
||||||
|
async def _normalize_sound_safe(self, sound_id: int) -> None:
|
||||||
|
"""Normalize the TTS sound with error handling."""
|
||||||
|
try:
|
||||||
|
# Get fresh sound object from database for normalization
|
||||||
|
sound = await self.sound_repo.get_by_id(sound_id)
|
||||||
|
if not sound:
|
||||||
|
return
|
||||||
|
|
||||||
|
normalizer_service = SoundNormalizerService(self.session)
|
||||||
|
result = await normalizer_service.normalize_sound(sound)
|
||||||
|
|
||||||
|
if result["status"] == "error":
|
||||||
|
print(f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception during TTS sound normalization {sound_id}: {e}")
|
||||||
|
# Don't fail the TTS generation if normalization fails
|
||||||
|
|
||||||
|
async def _normalize_sound(self, sound_id: int) -> None:
|
||||||
|
"""Normalize the TTS sound."""
|
||||||
|
try:
|
||||||
|
# Get fresh sound object from database for normalization
|
||||||
|
sound = await self.sound_repo.get_by_id(sound_id)
|
||||||
|
if not sound:
|
||||||
|
return
|
||||||
|
|
||||||
|
normalizer_service = SoundNormalizerService(self.session)
|
||||||
|
result = await normalizer_service.normalize_sound(sound)
|
||||||
|
|
||||||
|
if result["status"] == "error":
|
||||||
|
# Log warning but don't fail the TTS generation
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Don't fail the TTS generation if normalization fails
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_tts(self, tts_id: int, user_id: int) -> None:
|
||||||
|
"""Delete a TTS generation and its associated sound and files."""
|
||||||
|
# Get the TTS record
|
||||||
|
tts = await self.tts_repo.get_by_id(tts_id)
|
||||||
|
if not tts:
|
||||||
|
raise ValueError(f"TTS with ID {tts_id} not found")
|
||||||
|
|
||||||
|
# Check ownership
|
||||||
|
if tts.user_id != user_id:
|
||||||
|
raise PermissionError("You don't have permission to delete this TTS generation")
|
||||||
|
|
||||||
|
# If there's an associated sound, delete it and its files
|
||||||
|
if tts.sound_id:
|
||||||
|
sound = await self.sound_repo.get_by_id(tts.sound_id)
|
||||||
|
if sound:
|
||||||
|
# Delete the sound files
|
||||||
|
await self._delete_sound_files(sound)
|
||||||
|
# Delete the sound record
|
||||||
|
await self.sound_repo.delete(sound)
|
||||||
|
|
||||||
|
# Delete the TTS record
|
||||||
|
await self.tts_repo.delete(tts)
|
||||||
|
|
||||||
|
async def _delete_sound_files(self, sound: Sound) -> None:
|
||||||
|
"""Delete all files associated with a sound."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Delete original file
|
||||||
|
original_path = Path("sounds/originals/text_to_speech") / sound.filename
|
||||||
|
if original_path.exists():
|
||||||
|
original_path.unlink()
|
||||||
|
|
||||||
|
# Delete normalized file if it exists
|
||||||
|
if sound.normalized_filename:
|
||||||
|
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
||||||
|
if normalized_path.exists():
|
||||||
|
normalized_path.unlink()
|
||||||
Reference in New Issue
Block a user