feat: Add status and error fields to TTS model and implement background processing for TTS generations
This commit is contained in:
@@ -101,6 +101,7 @@ class TTSService:
|
||||
text=text,
|
||||
provider=provider,
|
||||
options=options,
|
||||
status="pending",
|
||||
sound_id=None, # Will be set when processing completes
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -108,9 +109,10 @@ class TTSService:
|
||||
await self.session.commit()
|
||||
await self.session.refresh(tts)
|
||||
|
||||
# Queue for background processing
|
||||
# Queue for background processing using the TTS processor
|
||||
if tts.id is not None:
|
||||
await self._queue_tts_processing(tts.id)
|
||||
from app.services.tts_processor import tts_processor
|
||||
await tts_processor.queue_tts(tts.id)
|
||||
|
||||
return {"tts": tts, "message": "TTS generation queued successfully"}
|
||||
|
||||
@@ -401,4 +403,120 @@ class TTSService:
|
||||
if sound.normalized_filename:
|
||||
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
||||
if normalized_path.exists():
|
||||
normalized_path.unlink()
|
||||
normalized_path.unlink()
|
||||
|
||||
async def get_pending_tts(self) -> list[TTS]:
|
||||
"""Get all pending TTS generations."""
|
||||
stmt = select(TTS).where(TTS.status == "pending").order_by(TTS.created_at)
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
|
||||
async def mark_tts_processing(self, tts_id: int) -> None:
|
||||
"""Mark a TTS generation as processing."""
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await self.session.exec(stmt)
|
||||
tts = result.first()
|
||||
if tts:
|
||||
tts.status = "processing"
|
||||
self.session.add(tts)
|
||||
await self.session.commit()
|
||||
|
||||
async def mark_tts_completed(self, tts_id: int, sound_id: int) -> None:
|
||||
"""Mark a TTS generation as completed."""
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await self.session.exec(stmt)
|
||||
tts = result.first()
|
||||
if tts:
|
||||
tts.status = "completed"
|
||||
tts.sound_id = sound_id
|
||||
tts.error = None
|
||||
self.session.add(tts)
|
||||
await self.session.commit()
|
||||
|
||||
async def mark_tts_failed(self, tts_id: int, error_message: str) -> None:
|
||||
"""Mark a TTS generation as failed."""
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await self.session.exec(stmt)
|
||||
tts = result.first()
|
||||
if tts:
|
||||
tts.status = "failed"
|
||||
tts.error = error_message
|
||||
self.session.add(tts)
|
||||
await self.session.commit()
|
||||
|
||||
async def reset_stuck_tts(self) -> int:
|
||||
"""Reset stuck TTS generations from processing back to pending."""
|
||||
stmt = select(TTS).where(TTS.status == "processing")
|
||||
result = await self.session.exec(stmt)
|
||||
stuck_tts = list(result.all())
|
||||
|
||||
for tts in stuck_tts:
|
||||
tts.status = "pending"
|
||||
tts.error = None
|
||||
self.session.add(tts)
|
||||
|
||||
await self.session.commit()
|
||||
return len(stuck_tts)
|
||||
|
||||
async def process_tts_generation(self, tts_id: int) -> None:
|
||||
"""Process a TTS generation (used by the processor)."""
|
||||
# Mark as processing
|
||||
await self.mark_tts_processing(tts_id)
|
||||
|
||||
try:
|
||||
# Get the TTS record
|
||||
stmt = select(TTS).where(TTS.id == tts_id)
|
||||
result = await self.session.exec(stmt)
|
||||
tts = result.first()
|
||||
|
||||
if not tts:
|
||||
raise ValueError(f"TTS with ID {tts_id} not found")
|
||||
|
||||
# Generate the TTS
|
||||
sound = await self._generate_tts_sync(
|
||||
tts.text,
|
||||
tts.provider,
|
||||
tts.user_id,
|
||||
tts.options,
|
||||
)
|
||||
|
||||
# Capture sound ID before session issues
|
||||
sound_id = sound.id
|
||||
|
||||
# Mark as completed
|
||||
await self.mark_tts_completed(tts_id, sound_id)
|
||||
|
||||
# Emit socket event for completion
|
||||
await self._emit_tts_event("tts_completed", tts_id, sound_id)
|
||||
|
||||
except Exception as e:
|
||||
# Mark as failed
|
||||
await self.mark_tts_failed(tts_id, str(e))
|
||||
|
||||
# Emit socket event for failure
|
||||
await self._emit_tts_event("tts_failed", tts_id, None, str(e))
|
||||
raise
|
||||
|
||||
async def _emit_tts_event(self, event: str, tts_id: int, sound_id: int | None = None, error: str | None = None) -> None:
|
||||
"""Emit a socket event for TTS status change."""
|
||||
try:
|
||||
from app.services.socket import socket_manager
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
data = {
|
||||
"tts_id": tts_id,
|
||||
"sound_id": sound_id,
|
||||
}
|
||||
if error:
|
||||
data["error"] = error
|
||||
|
||||
logger.info(f"Emitting TTS socket event: {event} with data: {data}")
|
||||
await socket_manager.broadcast_to_all(event, data)
|
||||
logger.info(f"Successfully emitted TTS socket event: {event}")
|
||||
except Exception as e:
|
||||
# Don't fail TTS processing if socket emission fails
|
||||
from app.core.logging import get_logger
|
||||
logger = get_logger(__name__)
|
||||
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True)
|
||||
Reference in New Issue
Block a user