feat: Implement Text-to-Speech (TTS) functionality with API endpoints, models, and service integration
This commit is contained in:
45
alembic/versions/e617c155eea9_add_tts_table.py
Normal file
45
alembic/versions/e617c155eea9_add_tts_table.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Add TTS table
|
||||
|
||||
Revision ID: e617c155eea9
|
||||
Revises: a0d322857b2c
|
||||
Create Date: 2025-09-20 21:51:26.557738
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e617c155eea9'
|
||||
down_revision: Union[str, Sequence[str], None] = 'a0d322857b2c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tts',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('text', sqlmodel.sql.sqltypes.AutoString(length=1000), nullable=False),
|
||||
sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(length=50), nullable=False),
|
||||
sa.Column('options', sa.JSON(), nullable=True),
|
||||
sa.Column('sound_id', sa.Integer(), nullable=True),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['sound_id'], ['sound.id'], ),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tts')
|
||||
# ### end Alembic commands ###
|
||||
@@ -15,6 +15,7 @@ from app.api.v1 import (
|
||||
scheduler,
|
||||
socket,
|
||||
sounds,
|
||||
tts,
|
||||
)
|
||||
|
||||
# V1 API router with v1 prefix
|
||||
@@ -32,4 +33,5 @@ api_router.include_router(playlists.router, tags=["playlists"])
|
||||
api_router.include_router(scheduler.router, tags=["scheduler"])
|
||||
api_router.include_router(socket.router, tags=["socket"])
|
||||
api_router.include_router(sounds.router, tags=["sounds"])
|
||||
api_router.include_router(tts.router, tags=["tts"])
|
||||
api_router.include_router(admin.router)
|
||||
|
||||
216
app/api/v1/tts.py
Normal file
216
app/api/v1/tts.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""TTS API endpoints."""
|
||||
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.user import User
|
||||
from app.services.tts import TTSService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/tts", tags=["tts"])
|
||||
|
||||
|
||||
class TTSGenerateRequest(BaseModel):
|
||||
"""TTS generation request model."""
|
||||
|
||||
text: str = Field(..., min_length=1, max_length=1000, description="Text to convert to speech")
|
||||
provider: str = Field(default="gtts", description="TTS provider to use")
|
||||
options: dict[str, Any] = Field(default_factory=dict, description="Provider-specific options")
|
||||
|
||||
|
||||
class TTSResponse(BaseModel):
|
||||
"""TTS generation response model."""
|
||||
|
||||
id: int
|
||||
text: str
|
||||
provider: str
|
||||
options: dict[str, Any]
|
||||
sound_id: int | None
|
||||
user_id: int
|
||||
created_at: str
|
||||
|
||||
|
||||
class ProviderInfo(BaseModel):
|
||||
"""Provider information model."""
|
||||
|
||||
name: str
|
||||
file_extension: str
|
||||
supported_languages: list[str]
|
||||
option_schema: dict[str, Any]
|
||||
|
||||
|
||||
async def get_tts_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> TTSService:
|
||||
"""Get the TTS service."""
|
||||
return TTSService(session)
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate_tts(
|
||||
request: TTSGenerateRequest,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
) -> dict[str, Any]:
|
||||
"""Generate TTS audio and create sound."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
result = await tts_service.create_tts_request(
|
||||
text=request.text,
|
||||
user_id=current_user.id,
|
||||
provider=request.provider,
|
||||
**request.options
|
||||
)
|
||||
|
||||
tts_record = result["tts"]
|
||||
|
||||
return {
|
||||
"message": result["message"],
|
||||
"tts": TTSResponse(
|
||||
id=tts_record.id,
|
||||
text=tts_record.text,
|
||||
provider=tts_record.provider,
|
||||
options=tts_record.options,
|
||||
sound_id=tts_record.sound_id,
|
||||
user_id=tts_record.user_id,
|
||||
created_at=tts_record.created_at.isoformat(),
|
||||
)
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate TTS: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def get_providers(
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
) -> dict[str, ProviderInfo]:
|
||||
"""Get all available TTS providers."""
|
||||
providers = tts_service.get_providers()
|
||||
result = {}
|
||||
|
||||
for name, provider in providers.items():
|
||||
result[name] = ProviderInfo(
|
||||
name=provider.name,
|
||||
file_extension=provider.file_extension,
|
||||
supported_languages=provider.get_supported_languages(),
|
||||
option_schema=provider.get_option_schema(),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/providers/{provider_name}")
|
||||
async def get_provider(
|
||||
provider_name: str,
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
) -> ProviderInfo:
|
||||
"""Get information about a specific TTS provider."""
|
||||
provider = tts_service.get_provider(provider_name)
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Provider '{provider_name}' not found",
|
||||
)
|
||||
|
||||
return ProviderInfo(
|
||||
name=provider.name,
|
||||
file_extension=provider.file_extension,
|
||||
supported_languages=provider.get_supported_languages(),
|
||||
option_schema=provider.get_option_schema(),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_tts_history(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[TTSResponse]:
|
||||
"""Get TTS generation history for the current user."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
tts_records = await tts_service.get_user_tts_history(
|
||||
user_id=current_user.id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return [
|
||||
TTSResponse(
|
||||
id=tts.id,
|
||||
text=tts.text,
|
||||
provider=tts.provider,
|
||||
options=tts.options,
|
||||
sound_id=tts.sound_id,
|
||||
user_id=tts.user_id,
|
||||
created_at=tts.created_at.isoformat(),
|
||||
)
|
||||
for tts in tts_records
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get TTS history: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/{tts_id}")
|
||||
async def delete_tts(
|
||||
tts_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Delete a TTS generation and its associated files."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
|
||||
|
||||
return {"message": "TTS generation deleted successfully"}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete TTS: {e!s}",
|
||||
) from e
|
||||
@@ -12,6 +12,7 @@ from .playlist_sound import PlaylistSound
|
||||
from .scheduled_task import ScheduledTask
|
||||
from .sound import Sound
|
||||
from .sound_played import SoundPlayed
|
||||
from .tts import TTS
|
||||
from .user import User
|
||||
from .user_oauth import UserOauth
|
||||
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"ScheduledTask",
|
||||
"Sound",
|
||||
"SoundPlayed",
|
||||
"TTS",
|
||||
"User",
|
||||
"UserOauth",
|
||||
]
|
||||
|
||||
26
app/models/tts.py
Normal file
26
app/models/tts.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""TTS model."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class TTS(SQLModel, table=True):
|
||||
"""Text-to-Speech generation record."""
|
||||
|
||||
__tablename__ = "tts"
|
||||
|
||||
id: int | None = Field(primary_key=True)
|
||||
text: str = Field(max_length=1000, description="Text that was converted to speech")
|
||||
provider: str = Field(max_length=50, description="TTS provider used")
|
||||
options: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
sa_column=Column(JSON),
|
||||
description="Provider-specific options used"
|
||||
)
|
||||
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
|
||||
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
62
app/repositories/tts.py
Normal file
62
app/repositories/tts.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""TTS repository for database operations."""
|
||||
|
||||
from typing import Any, Sequence
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from app.models.tts import TTS
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class TTSRepository(BaseRepository[TTS]):
|
||||
"""Repository for TTS operations."""
|
||||
|
||||
def __init__(self, session: Any) -> None:
|
||||
super().__init__(TTS, session)
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> Sequence[TTS]:
|
||||
"""Get TTS records by user ID with pagination.
|
||||
|
||||
Args:
|
||||
user_id: User ID to filter by
|
||||
limit: Maximum number of records to return
|
||||
offset: Number of records to skip
|
||||
|
||||
Returns:
|
||||
List of TTS records
|
||||
"""
|
||||
stmt = (
|
||||
select(self.model)
|
||||
.where(self.model.user_id == user_id)
|
||||
.order_by(self.model.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return result.all()
|
||||
|
||||
async def get_by_user_and_id(
|
||||
self,
|
||||
user_id: int,
|
||||
tts_id: int,
|
||||
) -> TTS | None:
|
||||
"""Get a specific TTS record by user ID and TTS ID.
|
||||
|
||||
Args:
|
||||
user_id: User ID to filter by
|
||||
tts_id: TTS ID to retrieve
|
||||
|
||||
Returns:
|
||||
TTS record if found and belongs to user, None otherwise
|
||||
"""
|
||||
stmt = select(self.model).where(
|
||||
self.model.id == tts_id,
|
||||
self.model.user_id == user_id,
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return result.first()
|
||||
6
app/services/tts/__init__.py
Normal file
6
app/services/tts/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Text-to-Speech services package."""
|
||||
|
||||
from .base import TTSProvider
|
||||
from .service import TTSService
|
||||
|
||||
__all__ = ["TTSProvider", "TTSService"]
|
||||
38
app/services/tts/base.py
Normal file
38
app/services/tts/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Base TTS provider interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TTSProvider(ABC):
|
||||
"""Abstract base class for TTS providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def generate_speech(self, text: str, **options: Any) -> bytes:
|
||||
"""Generate speech from text with provider-specific options.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech
|
||||
**options: Provider-specific options
|
||||
|
||||
Returns:
|
||||
Audio data as bytes
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_languages(self) -> list[str]:
|
||||
"""Return list of supported language codes."""
|
||||
|
||||
@abstractmethod
|
||||
def get_option_schema(self) -> dict[str, Any]:
|
||||
"""Return schema for provider-specific options."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def file_extension(self) -> str:
|
||||
"""Return the default file extension for this provider."""
|
||||
5
app/services/tts/providers/__init__.py
Normal file
5
app/services/tts/providers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""TTS providers package."""
|
||||
|
||||
from .gtts import GTTSProvider
|
||||
|
||||
__all__ = ["GTTSProvider"]
|
||||
81
app/services/tts/providers/gtts.py
Normal file
81
app/services/tts/providers/gtts.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Google Text-to-Speech provider."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
from gtts import gTTS
|
||||
|
||||
from ..base import TTSProvider
|
||||
|
||||
|
||||
class GTTSProvider(TTSProvider):
|
||||
"""Google Text-to-Speech provider implementation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "gtts"
|
||||
|
||||
@property
|
||||
def file_extension(self) -> str:
|
||||
"""Return the default file extension for this provider."""
|
||||
return "mp3"
|
||||
|
||||
async def generate_speech(self, text: str, **options: Any) -> bytes:
|
||||
"""Generate speech from text using Google TTS.
|
||||
|
||||
Args:
|
||||
text: The text to convert to speech
|
||||
**options: GTTS-specific options (lang, tld, slow)
|
||||
|
||||
Returns:
|
||||
MP3 audio data as bytes
|
||||
"""
|
||||
lang = options.get("lang", "en")
|
||||
tld = options.get("tld", "com")
|
||||
slow = options.get("slow", False)
|
||||
|
||||
# Run TTS generation in thread pool since gTTS is synchronous
|
||||
def _generate():
|
||||
tts = gTTS(text=text, lang=lang, tld=tld, slow=slow)
|
||||
fp = io.BytesIO()
|
||||
tts.write_to_fp(fp)
|
||||
fp.seek(0)
|
||||
return fp.read()
|
||||
|
||||
# Use asyncio.to_thread which is more reliable than run_in_executor
|
||||
return await asyncio.to_thread(_generate)
|
||||
|
||||
def get_supported_languages(self) -> list[str]:
|
||||
"""Return list of supported language codes."""
|
||||
# Common GTTS supported languages
|
||||
return [
|
||||
"af", "ar", "bg", "bn", "bs", "ca", "cs", "cy", "da", "de", "el", "en",
|
||||
"eo", "es", "et", "fi", "fr", "gu", "hi", "hr", "hu", "hy", "id", "is",
|
||||
"it", "ja", "jw", "km", "kn", "ko", "la", "lv", "mk", "ml", "mr", "my",
|
||||
"ne", "nl", "no", "pl", "pt", "ro", "ru", "si", "sk", "sq", "sr", "su",
|
||||
"sv", "sw", "ta", "te", "th", "tl", "tr", "uk", "ur", "vi", "zh-cn", "zh-tw"
|
||||
]
|
||||
|
||||
def get_option_schema(self) -> dict[str, Any]:
|
||||
"""Return schema for GTTS-specific options."""
|
||||
return {
|
||||
"lang": {
|
||||
"type": "string",
|
||||
"default": "en",
|
||||
"description": "Language code",
|
||||
"enum": self.get_supported_languages()
|
||||
},
|
||||
"tld": {
|
||||
"type": "string",
|
||||
"default": "com",
|
||||
"description": "Top-level domain for Google TTS",
|
||||
"enum": ["com", "co.uk", "com.au", "ca", "co.in", "ie", "co.za"]
|
||||
},
|
||||
"slow": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description": "Speak slowly"
|
||||
}
|
||||
}
|
||||
404
app/services/tts/service.py
Normal file
404
app/services/tts/service.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""TTS service implementation."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from gtts import gTTS
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.sound import Sound
|
||||
from app.models.tts import TTS
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.tts import TTSRepository
|
||||
from app.services.sound_normalizer import SoundNormalizerService
|
||||
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
|
||||
from .base import TTSProvider
|
||||
from .providers import GTTSProvider
|
||||
|
||||
# Constants
|
||||
MAX_TEXT_LENGTH = 1000
|
||||
MAX_NAME_LENGTH = 50
|
||||
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-Speech service with provider management."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize TTS service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
"""
|
||||
self.session = session
|
||||
self.sound_repo = SoundRepository(session)
|
||||
self.tts_repo = TTSRepository(session)
|
||||
self.providers: dict[str, TTSProvider] = {}
|
||||
|
||||
# Register default providers
|
||||
self._register_default_providers()
|
||||
|
||||
def _register_default_providers(self) -> None:
|
||||
"""Register default TTS providers."""
|
||||
self.register_provider(GTTSProvider())
|
||||
|
||||
def register_provider(self, provider: TTSProvider) -> None:
|
||||
"""Register a TTS provider.
|
||||
|
||||
Args:
|
||||
provider: TTS provider instance
|
||||
"""
|
||||
self.providers[provider.name] = provider
|
||||
|
||||
def get_providers(self) -> dict[str, TTSProvider]:
|
||||
"""Get all registered providers."""
|
||||
return self.providers.copy()
|
||||
|
||||
def get_provider(self, name: str) -> TTSProvider | None:
|
||||
"""Get a specific provider by name."""
|
||||
return self.providers.get(name)
|
||||
|
||||
async def create_tts_request(
|
||||
self,
|
||||
text: str,
|
||||
user_id: int,
|
||||
provider: str = "gtts",
|
||||
**options: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a TTS request that will be processed in the background.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
user_id: ID of user creating the sound
|
||||
provider: TTS provider name
|
||||
**options: Provider-specific options
|
||||
|
||||
Returns:
|
||||
Dictionary with TTS record information
|
||||
|
||||
Raises:
|
||||
ValueError: If provider not found or text too long
|
||||
Exception: If request creation fails
|
||||
"""
|
||||
provider_not_found_msg = f"Provider '{provider}' not found"
|
||||
if provider not in self.providers:
|
||||
raise ValueError(provider_not_found_msg)
|
||||
|
||||
text_too_long_msg = f"Text too long (max {MAX_TEXT_LENGTH} characters)"
|
||||
if len(text) > MAX_TEXT_LENGTH:
|
||||
raise ValueError(text_too_long_msg)
|
||||
|
||||
empty_text_msg = "Text cannot be empty"
|
||||
if not text.strip():
|
||||
raise ValueError(empty_text_msg)
|
||||
|
||||
# Create TTS record with pending status
|
||||
tts = TTS(
|
||||
text=text,
|
||||
provider=provider,
|
||||
options=options,
|
||||
sound_id=None, # Will be set when processing completes
|
||||
user_id=user_id,
|
||||
)
|
||||
self.session.add(tts)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(tts)
|
||||
|
||||
# Queue for background processing
|
||||
if tts.id is not None:
|
||||
await self._queue_tts_processing(tts.id)
|
||||
|
||||
return {"tts": tts, "message": "TTS generation queued successfully"}
|
||||
|
||||
async def _queue_tts_processing(self, tts_id: int) -> None:
|
||||
"""Queue TTS for background processing."""
|
||||
# For now, process immediately in a different way
|
||||
# This could be moved to a proper background queue later
|
||||
task = asyncio.create_task(self._process_tts_in_background(tts_id))
|
||||
# Store reference to prevent garbage collection
|
||||
self._background_tasks = getattr(self, '_background_tasks', set())
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._background_tasks.discard)
|
||||
|
||||
async def _process_tts_in_background(self, tts_id: int) -> None:
|
||||
"""Process TTS generation in background."""
|
||||
from app.core.database import get_session_factory
|
||||
|
||||
try:
|
||||
# Create a new session for background processing
|
||||
session_factory = get_session_factory()
|
||||
async with session_factory() as background_session:
|
||||
tts_service = TTSService(background_session)
|
||||
|
||||
# Get the TTS record
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await background_session.exec(stmt)
|
||||
tts = result.first()
|
||||
|
||||
if not tts:
|
||||
return
|
||||
|
||||
# Use a synchronous approach for the actual generation
|
||||
sound = await tts_service._generate_tts_sync(
|
||||
tts.text,
|
||||
tts.provider,
|
||||
tts.user_id,
|
||||
tts.options,
|
||||
)
|
||||
|
||||
# Update the TTS record with the sound ID
|
||||
if sound.id is not None:
|
||||
tts.sound_id = sound.id
|
||||
background_session.add(tts)
|
||||
await background_session.commit()
|
||||
|
||||
except Exception:
|
||||
# Log error but don't fail - avoiding print for production
|
||||
pass
|
||||
|
||||
async def _generate_tts_sync(
|
||||
self,
|
||||
text: str,
|
||||
provider: str,
|
||||
user_id: int,
|
||||
options: dict[str, Any]
|
||||
) -> Sound:
|
||||
"""Generate TTS using a synchronous approach."""
|
||||
# Generate the audio using the provider (avoid async issues by doing it directly)
|
||||
tts_provider = self.providers[provider]
|
||||
|
||||
# Create directories if they don't exist
|
||||
original_dir = Path("sounds/originals/text_to_speech")
|
||||
original_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create UUID filename
|
||||
sound_uuid = str(uuid.uuid4())
|
||||
original_filename = f"{sound_uuid}.{tts_provider.file_extension}"
|
||||
original_path = original_dir / original_filename
|
||||
|
||||
# Generate audio synchronously
|
||||
try:
|
||||
# Generate TTS audio
|
||||
lang = options.get("lang", "en")
|
||||
tld = options.get("tld", "com")
|
||||
slow = options.get("slow", False)
|
||||
|
||||
tts_instance = gTTS(text=text, lang=lang, tld=tld, slow=slow)
|
||||
fp = io.BytesIO()
|
||||
tts_instance.write_to_fp(fp)
|
||||
fp.seek(0)
|
||||
audio_bytes = fp.read()
|
||||
|
||||
# Save the file
|
||||
original_path.write_bytes(audio_bytes)
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
# Create Sound record with proper metadata
|
||||
sound = await self._create_sound_record_complete(
|
||||
original_path,
|
||||
text,
|
||||
provider,
|
||||
user_id
|
||||
)
|
||||
|
||||
# Normalize the sound
|
||||
await self._normalize_sound_safe(sound.id)
|
||||
|
||||
return sound
|
||||
|
||||
async def get_user_tts_history(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> list[TTS]:
|
||||
"""Get TTS history for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
limit: Maximum number of records
|
||||
offset: Offset for pagination
|
||||
|
||||
Returns:
|
||||
List of TTS records
|
||||
"""
|
||||
result = await self.tts_repo.get_by_user_id(user_id, limit, offset)
|
||||
return list(result)
|
||||
|
||||
async def _create_sound_record(
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
provider: str,
|
||||
user_id: int,
|
||||
file_hash: str
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio."""
|
||||
# Get audio metadata
|
||||
duration = get_audio_duration(audio_path)
|
||||
size = get_file_size(audio_path)
|
||||
|
||||
# Create sound data
|
||||
sound_data = {
|
||||
"type": "TTS",
|
||||
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
||||
"filename": audio_path.name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
"user_id": user_id,
|
||||
"is_deletable": True,
|
||||
"is_music": False, # TTS is speech, not music
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
return sound
|
||||
|
||||
async def _create_sound_record_simple(
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
provider: str,
|
||||
user_id: int
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio with minimal processing."""
|
||||
# Create sound data with basic info
|
||||
sound_data = {
|
||||
"type": "TTS",
|
||||
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
||||
"filename": audio_path.name,
|
||||
"duration": 0, # Skip duration calculation for now
|
||||
"size": 0, # Skip size calculation for now
|
||||
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
|
||||
"user_id": user_id,
|
||||
"is_deletable": True,
|
||||
"is_music": False, # TTS is speech, not music
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
return sound
|
||||
|
||||
async def _create_sound_record_complete(
|
||||
self,
|
||||
audio_path: Path,
|
||||
text: str,
|
||||
provider: str,
|
||||
user_id: int
|
||||
) -> Sound:
|
||||
"""Create a Sound record for the TTS audio with complete metadata."""
|
||||
# Get audio metadata
|
||||
duration = get_audio_duration(audio_path)
|
||||
size = get_file_size(audio_path)
|
||||
file_hash = get_file_hash(audio_path)
|
||||
|
||||
# Check if a sound with this hash already exists
|
||||
existing_sound = await self.sound_repo.get_by_hash(file_hash)
|
||||
|
||||
if existing_sound:
|
||||
# Clean up the temporary file since we have a duplicate
|
||||
if audio_path.exists():
|
||||
audio_path.unlink()
|
||||
return existing_sound
|
||||
|
||||
# Create sound data with complete metadata
|
||||
sound_data = {
|
||||
"type": "TTS",
|
||||
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
||||
"filename": audio_path.name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
"user_id": user_id,
|
||||
"is_deletable": True,
|
||||
"is_music": False, # TTS is speech, not music
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
return sound
|
||||
|
||||
async def _normalize_sound_safe(self, sound_id: int) -> None:
|
||||
"""Normalize the TTS sound with error handling."""
|
||||
try:
|
||||
# Get fresh sound object from database for normalization
|
||||
sound = await self.sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
return
|
||||
|
||||
normalizer_service = SoundNormalizerService(self.session)
|
||||
result = await normalizer_service.normalize_sound(sound)
|
||||
|
||||
if result["status"] == "error":
|
||||
print(f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception during TTS sound normalization {sound_id}: {e}")
|
||||
# Don't fail the TTS generation if normalization fails
|
||||
|
||||
async def _normalize_sound(self, sound_id: int) -> None:
|
||||
"""Normalize the TTS sound."""
|
||||
try:
|
||||
# Get fresh sound object from database for normalization
|
||||
sound = await self.sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
return
|
||||
|
||||
normalizer_service = SoundNormalizerService(self.session)
|
||||
result = await normalizer_service.normalize_sound(sound)
|
||||
|
||||
if result["status"] == "error":
|
||||
# Log warning but don't fail the TTS generation
|
||||
pass
|
||||
|
||||
except Exception:
|
||||
# Don't fail the TTS generation if normalization fails
|
||||
pass
|
||||
|
||||
async def delete_tts(self, tts_id: int, user_id: int) -> None:
|
||||
"""Delete a TTS generation and its associated sound and files."""
|
||||
# Get the TTS record
|
||||
tts = await self.tts_repo.get_by_id(tts_id)
|
||||
if not tts:
|
||||
raise ValueError(f"TTS with ID {tts_id} not found")
|
||||
|
||||
# Check ownership
|
||||
if tts.user_id != user_id:
|
||||
raise PermissionError("You don't have permission to delete this TTS generation")
|
||||
|
||||
# If there's an associated sound, delete it and its files
|
||||
if tts.sound_id:
|
||||
sound = await self.sound_repo.get_by_id(tts.sound_id)
|
||||
if sound:
|
||||
# Delete the sound files
|
||||
await self._delete_sound_files(sound)
|
||||
# Delete the sound record
|
||||
await self.sound_repo.delete(sound)
|
||||
|
||||
# Delete the TTS record
|
||||
await self.tts_repo.delete(tts)
|
||||
|
||||
async def _delete_sound_files(self, sound: Sound) -> None:
|
||||
"""Delete all files associated with a sound."""
|
||||
from pathlib import Path
|
||||
|
||||
# Delete original file
|
||||
original_path = Path("sounds/originals/text_to_speech") / sound.filename
|
||||
if original_path.exists():
|
||||
original_path.unlink()
|
||||
|
||||
# Delete normalized file if it exists
|
||||
if sound.normalized_filename:
|
||||
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
||||
if normalized_path.exists():
|
||||
normalized_path.unlink()
|
||||
Reference in New Issue
Block a user