- Adjusted function signatures in various test files to enhance clarity by aligning parameters. - Updated patching syntax for better readability across test cases. - Improved formatting and spacing in test assertions and mock setups. - Ensured consistent use of async/await patterns in async test functions. - Enhanced comments for better understanding of test intentions.
108 lines
3.1 KiB
Python
108 lines
3.1 KiB
Python
"""Repository for credit transaction database operations."""
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.credit_transaction import CreditTransaction
|
|
from app.repositories.base import BaseRepository
|
|
|
|
|
|
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
|
|
"""Repository for credit transaction operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the repository.
|
|
|
|
Args:
|
|
session: Database session
|
|
|
|
"""
|
|
super().__init__(CreditTransaction, session)
|
|
|
|
async def get_by_user_id(
|
|
self,
|
|
user_id: int,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> list[CreditTransaction]:
|
|
"""Get credit transactions for a user.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
limit: Maximum number of transactions to return
|
|
offset: Number of transactions to skip
|
|
|
|
Returns:
|
|
List of credit transactions ordered by creation date (newest first)
|
|
|
|
"""
|
|
stmt = (
|
|
select(CreditTransaction)
|
|
.where(CreditTransaction.user_id == user_id)
|
|
.order_by(CreditTransaction.created_at.desc())
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
result = await self.session.exec(stmt)
|
|
return list(result.all())
|
|
|
|
async def get_by_action_type(
|
|
self,
|
|
action_type: str,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> list[CreditTransaction]:
|
|
"""Get credit transactions by action type.
|
|
|
|
Args:
|
|
action_type: The action type to filter by
|
|
limit: Maximum number of transactions to return
|
|
offset: Number of transactions to skip
|
|
|
|
Returns:
|
|
List of credit transactions ordered by creation date (newest first)
|
|
|
|
"""
|
|
stmt = (
|
|
select(CreditTransaction)
|
|
.where(CreditTransaction.action_type == action_type)
|
|
.order_by(CreditTransaction.created_at.desc())
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
result = await self.session.exec(stmt)
|
|
return list(result.all())
|
|
|
|
async def get_successful_transactions(
|
|
self,
|
|
user_id: int | None = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> list[CreditTransaction]:
|
|
"""Get successful credit transactions.
|
|
|
|
Args:
|
|
user_id: Optional user ID to filter by
|
|
limit: Maximum number of transactions to return
|
|
offset: Number of transactions to skip
|
|
|
|
Returns:
|
|
List of successful credit transactions
|
|
|
|
"""
|
|
stmt = (
|
|
select(CreditTransaction).where(CreditTransaction.success == True) # noqa: E712
|
|
)
|
|
|
|
if user_id is not None:
|
|
stmt = stmt.where(CreditTransaction.user_id == user_id)
|
|
|
|
stmt = (
|
|
stmt.order_by(CreditTransaction.created_at.desc())
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
|
|
result = await self.session.exec(stmt)
|
|
return list(result.all())
|