Add tests for sound repository, user OAuth repository, credit service, and credit decorators
- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities. - Create tests for UserOauthRepository to validate OAuth record management. - Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits. - Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
132
app/repositories/base.py
Normal file
132
app/repositories/base.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
# Type variable for the model
|
||||
ModelType = TypeVar("ModelType")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
|
||||
"""Initialize the repository.
|
||||
|
||||
Args:
|
||||
model: The SQLModel class
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
self.model = model
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, entity_id: int) -> ModelType | None:
|
||||
"""Get an entity by ID.
|
||||
|
||||
Args:
|
||||
entity_id: The entity ID
|
||||
|
||||
Returns:
|
||||
The entity if found, None otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
|
||||
raise
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ModelType]:
|
||||
"""Get all entities with pagination.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entities to return
|
||||
offset: Number of entities to skip
|
||||
|
||||
Returns:
|
||||
List of entities
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).limit(limit).offset(offset)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get all %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
||||
"""Create a new entity.
|
||||
|
||||
Args:
|
||||
entity_data: Dictionary of entity data
|
||||
|
||||
Returns:
|
||||
The created entity
|
||||
|
||||
"""
|
||||
try:
|
||||
entity = self.model(**entity_data)
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
|
||||
"""Update an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to update
|
||||
update_data: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
The updated entity
|
||||
|
||||
"""
|
||||
try:
|
||||
for field, value in update_data.items():
|
||||
setattr(entity, field, value)
|
||||
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def delete(self, entity: ModelType) -> None:
|
||||
"""Delete an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to delete
|
||||
|
||||
"""
|
||||
try:
|
||||
await self.session.delete(entity)
|
||||
await self.session.commit()
|
||||
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete %s", self.model.__name__)
|
||||
raise
|
||||
Reference in New Issue
Block a user