Files
sdb2-backend/app/repositories/playlist.py

536 lines
20 KiB
Python

"""Playlist repository for database operations."""
from datetime import UTC, datetime
from enum import Enum
from sqlalchemy import func, update
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.models.playlist import Playlist
from app.models.playlist_sound import PlaylistSound
from app.models.sound import Sound
from app.models.user import User
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class PlaylistSortField(str, Enum):
"""Playlist sort field enumeration."""
NAME = "name"
GENRE = "genre"
CREATED_AT = "created_at"
UPDATED_AT = "updated_at"
SOUND_COUNT = "sound_count"
TOTAL_DURATION = "total_duration"
class SortOrder(str, Enum):
"""Sort order enumeration."""
ASC = "asc"
DESC = "desc"
class PlaylistRepository(BaseRepository[Playlist]):
"""Repository for playlist operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the playlist repository."""
super().__init__(Playlist, session)
async def _update_playlist_timestamp(self, playlist_id: int) -> None:
"""Update the playlist's updated_at timestamp."""
try:
update_stmt = (
update(Playlist)
.where(Playlist.id == playlist_id)
.values(updated_at=datetime.now(UTC))
)
await self.session.exec(update_stmt)
# Note: No commit here - let the calling method handle transaction
# management
except Exception:
logger.exception(
"Failed to update playlist timestamp for playlist: %s", playlist_id,
)
raise
async def get_by_name(self, name: str) -> Playlist | None:
"""Get a playlist by name."""
try:
statement = select(Playlist).where(Playlist.name == name)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get playlist by name: %s", name)
raise
async def get_by_user_id(self, user_id: int) -> list[Playlist]:
"""Get all playlists for a user."""
try:
statement = select(Playlist).where(Playlist.user_id == user_id)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get playlists for user: %s", user_id)
raise
async def get_main_playlist(self) -> Playlist | None:
"""Get the global main playlist."""
try:
statement = select(Playlist).where(
Playlist.is_main == True, # noqa: E712
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get main playlist")
raise
async def get_current_playlist(self) -> Playlist | None:
"""Get the global current playlist (app-wide)."""
try:
statement = select(Playlist).where(
Playlist.is_current == True, # noqa: E712
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get current playlist")
raise
async def search_by_name(
self,
query: str,
user_id: int | None = None,
) -> list[Playlist]:
"""Search playlists by name (case-insensitive)."""
try:
statement = select(Playlist).where(
func.lower(Playlist.name).like(f"%{query.lower()}%"),
)
if user_id is not None:
statement = statement.where(Playlist.user_id == user_id)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to search playlists by name: %s", query)
raise
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
"""Get all sounds in a playlist with extractions, ordered by position."""
try:
statement = (
select(Sound)
.join(PlaylistSound)
.options(selectinload(Sound.extractions))
.where(PlaylistSound.playlist_id == playlist_id)
.order_by(PlaylistSound.position)
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get sounds for playlist: %s", playlist_id)
raise
async def get_playlist_sound_entries(self, playlist_id: int) -> list[PlaylistSound]:
"""Get all PlaylistSound entries for a playlist, ordered by position."""
try:
statement = (
select(PlaylistSound)
.where(PlaylistSound.playlist_id == playlist_id)
.order_by(PlaylistSound.position)
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception(
"Failed to get playlist sound entries for playlist: %s",
playlist_id,
)
raise
async def add_sound_to_playlist(
self,
playlist_id: int,
sound_id: int,
position: int | None = None,
) -> PlaylistSound:
"""Add a sound to a playlist."""
try:
if position is None:
# Get the next available position
statement = select(
func.coalesce(func.max(PlaylistSound.position), -1) + 1,
).where(PlaylistSound.playlist_id == playlist_id)
result = await self.session.exec(statement)
position = result.first() or 0
else:
# Shift existing positions to make room for the new sound
# Use a two-step approach to avoid unique constraint violations:
# 1. Move all affected positions to negative temporary positions
# 2. Then move them to their final positions
# Step 1: Move to temporary negative positions
update_to_negative = (
update(PlaylistSound)
.where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.position >= position,
)
.values(position=PlaylistSound.position - 10000)
)
await self.session.exec(update_to_negative)
await self.session.commit()
# Step 2: Move from temporary negative positions to final positions
update_to_final = (
update(PlaylistSound)
.where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.position < 0,
)
.values(position=PlaylistSound.position + 10001)
)
await self.session.exec(update_to_final)
await self.session.commit()
playlist_sound = PlaylistSound(
playlist_id=playlist_id,
sound_id=sound_id,
position=position,
)
self.session.add(playlist_sound)
# Update playlist timestamp before commit
await self._update_playlist_timestamp(playlist_id)
await self.session.commit()
await self.session.refresh(playlist_sound)
except Exception:
await self.session.rollback()
logger.exception(
"Failed to add sound %s to playlist %s",
sound_id,
playlist_id,
)
raise
else:
logger.info(
"Added sound %s to playlist %s at position %s",
sound_id,
playlist_id,
position,
)
return playlist_sound
async def remove_sound_from_playlist(self, playlist_id: int, sound_id: int) -> None:
"""Remove a sound from a playlist."""
try:
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.sound_id == sound_id,
)
result = await self.session.exec(statement)
playlist_sound = result.first()
if playlist_sound:
await self.session.delete(playlist_sound)
# Update playlist timestamp before commit
await self._update_playlist_timestamp(playlist_id)
await self.session.commit()
logger.info("Removed sound %s from playlist %s", sound_id, playlist_id)
except Exception:
await self.session.rollback()
logger.exception(
"Failed to remove sound %s from playlist %s",
sound_id,
playlist_id,
)
raise
async def reorder_playlist_sounds(
self,
playlist_id: int,
sound_positions: list[tuple[int, int]],
) -> None:
"""Reorder sounds in a playlist.
Args:
playlist_id: The playlist ID
sound_positions: List of (sound_id, new_position) tuples
"""
try:
# Phase 1: Set all positions to temporary negative values to avoid conflicts
temp_offset = -10000 # Use large negative number to avoid conflicts
for i, (sound_id, _) in enumerate(sound_positions):
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.sound_id == sound_id,
)
result = await self.session.exec(statement)
playlist_sound = result.first()
if playlist_sound:
playlist_sound.position = temp_offset + i
# Phase 2: Set the final positions
for sound_id, new_position in sound_positions:
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.sound_id == sound_id,
)
result = await self.session.exec(statement)
playlist_sound = result.first()
if playlist_sound:
playlist_sound.position = new_position
# Update playlist timestamp before commit
await self._update_playlist_timestamp(playlist_id)
await self.session.commit()
logger.info("Reordered sounds in playlist %s", playlist_id)
except Exception:
await self.session.rollback()
logger.exception("Failed to reorder sounds in playlist %s", playlist_id)
raise
async def get_playlist_sound_count(self, playlist_id: int) -> int:
"""Get the number of sounds in a playlist."""
try:
statement = select(func.count(PlaylistSound.id)).where(
PlaylistSound.playlist_id == playlist_id,
)
result = await self.session.exec(statement)
return result.first() or 0
except Exception:
logger.exception("Failed to get sound count for playlist: %s", playlist_id)
raise
async def is_sound_in_playlist(self, playlist_id: int, sound_id: int) -> bool:
"""Check if a sound is already in a playlist."""
try:
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.sound_id == sound_id,
)
result = await self.session.exec(statement)
return result.first() is not None
except Exception:
logger.exception(
"Failed to check if sound %s is in playlist %s",
sound_id,
playlist_id,
)
raise
async def search_and_sort( # noqa: C901, PLR0913, PLR0912, PLR0915
self,
search_query: str | None = None,
sort_by: PlaylistSortField | None = None,
sort_order: SortOrder = SortOrder.ASC,
user_id: int | None = None,
include_stats: bool = False, # noqa: FBT001, FBT002
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False,
current_user_id: int | None = None,
) -> list[dict]:
"""Search and sort playlists with optional statistics."""
try:
if include_stats and sort_by in (
PlaylistSortField.SOUND_COUNT,
PlaylistSortField.TOTAL_DURATION,
):
# Use subquery for sorting by stats
subquery = (
select(
Playlist.id,
Playlist.name,
Playlist.description,
Playlist.genre,
Playlist.user_id,
Playlist.is_main,
Playlist.is_current,
Playlist.is_deletable,
Playlist.created_at,
Playlist.updated_at,
func.count(PlaylistSound.id).label("sound_count"),
func.coalesce(func.sum(Sound.duration), 0).label(
"total_duration",
),
User.name.label("user_name"),
)
.select_from(Playlist)
.join(User, Playlist.user_id == User.id, isouter=True)
.join(
PlaylistSound,
Playlist.id == PlaylistSound.playlist_id,
isouter=True,
)
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name)
)
# Apply filters
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(
func.lower(Playlist.name).like(search_pattern),
)
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = select(1).select_from(Favorite).where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
subquery = subquery.where(favorites_subquery.exists())
# Apply sorting
if sort_by == PlaylistSortField.SOUND_COUNT:
if sort_order == SortOrder.DESC:
subquery = subquery.order_by(
func.count(PlaylistSound.id).desc(),
)
else:
subquery = subquery.order_by(func.count(PlaylistSound.id).asc())
elif sort_by == PlaylistSortField.TOTAL_DURATION:
if sort_order == SortOrder.DESC:
subquery = subquery.order_by(
func.coalesce(func.sum(Sound.duration), 0).desc(),
)
else:
subquery = subquery.order_by(
func.coalesce(func.sum(Sound.duration), 0).asc(),
)
else:
# Default sorting by name
subquery = subquery.order_by(Playlist.name.asc())
else:
# Simple query without stats-based sorting
subquery = (
select(
Playlist.id,
Playlist.name,
Playlist.description,
Playlist.genre,
Playlist.user_id,
Playlist.is_main,
Playlist.is_current,
Playlist.is_deletable,
Playlist.created_at,
Playlist.updated_at,
func.count(PlaylistSound.id).label("sound_count"),
func.coalesce(func.sum(Sound.duration), 0).label(
"total_duration",
),
User.name.label("user_name"),
)
.select_from(Playlist)
.join(User, Playlist.user_id == User.id, isouter=True)
.join(
PlaylistSound,
Playlist.id == PlaylistSound.playlist_id,
isouter=True,
)
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name)
)
# Apply filters
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(
func.lower(Playlist.name).like(search_pattern),
)
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = select(1).select_from(Favorite).where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
subquery = subquery.where(favorites_subquery.exists())
# Apply sorting
if sort_by:
if sort_by == PlaylistSortField.NAME:
sort_column = Playlist.name
elif sort_by == PlaylistSortField.GENRE:
sort_column = Playlist.genre
elif sort_by == PlaylistSortField.CREATED_AT:
sort_column = Playlist.created_at
elif sort_by == PlaylistSortField.UPDATED_AT:
sort_column = Playlist.updated_at
else:
sort_column = Playlist.name
if sort_order == SortOrder.DESC:
subquery = subquery.order_by(sort_column.desc())
else:
subquery = subquery.order_by(sort_column.asc())
else:
# Default sorting by name ascending
subquery = subquery.order_by(Playlist.name.asc())
# Apply pagination
if offset > 0:
subquery = subquery.offset(offset)
if limit is not None:
subquery = subquery.limit(limit)
result = await self.session.exec(subquery)
rows = result.all()
# Convert to dictionary format
playlists = [
{
"id": row.id,
"name": row.name,
"description": row.description,
"genre": row.genre,
"user_id": row.user_id,
"user_name": row.user_name,
"is_main": row.is_main,
"is_current": row.is_current,
"is_deletable": row.is_deletable,
"created_at": row.created_at,
"updated_at": row.updated_at,
"sound_count": row.sound_count or 0,
"total_duration": row.total_duration or 0,
}
for row in rows
]
except Exception:
logger.exception(
(
"Failed to search and sort playlists: "
"query=%s, sort_by=%s, sort_order=%s"
),
search_query,
sort_by,
sort_order,
)
raise
else:
return playlists