"""Sound repository for database operations.""" from datetime import datetime from enum import Enum from sqlalchemy import func from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.sound import Sound from app.models.sound_played import SoundPlayed from app.repositories.base import BaseRepository logger = get_logger(__name__) class SoundSortField(str, Enum): """Sound sort field enumeration.""" NAME = "name" FILENAME = "filename" DURATION = "duration" SIZE = "size" TYPE = "type" PLAY_COUNT = "play_count" CREATED_AT = "created_at" UPDATED_AT = "updated_at" class SortOrder(str, Enum): """Sort order enumeration.""" ASC = "asc" DESC = "desc" class SoundRepository(BaseRepository[Sound]): """Repository for sound operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the sound repository.""" super().__init__(Sound, session) async def get_by_filename(self, filename: str) -> Sound | None: """Get a sound by filename.""" try: statement = select(Sound).where(Sound.filename == filename) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get sound by filename: %s", filename) raise async def get_by_hash(self, hash_value: str) -> Sound | None: """Get a sound by hash.""" try: statement = select(Sound).where(Sound.hash == hash_value) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get sound by hash") raise async def get_by_type(self, sound_type: str) -> list[Sound]: """Get all sounds by type.""" try: statement = select(Sound).where(Sound.type == sound_type) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get sounds by type: %s", sound_type) raise async def search_by_name(self, query: str) -> list[Sound]: """Search sounds by name (case-insensitive).""" try: statement = select(Sound).where( func.lower(Sound.name).like(f"%{query.lower()}%"), ) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to search sounds by name: %s", query) raise async def get_popular_sounds(self, limit: int = 10) -> list[Sound]: """Get the most played sounds.""" try: statement = select(Sound).order_by(Sound.play_count.desc()).limit(limit) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get popular sounds") raise async def get_unnormalized_sounds(self) -> list[Sound]: """Get all sounds that haven't been normalized yet.""" try: statement = select(Sound).where(Sound.is_normalized == False) # noqa: E712 result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get unnormalized sounds") raise async def get_unnormalized_sounds_by_type(self, sound_type: str) -> list[Sound]: """Get unnormalized sounds by type.""" try: statement = select(Sound).where( Sound.type == sound_type, Sound.is_normalized == False, # noqa: E712 ) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception( "Failed to get unnormalized sounds by type: %s", sound_type, ) raise async def get_by_types(self, sound_types: list[str] | None = None) -> list[Sound]: """Get sounds by types. If types is None or empty, return all sounds.""" try: statement = select(Sound) if sound_types: statement = statement.where(col(Sound.type).in_(sound_types)) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get sounds by types: %s", sound_types) raise async def search_and_sort( # noqa: PLR0913 self, search_query: str | None = None, sound_types: list[str] | None = None, sort_by: SoundSortField | None = None, sort_order: SortOrder = SortOrder.ASC, limit: int | None = None, offset: int = 0, ) -> list[Sound]: """Search and sort sounds with optional filtering.""" try: statement = select(Sound) # Apply type filter if sound_types: statement = statement.where(col(Sound.type).in_(sound_types)) # Apply search filter if search_query and search_query.strip(): search_pattern = f"%{search_query.strip().lower()}%" statement = statement.where( func.lower(Sound.name).like(search_pattern), ) # Apply sorting if sort_by: sort_column = getattr(Sound, sort_by.value) if sort_order == SortOrder.DESC: statement = statement.order_by(sort_column.desc()) else: statement = statement.order_by(sort_column.asc()) else: # Default sorting by name ascending statement = statement.order_by(Sound.name.asc()) # Apply pagination if offset > 0: statement = statement.offset(offset) if limit is not None: statement = statement.limit(limit) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception( ( "Failed to search and sort sounds: " "query=%s, types=%s, sort_by=%s, sort_order=%s" ), search_query, sound_types, sort_by, sort_order, ) raise async def get_soundboard_statistics(self) -> dict[str, int | float]: """Get statistics for SDB type sounds.""" try: statement = select( func.count(Sound.id).label("count"), func.sum(Sound.play_count).label("total_plays"), func.sum(Sound.duration).label("total_duration"), func.sum( Sound.size + func.coalesce(Sound.normalized_size, 0), ).label("total_size"), ).where(Sound.type == "SDB") result = await self.session.exec(statement) row = result.first() except Exception: logger.exception("Failed to get soundboard statistics") raise else: return { "count": row.count if row.count is not None else 0, "total_plays": row.total_plays if row.total_plays is not None else 0, "total_duration": ( row.total_duration if row.total_duration is not None else 0 ), "total_size": row.total_size if row.total_size is not None else 0, } async def get_track_statistics(self) -> dict[str, int | float]: """Get statistics for EXT type sounds.""" try: statement = select( func.count(Sound.id).label("count"), func.sum(Sound.play_count).label("total_plays"), func.sum(Sound.duration).label("total_duration"), func.sum( Sound.size + func.coalesce(Sound.normalized_size, 0), ).label("total_size"), ).where(Sound.type == "EXT") result = await self.session.exec(statement) row = result.first() except Exception: logger.exception("Failed to get track statistics") raise else: return { "count": row.count if row.count is not None else 0, "total_plays": row.total_plays if row.total_plays is not None else 0, "total_duration": ( row.total_duration if row.total_duration is not None else 0 ), "total_size": row.total_size if row.total_size is not None else 0, } async def get_top_sounds( self, sound_type: str, date_filter: datetime | None = None, limit: int = 10, ) -> list[dict]: """Get top sounds by play count for a specific type and period.""" try: # Join SoundPlayed with Sound and count plays within the period statement = ( select( Sound.id, Sound.name, Sound.type, Sound.duration, Sound.created_at, func.count(SoundPlayed.id).label("play_count"), ) .select_from(SoundPlayed) .join(Sound, SoundPlayed.sound_id == Sound.id) ) # Apply sound type filter if sound_type != "all": statement = statement.where(Sound.type == sound_type.upper()) # Apply date filter if provided if date_filter: statement = statement.where(SoundPlayed.created_at >= date_filter) # Group by sound and order by play count descending statement = ( statement .group_by( Sound.id, Sound.name, Sound.type, Sound.duration, Sound.created_at, ) .order_by(func.count(SoundPlayed.id).desc()) .limit(limit) ) result = await self.session.exec(statement) rows = result.all() # Convert to dictionaries with the play count from the period return [ { "id": row.id, "name": row.name, "type": row.type, "play_count": row.play_count, "duration": row.duration, "created_at": row.created_at, } for row in rows ] except Exception: logger.exception( "Failed to get top sounds: type=%s, date_filter=%s, limit=%s", sound_type, date_filter, limit, ) raise