Files
sdb2-backend/app/services/credit.py
JSC bccfcafe0e
Some checks failed
Backend CI / lint (push) Failing after 10s
Backend CI / test (push) Failing after 1m37s
feat: Update CORS origins to allow Chrome extensions and improve logging in migration tool
2025-09-19 16:41:11 +02:00

628 lines
20 KiB
Python

"""Credit management service for tracking and validating user credit usage."""
import json
from collections.abc import Callable
from typing import Any
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
logger = get_logger(__name__)
class InsufficientCreditsError(Exception):
"""Raised when user has insufficient credits for an action."""
def __init__(self, required: int, available: int) -> None:
"""Initialize the error.
Args:
required: Number of credits required
available: Number of credits available
"""
self.required = required
self.available = available
super().__init__(
f"Insufficient credits: {required} required, {available} available",
)
class CreditService:
"""Service for managing user credits and transactions."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
"""Initialize the credit service.
Args:
db_session_factory: Factory function to create database sessions
"""
self.db_session_factory = db_session_factory
async def check_credits(
self,
user_id: int,
action_type: CreditActionType,
) -> bool:
"""Check if user has sufficient credits for an action.
Args:
user_id: The user ID
action_type: The type of action to check
Returns:
True if user has sufficient credits, False otherwise
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
return False
return user.credits >= action.cost
finally:
await session.close()
async def validate_and_reserve_credits(
self,
user_id: int,
action_type: CreditActionType,
) -> tuple[User, CreditAction]:
"""Validate user has sufficient credits and optionally reserve them.
Args:
user_id: The user ID
action_type: The type of action
Returns:
Tuple of (user, credit_action)
Raises:
InsufficientCreditsError: If user has insufficient credits
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
logger.info(
"Credits validated for user %s: %s credits available, %s required",
user_id,
user.credits,
action.cost,
)
return user, action
finally:
await session.close()
async def deduct_credits(
self,
user_id: int,
action_type: CreditActionType,
*,
success: bool = True,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Deduct credits from user account and record transaction.
Args:
user_id: The user ID
action_type: The type of action
success: Whether the action was successful
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
InsufficientCreditsError: If user has insufficient credits
ValueError: If user not found
"""
action = get_credit_action(action_type)
# Only deduct if action requires success and was successful,
# or doesn't require success
should_deduct = (
action.requires_success and success
) or not action.requires_success
if not should_deduct:
logger.info(
"Skipping credit deduction for user %s: "
"action %s failed and requires success",
user_id,
action_type.value,
)
# Still create a transaction record for auditing
return await self._create_transaction_record(
user_id,
action,
0,
success=success,
metadata=metadata,
)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
# Record transaction
balance_before = user.credits
balance_after = user.credits - action.cost
transaction = CreditTransaction(
user_id=user_id,
action_type=action_type.value,
amount=-action.cost,
balance_before=balance_before,
balance_after=balance_after,
description=action.description,
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
user_id,
action.cost,
action_type.value,
success,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_deducted": action.cost,
"action_type": action_type.value,
"success": success,
}
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s",
user_id,
)
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def add_credits(
self,
user_id: int,
amount: int,
description: str,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Add credits to user account.
Args:
user_id: The user ID
amount: Number of credits to add
description: Description of the credit addition
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
ValueError: If user not found or amount is negative
"""
if amount <= 0:
msg = "Amount must be positive"
raise ValueError(msg)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Record transaction
balance_before = user.credits
balance_after = user.credits + amount
transaction = CreditTransaction(
user_id=user_id,
action_type="credit_addition",
amount=amount,
balance_before=balance_before,
balance_after=balance_after,
description=description,
success=True,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits added for user %s: %s credits (description: %s)",
user_id,
amount,
description,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_added": amount,
"description": description,
"success": True,
}
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s",
user_id,
)
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def _create_transaction_record(
self,
user_id: int,
action: CreditAction,
amount: int,
*,
success: bool,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Create a transaction record without modifying credits.
Args:
user_id: The user ID
action: The credit action
amount: Amount to record (typically 0 for failed actions)
success: Whether the action was successful
metadata: Optional metadata
Returns:
The created transaction record
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
transaction = CreditTransaction(
user_id=user_id,
action_type=action.action_type.value,
amount=amount,
balance_before=user.credits,
balance_after=user.credits,
description=(
f"{action.description} (failed)"
if not success
else action.description
),
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(transaction)
await session.commit()
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def get_user_balance(self, user_id: int) -> int:
"""Get current credit balance for a user.
Args:
user_id: The user ID
Returns:
Current credit balance
Raises:
ValueError: If user not found
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
return user.credits
finally:
await session.close()
async def recharge_user_credits_auto(
self,
user_id: int,
) -> CreditTransaction | None:
"""Recharge credits for a user automatically based on their plan.
Args:
user_id: The user ID
Returns:
The created credit transaction if credits were added, None if no recharge
needed
Raises:
ValueError: If user not found or has no plan
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id_with_plan(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if not user.plan:
msg = f"User {user_id} has no plan assigned"
raise ValueError(msg)
# Call the main method with plan details
return await self.recharge_user_credits(
user_id,
user.plan.credits,
user.plan.max_credits,
)
finally:
await session.close()
async def recharge_user_credits(
self,
user_id: int,
plan_credits: int,
max_credits: int,
) -> CreditTransaction | None:
"""Recharge credits for a user based on their plan.
Args:
user_id: The user ID
plan_credits: Number of credits from the plan
max_credits: Maximum credits allowed for the plan
Returns:
The created credit transaction if credits were added, None if no recharge
needed
Raises:
ValueError: If user not found
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Calculate credits to add (can't exceed max_credits)
current_credits = user.credits
target_credits = min(current_credits + plan_credits, max_credits)
credits_to_add = target_credits - current_credits
# If no credits to add, return None
if credits_to_add <= 0:
logger.info(
"No credits to add for user %s: current=%s, "
"plan_credits=%s, max=%s",
user_id,
current_credits,
plan_credits,
max_credits,
)
return None
# Record transaction
transaction = CreditTransaction(
user_id=user_id,
action_type=CreditActionType.DAILY_RECHARGE.value,
amount=credits_to_add,
balance_before=current_credits,
balance_after=target_credits,
description="Daily credit recharge",
success=True,
metadata_json=json.dumps(
{
"plan_credits": plan_credits,
"max_credits": max_credits,
},
),
)
# Update user credits
await user_repo.update(user, {"credits": target_credits})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits recharged for user %s: %s credits added (balance: %s%s)",
user_id,
credits_to_add,
current_credits,
target_credits,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": current_credits,
"credits_after": target_credits,
"credits_added": credits_to_add,
"description": "Daily credit recharge",
"success": True,
}
await socket_manager.send_to_user(
str(user_id),
"user_credits_changed",
event_data,
)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s",
user_id,
)
except Exception:
await session.rollback()
raise
else:
return transaction
finally:
await session.close()
async def recharge_all_users_credits(self) -> dict[str, int]:
"""Recharge credits for all users based on their plans.
Returns:
Dictionary with statistics about the recharge operation
"""
session = self.db_session_factory()
stats = {
"total_users": 0,
"recharged_users": 0,
"skipped_users": 0,
"total_credits_added": 0,
}
try:
user_repo = UserRepository(session)
# Process users in batches to avoid memory issues
offset = 0
batch_size = 100
while True:
users = await user_repo.get_all_with_plan(
limit=batch_size,
offset=offset,
)
if not users:
break
for user in users:
stats["total_users"] += 1
# Skip users without ID (shouldn't happen in practice)
if user.id is None:
continue
transaction = await self.recharge_user_credits(
user.id,
user.plan.credits,
user.plan.max_credits,
)
if transaction:
stats["recharged_users"] += 1
# Calculate the amount from plan data to avoid session issues
current_credits = user.credits
plan_credits = user.plan.credits
max_credits = user.plan.max_credits
target_credits = min(
current_credits + plan_credits, max_credits,
)
credits_added = target_credits - current_credits
stats["total_credits_added"] += credits_added
else:
stats["skipped_users"] += 1
offset += batch_size
# Break if we got fewer users than batch_size (last batch)
if len(users) < batch_size:
break
logger.info(
"Daily credit recharge completed: %s total users, "
"%s recharged, %s skipped, %s total credits added",
stats["total_users"],
stats["recharged_users"],
stats["skipped_users"],
stats["total_credits_added"],
)
return stats
finally:
await session.close()