Files
sdb2-backend/app/api/v1/tts.py

216 lines
6.3 KiB
Python

"""TTS API endpoints."""
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.services.tts import TTSService
router = APIRouter(prefix="/tts", tags=["tts"])
class TTSGenerateRequest(BaseModel):
"""TTS generation request model."""
text: str = Field(..., min_length=1, max_length=1000, description="Text to convert to speech")
provider: str = Field(default="gtts", description="TTS provider to use")
options: dict[str, Any] = Field(default_factory=dict, description="Provider-specific options")
class TTSResponse(BaseModel):
"""TTS generation response model."""
id: int
text: str
provider: str
options: dict[str, Any]
sound_id: int | None
user_id: int
created_at: str
class ProviderInfo(BaseModel):
"""Provider information model."""
name: str
file_extension: str
supported_languages: list[str]
option_schema: dict[str, Any]
async def get_tts_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> TTSService:
"""Get the TTS service."""
return TTSService(session)
@router.post("")
async def generate_tts(
request: TTSGenerateRequest,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, Any]:
"""Generate TTS audio and create sound."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
result = await tts_service.create_tts_request(
text=request.text,
user_id=current_user.id,
provider=request.provider,
**request.options
)
tts_record = result["tts"]
return {
"message": result["message"],
"tts": TTSResponse(
id=tts_record.id,
text=tts_record.text,
provider=tts_record.provider,
options=tts_record.options,
sound_id=tts_record.sound_id,
user_id=tts_record.user_id,
created_at=tts_record.created_at.isoformat(),
)
}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate TTS: {e!s}",
) from e
@router.get("/providers")
async def get_providers(
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, ProviderInfo]:
"""Get all available TTS providers."""
providers = tts_service.get_providers()
result = {}
for name, provider in providers.items():
result[name] = ProviderInfo(
name=provider.name,
file_extension=provider.file_extension,
supported_languages=provider.get_supported_languages(),
option_schema=provider.get_option_schema(),
)
return result
@router.get("/providers/{provider_name}")
async def get_provider(
provider_name: str,
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> ProviderInfo:
"""Get information about a specific TTS provider."""
provider = tts_service.get_provider(provider_name)
if not provider:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Provider '{provider_name}' not found",
)
return ProviderInfo(
name=provider.name,
file_extension=provider.file_extension,
supported_languages=provider.get_supported_languages(),
option_schema=provider.get_option_schema(),
)
@router.get("")
async def get_tts_list(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
limit: int = 50,
offset: int = 0,
) -> list[TTSResponse]:
"""Get TTS generation history for the current user."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
tts_records = await tts_service.get_user_tts_history(
user_id=current_user.id,
limit=limit,
offset=offset,
)
return [
TTSResponse(
id=tts.id,
text=tts.text,
provider=tts.provider,
options=tts.options,
sound_id=tts.sound_id,
user_id=tts.user_id,
created_at=tts.created_at.isoformat(),
)
for tts in tts_records
]
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get TTS history: {e!s}",
) from e
@router.delete("/{tts_id}")
async def delete_tts(
tts_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
tts_service: Annotated[TTSService, Depends(get_tts_service)],
) -> dict[str, str]:
"""Delete a TTS generation and its associated files."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
return {"message": "TTS generation deleted successfully"}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
except PermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete TTS: {e!s}",
) from e