refactor: Simplify repository classes by inheriting from BaseRepository and removing redundant methods

This commit is contained in:
JSC
2025-07-31 21:32:46 +02:00
parent c63997f591
commit 3405d817d5
8 changed files with 55 additions and 293 deletions

View File

@@ -38,7 +38,7 @@ class BaseRepository(Generic[ModelType]):
"""
try:
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
statement = select(self.model).where(self.model.id == entity_id)
result = await self.session.exec(statement)
return result.first()
except Exception:

View File

@@ -1,42 +1,29 @@
"""Extraction repository for database operations."""
from sqlalchemy import desc
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction
from app.repositories.base import BaseRepository
class ExtractionRepository:
class ExtractionRepository(BaseRepository[Extraction]):
"""Repository for extraction database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the extraction repository."""
self.session = session
async def create(self, extraction_data: dict) -> Extraction:
"""Create a new extraction."""
extraction = Extraction(**extraction_data)
self.session.add(extraction)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def get_by_id(self, extraction_id: int) -> Extraction | None:
"""Get an extraction by ID."""
result = await self.session.exec(
select(Extraction).where(Extraction.id == extraction_id)
)
return result.first()
super().__init__(Extraction, session)
async def get_by_service_and_id(
self, service: str, service_id: str
self, service: str, service_id: str,
) -> Extraction | None:
"""Get an extraction by service and service_id."""
result = await self.session.exec(
select(Extraction).where(
Extraction.service == service, Extraction.service_id == service_id
)
Extraction.service == service, Extraction.service_id == service_id,
),
)
return result.first()
@@ -45,7 +32,7 @@ class ExtractionRepository:
result = await self.session.exec(
select(Extraction)
.where(Extraction.user_id == user_id)
.order_by(desc(Extraction.created_at))
.order_by(desc(Extraction.created_at)),
)
return list(result.all())
@@ -54,29 +41,15 @@ class ExtractionRepository:
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == "pending")
.order_by(Extraction.created_at)
.order_by(Extraction.created_at),
)
return list(result.all())
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
"""Update an extraction."""
for key, value in update_data.items():
setattr(extraction, key, value)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def delete(self, extraction: Extraction) -> None:
"""Delete an extraction."""
await self.session.delete(extraction)
await self.session.commit()
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
"""Get extractions by status."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == status)
.order_by(desc(Extraction.created_at))
.order_by(desc(Extraction.created_at)),
)
return list(result.all())

View File

@@ -1,6 +1,5 @@
"""Playlist repository for database operations."""
from typing import Any
from sqlalchemy import func
from sqlmodel import select
@@ -10,26 +9,17 @@ from app.core.logging import get_logger
from app.models.playlist import Playlist
from app.models.playlist_sound import PlaylistSound
from app.models.sound import Sound
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class PlaylistRepository:
class PlaylistRepository(BaseRepository[Playlist]):
"""Repository for playlist operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the playlist repository."""
self.session = session
async def get_by_id(self, playlist_id: int) -> Playlist | None:
"""Get a playlist by ID."""
try:
statement = select(Playlist).where(Playlist.id == playlist_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get playlist by ID: %s", playlist_id)
raise
super().__init__(Playlist, session)
async def get_by_name(self, name: str) -> Playlist | None:
"""Get a playlist by name."""
@@ -51,16 +41,6 @@ class PlaylistRepository:
logger.exception("Failed to get playlists for user: %s", user_id)
raise
async def get_all(self) -> list[Playlist]:
"""Get all playlists from all users."""
try:
statement = select(Playlist)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all playlists")
raise
async def get_main_playlist(self) -> Playlist | None:
"""Get the global main playlist."""
try:
@@ -86,50 +66,8 @@ class PlaylistRepository:
logger.exception("Failed to get current playlist for user: %s", user_id)
raise
async def create(self, playlist_data: dict[str, Any]) -> Playlist:
"""Create a new playlist."""
try:
playlist = Playlist(**playlist_data)
self.session.add(playlist)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to create playlist")
raise
else:
logger.info("Created new playlist: %s", playlist.name)
return playlist
async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist:
"""Update a playlist."""
try:
for field, value in update_data.items():
setattr(playlist, field, value)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to update playlist")
raise
else:
logger.info("Updated playlist: %s", playlist.name)
return playlist
async def delete(self, playlist: Playlist) -> None:
"""Delete a playlist."""
try:
await self.session.delete(playlist)
await self.session.commit()
logger.info("Deleted playlist: %s", playlist.name)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete playlist")
raise
async def search_by_name(
self, query: str, user_id: int | None = None
self, query: str, user_id: int | None = None,
) -> list[Playlist]:
"""Search playlists by name (case-insensitive)."""
try:
@@ -161,14 +99,14 @@ class PlaylistRepository:
raise
async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, position: int | None = None
self, playlist_id: int, sound_id: int, position: int | None = None,
) -> PlaylistSound:
"""Add a sound to a playlist."""
try:
if position is None:
# Get the next available position
statement = select(
func.coalesce(func.max(PlaylistSound.position), -1) + 1
func.coalesce(func.max(PlaylistSound.position), -1) + 1,
).where(PlaylistSound.playlist_id == playlist_id)
result = await self.session.exec(statement)
position = result.first() or 0
@@ -184,7 +122,7 @@ class PlaylistRepository:
except Exception:
await self.session.rollback()
logger.exception(
"Failed to add sound %s to playlist %s", sound_id, playlist_id
"Failed to add sound %s to playlist %s", sound_id, playlist_id,
)
raise
else:
@@ -213,18 +151,19 @@ class PlaylistRepository:
except Exception:
await self.session.rollback()
logger.exception(
"Failed to remove sound %s from playlist %s", sound_id, playlist_id
"Failed to remove sound %s from playlist %s", sound_id, playlist_id,
)
raise
async def reorder_playlist_sounds(
self, playlist_id: int, sound_positions: list[tuple[int, int]]
self, playlist_id: int, sound_positions: list[tuple[int, int]],
) -> None:
"""Reorder sounds in a playlist.
Args:
playlist_id: The playlist ID
sound_positions: List of (sound_id, new_position) tuples
"""
try:
for sound_id, new_position in sound_positions:
@@ -249,7 +188,7 @@ class PlaylistRepository:
"""Get the number of sounds in a playlist."""
try:
statement = select(func.count(PlaylistSound.id)).where(
PlaylistSound.playlist_id == playlist_id
PlaylistSound.playlist_id == playlist_id,
)
result = await self.session.exec(statement)
return result.first() or 0
@@ -268,6 +207,6 @@ class PlaylistRepository:
return result.first() is not None
except Exception:
logger.exception(
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id,
)
raise

View File

@@ -1,33 +1,22 @@
"""Sound repository for database operations."""
from typing import Any
from sqlalchemy import desc, func
from sqlalchemy import func
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.sound import Sound
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class SoundRepository:
class SoundRepository(BaseRepository[Sound]):
"""Repository for sound operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the sound repository."""
self.session = session
async def get_by_id(self, sound_id: int) -> Sound | None:
"""Get a sound by ID."""
try:
statement = select(Sound).where(Sound.id == sound_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get sound by ID: %s", sound_id)
raise
super().__init__(Sound, session)
async def get_by_filename(self, filename: str) -> Sound | None:
"""Get a sound by filename."""
@@ -59,48 +48,6 @@ class SoundRepository:
logger.exception("Failed to get sounds by type: %s", sound_type)
raise
async def create(self, sound_data: dict[str, Any]) -> Sound:
"""Create a new sound."""
try:
sound = Sound(**sound_data)
self.session.add(sound)
await self.session.commit()
await self.session.refresh(sound)
except Exception:
await self.session.rollback()
logger.exception("Failed to create sound")
raise
else:
logger.info("Created new sound: %s", sound.name)
return sound
async def update(self, sound: Sound, update_data: dict[str, Any]) -> Sound:
"""Update a sound."""
try:
for field, value in update_data.items():
setattr(sound, field, value)
await self.session.commit()
await self.session.refresh(sound)
except Exception:
await self.session.rollback()
logger.exception("Failed to update sound")
raise
else:
logger.info("Updated sound: %s", sound.name)
return sound
async def delete(self, sound: Sound) -> None:
"""Delete a sound."""
try:
await self.session.delete(sound)
await self.session.commit()
logger.info("Deleted sound: %s", sound.name)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete sound")
raise
async def search_by_name(self, query: str) -> list[Sound]:
"""Search sounds by name (case-insensitive)."""
try:
@@ -144,6 +91,6 @@ class SoundRepository:
return list(result.all())
except Exception:
logger.exception(
"Failed to get unnormalized sounds by type: %s", sound_type
"Failed to get unnormalized sounds by type: %s", sound_type,
)
raise

View File

@@ -8,26 +8,17 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.plan import Plan
from app.models.user import User
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class UserRepository:
class UserRepository(BaseRepository[User]):
"""Repository for user operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the user repository."""
self.session = session
async def get_by_id(self, user_id: int) -> User | None:
"""Get a user by ID."""
try:
statement = select(User).where(User.id == user_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by ID: %s", user_id)
raise
super().__init__(User, session)
async def get_by_email(self, email: str) -> User | None:
"""Get a user by email address."""
@@ -50,7 +41,7 @@ class UserRepository:
raise
async def create(self, user_data: dict[str, Any]) -> User:
"""Create a new user."""
"""Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None:
msg = "Default plan not found"
@@ -84,45 +75,11 @@ class UserRepository:
user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits
user = User(**user_data)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user)
# Use BaseRepository's create method
return await super().create(user_data)
except Exception:
await self.session.rollback()
logger.exception("Failed to create user")
raise
else:
logger.info("Created new user with email: %s", user.email)
return user
async def update(self, user: User, update_data: dict[str, Any]) -> User:
"""Update a user."""
try:
for field, value in update_data.items():
setattr(user, field, value)
await self.session.commit()
await self.session.refresh(user)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user")
raise
else:
logger.info("Updated user: %s", user.email)
return user
async def delete(self, user: User) -> None:
"""Delete a user."""
try:
await self.session.delete(user)
await self.session.commit()
logger.info("Deleted user: %s", user.email)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user")
raise
async def email_exists(self, email: str) -> bool:
"""Check if an email address is already registered."""

View File

@@ -1,22 +1,22 @@
"""Repository for user OAuth operations."""
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.user_oauth import UserOauth
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class UserOauthRepository:
class UserOauthRepository(BaseRepository[UserOauth]):
"""Repository for user OAuth operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize repository with database session."""
self.session = session
super().__init__(UserOauth, session)
async def get_by_provider_user_id(
self,
@@ -61,57 +61,3 @@ class UserOauthRepository:
else:
return result.first()
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
"""Create a new user OAuth record."""
try:
oauth = UserOauth(**oauth_data)
self.session.add(oauth)
await self.session.commit()
await self.session.refresh(oauth)
logger.info(
"Created OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to create user OAuth")
raise
else:
return oauth
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
"""Update a user OAuth record."""
try:
for key, value in update_data.items():
setattr(oauth, key, value)
self.session.add(oauth)
await self.session.commit()
await self.session.refresh(oauth)
logger.info(
"Updated OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user OAuth")
raise
else:
return oauth
async def delete(self, oauth: UserOauth) -> None:
"""Delete a user OAuth record."""
try:
await self.session.delete(oauth)
await self.session.commit()
logger.info(
"Deleted OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user OAuth")
raise

View File

@@ -39,24 +39,24 @@ class TestExtractionRepository:
}
# Mock the session operations
mock_extraction = Extraction(**extraction_data, id=1)
extraction_repo.session.add = Mock()
extraction_repo.session.commit = AsyncMock()
extraction_repo.session.refresh = AsyncMock()
# Mock the Extraction constructor to return our mock
with pytest.MonkeyPatch().context() as m:
m.setattr(
"app.repositories.extraction.Extraction",
lambda **kwargs: mock_extraction,
)
result = await extraction_repo.create(extraction_data)
assert result == mock_extraction
# Verify the result has the expected attributes
assert result.url == extraction_data["url"]
assert result.user_id == extraction_data["user_id"]
assert result.service == extraction_data["service"]
assert result.service_id == extraction_data["service_id"]
assert result.title == extraction_data["title"]
assert result.status == extraction_data["status"]
# Verify session methods were called
extraction_repo.session.add.assert_called_once()
extraction_repo.session.commit.assert_called_once()
extraction_repo.session.refresh.assert_called_once_with(mock_extraction)
extraction_repo.session.refresh.assert_called_once()
@pytest.mark.asyncio
async def test_get_by_service_and_id(self, extraction_repo):