Add tests for sound repository, user OAuth repository, credit service, and credit decorators

- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities.
- Create tests for UserOauthRepository to validate OAuth record management.
- Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits.
- Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
JSC
2025-07-30 21:33:55 +02:00
parent dd10ef5d41
commit e43650c26c
14 changed files with 2692 additions and 1 deletions

View File

@@ -7,9 +7,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType
from app.models.user import User
from app.repositories.sound import SoundRepository
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService
@@ -45,6 +47,12 @@ def get_vlc_player() -> VLCPlayerService:
return get_vlc_player_service(get_session_factory())
def get_credit_service() -> CreditService:
"""Get the credit service."""
from app.core.database import get_session_factory
return CreditService(get_session_factory())
async def get_sound_repository(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundRepository:
@@ -373,8 +381,9 @@ async def play_sound_with_vlc(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
credit_service: Annotated[CreditService, Depends(get_credit_service)],
) -> dict[str, str | int | bool]:
"""Play a sound using VLC subprocess."""
"""Play a sound using VLC subprocess (requires 1 credit)."""
try:
# Get the sound
sound = await sound_repo.get_by_id(sound_id)
@@ -384,9 +393,30 @@ async def play_sound_with_vlc(
detail=f"Sound with ID {sound_id} not found",
)
# Check and validate credits before playing
try:
await credit_service.validate_and_reserve_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
{"sound_id": sound_id, "sound_name": sound.name}
)
except InsufficientCreditsError as e:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Insufficient credits: {e.required} required, {e.available} available",
) from e
# Play the sound using VLC
success = await vlc_player.play_sound(sound)
# Deduct credits based on success
await credit_service.deduct_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
success,
{"sound_id": sound_id, "sound_name": sound.name},
)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -398,6 +428,7 @@ async def play_sound_with_vlc(
"sound_id": sound_id,
"sound_name": sound.name,
"success": True,
"credits_deducted": 1,
}
except HTTPException:

121
app/models/credit_action.py Normal file
View File

@@ -0,0 +1,121 @@
"""Credit action definitions for the credit system."""
from enum import Enum
from typing import Any
class CreditActionType(str, Enum):
"""Types of actions that consume credits."""
VLC_PLAY_SOUND = "vlc_play_sound"
AUDIO_EXTRACTION = "audio_extraction"
TEXT_TO_SPEECH = "text_to_speech"
SOUND_NORMALIZATION = "sound_normalization"
API_REQUEST = "api_request"
PLAYLIST_CREATION = "playlist_creation"
class CreditAction:
"""Definition of a credit-consuming action."""
def __init__(
self,
action_type: CreditActionType,
cost: int,
description: str,
*,
requires_success: bool = True,
) -> None:
"""Initialize a credit action.
Args:
action_type: The type of action
cost: Number of credits required
description: Human-readable description
requires_success: Whether credits are only deducted on successful completion
"""
self.action_type = action_type
self.cost = cost
self.description = description
self.requires_success = requires_success
def __str__(self) -> str:
"""Return string representation of the action."""
return f"{self.action_type.value} ({self.cost} credits)"
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"action_type": self.action_type.value,
"cost": self.cost,
"description": self.description,
"requires_success": self.requires_success,
}
# Predefined credit actions
CREDIT_ACTIONS = {
CreditActionType.VLC_PLAY_SOUND: CreditAction(
action_type=CreditActionType.VLC_PLAY_SOUND,
cost=1,
description="Play a sound using VLC player",
requires_success=True,
),
CreditActionType.AUDIO_EXTRACTION: CreditAction(
action_type=CreditActionType.AUDIO_EXTRACTION,
cost=5,
description="Extract audio from external URL",
requires_success=True,
),
CreditActionType.TEXT_TO_SPEECH: CreditAction(
action_type=CreditActionType.TEXT_TO_SPEECH,
cost=2,
description="Generate speech from text",
requires_success=True,
),
CreditActionType.SOUND_NORMALIZATION: CreditAction(
action_type=CreditActionType.SOUND_NORMALIZATION,
cost=1,
description="Normalize audio levels",
requires_success=True,
),
CreditActionType.API_REQUEST: CreditAction(
action_type=CreditActionType.API_REQUEST,
cost=1,
description="API request (rate limiting)",
requires_success=False, # Charged even if request fails
),
CreditActionType.PLAYLIST_CREATION: CreditAction(
action_type=CreditActionType.PLAYLIST_CREATION,
cost=3,
description="Create a new playlist",
requires_success=True,
),
}
def get_credit_action(action_type: CreditActionType) -> CreditAction:
"""Get a credit action definition by type.
Args:
action_type: The action type to look up
Returns:
The credit action definition
Raises:
KeyError: If action type is not found
"""
return CREDIT_ACTIONS[action_type]
def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
"""Get all available credit actions.
Returns:
Dictionary of all credit actions
"""
return CREDIT_ACTIONS.copy()

View File

@@ -0,0 +1,29 @@
"""Credit transaction model for tracking credit usage."""
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.user import User
class CreditTransaction(BaseModel, table=True):
"""Database model for credit transactions."""
__tablename__ = "credit_transaction" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="user.id", nullable=False)
action_type: str = Field(nullable=False)
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
balance_before: int = Field(nullable=False)
balance_after: int = Field(nullable=False)
description: str = Field(nullable=False)
success: bool = Field(nullable=False, default=True)
# JSON string for additional data
metadata_json: str | None = Field(default=None)
# relationships
user: "User" = Relationship(back_populates="credit_transactions")

View File

@@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.credit_transaction import CreditTransaction
from app.models.extraction import Extraction
from app.models.plan import Plan
from app.models.playlist import Playlist
@@ -35,3 +36,4 @@ class User(BaseModel, table=True):
playlists: list["Playlist"] = Relationship(back_populates="user")
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
extractions: list["Extraction"] = Relationship(back_populates="user")
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")

132
app/repositories/base.py Normal file
View File

@@ -0,0 +1,132 @@
"""Base repository with common CRUD operations."""
from typing import Any, Generic, TypeVar
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
# Type variable for the model
ModelType = TypeVar("ModelType")
logger = get_logger(__name__)
class BaseRepository(Generic[ModelType]):
"""Base repository with common CRUD operations."""
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
"""Initialize the repository.
Args:
model: The SQLModel class
session: Database session
"""
self.model = model
self.session = session
async def get_by_id(self, entity_id: int) -> ModelType | None:
"""Get an entity by ID.
Args:
entity_id: The entity ID
Returns:
The entity if found, None otherwise
"""
try:
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
raise
async def get_all(
self,
limit: int = 100,
offset: int = 0,
) -> list[ModelType]:
"""Get all entities with pagination.
Args:
limit: Maximum number of entities to return
offset: Number of entities to skip
Returns:
List of entities
"""
try:
statement = select(self.model).limit(limit).offset(offset)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all %s", self.model.__name__)
raise
async def create(self, entity_data: dict[str, Any]) -> ModelType:
"""Create a new entity.
Args:
entity_data: Dictionary of entity data
Returns:
The created entity
"""
try:
entity = self.model(**entity_data)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to create %s", self.model.__name__)
raise
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
"""Update an entity.
Args:
entity: The entity to update
update_data: Dictionary of fields to update
Returns:
The updated entity
"""
try:
for field, value in update_data.items():
setattr(entity, field, value)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to update %s", self.model.__name__)
raise
async def delete(self, entity: ModelType) -> None:
"""Delete an entity.
Args:
entity: The entity to delete
"""
try:
await self.session.delete(entity)
await self.session.commit()
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
except Exception:
await self.session.rollback()
logger.exception("Failed to delete %s", self.model.__name__)
raise

View File

@@ -0,0 +1,108 @@
"""Repository for credit transaction database operations."""
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_transaction import CreditTransaction
from app.repositories.base import BaseRepository
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
"""Repository for credit transaction operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the repository.
Args:
session: Database session
"""
super().__init__(CreditTransaction, session)
async def get_by_user_id(
self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get credit transactions for a user.
Args:
user_id: The user ID
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of credit transactions ordered by creation date (newest first)
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.user_id == user_id)
.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
async def get_by_action_type(
self,
action_type: str,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get credit transactions by action type.
Args:
action_type: The action type to filter by
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of credit transactions ordered by creation date (newest first)
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.action_type == action_type)
.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
async def get_successful_transactions(
self,
user_id: int | None = None,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get successful credit transactions.
Args:
user_id: Optional user ID to filter by
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of successful credit transactions
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.success == True) # noqa: E712
)
if user_id is not None:
stmt = stmt.where(CreditTransaction.user_id == user_id)
stmt = (
stmt.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())

383
app/services/credit.py Normal file
View File

@@ -0,0 +1,383 @@
"""Credit management service for tracking and validating user credit usage."""
import json
from collections.abc import Callable
from typing import Any
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
logger = get_logger(__name__)
class InsufficientCreditsError(Exception):
"""Raised when user has insufficient credits for an action."""
def __init__(self, required: int, available: int) -> None:
"""Initialize the error.
Args:
required: Number of credits required
available: Number of credits available
"""
self.required = required
self.available = available
super().__init__(
f"Insufficient credits: {required} required, {available} available"
)
class CreditService:
"""Service for managing user credits and transactions."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
"""Initialize the credit service.
Args:
db_session_factory: Factory function to create database sessions
"""
self.db_session_factory = db_session_factory
async def check_credits(
self,
user_id: int,
action_type: CreditActionType,
) -> bool:
"""Check if user has sufficient credits for an action.
Args:
user_id: The user ID
action_type: The type of action to check
Returns:
True if user has sufficient credits, False otherwise
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
return False
return user.credits >= action.cost
finally:
await session.close()
async def validate_and_reserve_credits(
self,
user_id: int,
action_type: CreditActionType,
metadata: dict[str, Any] | None = None,
) -> tuple[User, CreditAction]:
"""Validate user has sufficient credits and optionally reserve them.
Args:
user_id: The user ID
action_type: The type of action
metadata: Optional metadata to store with transaction
Returns:
Tuple of (user, credit_action)
Raises:
InsufficientCreditsError: If user has insufficient credits
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
logger.info(
"Credits validated for user %s: %s credits available, %s required",
user_id,
user.credits,
action.cost,
)
return user, action
finally:
await session.close()
async def deduct_credits(
self,
user_id: int,
action_type: CreditActionType,
success: bool = True,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Deduct credits from user account and record transaction.
Args:
user_id: The user ID
action_type: The type of action
success: Whether the action was successful
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
InsufficientCreditsError: If user has insufficient credits
ValueError: If user not found
"""
action = get_credit_action(action_type)
# Only deduct if action requires success and was successful, or doesn't require success
should_deduct = (action.requires_success and success) or not action.requires_success
if not should_deduct:
logger.info(
"Skipping credit deduction for user %s: action %s failed and requires success",
user_id,
action_type.value,
)
# Still create a transaction record for auditing
return await self._create_transaction_record(
user_id, action, 0, success, metadata
)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
# Record transaction
balance_before = user.credits
balance_after = user.credits - action.cost
transaction = CreditTransaction(
user_id=user_id,
action_type=action_type.value,
amount=-action.cost,
balance_before=balance_before,
balance_after=balance_after,
description=action.description,
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
user_id,
action.cost,
action_type.value,
success,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_deducted": action.cost,
"action_type": action_type.value,
"success": success,
}
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s", user_id,
)
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def add_credits(
self,
user_id: int,
amount: int,
description: str,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Add credits to user account.
Args:
user_id: The user ID
amount: Number of credits to add
description: Description of the credit addition
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
ValueError: If user not found or amount is negative
"""
if amount <= 0:
msg = "Amount must be positive"
raise ValueError(msg)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Record transaction
balance_before = user.credits
balance_after = user.credits + amount
transaction = CreditTransaction(
user_id=user_id,
action_type="credit_addition",
amount=amount,
balance_before=balance_before,
balance_after=balance_after,
description=description,
success=True,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits added for user %s: %s credits (description: %s)",
user_id,
amount,
description,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_added": amount,
"description": description,
"success": True,
}
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s", user_id,
)
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def _create_transaction_record(
self,
user_id: int,
action: CreditAction,
amount: int,
success: bool,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Create a transaction record without modifying credits.
Args:
user_id: The user ID
action: The credit action
amount: Amount to record (typically 0 for failed actions)
success: Whether the action was successful
metadata: Optional metadata
Returns:
The created transaction record
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
transaction = CreditTransaction(
user_id=user_id,
action_type=action.action_type.value,
amount=amount,
balance_before=user.credits,
balance_after=user.credits,
description=f"{action.description} (failed)" if not success else action.description,
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(transaction)
await session.commit()
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def get_user_balance(self, user_id: int) -> int:
"""Get current credit balance for a user.
Args:
user_id: The user ID
Returns:
Current credit balance
Raises:
ValueError: If user not found
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
return user.credits
finally:
await session.close()

View File

@@ -0,0 +1,192 @@
"""Decorators for credit management and validation."""
import functools
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
def requires_credits(
action_type: CreditActionType,
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
) -> Callable[[F], F]:
"""Decorator to enforce credit requirements for actions.
Args:
action_type: The type of action that requires credits
credit_service_factory: Factory to create credit service instance
user_id_param: Name of the parameter containing user ID
metadata_extractor: Optional function to extract metadata from function args
Returns:
Decorated function that validates and deducts credits
Example:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
lambda: get_credit_service(),
user_id_param="user_id"
)
async def play_sound_for_user(user_id: int, sound: Sound) -> bool:
# Implementation here
return True
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
param_index = param_names.index(user_id_param)
if param_index < len(args):
user_id = args[param_index]
if user_id is None:
msg = f"Could not extract user_id from parameter '{user_id_param}'"
raise ValueError(msg)
# Extract metadata if extractor provided
metadata = None
if metadata_extractor:
metadata = metadata_extractor(*args, **kwargs)
# Get credit service
credit_service = credit_service_factory()
# Validate credits before execution
await credit_service.validate_and_reserve_credits(
user_id, action_type, metadata
)
# Execute the function
success = False
result = None
try:
result = await func(*args, **kwargs)
success = bool(result) # Consider function result as success indicator
return result
except Exception:
success = False
raise
finally:
# Deduct credits based on success
await credit_service.deduct_credits(
user_id, action_type, success, metadata
)
return wrapper # type: ignore[return-value]
return decorator
def validate_credits_only(
action_type: CreditActionType,
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
) -> Callable[[F], F]:
"""Decorator to only validate credits without deducting them.
Useful for checking if a user can perform an action before actual execution.
Args:
action_type: The type of action that requires credits
credit_service_factory: Factory to create credit service instance
user_id_param: Name of the parameter containing user ID
Returns:
Decorated function that validates credits only
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
param_index = param_names.index(user_id_param)
if param_index < len(args):
user_id = args[param_index]
if user_id is None:
msg = f"Could not extract user_id from parameter '{user_id_param}'"
raise ValueError(msg)
# Get credit service
credit_service = credit_service_factory()
# Validate credits only
await credit_service.validate_and_reserve_credits(user_id, action_type)
# Execute the function
return await func(*args, **kwargs)
return wrapper # type: ignore[return-value]
return decorator
class CreditManager:
"""Context manager for credit operations."""
def __init__(
self,
credit_service: CreditService,
user_id: int,
action_type: CreditActionType,
metadata: dict[str, Any] | None = None,
) -> None:
"""Initialize credit manager.
Args:
credit_service: Credit service instance
user_id: User ID
action_type: Action type
metadata: Optional metadata
"""
self.credit_service = credit_service
self.user_id = user_id
self.action_type = action_type
self.metadata = metadata
self.validated = False
self.success = False
async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.metadata
)
self.validated = True
return self
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
"""Exit context manager - deduct credits based on success."""
if self.validated:
# If no exception occurred, consider it successful
success = exc_type is None and self.success
await self.credit_service.deduct_credits(
self.user_id, self.action_type, success, self.metadata
)
def mark_success(self) -> None:
"""Mark the operation as successful."""
self.success = True