- Removed unnecessary blank lines and adjusted formatting in test files. - Ensured consistent use of commas in function calls and assertions across various test cases. - Updated import statements for better organization and clarity. - Enhanced mock setups in tests for better isolation and reliability. - Improved assertions to follow a consistent style for better readability.
384 lines
12 KiB
Python
384 lines
12 KiB
Python
"""Credit management service for tracking and validating user credit usage."""
|
|
|
|
import json
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
|
|
from app.models.credit_transaction import CreditTransaction
|
|
from app.models.user import User
|
|
from app.repositories.user import UserRepository
|
|
from app.services.socket import socket_manager
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class InsufficientCreditsError(Exception):
|
|
"""Raised when user has insufficient credits for an action."""
|
|
|
|
def __init__(self, required: int, available: int) -> None:
|
|
"""Initialize the error.
|
|
|
|
Args:
|
|
required: Number of credits required
|
|
available: Number of credits available
|
|
|
|
"""
|
|
self.required = required
|
|
self.available = available
|
|
super().__init__(
|
|
f"Insufficient credits: {required} required, {available} available",
|
|
)
|
|
|
|
|
|
class CreditService:
|
|
"""Service for managing user credits and transactions."""
|
|
|
|
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
|
"""Initialize the credit service.
|
|
|
|
Args:
|
|
db_session_factory: Factory function to create database sessions
|
|
|
|
"""
|
|
self.db_session_factory = db_session_factory
|
|
|
|
async def check_credits(
|
|
self,
|
|
user_id: int,
|
|
action_type: CreditActionType,
|
|
) -> bool:
|
|
"""Check if user has sufficient credits for an action.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
action_type: The type of action to check
|
|
|
|
Returns:
|
|
True if user has sufficient credits, False otherwise
|
|
|
|
"""
|
|
action = get_credit_action(action_type)
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
return False
|
|
return user.credits >= action.cost
|
|
finally:
|
|
await session.close()
|
|
|
|
async def validate_and_reserve_credits(
|
|
self,
|
|
user_id: int,
|
|
action_type: CreditActionType,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> tuple[User, CreditAction]:
|
|
"""Validate user has sufficient credits and optionally reserve them.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
action_type: The type of action
|
|
metadata: Optional metadata to store with transaction
|
|
|
|
Returns:
|
|
Tuple of (user, credit_action)
|
|
|
|
Raises:
|
|
InsufficientCreditsError: If user has insufficient credits
|
|
|
|
"""
|
|
action = get_credit_action(action_type)
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
if user.credits < action.cost:
|
|
raise InsufficientCreditsError(action.cost, user.credits)
|
|
|
|
logger.info(
|
|
"Credits validated for user %s: %s credits available, %s required",
|
|
user_id,
|
|
user.credits,
|
|
action.cost,
|
|
)
|
|
return user, action
|
|
finally:
|
|
await session.close()
|
|
|
|
async def deduct_credits(
|
|
self,
|
|
user_id: int,
|
|
action_type: CreditActionType,
|
|
success: bool = True,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> CreditTransaction:
|
|
"""Deduct credits from user account and record transaction.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
action_type: The type of action
|
|
success: Whether the action was successful
|
|
metadata: Optional metadata to store with transaction
|
|
|
|
Returns:
|
|
The created credit transaction
|
|
|
|
Raises:
|
|
InsufficientCreditsError: If user has insufficient credits
|
|
ValueError: If user not found
|
|
|
|
"""
|
|
action = get_credit_action(action_type)
|
|
|
|
# Only deduct if action requires success and was successful, or doesn't require success
|
|
should_deduct = (action.requires_success and success) or not action.requires_success
|
|
|
|
if not should_deduct:
|
|
logger.info(
|
|
"Skipping credit deduction for user %s: action %s failed and requires success",
|
|
user_id,
|
|
action_type.value,
|
|
)
|
|
# Still create a transaction record for auditing
|
|
return await self._create_transaction_record(
|
|
user_id, action, 0, success, metadata,
|
|
)
|
|
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
if user.credits < action.cost:
|
|
raise InsufficientCreditsError(action.cost, user.credits)
|
|
|
|
# Record transaction
|
|
balance_before = user.credits
|
|
balance_after = user.credits - action.cost
|
|
|
|
transaction = CreditTransaction(
|
|
user_id=user_id,
|
|
action_type=action_type.value,
|
|
amount=-action.cost,
|
|
balance_before=balance_before,
|
|
balance_after=balance_after,
|
|
description=action.description,
|
|
success=success,
|
|
metadata_json=json.dumps(metadata) if metadata else None,
|
|
)
|
|
|
|
# Update user credits
|
|
await user_repo.update(user, {"credits": balance_after})
|
|
|
|
# Save transaction
|
|
session.add(transaction)
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
|
|
user_id,
|
|
action.cost,
|
|
action_type.value,
|
|
success,
|
|
)
|
|
|
|
# Emit user_credits_changed event via WebSocket
|
|
try:
|
|
event_data = {
|
|
"user_id": str(user_id),
|
|
"credits_before": balance_before,
|
|
"credits_after": balance_after,
|
|
"credits_deducted": action.cost,
|
|
"action_type": action_type.value,
|
|
"success": success,
|
|
}
|
|
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
|
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to emit user_credits_changed event for user %s", user_id,
|
|
)
|
|
|
|
return transaction
|
|
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
async def add_credits(
|
|
self,
|
|
user_id: int,
|
|
amount: int,
|
|
description: str,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> CreditTransaction:
|
|
"""Add credits to user account.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
amount: Number of credits to add
|
|
description: Description of the credit addition
|
|
metadata: Optional metadata to store with transaction
|
|
|
|
Returns:
|
|
The created credit transaction
|
|
|
|
Raises:
|
|
ValueError: If user not found or amount is negative
|
|
|
|
"""
|
|
if amount <= 0:
|
|
msg = "Amount must be positive"
|
|
raise ValueError(msg)
|
|
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
# Record transaction
|
|
balance_before = user.credits
|
|
balance_after = user.credits + amount
|
|
|
|
transaction = CreditTransaction(
|
|
user_id=user_id,
|
|
action_type="credit_addition",
|
|
amount=amount,
|
|
balance_before=balance_before,
|
|
balance_after=balance_after,
|
|
description=description,
|
|
success=True,
|
|
metadata_json=json.dumps(metadata) if metadata else None,
|
|
)
|
|
|
|
# Update user credits
|
|
await user_repo.update(user, {"credits": balance_after})
|
|
|
|
# Save transaction
|
|
session.add(transaction)
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
"Credits added for user %s: %s credits (description: %s)",
|
|
user_id,
|
|
amount,
|
|
description,
|
|
)
|
|
|
|
# Emit user_credits_changed event via WebSocket
|
|
try:
|
|
event_data = {
|
|
"user_id": str(user_id),
|
|
"credits_before": balance_before,
|
|
"credits_after": balance_after,
|
|
"credits_added": amount,
|
|
"description": description,
|
|
"success": True,
|
|
}
|
|
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
|
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to emit user_credits_changed event for user %s", user_id,
|
|
)
|
|
|
|
return transaction
|
|
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
async def _create_transaction_record(
|
|
self,
|
|
user_id: int,
|
|
action: CreditAction,
|
|
amount: int,
|
|
success: bool,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> CreditTransaction:
|
|
"""Create a transaction record without modifying credits.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
action: The credit action
|
|
amount: Amount to record (typically 0 for failed actions)
|
|
success: Whether the action was successful
|
|
metadata: Optional metadata
|
|
|
|
Returns:
|
|
The created transaction record
|
|
|
|
"""
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
transaction = CreditTransaction(
|
|
user_id=user_id,
|
|
action_type=action.action_type.value,
|
|
amount=amount,
|
|
balance_before=user.credits,
|
|
balance_after=user.credits,
|
|
description=f"{action.description} (failed)" if not success else action.description,
|
|
success=success,
|
|
metadata_json=json.dumps(metadata) if metadata else None,
|
|
)
|
|
|
|
session.add(transaction)
|
|
await session.commit()
|
|
|
|
return transaction
|
|
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
async def get_user_balance(self, user_id: int) -> int:
|
|
"""Get current credit balance for a user.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
|
|
Returns:
|
|
Current credit balance
|
|
|
|
Raises:
|
|
ValueError: If user not found
|
|
|
|
"""
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
return user.credits
|
|
finally:
|
|
await session.close()
|