- Introduced a new test suite for the PlaylistService covering various functionalities including creation, retrieval, updating, and deletion of playlists. - Added tests for handling sounds within playlists, ensuring correct behavior when adding/removing sounds and managing current playlists. - Refactored socket service tests for improved readability by adjusting function signatures. - Cleaned up unnecessary whitespace in sound normalizer and sound scanner tests for consistency. - Enhanced audio utility tests to ensure accurate hash and size calculations, including edge cases for nonexistent files. - Removed redundant blank lines in cookie utility tests for cleaner code.
274 lines
10 KiB
Python
274 lines
10 KiB
Python
"""Playlist repository for database operations."""
|
|
|
|
from typing import Any
|
|
|
|
from sqlalchemy import func
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.playlist import Playlist
|
|
from app.models.playlist_sound import PlaylistSound
|
|
from app.models.sound import Sound
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class PlaylistRepository:
|
|
"""Repository for playlist operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the playlist repository."""
|
|
self.session = session
|
|
|
|
async def get_by_id(self, playlist_id: int) -> Playlist | None:
|
|
"""Get a playlist by ID."""
|
|
try:
|
|
statement = select(Playlist).where(Playlist.id == playlist_id)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get playlist by ID: %s", playlist_id)
|
|
raise
|
|
|
|
async def get_by_name(self, name: str) -> Playlist | None:
|
|
"""Get a playlist by name."""
|
|
try:
|
|
statement = select(Playlist).where(Playlist.name == name)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get playlist by name: %s", name)
|
|
raise
|
|
|
|
async def get_by_user_id(self, user_id: int) -> list[Playlist]:
|
|
"""Get all playlists for a user."""
|
|
try:
|
|
statement = select(Playlist).where(Playlist.user_id == user_id)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get playlists for user: %s", user_id)
|
|
raise
|
|
|
|
async def get_all(self) -> list[Playlist]:
|
|
"""Get all playlists from all users."""
|
|
try:
|
|
statement = select(Playlist)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get all playlists")
|
|
raise
|
|
|
|
async def get_main_playlist(self) -> Playlist | None:
|
|
"""Get the global main playlist."""
|
|
try:
|
|
statement = select(Playlist).where(
|
|
Playlist.is_main == True, # noqa: E712
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get main playlist")
|
|
raise
|
|
|
|
async def get_current_playlist(self, user_id: int) -> Playlist | None:
|
|
"""Get the user's current playlist."""
|
|
try:
|
|
statement = select(Playlist).where(
|
|
Playlist.user_id == user_id,
|
|
Playlist.is_current == True, # noqa: E712
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get current playlist for user: %s", user_id)
|
|
raise
|
|
|
|
async def create(self, playlist_data: dict[str, Any]) -> Playlist:
|
|
"""Create a new playlist."""
|
|
try:
|
|
playlist = Playlist(**playlist_data)
|
|
self.session.add(playlist)
|
|
await self.session.commit()
|
|
await self.session.refresh(playlist)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to create playlist")
|
|
raise
|
|
else:
|
|
logger.info("Created new playlist: %s", playlist.name)
|
|
return playlist
|
|
|
|
async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist:
|
|
"""Update a playlist."""
|
|
try:
|
|
for field, value in update_data.items():
|
|
setattr(playlist, field, value)
|
|
|
|
await self.session.commit()
|
|
await self.session.refresh(playlist)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to update playlist")
|
|
raise
|
|
else:
|
|
logger.info("Updated playlist: %s", playlist.name)
|
|
return playlist
|
|
|
|
async def delete(self, playlist: Playlist) -> None:
|
|
"""Delete a playlist."""
|
|
try:
|
|
await self.session.delete(playlist)
|
|
await self.session.commit()
|
|
logger.info("Deleted playlist: %s", playlist.name)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to delete playlist")
|
|
raise
|
|
|
|
async def search_by_name(
|
|
self, query: str, user_id: int | None = None
|
|
) -> list[Playlist]:
|
|
"""Search playlists by name (case-insensitive)."""
|
|
try:
|
|
statement = select(Playlist).where(
|
|
func.lower(Playlist.name).like(f"%{query.lower()}%"),
|
|
)
|
|
if user_id is not None:
|
|
statement = statement.where(Playlist.user_id == user_id)
|
|
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to search playlists by name: %s", query)
|
|
raise
|
|
|
|
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
|
|
"""Get all sounds in a playlist, ordered by position."""
|
|
try:
|
|
statement = (
|
|
select(Sound)
|
|
.join(PlaylistSound)
|
|
.where(PlaylistSound.playlist_id == playlist_id)
|
|
.order_by(PlaylistSound.position)
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get sounds for playlist: %s", playlist_id)
|
|
raise
|
|
|
|
async def add_sound_to_playlist(
|
|
self, playlist_id: int, sound_id: int, position: int | None = None
|
|
) -> PlaylistSound:
|
|
"""Add a sound to a playlist."""
|
|
try:
|
|
if position is None:
|
|
# Get the next available position
|
|
statement = select(
|
|
func.coalesce(func.max(PlaylistSound.position), -1) + 1
|
|
).where(PlaylistSound.playlist_id == playlist_id)
|
|
result = await self.session.exec(statement)
|
|
position = result.first() or 0
|
|
|
|
playlist_sound = PlaylistSound(
|
|
playlist_id=playlist_id,
|
|
sound_id=sound_id,
|
|
position=position,
|
|
)
|
|
self.session.add(playlist_sound)
|
|
await self.session.commit()
|
|
await self.session.refresh(playlist_sound)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception(
|
|
"Failed to add sound %s to playlist %s", sound_id, playlist_id
|
|
)
|
|
raise
|
|
else:
|
|
logger.info(
|
|
"Added sound %s to playlist %s at position %s",
|
|
sound_id,
|
|
playlist_id,
|
|
position,
|
|
)
|
|
return playlist_sound
|
|
|
|
async def remove_sound_from_playlist(self, playlist_id: int, sound_id: int) -> None:
|
|
"""Remove a sound from a playlist."""
|
|
try:
|
|
statement = select(PlaylistSound).where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.sound_id == sound_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
playlist_sound = result.first()
|
|
|
|
if playlist_sound:
|
|
await self.session.delete(playlist_sound)
|
|
await self.session.commit()
|
|
logger.info("Removed sound %s from playlist %s", sound_id, playlist_id)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception(
|
|
"Failed to remove sound %s from playlist %s", sound_id, playlist_id
|
|
)
|
|
raise
|
|
|
|
async def reorder_playlist_sounds(
|
|
self, playlist_id: int, sound_positions: list[tuple[int, int]]
|
|
) -> None:
|
|
"""Reorder sounds in a playlist.
|
|
|
|
Args:
|
|
playlist_id: The playlist ID
|
|
sound_positions: List of (sound_id, new_position) tuples
|
|
"""
|
|
try:
|
|
for sound_id, new_position in sound_positions:
|
|
statement = select(PlaylistSound).where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.sound_id == sound_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
playlist_sound = result.first()
|
|
|
|
if playlist_sound:
|
|
playlist_sound.position = new_position
|
|
|
|
await self.session.commit()
|
|
logger.info("Reordered sounds in playlist %s", playlist_id)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to reorder sounds in playlist %s", playlist_id)
|
|
raise
|
|
|
|
async def get_playlist_sound_count(self, playlist_id: int) -> int:
|
|
"""Get the number of sounds in a playlist."""
|
|
try:
|
|
statement = select(func.count(PlaylistSound.id)).where(
|
|
PlaylistSound.playlist_id == playlist_id
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first() or 0
|
|
except Exception:
|
|
logger.exception("Failed to get sound count for playlist: %s", playlist_id)
|
|
raise
|
|
|
|
async def is_sound_in_playlist(self, playlist_id: int, sound_id: int) -> bool:
|
|
"""Check if a sound is already in a playlist."""
|
|
try:
|
|
statement = select(PlaylistSound).where(
|
|
PlaylistSound.playlist_id == playlist_id,
|
|
PlaylistSound.sound_id == sound_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first() is not None
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id
|
|
)
|
|
raise
|