- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities. - Create tests for UserOauthRepository to validate OAuth record management. - Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits. - Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
132 lines
3.8 KiB
Python
132 lines
3.8 KiB
Python
"""Base repository with common CRUD operations."""
|
|
|
|
from typing import Any, Generic, TypeVar
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
|
|
# Type variable for the model
|
|
ModelType = TypeVar("ModelType")
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BaseRepository(Generic[ModelType]):
|
|
"""Base repository with common CRUD operations."""
|
|
|
|
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
|
|
"""Initialize the repository.
|
|
|
|
Args:
|
|
model: The SQLModel class
|
|
session: Database session
|
|
|
|
"""
|
|
self.model = model
|
|
self.session = session
|
|
|
|
async def get_by_id(self, entity_id: int) -> ModelType | None:
|
|
"""Get an entity by ID.
|
|
|
|
Args:
|
|
entity_id: The entity ID
|
|
|
|
Returns:
|
|
The entity if found, None otherwise
|
|
|
|
"""
|
|
try:
|
|
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
|
|
raise
|
|
|
|
async def get_all(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[ModelType]:
|
|
"""Get all entities with pagination.
|
|
|
|
Args:
|
|
limit: Maximum number of entities to return
|
|
offset: Number of entities to skip
|
|
|
|
Returns:
|
|
List of entities
|
|
|
|
"""
|
|
try:
|
|
statement = select(self.model).limit(limit).offset(offset)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get all %s", self.model.__name__)
|
|
raise
|
|
|
|
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
|
"""Create a new entity.
|
|
|
|
Args:
|
|
entity_data: Dictionary of entity data
|
|
|
|
Returns:
|
|
The created entity
|
|
|
|
"""
|
|
try:
|
|
entity = self.model(**entity_data)
|
|
self.session.add(entity)
|
|
await self.session.commit()
|
|
await self.session.refresh(entity)
|
|
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
|
return entity
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to create %s", self.model.__name__)
|
|
raise
|
|
|
|
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
|
|
"""Update an entity.
|
|
|
|
Args:
|
|
entity: The entity to update
|
|
update_data: Dictionary of fields to update
|
|
|
|
Returns:
|
|
The updated entity
|
|
|
|
"""
|
|
try:
|
|
for field, value in update_data.items():
|
|
setattr(entity, field, value)
|
|
|
|
self.session.add(entity)
|
|
await self.session.commit()
|
|
await self.session.refresh(entity)
|
|
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
|
return entity
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to update %s", self.model.__name__)
|
|
raise
|
|
|
|
async def delete(self, entity: ModelType) -> None:
|
|
"""Delete an entity.
|
|
|
|
Args:
|
|
entity: The entity to delete
|
|
|
|
"""
|
|
try:
|
|
await self.session.delete(entity)
|
|
await self.session.commit()
|
|
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to delete %s", self.model.__name__)
|
|
raise |