diff --git a/alembic/versions/0d9b7f1c367f_add_status_and_error_fields_to_tts_table.py b/alembic/versions/0d9b7f1c367f_add_status_and_error_fields_to_tts_table.py new file mode 100644 index 0000000..66e8275 --- /dev/null +++ b/alembic/versions/0d9b7f1c367f_add_status_and_error_fields_to_tts_table.py @@ -0,0 +1,34 @@ +"""Add status and error fields to TTS table + +Revision ID: 0d9b7f1c367f +Revises: e617c155eea9 +Create Date: 2025-09-21 14:09:56.418372 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '0d9b7f1c367f' +down_revision: Union[str, Sequence[str], None] = 'e617c155eea9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('tts', sa.Column('status', sa.String(), nullable=False, server_default='pending')) + op.add_column('tts', sa.Column('error', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('tts', 'error') + op.drop_column('tts', 'status') + # ### end Alembic commands ### diff --git a/app/api/v1/tts.py b/app/api/v1/tts.py index 6021285..245351f 100644 --- a/app/api/v1/tts.py +++ b/app/api/v1/tts.py @@ -33,6 +33,8 @@ class TTSResponse(BaseModel): text: str provider: str options: dict[str, Any] + status: str + error: str | None sound_id: int | None user_id: int created_at: str @@ -81,6 +83,8 @@ async def get_tts_list( text=tts.text, provider=tts.provider, options=tts.options, + status=tts.status, + error=tts.error, sound_id=tts.sound_id, user_id=tts.user_id, created_at=tts.created_at.isoformat(), @@ -125,6 +129,8 @@ async def generate_tts( text=tts_record.text, provider=tts_record.provider, options=tts_record.options, + status=tts_record.status, + error=tts_record.error, sound_id=tts_record.sound_id, user_id=tts_record.user_id, created_at=tts_record.created_at.isoformat(), diff --git a/app/main.py b/app/main.py index fd1c422..a351fd7 100644 --- a/app/main.py +++ b/app/main.py @@ -12,6 +12,7 @@ from app.core.logging import get_logger, setup_logging from app.core.services import app_services from app.middleware.logging import LoggingMiddleware from app.services.extraction_processor import extraction_processor +from app.services.tts_processor import tts_processor from app.services.player import ( get_player_service, initialize_player_service, @@ -35,6 +36,10 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]: await extraction_processor.start() logger.info("Extraction processor started") + # Start the TTS processor + await tts_processor.start() + logger.info("TTS processor started") + # Start the player service await initialize_player_service(get_session_factory()) logger.info("Player service started") @@ -65,6 +70,10 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]: await shutdown_player_service() logger.info("Player service stopped") + # Stop the TTS processor + await tts_processor.stop() + logger.info("TTS processor stopped") + # Stop the extraction processor await extraction_processor.stop() logger.info("Extraction processor stopped") diff --git a/app/models/tts.py b/app/models/tts.py index 3dc1a66..830f087 100644 --- a/app/models/tts.py +++ b/app/models/tts.py @@ -20,6 +20,8 @@ class TTS(SQLModel, table=True): sa_column=Column(JSON), description="Provider-specific options used" ) + status: str = Field(default="pending", description="Processing status") + error: str | None = Field(default=None, description="Error message if failed") sound_id: int | None = Field(foreign_key="sound.id", description="Associated sound ID") user_id: int = Field(foreign_key="user.id", description="User who created the TTS") created_at: datetime = Field(default_factory=datetime.utcnow) diff --git a/app/services/tts/service.py b/app/services/tts/service.py index 843ef29..4168192 100644 --- a/app/services/tts/service.py +++ b/app/services/tts/service.py @@ -101,6 +101,7 @@ class TTSService: text=text, provider=provider, options=options, + status="pending", sound_id=None, # Will be set when processing completes user_id=user_id, ) @@ -108,9 +109,10 @@ class TTSService: await self.session.commit() await self.session.refresh(tts) - # Queue for background processing + # Queue for background processing using the TTS processor if tts.id is not None: - await self._queue_tts_processing(tts.id) + from app.services.tts_processor import tts_processor + await tts_processor.queue_tts(tts.id) return {"tts": tts, "message": "TTS generation queued successfully"} @@ -401,4 +403,120 @@ class TTSService: if sound.normalized_filename: normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename if normalized_path.exists(): - normalized_path.unlink() \ No newline at end of file + normalized_path.unlink() + + async def get_pending_tts(self) -> list[TTS]: + """Get all pending TTS generations.""" + stmt = select(TTS).where(TTS.status == "pending").order_by(TTS.created_at) + result = await self.session.exec(stmt) + return list(result.all()) + + async def mark_tts_processing(self, tts_id: int) -> None: + """Mark a TTS generation as processing.""" + stmt = select(TTS).where(TTS.id == tts_id) + result = await self.session.exec(stmt) + tts = result.first() + if tts: + tts.status = "processing" + self.session.add(tts) + await self.session.commit() + + async def mark_tts_completed(self, tts_id: int, sound_id: int) -> None: + """Mark a TTS generation as completed.""" + stmt = select(TTS).where(TTS.id == tts_id) + result = await self.session.exec(stmt) + tts = result.first() + if tts: + tts.status = "completed" + tts.sound_id = sound_id + tts.error = None + self.session.add(tts) + await self.session.commit() + + async def mark_tts_failed(self, tts_id: int, error_message: str) -> None: + """Mark a TTS generation as failed.""" + stmt = select(TTS).where(TTS.id == tts_id) + result = await self.session.exec(stmt) + tts = result.first() + if tts: + tts.status = "failed" + tts.error = error_message + self.session.add(tts) + await self.session.commit() + + async def reset_stuck_tts(self) -> int: + """Reset stuck TTS generations from processing back to pending.""" + stmt = select(TTS).where(TTS.status == "processing") + result = await self.session.exec(stmt) + stuck_tts = list(result.all()) + + for tts in stuck_tts: + tts.status = "pending" + tts.error = None + self.session.add(tts) + + await self.session.commit() + return len(stuck_tts) + + async def process_tts_generation(self, tts_id: int) -> None: + """Process a TTS generation (used by the processor).""" + # Mark as processing + await self.mark_tts_processing(tts_id) + + try: + # Get the TTS record + stmt = select(TTS).where(TTS.id == tts_id) + result = await self.session.exec(stmt) + tts = result.first() + + if not tts: + raise ValueError(f"TTS with ID {tts_id} not found") + + # Generate the TTS + sound = await self._generate_tts_sync( + tts.text, + tts.provider, + tts.user_id, + tts.options, + ) + + # Capture sound ID before session issues + sound_id = sound.id + + # Mark as completed + await self.mark_tts_completed(tts_id, sound_id) + + # Emit socket event for completion + await self._emit_tts_event("tts_completed", tts_id, sound_id) + + except Exception as e: + # Mark as failed + await self.mark_tts_failed(tts_id, str(e)) + + # Emit socket event for failure + await self._emit_tts_event("tts_failed", tts_id, None, str(e)) + raise + + async def _emit_tts_event(self, event: str, tts_id: int, sound_id: int | None = None, error: str | None = None) -> None: + """Emit a socket event for TTS status change.""" + try: + from app.services.socket import socket_manager + from app.core.logging import get_logger + + logger = get_logger(__name__) + + data = { + "tts_id": tts_id, + "sound_id": sound_id, + } + if error: + data["error"] = error + + logger.info(f"Emitting TTS socket event: {event} with data: {data}") + await socket_manager.broadcast_to_all(event, data) + logger.info(f"Successfully emitted TTS socket event: {event}") + except Exception as e: + # Don't fail TTS processing if socket emission fails + from app.core.logging import get_logger + logger = get_logger(__name__) + logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True) \ No newline at end of file diff --git a/app/services/tts_processor.py b/app/services/tts_processor.py new file mode 100644 index 0000000..1497568 --- /dev/null +++ b/app/services/tts_processor.py @@ -0,0 +1,193 @@ +"""Background TTS processor for handling TTS generation queue.""" + +import asyncio +import contextlib + +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.config import settings +from app.core.database import engine +from app.core.logging import get_logger +from app.services.tts import TTSService + +logger = get_logger(__name__) + + +class TTSProcessor: + """Background processor for handling TTS generation queue with concurrency control.""" + + def __init__(self) -> None: + """Initialize the TTS processor.""" + self.max_concurrent = getattr(settings, 'TTS_MAX_CONCURRENT', 3) + self.running_tts: set[int] = set() + self.processing_lock = asyncio.Lock() + self.shutdown_event = asyncio.Event() + self.processor_task: asyncio.Task | None = None + + logger.info( + "Initialized TTS processor with max concurrent: %d", + self.max_concurrent, + ) + + async def start(self) -> None: + """Start the background TTS processor.""" + if self.processor_task and not self.processor_task.done(): + logger.warning("TTS processor is already running") + return + + # Reset any stuck TTS generations from previous runs + await self._reset_stuck_tts() + + self.shutdown_event.clear() + self.processor_task = asyncio.create_task(self._process_queue()) + logger.info("Started TTS processor") + + async def stop(self) -> None: + """Stop the background TTS processor.""" + logger.info("Stopping TTS processor...") + self.shutdown_event.set() + + if self.processor_task and not self.processor_task.done(): + try: + await asyncio.wait_for(self.processor_task, timeout=30.0) + except TimeoutError: + logger.warning( + "TTS processor did not stop gracefully, cancelling...", + ) + self.processor_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self.processor_task + + logger.info("TTS processor stopped") + + async def queue_tts(self, tts_id: int) -> None: + """Queue a TTS generation for processing.""" + async with self.processing_lock: + if tts_id not in self.running_tts: + logger.info("Queued TTS %d for processing", tts_id) + # The processor will pick it up on the next cycle + else: + logger.warning( + "TTS %d is already being processed", + tts_id, + ) + + async def _process_queue(self) -> None: + """Process the TTS queue in the main processing loop.""" + logger.info("Starting TTS queue processor") + + while not self.shutdown_event.is_set(): + try: + await self._process_pending_tts() + + # Wait before checking for new TTS generations + try: + await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0) + break # Shutdown requested + except TimeoutError: + continue # Continue processing + + except Exception: + logger.exception("Error in TTS queue processor") + # Wait a bit before retrying to avoid tight error loops + try: + await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0) + break # Shutdown requested + except TimeoutError: + continue + + logger.info("TTS queue processor stopped") + + async def _process_pending_tts(self) -> None: + """Process pending TTS generations up to the concurrency limit.""" + async with self.processing_lock: + # Check how many slots are available + available_slots = self.max_concurrent - len(self.running_tts) + + if available_slots <= 0: + return # No available slots + + # Get pending TTS generations from database + async with AsyncSession(engine) as session: + tts_service = TTSService(session) + pending_tts = await tts_service.get_pending_tts() + + # Filter out TTS that are already being processed + available_tts = [ + tts + for tts in pending_tts + if tts.id not in self.running_tts + ] + + # Start processing up to available slots + tts_to_start = available_tts[:available_slots] + + for tts in tts_to_start: + tts_id = tts.id + self.running_tts.add(tts_id) + + # Start processing this TTS in the background + task = asyncio.create_task( + self._process_single_tts(tts_id), + ) + task.add_done_callback( + lambda t, tid=tts_id: self._on_tts_completed( + tid, + t, + ), + ) + + logger.info( + "Started processing TTS %d (%d/%d slots used)", + tts_id, + len(self.running_tts), + self.max_concurrent, + ) + + async def _process_single_tts(self, tts_id: int) -> None: + """Process a single TTS generation.""" + try: + async with AsyncSession(engine) as session: + tts_service = TTSService(session) + await tts_service.process_tts_generation(tts_id) + logger.info("Successfully processed TTS %d", tts_id) + + except Exception: + logger.exception("Failed to process TTS %d", tts_id) + # Mark TTS as failed in database + try: + async with AsyncSession(engine) as session: + tts_service = TTSService(session) + await tts_service.mark_tts_failed(tts_id, "Processing failed") + except Exception: + logger.exception("Failed to mark TTS %d as failed", tts_id) + + def _on_tts_completed(self, tts_id: int, task: asyncio.Task) -> None: + """Handle completion of a TTS processing task.""" + self.running_tts.discard(tts_id) + + if task.exception(): + logger.error( + "TTS processing task %d failed: %s", + tts_id, + task.exception(), + ) + else: + logger.info("TTS processing task %d completed successfully", tts_id) + + async def _reset_stuck_tts(self) -> None: + """Reset any TTS generations that were stuck in 'processing' state.""" + try: + async with AsyncSession(engine) as session: + tts_service = TTSService(session) + reset_count = await tts_service.reset_stuck_tts() + if reset_count > 0: + logger.info("Reset %d stuck TTS generations", reset_count) + else: + logger.info("No stuck TTS generations found to reset") + except Exception: + logger.exception("Failed to reset stuck TTS generations") + + +# Global TTS processor instance +tts_processor = TTSProcessor() \ No newline at end of file