feat: Add status and error fields to TTS model and implement background processing for TTS generations
This commit is contained in:
@@ -0,0 +1,34 @@
|
|||||||
|
"""Add status and error fields to TTS table
|
||||||
|
|
||||||
|
Revision ID: 0d9b7f1c367f
|
||||||
|
Revises: e617c155eea9
|
||||||
|
Create Date: 2025-09-21 14:09:56.418372
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '0d9b7f1c367f'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = 'e617c155eea9'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('tts', sa.Column('status', sa.String(), nullable=False, server_default='pending'))
|
||||||
|
op.add_column('tts', sa.Column('error', sa.String(), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('tts', 'error')
|
||||||
|
op.drop_column('tts', 'status')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -33,6 +33,8 @@ class TTSResponse(BaseModel):
|
|||||||
text: str
|
text: str
|
||||||
provider: str
|
provider: str
|
||||||
options: dict[str, Any]
|
options: dict[str, Any]
|
||||||
|
status: str
|
||||||
|
error: str | None
|
||||||
sound_id: int | None
|
sound_id: int | None
|
||||||
user_id: int
|
user_id: int
|
||||||
created_at: str
|
created_at: str
|
||||||
@@ -81,6 +83,8 @@ async def get_tts_list(
|
|||||||
text=tts.text,
|
text=tts.text,
|
||||||
provider=tts.provider,
|
provider=tts.provider,
|
||||||
options=tts.options,
|
options=tts.options,
|
||||||
|
status=tts.status,
|
||||||
|
error=tts.error,
|
||||||
sound_id=tts.sound_id,
|
sound_id=tts.sound_id,
|
||||||
user_id=tts.user_id,
|
user_id=tts.user_id,
|
||||||
created_at=tts.created_at.isoformat(),
|
created_at=tts.created_at.isoformat(),
|
||||||
@@ -125,6 +129,8 @@ async def generate_tts(
|
|||||||
text=tts_record.text,
|
text=tts_record.text,
|
||||||
provider=tts_record.provider,
|
provider=tts_record.provider,
|
||||||
options=tts_record.options,
|
options=tts_record.options,
|
||||||
|
status=tts_record.status,
|
||||||
|
error=tts_record.error,
|
||||||
sound_id=tts_record.sound_id,
|
sound_id=tts_record.sound_id,
|
||||||
user_id=tts_record.user_id,
|
user_id=tts_record.user_id,
|
||||||
created_at=tts_record.created_at.isoformat(),
|
created_at=tts_record.created_at.isoformat(),
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from app.core.logging import get_logger, setup_logging
|
|||||||
from app.core.services import app_services
|
from app.core.services import app_services
|
||||||
from app.middleware.logging import LoggingMiddleware
|
from app.middleware.logging import LoggingMiddleware
|
||||||
from app.services.extraction_processor import extraction_processor
|
from app.services.extraction_processor import extraction_processor
|
||||||
|
from app.services.tts_processor import tts_processor
|
||||||
from app.services.player import (
|
from app.services.player import (
|
||||||
get_player_service,
|
get_player_service,
|
||||||
initialize_player_service,
|
initialize_player_service,
|
||||||
@@ -35,6 +36,10 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
|||||||
await extraction_processor.start()
|
await extraction_processor.start()
|
||||||
logger.info("Extraction processor started")
|
logger.info("Extraction processor started")
|
||||||
|
|
||||||
|
# Start the TTS processor
|
||||||
|
await tts_processor.start()
|
||||||
|
logger.info("TTS processor started")
|
||||||
|
|
||||||
# Start the player service
|
# Start the player service
|
||||||
await initialize_player_service(get_session_factory())
|
await initialize_player_service(get_session_factory())
|
||||||
logger.info("Player service started")
|
logger.info("Player service started")
|
||||||
@@ -65,6 +70,10 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
|||||||
await shutdown_player_service()
|
await shutdown_player_service()
|
||||||
logger.info("Player service stopped")
|
logger.info("Player service stopped")
|
||||||
|
|
||||||
|
# Stop the TTS processor
|
||||||
|
await tts_processor.stop()
|
||||||
|
logger.info("TTS processor stopped")
|
||||||
|
|
||||||
# Stop the extraction processor
|
# Stop the extraction processor
|
||||||
await extraction_processor.stop()
|
await extraction_processor.stop()
|
||||||
logger.info("Extraction processor stopped")
|
logger.info("Extraction processor stopped")
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ class TTS(SQLModel, table=True):
|
|||||||
sa_column=Column(JSON),
|
sa_column=Column(JSON),
|
||||||
description="Provider-specific options used"
|
description="Provider-specific options used"
|
||||||
)
|
)
|
||||||
|
status: str = Field(default="pending", description="Processing status")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
|
sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID")
|
||||||
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
|
user_id: int = Field(foreign_key="user.id", description="User who created the TTS")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ class TTSService:
|
|||||||
text=text,
|
text=text,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
options=options,
|
options=options,
|
||||||
|
status="pending",
|
||||||
sound_id=None, # Will be set when processing completes
|
sound_id=None, # Will be set when processing completes
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
@@ -108,9 +109,10 @@ class TTSService:
|
|||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
await self.session.refresh(tts)
|
await self.session.refresh(tts)
|
||||||
|
|
||||||
# Queue for background processing
|
# Queue for background processing using the TTS processor
|
||||||
if tts.id is not None:
|
if tts.id is not None:
|
||||||
await self._queue_tts_processing(tts.id)
|
from app.services.tts_processor import tts_processor
|
||||||
|
await tts_processor.queue_tts(tts.id)
|
||||||
|
|
||||||
return {"tts": tts, "message": "TTS generation queued successfully"}
|
return {"tts": tts, "message": "TTS generation queued successfully"}
|
||||||
|
|
||||||
@@ -401,4 +403,120 @@ class TTSService:
|
|||||||
if sound.normalized_filename:
|
if sound.normalized_filename:
|
||||||
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
||||||
if normalized_path.exists():
|
if normalized_path.exists():
|
||||||
normalized_path.unlink()
|
normalized_path.unlink()
|
||||||
|
|
||||||
|
async def get_pending_tts(self) -> list[TTS]:
|
||||||
|
"""Get all pending TTS generations."""
|
||||||
|
stmt = select(TTS).where(TTS.status == "pending").order_by(TTS.created_at)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def mark_tts_processing(self, tts_id: int) -> None:
|
||||||
|
"""Mark a TTS generation as processing."""
|
||||||
|
stmt = select(TTS).where(TTS.id == tts_id)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
tts = result.first()
|
||||||
|
if tts:
|
||||||
|
tts.status = "processing"
|
||||||
|
self.session.add(tts)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def mark_tts_completed(self, tts_id: int, sound_id: int) -> None:
|
||||||
|
"""Mark a TTS generation as completed."""
|
||||||
|
stmt = select(TTS).where(TTS.id == tts_id)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
tts = result.first()
|
||||||
|
if tts:
|
||||||
|
tts.status = "completed"
|
||||||
|
tts.sound_id = sound_id
|
||||||
|
tts.error = None
|
||||||
|
self.session.add(tts)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def mark_tts_failed(self, tts_id: int, error_message: str) -> None:
|
||||||
|
"""Mark a TTS generation as failed."""
|
||||||
|
stmt = select(TTS).where(TTS.id == tts_id)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
tts = result.first()
|
||||||
|
if tts:
|
||||||
|
tts.status = "failed"
|
||||||
|
tts.error = error_message
|
||||||
|
self.session.add(tts)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def reset_stuck_tts(self) -> int:
|
||||||
|
"""Reset stuck TTS generations from processing back to pending."""
|
||||||
|
stmt = select(TTS).where(TTS.status == "processing")
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
stuck_tts = list(result.all())
|
||||||
|
|
||||||
|
for tts in stuck_tts:
|
||||||
|
tts.status = "pending"
|
||||||
|
tts.error = None
|
||||||
|
self.session.add(tts)
|
||||||
|
|
||||||
|
await self.session.commit()
|
||||||
|
return len(stuck_tts)
|
||||||
|
|
||||||
|
async def process_tts_generation(self, tts_id: int) -> None:
|
||||||
|
"""Process a TTS generation (used by the processor)."""
|
||||||
|
# Mark as processing
|
||||||
|
await self.mark_tts_processing(tts_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the TTS record
|
||||||
|
stmt = select(TTS).where(TTS.id == tts_id)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
tts = result.first()
|
||||||
|
|
||||||
|
if not tts:
|
||||||
|
raise ValueError(f"TTS with ID {tts_id} not found")
|
||||||
|
|
||||||
|
# Generate the TTS
|
||||||
|
sound = await self._generate_tts_sync(
|
||||||
|
tts.text,
|
||||||
|
tts.provider,
|
||||||
|
tts.user_id,
|
||||||
|
tts.options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Capture sound ID before session issues
|
||||||
|
sound_id = sound.id
|
||||||
|
|
||||||
|
# Mark as completed
|
||||||
|
await self.mark_tts_completed(tts_id, sound_id)
|
||||||
|
|
||||||
|
# Emit socket event for completion
|
||||||
|
await self._emit_tts_event("tts_completed", tts_id, sound_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Mark as failed
|
||||||
|
await self.mark_tts_failed(tts_id, str(e))
|
||||||
|
|
||||||
|
# Emit socket event for failure
|
||||||
|
await self._emit_tts_event("tts_failed", tts_id, None, str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _emit_tts_event(self, event: str, tts_id: int, sound_id: int | None = None, error: str | None = None) -> None:
|
||||||
|
"""Emit a socket event for TTS status change."""
|
||||||
|
try:
|
||||||
|
from app.services.socket import socket_manager
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"tts_id": tts_id,
|
||||||
|
"sound_id": sound_id,
|
||||||
|
}
|
||||||
|
if error:
|
||||||
|
data["error"] = error
|
||||||
|
|
||||||
|
logger.info(f"Emitting TTS socket event: {event} with data: {data}")
|
||||||
|
await socket_manager.broadcast_to_all(event, data)
|
||||||
|
logger.info(f"Successfully emitted TTS socket event: {event}")
|
||||||
|
except Exception as e:
|
||||||
|
# Don't fail TTS processing if socket emission fails
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True)
|
||||||
193
app/services/tts_processor.py
Normal file
193
app/services/tts_processor.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""Background TTS processor for handling TTS generation queue."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.database import engine
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.services.tts import TTSService
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TTSProcessor:
|
||||||
|
"""Background processor for handling TTS generation queue with concurrency control."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the TTS processor."""
|
||||||
|
self.max_concurrent = getattr(settings, 'TTS_MAX_CONCURRENT', 3)
|
||||||
|
self.running_tts: set[int] = set()
|
||||||
|
self.processing_lock = asyncio.Lock()
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
self.processor_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Initialized TTS processor with max concurrent: %d",
|
||||||
|
self.max_concurrent,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the background TTS processor."""
|
||||||
|
if self.processor_task and not self.processor_task.done():
|
||||||
|
logger.warning("TTS processor is already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reset any stuck TTS generations from previous runs
|
||||||
|
await self._reset_stuck_tts()
|
||||||
|
|
||||||
|
self.shutdown_event.clear()
|
||||||
|
self.processor_task = asyncio.create_task(self._process_queue())
|
||||||
|
logger.info("Started TTS processor")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the background TTS processor."""
|
||||||
|
logger.info("Stopping TTS processor...")
|
||||||
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
if self.processor_task and not self.processor_task.done():
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.processor_task, timeout=30.0)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"TTS processor did not stop gracefully, cancelling...",
|
||||||
|
)
|
||||||
|
self.processor_task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await self.processor_task
|
||||||
|
|
||||||
|
logger.info("TTS processor stopped")
|
||||||
|
|
||||||
|
async def queue_tts(self, tts_id: int) -> None:
|
||||||
|
"""Queue a TTS generation for processing."""
|
||||||
|
async with self.processing_lock:
|
||||||
|
if tts_id not in self.running_tts:
|
||||||
|
logger.info("Queued TTS %d for processing", tts_id)
|
||||||
|
# The processor will pick it up on the next cycle
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"TTS %d is already being processed",
|
||||||
|
tts_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_queue(self) -> None:
|
||||||
|
"""Process the TTS queue in the main processing loop."""
|
||||||
|
logger.info("Starting TTS queue processor")
|
||||||
|
|
||||||
|
while not self.shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
await self._process_pending_tts()
|
||||||
|
|
||||||
|
# Wait before checking for new TTS generations
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
|
||||||
|
break # Shutdown requested
|
||||||
|
except TimeoutError:
|
||||||
|
continue # Continue processing
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error in TTS queue processor")
|
||||||
|
# Wait a bit before retrying to avoid tight error loops
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
|
||||||
|
break # Shutdown requested
|
||||||
|
except TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("TTS queue processor stopped")
|
||||||
|
|
||||||
|
async def _process_pending_tts(self) -> None:
|
||||||
|
"""Process pending TTS generations up to the concurrency limit."""
|
||||||
|
async with self.processing_lock:
|
||||||
|
# Check how many slots are available
|
||||||
|
available_slots = self.max_concurrent - len(self.running_tts)
|
||||||
|
|
||||||
|
if available_slots <= 0:
|
||||||
|
return # No available slots
|
||||||
|
|
||||||
|
# Get pending TTS generations from database
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
tts_service = TTSService(session)
|
||||||
|
pending_tts = await tts_service.get_pending_tts()
|
||||||
|
|
||||||
|
# Filter out TTS that are already being processed
|
||||||
|
available_tts = [
|
||||||
|
tts
|
||||||
|
for tts in pending_tts
|
||||||
|
if tts.id not in self.running_tts
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start processing up to available slots
|
||||||
|
tts_to_start = available_tts[:available_slots]
|
||||||
|
|
||||||
|
for tts in tts_to_start:
|
||||||
|
tts_id = tts.id
|
||||||
|
self.running_tts.add(tts_id)
|
||||||
|
|
||||||
|
# Start processing this TTS in the background
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._process_single_tts(tts_id),
|
||||||
|
)
|
||||||
|
task.add_done_callback(
|
||||||
|
lambda t, tid=tts_id: self._on_tts_completed(
|
||||||
|
tid,
|
||||||
|
t,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Started processing TTS %d (%d/%d slots used)",
|
||||||
|
tts_id,
|
||||||
|
len(self.running_tts),
|
||||||
|
self.max_concurrent,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_single_tts(self, tts_id: int) -> None:
|
||||||
|
"""Process a single TTS generation."""
|
||||||
|
try:
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
tts_service = TTSService(session)
|
||||||
|
await tts_service.process_tts_generation(tts_id)
|
||||||
|
logger.info("Successfully processed TTS %d", tts_id)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to process TTS %d", tts_id)
|
||||||
|
# Mark TTS as failed in database
|
||||||
|
try:
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
tts_service = TTSService(session)
|
||||||
|
await tts_service.mark_tts_failed(tts_id, "Processing failed")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to mark TTS %d as failed", tts_id)
|
||||||
|
|
||||||
|
def _on_tts_completed(self, tts_id: int, task: asyncio.Task) -> None:
|
||||||
|
"""Handle completion of a TTS processing task."""
|
||||||
|
self.running_tts.discard(tts_id)
|
||||||
|
|
||||||
|
if task.exception():
|
||||||
|
logger.error(
|
||||||
|
"TTS processing task %d failed: %s",
|
||||||
|
tts_id,
|
||||||
|
task.exception(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("TTS processing task %d completed successfully", tts_id)
|
||||||
|
|
||||||
|
async def _reset_stuck_tts(self) -> None:
|
||||||
|
"""Reset any TTS generations that were stuck in 'processing' state."""
|
||||||
|
try:
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
tts_service = TTSService(session)
|
||||||
|
reset_count = await tts_service.reset_stuck_tts()
|
||||||
|
if reset_count > 0:
|
||||||
|
logger.info("Reset %d stuck TTS generations", reset_count)
|
||||||
|
else:
|
||||||
|
logger.info("No stuck TTS generations found to reset")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to reset stuck TTS generations")
|
||||||
|
|
||||||
|
|
||||||
|
# Global TTS processor instance
|
||||||
|
tts_processor = TTSProcessor()
|
||||||
Reference in New Issue
Block a user