"""Decorators for credit management and validation.""" import functools import inspect import types from collections.abc import Awaitable, Callable from typing import Any, TypeVar from app.models.credit_action import CreditActionType from app.services.credit import CreditService F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) def requires_credits( action_type: CreditActionType, credit_service_factory: Callable[[], CreditService], user_id_param: str = "user_id", metadata_extractor: Callable[..., dict[str, Any]] | None = None, ) -> Callable[[F], F]: """Enforce credit requirements for actions. Args: action_type: The type of action that requires credits credit_service_factory: Factory to create credit service instance user_id_param: Name of the parameter containing user ID metadata_extractor: Optional function to extract metadata from function args Returns: Decorated function that validates and deducts credits Example: @requires_credits( CreditActionType.VLC_PLAY_SOUND, lambda: get_credit_service(), user_id_param="user_id" ) async def play_sound_for_user(user_id: int, sound: Sound) -> bool: # Implementation here return True """ def decorator(func: F) -> F: @functools.wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 # Extract user ID from parameters user_id = None if user_id_param in kwargs: user_id = kwargs[user_id_param] else: # Try to find user_id in function signature sig = inspect.signature(func) param_names = list(sig.parameters.keys()) if user_id_param in param_names: param_index = param_names.index(user_id_param) if param_index < len(args): user_id = args[param_index] if user_id is None: msg = f"Could not extract user_id from parameter '{user_id_param}'" raise ValueError(msg) # Extract metadata if extractor provided metadata = None if metadata_extractor: metadata = metadata_extractor(*args, **kwargs) # Get credit service credit_service = credit_service_factory() # Validate credits before execution await credit_service.validate_and_reserve_credits( user_id, action_type, ) # Execute the function success = False try: result = await func(*args, **kwargs) success = bool(result) # Consider function result as success indicator except Exception: success = False raise else: return result finally: # Deduct credits based on success await credit_service.deduct_credits( user_id, action_type, success=success, metadata=metadata, ) return wrapper # type: ignore[return-value] return decorator def validate_credits_only( action_type: CreditActionType, credit_service_factory: Callable[[], CreditService], user_id_param: str = "user_id", ) -> Callable[[F], F]: """Validate credits without deducting them. Useful for checking if a user can perform an action before actual execution. Args: action_type: The type of action that requires credits credit_service_factory: Factory to create credit service instance user_id_param: Name of the parameter containing user ID Returns: Decorated function that validates credits only """ def decorator(func: F) -> F: @functools.wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 # Extract user ID from parameters user_id = None if user_id_param in kwargs: user_id = kwargs[user_id_param] else: # Try to find user_id in function signature sig = inspect.signature(func) param_names = list(sig.parameters.keys()) if user_id_param in param_names: param_index = param_names.index(user_id_param) if param_index < len(args): user_id = args[param_index] if user_id is None: msg = f"Could not extract user_id from parameter '{user_id_param}'" raise ValueError(msg) # Get credit service credit_service = credit_service_factory() # Validate credits only await credit_service.validate_and_reserve_credits(user_id, action_type) # Execute the function return await func(*args, **kwargs) return wrapper # type: ignore[return-value] return decorator class CreditManager: """Context manager for credit operations.""" def __init__( self, credit_service: CreditService, user_id: int, action_type: CreditActionType, metadata: dict[str, Any] | None = None, ) -> None: """Initialize credit manager. Args: credit_service: Credit service instance user_id: User ID action_type: Action type metadata: Optional metadata """ self.credit_service = credit_service self.user_id = user_id self.action_type = action_type self.metadata = metadata self.validated = False self.success = False async def __aenter__(self) -> "CreditManager": """Enter context manager - validate credits.""" await self.credit_service.validate_and_reserve_credits( self.user_id, self.action_type, ) self.validated = True return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: types.TracebackType | None, ) -> None: """Exit context manager - deduct credits based on success.""" if self.validated: # If no exception occurred, consider it successful success = exc_type is None and self.success await self.credit_service.deduct_credits( self.user_id, self.action_type, success=success, metadata=self.metadata, ) def mark_success(self) -> None: """Mark the operation as successful.""" self.success = True