"""Background TTS processor for handling TTS generation queue.""" import asyncio import contextlib from sqlmodel.ext.asyncio.session import AsyncSession from app.core.config import settings from app.core.database import engine from app.core.logging import get_logger from app.services.tts import TTSService logger = get_logger(__name__) class TTSProcessor: """Background processor for handling TTS generation queue with concurrency control.""" def __init__(self) -> None: """Initialize the TTS processor.""" self.max_concurrent = getattr(settings, 'TTS_MAX_CONCURRENT', 3) self.running_tts: set[int] = set() self.processing_lock = asyncio.Lock() self.shutdown_event = asyncio.Event() self.processor_task: asyncio.Task | None = None logger.info( "Initialized TTS processor with max concurrent: %d", self.max_concurrent, ) async def start(self) -> None: """Start the background TTS processor.""" if self.processor_task and not self.processor_task.done(): logger.warning("TTS processor is already running") return # Reset any stuck TTS generations from previous runs await self._reset_stuck_tts() self.shutdown_event.clear() self.processor_task = asyncio.create_task(self._process_queue()) logger.info("Started TTS processor") async def stop(self) -> None: """Stop the background TTS processor.""" logger.info("Stopping TTS processor...") self.shutdown_event.set() if self.processor_task and not self.processor_task.done(): try: await asyncio.wait_for(self.processor_task, timeout=30.0) except TimeoutError: logger.warning( "TTS processor did not stop gracefully, cancelling...", ) self.processor_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self.processor_task logger.info("TTS processor stopped") async def queue_tts(self, tts_id: int) -> None: """Queue a TTS generation for processing.""" async with self.processing_lock: if tts_id not in self.running_tts: logger.info("Queued TTS %d for processing", tts_id) # The processor will pick it up on the next cycle else: logger.warning( "TTS %d is already being processed", tts_id, ) async def _process_queue(self) -> None: """Process the TTS queue in the main processing loop.""" logger.info("Starting TTS queue processor") while not self.shutdown_event.is_set(): try: await self._process_pending_tts() # Wait before checking for new TTS generations try: await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0) break # Shutdown requested except TimeoutError: continue # Continue processing except Exception: logger.exception("Error in TTS queue processor") # Wait a bit before retrying to avoid tight error loops try: await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0) break # Shutdown requested except TimeoutError: continue logger.info("TTS queue processor stopped") async def _process_pending_tts(self) -> None: """Process pending TTS generations up to the concurrency limit.""" async with self.processing_lock: # Check how many slots are available available_slots = self.max_concurrent - len(self.running_tts) if available_slots <= 0: return # No available slots # Get pending TTS generations from database async with AsyncSession(engine) as session: tts_service = TTSService(session) pending_tts = await tts_service.get_pending_tts() # Filter out TTS that are already being processed available_tts = [ tts for tts in pending_tts if tts.id not in self.running_tts ] # Start processing up to available slots tts_to_start = available_tts[:available_slots] for tts in tts_to_start: tts_id = tts.id self.running_tts.add(tts_id) # Start processing this TTS in the background task = asyncio.create_task( self._process_single_tts(tts_id), ) task.add_done_callback( lambda t, tid=tts_id: self._on_tts_completed( tid, t, ), ) logger.info( "Started processing TTS %d (%d/%d slots used)", tts_id, len(self.running_tts), self.max_concurrent, ) async def _process_single_tts(self, tts_id: int) -> None: """Process a single TTS generation.""" try: async with AsyncSession(engine) as session: tts_service = TTSService(session) await tts_service.process_tts_generation(tts_id) logger.info("Successfully processed TTS %d", tts_id) except Exception: logger.exception("Failed to process TTS %d", tts_id) # Mark TTS as failed in database try: async with AsyncSession(engine) as session: tts_service = TTSService(session) await tts_service.mark_tts_failed(tts_id, "Processing failed") except Exception: logger.exception("Failed to mark TTS %d as failed", tts_id) def _on_tts_completed(self, tts_id: int, task: asyncio.Task) -> None: """Handle completion of a TTS processing task.""" self.running_tts.discard(tts_id) if task.exception(): logger.error( "TTS processing task %d failed: %s", tts_id, task.exception(), ) else: logger.info("TTS processing task %d completed successfully", tts_id) async def _reset_stuck_tts(self) -> None: """Reset any TTS generations that were stuck in 'processing' state.""" try: async with AsyncSession(engine) as session: tts_service = TTSService(session) reset_count = await tts_service.reset_stuck_tts() if reset_count > 0: logger.info("Reset %d stuck TTS generations", reset_count) else: logger.info("No stuck TTS generations found to reset") except Exception: logger.exception("Failed to reset stuck TTS generations") # Global TTS processor instance tts_processor = TTSProcessor()