Add tests for sound repository, user OAuth repository, credit service, and credit decorators
- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities. - Create tests for UserOauthRepository to validate OAuth record management. - Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits. - Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
132
app/repositories/base.py
Normal file
132
app/repositories/base.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
# Type variable for the model
|
||||
ModelType = TypeVar("ModelType")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
|
||||
"""Initialize the repository.
|
||||
|
||||
Args:
|
||||
model: The SQLModel class
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
self.model = model
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, entity_id: int) -> ModelType | None:
|
||||
"""Get an entity by ID.
|
||||
|
||||
Args:
|
||||
entity_id: The entity ID
|
||||
|
||||
Returns:
|
||||
The entity if found, None otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
|
||||
raise
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ModelType]:
|
||||
"""Get all entities with pagination.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entities to return
|
||||
offset: Number of entities to skip
|
||||
|
||||
Returns:
|
||||
List of entities
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).limit(limit).offset(offset)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get all %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
||||
"""Create a new entity.
|
||||
|
||||
Args:
|
||||
entity_data: Dictionary of entity data
|
||||
|
||||
Returns:
|
||||
The created entity
|
||||
|
||||
"""
|
||||
try:
|
||||
entity = self.model(**entity_data)
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
|
||||
"""Update an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to update
|
||||
update_data: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
The updated entity
|
||||
|
||||
"""
|
||||
try:
|
||||
for field, value in update_data.items():
|
||||
setattr(entity, field, value)
|
||||
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def delete(self, entity: ModelType) -> None:
|
||||
"""Delete an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to delete
|
||||
|
||||
"""
|
||||
try:
|
||||
await self.session.delete(entity)
|
||||
await self.session.commit()
|
||||
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete %s", self.model.__name__)
|
||||
raise
|
||||
108
app/repositories/credit_transaction.py
Normal file
108
app/repositories/credit_transaction.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Repository for credit transaction database operations."""
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
|
||||
"""Repository for credit transaction operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the repository.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
super().__init__(CreditTransaction, session)
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get credit transactions for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of credit transactions ordered by creation date (newest first)
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.user_id == user_id)
|
||||
.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
|
||||
async def get_by_action_type(
|
||||
self,
|
||||
action_type: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get credit transactions by action type.
|
||||
|
||||
Args:
|
||||
action_type: The action type to filter by
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of credit transactions ordered by creation date (newest first)
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.action_type == action_type)
|
||||
.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
|
||||
async def get_successful_transactions(
|
||||
self,
|
||||
user_id: int | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get successful credit transactions.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter by
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of successful credit transactions
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.success == True) # noqa: E712
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(CreditTransaction.user_id == user_id)
|
||||
|
||||
stmt = (
|
||||
stmt.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
Reference in New Issue
Block a user