refactor: Simplify repository classes by inheriting from BaseRepository and removing redundant methods

This commit is contained in:
JSC
2025-07-31 21:32:46 +02:00
parent c63997f591
commit 3405d817d5
8 changed files with 55 additions and 293 deletions

View File

@@ -1,6 +1,5 @@
"""Playlist repository for database operations."""
from typing import Any
from sqlalchemy import func
from sqlmodel import select
@@ -10,26 +9,17 @@ from app.core.logging import get_logger
from app.models.playlist import Playlist
from app.models.playlist_sound import PlaylistSound
from app.models.sound import Sound
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class PlaylistRepository:
class PlaylistRepository(BaseRepository[Playlist]):
"""Repository for playlist operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the playlist repository."""
self.session = session
async def get_by_id(self, playlist_id: int) -> Playlist | None:
"""Get a playlist by ID."""
try:
statement = select(Playlist).where(Playlist.id == playlist_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get playlist by ID: %s", playlist_id)
raise
super().__init__(Playlist, session)
async def get_by_name(self, name: str) -> Playlist | None:
"""Get a playlist by name."""
@@ -51,16 +41,6 @@ class PlaylistRepository:
logger.exception("Failed to get playlists for user: %s", user_id)
raise
async def get_all(self) -> list[Playlist]:
"""Get all playlists from all users."""
try:
statement = select(Playlist)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all playlists")
raise
async def get_main_playlist(self) -> Playlist | None:
"""Get the global main playlist."""
try:
@@ -86,50 +66,8 @@ class PlaylistRepository:
logger.exception("Failed to get current playlist for user: %s", user_id)
raise
async def create(self, playlist_data: dict[str, Any]) -> Playlist:
"""Create a new playlist."""
try:
playlist = Playlist(**playlist_data)
self.session.add(playlist)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to create playlist")
raise
else:
logger.info("Created new playlist: %s", playlist.name)
return playlist
async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist:
"""Update a playlist."""
try:
for field, value in update_data.items():
setattr(playlist, field, value)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to update playlist")
raise
else:
logger.info("Updated playlist: %s", playlist.name)
return playlist
async def delete(self, playlist: Playlist) -> None:
"""Delete a playlist."""
try:
await self.session.delete(playlist)
await self.session.commit()
logger.info("Deleted playlist: %s", playlist.name)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete playlist")
raise
async def search_by_name(
self, query: str, user_id: int | None = None
self, query: str, user_id: int | None = None,
) -> list[Playlist]:
"""Search playlists by name (case-insensitive)."""
try:
@@ -161,14 +99,14 @@ class PlaylistRepository:
raise
async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, position: int | None = None
self, playlist_id: int, sound_id: int, position: int | None = None,
) -> PlaylistSound:
"""Add a sound to a playlist."""
try:
if position is None:
# Get the next available position
statement = select(
func.coalesce(func.max(PlaylistSound.position), -1) + 1
func.coalesce(func.max(PlaylistSound.position), -1) + 1,
).where(PlaylistSound.playlist_id == playlist_id)
result = await self.session.exec(statement)
position = result.first() or 0
@@ -184,7 +122,7 @@ class PlaylistRepository:
except Exception:
await self.session.rollback()
logger.exception(
"Failed to add sound %s to playlist %s", sound_id, playlist_id
"Failed to add sound %s to playlist %s", sound_id, playlist_id,
)
raise
else:
@@ -213,18 +151,19 @@ class PlaylistRepository:
except Exception:
await self.session.rollback()
logger.exception(
"Failed to remove sound %s from playlist %s", sound_id, playlist_id
"Failed to remove sound %s from playlist %s", sound_id, playlist_id,
)
raise
async def reorder_playlist_sounds(
self, playlist_id: int, sound_positions: list[tuple[int, int]]
self, playlist_id: int, sound_positions: list[tuple[int, int]],
) -> None:
"""Reorder sounds in a playlist.
Args:
playlist_id: The playlist ID
sound_positions: List of (sound_id, new_position) tuples
"""
try:
for sound_id, new_position in sound_positions:
@@ -249,7 +188,7 @@ class PlaylistRepository:
"""Get the number of sounds in a playlist."""
try:
statement = select(func.count(PlaylistSound.id)).where(
PlaylistSound.playlist_id == playlist_id
PlaylistSound.playlist_id == playlist_id,
)
result = await self.session.exec(statement)
return result.first() or 0
@@ -268,6 +207,6 @@ class PlaylistRepository:
return result.first() is not None
except Exception:
logger.exception(
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id,
)
raise