515 lines
19 KiB
Python
515 lines
19 KiB
Python
"""Playlist repository for database operations."""
|
|
|
|
from datetime import UTC, datetime
|
|
from enum import Enum
|
|
|
|
from sqlalchemy import func, update
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.playlist import Playlist
|
|
from app.models.playlist_sound import PlaylistSound
|
|
from app.models.sound import Sound
|
|
from app.models.user import User
|
|
from app.repositories.base import BaseRepository
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class PlaylistSortField(str, Enum):
|
|
"""Playlist sort field enumeration."""
|
|
|
|
NAME = "name"
|
|
GENRE = "genre"
|
|
CREATED_AT = "created_at"
|
|
UPDATED_AT = "updated_at"
|
|
SOUND_COUNT = "sound_count"
|
|
TOTAL_DURATION = "total_duration"
|
|
|
|
|
|
class SortOrder(str, Enum):
|
|
"""Sort order enumeration."""
|
|
|
|
ASC = "asc"
|
|
DESC = "desc"
|
|
|
|
|
|
class PlaylistRepository(BaseRepository[Playlist]):
|
|
"""Repository for playlist operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the playlist repository."""
|
|
super().__init__(Playlist, session)
|
|
|
|
async def _update_playlist_timestamp(self, playlist_id: int) -> None:
|
|
"""Update the playlist's updated_at timestamp."""
|
|
try:
|
|
update_stmt = (
|
|
update(Playlist)
|
|
.where(Playlist.id == playlist_id)
|
|
.values(updated_at=datetime.now(UTC))
|
|
)
|
|
await self.session.exec(update_stmt)
|
|
# Note: No commit here - let the calling method handle transaction
|
|
# management
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to update playlist timestamp for playlist: %s", playlist_id,
|
|
)
|
|
raise
|
|
|
|
async def get_by_name(self, name: str) -> Playlist | None:
|
|
"""Get a playlist by name."""
|
|
try:
|
|
statement = select(Playlist).where(Playlist.name == name)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get playlist by name: %s", name)
|
|
raise
|
|
|
|
async def get_by_user_id(self, user_id: int) -> list[Playlist]:
|
|
"""Get all playlists for a user."""
|
|
try:
|
|
statement = select(Playlist).where(Playlist.user_id == user_id)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get playlists for user: %s", user_id)
|
|
raise
|
|
|
|
async def get_main_playlist(self) -> Playlist | None:
|
|
"""Get the global main playlist."""
|
|
try:
|
|
statement = select(Playlist).where(
|
|
Playlist.is_main == True, # noqa: E712
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get main playlist")
|
|
raise
|
|
|
|
async def get_current_playlist(self) -> Playlist | None:
|
|
"""Get the global current playlist (app-wide)."""
|
|
try:
|
|
statement = select(Playlist).where(
|
|
Playlist.is_current == True, # noqa: E712
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get current playlist")
|
|
raise
|
|
|
|
async def search_by_name(
|
|
self,
|
|
query: str,
|
|
user_id: int | None = None,
|
|
) -> list[Playlist]:
|
|
"""Search playlists by name (case-insensitive)."""
|
|
try:
|
|
statement = select(Playlist).where(
|
|
func.lower(Playlist.name).like(f"%{query.lower()}%"),
|
|
)
|
|
if user_id is not None:
|
|
statement = statement.where(Playlist.user_id == user_id)
|
|
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to search playlists by name: %s", query)
|
|
raise
|
|
|
|
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
|
|
"""Get all sounds in a playlist with extractions, ordered by position."""
|
|
try:
|
|
statement = (
|
|
select(Sound)
|
|
.join(PlaylistSound)
|
|
.options(selectinload(Sound.extractions))
|
|
.where(PlaylistSound.playlist_id == playlist_id)
|
|
.order_by(PlaylistSound.position)
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get sounds for playlist: %s", playlist_id)
|
|
raise
|
|
|
|
async def get_playlist_sound_entries(self, playlist_id: int) -> list[PlaylistSound]:
|
|
"""Get all PlaylistSound entries for a playlist, ordered by position."""
|
|
try:
|
|
statement = (
|
|
select(PlaylistSound)
|
|
.where(PlaylistSound.playlist_id == playlist_id)
|
|
.order_by(PlaylistSound.position)
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to get playlist sound entries for playlist: %s",
|
|
playlist_id,
|
|
)
|
|
raise
|
|
|
|
async def add_sound_to_playlist(
|
|
self,
|
|
playlist_id: int,
|
|
sound_id: int,
|
|
position: int | None = None,
|
|
) -> PlaylistSound:
|
|
"""Add a sound to a playlist."""
|
|
try:
|
|
if position is None:
|
|
# Get the next available position
|
|
statement = select(
|
|
func.coalesce(func.max(PlaylistSound.position), -1) + 1,
|
|
).where(PlaylistSound.playlist_id == playlist_id)
|
|
result = await self.session.exec(statement)
|
|
position = result.first() or 0
|
|
else:
|
|
# Shift existing positions to make room for the new sound
|
|
# Use a two-step approach to avoid unique constraint violations:
|
|
# 1. Move all affected positions to negative temporary positions
|
|
# 2. Then move them to their final positions
|
|
|
|
# Step 1: Move to temporary negative positions
|
|
update_to_negative = (
|
|
update(PlaylistSound)
|
|
.where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.position >= position,
|
|
)
|
|
.values(position=PlaylistSound.position - 10000)
|
|
)
|
|
await self.session.exec(update_to_negative)
|
|
await self.session.commit()
|
|
|
|
# Step 2: Move from temporary negative positions to final positions
|
|
update_to_final = (
|
|
update(PlaylistSound)
|
|
.where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.position < 0,
|
|
)
|
|
.values(position=PlaylistSound.position + 10001)
|
|
)
|
|
await self.session.exec(update_to_final)
|
|
await self.session.commit()
|
|
|
|
playlist_sound = PlaylistSound(
|
|
playlist_id=playlist_id,
|
|
sound_id=sound_id,
|
|
position=position,
|
|
)
|
|
self.session.add(playlist_sound)
|
|
|
|
# Update playlist timestamp before commit
|
|
await self._update_playlist_timestamp(playlist_id)
|
|
await self.session.commit()
|
|
await self.session.refresh(playlist_sound)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception(
|
|
"Failed to add sound %s to playlist %s",
|
|
sound_id,
|
|
playlist_id,
|
|
)
|
|
raise
|
|
else:
|
|
logger.info(
|
|
"Added sound %s to playlist %s at position %s",
|
|
sound_id,
|
|
playlist_id,
|
|
position,
|
|
)
|
|
return playlist_sound
|
|
|
|
async def remove_sound_from_playlist(self, playlist_id: int, sound_id: int) -> None:
|
|
"""Remove a sound from a playlist."""
|
|
try:
|
|
statement = select(PlaylistSound).where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.sound_id == sound_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
playlist_sound = result.first()
|
|
|
|
if playlist_sound:
|
|
await self.session.delete(playlist_sound)
|
|
|
|
# Update playlist timestamp before commit
|
|
await self._update_playlist_timestamp(playlist_id)
|
|
await self.session.commit()
|
|
logger.info("Removed sound %s from playlist %s", sound_id, playlist_id)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception(
|
|
"Failed to remove sound %s from playlist %s",
|
|
sound_id,
|
|
playlist_id,
|
|
)
|
|
raise
|
|
|
|
async def reorder_playlist_sounds(
|
|
self,
|
|
playlist_id: int,
|
|
sound_positions: list[tuple[int, int]],
|
|
) -> None:
|
|
"""Reorder sounds in a playlist.
|
|
|
|
Args:
|
|
playlist_id: The playlist ID
|
|
sound_positions: List of (sound_id, new_position) tuples
|
|
|
|
"""
|
|
try:
|
|
# Phase 1: Set all positions to temporary negative values to avoid conflicts
|
|
temp_offset = -10000 # Use large negative number to avoid conflicts
|
|
for i, (sound_id, _) in enumerate(sound_positions):
|
|
statement = select(PlaylistSound).where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.sound_id == sound_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
playlist_sound = result.first()
|
|
|
|
if playlist_sound:
|
|
playlist_sound.position = temp_offset + i
|
|
|
|
# Phase 2: Set the final positions
|
|
for sound_id, new_position in sound_positions:
|
|
statement = select(PlaylistSound).where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.sound_id == sound_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
playlist_sound = result.first()
|
|
|
|
if playlist_sound:
|
|
playlist_sound.position = new_position
|
|
|
|
# Update playlist timestamp before commit
|
|
await self._update_playlist_timestamp(playlist_id)
|
|
await self.session.commit()
|
|
logger.info("Reordered sounds in playlist %s", playlist_id)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to reorder sounds in playlist %s", playlist_id)
|
|
raise
|
|
|
|
async def get_playlist_sound_count(self, playlist_id: int) -> int:
|
|
"""Get the number of sounds in a playlist."""
|
|
try:
|
|
statement = select(func.count(PlaylistSound.id)).where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first() or 0
|
|
except Exception:
|
|
logger.exception("Failed to get sound count for playlist: %s", playlist_id)
|
|
raise
|
|
|
|
async def is_sound_in_playlist(self, playlist_id: int, sound_id: int) -> bool:
|
|
"""Check if a sound is already in a playlist."""
|
|
try:
|
|
statement = select(PlaylistSound).where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.sound_id == sound_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first() is not None
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to check if sound %s is in playlist %s",
|
|
sound_id,
|
|
playlist_id,
|
|
)
|
|
raise
|
|
|
|
async def search_and_sort( # noqa: C901, PLR0913, PLR0912, PLR0915
|
|
self,
|
|
search_query: str | None = None,
|
|
sort_by: PlaylistSortField | None = None,
|
|
sort_order: SortOrder = SortOrder.ASC,
|
|
user_id: int | None = None,
|
|
include_stats: bool = False, # noqa: FBT001, FBT002
|
|
limit: int | None = None,
|
|
offset: int = 0,
|
|
) -> list[dict]:
|
|
"""Search and sort playlists with optional statistics."""
|
|
try:
|
|
if include_stats and sort_by in (
|
|
PlaylistSortField.SOUND_COUNT,
|
|
PlaylistSortField.TOTAL_DURATION,
|
|
):
|
|
# Use subquery for sorting by stats
|
|
subquery = (
|
|
select(
|
|
Playlist.id,
|
|
Playlist.name,
|
|
Playlist.description,
|
|
Playlist.genre,
|
|
Playlist.user_id,
|
|
Playlist.is_main,
|
|
Playlist.is_current,
|
|
Playlist.is_deletable,
|
|
Playlist.created_at,
|
|
Playlist.updated_at,
|
|
func.count(PlaylistSound.id).label("sound_count"),
|
|
func.coalesce(func.sum(Sound.duration), 0).label(
|
|
"total_duration",
|
|
),
|
|
User.name.label("user_name"),
|
|
)
|
|
.select_from(Playlist)
|
|
.join(User, Playlist.user_id == User.id, isouter=True)
|
|
.join(
|
|
PlaylistSound,
|
|
Playlist.id == PlaylistSound.playlist_id,
|
|
isouter=True,
|
|
)
|
|
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
|
.group_by(Playlist.id, User.name)
|
|
)
|
|
|
|
# Apply filters
|
|
if search_query and search_query.strip():
|
|
search_pattern = f"%{search_query.strip().lower()}%"
|
|
subquery = subquery.where(
|
|
func.lower(Playlist.name).like(search_pattern),
|
|
)
|
|
|
|
if user_id is not None:
|
|
subquery = subquery.where(Playlist.user_id == user_id)
|
|
|
|
# Apply sorting
|
|
if sort_by == PlaylistSortField.SOUND_COUNT:
|
|
if sort_order == SortOrder.DESC:
|
|
subquery = subquery.order_by(
|
|
func.count(PlaylistSound.id).desc(),
|
|
)
|
|
else:
|
|
subquery = subquery.order_by(func.count(PlaylistSound.id).asc())
|
|
elif sort_by == PlaylistSortField.TOTAL_DURATION:
|
|
if sort_order == SortOrder.DESC:
|
|
subquery = subquery.order_by(
|
|
func.coalesce(func.sum(Sound.duration), 0).desc(),
|
|
)
|
|
else:
|
|
subquery = subquery.order_by(
|
|
func.coalesce(func.sum(Sound.duration), 0).asc(),
|
|
)
|
|
else:
|
|
# Default sorting by name
|
|
subquery = subquery.order_by(Playlist.name.asc())
|
|
|
|
else:
|
|
# Simple query without stats-based sorting
|
|
subquery = (
|
|
select(
|
|
Playlist.id,
|
|
Playlist.name,
|
|
Playlist.description,
|
|
Playlist.genre,
|
|
Playlist.user_id,
|
|
Playlist.is_main,
|
|
Playlist.is_current,
|
|
Playlist.is_deletable,
|
|
Playlist.created_at,
|
|
Playlist.updated_at,
|
|
func.count(PlaylistSound.id).label("sound_count"),
|
|
func.coalesce(func.sum(Sound.duration), 0).label(
|
|
"total_duration",
|
|
),
|
|
User.name.label("user_name"),
|
|
)
|
|
.select_from(Playlist)
|
|
.join(User, Playlist.user_id == User.id, isouter=True)
|
|
.join(
|
|
PlaylistSound,
|
|
Playlist.id == PlaylistSound.playlist_id,
|
|
isouter=True,
|
|
)
|
|
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
|
.group_by(Playlist.id, User.name)
|
|
)
|
|
|
|
# Apply filters
|
|
if search_query and search_query.strip():
|
|
search_pattern = f"%{search_query.strip().lower()}%"
|
|
subquery = subquery.where(
|
|
func.lower(Playlist.name).like(search_pattern),
|
|
)
|
|
|
|
if user_id is not None:
|
|
subquery = subquery.where(Playlist.user_id == user_id)
|
|
|
|
# Apply sorting
|
|
if sort_by:
|
|
if sort_by == PlaylistSortField.NAME:
|
|
sort_column = Playlist.name
|
|
elif sort_by == PlaylistSortField.GENRE:
|
|
sort_column = Playlist.genre
|
|
elif sort_by == PlaylistSortField.CREATED_AT:
|
|
sort_column = Playlist.created_at
|
|
elif sort_by == PlaylistSortField.UPDATED_AT:
|
|
sort_column = Playlist.updated_at
|
|
else:
|
|
sort_column = Playlist.name
|
|
|
|
if sort_order == SortOrder.DESC:
|
|
subquery = subquery.order_by(sort_column.desc())
|
|
else:
|
|
subquery = subquery.order_by(sort_column.asc())
|
|
else:
|
|
# Default sorting by name ascending
|
|
subquery = subquery.order_by(Playlist.name.asc())
|
|
|
|
# Apply pagination
|
|
if offset > 0:
|
|
subquery = subquery.offset(offset)
|
|
if limit is not None:
|
|
subquery = subquery.limit(limit)
|
|
|
|
result = await self.session.exec(subquery)
|
|
rows = result.all()
|
|
|
|
# Convert to dictionary format
|
|
playlists = [
|
|
{
|
|
"id": row.id,
|
|
"name": row.name,
|
|
"description": row.description,
|
|
"genre": row.genre,
|
|
"user_id": row.user_id,
|
|
"user_name": row.user_name,
|
|
"is_main": row.is_main,
|
|
"is_current": row.is_current,
|
|
"is_deletable": row.is_deletable,
|
|
"created_at": row.created_at,
|
|
"updated_at": row.updated_at,
|
|
"sound_count": row.sound_count or 0,
|
|
"total_duration": row.total_duration or 0,
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
except Exception:
|
|
logger.exception(
|
|
(
|
|
"Failed to search and sort playlists: "
|
|
"query=%s, sort_by=%s, sort_order=%s"
|
|
),
|
|
search_query,
|
|
sort_by,
|
|
sort_order,
|
|
)
|
|
raise
|
|
else:
|
|
return playlists
|