"""Credit management service for tracking and validating user credit usage.""" import json from collections.abc import Callable from typing import Any from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.credit_action import CreditAction, CreditActionType, get_credit_action from app.models.credit_transaction import CreditTransaction from app.models.user import User from app.repositories.user import UserRepository from app.services.socket import socket_manager logger = get_logger(__name__) class InsufficientCreditsError(Exception): """Raised when user has insufficient credits for an action.""" def __init__(self, required: int, available: int) -> None: """Initialize the error. Args: required: Number of credits required available: Number of credits available """ self.required = required self.available = available super().__init__( f"Insufficient credits: {required} required, {available} available", ) class CreditService: """Service for managing user credits and transactions.""" def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None: """Initialize the credit service. Args: db_session_factory: Factory function to create database sessions """ self.db_session_factory = db_session_factory async def check_credits( self, user_id: int, action_type: CreditActionType, ) -> bool: """Check if user has sufficient credits for an action. Args: user_id: The user ID action_type: The type of action to check Returns: True if user has sufficient credits, False otherwise """ action = get_credit_action(action_type) session = self.db_session_factory() try: user_repo = UserRepository(session) user = await user_repo.get_by_id(user_id) if not user: return False return user.credits >= action.cost finally: await session.close() async def validate_and_reserve_credits( self, user_id: int, action_type: CreditActionType, metadata: dict[str, Any] | None = None, ) -> tuple[User, CreditAction]: """Validate user has sufficient credits and optionally reserve them. Args: user_id: The user ID action_type: The type of action metadata: Optional metadata to store with transaction Returns: Tuple of (user, credit_action) Raises: InsufficientCreditsError: If user has insufficient credits """ action = get_credit_action(action_type) session = self.db_session_factory() try: user_repo = UserRepository(session) user = await user_repo.get_by_id(user_id) if not user: msg = f"User {user_id} not found" raise ValueError(msg) if user.credits < action.cost: raise InsufficientCreditsError(action.cost, user.credits) logger.info( "Credits validated for user %s: %s credits available, %s required", user_id, user.credits, action.cost, ) return user, action finally: await session.close() async def deduct_credits( self, user_id: int, action_type: CreditActionType, success: bool = True, metadata: dict[str, Any] | None = None, ) -> CreditTransaction: """Deduct credits from user account and record transaction. Args: user_id: The user ID action_type: The type of action success: Whether the action was successful metadata: Optional metadata to store with transaction Returns: The created credit transaction Raises: InsufficientCreditsError: If user has insufficient credits ValueError: If user not found """ action = get_credit_action(action_type) # Only deduct if action requires success and was successful, or doesn't require success should_deduct = (action.requires_success and success) or not action.requires_success if not should_deduct: logger.info( "Skipping credit deduction for user %s: action %s failed and requires success", user_id, action_type.value, ) # Still create a transaction record for auditing return await self._create_transaction_record( user_id, action, 0, success, metadata, ) session = self.db_session_factory() try: user_repo = UserRepository(session) user = await user_repo.get_by_id(user_id) if not user: msg = f"User {user_id} not found" raise ValueError(msg) if user.credits < action.cost: raise InsufficientCreditsError(action.cost, user.credits) # Record transaction balance_before = user.credits balance_after = user.credits - action.cost transaction = CreditTransaction( user_id=user_id, action_type=action_type.value, amount=-action.cost, balance_before=balance_before, balance_after=balance_after, description=action.description, success=success, metadata_json=json.dumps(metadata) if metadata else None, ) # Update user credits await user_repo.update(user, {"credits": balance_after}) # Save transaction session.add(transaction) await session.commit() logger.info( "Credits deducted for user %s: %s credits (action: %s, success: %s)", user_id, action.cost, action_type.value, success, ) # Emit user_credits_changed event via WebSocket try: event_data = { "user_id": str(user_id), "credits_before": balance_before, "credits_after": balance_after, "credits_deducted": action.cost, "action_type": action_type.value, "success": success, } await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data) logger.info("Emitted user_credits_changed event for user %s", user_id) except Exception: logger.exception( "Failed to emit user_credits_changed event for user %s", user_id, ) return transaction except Exception: await session.rollback() raise finally: await session.close() async def add_credits( self, user_id: int, amount: int, description: str, metadata: dict[str, Any] | None = None, ) -> CreditTransaction: """Add credits to user account. Args: user_id: The user ID amount: Number of credits to add description: Description of the credit addition metadata: Optional metadata to store with transaction Returns: The created credit transaction Raises: ValueError: If user not found or amount is negative """ if amount <= 0: msg = "Amount must be positive" raise ValueError(msg) session = self.db_session_factory() try: user_repo = UserRepository(session) user = await user_repo.get_by_id(user_id) if not user: msg = f"User {user_id} not found" raise ValueError(msg) # Record transaction balance_before = user.credits balance_after = user.credits + amount transaction = CreditTransaction( user_id=user_id, action_type="credit_addition", amount=amount, balance_before=balance_before, balance_after=balance_after, description=description, success=True, metadata_json=json.dumps(metadata) if metadata else None, ) # Update user credits await user_repo.update(user, {"credits": balance_after}) # Save transaction session.add(transaction) await session.commit() logger.info( "Credits added for user %s: %s credits (description: %s)", user_id, amount, description, ) # Emit user_credits_changed event via WebSocket try: event_data = { "user_id": str(user_id), "credits_before": balance_before, "credits_after": balance_after, "credits_added": amount, "description": description, "success": True, } await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data) logger.info("Emitted user_credits_changed event for user %s", user_id) except Exception: logger.exception( "Failed to emit user_credits_changed event for user %s", user_id, ) return transaction except Exception: await session.rollback() raise finally: await session.close() async def _create_transaction_record( self, user_id: int, action: CreditAction, amount: int, success: bool, metadata: dict[str, Any] | None = None, ) -> CreditTransaction: """Create a transaction record without modifying credits. Args: user_id: The user ID action: The credit action amount: Amount to record (typically 0 for failed actions) success: Whether the action was successful metadata: Optional metadata Returns: The created transaction record """ session = self.db_session_factory() try: user_repo = UserRepository(session) user = await user_repo.get_by_id(user_id) if not user: msg = f"User {user_id} not found" raise ValueError(msg) transaction = CreditTransaction( user_id=user_id, action_type=action.action_type.value, amount=amount, balance_before=user.credits, balance_after=user.credits, description=f"{action.description} (failed)" if not success else action.description, success=success, metadata_json=json.dumps(metadata) if metadata else None, ) session.add(transaction) await session.commit() return transaction except Exception: await session.rollback() raise finally: await session.close() async def get_user_balance(self, user_id: int) -> int: """Get current credit balance for a user. Args: user_id: The user ID Returns: Current credit balance Raises: ValueError: If user not found """ session = self.db_session_factory() try: user_repo = UserRepository(session) user = await user_repo.get_by_id(user_id) if not user: msg = f"User {user_id} not found" raise ValueError(msg) return user.credits finally: await session.close()