fix: Utils lint fixes

This commit is contained in:
JSC
2025-07-31 21:56:03 +02:00
parent 8847131f24
commit 01bb48c206
4 changed files with 85 additions and 81 deletions

View File

@@ -1,6 +1,8 @@
"""Decorators for credit management and validation."""
import functools
import inspect
import types
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
@@ -16,7 +18,7 @@ def requires_credits(
user_id_param: str = "user_id",
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
) -> Callable[[F], F]:
"""Decorator to enforce credit requirements for actions.
"""Enforce credit requirements for actions.
Args:
action_type: The type of action that requires credits
@@ -40,14 +42,13 @@ def requires_credits(
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
@@ -74,14 +75,14 @@ def requires_credits(
# Execute the function
success = False
result = None
try:
result = await func(*args, **kwargs)
success = bool(result) # Consider function result as success indicator
return result
except Exception:
success = False
raise
else:
return result
finally:
# Deduct credits based on success
await credit_service.deduct_credits(
@@ -97,7 +98,7 @@ def validate_credits_only(
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
) -> Callable[[F], F]:
"""Decorator to only validate credits without deducting them.
"""Validate credits without deducting them.
Useful for checking if a user can perform an action before actual execution.
@@ -112,14 +113,13 @@ def validate_credits_only(
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
@@ -178,7 +178,12 @@ class CreditManager:
self.validated = True
return self
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit context manager - deduct credits based on success."""
if self.validated:
# If no exception occurred, consider it successful