628 lines
20 KiB
Python
628 lines
20 KiB
Python
"""Credit management service for tracking and validating user credit usage."""
|
|
|
|
import json
|
|
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
|
|
from app.models.credit_transaction import CreditTransaction
|
|
from app.models.user import User
|
|
from app.repositories.user import UserRepository
|
|
from app.services.socket import socket_manager
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class InsufficientCreditsError(Exception):
|
|
"""Raised when user has insufficient credits for an action."""
|
|
|
|
def __init__(self, required: int, available: int) -> None:
|
|
"""Initialize the error.
|
|
|
|
Args:
|
|
required: Number of credits required
|
|
available: Number of credits available
|
|
|
|
"""
|
|
self.required = required
|
|
self.available = available
|
|
super().__init__(
|
|
f"Insufficient credits: {required} required, {available} available",
|
|
)
|
|
|
|
|
|
class CreditService:
|
|
"""Service for managing user credits and transactions."""
|
|
|
|
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
|
"""Initialize the credit service.
|
|
|
|
Args:
|
|
db_session_factory: Factory function to create database sessions
|
|
|
|
"""
|
|
self.db_session_factory = db_session_factory
|
|
|
|
async def check_credits(
|
|
self,
|
|
user_id: int,
|
|
action_type: CreditActionType,
|
|
) -> bool:
|
|
"""Check if user has sufficient credits for an action.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
action_type: The type of action to check
|
|
|
|
Returns:
|
|
True if user has sufficient credits, False otherwise
|
|
|
|
"""
|
|
action = get_credit_action(action_type)
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
return False
|
|
return user.credits >= action.cost
|
|
finally:
|
|
await session.close()
|
|
|
|
async def validate_and_reserve_credits(
|
|
self,
|
|
user_id: int,
|
|
action_type: CreditActionType,
|
|
) -> tuple[User, CreditAction]:
|
|
"""Validate user has sufficient credits and optionally reserve them.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
action_type: The type of action
|
|
|
|
Returns:
|
|
Tuple of (user, credit_action)
|
|
|
|
Raises:
|
|
InsufficientCreditsError: If user has insufficient credits
|
|
|
|
"""
|
|
action = get_credit_action(action_type)
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
if user.credits < action.cost:
|
|
raise InsufficientCreditsError(action.cost, user.credits)
|
|
|
|
logger.info(
|
|
"Credits validated for user %s: %s credits available, %s required",
|
|
user_id,
|
|
user.credits,
|
|
action.cost,
|
|
)
|
|
return user, action
|
|
finally:
|
|
await session.close()
|
|
|
|
async def deduct_credits(
|
|
self,
|
|
user_id: int,
|
|
action_type: CreditActionType,
|
|
*,
|
|
success: bool = True,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> CreditTransaction:
|
|
"""Deduct credits from user account and record transaction.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
action_type: The type of action
|
|
success: Whether the action was successful
|
|
metadata: Optional metadata to store with transaction
|
|
|
|
Returns:
|
|
The created credit transaction
|
|
|
|
Raises:
|
|
InsufficientCreditsError: If user has insufficient credits
|
|
ValueError: If user not found
|
|
|
|
"""
|
|
action = get_credit_action(action_type)
|
|
|
|
# Only deduct if action requires success and was successful,
|
|
# or doesn't require success
|
|
should_deduct = (
|
|
action.requires_success and success
|
|
) or not action.requires_success
|
|
|
|
if not should_deduct:
|
|
logger.info(
|
|
"Skipping credit deduction for user %s: "
|
|
"action %s failed and requires success",
|
|
user_id,
|
|
action_type.value,
|
|
)
|
|
# Still create a transaction record for auditing
|
|
return await self._create_transaction_record(
|
|
user_id,
|
|
action,
|
|
0,
|
|
success=success,
|
|
metadata=metadata,
|
|
)
|
|
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
if user.credits < action.cost:
|
|
raise InsufficientCreditsError(action.cost, user.credits)
|
|
|
|
# Record transaction
|
|
balance_before = user.credits
|
|
balance_after = user.credits - action.cost
|
|
|
|
transaction = CreditTransaction(
|
|
user_id=user_id,
|
|
action_type=action_type.value,
|
|
amount=-action.cost,
|
|
balance_before=balance_before,
|
|
balance_after=balance_after,
|
|
description=action.description,
|
|
success=success,
|
|
metadata_json=json.dumps(metadata) if metadata else None,
|
|
)
|
|
|
|
# Update user credits
|
|
await user_repo.update(user, {"credits": balance_after})
|
|
|
|
# Save transaction
|
|
session.add(transaction)
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
|
|
user_id,
|
|
action.cost,
|
|
action_type.value,
|
|
success,
|
|
)
|
|
|
|
# Emit user_credits_changed event via WebSocket
|
|
try:
|
|
event_data = {
|
|
"user_id": str(user_id),
|
|
"credits_before": balance_before,
|
|
"credits_after": balance_after,
|
|
"credits_deducted": action.cost,
|
|
"action_type": action_type.value,
|
|
"success": success,
|
|
}
|
|
await socket_manager.send_to_user(
|
|
str(user_id),
|
|
"user_credits_changed",
|
|
event_data,
|
|
)
|
|
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to emit user_credits_changed event for user %s",
|
|
user_id,
|
|
)
|
|
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
else:
|
|
return transaction
|
|
finally:
|
|
await session.close()
|
|
|
|
async def add_credits(
|
|
self,
|
|
user_id: int,
|
|
amount: int,
|
|
description: str,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> CreditTransaction:
|
|
"""Add credits to user account.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
amount: Number of credits to add
|
|
description: Description of the credit addition
|
|
metadata: Optional metadata to store with transaction
|
|
|
|
Returns:
|
|
The created credit transaction
|
|
|
|
Raises:
|
|
ValueError: If user not found or amount is negative
|
|
|
|
"""
|
|
if amount <= 0:
|
|
msg = "Amount must be positive"
|
|
raise ValueError(msg)
|
|
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
# Record transaction
|
|
balance_before = user.credits
|
|
balance_after = user.credits + amount
|
|
|
|
transaction = CreditTransaction(
|
|
user_id=user_id,
|
|
action_type="credit_addition",
|
|
amount=amount,
|
|
balance_before=balance_before,
|
|
balance_after=balance_after,
|
|
description=description,
|
|
success=True,
|
|
metadata_json=json.dumps(metadata) if metadata else None,
|
|
)
|
|
|
|
# Update user credits
|
|
await user_repo.update(user, {"credits": balance_after})
|
|
|
|
# Save transaction
|
|
session.add(transaction)
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
"Credits added for user %s: %s credits (description: %s)",
|
|
user_id,
|
|
amount,
|
|
description,
|
|
)
|
|
|
|
# Emit user_credits_changed event via WebSocket
|
|
try:
|
|
event_data = {
|
|
"user_id": str(user_id),
|
|
"credits_before": balance_before,
|
|
"credits_after": balance_after,
|
|
"credits_added": amount,
|
|
"description": description,
|
|
"success": True,
|
|
}
|
|
await socket_manager.send_to_user(
|
|
str(user_id),
|
|
"user_credits_changed",
|
|
event_data,
|
|
)
|
|
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to emit user_credits_changed event for user %s",
|
|
user_id,
|
|
)
|
|
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
else:
|
|
return transaction
|
|
finally:
|
|
await session.close()
|
|
|
|
async def _create_transaction_record(
|
|
self,
|
|
user_id: int,
|
|
action: CreditAction,
|
|
amount: int,
|
|
*,
|
|
success: bool,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> CreditTransaction:
|
|
"""Create a transaction record without modifying credits.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
action: The credit action
|
|
amount: Amount to record (typically 0 for failed actions)
|
|
success: Whether the action was successful
|
|
metadata: Optional metadata
|
|
|
|
Returns:
|
|
The created transaction record
|
|
|
|
"""
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
transaction = CreditTransaction(
|
|
user_id=user_id,
|
|
action_type=action.action_type.value,
|
|
amount=amount,
|
|
balance_before=user.credits,
|
|
balance_after=user.credits,
|
|
description=(
|
|
f"{action.description} (failed)"
|
|
if not success
|
|
else action.description
|
|
),
|
|
success=success,
|
|
metadata_json=json.dumps(metadata) if metadata else None,
|
|
)
|
|
|
|
session.add(transaction)
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
else:
|
|
return transaction
|
|
finally:
|
|
await session.close()
|
|
|
|
async def get_user_balance(self, user_id: int) -> int:
|
|
"""Get current credit balance for a user.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
|
|
Returns:
|
|
Current credit balance
|
|
|
|
Raises:
|
|
ValueError: If user not found
|
|
|
|
"""
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
return user.credits
|
|
finally:
|
|
await session.close()
|
|
|
|
async def recharge_user_credits_auto(
|
|
self,
|
|
user_id: int,
|
|
) -> CreditTransaction | None:
|
|
"""Recharge credits for a user automatically based on their plan.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
|
|
Returns:
|
|
The created credit transaction if credits were added, None if no recharge
|
|
needed
|
|
|
|
Raises:
|
|
ValueError: If user not found or has no plan
|
|
|
|
"""
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id_with_plan(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
if not user.plan:
|
|
msg = f"User {user_id} has no plan assigned"
|
|
raise ValueError(msg)
|
|
|
|
# Call the main method with plan details
|
|
return await self.recharge_user_credits(
|
|
user_id,
|
|
user.plan.credits,
|
|
user.plan.max_credits,
|
|
)
|
|
finally:
|
|
await session.close()
|
|
|
|
async def recharge_user_credits(
|
|
self,
|
|
user_id: int,
|
|
plan_credits: int,
|
|
max_credits: int,
|
|
) -> CreditTransaction | None:
|
|
"""Recharge credits for a user based on their plan.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
plan_credits: Number of credits from the plan
|
|
max_credits: Maximum credits allowed for the plan
|
|
|
|
Returns:
|
|
The created credit transaction if credits were added, None if no recharge
|
|
needed
|
|
|
|
Raises:
|
|
ValueError: If user not found
|
|
|
|
"""
|
|
session = self.db_session_factory()
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
user = await user_repo.get_by_id(user_id)
|
|
if not user:
|
|
msg = f"User {user_id} not found"
|
|
raise ValueError(msg)
|
|
|
|
# Calculate credits to add (can't exceed max_credits)
|
|
current_credits = user.credits
|
|
target_credits = min(current_credits + plan_credits, max_credits)
|
|
credits_to_add = target_credits - current_credits
|
|
|
|
# If no credits to add, return None
|
|
if credits_to_add <= 0:
|
|
logger.info(
|
|
"No credits to add for user %s: current=%s, "
|
|
"plan_credits=%s, max=%s",
|
|
user_id,
|
|
current_credits,
|
|
plan_credits,
|
|
max_credits,
|
|
)
|
|
return None
|
|
|
|
# Record transaction
|
|
transaction = CreditTransaction(
|
|
user_id=user_id,
|
|
action_type=CreditActionType.DAILY_RECHARGE.value,
|
|
amount=credits_to_add,
|
|
balance_before=current_credits,
|
|
balance_after=target_credits,
|
|
description="Daily credit recharge",
|
|
success=True,
|
|
metadata_json=json.dumps(
|
|
{
|
|
"plan_credits": plan_credits,
|
|
"max_credits": max_credits,
|
|
},
|
|
),
|
|
)
|
|
|
|
# Update user credits
|
|
await user_repo.update(user, {"credits": target_credits})
|
|
|
|
# Save transaction
|
|
session.add(transaction)
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
"Credits recharged for user %s: %s credits added (balance: %s → %s)",
|
|
user_id,
|
|
credits_to_add,
|
|
current_credits,
|
|
target_credits,
|
|
)
|
|
|
|
# Emit user_credits_changed event via WebSocket
|
|
try:
|
|
event_data = {
|
|
"user_id": str(user_id),
|
|
"credits_before": current_credits,
|
|
"credits_after": target_credits,
|
|
"credits_added": credits_to_add,
|
|
"description": "Daily credit recharge",
|
|
"success": True,
|
|
}
|
|
await socket_manager.send_to_user(
|
|
str(user_id),
|
|
"user_credits_changed",
|
|
event_data,
|
|
)
|
|
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to emit user_credits_changed event for user %s",
|
|
user_id,
|
|
)
|
|
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
else:
|
|
return transaction
|
|
finally:
|
|
await session.close()
|
|
|
|
async def recharge_all_users_credits(self) -> dict[str, int]:
|
|
"""Recharge credits for all users based on their plans.
|
|
|
|
Returns:
|
|
Dictionary with statistics about the recharge operation
|
|
|
|
"""
|
|
session = self.db_session_factory()
|
|
stats = {
|
|
"total_users": 0,
|
|
"recharged_users": 0,
|
|
"skipped_users": 0,
|
|
"total_credits_added": 0,
|
|
}
|
|
|
|
try:
|
|
user_repo = UserRepository(session)
|
|
|
|
# Process users in batches to avoid memory issues
|
|
offset = 0
|
|
batch_size = 100
|
|
|
|
while True:
|
|
users = await user_repo.get_all_with_plan(
|
|
limit=batch_size,
|
|
offset=offset,
|
|
)
|
|
if not users:
|
|
break
|
|
|
|
for user in users:
|
|
stats["total_users"] += 1
|
|
|
|
# Skip users without ID (shouldn't happen in practice)
|
|
if user.id is None:
|
|
continue
|
|
|
|
transaction = await self.recharge_user_credits(
|
|
user.id,
|
|
user.plan.credits,
|
|
user.plan.max_credits,
|
|
)
|
|
|
|
if transaction:
|
|
stats["recharged_users"] += 1
|
|
# Calculate the amount from plan data to avoid session issues
|
|
current_credits = user.credits
|
|
plan_credits = user.plan.credits
|
|
max_credits = user.plan.max_credits
|
|
target_credits = min(
|
|
current_credits + plan_credits, max_credits,
|
|
)
|
|
credits_added = target_credits - current_credits
|
|
stats["total_credits_added"] += credits_added
|
|
else:
|
|
stats["skipped_users"] += 1
|
|
|
|
offset += batch_size
|
|
|
|
# Break if we got fewer users than batch_size (last batch)
|
|
if len(users) < batch_size:
|
|
break
|
|
|
|
logger.info(
|
|
"Daily credit recharge completed: %s total users, "
|
|
"%s recharged, %s skipped, %s total credits added",
|
|
stats["total_users"],
|
|
stats["recharged_users"],
|
|
stats["skipped_users"],
|
|
stats["total_credits_added"],
|
|
)
|
|
|
|
return stats
|
|
|
|
finally:
|
|
await session.close()
|