Files
sdb2-backend/app/services/tts_processor.py
JSC acdf191a5a
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m36s
refactor: Improve code readability and structure across TTS modules
2025-09-21 19:07:32 +02:00

194 lines
7.1 KiB
Python

"""Background TTS processor for handling TTS generation queue."""
import asyncio
import contextlib
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.database import engine
from app.core.logging import get_logger
from app.services.tts import TTSService
logger = get_logger(__name__)
class TTSProcessor:
"""Background processor for handling TTS generation queue."""
def __init__(self) -> None:
"""Initialize the TTS processor."""
self.max_concurrent = getattr(settings, "TTS_MAX_CONCURRENT", 3)
self.running_tts: set[int] = set()
self.processing_lock = asyncio.Lock()
self.shutdown_event = asyncio.Event()
self.processor_task: asyncio.Task | None = None
logger.info(
"Initialized TTS processor with max concurrent: %d",
self.max_concurrent,
)
async def start(self) -> None:
"""Start the background TTS processor."""
if self.processor_task and not self.processor_task.done():
logger.warning("TTS processor is already running")
return
# Reset any stuck TTS generations from previous runs
await self._reset_stuck_tts()
self.shutdown_event.clear()
self.processor_task = asyncio.create_task(self._process_queue())
logger.info("Started TTS processor")
async def stop(self) -> None:
"""Stop the background TTS processor."""
logger.info("Stopping TTS processor...")
self.shutdown_event.set()
if self.processor_task and not self.processor_task.done():
try:
await asyncio.wait_for(self.processor_task, timeout=30.0)
except TimeoutError:
logger.warning(
"TTS processor did not stop gracefully, cancelling...",
)
self.processor_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.processor_task
logger.info("TTS processor stopped")
async def queue_tts(self, tts_id: int) -> None:
"""Queue a TTS generation for processing."""
async with self.processing_lock:
if tts_id not in self.running_tts:
logger.info("Queued TTS %d for processing", tts_id)
# The processor will pick it up on the next cycle
else:
logger.warning(
"TTS %d is already being processed",
tts_id,
)
async def _process_queue(self) -> None:
"""Process the TTS queue in the main processing loop."""
logger.info("Starting TTS queue processor")
while not self.shutdown_event.is_set():
try:
await self._process_pending_tts()
# Wait before checking for new TTS generations
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
break # Shutdown requested
except TimeoutError:
continue # Continue processing
except Exception:
logger.exception("Error in TTS queue processor")
# Wait a bit before retrying to avoid tight error loops
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
break # Shutdown requested
except TimeoutError:
continue
logger.info("TTS queue processor stopped")
async def _process_pending_tts(self) -> None:
"""Process pending TTS generations up to the concurrency limit."""
async with self.processing_lock:
# Check how many slots are available
available_slots = self.max_concurrent - len(self.running_tts)
if available_slots <= 0:
return # No available slots
# Get pending TTS generations from database
async with AsyncSession(engine) as session:
tts_service = TTSService(session)
pending_tts = await tts_service.get_pending_tts()
# Filter out TTS that are already being processed
available_tts = [
tts
for tts in pending_tts
if tts.id not in self.running_tts
]
# Start processing up to available slots
tts_to_start = available_tts[:available_slots]
for tts in tts_to_start:
tts_id = tts.id
self.running_tts.add(tts_id)
# Start processing this TTS in the background
task = asyncio.create_task(
self._process_single_tts(tts_id),
)
task.add_done_callback(
lambda t, tid=tts_id: self._on_tts_completed(
tid,
t,
),
)
logger.info(
"Started processing TTS %d (%d/%d slots used)",
tts_id,
len(self.running_tts),
self.max_concurrent,
)
async def _process_single_tts(self, tts_id: int) -> None:
"""Process a single TTS generation."""
try:
async with AsyncSession(engine) as session:
tts_service = TTSService(session)
await tts_service.process_tts_generation(tts_id)
logger.info("Successfully processed TTS %d", tts_id)
except Exception:
logger.exception("Failed to process TTS %d", tts_id)
# Mark TTS as failed in database
try:
async with AsyncSession(engine) as session:
tts_service = TTSService(session)
await tts_service.mark_tts_failed(tts_id, "Processing failed")
except Exception:
logger.exception("Failed to mark TTS %d as failed", tts_id)
def _on_tts_completed(self, tts_id: int, task: asyncio.Task) -> None:
"""Handle completion of a TTS processing task."""
self.running_tts.discard(tts_id)
if task.exception():
logger.error(
"TTS processing task %d failed: %s",
tts_id,
task.exception(),
)
else:
logger.info("TTS processing task %d completed successfully", tts_id)
async def _reset_stuck_tts(self) -> None:
"""Reset any TTS generations that were stuck in 'processing' state."""
try:
async with AsyncSession(engine) as session:
tts_service = TTSService(session)
reset_count = await tts_service.reset_stuck_tts()
if reset_count > 0:
logger.info("Reset %d stuck TTS generations", reset_count)
else:
logger.info("No stuck TTS generations found to reset")
except Exception:
logger.exception("Failed to reset stuck TTS generations")
# Global TTS processor instance
tts_processor = TTSProcessor()