194 lines
7.2 KiB
Python
194 lines
7.2 KiB
Python
"""Background TTS processor for handling TTS generation queue."""
|
|
|
|
import asyncio
|
|
import contextlib
|
|
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.database import engine
|
|
from app.core.logging import get_logger
|
|
from app.services.tts import TTSService
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class TTSProcessor:
|
|
"""Background processor for handling TTS generation queue with concurrency control."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the TTS processor."""
|
|
self.max_concurrent = getattr(settings, "TTS_MAX_CONCURRENT", 3)
|
|
self.running_tts: set[int] = set()
|
|
self.processing_lock = asyncio.Lock()
|
|
self.shutdown_event = asyncio.Event()
|
|
self.processor_task: asyncio.Task | None = None
|
|
|
|
logger.info(
|
|
"Initialized TTS processor with max concurrent: %d",
|
|
self.max_concurrent,
|
|
)
|
|
|
|
async def start(self) -> None:
|
|
"""Start the background TTS processor."""
|
|
if self.processor_task and not self.processor_task.done():
|
|
logger.warning("TTS processor is already running")
|
|
return
|
|
|
|
# Reset any stuck TTS generations from previous runs
|
|
await self._reset_stuck_tts()
|
|
|
|
self.shutdown_event.clear()
|
|
self.processor_task = asyncio.create_task(self._process_queue())
|
|
logger.info("Started TTS processor")
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the background TTS processor."""
|
|
logger.info("Stopping TTS processor...")
|
|
self.shutdown_event.set()
|
|
|
|
if self.processor_task and not self.processor_task.done():
|
|
try:
|
|
await asyncio.wait_for(self.processor_task, timeout=30.0)
|
|
except TimeoutError:
|
|
logger.warning(
|
|
"TTS processor did not stop gracefully, cancelling...",
|
|
)
|
|
self.processor_task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await self.processor_task
|
|
|
|
logger.info("TTS processor stopped")
|
|
|
|
async def queue_tts(self, tts_id: int) -> None:
|
|
"""Queue a TTS generation for processing."""
|
|
async with self.processing_lock:
|
|
if tts_id not in self.running_tts:
|
|
logger.info("Queued TTS %d for processing", tts_id)
|
|
# The processor will pick it up on the next cycle
|
|
else:
|
|
logger.warning(
|
|
"TTS %d is already being processed",
|
|
tts_id,
|
|
)
|
|
|
|
async def _process_queue(self) -> None:
|
|
"""Process the TTS queue in the main processing loop."""
|
|
logger.info("Starting TTS queue processor")
|
|
|
|
while not self.shutdown_event.is_set():
|
|
try:
|
|
await self._process_pending_tts()
|
|
|
|
# Wait before checking for new TTS generations
|
|
try:
|
|
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
|
|
break # Shutdown requested
|
|
except TimeoutError:
|
|
continue # Continue processing
|
|
|
|
except Exception:
|
|
logger.exception("Error in TTS queue processor")
|
|
# Wait a bit before retrying to avoid tight error loops
|
|
try:
|
|
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
|
|
break # Shutdown requested
|
|
except TimeoutError:
|
|
continue
|
|
|
|
logger.info("TTS queue processor stopped")
|
|
|
|
async def _process_pending_tts(self) -> None:
|
|
"""Process pending TTS generations up to the concurrency limit."""
|
|
async with self.processing_lock:
|
|
# Check how many slots are available
|
|
available_slots = self.max_concurrent - len(self.running_tts)
|
|
|
|
if available_slots <= 0:
|
|
return # No available slots
|
|
|
|
# Get pending TTS generations from database
|
|
async with AsyncSession(engine) as session:
|
|
tts_service = TTSService(session)
|
|
pending_tts = await tts_service.get_pending_tts()
|
|
|
|
# Filter out TTS that are already being processed
|
|
available_tts = [
|
|
tts
|
|
for tts in pending_tts
|
|
if tts.id not in self.running_tts
|
|
]
|
|
|
|
# Start processing up to available slots
|
|
tts_to_start = available_tts[:available_slots]
|
|
|
|
for tts in tts_to_start:
|
|
tts_id = tts.id
|
|
self.running_tts.add(tts_id)
|
|
|
|
# Start processing this TTS in the background
|
|
task = asyncio.create_task(
|
|
self._process_single_tts(tts_id),
|
|
)
|
|
task.add_done_callback(
|
|
lambda t, tid=tts_id: self._on_tts_completed(
|
|
tid,
|
|
t,
|
|
),
|
|
)
|
|
|
|
logger.info(
|
|
"Started processing TTS %d (%d/%d slots used)",
|
|
tts_id,
|
|
len(self.running_tts),
|
|
self.max_concurrent,
|
|
)
|
|
|
|
async def _process_single_tts(self, tts_id: int) -> None:
|
|
"""Process a single TTS generation."""
|
|
try:
|
|
async with AsyncSession(engine) as session:
|
|
tts_service = TTSService(session)
|
|
await tts_service.process_tts_generation(tts_id)
|
|
logger.info("Successfully processed TTS %d", tts_id)
|
|
|
|
except Exception:
|
|
logger.exception("Failed to process TTS %d", tts_id)
|
|
# Mark TTS as failed in database
|
|
try:
|
|
async with AsyncSession(engine) as session:
|
|
tts_service = TTSService(session)
|
|
await tts_service.mark_tts_failed(tts_id, "Processing failed")
|
|
except Exception:
|
|
logger.exception("Failed to mark TTS %d as failed", tts_id)
|
|
|
|
def _on_tts_completed(self, tts_id: int, task: asyncio.Task) -> None:
|
|
"""Handle completion of a TTS processing task."""
|
|
self.running_tts.discard(tts_id)
|
|
|
|
if task.exception():
|
|
logger.error(
|
|
"TTS processing task %d failed: %s",
|
|
tts_id,
|
|
task.exception(),
|
|
)
|
|
else:
|
|
logger.info("TTS processing task %d completed successfully", tts_id)
|
|
|
|
async def _reset_stuck_tts(self) -> None:
|
|
"""Reset any TTS generations that were stuck in 'processing' state."""
|
|
try:
|
|
async with AsyncSession(engine) as session:
|
|
tts_service = TTSService(session)
|
|
reset_count = await tts_service.reset_stuck_tts()
|
|
if reset_count > 0:
|
|
logger.info("Reset %d stuck TTS generations", reset_count)
|
|
else:
|
|
logger.info("No stuck TTS generations found to reset")
|
|
except Exception:
|
|
logger.exception("Failed to reset stuck TTS generations")
|
|
|
|
|
|
# Global TTS processor instance
|
|
tts_processor = TTSProcessor()
|