Add tests for sound repository, user OAuth repository, credit service, and credit decorators
- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities. - Create tests for UserOauthRepository to validate OAuth record management. - Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits. - Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
383
app/services/credit.py
Normal file
383
app/services/credit.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""Credit management service for tracking and validating user credit usage."""
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.services.socket import socket_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class InsufficientCreditsError(Exception):
|
||||
"""Raised when user has insufficient credits for an action."""
|
||||
|
||||
def __init__(self, required: int, available: int) -> None:
|
||||
"""Initialize the error.
|
||||
|
||||
Args:
|
||||
required: Number of credits required
|
||||
available: Number of credits available
|
||||
|
||||
"""
|
||||
self.required = required
|
||||
self.available = available
|
||||
super().__init__(
|
||||
f"Insufficient credits: {required} required, {available} available"
|
||||
)
|
||||
|
||||
|
||||
class CreditService:
|
||||
"""Service for managing user credits and transactions."""
|
||||
|
||||
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
||||
"""Initialize the credit service.
|
||||
|
||||
Args:
|
||||
db_session_factory: Factory function to create database sessions
|
||||
|
||||
"""
|
||||
self.db_session_factory = db_session_factory
|
||||
|
||||
async def check_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
) -> bool:
|
||||
"""Check if user has sufficient credits for an action.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action to check
|
||||
|
||||
Returns:
|
||||
True if user has sufficient credits, False otherwise
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
return user.credits >= action.cost
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def validate_and_reserve_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> tuple[User, CreditAction]:
|
||||
"""Validate user has sufficient credits and optionally reserve them.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
Tuple of (user, credit_action)
|
||||
|
||||
Raises:
|
||||
InsufficientCreditsError: If user has insufficient credits
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if user.credits < action.cost:
|
||||
raise InsufficientCreditsError(action.cost, user.credits)
|
||||
|
||||
logger.info(
|
||||
"Credits validated for user %s: %s credits available, %s required",
|
||||
user_id,
|
||||
user.credits,
|
||||
action.cost,
|
||||
)
|
||||
return user, action
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def deduct_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
success: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Deduct credits from user account and record transaction.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action
|
||||
success: Whether the action was successful
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
The created credit transaction
|
||||
|
||||
Raises:
|
||||
InsufficientCreditsError: If user has insufficient credits
|
||||
ValueError: If user not found
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
|
||||
# Only deduct if action requires success and was successful, or doesn't require success
|
||||
should_deduct = (action.requires_success and success) or not action.requires_success
|
||||
|
||||
if not should_deduct:
|
||||
logger.info(
|
||||
"Skipping credit deduction for user %s: action %s failed and requires success",
|
||||
user_id,
|
||||
action_type.value,
|
||||
)
|
||||
# Still create a transaction record for auditing
|
||||
return await self._create_transaction_record(
|
||||
user_id, action, 0, success, metadata
|
||||
)
|
||||
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if user.credits < action.cost:
|
||||
raise InsufficientCreditsError(action.cost, user.credits)
|
||||
|
||||
# Record transaction
|
||||
balance_before = user.credits
|
||||
balance_after = user.credits - action.cost
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type=action_type.value,
|
||||
amount=-action.cost,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance_after,
|
||||
description=action.description,
|
||||
success=success,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
# Update user credits
|
||||
await user_repo.update(user, {"credits": balance_after})
|
||||
|
||||
# Save transaction
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
|
||||
user_id,
|
||||
action.cost,
|
||||
action_type.value,
|
||||
success,
|
||||
)
|
||||
|
||||
# Emit user_credits_changed event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"user_id": str(user_id),
|
||||
"credits_before": balance_before,
|
||||
"credits_after": balance_after,
|
||||
"credits_deducted": action.cost,
|
||||
"action_type": action_type.value,
|
||||
"success": success,
|
||||
}
|
||||
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def add_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
amount: int,
|
||||
description: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Add credits to user account.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
amount: Number of credits to add
|
||||
description: Description of the credit addition
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
The created credit transaction
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found or amount is negative
|
||||
|
||||
"""
|
||||
if amount <= 0:
|
||||
msg = "Amount must be positive"
|
||||
raise ValueError(msg)
|
||||
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Record transaction
|
||||
balance_before = user.credits
|
||||
balance_after = user.credits + amount
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type="credit_addition",
|
||||
amount=amount,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance_after,
|
||||
description=description,
|
||||
success=True,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
# Update user credits
|
||||
await user_repo.update(user, {"credits": balance_after})
|
||||
|
||||
# Save transaction
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Credits added for user %s: %s credits (description: %s)",
|
||||
user_id,
|
||||
amount,
|
||||
description,
|
||||
)
|
||||
|
||||
# Emit user_credits_changed event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"user_id": str(user_id),
|
||||
"credits_before": balance_before,
|
||||
"credits_after": balance_after,
|
||||
"credits_added": amount,
|
||||
"description": description,
|
||||
"success": True,
|
||||
}
|
||||
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def _create_transaction_record(
|
||||
self,
|
||||
user_id: int,
|
||||
action: CreditAction,
|
||||
amount: int,
|
||||
success: bool,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Create a transaction record without modifying credits.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action: The credit action
|
||||
amount: Amount to record (typically 0 for failed actions)
|
||||
success: Whether the action was successful
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
The created transaction record
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type=action.action_type.value,
|
||||
amount=amount,
|
||||
balance_before=user.credits,
|
||||
balance_after=user.credits,
|
||||
description=f"{action.description} (failed)" if not success else action.description,
|
||||
success=success,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def get_user_balance(self, user_id: int) -> int:
|
||||
"""Get current credit balance for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
Current credit balance
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
return user.credits
|
||||
finally:
|
||||
await session.close()
|
||||
Reference in New Issue
Block a user