Files
sdb2-backend/app/utils/credit_decorators.py
2025-07-31 21:56:03 +02:00

198 lines
6.6 KiB
Python

"""Decorators for credit management and validation."""
import functools
import inspect
import types
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
def requires_credits(
action_type: CreditActionType,
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
) -> Callable[[F], F]:
"""Enforce credit requirements for actions.
Args:
action_type: The type of action that requires credits
credit_service_factory: Factory to create credit service instance
user_id_param: Name of the parameter containing user ID
metadata_extractor: Optional function to extract metadata from function args
Returns:
Decorated function that validates and deducts credits
Example:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
lambda: get_credit_service(),
user_id_param="user_id"
)
async def play_sound_for_user(user_id: int, sound: Sound) -> bool:
# Implementation here
return True
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
param_index = param_names.index(user_id_param)
if param_index < len(args):
user_id = args[param_index]
if user_id is None:
msg = f"Could not extract user_id from parameter '{user_id_param}'"
raise ValueError(msg)
# Extract metadata if extractor provided
metadata = None
if metadata_extractor:
metadata = metadata_extractor(*args, **kwargs)
# Get credit service
credit_service = credit_service_factory()
# Validate credits before execution
await credit_service.validate_and_reserve_credits(
user_id, action_type, metadata,
)
# Execute the function
success = False
try:
result = await func(*args, **kwargs)
success = bool(result) # Consider function result as success indicator
except Exception:
success = False
raise
else:
return result
finally:
# Deduct credits based on success
await credit_service.deduct_credits(
user_id, action_type, success, metadata,
)
return wrapper # type: ignore[return-value]
return decorator
def validate_credits_only(
action_type: CreditActionType,
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
) -> Callable[[F], F]:
"""Validate credits without deducting them.
Useful for checking if a user can perform an action before actual execution.
Args:
action_type: The type of action that requires credits
credit_service_factory: Factory to create credit service instance
user_id_param: Name of the parameter containing user ID
Returns:
Decorated function that validates credits only
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
param_index = param_names.index(user_id_param)
if param_index < len(args):
user_id = args[param_index]
if user_id is None:
msg = f"Could not extract user_id from parameter '{user_id_param}'"
raise ValueError(msg)
# Get credit service
credit_service = credit_service_factory()
# Validate credits only
await credit_service.validate_and_reserve_credits(user_id, action_type)
# Execute the function
return await func(*args, **kwargs)
return wrapper # type: ignore[return-value]
return decorator
class CreditManager:
"""Context manager for credit operations."""
def __init__(
self,
credit_service: CreditService,
user_id: int,
action_type: CreditActionType,
metadata: dict[str, Any] | None = None,
) -> None:
"""Initialize credit manager.
Args:
credit_service: Credit service instance
user_id: User ID
action_type: Action type
metadata: Optional metadata
"""
self.credit_service = credit_service
self.user_id = user_id
self.action_type = action_type
self.metadata = metadata
self.validated = False
self.success = False
async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.metadata,
)
self.validated = True
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit context manager - deduct credits based on success."""
if self.validated:
# If no exception occurred, consider it successful
success = exc_type is None and self.success
await self.credit_service.deduct_credits(
self.user_id, self.action_type, success, self.metadata,
)
def mark_success(self) -> None:
"""Mark the operation as successful."""
self.success = True