diff --git a/app/repositories/base.py b/app/repositories/base.py index 0040b1b..2d2b1fb 100644 --- a/app/repositories/base.py +++ b/app/repositories/base.py @@ -38,7 +38,7 @@ class BaseRepository(Generic[ModelType]): """ try: - statement = select(self.model).where(getattr(self.model, "id") == entity_id) + statement = select(self.model).where(self.model.id == entity_id) result = await self.session.exec(statement) return result.first() except Exception: @@ -129,4 +129,4 @@ class BaseRepository(Generic[ModelType]): except Exception: await self.session.rollback() logger.exception("Failed to delete %s", self.model.__name__) - raise \ No newline at end of file + raise diff --git a/app/repositories/credit_transaction.py b/app/repositories/credit_transaction.py index ecad094..b2f53b4 100644 --- a/app/repositories/credit_transaction.py +++ b/app/repositories/credit_transaction.py @@ -94,15 +94,15 @@ class CreditTransactionRepository(BaseRepository[CreditTransaction]): select(CreditTransaction) .where(CreditTransaction.success == True) # noqa: E712 ) - + if user_id is not None: stmt = stmt.where(CreditTransaction.user_id == user_id) - + stmt = ( stmt.order_by(CreditTransaction.created_at.desc()) .limit(limit) .offset(offset) ) - + result = await self.session.exec(stmt) - return list(result.all()) \ No newline at end of file + return list(result.all()) diff --git a/app/repositories/extraction.py b/app/repositories/extraction.py index e15ca93..b8cb9ba 100644 --- a/app/repositories/extraction.py +++ b/app/repositories/extraction.py @@ -1,42 +1,29 @@ """Extraction repository for database operations.""" + from sqlalchemy import desc from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.extraction import Extraction +from app.repositories.base import BaseRepository -class ExtractionRepository: +class ExtractionRepository(BaseRepository[Extraction]): """Repository for extraction database operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the extraction repository.""" - self.session = session - - async def create(self, extraction_data: dict) -> Extraction: - """Create a new extraction.""" - extraction = Extraction(**extraction_data) - self.session.add(extraction) - await self.session.commit() - await self.session.refresh(extraction) - return extraction - - async def get_by_id(self, extraction_id: int) -> Extraction | None: - """Get an extraction by ID.""" - result = await self.session.exec( - select(Extraction).where(Extraction.id == extraction_id) - ) - return result.first() + super().__init__(Extraction, session) async def get_by_service_and_id( - self, service: str, service_id: str + self, service: str, service_id: str, ) -> Extraction | None: """Get an extraction by service and service_id.""" result = await self.session.exec( select(Extraction).where( - Extraction.service == service, Extraction.service_id == service_id - ) + Extraction.service == service, Extraction.service_id == service_id, + ), ) return result.first() @@ -45,7 +32,7 @@ class ExtractionRepository: result = await self.session.exec( select(Extraction) .where(Extraction.user_id == user_id) - .order_by(desc(Extraction.created_at)) + .order_by(desc(Extraction.created_at)), ) return list(result.all()) @@ -54,29 +41,15 @@ class ExtractionRepository: result = await self.session.exec( select(Extraction) .where(Extraction.status == "pending") - .order_by(Extraction.created_at) + .order_by(Extraction.created_at), ) return list(result.all()) - async def update(self, extraction: Extraction, update_data: dict) -> Extraction: - """Update an extraction.""" - for key, value in update_data.items(): - setattr(extraction, key, value) - - await self.session.commit() - await self.session.refresh(extraction) - return extraction - - async def delete(self, extraction: Extraction) -> None: - """Delete an extraction.""" - await self.session.delete(extraction) - await self.session.commit() - async def get_extractions_by_status(self, status: str) -> list[Extraction]: """Get extractions by status.""" result = await self.session.exec( select(Extraction) .where(Extraction.status == status) - .order_by(desc(Extraction.created_at)) + .order_by(desc(Extraction.created_at)), ) return list(result.all()) diff --git a/app/repositories/playlist.py b/app/repositories/playlist.py index 4a46278..53a7f56 100644 --- a/app/repositories/playlist.py +++ b/app/repositories/playlist.py @@ -1,6 +1,5 @@ """Playlist repository for database operations.""" -from typing import Any from sqlalchemy import func from sqlmodel import select @@ -10,26 +9,17 @@ from app.core.logging import get_logger from app.models.playlist import Playlist from app.models.playlist_sound import PlaylistSound from app.models.sound import Sound +from app.repositories.base import BaseRepository logger = get_logger(__name__) -class PlaylistRepository: +class PlaylistRepository(BaseRepository[Playlist]): """Repository for playlist operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the playlist repository.""" - self.session = session - - async def get_by_id(self, playlist_id: int) -> Playlist | None: - """Get a playlist by ID.""" - try: - statement = select(Playlist).where(Playlist.id == playlist_id) - result = await self.session.exec(statement) - return result.first() - except Exception: - logger.exception("Failed to get playlist by ID: %s", playlist_id) - raise + super().__init__(Playlist, session) async def get_by_name(self, name: str) -> Playlist | None: """Get a playlist by name.""" @@ -51,16 +41,6 @@ class PlaylistRepository: logger.exception("Failed to get playlists for user: %s", user_id) raise - async def get_all(self) -> list[Playlist]: - """Get all playlists from all users.""" - try: - statement = select(Playlist) - result = await self.session.exec(statement) - return list(result.all()) - except Exception: - logger.exception("Failed to get all playlists") - raise - async def get_main_playlist(self) -> Playlist | None: """Get the global main playlist.""" try: @@ -86,50 +66,8 @@ class PlaylistRepository: logger.exception("Failed to get current playlist for user: %s", user_id) raise - async def create(self, playlist_data: dict[str, Any]) -> Playlist: - """Create a new playlist.""" - try: - playlist = Playlist(**playlist_data) - self.session.add(playlist) - await self.session.commit() - await self.session.refresh(playlist) - except Exception: - await self.session.rollback() - logger.exception("Failed to create playlist") - raise - else: - logger.info("Created new playlist: %s", playlist.name) - return playlist - - async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist: - """Update a playlist.""" - try: - for field, value in update_data.items(): - setattr(playlist, field, value) - - await self.session.commit() - await self.session.refresh(playlist) - except Exception: - await self.session.rollback() - logger.exception("Failed to update playlist") - raise - else: - logger.info("Updated playlist: %s", playlist.name) - return playlist - - async def delete(self, playlist: Playlist) -> None: - """Delete a playlist.""" - try: - await self.session.delete(playlist) - await self.session.commit() - logger.info("Deleted playlist: %s", playlist.name) - except Exception: - await self.session.rollback() - logger.exception("Failed to delete playlist") - raise - async def search_by_name( - self, query: str, user_id: int | None = None + self, query: str, user_id: int | None = None, ) -> list[Playlist]: """Search playlists by name (case-insensitive).""" try: @@ -161,14 +99,14 @@ class PlaylistRepository: raise async def add_sound_to_playlist( - self, playlist_id: int, sound_id: int, position: int | None = None + self, playlist_id: int, sound_id: int, position: int | None = None, ) -> PlaylistSound: """Add a sound to a playlist.""" try: if position is None: # Get the next available position statement = select( - func.coalesce(func.max(PlaylistSound.position), -1) + 1 + func.coalesce(func.max(PlaylistSound.position), -1) + 1, ).where(PlaylistSound.playlist_id == playlist_id) result = await self.session.exec(statement) position = result.first() or 0 @@ -184,7 +122,7 @@ class PlaylistRepository: except Exception: await self.session.rollback() logger.exception( - "Failed to add sound %s to playlist %s", sound_id, playlist_id + "Failed to add sound %s to playlist %s", sound_id, playlist_id, ) raise else: @@ -213,18 +151,19 @@ class PlaylistRepository: except Exception: await self.session.rollback() logger.exception( - "Failed to remove sound %s from playlist %s", sound_id, playlist_id + "Failed to remove sound %s from playlist %s", sound_id, playlist_id, ) raise async def reorder_playlist_sounds( - self, playlist_id: int, sound_positions: list[tuple[int, int]] + self, playlist_id: int, sound_positions: list[tuple[int, int]], ) -> None: """Reorder sounds in a playlist. Args: playlist_id: The playlist ID sound_positions: List of (sound_id, new_position) tuples + """ try: for sound_id, new_position in sound_positions: @@ -249,7 +188,7 @@ class PlaylistRepository: """Get the number of sounds in a playlist.""" try: statement = select(func.count(PlaylistSound.id)).where( - PlaylistSound.playlist_id == playlist_id + PlaylistSound.playlist_id == playlist_id, ) result = await self.session.exec(statement) return result.first() or 0 @@ -268,6 +207,6 @@ class PlaylistRepository: return result.first() is not None except Exception: logger.exception( - "Failed to check if sound %s is in playlist %s", sound_id, playlist_id + "Failed to check if sound %s is in playlist %s", sound_id, playlist_id, ) raise diff --git a/app/repositories/sound.py b/app/repositories/sound.py index afce0ad..98ad061 100644 --- a/app/repositories/sound.py +++ b/app/repositories/sound.py @@ -1,33 +1,22 @@ """Sound repository for database operations.""" -from typing import Any - -from sqlalchemy import desc, func +from sqlalchemy import func from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.sound import Sound +from app.repositories.base import BaseRepository logger = get_logger(__name__) -class SoundRepository: +class SoundRepository(BaseRepository[Sound]): """Repository for sound operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the sound repository.""" - self.session = session - - async def get_by_id(self, sound_id: int) -> Sound | None: - """Get a sound by ID.""" - try: - statement = select(Sound).where(Sound.id == sound_id) - result = await self.session.exec(statement) - return result.first() - except Exception: - logger.exception("Failed to get sound by ID: %s", sound_id) - raise + super().__init__(Sound, session) async def get_by_filename(self, filename: str) -> Sound | None: """Get a sound by filename.""" @@ -59,48 +48,6 @@ class SoundRepository: logger.exception("Failed to get sounds by type: %s", sound_type) raise - async def create(self, sound_data: dict[str, Any]) -> Sound: - """Create a new sound.""" - try: - sound = Sound(**sound_data) - self.session.add(sound) - await self.session.commit() - await self.session.refresh(sound) - except Exception: - await self.session.rollback() - logger.exception("Failed to create sound") - raise - else: - logger.info("Created new sound: %s", sound.name) - return sound - - async def update(self, sound: Sound, update_data: dict[str, Any]) -> Sound: - """Update a sound.""" - try: - for field, value in update_data.items(): - setattr(sound, field, value) - - await self.session.commit() - await self.session.refresh(sound) - except Exception: - await self.session.rollback() - logger.exception("Failed to update sound") - raise - else: - logger.info("Updated sound: %s", sound.name) - return sound - - async def delete(self, sound: Sound) -> None: - """Delete a sound.""" - try: - await self.session.delete(sound) - await self.session.commit() - logger.info("Deleted sound: %s", sound.name) - except Exception: - await self.session.rollback() - logger.exception("Failed to delete sound") - raise - async def search_by_name(self, query: str) -> list[Sound]: """Search sounds by name (case-insensitive).""" try: @@ -144,6 +91,6 @@ class SoundRepository: return list(result.all()) except Exception: logger.exception( - "Failed to get unnormalized sounds by type: %s", sound_type + "Failed to get unnormalized sounds by type: %s", sound_type, ) raise diff --git a/app/repositories/user.py b/app/repositories/user.py index bbcef80..1ac4bcb 100644 --- a/app/repositories/user.py +++ b/app/repositories/user.py @@ -8,26 +8,17 @@ from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.plan import Plan from app.models.user import User +from app.repositories.base import BaseRepository logger = get_logger(__name__) -class UserRepository: +class UserRepository(BaseRepository[User]): """Repository for user operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the user repository.""" - self.session = session - - async def get_by_id(self, user_id: int) -> User | None: - """Get a user by ID.""" - try: - statement = select(User).where(User.id == user_id) - result = await self.session.exec(statement) - return result.first() - except Exception: - logger.exception("Failed to get user by ID: %s", user_id) - raise + super().__init__(User, session) async def get_by_email(self, email: str) -> User | None: """Get a user by email address.""" @@ -50,7 +41,7 @@ class UserRepository: raise async def create(self, user_data: dict[str, Any]) -> User: - """Create a new user.""" + """Create a new user with plan assignment and first user admin logic.""" def _raise_plan_not_found() -> None: msg = "Default plan not found" @@ -84,45 +75,11 @@ class UserRepository: user_data["plan_id"] = default_plan.id user_data["credits"] = default_plan.credits - user = User(**user_data) - self.session.add(user) - await self.session.commit() - await self.session.refresh(user) + # Use BaseRepository's create method + return await super().create(user_data) except Exception: - await self.session.rollback() logger.exception("Failed to create user") raise - else: - logger.info("Created new user with email: %s", user.email) - return user - - async def update(self, user: User, update_data: dict[str, Any]) -> User: - """Update a user.""" - try: - for field, value in update_data.items(): - setattr(user, field, value) - - await self.session.commit() - await self.session.refresh(user) - except Exception: - await self.session.rollback() - logger.exception("Failed to update user") - raise - else: - logger.info("Updated user: %s", user.email) - return user - - async def delete(self, user: User) -> None: - """Delete a user.""" - try: - await self.session.delete(user) - await self.session.commit() - - logger.info("Deleted user: %s", user.email) - except Exception: - await self.session.rollback() - logger.exception("Failed to delete user") - raise async def email_exists(self, email: str) -> bool: """Check if an email address is already registered.""" diff --git a/app/repositories/user_oauth.py b/app/repositories/user_oauth.py index 7bf76b6..bcc1d14 100644 --- a/app/repositories/user_oauth.py +++ b/app/repositories/user_oauth.py @@ -1,22 +1,22 @@ """Repository for user OAuth operations.""" -from typing import Any from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.user_oauth import UserOauth +from app.repositories.base import BaseRepository logger = get_logger(__name__) -class UserOauthRepository: +class UserOauthRepository(BaseRepository[UserOauth]): """Repository for user OAuth operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize repository with database session.""" - self.session = session + super().__init__(UserOauth, session) async def get_by_provider_user_id( self, @@ -61,57 +61,3 @@ class UserOauthRepository: else: return result.first() - async def create(self, oauth_data: dict[str, Any]) -> UserOauth: - """Create a new user OAuth record.""" - try: - oauth = UserOauth(**oauth_data) - self.session.add(oauth) - await self.session.commit() - await self.session.refresh(oauth) - logger.info( - "Created OAuth link for user %s with provider %s", - oauth.user_id, - oauth.provider, - ) - except Exception: - await self.session.rollback() - logger.exception("Failed to create user OAuth") - raise - else: - return oauth - - async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth: - """Update a user OAuth record.""" - try: - for key, value in update_data.items(): - setattr(oauth, key, value) - - self.session.add(oauth) - await self.session.commit() - await self.session.refresh(oauth) - logger.info( - "Updated OAuth link for user %s with provider %s", - oauth.user_id, - oauth.provider, - ) - except Exception: - await self.session.rollback() - logger.exception("Failed to update user OAuth") - raise - else: - return oauth - - async def delete(self, oauth: UserOauth) -> None: - """Delete a user OAuth record.""" - try: - await self.session.delete(oauth) - await self.session.commit() - logger.info( - "Deleted OAuth link for user %s with provider %s", - oauth.user_id, - oauth.provider, - ) - except Exception: - await self.session.rollback() - logger.exception("Failed to delete user OAuth") - raise diff --git a/tests/repositories/test_extraction.py b/tests/repositories/test_extraction.py index dffc609..dac3019 100644 --- a/tests/repositories/test_extraction.py +++ b/tests/repositories/test_extraction.py @@ -39,24 +39,24 @@ class TestExtractionRepository: } # Mock the session operations - mock_extraction = Extraction(**extraction_data, id=1) extraction_repo.session.add = Mock() extraction_repo.session.commit = AsyncMock() extraction_repo.session.refresh = AsyncMock() - # Mock the Extraction constructor to return our mock - with pytest.MonkeyPatch().context() as m: - m.setattr( - "app.repositories.extraction.Extraction", - lambda **kwargs: mock_extraction, - ) + result = await extraction_repo.create(extraction_data) - result = await extraction_repo.create(extraction_data) - - assert result == mock_extraction - extraction_repo.session.add.assert_called_once() - extraction_repo.session.commit.assert_called_once() - extraction_repo.session.refresh.assert_called_once_with(mock_extraction) + # Verify the result has the expected attributes + assert result.url == extraction_data["url"] + assert result.user_id == extraction_data["user_id"] + assert result.service == extraction_data["service"] + assert result.service_id == extraction_data["service_id"] + assert result.title == extraction_data["title"] + assert result.status == extraction_data["status"] + + # Verify session methods were called + extraction_repo.session.add.assert_called_once() + extraction_repo.session.commit.assert_called_once() + extraction_repo.session.refresh.assert_called_once() @pytest.mark.asyncio async def test_get_by_service_and_id(self, extraction_repo):