226 lines
6.5 KiB
Python
226 lines
6.5 KiB
Python
"""TTS API endpoints."""
|
|
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from pydantic import BaseModel, Field
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.database import get_db
|
|
from app.core.dependencies import get_current_active_user_flexible
|
|
from app.models.user import User
|
|
from app.services.tts import TTSService
|
|
|
|
router = APIRouter(prefix="/tts", tags=["tts"])
|
|
|
|
|
|
class TTSGenerateRequest(BaseModel):
|
|
"""TTS generation request model."""
|
|
|
|
text: str = Field(
|
|
..., min_length=1, max_length=1000, description="Text to convert to speech",
|
|
)
|
|
provider: str = Field(default="gtts", description="TTS provider to use")
|
|
options: dict[str, Any] = Field(
|
|
default_factory=dict, description="Provider-specific options",
|
|
)
|
|
|
|
|
|
class TTSResponse(BaseModel):
|
|
"""TTS generation response model."""
|
|
|
|
id: int
|
|
text: str
|
|
provider: str
|
|
options: dict[str, Any]
|
|
status: str
|
|
error: str | None
|
|
sound_id: int | None
|
|
user_id: int
|
|
created_at: str
|
|
|
|
|
|
class ProviderInfo(BaseModel):
|
|
"""Provider information model."""
|
|
|
|
name: str
|
|
file_extension: str
|
|
supported_languages: list[str]
|
|
option_schema: dict[str, Any]
|
|
|
|
|
|
async def get_tts_service(
|
|
session: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> TTSService:
|
|
"""Get the TTS service."""
|
|
return TTSService(session)
|
|
|
|
|
|
@router.get("")
|
|
async def get_tts_list(
|
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> list[TTSResponse]:
|
|
"""Get TTS list for the current user."""
|
|
try:
|
|
if current_user.id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User ID not available",
|
|
)
|
|
|
|
tts_records = await tts_service.get_user_tts_history(
|
|
user_id=current_user.id,
|
|
limit=limit,
|
|
offset=offset,
|
|
)
|
|
|
|
return [
|
|
TTSResponse(
|
|
id=tts.id,
|
|
text=tts.text,
|
|
provider=tts.provider,
|
|
options=tts.options,
|
|
status=tts.status,
|
|
error=tts.error,
|
|
sound_id=tts.sound_id,
|
|
user_id=tts.user_id,
|
|
created_at=tts.created_at.isoformat(),
|
|
)
|
|
for tts in tts_records
|
|
]
|
|
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to get TTS history: {e!s}",
|
|
) from e
|
|
|
|
|
|
@router.post("")
|
|
async def generate_tts(
|
|
request: TTSGenerateRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
|
) -> dict[str, Any]:
|
|
"""Generate TTS audio and create sound."""
|
|
try:
|
|
if current_user.id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User ID not available",
|
|
)
|
|
|
|
result = await tts_service.create_tts_request(
|
|
text=request.text,
|
|
user_id=current_user.id,
|
|
provider=request.provider,
|
|
**request.options,
|
|
)
|
|
|
|
tts_record = result["tts"]
|
|
|
|
return {
|
|
"message": result["message"],
|
|
"tts": TTSResponse(
|
|
id=tts_record.id,
|
|
text=tts_record.text,
|
|
provider=tts_record.provider,
|
|
options=tts_record.options,
|
|
status=tts_record.status,
|
|
error=tts_record.error,
|
|
sound_id=tts_record.sound_id,
|
|
user_id=tts_record.user_id,
|
|
created_at=tts_record.created_at.isoformat(),
|
|
),
|
|
}
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e),
|
|
) from e
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to generate TTS: {e!s}",
|
|
) from e
|
|
|
|
|
|
@router.get("/providers")
|
|
async def get_providers(
|
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
|
) -> dict[str, ProviderInfo]:
|
|
"""Get all available TTS providers."""
|
|
providers = tts_service.get_providers()
|
|
result = {}
|
|
|
|
for name, provider in providers.items():
|
|
result[name] = ProviderInfo(
|
|
name=provider.name,
|
|
file_extension=provider.file_extension,
|
|
supported_languages=provider.get_supported_languages(),
|
|
option_schema=provider.get_option_schema(),
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
@router.get("/providers/{provider_name}")
|
|
async def get_provider(
|
|
provider_name: str,
|
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
|
) -> ProviderInfo:
|
|
"""Get information about a specific TTS provider."""
|
|
provider = tts_service.get_provider(provider_name)
|
|
|
|
if not provider:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Provider '{provider_name}' not found",
|
|
)
|
|
|
|
return ProviderInfo(
|
|
name=provider.name,
|
|
file_extension=provider.file_extension,
|
|
supported_languages=provider.get_supported_languages(),
|
|
option_schema=provider.get_option_schema(),
|
|
)
|
|
|
|
|
|
@router.delete("/{tts_id}")
|
|
async def delete_tts(
|
|
tts_id: int,
|
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
|
tts_service: Annotated[TTSService, Depends(get_tts_service)],
|
|
) -> dict[str, str]:
|
|
"""Delete a TTS generation and its associated files."""
|
|
try:
|
|
if current_user.id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="User ID not available",
|
|
)
|
|
|
|
await tts_service.delete_tts(tts_id=tts_id, user_id=current_user.id)
|
|
|
|
return {"message": "TTS generation deleted successfully"}
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=str(e),
|
|
) from e
|
|
except PermissionError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=str(e),
|
|
) from e
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to delete TTS: {e!s}",
|
|
) from e
|