Files
sdb2-backend/app/services/credit.py
JSC bee1076239
All checks were successful
Backend CI / lint (push) Successful in 9m21s
Backend CI / test (push) Successful in 4m18s
refactor: Improve exception handling and logging in authentication and playlist services; enhance code readability and structure
2025-08-13 00:04:55 +02:00

582 lines
18 KiB
Python

"""Credit management service for tracking and validating user credit usage."""
import json
from collections.abc import Callable
from typing import Any
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
logger = get_logger(__name__)
class InsufficientCreditsError(Exception):
"""Raised when user has insufficient credits for an action."""
def __init__(self, required: int, available: int) -> None:
"""Initialize the error.
Args:
required: Number of credits required
available: Number of credits available
"""
self.required = required
self.available = available
super().__init__(
f"Insufficient credits: {required} required, {available} available",
)
class CreditService:
"""Service for managing user credits and transactions."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
"""Initialize the credit service.
Args:
db_session_factory: Factory function to create database sessions
"""
self.db_session_factory = db_session_factory
async def check_credits(
self,
user_id: int,
action_type: CreditActionType,
) -> bool:
"""Check if user has sufficient credits for an action.
Args:
user_id: The user ID
action_type: The type of action to check
Returns:
True if user has sufficient credits, False otherwise
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
return False
return user.credits >= action.cost
finally:
await session.close()
async def validate_and_reserve_credits(
self,
user_id: int,
action_type: CreditActionType,
) -> tuple[User, CreditAction]:
"""Validate user has sufficient credits and optionally reserve them.
Args:
user_id: The user ID
action_type: The type of action
Returns:
Tuple of (user, credit_action)
Raises:
InsufficientCreditsError: If user has insufficient credits
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
logger.info(
"Credits validated for user %s: %s credits available, %s required",
user_id,
user.credits,
action.cost,
)
return user, action
finally:
await session.close()
async def deduct_credits(
self,
user_id: int,
action_type: CreditActionType,
*,
success: bool = True,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Deduct credits from user account and record transaction.
Args:
user_id: The user ID
action_type: The type of action
success: Whether the action was successful
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
InsufficientCreditsError: If user has insufficient credits
ValueError: If user not found
"""
action = get_credit_action(action_type)
# Only deduct if action requires success and was successful,
# or doesn't require success
should_deduct = (
action.requires_success and success
) or not action.requires_success
if not should_deduct:
logger.info(
"Skipping credit deduction for user %s: "
"action %s failed and requires success",
user_id,
action_type.value,
)
# Still create a transaction record for auditing
return await self._create_transaction_record(
user_id,
action,
0,
success=success,
metadata=metadata,
)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
# Record transaction
balance_before = user.credits
balance_after = user.credits - action.cost
transaction = CreditTransaction(
user_id=user_id,
action_type=action_type.value,
amount=-action.cost,
balance_before=balance_before,
balance_after=balance_after,
description=action.description,
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
user_id,
action.cost,
action_type.value,
success,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_deducted": action.cost,
"action_type": action_type.value,
"success": success,
}
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s",
user_id,
)
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def add_credits(
self,
user_id: int,
amount: int,
description: str,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Add credits to user account.
Args:
user_id: The user ID
amount: Number of credits to add
description: Description of the credit addition
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
ValueError: If user not found or amount is negative
"""
if amount <= 0:
msg = "Amount must be positive"
raise ValueError(msg)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Record transaction
balance_before = user.credits
balance_after = user.credits + amount
transaction = CreditTransaction(
user_id=user_id,
action_type="credit_addition",
amount=amount,
balance_before=balance_before,
balance_after=balance_after,
description=description,
success=True,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits added for user %s: %s credits (description: %s)",
user_id,
amount,
description,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_added": amount,
"description": description,
"success": True,
}
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s",
user_id,
)
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def _create_transaction_record(
self,
user_id: int,
action: CreditAction,
amount: int,
*,
success: bool,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Create a transaction record without modifying credits.
Args:
user_id: The user ID
action: The credit action
amount: Amount to record (typically 0 for failed actions)
success: Whether the action was successful
metadata: Optional metadata
Returns:
The created transaction record
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
transaction = CreditTransaction(
user_id=user_id,
action_type=action.action_type.value,
amount=amount,
balance_before=user.credits,
balance_after=user.credits,
description=(
f"{action.description} (failed)"
if not success
else action.description
),
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(transaction)
await session.commit()
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def get_user_balance(self, user_id: int) -> int:
"""Get current credit balance for a user.
Args:
user_id: The user ID
Returns:
Current credit balance
Raises:
ValueError: If user not found
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
return user.credits
finally:
await session.close()
async def recharge_user_credits(
self,
user_id: int,
plan_credits: int,
max_credits: int,
) -> CreditTransaction | None:
"""Recharge credits for a user based on their plan.
Args:
user_id: The user ID
plan_credits: Number of credits from the plan
max_credits: Maximum credits allowed for the plan
Returns:
The created credit transaction if credits were added, None if no recharge
needed
Raises:
ValueError: If user not found
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Calculate credits to add (can't exceed max_credits)
current_credits = user.credits
target_credits = min(current_credits + plan_credits, max_credits)
credits_to_add = target_credits - current_credits
# If no credits to add, return None
if credits_to_add <= 0:
logger.info(
"No credits to add for user %s: current=%s, "
"plan_credits=%s, max=%s",
user_id,
current_credits,
plan_credits,
max_credits,
)
return None
# Record transaction
transaction = CreditTransaction(
user_id=user_id,
action_type=CreditActionType.DAILY_RECHARGE.value,
amount=credits_to_add,
balance_before=current_credits,
balance_after=target_credits,
description="Daily credit recharge",
success=True,
metadata_json=json.dumps(
{
"plan_credits": plan_credits,
"max_credits": max_credits,
},
),
)
# Update user credits
await user_repo.update(user, {"credits": target_credits})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits recharged for user %s: %s credits added (balance: %s%s)",
user_id,
credits_to_add,
current_credits,
target_credits,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": current_credits,
"credits_after": target_credits,
"credits_added": credits_to_add,
"description": "Daily credit recharge",
"success": True,
}
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s",
user_id,
)
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def recharge_all_users_credits(self) -> dict[str, int]:
"""Recharge credits for all users based on their plans.
Returns:
Dictionary with statistics about the recharge operation
"""
session = self.db_session_factory()
stats = {
"total_users": 0,
"recharged_users": 0,
"skipped_users": 0,
"total_credits_added": 0,
}
try:
user_repo = UserRepository(session)
# Process users in batches to avoid memory issues
offset = 0
batch_size = 100
while True:
users = await user_repo.get_all_with_plan(
limit=batch_size,
offset=offset,
)
if not users:
break
for user in users:
stats["total_users"] += 1
# Skip users without ID (shouldn't happen in practice)
if user.id is None:
continue
transaction = await self.recharge_user_credits(
user.id,
user.plan.credits,
user.plan.max_credits,
)
if transaction:
stats["recharged_users"] += 1
stats["total_credits_added"] += transaction.amount
else:
stats["skipped_users"] += 1
offset += batch_size
# Break if we got fewer users than batch_size (last batch)
if len(users) < batch_size:
break
logger.info(
"Daily credit recharge completed: %s total users, "
"%s recharged, %s skipped, %s total credits added",
stats["total_users"],
stats["recharged_users"],
stats["skipped_users"],
stats["total_credits_added"],
)
return stats
finally:
await session.close()