"""Playlist repository for database operations.""" from datetime import UTC, datetime from enum import Enum from sqlalchemy import func, update from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.playlist import Playlist from app.models.playlist_sound import PlaylistSound from app.models.sound import Sound from app.models.user import User from app.repositories.base import BaseRepository logger = get_logger(__name__) class PlaylistSortField(str, Enum): """Playlist sort field enumeration.""" NAME = "name" GENRE = "genre" CREATED_AT = "created_at" UPDATED_AT = "updated_at" SOUND_COUNT = "sound_count" TOTAL_DURATION = "total_duration" class SortOrder(str, Enum): """Sort order enumeration.""" ASC = "asc" DESC = "desc" class PlaylistRepository(BaseRepository[Playlist]): """Repository for playlist operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the playlist repository.""" super().__init__(Playlist, session) async def _update_playlist_timestamp(self, playlist_id: int) -> None: """Update the playlist's updated_at timestamp.""" try: update_stmt = ( update(Playlist) .where(Playlist.id == playlist_id) .values(updated_at=datetime.now(UTC)) ) await self.session.exec(update_stmt) # Note: No commit here - let the calling method handle transaction # management except Exception: logger.exception( "Failed to update playlist timestamp for playlist: %s", playlist_id, ) raise async def get_by_name(self, name: str) -> Playlist | None: """Get a playlist by name.""" try: statement = select(Playlist).where(Playlist.name == name) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get playlist by name: %s", name) raise async def get_by_user_id(self, user_id: int) -> list[Playlist]: """Get all playlists for a user.""" try: statement = select(Playlist).where(Playlist.user_id == user_id) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get playlists for user: %s", user_id) raise async def get_main_playlist(self) -> Playlist | None: """Get the global main playlist.""" try: statement = select(Playlist).where( Playlist.is_main == True, # noqa: E712 ) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get main playlist") raise async def get_current_playlist(self) -> Playlist | None: """Get the global current playlist (app-wide).""" try: statement = select(Playlist).where( Playlist.is_current == True, # noqa: E712 ) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get current playlist") raise async def search_by_name( self, query: str, user_id: int | None = None, ) -> list[Playlist]: """Search playlists by name (case-insensitive).""" try: statement = select(Playlist).where( func.lower(Playlist.name).like(f"%{query.lower()}%"), ) if user_id is not None: statement = statement.where(Playlist.user_id == user_id) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to search playlists by name: %s", query) raise async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]: """Get all sounds in a playlist with extractions, ordered by position.""" try: statement = ( select(Sound) .join(PlaylistSound) .options(selectinload(Sound.extractions)) .where(PlaylistSound.playlist_id == playlist_id) .order_by(PlaylistSound.position) ) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get sounds for playlist: %s", playlist_id) raise async def get_playlist_sound_entries(self, playlist_id: int) -> list[PlaylistSound]: """Get all PlaylistSound entries for a playlist, ordered by position.""" try: statement = ( select(PlaylistSound) .where(PlaylistSound.playlist_id == playlist_id) .order_by(PlaylistSound.position) ) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception( "Failed to get playlist sound entries for playlist: %s", playlist_id, ) raise async def add_sound_to_playlist( self, playlist_id: int, sound_id: int, position: int | None = None, ) -> PlaylistSound: """Add a sound to a playlist.""" try: if position is None: # Get the next available position statement = select( func.coalesce(func.max(PlaylistSound.position), -1) + 1, ).where(PlaylistSound.playlist_id == playlist_id) result = await self.session.exec(statement) position = result.first() or 0 else: # Shift existing positions to make room for the new sound # Use a two-step approach to avoid unique constraint violations: # 1. Move all affected positions to negative temporary positions # 2. Then move them to their final positions # Step 1: Move to temporary negative positions update_to_negative = ( update(PlaylistSound) .where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.position >= position, ) .values(position=PlaylistSound.position - 10000) ) await self.session.exec(update_to_negative) await self.session.commit() # Step 2: Move from temporary negative positions to final positions update_to_final = ( update(PlaylistSound) .where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.position < 0, ) .values(position=PlaylistSound.position + 10001) ) await self.session.exec(update_to_final) await self.session.commit() playlist_sound = PlaylistSound( playlist_id=playlist_id, sound_id=sound_id, position=position, ) self.session.add(playlist_sound) # Update playlist timestamp before commit await self._update_playlist_timestamp(playlist_id) await self.session.commit() await self.session.refresh(playlist_sound) except Exception: await self.session.rollback() logger.exception( "Failed to add sound %s to playlist %s", sound_id, playlist_id, ) raise else: logger.info( "Added sound %s to playlist %s at position %s", sound_id, playlist_id, position, ) return playlist_sound async def remove_sound_from_playlist(self, playlist_id: int, sound_id: int) -> None: """Remove a sound from a playlist.""" try: statement = select(PlaylistSound).where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.sound_id == sound_id, ) result = await self.session.exec(statement) playlist_sound = result.first() if playlist_sound: await self.session.delete(playlist_sound) # Update playlist timestamp before commit await self._update_playlist_timestamp(playlist_id) await self.session.commit() logger.info("Removed sound %s from playlist %s", sound_id, playlist_id) except Exception: await self.session.rollback() logger.exception( "Failed to remove sound %s from playlist %s", sound_id, playlist_id, ) raise async def reorder_playlist_sounds( self, playlist_id: int, sound_positions: list[tuple[int, int]], ) -> None: """Reorder sounds in a playlist. Args: playlist_id: The playlist ID sound_positions: List of (sound_id, new_position) tuples """ try: # Phase 1: Set all positions to temporary negative values to avoid conflicts temp_offset = -10000 # Use large negative number to avoid conflicts for i, (sound_id, _) in enumerate(sound_positions): statement = select(PlaylistSound).where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.sound_id == sound_id, ) result = await self.session.exec(statement) playlist_sound = result.first() if playlist_sound: playlist_sound.position = temp_offset + i # Phase 2: Set the final positions for sound_id, new_position in sound_positions: statement = select(PlaylistSound).where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.sound_id == sound_id, ) result = await self.session.exec(statement) playlist_sound = result.first() if playlist_sound: playlist_sound.position = new_position # Update playlist timestamp before commit await self._update_playlist_timestamp(playlist_id) await self.session.commit() logger.info("Reordered sounds in playlist %s", playlist_id) except Exception: await self.session.rollback() logger.exception("Failed to reorder sounds in playlist %s", playlist_id) raise async def get_playlist_sound_count(self, playlist_id: int) -> int: """Get the number of sounds in a playlist.""" try: statement = select(func.count(PlaylistSound.id)).where( PlaylistSound.playlist_id == playlist_id, ) result = await self.session.exec(statement) return result.first() or 0 except Exception: logger.exception("Failed to get sound count for playlist: %s", playlist_id) raise async def is_sound_in_playlist(self, playlist_id: int, sound_id: int) -> bool: """Check if a sound is already in a playlist.""" try: statement = select(PlaylistSound).where( PlaylistSound.playlist_id == playlist_id, PlaylistSound.sound_id == sound_id, ) result = await self.session.exec(statement) return result.first() is not None except Exception: logger.exception( "Failed to check if sound %s is in playlist %s", sound_id, playlist_id, ) raise async def search_and_sort( # noqa: C901, PLR0913, PLR0912, PLR0915 self, search_query: str | None = None, sort_by: PlaylistSortField | None = None, sort_order: SortOrder = SortOrder.ASC, user_id: int | None = None, include_stats: bool = False, # noqa: FBT001, FBT002 limit: int | None = None, offset: int = 0, ) -> list[dict]: """Search and sort playlists with optional statistics.""" try: if include_stats and sort_by in ( PlaylistSortField.SOUND_COUNT, PlaylistSortField.TOTAL_DURATION, ): # Use subquery for sorting by stats subquery = ( select( Playlist.id, Playlist.name, Playlist.description, Playlist.genre, Playlist.user_id, Playlist.is_main, Playlist.is_current, Playlist.is_deletable, Playlist.created_at, Playlist.updated_at, func.count(PlaylistSound.id).label("sound_count"), func.coalesce(func.sum(Sound.duration), 0).label( "total_duration", ), User.name.label("user_name"), ) .select_from(Playlist) .join(User, Playlist.user_id == User.id, isouter=True) .join( PlaylistSound, Playlist.id == PlaylistSound.playlist_id, isouter=True, ) .join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True) .group_by(Playlist.id, User.name) ) # Apply filters if search_query and search_query.strip(): search_pattern = f"%{search_query.strip().lower()}%" subquery = subquery.where( func.lower(Playlist.name).like(search_pattern), ) if user_id is not None: subquery = subquery.where(Playlist.user_id == user_id) # Apply sorting if sort_by == PlaylistSortField.SOUND_COUNT: if sort_order == SortOrder.DESC: subquery = subquery.order_by( func.count(PlaylistSound.id).desc(), ) else: subquery = subquery.order_by(func.count(PlaylistSound.id).asc()) elif sort_by == PlaylistSortField.TOTAL_DURATION: if sort_order == SortOrder.DESC: subquery = subquery.order_by( func.coalesce(func.sum(Sound.duration), 0).desc(), ) else: subquery = subquery.order_by( func.coalesce(func.sum(Sound.duration), 0).asc(), ) else: # Default sorting by name subquery = subquery.order_by(Playlist.name.asc()) else: # Simple query without stats-based sorting subquery = ( select( Playlist.id, Playlist.name, Playlist.description, Playlist.genre, Playlist.user_id, Playlist.is_main, Playlist.is_current, Playlist.is_deletable, Playlist.created_at, Playlist.updated_at, func.count(PlaylistSound.id).label("sound_count"), func.coalesce(func.sum(Sound.duration), 0).label( "total_duration", ), User.name.label("user_name"), ) .select_from(Playlist) .join(User, Playlist.user_id == User.id, isouter=True) .join( PlaylistSound, Playlist.id == PlaylistSound.playlist_id, isouter=True, ) .join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True) .group_by(Playlist.id, User.name) ) # Apply filters if search_query and search_query.strip(): search_pattern = f"%{search_query.strip().lower()}%" subquery = subquery.where( func.lower(Playlist.name).like(search_pattern), ) if user_id is not None: subquery = subquery.where(Playlist.user_id == user_id) # Apply sorting if sort_by: if sort_by == PlaylistSortField.NAME: sort_column = Playlist.name elif sort_by == PlaylistSortField.GENRE: sort_column = Playlist.genre elif sort_by == PlaylistSortField.CREATED_AT: sort_column = Playlist.created_at elif sort_by == PlaylistSortField.UPDATED_AT: sort_column = Playlist.updated_at else: sort_column = Playlist.name if sort_order == SortOrder.DESC: subquery = subquery.order_by(sort_column.desc()) else: subquery = subquery.order_by(sort_column.asc()) else: # Default sorting by name ascending subquery = subquery.order_by(Playlist.name.asc()) # Apply pagination if offset > 0: subquery = subquery.offset(offset) if limit is not None: subquery = subquery.limit(limit) result = await self.session.exec(subquery) rows = result.all() # Convert to dictionary format playlists = [ { "id": row.id, "name": row.name, "description": row.description, "genre": row.genre, "user_id": row.user_id, "user_name": row.user_name, "is_main": row.is_main, "is_current": row.is_current, "is_deletable": row.is_deletable, "created_at": row.created_at, "updated_at": row.updated_at, "sound_count": row.sound_count or 0, "total_duration": row.total_duration or 0, } for row in rows ] except Exception: logger.exception( ( "Failed to search and sort playlists: " "query=%s, sort_by=%s, sort_order=%s" ), search_query, sort_by, sort_order, ) raise else: return playlists