"""Playlist repository for database operations.""" from enum import Enum from sqlalchemy import func from sqlalchemy.orm import selectinload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.playlist import Playlist from app.models.playlist_sound import PlaylistSound from app.models.sound import Sound from app.models.user import User from app.repositories.base import BaseRepository logger = get_logger(__name__) class PlaylistSortField(str, Enum): """Playlist sort field enumeration.""" NAME = "name" GENRE = "genre" CREATED_AT = "created_at" UPDATED_AT = "updated_at" SOUND_COUNT = "sound_count" TOTAL_DURATION = "total_duration" class SortOrder(str, Enum): """Sort order enumeration.""" ASC = "asc" DESC = "desc" class PlaylistRepository(BaseRepository[Playlist]): """Repository for playlist operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the playlist repository.""" super().__init__(Playlist, session) async def get_by_name(self, name: str) -> Playlist | None: """Get a playlist by name.""" try: statement = select(Playlist).where(Playlist.name == name) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get playlist by name: %s", name) raise async def get_by_user_id(self, user_id: int) -> list[Playlist]: """Get all playlists for a user.""" try: statement = select(Playlist).where(Playlist.user_id == user_id) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get playlists for user: %s", user_id) raise async def get_main_playlist(self) -> Playlist | None: """Get the global main playlist.""" try: statement = select(Playlist).where( Playlist.is_main == True, # noqa: E712 ) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get main playlist") raise async def get_current_playlist(self) -> Playlist | None: """Get the global current playlist (app-wide).""" try: statement = select(Playlist).where( Playlist.is_current == True, # noqa: E712 ) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get current playlist") raise async def search_by_name( self, query: str, user_id: int | None = None, ) -> list[Playlist]: """Search playlists by name (case-insensitive).""" try: statement = select(Playlist).where( func.lower(Playlist.name).like(f"%{query.lower()}%"), ) if user_id is not None: statement = statement.where(Playlist.user_id == user_id) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to search playlists by name: %s", query) raise async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]: """Get all sounds in a playlist with extractions, ordered by position.""" try: statement = ( select(Sound) .join(PlaylistSound) .options(selectinload(Sound.extractions)) .where(PlaylistSound.playlist_id == playlist_id) .order_by(PlaylistSound.position) ) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get sounds for playlist: %s", playlist_id) raise async def add_sound_to_playlist( self, playlist_id: int, sound_id: int, position: int | None = None, ) -> PlaylistSound: """Add a sound to a playlist.""" try: if position is None: # Get the next available position statement = select( func.coalesce(func.max(PlaylistSound.position), -1) + 1, ).where(PlaylistSound.playlist_id == playlist_id) result = await self.session.exec(statement) position = result.first() or 0 playlist_sound = PlaylistSound( playlist_id=playlist_id, sound_id=sound_id, position=position, ) self.session.add(playlist_sound) await self.session.commit() await self.session.refresh(playlist_sound) except Exception: await self.session.rollback() logger.exception( "Failed to add sound %s to playlist %s", sound_id, playlist_id, ) raise else: logger.info( "Added sound %s to playlist %s at position %s", sound_id, playlist_id, position, ) return playlist_sound async def remove_sound_from_playlist(self, playlist_id: int, sound_id: int) -> None: """Remove a sound from a playlist.""" try: statement = select(PlaylistSound).where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.sound_id == sound_id, ) result = await self.session.exec(statement) playlist_sound = result.first() if playlist_sound: await self.session.delete(playlist_sound) await self.session.commit() logger.info("Removed sound %s from playlist %s", sound_id, playlist_id) except Exception: await self.session.rollback() logger.exception( "Failed to remove sound %s from playlist %s", sound_id, playlist_id, ) raise async def reorder_playlist_sounds( self, playlist_id: int, sound_positions: list[tuple[int, int]], ) -> None: """Reorder sounds in a playlist. Args: playlist_id: The playlist ID sound_positions: List of (sound_id, new_position) tuples """ try: # Phase 1: Set all positions to temporary negative values to avoid conflicts temp_offset = -10000 # Use large negative number to avoid conflicts for i, (sound_id, _) in enumerate(sound_positions): statement = select(PlaylistSound).where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.sound_id == sound_id, ) result = await self.session.exec(statement) playlist_sound = result.first() if playlist_sound: playlist_sound.position = temp_offset + i # Phase 2: Set the final positions for sound_id, new_position in sound_positions: statement = select(PlaylistSound).where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.sound_id == sound_id, ) result = await self.session.exec(statement) playlist_sound = result.first() if playlist_sound: playlist_sound.position = new_position await self.session.commit() logger.info("Reordered sounds in playlist %s", playlist_id) except Exception: await self.session.rollback() logger.exception("Failed to reorder sounds in playlist %s", playlist_id) raise async def get_playlist_sound_count(self, playlist_id: int) -> int: """Get the number of sounds in a playlist.""" try: statement = select(func.count(PlaylistSound.id)).where( PlaylistSound.playlist_id == playlist_id, ) result = await self.session.exec(statement) return result.first() or 0 except Exception: logger.exception("Failed to get sound count for playlist: %s", playlist_id) raise async def is_sound_in_playlist(self, playlist_id: int, sound_id: int) -> bool: """Check if a sound is already in a playlist.""" try: statement = select(PlaylistSound).where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.sound_id == sound_id, ) result = await self.session.exec(statement) return result.first() is not None except Exception: logger.exception( "Failed to check if sound %s is in playlist %s", sound_id, playlist_id, ) raise async def search_and_sort( self, search_query: str | None = None, sort_by: PlaylistSortField | None = None, sort_order: SortOrder = SortOrder.ASC, user_id: int | None = None, include_stats: bool = False, limit: int | None = None, offset: int = 0, ) -> list[dict]: """Search and sort playlists with optional statistics.""" try: if include_stats and sort_by in (PlaylistSortField.SOUND_COUNT, PlaylistSortField.TOTAL_DURATION): # Use subquery for sorting by stats subquery = ( select( Playlist.id, Playlist.name, Playlist.description, Playlist.genre, Playlist.user_id, Playlist.is_main, Playlist.is_current, Playlist.is_deletable, Playlist.created_at, Playlist.updated_at, func.count(PlaylistSound.id).label("sound_count"), func.coalesce(func.sum(Sound.duration), 0).label("total_duration"), User.name.label("user_name"), ) .select_from(Playlist) .join(User, Playlist.user_id == User.id, isouter=True) .join(PlaylistSound, Playlist.id == PlaylistSound.playlist_id, isouter=True) .join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True) .group_by(Playlist.id, User.name) ) # Apply filters if search_query and search_query.strip(): search_pattern = f"%{search_query.strip().lower()}%" subquery = subquery.where(func.lower(Playlist.name).like(search_pattern)) if user_id is not None: subquery = subquery.where(Playlist.user_id == user_id) # Apply sorting if sort_by == PlaylistSortField.SOUND_COUNT: if sort_order == SortOrder.DESC: subquery = subquery.order_by(func.count(PlaylistSound.id).desc()) else: subquery = subquery.order_by(func.count(PlaylistSound.id).asc()) elif sort_by == PlaylistSortField.TOTAL_DURATION: if sort_order == SortOrder.DESC: subquery = subquery.order_by(func.coalesce(func.sum(Sound.duration), 0).desc()) else: subquery = subquery.order_by(func.coalesce(func.sum(Sound.duration), 0).asc()) else: # Default sorting by name subquery = subquery.order_by(Playlist.name.asc()) else: # Simple query without stats-based sorting subquery = ( select( Playlist.id, Playlist.name, Playlist.description, Playlist.genre, Playlist.user_id, Playlist.is_main, Playlist.is_current, Playlist.is_deletable, Playlist.created_at, Playlist.updated_at, func.count(PlaylistSound.id).label("sound_count"), func.coalesce(func.sum(Sound.duration), 0).label("total_duration"), User.name.label("user_name"), ) .select_from(Playlist) .join(User, Playlist.user_id == User.id, isouter=True) .join(PlaylistSound, Playlist.id == PlaylistSound.playlist_id, isouter=True) .join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True) .group_by(Playlist.id, User.name) ) # Apply filters if search_query and search_query.strip(): search_pattern = f"%{search_query.strip().lower()}%" subquery = subquery.where(func.lower(Playlist.name).like(search_pattern)) if user_id is not None: subquery = subquery.where(Playlist.user_id == user_id) # Apply sorting if sort_by: if sort_by == PlaylistSortField.NAME: sort_column = Playlist.name elif sort_by == PlaylistSortField.GENRE: sort_column = Playlist.genre elif sort_by == PlaylistSortField.CREATED_AT: sort_column = Playlist.created_at elif sort_by == PlaylistSortField.UPDATED_AT: sort_column = Playlist.updated_at else: sort_column = Playlist.name if sort_order == SortOrder.DESC: subquery = subquery.order_by(sort_column.desc()) else: subquery = subquery.order_by(sort_column.asc()) else: # Default sorting by name ascending subquery = subquery.order_by(Playlist.name.asc()) # Apply pagination if offset > 0: subquery = subquery.offset(offset) if limit is not None: subquery = subquery.limit(limit) result = await self.session.exec(subquery) rows = result.all() # Convert to dictionary format playlists = [] for row in rows: playlists.append({ "id": row.id, "name": row.name, "description": row.description, "genre": row.genre, "user_id": row.user_id, "user_name": row.user_name, "is_main": row.is_main, "is_current": row.is_current, "is_deletable": row.is_deletable, "created_at": row.created_at, "updated_at": row.updated_at, "sound_count": row.sound_count or 0, "total_duration": row.total_duration or 0, }) return playlists except Exception: logger.exception( "Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s", search_query, sort_by, sort_order ) raise