refactor: Clean up TTSService methods for improved readability and consistency
This commit is contained in:
@@ -112,6 +112,7 @@ class TTSService:
|
|||||||
# Queue for background processing using the TTS processor
|
# Queue for background processing using the TTS processor
|
||||||
if tts.id is not None:
|
if tts.id is not None:
|
||||||
from app.services.tts_processor import tts_processor
|
from app.services.tts_processor import tts_processor
|
||||||
|
|
||||||
await tts_processor.queue_tts(tts.id)
|
await tts_processor.queue_tts(tts.id)
|
||||||
|
|
||||||
return {"tts": tts, "message": "TTS generation queued successfully"}
|
return {"tts": tts, "message": "TTS generation queued successfully"}
|
||||||
@@ -122,7 +123,7 @@ class TTSService:
|
|||||||
# This could be moved to a proper background queue later
|
# This could be moved to a proper background queue later
|
||||||
task = asyncio.create_task(self._process_tts_in_background(tts_id))
|
task = asyncio.create_task(self._process_tts_in_background(tts_id))
|
||||||
# Store reference to prevent garbage collection
|
# Store reference to prevent garbage collection
|
||||||
self._background_tasks = getattr(self, '_background_tasks', set())
|
self._background_tasks = getattr(self, "_background_tasks", set())
|
||||||
self._background_tasks.add(task)
|
self._background_tasks.add(task)
|
||||||
task.add_done_callback(self._background_tasks.discard)
|
task.add_done_callback(self._background_tasks.discard)
|
||||||
|
|
||||||
@@ -163,11 +164,7 @@ class TTSService:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
async def _generate_tts_sync(
|
async def _generate_tts_sync(
|
||||||
self,
|
self, text: str, provider: str, user_id: int, options: dict[str, Any]
|
||||||
text: str,
|
|
||||||
provider: str,
|
|
||||||
user_id: int,
|
|
||||||
options: dict[str, Any]
|
|
||||||
) -> Sound:
|
) -> Sound:
|
||||||
"""Generate TTS using a synchronous approach."""
|
"""Generate TTS using a synchronous approach."""
|
||||||
# Generate the audio using the provider (avoid async issues by doing it directly)
|
# Generate the audio using the provider (avoid async issues by doing it directly)
|
||||||
@@ -203,10 +200,7 @@ class TTSService:
|
|||||||
|
|
||||||
# Create Sound record with proper metadata
|
# Create Sound record with proper metadata
|
||||||
sound = await self._create_sound_record_complete(
|
sound = await self._create_sound_record_complete(
|
||||||
original_path,
|
original_path, text, provider, user_id
|
||||||
text,
|
|
||||||
provider,
|
|
||||||
user_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalize the sound
|
# Normalize the sound
|
||||||
@@ -215,10 +209,7 @@ class TTSService:
|
|||||||
return sound
|
return sound
|
||||||
|
|
||||||
async def get_user_tts_history(
|
async def get_user_tts_history(
|
||||||
self,
|
self, user_id: int, limit: int = 50, offset: int = 0
|
||||||
user_id: int,
|
|
||||||
limit: int = 50,
|
|
||||||
offset: int = 0
|
|
||||||
) -> list[TTS]:
|
) -> list[TTS]:
|
||||||
"""Get TTS history for a user.
|
"""Get TTS history for a user.
|
||||||
|
|
||||||
@@ -234,22 +225,19 @@ class TTSService:
|
|||||||
return list(result)
|
return list(result)
|
||||||
|
|
||||||
async def _create_sound_record(
|
async def _create_sound_record(
|
||||||
self,
|
self, audio_path: Path, text: str, provider: str, user_id: int, file_hash: str
|
||||||
audio_path: Path,
|
|
||||||
text: str,
|
|
||||||
provider: str,
|
|
||||||
user_id: int,
|
|
||||||
file_hash: str
|
|
||||||
) -> Sound:
|
) -> Sound:
|
||||||
"""Create a Sound record for the TTS audio."""
|
"""Create a Sound record for the TTS audio."""
|
||||||
# Get audio metadata
|
# Get audio metadata
|
||||||
duration = get_audio_duration(audio_path)
|
duration = get_audio_duration(audio_path)
|
||||||
size = get_file_size(audio_path)
|
size = get_file_size(audio_path)
|
||||||
|
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
|
||||||
|
name = " ".join(word.capitalize() for word in name.split())
|
||||||
|
|
||||||
# Create sound data
|
# Create sound data
|
||||||
sound_data = {
|
sound_data = {
|
||||||
"type": "TTS",
|
"type": "TTS",
|
||||||
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
"name": name,
|
||||||
"filename": audio_path.name,
|
"filename": audio_path.name,
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
"size": size,
|
"size": size,
|
||||||
@@ -265,20 +253,19 @@ class TTSService:
|
|||||||
return sound
|
return sound
|
||||||
|
|
||||||
async def _create_sound_record_simple(
|
async def _create_sound_record_simple(
|
||||||
self,
|
self, audio_path: Path, text: str, provider: str, user_id: int
|
||||||
audio_path: Path,
|
|
||||||
text: str,
|
|
||||||
provider: str,
|
|
||||||
user_id: int
|
|
||||||
) -> Sound:
|
) -> Sound:
|
||||||
"""Create a Sound record for the TTS audio with minimal processing."""
|
"""Create a Sound record for the TTS audio with minimal processing."""
|
||||||
# Create sound data with basic info
|
# Create sound data with basic info
|
||||||
|
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
|
||||||
|
name = " ".join(word.capitalize() for word in name.split())
|
||||||
|
|
||||||
sound_data = {
|
sound_data = {
|
||||||
"type": "TTS",
|
"type": "TTS",
|
||||||
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
"name": name,
|
||||||
"filename": audio_path.name,
|
"filename": audio_path.name,
|
||||||
"duration": 0, # Skip duration calculation for now
|
"duration": 0, # Skip duration calculation for now
|
||||||
"size": 0, # Skip size calculation for now
|
"size": 0, # Skip size calculation for now
|
||||||
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
|
"hash": str(uuid.uuid4()), # Use UUID as temporary hash
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"is_deletable": True,
|
"is_deletable": True,
|
||||||
@@ -291,17 +278,15 @@ class TTSService:
|
|||||||
return sound
|
return sound
|
||||||
|
|
||||||
async def _create_sound_record_complete(
|
async def _create_sound_record_complete(
|
||||||
self,
|
self, audio_path: Path, text: str, provider: str, user_id: int
|
||||||
audio_path: Path,
|
|
||||||
text: str,
|
|
||||||
provider: str,
|
|
||||||
user_id: int
|
|
||||||
) -> Sound:
|
) -> Sound:
|
||||||
"""Create a Sound record for the TTS audio with complete metadata."""
|
"""Create a Sound record for the TTS audio with complete metadata."""
|
||||||
# Get audio metadata
|
# Get audio metadata
|
||||||
duration = get_audio_duration(audio_path)
|
duration = get_audio_duration(audio_path)
|
||||||
size = get_file_size(audio_path)
|
size = get_file_size(audio_path)
|
||||||
file_hash = get_file_hash(audio_path)
|
file_hash = get_file_hash(audio_path)
|
||||||
|
name = text[:MAX_NAME_LENGTH] + ("..." if len(text) > MAX_NAME_LENGTH else "")
|
||||||
|
name = " ".join(word.capitalize() for word in name.split())
|
||||||
|
|
||||||
# Check if a sound with this hash already exists
|
# Check if a sound with this hash already exists
|
||||||
existing_sound = await self.sound_repo.get_by_hash(file_hash)
|
existing_sound = await self.sound_repo.get_by_hash(file_hash)
|
||||||
@@ -315,7 +300,7 @@ class TTSService:
|
|||||||
# Create sound data with complete metadata
|
# Create sound data with complete metadata
|
||||||
sound_data = {
|
sound_data = {
|
||||||
"type": "TTS",
|
"type": "TTS",
|
||||||
"name": text[:50] + ("..." if len(text) > 50 else ""),
|
"name": name,
|
||||||
"filename": audio_path.name,
|
"filename": audio_path.name,
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
"size": size,
|
"size": size,
|
||||||
@@ -342,7 +327,9 @@ class TTSService:
|
|||||||
result = await normalizer_service.normalize_sound(sound)
|
result = await normalizer_service.normalize_sound(sound)
|
||||||
|
|
||||||
if result["status"] == "error":
|
if result["status"] == "error":
|
||||||
print(f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}")
|
print(
|
||||||
|
f"Warning: Failed to normalize TTS sound {sound_id}: {result.get('error')}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception during TTS sound normalization {sound_id}: {e}")
|
print(f"Exception during TTS sound normalization {sound_id}: {e}")
|
||||||
@@ -376,7 +363,9 @@ class TTSService:
|
|||||||
|
|
||||||
# Check ownership
|
# Check ownership
|
||||||
if tts.user_id != user_id:
|
if tts.user_id != user_id:
|
||||||
raise PermissionError("You don't have permission to delete this TTS generation")
|
raise PermissionError(
|
||||||
|
"You don't have permission to delete this TTS generation"
|
||||||
|
)
|
||||||
|
|
||||||
# If there's an associated sound, delete it and its files
|
# If there's an associated sound, delete it and its files
|
||||||
if tts.sound_id:
|
if tts.sound_id:
|
||||||
@@ -401,7 +390,9 @@ class TTSService:
|
|||||||
|
|
||||||
# Delete normalized file if it exists
|
# Delete normalized file if it exists
|
||||||
if sound.normalized_filename:
|
if sound.normalized_filename:
|
||||||
normalized_path = Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
normalized_path = (
|
||||||
|
Path("sounds/normalized/text_to_speech") / sound.normalized_filename
|
||||||
|
)
|
||||||
if normalized_path.exists():
|
if normalized_path.exists():
|
||||||
normalized_path.unlink()
|
normalized_path.unlink()
|
||||||
|
|
||||||
@@ -497,11 +488,17 @@ class TTSService:
|
|||||||
await self._emit_tts_event("tts_failed", tts_id, None, str(e))
|
await self._emit_tts_event("tts_failed", tts_id, None, str(e))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _emit_tts_event(self, event: str, tts_id: int, sound_id: int | None = None, error: str | None = None) -> None:
|
async def _emit_tts_event(
|
||||||
|
self,
|
||||||
|
event: str,
|
||||||
|
tts_id: int,
|
||||||
|
sound_id: int | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None:
|
||||||
"""Emit a socket event for TTS status change."""
|
"""Emit a socket event for TTS status change."""
|
||||||
try:
|
try:
|
||||||
from app.services.socket import socket_manager
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
from app.services.socket import socket_manager
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -518,5 +515,6 @@ class TTSService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Don't fail TTS processing if socket emission fails
|
# Don't fail TTS processing if socket emission fails
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True)
|
logger.error(f"Failed to emit TTS socket event {event}: {e}", exc_info=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user