fix: Utils lint fixes
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
"""Decorators for credit management and validation."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import types
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
@@ -16,7 +18,7 @@ def requires_credits(
|
||||
user_id_param: str = "user_id",
|
||||
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to enforce credit requirements for actions.
|
||||
"""Enforce credit requirements for actions.
|
||||
|
||||
Args:
|
||||
action_type: The type of action that requires credits
|
||||
@@ -40,14 +42,13 @@ def requires_credits(
|
||||
"""
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
# Extract user ID from parameters
|
||||
user_id = None
|
||||
if user_id_param in kwargs:
|
||||
user_id = kwargs[user_id_param]
|
||||
else:
|
||||
# Try to find user_id in function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
if user_id_param in param_names:
|
||||
@@ -74,14 +75,14 @@ def requires_credits(
|
||||
|
||||
# Execute the function
|
||||
success = False
|
||||
result = None
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
success = bool(result) # Consider function result as success indicator
|
||||
return result
|
||||
except Exception:
|
||||
success = False
|
||||
raise
|
||||
else:
|
||||
return result
|
||||
finally:
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
@@ -97,7 +98,7 @@ def validate_credits_only(
|
||||
credit_service_factory: Callable[[], CreditService],
|
||||
user_id_param: str = "user_id",
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to only validate credits without deducting them.
|
||||
"""Validate credits without deducting them.
|
||||
|
||||
Useful for checking if a user can perform an action before actual execution.
|
||||
|
||||
@@ -112,14 +113,13 @@ def validate_credits_only(
|
||||
"""
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
# Extract user ID from parameters
|
||||
user_id = None
|
||||
if user_id_param in kwargs:
|
||||
user_id = kwargs[user_id_param]
|
||||
else:
|
||||
# Try to find user_id in function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
if user_id_param in param_names:
|
||||
@@ -178,7 +178,12 @@ class CreditManager:
|
||||
self.validated = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: types.TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit context manager - deduct credits based on success."""
|
||||
if self.validated:
|
||||
# If no exception occurred, consider it successful
|
||||
|
||||
Reference in New Issue
Block a user