From 13e0db1fe960fe741fd9a0b6c54483ae39e1487a Mon Sep 17 00:00:00 2001 From: JSC Date: Sun, 10 Aug 2025 21:33:06 +0200 Subject: [PATCH] feat: Add position shifting logic for adding sounds to playlists in repository --- app/repositories/playlist.py | 45 +++++++- tests/repositories/test_playlist.py | 152 ++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 1 deletion(-) diff --git a/app/repositories/playlist.py b/app/repositories/playlist.py index 2c1917f..071c0de 100644 --- a/app/repositories/playlist.py +++ b/app/repositories/playlist.py @@ -1,7 +1,7 @@ """Playlist repository for database operations.""" from enum import Enum -from sqlalchemy import func +from sqlalchemy import func, update from sqlalchemy.orm import selectinload from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -120,6 +120,20 @@ class PlaylistRepository(BaseRepository[Playlist]): logger.exception("Failed to get sounds for playlist: %s", playlist_id) raise + async def get_playlist_sound_entries(self, playlist_id: int) -> list[PlaylistSound]: + """Get all PlaylistSound entries for a playlist, ordered by position.""" + try: + statement = ( + select(PlaylistSound) + .where(PlaylistSound.playlist_id == playlist_id) + .order_by(PlaylistSound.position) + ) + result = await self.session.exec(statement) + return list(result.all()) + except Exception: + logger.exception("Failed to get playlist sound entries for playlist: %s", playlist_id) + raise + async def add_sound_to_playlist( self, playlist_id: int, @@ -135,6 +149,35 @@ class PlaylistRepository(BaseRepository[Playlist]): ).where(PlaylistSound.playlist_id == playlist_id) result = await self.session.exec(statement) position = result.first() or 0 + else: + # Shift existing positions to make room for the new sound + # Use a two-step approach to avoid unique constraint violations: + # 1. Move all affected positions to negative temporary positions + # 2. Then move them to their final positions + + # Step 1: Move to temporary negative positions + update_to_negative = ( + update(PlaylistSound) + .where( + PlaylistSound.playlist_id == playlist_id, + PlaylistSound.position >= position, + ) + .values(position=PlaylistSound.position - 10000) + ) + await self.session.exec(update_to_negative) + await self.session.commit() + + # Step 2: Move from temporary negative positions to final positions + update_to_final = ( + update(PlaylistSound) + .where( + PlaylistSound.playlist_id == playlist_id, + PlaylistSound.position < 0, + ) + .values(position=PlaylistSound.position + 10001) + ) + await self.session.exec(update_to_final) + await self.session.commit() playlist_sound = PlaylistSound( playlist_id=playlist_id, diff --git a/tests/repositories/test_playlist.py b/tests/repositories/test_playlist.py index e884df8..9a23870 100644 --- a/tests/repositories/test_playlist.py +++ b/tests/repositories/test_playlist.py @@ -480,6 +480,158 @@ class TestPlaylistRepository: assert playlist_sound.position == TEST_POSITION + @pytest.mark.asyncio + async def test_add_sound_to_playlist_with_position_shifting( + self, + playlist_repository: PlaylistRepository, + test_session: AsyncSession, + ensure_plans: Any, + ) -> None: + """Test adding a sound to a playlist with position shifting when positions are occupied.""" + # Create test user + user = User( + email="test_shifting@example.com", + name="Test User Shifting", + password_hash=PasswordUtils.hash_password("password123"), + role="user", + is_active=True, + plan_id=ensure_plans[0].id, + credits=100, + ) + test_session.add(user) + await test_session.commit() + await test_session.refresh(user) + user_id = user.id + + # Create test playlist + playlist = Playlist( + user_id=user_id, + name="Test Playlist Shifting", + description="A test playlist for position shifting", + genre="test", + is_main=False, + is_current=False, + is_deletable=True, + ) + test_session.add(playlist) + + # Create multiple sounds + sounds = [] + for i in range(3): + sound = Sound( + name=f"Test Sound {i}", + filename=f"test_{i}.mp3", + type="SDB", + duration=5000, + size=1024, + hash=f"test_hash_{i}", + play_count=0, + ) + test_session.add(sound) + sounds.append(sound) + + await test_session.commit() + await test_session.refresh(playlist) + for sound in sounds: + await test_session.refresh(sound) + + playlist_id = playlist.id + sound_ids = [s.id for s in sounds] + + # Add first two sounds sequentially (positions 0, 1) + await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0 + await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1 + + # Now insert third sound at position 1 - should shift existing sound at position 1 to position 2 + await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=1) + + # Verify the final positions + playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id) + + assert len(playlist_sounds) == 3 + assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0 + assert playlist_sounds[0].position == 0 + assert playlist_sounds[1].sound_id == sound_ids[2] # New sound 2 inserted at position 1 + assert playlist_sounds[1].position == 1 + assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2 + assert playlist_sounds[2].position == 2 + + @pytest.mark.asyncio + async def test_add_sound_to_playlist_at_position_zero( + self, + playlist_repository: PlaylistRepository, + test_session: AsyncSession, + ensure_plans: Any, + ) -> None: + """Test adding a sound at position 0 when playlist already has sounds.""" + # Create test user + user = User( + email="test_position_zero@example.com", + name="Test User Position Zero", + password_hash=PasswordUtils.hash_password("password123"), + role="user", + is_active=True, + plan_id=ensure_plans[0].id, + credits=100, + ) + test_session.add(user) + await test_session.commit() + await test_session.refresh(user) + user_id = user.id + + # Create test playlist + playlist = Playlist( + user_id=user_id, + name="Test Playlist Position Zero", + description="A test playlist for position zero insertion", + genre="test", + is_main=False, + is_current=False, + is_deletable=True, + ) + test_session.add(playlist) + + # Create multiple sounds + sounds = [] + for i in range(3): + sound = Sound( + name=f"Test Sound {i}", + filename=f"test_zero_{i}.mp3", + type="SDB", + duration=5000, + size=1024, + hash=f"test_hash_zero_{i}", + play_count=0, + ) + test_session.add(sound) + sounds.append(sound) + + await test_session.commit() + await test_session.refresh(playlist) + for sound in sounds: + await test_session.refresh(sound) + + playlist_id = playlist.id + sound_ids = [s.id for s in sounds] + + # Add first two sounds sequentially (positions 0, 1) + await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0 + await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1 + + # Now insert third sound at position 0 - should shift existing sounds to positions 1, 2 + await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=0) + + # Verify the final positions + playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id) + + assert len(playlist_sounds) == 3 + assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0 + assert playlist_sounds[0].position == 0 + assert playlist_sounds[1].sound_id == sound_ids[0] # Original sound 0 shifted to position 1 + assert playlist_sounds[1].position == 1 + assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2 + assert playlist_sounds[2].position == 2 + @pytest.mark.asyncio async def test_remove_sound_from_playlist( self,