Add tests for sound repository, user OAuth repository, credit service, and credit decorators

- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities.
- Create tests for UserOauthRepository to validate OAuth record management.
- Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits.
- Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
JSC
2025-07-30 21:33:55 +02:00
parent dd10ef5d41
commit e43650c26c
14 changed files with 2692 additions and 1 deletions

132
app/repositories/base.py Normal file
View File

@@ -0,0 +1,132 @@
"""Base repository with common CRUD operations."""
from typing import Any, Generic, TypeVar
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
# Type variable for the model
ModelType = TypeVar("ModelType")
logger = get_logger(__name__)
class BaseRepository(Generic[ModelType]):
"""Base repository with common CRUD operations."""
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
"""Initialize the repository.
Args:
model: The SQLModel class
session: Database session
"""
self.model = model
self.session = session
async def get_by_id(self, entity_id: int) -> ModelType | None:
"""Get an entity by ID.
Args:
entity_id: The entity ID
Returns:
The entity if found, None otherwise
"""
try:
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
raise
async def get_all(
self,
limit: int = 100,
offset: int = 0,
) -> list[ModelType]:
"""Get all entities with pagination.
Args:
limit: Maximum number of entities to return
offset: Number of entities to skip
Returns:
List of entities
"""
try:
statement = select(self.model).limit(limit).offset(offset)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all %s", self.model.__name__)
raise
async def create(self, entity_data: dict[str, Any]) -> ModelType:
"""Create a new entity.
Args:
entity_data: Dictionary of entity data
Returns:
The created entity
"""
try:
entity = self.model(**entity_data)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to create %s", self.model.__name__)
raise
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
"""Update an entity.
Args:
entity: The entity to update
update_data: Dictionary of fields to update
Returns:
The updated entity
"""
try:
for field, value in update_data.items():
setattr(entity, field, value)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to update %s", self.model.__name__)
raise
async def delete(self, entity: ModelType) -> None:
"""Delete an entity.
Args:
entity: The entity to delete
"""
try:
await self.session.delete(entity)
await self.session.commit()
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
except Exception:
await self.session.rollback()
logger.exception("Failed to delete %s", self.model.__name__)
raise

View File

@@ -0,0 +1,108 @@
"""Repository for credit transaction database operations."""
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_transaction import CreditTransaction
from app.repositories.base import BaseRepository
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
"""Repository for credit transaction operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the repository.
Args:
session: Database session
"""
super().__init__(CreditTransaction, session)
async def get_by_user_id(
self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get credit transactions for a user.
Args:
user_id: The user ID
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of credit transactions ordered by creation date (newest first)
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.user_id == user_id)
.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
async def get_by_action_type(
self,
action_type: str,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get credit transactions by action type.
Args:
action_type: The action type to filter by
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of credit transactions ordered by creation date (newest first)
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.action_type == action_type)
.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
async def get_successful_transactions(
self,
user_id: int | None = None,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get successful credit transactions.
Args:
user_id: Optional user ID to filter by
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of successful credit transactions
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.success == True) # noqa: E712
)
if user_id is not None:
stmt = stmt.where(CreditTransaction.user_id == user_id)
stmt = (
stmt.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())