diff --git a/app/api/v1/sounds.py b/app/api/v1/sounds.py index 4be2619..45abb72 100644 --- a/app/api/v1/sounds.py +++ b/app/api/v1/sounds.py @@ -7,9 +7,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession from app.core.database import get_db from app.core.dependencies import get_current_active_user_flexible +from app.models.credit_action import CreditActionType from app.models.user import User from app.repositories.sound import SoundRepository from app.services.extraction import ExtractionInfo, ExtractionService +from app.services.credit import CreditService, InsufficientCreditsError from app.services.extraction_processor import extraction_processor from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService from app.services.sound_scanner import ScanResults, SoundScannerService @@ -45,6 +47,12 @@ def get_vlc_player() -> VLCPlayerService: return get_vlc_player_service(get_session_factory()) +def get_credit_service() -> CreditService: + """Get the credit service.""" + from app.core.database import get_session_factory + return CreditService(get_session_factory()) + + async def get_sound_repository( session: Annotated[AsyncSession, Depends(get_db)], ) -> SoundRepository: @@ -373,8 +381,9 @@ async def play_sound_with_vlc( current_user: Annotated[User, Depends(get_current_active_user_flexible)], vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)], sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)], + credit_service: Annotated[CreditService, Depends(get_credit_service)], ) -> dict[str, str | int | bool]: - """Play a sound using VLC subprocess.""" + """Play a sound using VLC subprocess (requires 1 credit).""" try: # Get the sound sound = await sound_repo.get_by_id(sound_id) @@ -384,9 +393,30 @@ async def play_sound_with_vlc( detail=f"Sound with ID {sound_id} not found", ) + # Check and validate credits before playing + try: + await credit_service.validate_and_reserve_credits( + current_user.id, + CreditActionType.VLC_PLAY_SOUND, + {"sound_id": sound_id, "sound_name": sound.name} + ) + except InsufficientCreditsError as e: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Insufficient credits: {e.required} required, {e.available} available", + ) from e + # Play the sound using VLC success = await vlc_player.play_sound(sound) + # Deduct credits based on success + await credit_service.deduct_credits( + current_user.id, + CreditActionType.VLC_PLAY_SOUND, + success, + {"sound_id": sound_id, "sound_name": sound.name}, + ) + if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -398,6 +428,7 @@ async def play_sound_with_vlc( "sound_id": sound_id, "sound_name": sound.name, "success": True, + "credits_deducted": 1, } except HTTPException: diff --git a/app/models/credit_action.py b/app/models/credit_action.py new file mode 100644 index 0000000..b65bb41 --- /dev/null +++ b/app/models/credit_action.py @@ -0,0 +1,121 @@ +"""Credit action definitions for the credit system.""" + +from enum import Enum +from typing import Any + + +class CreditActionType(str, Enum): + """Types of actions that consume credits.""" + + VLC_PLAY_SOUND = "vlc_play_sound" + AUDIO_EXTRACTION = "audio_extraction" + TEXT_TO_SPEECH = "text_to_speech" + SOUND_NORMALIZATION = "sound_normalization" + API_REQUEST = "api_request" + PLAYLIST_CREATION = "playlist_creation" + + +class CreditAction: + """Definition of a credit-consuming action.""" + + def __init__( + self, + action_type: CreditActionType, + cost: int, + description: str, + *, + requires_success: bool = True, + ) -> None: + """Initialize a credit action. + + Args: + action_type: The type of action + cost: Number of credits required + description: Human-readable description + requires_success: Whether credits are only deducted on successful completion + + """ + self.action_type = action_type + self.cost = cost + self.description = description + self.requires_success = requires_success + + def __str__(self) -> str: + """Return string representation of the action.""" + return f"{self.action_type.value} ({self.cost} credits)" + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "action_type": self.action_type.value, + "cost": self.cost, + "description": self.description, + "requires_success": self.requires_success, + } + + +# Predefined credit actions +CREDIT_ACTIONS = { + CreditActionType.VLC_PLAY_SOUND: CreditAction( + action_type=CreditActionType.VLC_PLAY_SOUND, + cost=1, + description="Play a sound using VLC player", + requires_success=True, + ), + CreditActionType.AUDIO_EXTRACTION: CreditAction( + action_type=CreditActionType.AUDIO_EXTRACTION, + cost=5, + description="Extract audio from external URL", + requires_success=True, + ), + CreditActionType.TEXT_TO_SPEECH: CreditAction( + action_type=CreditActionType.TEXT_TO_SPEECH, + cost=2, + description="Generate speech from text", + requires_success=True, + ), + CreditActionType.SOUND_NORMALIZATION: CreditAction( + action_type=CreditActionType.SOUND_NORMALIZATION, + cost=1, + description="Normalize audio levels", + requires_success=True, + ), + CreditActionType.API_REQUEST: CreditAction( + action_type=CreditActionType.API_REQUEST, + cost=1, + description="API request (rate limiting)", + requires_success=False, # Charged even if request fails + ), + CreditActionType.PLAYLIST_CREATION: CreditAction( + action_type=CreditActionType.PLAYLIST_CREATION, + cost=3, + description="Create a new playlist", + requires_success=True, + ), +} + + +def get_credit_action(action_type: CreditActionType) -> CreditAction: + """Get a credit action definition by type. + + Args: + action_type: The action type to look up + + Returns: + The credit action definition + + Raises: + KeyError: If action type is not found + + """ + return CREDIT_ACTIONS[action_type] + + +def get_all_credit_actions() -> dict[CreditActionType, CreditAction]: + """Get all available credit actions. + + Returns: + Dictionary of all credit actions + + """ + return CREDIT_ACTIONS.copy() \ No newline at end of file diff --git a/app/models/credit_transaction.py b/app/models/credit_transaction.py new file mode 100644 index 0000000..4102872 --- /dev/null +++ b/app/models/credit_transaction.py @@ -0,0 +1,29 @@ +"""Credit transaction model for tracking credit usage.""" + +from typing import TYPE_CHECKING + +from sqlmodel import Field, Relationship + +from app.models.base import BaseModel + +if TYPE_CHECKING: + from app.models.user import User + + +class CreditTransaction(BaseModel, table=True): + """Database model for credit transactions.""" + + __tablename__ = "credit_transaction" # pyright: ignore[reportAssignmentType] + + user_id: int = Field(foreign_key="user.id", nullable=False) + action_type: str = Field(nullable=False) + amount: int = Field(nullable=False) # Negative for deductions, positive for additions + balance_before: int = Field(nullable=False) + balance_after: int = Field(nullable=False) + description: str = Field(nullable=False) + success: bool = Field(nullable=False, default=True) + # JSON string for additional data + metadata_json: str | None = Field(default=None) + + # relationships + user: "User" = Relationship(back_populates="credit_transactions") \ No newline at end of file diff --git a/app/models/user.py b/app/models/user.py index a1c3cbb..db1d6f7 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship from app.models.base import BaseModel if TYPE_CHECKING: + from app.models.credit_transaction import CreditTransaction from app.models.extraction import Extraction from app.models.plan import Plan from app.models.playlist import Playlist @@ -35,3 +36,4 @@ class User(BaseModel, table=True): playlists: list["Playlist"] = Relationship(back_populates="user") sounds_played: list["SoundPlayed"] = Relationship(back_populates="user") extractions: list["Extraction"] = Relationship(back_populates="user") + credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user") diff --git a/app/repositories/base.py b/app/repositories/base.py new file mode 100644 index 0000000..0040b1b --- /dev/null +++ b/app/repositories/base.py @@ -0,0 +1,132 @@ +"""Base repository with common CRUD operations.""" + +from typing import Any, Generic, TypeVar + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.logging import get_logger + +# Type variable for the model +ModelType = TypeVar("ModelType") + +logger = get_logger(__name__) + + +class BaseRepository(Generic[ModelType]): + """Base repository with common CRUD operations.""" + + def __init__(self, model: type[ModelType], session: AsyncSession) -> None: + """Initialize the repository. + + Args: + model: The SQLModel class + session: Database session + + """ + self.model = model + self.session = session + + async def get_by_id(self, entity_id: int) -> ModelType | None: + """Get an entity by ID. + + Args: + entity_id: The entity ID + + Returns: + The entity if found, None otherwise + + """ + try: + statement = select(self.model).where(getattr(self.model, "id") == entity_id) + result = await self.session.exec(statement) + return result.first() + except Exception: + logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id) + raise + + async def get_all( + self, + limit: int = 100, + offset: int = 0, + ) -> list[ModelType]: + """Get all entities with pagination. + + Args: + limit: Maximum number of entities to return + offset: Number of entities to skip + + Returns: + List of entities + + """ + try: + statement = select(self.model).limit(limit).offset(offset) + result = await self.session.exec(statement) + return list(result.all()) + except Exception: + logger.exception("Failed to get all %s", self.model.__name__) + raise + + async def create(self, entity_data: dict[str, Any]) -> ModelType: + """Create a new entity. + + Args: + entity_data: Dictionary of entity data + + Returns: + The created entity + + """ + try: + entity = self.model(**entity_data) + self.session.add(entity) + await self.session.commit() + await self.session.refresh(entity) + logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown")) + return entity + except Exception: + await self.session.rollback() + logger.exception("Failed to create %s", self.model.__name__) + raise + + async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType: + """Update an entity. + + Args: + entity: The entity to update + update_data: Dictionary of fields to update + + Returns: + The updated entity + + """ + try: + for field, value in update_data.items(): + setattr(entity, field, value) + + self.session.add(entity) + await self.session.commit() + await self.session.refresh(entity) + logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown")) + return entity + except Exception: + await self.session.rollback() + logger.exception("Failed to update %s", self.model.__name__) + raise + + async def delete(self, entity: ModelType) -> None: + """Delete an entity. + + Args: + entity: The entity to delete + + """ + try: + await self.session.delete(entity) + await self.session.commit() + logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown")) + except Exception: + await self.session.rollback() + logger.exception("Failed to delete %s", self.model.__name__) + raise \ No newline at end of file diff --git a/app/repositories/credit_transaction.py b/app/repositories/credit_transaction.py new file mode 100644 index 0000000..ecad094 --- /dev/null +++ b/app/repositories/credit_transaction.py @@ -0,0 +1,108 @@ +"""Repository for credit transaction database operations.""" + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.credit_transaction import CreditTransaction +from app.repositories.base import BaseRepository + + +class CreditTransactionRepository(BaseRepository[CreditTransaction]): + """Repository for credit transaction operations.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize the repository. + + Args: + session: Database session + + """ + super().__init__(CreditTransaction, session) + + async def get_by_user_id( + self, + user_id: int, + limit: int = 50, + offset: int = 0, + ) -> list[CreditTransaction]: + """Get credit transactions for a user. + + Args: + user_id: The user ID + limit: Maximum number of transactions to return + offset: Number of transactions to skip + + Returns: + List of credit transactions ordered by creation date (newest first) + + """ + stmt = ( + select(CreditTransaction) + .where(CreditTransaction.user_id == user_id) + .order_by(CreditTransaction.created_at.desc()) + .limit(limit) + .offset(offset) + ) + result = await self.session.exec(stmt) + return list(result.all()) + + async def get_by_action_type( + self, + action_type: str, + limit: int = 50, + offset: int = 0, + ) -> list[CreditTransaction]: + """Get credit transactions by action type. + + Args: + action_type: The action type to filter by + limit: Maximum number of transactions to return + offset: Number of transactions to skip + + Returns: + List of credit transactions ordered by creation date (newest first) + + """ + stmt = ( + select(CreditTransaction) + .where(CreditTransaction.action_type == action_type) + .order_by(CreditTransaction.created_at.desc()) + .limit(limit) + .offset(offset) + ) + result = await self.session.exec(stmt) + return list(result.all()) + + async def get_successful_transactions( + self, + user_id: int | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[CreditTransaction]: + """Get successful credit transactions. + + Args: + user_id: Optional user ID to filter by + limit: Maximum number of transactions to return + offset: Number of transactions to skip + + Returns: + List of successful credit transactions + + """ + stmt = ( + select(CreditTransaction) + .where(CreditTransaction.success == True) # noqa: E712 + ) + + if user_id is not None: + stmt = stmt.where(CreditTransaction.user_id == user_id) + + stmt = ( + stmt.order_by(CreditTransaction.created_at.desc()) + .limit(limit) + .offset(offset) + ) + + result = await self.session.exec(stmt) + return list(result.all()) \ No newline at end of file diff --git a/app/services/credit.py b/app/services/credit.py new file mode 100644 index 0000000..792289f --- /dev/null +++ b/app/services/credit.py @@ -0,0 +1,383 @@ +"""Credit management service for tracking and validating user credit usage.""" + +import json +from collections.abc import Callable +from typing import Any + +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.logging import get_logger +from app.models.credit_action import CreditAction, CreditActionType, get_credit_action +from app.models.credit_transaction import CreditTransaction +from app.models.user import User +from app.repositories.user import UserRepository +from app.services.socket import socket_manager + +logger = get_logger(__name__) + + +class InsufficientCreditsError(Exception): + """Raised when user has insufficient credits for an action.""" + + def __init__(self, required: int, available: int) -> None: + """Initialize the error. + + Args: + required: Number of credits required + available: Number of credits available + + """ + self.required = required + self.available = available + super().__init__( + f"Insufficient credits: {required} required, {available} available" + ) + + +class CreditService: + """Service for managing user credits and transactions.""" + + def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None: + """Initialize the credit service. + + Args: + db_session_factory: Factory function to create database sessions + + """ + self.db_session_factory = db_session_factory + + async def check_credits( + self, + user_id: int, + action_type: CreditActionType, + ) -> bool: + """Check if user has sufficient credits for an action. + + Args: + user_id: The user ID + action_type: The type of action to check + + Returns: + True if user has sufficient credits, False otherwise + + """ + action = get_credit_action(action_type) + session = self.db_session_factory() + try: + user_repo = UserRepository(session) + user = await user_repo.get_by_id(user_id) + if not user: + return False + return user.credits >= action.cost + finally: + await session.close() + + async def validate_and_reserve_credits( + self, + user_id: int, + action_type: CreditActionType, + metadata: dict[str, Any] | None = None, + ) -> tuple[User, CreditAction]: + """Validate user has sufficient credits and optionally reserve them. + + Args: + user_id: The user ID + action_type: The type of action + metadata: Optional metadata to store with transaction + + Returns: + Tuple of (user, credit_action) + + Raises: + InsufficientCreditsError: If user has insufficient credits + + """ + action = get_credit_action(action_type) + session = self.db_session_factory() + try: + user_repo = UserRepository(session) + user = await user_repo.get_by_id(user_id) + if not user: + msg = f"User {user_id} not found" + raise ValueError(msg) + + if user.credits < action.cost: + raise InsufficientCreditsError(action.cost, user.credits) + + logger.info( + "Credits validated for user %s: %s credits available, %s required", + user_id, + user.credits, + action.cost, + ) + return user, action + finally: + await session.close() + + async def deduct_credits( + self, + user_id: int, + action_type: CreditActionType, + success: bool = True, + metadata: dict[str, Any] | None = None, + ) -> CreditTransaction: + """Deduct credits from user account and record transaction. + + Args: + user_id: The user ID + action_type: The type of action + success: Whether the action was successful + metadata: Optional metadata to store with transaction + + Returns: + The created credit transaction + + Raises: + InsufficientCreditsError: If user has insufficient credits + ValueError: If user not found + + """ + action = get_credit_action(action_type) + + # Only deduct if action requires success and was successful, or doesn't require success + should_deduct = (action.requires_success and success) or not action.requires_success + + if not should_deduct: + logger.info( + "Skipping credit deduction for user %s: action %s failed and requires success", + user_id, + action_type.value, + ) + # Still create a transaction record for auditing + return await self._create_transaction_record( + user_id, action, 0, success, metadata + ) + + session = self.db_session_factory() + try: + user_repo = UserRepository(session) + user = await user_repo.get_by_id(user_id) + if not user: + msg = f"User {user_id} not found" + raise ValueError(msg) + + if user.credits < action.cost: + raise InsufficientCreditsError(action.cost, user.credits) + + # Record transaction + balance_before = user.credits + balance_after = user.credits - action.cost + + transaction = CreditTransaction( + user_id=user_id, + action_type=action_type.value, + amount=-action.cost, + balance_before=balance_before, + balance_after=balance_after, + description=action.description, + success=success, + metadata_json=json.dumps(metadata) if metadata else None, + ) + + # Update user credits + await user_repo.update(user, {"credits": balance_after}) + + # Save transaction + session.add(transaction) + await session.commit() + + logger.info( + "Credits deducted for user %s: %s credits (action: %s, success: %s)", + user_id, + action.cost, + action_type.value, + success, + ) + + # Emit user_credits_changed event via WebSocket + try: + event_data = { + "user_id": str(user_id), + "credits_before": balance_before, + "credits_after": balance_after, + "credits_deducted": action.cost, + "action_type": action_type.value, + "success": success, + } + await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data) + logger.info("Emitted user_credits_changed event for user %s", user_id) + except Exception: + logger.exception( + "Failed to emit user_credits_changed event for user %s", user_id, + ) + + return transaction + + except Exception: + await session.rollback() + raise + finally: + await session.close() + + async def add_credits( + self, + user_id: int, + amount: int, + description: str, + metadata: dict[str, Any] | None = None, + ) -> CreditTransaction: + """Add credits to user account. + + Args: + user_id: The user ID + amount: Number of credits to add + description: Description of the credit addition + metadata: Optional metadata to store with transaction + + Returns: + The created credit transaction + + Raises: + ValueError: If user not found or amount is negative + + """ + if amount <= 0: + msg = "Amount must be positive" + raise ValueError(msg) + + session = self.db_session_factory() + try: + user_repo = UserRepository(session) + user = await user_repo.get_by_id(user_id) + if not user: + msg = f"User {user_id} not found" + raise ValueError(msg) + + # Record transaction + balance_before = user.credits + balance_after = user.credits + amount + + transaction = CreditTransaction( + user_id=user_id, + action_type="credit_addition", + amount=amount, + balance_before=balance_before, + balance_after=balance_after, + description=description, + success=True, + metadata_json=json.dumps(metadata) if metadata else None, + ) + + # Update user credits + await user_repo.update(user, {"credits": balance_after}) + + # Save transaction + session.add(transaction) + await session.commit() + + logger.info( + "Credits added for user %s: %s credits (description: %s)", + user_id, + amount, + description, + ) + + # Emit user_credits_changed event via WebSocket + try: + event_data = { + "user_id": str(user_id), + "credits_before": balance_before, + "credits_after": balance_after, + "credits_added": amount, + "description": description, + "success": True, + } + await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data) + logger.info("Emitted user_credits_changed event for user %s", user_id) + except Exception: + logger.exception( + "Failed to emit user_credits_changed event for user %s", user_id, + ) + + return transaction + + except Exception: + await session.rollback() + raise + finally: + await session.close() + + async def _create_transaction_record( + self, + user_id: int, + action: CreditAction, + amount: int, + success: bool, + metadata: dict[str, Any] | None = None, + ) -> CreditTransaction: + """Create a transaction record without modifying credits. + + Args: + user_id: The user ID + action: The credit action + amount: Amount to record (typically 0 for failed actions) + success: Whether the action was successful + metadata: Optional metadata + + Returns: + The created transaction record + + """ + session = self.db_session_factory() + try: + user_repo = UserRepository(session) + user = await user_repo.get_by_id(user_id) + if not user: + msg = f"User {user_id} not found" + raise ValueError(msg) + + transaction = CreditTransaction( + user_id=user_id, + action_type=action.action_type.value, + amount=amount, + balance_before=user.credits, + balance_after=user.credits, + description=f"{action.description} (failed)" if not success else action.description, + success=success, + metadata_json=json.dumps(metadata) if metadata else None, + ) + + session.add(transaction) + await session.commit() + + return transaction + + except Exception: + await session.rollback() + raise + finally: + await session.close() + + async def get_user_balance(self, user_id: int) -> int: + """Get current credit balance for a user. + + Args: + user_id: The user ID + + Returns: + Current credit balance + + Raises: + ValueError: If user not found + + """ + session = self.db_session_factory() + try: + user_repo = UserRepository(session) + user = await user_repo.get_by_id(user_id) + if not user: + msg = f"User {user_id} not found" + raise ValueError(msg) + return user.credits + finally: + await session.close() \ No newline at end of file diff --git a/app/utils/credit_decorators.py b/app/utils/credit_decorators.py new file mode 100644 index 0000000..5af4e85 --- /dev/null +++ b/app/utils/credit_decorators.py @@ -0,0 +1,192 @@ +"""Decorators for credit management and validation.""" + +import functools +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar + +from app.models.credit_action import CreditActionType +from app.services.credit import CreditService, InsufficientCreditsError + +F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) + + +def requires_credits( + action_type: CreditActionType, + credit_service_factory: Callable[[], CreditService], + user_id_param: str = "user_id", + metadata_extractor: Callable[..., dict[str, Any]] | None = None, +) -> Callable[[F], F]: + """Decorator to enforce credit requirements for actions. + + Args: + action_type: The type of action that requires credits + credit_service_factory: Factory to create credit service instance + user_id_param: Name of the parameter containing user ID + metadata_extractor: Optional function to extract metadata from function args + + Returns: + Decorated function that validates and deducts credits + + Example: + @requires_credits( + CreditActionType.VLC_PLAY_SOUND, + lambda: get_credit_service(), + user_id_param="user_id" + ) + async def play_sound_for_user(user_id: int, sound: Sound) -> bool: + # Implementation here + return True + + """ + def decorator(func: F) -> F: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + # Extract user ID from parameters + user_id = None + if user_id_param in kwargs: + user_id = kwargs[user_id_param] + else: + # Try to find user_id in function signature + import inspect + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + if user_id_param in param_names: + param_index = param_names.index(user_id_param) + if param_index < len(args): + user_id = args[param_index] + + if user_id is None: + msg = f"Could not extract user_id from parameter '{user_id_param}'" + raise ValueError(msg) + + # Extract metadata if extractor provided + metadata = None + if metadata_extractor: + metadata = metadata_extractor(*args, **kwargs) + + # Get credit service + credit_service = credit_service_factory() + + # Validate credits before execution + await credit_service.validate_and_reserve_credits( + user_id, action_type, metadata + ) + + # Execute the function + success = False + result = None + try: + result = await func(*args, **kwargs) + success = bool(result) # Consider function result as success indicator + return result + except Exception: + success = False + raise + finally: + # Deduct credits based on success + await credit_service.deduct_credits( + user_id, action_type, success, metadata + ) + + return wrapper # type: ignore[return-value] + return decorator + + +def validate_credits_only( + action_type: CreditActionType, + credit_service_factory: Callable[[], CreditService], + user_id_param: str = "user_id", +) -> Callable[[F], F]: + """Decorator to only validate credits without deducting them. + + Useful for checking if a user can perform an action before actual execution. + + Args: + action_type: The type of action that requires credits + credit_service_factory: Factory to create credit service instance + user_id_param: Name of the parameter containing user ID + + Returns: + Decorated function that validates credits only + + """ + def decorator(func: F) -> F: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + # Extract user ID from parameters + user_id = None + if user_id_param in kwargs: + user_id = kwargs[user_id_param] + else: + # Try to find user_id in function signature + import inspect + sig = inspect.signature(func) + param_names = list(sig.parameters.keys()) + if user_id_param in param_names: + param_index = param_names.index(user_id_param) + if param_index < len(args): + user_id = args[param_index] + + if user_id is None: + msg = f"Could not extract user_id from parameter '{user_id_param}'" + raise ValueError(msg) + + # Get credit service + credit_service = credit_service_factory() + + # Validate credits only + await credit_service.validate_and_reserve_credits(user_id, action_type) + + # Execute the function + return await func(*args, **kwargs) + + return wrapper # type: ignore[return-value] + return decorator + + +class CreditManager: + """Context manager for credit operations.""" + + def __init__( + self, + credit_service: CreditService, + user_id: int, + action_type: CreditActionType, + metadata: dict[str, Any] | None = None, + ) -> None: + """Initialize credit manager. + + Args: + credit_service: Credit service instance + user_id: User ID + action_type: Action type + metadata: Optional metadata + + """ + self.credit_service = credit_service + self.user_id = user_id + self.action_type = action_type + self.metadata = metadata + self.validated = False + self.success = False + + async def __aenter__(self) -> "CreditManager": + """Enter context manager - validate credits.""" + await self.credit_service.validate_and_reserve_credits( + self.user_id, self.action_type, self.metadata + ) + self.validated = True + return self + + async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None: + """Exit context manager - deduct credits based on success.""" + if self.validated: + # If no exception occurred, consider it successful + success = exc_type is None and self.success + await self.credit_service.deduct_credits( + self.user_id, self.action_type, success, self.metadata + ) + + def mark_success(self) -> None: + """Mark the operation as successful.""" + self.success = True \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 7842828..9b744fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,8 +13,10 @@ from sqlmodel import SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.database import get_db +from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy from app.models.plan import Plan from app.models.user import User +from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy from app.utils.auth import JWTUtils, PasswordUtils diff --git a/tests/repositories/test_credit_transaction.py b/tests/repositories/test_credit_transaction.py new file mode 100644 index 0000000..3dde7ba --- /dev/null +++ b/tests/repositories/test_credit_transaction.py @@ -0,0 +1,412 @@ +"""Tests for credit transaction repository.""" + +import json +from collections.abc import AsyncGenerator + +import pytest +import pytest_asyncio +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.credit_transaction import CreditTransaction +from app.models.user import User +from app.repositories.credit_transaction import CreditTransactionRepository + + +class TestCreditTransactionRepository: + """Test credit transaction repository operations.""" + + @pytest_asyncio.fixture + async def credit_transaction_repository( + self, + test_session: AsyncSession, + ) -> AsyncGenerator[CreditTransactionRepository, None]: # type: ignore[misc] + """Create a credit transaction repository instance.""" + yield CreditTransactionRepository(test_session) + + @pytest_asyncio.fixture + async def test_user_id( + self, + test_user: User, + ) -> int: + """Get test user ID to avoid lazy loading issues.""" + return test_user.id + + @pytest_asyncio.fixture + async def test_transactions( + self, + test_session: AsyncSession, + test_user_id: int, + ) -> AsyncGenerator[list[CreditTransaction], None]: # type: ignore[misc] + """Create test credit transactions.""" + transactions = [] + user_id = test_user_id + + # Create various types of transactions + transaction_data = [ + { + "user_id": user_id, + "action_type": "vlc_play_sound", + "amount": -1, + "balance_before": 10, + "balance_after": 9, + "description": "Play sound via VLC", + "success": True, + "metadata_json": json.dumps({"sound_id": 1, "sound_name": "test.mp3"}), + }, + { + "user_id": user_id, + "action_type": "audio_extraction", + "amount": -5, + "balance_before": 9, + "balance_after": 4, + "description": "Extract audio from URL", + "success": True, + "metadata_json": json.dumps({"url": "https://example.com/video"}), + }, + { + "user_id": user_id, + "action_type": "vlc_play_sound", + "amount": 0, + "balance_before": 4, + "balance_after": 4, + "description": "Play sound via VLC (failed)", + "success": False, + "metadata_json": json.dumps({"sound_id": 2, "error": "File not found"}), + }, + { + "user_id": user_id, + "action_type": "credit_addition", + "amount": 50, + "balance_before": 4, + "balance_after": 54, + "description": "Bonus credits", + "success": True, + "metadata_json": json.dumps({"reason": "signup_bonus"}), + }, + ] + + for data in transaction_data: + transaction = CreditTransaction(**data) + test_session.add(transaction) + transactions.append(transaction) + + await test_session.commit() + for transaction in transactions: + await test_session.refresh(transaction) + + yield transactions + + @pytest_asyncio.fixture + async def other_user_transaction( + self, + test_session: AsyncSession, + ensure_plans: tuple, # noqa: ARG002 + ) -> AsyncGenerator[CreditTransaction, None]: # type: ignore[misc] + """Create a transaction for a different user.""" + from app.models.plan import Plan + from app.repositories.user import UserRepository + + # Create another user + user_repo = UserRepository(test_session) + other_user_data = { + "email": "other@example.com", + "name": "Other User", + "password_hash": "hashed_password", + "role": "user", + "is_active": True, + } + other_user = await user_repo.create(other_user_data) + + # Create transaction for the other user + transaction_data = { + "user_id": other_user.id, + "action_type": "vlc_play_sound", + "amount": -1, + "balance_before": 100, + "balance_after": 99, + "description": "Other user play sound", + "success": True, + "metadata_json": None, + } + transaction = CreditTransaction(**transaction_data) + test_session.add(transaction) + await test_session.commit() + await test_session.refresh(transaction) + + yield transaction + + @pytest.mark.asyncio + async def test_get_by_id_existing( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + ) -> None: + """Test getting transaction by ID when it exists.""" + transaction = await credit_transaction_repository.get_by_id(test_transactions[0].id) + + assert transaction is not None + assert transaction.id == test_transactions[0].id + assert transaction.action_type == "vlc_play_sound" + assert transaction.amount == -1 + + @pytest.mark.asyncio + async def test_get_by_id_nonexistent( + self, + credit_transaction_repository: CreditTransactionRepository, + ) -> None: + """Test getting transaction by ID when it doesn't exist.""" + transaction = await credit_transaction_repository.get_by_id(99999) + + assert transaction is None + + @pytest.mark.asyncio + async def test_get_by_user_id( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + other_user_transaction: CreditTransaction, + test_user_id: int, + ) -> None: + """Test getting transactions by user ID.""" + transactions = await credit_transaction_repository.get_by_user_id(test_user_id) + + # Should return all transactions for test_user + assert len(transactions) == 4 + # Should be ordered by created_at desc (newest first) + assert all(t.user_id == test_user_id for t in transactions) + + # Should not include other user's transaction + other_user_ids = [t.user_id for t in transactions] + assert other_user_transaction.user_id not in other_user_ids + + @pytest.mark.asyncio + async def test_get_by_user_id_with_pagination( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + test_user_id: int, + ) -> None: + """Test getting transactions by user ID with pagination.""" + # Get first 2 transactions + first_page = await credit_transaction_repository.get_by_user_id( + test_user_id, limit=2, offset=0 + ) + assert len(first_page) == 2 + + # Get next 2 transactions + second_page = await credit_transaction_repository.get_by_user_id( + test_user_id, limit=2, offset=2 + ) + assert len(second_page) == 2 + + # Should not overlap + first_page_ids = {t.id for t in first_page} + second_page_ids = {t.id for t in second_page} + assert first_page_ids.isdisjoint(second_page_ids) + + @pytest.mark.asyncio + async def test_get_by_action_type( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + ) -> None: + """Test getting transactions by action type.""" + vlc_transactions = await credit_transaction_repository.get_by_action_type( + "vlc_play_sound" + ) + + # Should return 2 VLC transactions (1 successful, 1 failed) + assert len(vlc_transactions) >= 2 + assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions) + + extraction_transactions = await credit_transaction_repository.get_by_action_type( + "audio_extraction" + ) + + # Should return 1 extraction transaction + assert len(extraction_transactions) >= 1 + assert all(t.action_type == "audio_extraction" for t in extraction_transactions) + + @pytest.mark.asyncio + async def test_get_by_action_type_with_pagination( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + ) -> None: + """Test getting transactions by action type with pagination.""" + # Test with limit + transactions = await credit_transaction_repository.get_by_action_type( + "vlc_play_sound", limit=1 + ) + assert len(transactions) == 1 + assert transactions[0].action_type == "vlc_play_sound" + + # Test with offset + transactions = await credit_transaction_repository.get_by_action_type( + "vlc_play_sound", limit=1, offset=1 + ) + assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total + + @pytest.mark.asyncio + async def test_get_successful_transactions( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + ) -> None: + """Test getting only successful transactions.""" + successful_transactions = await credit_transaction_repository.get_successful_transactions() + + # Should only return successful transactions + assert all(t.success is True for t in successful_transactions) + # Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition) + assert len(successful_transactions) >= 3 + + @pytest.mark.asyncio + async def test_get_successful_transactions_by_user( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + other_user_transaction: CreditTransaction, + test_user_id: int, + ) -> None: + """Test getting successful transactions filtered by user.""" + successful_transactions = await credit_transaction_repository.get_successful_transactions( + user_id=test_user_id + ) + + # Should only return successful transactions for test_user + assert all(t.success is True for t in successful_transactions) + assert all(t.user_id == test_user_id for t in successful_transactions) + # Should be 3 successful transactions for test_user + assert len(successful_transactions) == 3 + + @pytest.mark.asyncio + async def test_get_successful_transactions_with_pagination( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + test_user_id: int, + ) -> None: + """Test getting successful transactions with pagination.""" + # Get first 2 successful transactions + first_page = await credit_transaction_repository.get_successful_transactions( + user_id=test_user_id, limit=2, offset=0 + ) + assert len(first_page) == 2 + assert all(t.success is True for t in first_page) + + # Get next successful transaction + second_page = await credit_transaction_repository.get_successful_transactions( + user_id=test_user_id, limit=2, offset=2 + ) + assert len(second_page) == 1 # Should be 1 remaining + assert all(t.success is True for t in second_page) + + @pytest.mark.asyncio + async def test_get_all_transactions( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + other_user_transaction: CreditTransaction, + ) -> None: + """Test getting all transactions.""" + all_transactions = await credit_transaction_repository.get_all() + + # Should return all transactions + assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction + + @pytest.mark.asyncio + async def test_create_transaction( + self, + credit_transaction_repository: CreditTransactionRepository, + test_user_id: int, + ) -> None: + """Test creating a new transaction.""" + transaction_data = { + "user_id": test_user_id, + "action_type": "test_action", + "amount": -10, + "balance_before": 100, + "balance_after": 90, + "description": "Test transaction", + "success": True, + "metadata_json": json.dumps({"test": "data"}), + } + + transaction = await credit_transaction_repository.create(transaction_data) + + assert transaction.id is not None + assert transaction.user_id == test_user_id + assert transaction.action_type == "test_action" + assert transaction.amount == -10 + assert transaction.balance_before == 100 + assert transaction.balance_after == 90 + assert transaction.success is True + assert json.loads(transaction.metadata_json) == {"test": "data"} + + @pytest.mark.asyncio + async def test_update_transaction( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + ) -> None: + """Test updating a transaction.""" + transaction = test_transactions[0] + update_data = { + "description": "Updated description", + "metadata_json": json.dumps({"updated": True}), + } + + updated_transaction = await credit_transaction_repository.update( + transaction, update_data + ) + + assert updated_transaction.id == transaction.id + assert updated_transaction.description == "Updated description" + assert json.loads(updated_transaction.metadata_json) == {"updated": True} + # Other fields should remain unchanged + assert updated_transaction.amount == transaction.amount + assert updated_transaction.action_type == transaction.action_type + + @pytest.mark.asyncio + async def test_delete_transaction( + self, + credit_transaction_repository: CreditTransactionRepository, + test_session: AsyncSession, + test_user_id: int, + ) -> None: + """Test deleting a transaction.""" + # Create a transaction to delete + transaction_data = { + "user_id": test_user_id, + "action_type": "to_delete", + "amount": -1, + "balance_before": 10, + "balance_after": 9, + "description": "To be deleted", + "success": True, + "metadata_json": None, + } + transaction = await credit_transaction_repository.create(transaction_data) + transaction_id = transaction.id + + # Delete the transaction + await credit_transaction_repository.delete(transaction) + + # Verify transaction is deleted + deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id) + assert deleted_transaction is None + + @pytest.mark.asyncio + async def test_transaction_ordering( + self, + credit_transaction_repository: CreditTransactionRepository, + test_transactions: list[CreditTransaction], + test_user_id: int, + ) -> None: + """Test that transactions are ordered by created_at desc.""" + transactions = await credit_transaction_repository.get_by_user_id(test_user_id) + + # Should be ordered by created_at desc (newest first) + for i in range(len(transactions) - 1): + assert transactions[i].created_at >= transactions[i + 1].created_at \ No newline at end of file diff --git a/tests/repositories/test_sound.py b/tests/repositories/test_sound.py new file mode 100644 index 0000000..fe90f49 --- /dev/null +++ b/tests/repositories/test_sound.py @@ -0,0 +1,376 @@ +"""Tests for sound repository.""" + +from collections.abc import AsyncGenerator + +import pytest +import pytest_asyncio +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.sound import Sound +from app.repositories.sound import SoundRepository + + +class TestSoundRepository: + """Test sound repository operations.""" + + @pytest_asyncio.fixture + async def sound_repository( + self, + test_session: AsyncSession, + ) -> AsyncGenerator[SoundRepository, None]: # type: ignore[misc] + """Create a sound repository instance.""" + yield SoundRepository(test_session) + + @pytest_asyncio.fixture + async def test_sound( + self, + test_session: AsyncSession, + ) -> AsyncGenerator[Sound, None]: # type: ignore[misc] + """Create a test sound.""" + sound_data = { + "name": "Test Sound", + "filename": "test_sound.mp3", + "type": "SDB", + "duration": 5000, + "size": 1024000, + "hash": "test_hash_123", + "play_count": 0, + "is_normalized": False, + } + sound = Sound(**sound_data) + test_session.add(sound) + await test_session.commit() + await test_session.refresh(sound) + yield sound + + @pytest_asyncio.fixture + async def normalized_sound( + self, + test_session: AsyncSession, + ) -> AsyncGenerator[Sound, None]: # type: ignore[misc] + """Create a normalized test sound.""" + sound_data = { + "name": "Normalized Sound", + "filename": "normalized_sound.mp3", + "type": "TTS", + "duration": 3000, + "size": 512000, + "hash": "normalized_hash_456", + "play_count": 5, + "is_normalized": True, + "normalized_filename": "normalized_sound_norm.mp3", + "normalized_duration": 3000, + "normalized_size": 480000, + "normalized_hash": "normalized_hash_norm_456", + } + sound = Sound(**sound_data) + test_session.add(sound) + await test_session.commit() + await test_session.refresh(sound) + yield sound + + @pytest.mark.asyncio + async def test_get_by_id_existing( + self, + sound_repository: SoundRepository, + test_sound: Sound, + ) -> None: + """Test getting sound by ID when it exists.""" + sound = await sound_repository.get_by_id(test_sound.id) + + assert sound is not None + assert sound.id == test_sound.id + assert sound.name == test_sound.name + assert sound.filename == test_sound.filename + assert sound.type == test_sound.type + + @pytest.mark.asyncio + async def test_get_by_id_nonexistent( + self, + sound_repository: SoundRepository, + ) -> None: + """Test getting sound by ID when it doesn't exist.""" + sound = await sound_repository.get_by_id(99999) + + assert sound is None + + @pytest.mark.asyncio + async def test_get_by_filename_existing( + self, + sound_repository: SoundRepository, + test_sound: Sound, + ) -> None: + """Test getting sound by filename when it exists.""" + sound = await sound_repository.get_by_filename(test_sound.filename) + + assert sound is not None + assert sound.id == test_sound.id + assert sound.filename == test_sound.filename + + @pytest.mark.asyncio + async def test_get_by_filename_nonexistent( + self, + sound_repository: SoundRepository, + ) -> None: + """Test getting sound by filename when it doesn't exist.""" + sound = await sound_repository.get_by_filename("nonexistent.mp3") + + assert sound is None + + @pytest.mark.asyncio + async def test_get_by_hash_existing( + self, + sound_repository: SoundRepository, + test_sound: Sound, + ) -> None: + """Test getting sound by hash when it exists.""" + sound = await sound_repository.get_by_hash(test_sound.hash) + + assert sound is not None + assert sound.id == test_sound.id + assert sound.hash == test_sound.hash + + @pytest.mark.asyncio + async def test_get_by_hash_nonexistent( + self, + sound_repository: SoundRepository, + ) -> None: + """Test getting sound by hash when it doesn't exist.""" + sound = await sound_repository.get_by_hash("nonexistent_hash") + + assert sound is None + + @pytest.mark.asyncio + async def test_get_by_type( + self, + sound_repository: SoundRepository, + test_sound: Sound, + normalized_sound: Sound, + ) -> None: + """Test getting sounds by type.""" + sdb_sounds = await sound_repository.get_by_type("SDB") + tts_sounds = await sound_repository.get_by_type("TTS") + ext_sounds = await sound_repository.get_by_type("EXT") + + # Should find the SDB sound + assert len(sdb_sounds) >= 1 + assert any(sound.id == test_sound.id for sound in sdb_sounds) + + # Should find the TTS sound + assert len(tts_sounds) >= 1 + assert any(sound.id == normalized_sound.id for sound in tts_sounds) + + # Should not find any EXT sounds + assert len(ext_sounds) == 0 + + @pytest.mark.asyncio + async def test_create_sound( + self, + sound_repository: SoundRepository, + ) -> None: + """Test creating a new sound.""" + sound_data = { + "name": "New Sound", + "filename": "new_sound.wav", + "type": "EXT", + "duration": 7500, + "size": 2048000, + "hash": "new_hash_789", + "play_count": 0, + "is_normalized": False, + } + + sound = await sound_repository.create(sound_data) + + assert sound.id is not None + assert sound.name == sound_data["name"] + assert sound.filename == sound_data["filename"] + assert sound.type == sound_data["type"] + assert sound.duration == sound_data["duration"] + assert sound.size == sound_data["size"] + assert sound.hash == sound_data["hash"] + assert sound.play_count == 0 + assert sound.is_normalized is False + + @pytest.mark.asyncio + async def test_update_sound( + self, + sound_repository: SoundRepository, + test_sound: Sound, + ) -> None: + """Test updating a sound.""" + update_data = { + "name": "Updated Sound Name", + "play_count": 10, + "is_normalized": True, + "normalized_filename": "updated_norm.mp3", + } + + updated_sound = await sound_repository.update(test_sound, update_data) + + assert updated_sound.id == test_sound.id + assert updated_sound.name == "Updated Sound Name" + assert updated_sound.play_count == 10 + assert updated_sound.is_normalized is True + assert updated_sound.normalized_filename == "updated_norm.mp3" + assert updated_sound.filename == test_sound.filename # Unchanged + + @pytest.mark.asyncio + async def test_delete_sound( + self, + sound_repository: SoundRepository, + test_session: AsyncSession, + ) -> None: + """Test deleting a sound.""" + # Create a sound to delete + sound_data = { + "name": "To Delete", + "filename": "to_delete.mp3", + "type": "SDB", + "duration": 1000, + "size": 256000, + "hash": "delete_hash", + "play_count": 0, + "is_normalized": False, + } + sound = await sound_repository.create(sound_data) + sound_id = sound.id + + # Delete the sound + await sound_repository.delete(sound) + + # Verify sound is deleted + deleted_sound = await sound_repository.get_by_id(sound_id) + assert deleted_sound is None + + @pytest.mark.asyncio + async def test_search_by_name( + self, + sound_repository: SoundRepository, + test_sound: Sound, + normalized_sound: Sound, + ) -> None: + """Test searching sounds by name.""" + # Search for "test" should find test_sound + results = await sound_repository.search_by_name("test") + assert len(results) >= 1 + assert any(sound.id == test_sound.id for sound in results) + + # Search for "normalized" should find normalized_sound + results = await sound_repository.search_by_name("normalized") + assert len(results) >= 1 + assert any(sound.id == normalized_sound.id for sound in results) + + # Case insensitive search + results = await sound_repository.search_by_name("TEST") + assert len(results) >= 1 + assert any(sound.id == test_sound.id for sound in results) + + # Partial match + results = await sound_repository.search_by_name("norm") + assert len(results) >= 1 + assert any(sound.id == normalized_sound.id for sound in results) + + # No matches + results = await sound_repository.search_by_name("nonexistent") + assert len(results) == 0 + + @pytest.mark.asyncio + async def test_get_popular_sounds( + self, + sound_repository: SoundRepository, + test_sound: Sound, + normalized_sound: Sound, + ) -> None: + """Test getting popular sounds.""" + # Update play counts to test ordering + await sound_repository.update(test_sound, {"play_count": 15}) + await sound_repository.update(normalized_sound, {"play_count": 5}) + + # Create another sound with higher play count + high_play_sound_data = { + "name": "Popular Sound", + "filename": "popular.mp3", + "type": "SDB", + "duration": 2000, + "size": 300000, + "hash": "popular_hash", + "play_count": 25, + "is_normalized": False, + } + high_play_sound = await sound_repository.create(high_play_sound_data) + + # Get popular sounds + popular_sounds = await sound_repository.get_popular_sounds(limit=10) + + assert len(popular_sounds) >= 3 + # Should be ordered by play_count desc + assert popular_sounds[0].play_count >= popular_sounds[1].play_count + # The highest play count sound should be first + assert popular_sounds[0].id == high_play_sound.id + + @pytest.mark.asyncio + async def test_get_unnormalized_sounds( + self, + sound_repository: SoundRepository, + test_sound: Sound, + normalized_sound: Sound, + ) -> None: + """Test getting unnormalized sounds.""" + unnormalized_sounds = await sound_repository.get_unnormalized_sounds() + + # Should include test_sound (not normalized) + assert any(sound.id == test_sound.id for sound in unnormalized_sounds) + # Should not include normalized_sound (already normalized) + assert not any(sound.id == normalized_sound.id for sound in unnormalized_sounds) + + @pytest.mark.asyncio + async def test_get_unnormalized_sounds_by_type( + self, + sound_repository: SoundRepository, + test_sound: Sound, + normalized_sound: Sound, + ) -> None: + """Test getting unnormalized sounds by type.""" + # Get unnormalized SDB sounds + sdb_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("SDB") + # Should include test_sound (SDB, not normalized) + assert any(sound.id == test_sound.id for sound in sdb_unnormalized) + + # Get unnormalized TTS sounds + tts_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("TTS") + # Should not include normalized_sound (TTS, but already normalized) + assert not any(sound.id == normalized_sound.id for sound in tts_unnormalized) + + # Get unnormalized EXT sounds + ext_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("EXT") + # Should be empty + assert len(ext_unnormalized) == 0 + + @pytest.mark.asyncio + async def test_create_duplicate_hash( + self, + sound_repository: SoundRepository, + test_sound: Sound, + ) -> None: + """Test creating sound with duplicate hash is allowed.""" + # Store the hash to avoid lazy loading issues + original_hash = test_sound.hash + + duplicate_sound_data = { + "name": "Duplicate Hash Sound", + "filename": "duplicate.mp3", + "type": "SDB", + "duration": 1000, + "size": 100000, + "hash": original_hash, # Same hash as test_sound + "play_count": 0, + "is_normalized": False, + } + + # Should succeed - duplicate hashes are allowed + duplicate_sound = await sound_repository.create(duplicate_sound_data) + + assert duplicate_sound.id is not None + assert duplicate_sound.name == "Duplicate Hash Sound" + assert duplicate_sound.hash == original_hash # Same hash is allowed \ No newline at end of file diff --git a/tests/repositories/test_user_oauth.py b/tests/repositories/test_user_oauth.py new file mode 100644 index 0000000..bbacb64 --- /dev/null +++ b/tests/repositories/test_user_oauth.py @@ -0,0 +1,268 @@ +"""Tests for user OAuth repository.""" + +from collections.abc import AsyncGenerator + +import pytest +import pytest_asyncio +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.user import User +from app.models.user_oauth import UserOauth +from app.repositories.user_oauth import UserOauthRepository + + +class TestUserOauthRepository: + """Test user OAuth repository operations.""" + + @pytest_asyncio.fixture + async def user_oauth_repository( + self, + test_session: AsyncSession, + ) -> AsyncGenerator[UserOauthRepository, None]: # type: ignore[misc] + """Create a user OAuth repository instance.""" + yield UserOauthRepository(test_session) + + @pytest_asyncio.fixture + async def test_user_id( + self, + test_user: User, + ) -> int: + """Get test user ID to avoid lazy loading issues.""" + return test_user.id + + @pytest_asyncio.fixture + async def test_oauth( + self, + test_session: AsyncSession, + test_user_id: int, + ) -> AsyncGenerator[UserOauth, None]: # type: ignore[misc] + """Create a test OAuth record.""" + oauth_data = { + "user_id": test_user_id, + "provider": "google", + "provider_user_id": "google_123456", + "email": "test@gmail.com", + "name": "Test User Google", + "picture": None, + } + oauth = UserOauth(**oauth_data) + test_session.add(oauth) + await test_session.commit() + await test_session.refresh(oauth) + yield oauth + + @pytest.mark.asyncio + async def test_get_by_provider_user_id_existing( + self, + user_oauth_repository: UserOauthRepository, + test_oauth: UserOauth, + ) -> None: + """Test getting OAuth by provider user ID when it exists.""" + oauth = await user_oauth_repository.get_by_provider_user_id( + "google", "google_123456" + ) + + assert oauth is not None + assert oauth.id == test_oauth.id + assert oauth.provider == "google" + assert oauth.provider_user_id == "google_123456" + assert oauth.user_id == test_oauth.user_id + + @pytest.mark.asyncio + async def test_get_by_provider_user_id_nonexistent( + self, + user_oauth_repository: UserOauthRepository, + ) -> None: + """Test getting OAuth by provider user ID when it doesn't exist.""" + oauth = await user_oauth_repository.get_by_provider_user_id( + "google", "nonexistent_id" + ) + + assert oauth is None + + @pytest.mark.asyncio + async def test_get_by_user_id_and_provider_existing( + self, + user_oauth_repository: UserOauthRepository, + test_oauth: UserOauth, + test_user_id: int, + ) -> None: + """Test getting OAuth by user ID and provider when it exists.""" + oauth = await user_oauth_repository.get_by_user_id_and_provider( + test_user_id, "google" + ) + + assert oauth is not None + assert oauth.id == test_oauth.id + assert oauth.provider == "google" + assert oauth.user_id == test_user_id + + @pytest.mark.asyncio + async def test_get_by_user_id_and_provider_nonexistent( + self, + user_oauth_repository: UserOauthRepository, + test_user_id: int, + ) -> None: + """Test getting OAuth by user ID and provider when it doesn't exist.""" + oauth = await user_oauth_repository.get_by_user_id_and_provider( + test_user_id, "github" + ) + + assert oauth is None + + @pytest.mark.asyncio + async def test_create_oauth( + self, + user_oauth_repository: UserOauthRepository, + test_user_id: int, + ) -> None: + """Test creating a new OAuth record.""" + oauth_data = { + "user_id": test_user_id, + "provider": "github", + "provider_user_id": "github_789", + "email": "test@github.com", + "name": "Test User GitHub", + "picture": None, + } + + oauth = await user_oauth_repository.create(oauth_data) + + assert oauth.id is not None + assert oauth.user_id == test_user_id + assert oauth.provider == "github" + assert oauth.provider_user_id == "github_789" + assert oauth.email == "test@github.com" + assert oauth.name == "Test User GitHub" + + @pytest.mark.asyncio + async def test_update_oauth( + self, + user_oauth_repository: UserOauthRepository, + test_oauth: UserOauth, + ) -> None: + """Test updating an OAuth record.""" + update_data = { + "email": "updated@gmail.com", + "name": "Updated User Name", + "picture": "https://example.com/photo.jpg", + } + + updated_oauth = await user_oauth_repository.update(test_oauth, update_data) + + assert updated_oauth.id == test_oauth.id + assert updated_oauth.email == "updated@gmail.com" + assert updated_oauth.name == "Updated User Name" + assert updated_oauth.picture == "https://example.com/photo.jpg" + assert updated_oauth.provider == test_oauth.provider # Unchanged + assert updated_oauth.provider_user_id == test_oauth.provider_user_id # Unchanged + + @pytest.mark.asyncio + async def test_delete_oauth( + self, + user_oauth_repository: UserOauthRepository, + test_session: AsyncSession, + test_user_id: int, + ) -> None: + """Test deleting an OAuth record.""" + # Create an OAuth record to delete + oauth_data = { + "user_id": test_user_id, + "provider": "twitter", + "provider_user_id": "twitter_456", + "email": "test@twitter.com", + "name": "Test User Twitter", + "picture": None, + } + oauth = await user_oauth_repository.create(oauth_data) + oauth_id = oauth.id + + # Delete the OAuth record + await user_oauth_repository.delete(oauth) + + # Verify it's deleted by trying to find it + deleted_oauth = await user_oauth_repository.get_by_provider_user_id( + "twitter", "twitter_456" + ) + assert deleted_oauth is None + + @pytest.mark.asyncio + async def test_create_duplicate_provider_user_id( + self, + user_oauth_repository: UserOauthRepository, + test_oauth: UserOauth, + test_user_id: int, + ) -> None: + """Test creating OAuth with duplicate provider user ID should fail.""" + # Try to create another OAuth with the same provider and provider_user_id + duplicate_oauth_data = { + "user_id": test_user_id, + "provider": "google", + "provider_user_id": "google_123456", # Same as test_oauth + "email": "another@gmail.com", + "name": "Another User", + "picture": None, + } + + # This should fail due to unique constraint + with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar + await user_oauth_repository.create(duplicate_oauth_data) + + @pytest.mark.asyncio + async def test_multiple_providers_same_user( + self, + user_oauth_repository: UserOauthRepository, + test_user_id: int, + ) -> None: + """Test that a user can have multiple OAuth providers.""" + # Create Google OAuth + google_oauth_data = { + "user_id": test_user_id, + "provider": "google", + "provider_user_id": "google_user_1", + "email": "user@gmail.com", + "name": "Test User Google", + "picture": None, + } + google_oauth = await user_oauth_repository.create(google_oauth_data) + + # Create GitHub OAuth for the same user + github_oauth_data = { + "user_id": test_user_id, + "provider": "github", + "provider_user_id": "github_user_1", + "email": "user@github.com", + "name": "Test User GitHub", + "picture": None, + } + github_oauth = await user_oauth_repository.create(github_oauth_data) + + # Verify both exist by querying back from database + found_google = await user_oauth_repository.get_by_user_id_and_provider( + test_user_id, "google" + ) + found_github = await user_oauth_repository.get_by_user_id_and_provider( + test_user_id, "github" + ) + + assert found_google is not None + assert found_github is not None + assert found_google.provider == "google" + assert found_github.provider == "github" + assert found_google.user_id == test_user_id + assert found_github.user_id == test_user_id + assert found_google.provider_user_id == "google_user_1" + assert found_github.provider_user_id == "github_user_1" + + # Verify we can also find them by provider_user_id + found_google_by_provider = await user_oauth_repository.get_by_provider_user_id( + "google", "google_user_1" + ) + found_github_by_provider = await user_oauth_repository.get_by_provider_user_id( + "github", "github_user_1" + ) + + assert found_google_by_provider is not None + assert found_github_by_provider is not None + assert found_google_by_provider.user_id == test_user_id + assert found_github_by_provider.user_id == test_user_id \ No newline at end of file diff --git a/tests/services/test_credit.py b/tests/services/test_credit.py new file mode 100644 index 0000000..6be40eb --- /dev/null +++ b/tests/services/test_credit.py @@ -0,0 +1,358 @@ +"""Tests for credit service.""" + +import json +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.models.credit_action import CreditActionType +from app.models.credit_transaction import CreditTransaction +from app.models.user import User +from app.services.credit import CreditService, InsufficientCreditsError + + +class TestCreditService: + """Test credit service functionality.""" + + @pytest.fixture + def mock_db_session_factory(self): + """Create a mock database session factory.""" + session = AsyncMock(spec=AsyncSession) + return lambda: session + + @pytest.fixture + def credit_service(self, mock_db_session_factory): + """Create a credit service instance for testing.""" + return CreditService(mock_db_session_factory) + + @pytest.fixture + def sample_user(self): + """Create a sample user for testing.""" + return User( + id=1, + name="Test User", + email="test@example.com", + role="user", + credits=10, + plan_id=1, + ) + + @pytest.mark.asyncio + async def test_check_credits_sufficient(self, credit_service, sample_user): + """Test checking credits when user has sufficient credits.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = sample_user + + result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND) + + assert result is True + mock_repo.get_by_id.assert_called_once_with(1) + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_check_credits_insufficient(self, credit_service): + """Test checking credits when user has insufficient credits.""" + mock_session = credit_service.db_session_factory() + poor_user = User( + id=1, + name="Poor User", + email="poor@example.com", + role="user", + credits=0, # No credits + plan_id=1, + ) + + with patch("app.services.credit.UserRepository") as mock_repo_class: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = poor_user + + result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND) + + assert result is False + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_check_credits_user_not_found(self, credit_service): + """Test checking credits when user is not found.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = None + + result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND) + + assert result is False + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_validate_and_reserve_credits_success(self, credit_service, sample_user): + """Test successful credit validation and reservation.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = sample_user + + user, action = await credit_service.validate_and_reserve_credits( + 1, CreditActionType.VLC_PLAY_SOUND + ) + + assert user == sample_user + assert action.action_type == CreditActionType.VLC_PLAY_SOUND + assert action.cost == 1 + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_validate_and_reserve_credits_insufficient(self, credit_service): + """Test credit validation with insufficient credits.""" + mock_session = credit_service.db_session_factory() + poor_user = User( + id=1, + name="Poor User", + email="poor@example.com", + role="user", + credits=0, + plan_id=1, + ) + + with patch("app.services.credit.UserRepository") as mock_repo_class: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = poor_user + + with pytest.raises(InsufficientCreditsError) as exc_info: + await credit_service.validate_and_reserve_credits( + 1, CreditActionType.VLC_PLAY_SOUND + ) + + assert exc_info.value.required == 1 + assert exc_info.value.available == 0 + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_validate_and_reserve_credits_user_not_found(self, credit_service): + """Test credit validation when user is not found.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = None + + with pytest.raises(ValueError, match="User 999 not found"): + await credit_service.validate_and_reserve_credits( + 999, CreditActionType.VLC_PLAY_SOUND + ) + + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_deduct_credits_success(self, credit_service, sample_user): + """Test successful credit deduction.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class, \ + patch("app.services.credit.socket_manager") as mock_socket_manager: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = sample_user + mock_socket_manager.send_to_user = AsyncMock() + + transaction = await credit_service.deduct_credits( + 1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"} + ) + + # Verify user credits were updated + mock_repo.update.assert_called_once_with(sample_user, {"credits": 9}) + + # Verify transaction was created + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + # Verify socket event was emitted + mock_socket_manager.send_to_user.assert_called_once_with( + "1", "user_credits_changed", { + "user_id": "1", + "credits_before": 10, + "credits_after": 9, + "credits_deducted": 1, + "action_type": "vlc_play_sound", + "success": True, + } + ) + + # Check transaction details + added_transaction = mock_session.add.call_args[0][0] + assert isinstance(added_transaction, CreditTransaction) + assert added_transaction.user_id == 1 + assert added_transaction.action_type == "vlc_play_sound" + assert added_transaction.amount == -1 + assert added_transaction.balance_before == 10 + assert added_transaction.balance_after == 9 + assert added_transaction.success is True + assert json.loads(added_transaction.metadata_json) == {"test": "data"} + + @pytest.mark.asyncio + async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user): + """Test credit deduction when action failed but requires success.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class, \ + patch("app.services.credit.socket_manager") as mock_socket_manager: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = sample_user + mock_socket_manager.send_to_user = AsyncMock() + + transaction = await credit_service.deduct_credits( + 1, CreditActionType.VLC_PLAY_SOUND, False # Action failed + ) + + # Verify user credits were NOT updated (action requires success) + mock_repo.update.assert_not_called() + + # Verify transaction was still created for auditing + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + # Verify no socket event was emitted since no credits were actually deducted + mock_socket_manager.send_to_user.assert_not_called() + + # Check transaction details + added_transaction = mock_session.add.call_args[0][0] + assert added_transaction.amount == 0 # No deduction for failed action + assert added_transaction.balance_before == 10 + assert added_transaction.balance_after == 10 # No change + assert added_transaction.success is False + + @pytest.mark.asyncio + async def test_deduct_credits_insufficient(self, credit_service): + """Test credit deduction with insufficient credits.""" + mock_session = credit_service.db_session_factory() + poor_user = User( + id=1, + name="Poor User", + email="poor@example.com", + role="user", + credits=0, + plan_id=1, + ) + + with patch("app.services.credit.UserRepository") as mock_repo_class, \ + patch("app.services.credit.socket_manager") as mock_socket_manager: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = poor_user + mock_socket_manager.send_to_user = AsyncMock() + + with pytest.raises(InsufficientCreditsError): + await credit_service.deduct_credits( + 1, CreditActionType.VLC_PLAY_SOUND, True + ) + + # Verify no socket event was emitted since credits could not be deducted + mock_socket_manager.send_to_user.assert_not_called() + + mock_session.rollback.assert_called_once() + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_add_credits(self, credit_service, sample_user): + """Test adding credits to user account.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class, \ + patch("app.services.credit.socket_manager") as mock_socket_manager: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = sample_user + mock_socket_manager.send_to_user = AsyncMock() + + transaction = await credit_service.add_credits( + 1, 5, "Bonus credits", {"reason": "signup"} + ) + + # Verify user credits were updated + mock_repo.update.assert_called_once_with(sample_user, {"credits": 15}) + + # Verify transaction was created + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + # Verify socket event was emitted + mock_socket_manager.send_to_user.assert_called_once_with( + "1", "user_credits_changed", { + "user_id": "1", + "credits_before": 10, + "credits_after": 15, + "credits_added": 5, + "description": "Bonus credits", + "success": True, + } + ) + + # Check transaction details + added_transaction = mock_session.add.call_args[0][0] + assert added_transaction.amount == 5 + assert added_transaction.balance_before == 10 + assert added_transaction.balance_after == 15 + assert added_transaction.description == "Bonus credits" + + @pytest.mark.asyncio + async def test_add_credits_invalid_amount(self, credit_service): + """Test adding invalid amount of credits.""" + with pytest.raises(ValueError, match="Amount must be positive"): + await credit_service.add_credits(1, 0, "Invalid") + + with pytest.raises(ValueError, match="Amount must be positive"): + await credit_service.add_credits(1, -5, "Invalid") + + @pytest.mark.asyncio + async def test_get_user_balance(self, credit_service, sample_user): + """Test getting user credit balance.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = sample_user + + balance = await credit_service.get_user_balance(1) + + assert balance == 10 + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_get_user_balance_user_not_found(self, credit_service): + """Test getting balance for non-existent user.""" + mock_session = credit_service.db_session_factory() + + with patch("app.services.credit.UserRepository") as mock_repo_class: + mock_repo = AsyncMock() + mock_repo_class.return_value = mock_repo + mock_repo.get_by_id.return_value = None + + with pytest.raises(ValueError, match="User 999 not found"): + await credit_service.get_user_balance(999) + + mock_session.close.assert_called_once() + + +class TestInsufficientCreditsError: + """Test InsufficientCreditsError exception.""" + + def test_insufficient_credits_error_creation(self): + """Test creating InsufficientCreditsError.""" + error = InsufficientCreditsError(5, 2) + assert error.required == 5 + assert error.available == 2 + assert str(error) == "Insufficient credits: 5 required, 2 available" \ No newline at end of file diff --git a/tests/utils/test_credit_decorators.py b/tests/utils/test_credit_decorators.py new file mode 100644 index 0000000..9f86f6b --- /dev/null +++ b/tests/utils/test_credit_decorators.py @@ -0,0 +1,277 @@ +"""Tests for credit decorators.""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from app.models.credit_action import CreditActionType +from app.services.credit import CreditService, InsufficientCreditsError +from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only + + +class TestRequiresCreditsDecorator: + """Test requires_credits decorator.""" + + @pytest.fixture + def mock_credit_service(self): + """Create a mock credit service.""" + service = AsyncMock(spec=CreditService) + service.validate_and_reserve_credits = AsyncMock() + service.deduct_credits = AsyncMock() + return service + + @pytest.fixture + def credit_service_factory(self, mock_credit_service): + """Create a credit service factory.""" + return lambda: mock_credit_service + + @pytest.mark.asyncio + async def test_decorator_success(self, credit_service_factory, mock_credit_service): + """Test decorator with successful action.""" + + @requires_credits( + CreditActionType.VLC_PLAY_SOUND, + credit_service_factory, + user_id_param="user_id" + ) + async def test_action(user_id: int, message: str) -> str: + return f"Success: {message}" + + result = await test_action(user_id=123, message="test") + + assert result == "Success: test" + mock_credit_service.validate_and_reserve_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, None + ) + mock_credit_service.deduct_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, True, None + ) + + @pytest.mark.asyncio + async def test_decorator_with_metadata(self, credit_service_factory, mock_credit_service): + """Test decorator with metadata extraction.""" + + def extract_metadata(user_id: int, sound_name: str) -> dict: + return {"sound_name": sound_name} + + @requires_credits( + CreditActionType.VLC_PLAY_SOUND, + credit_service_factory, + user_id_param="user_id", + metadata_extractor=extract_metadata + ) + async def test_action(user_id: int, sound_name: str) -> bool: + return True + + await test_action(user_id=123, sound_name="test.mp3") + + mock_credit_service.validate_and_reserve_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"} + ) + mock_credit_service.deduct_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"} + ) + + @pytest.mark.asyncio + async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service): + """Test decorator with failed action.""" + + @requires_credits( + CreditActionType.VLC_PLAY_SOUND, + credit_service_factory, + user_id_param="user_id" + ) + async def test_action(user_id: int) -> bool: + return False # Action fails + + result = await test_action(user_id=123) + + assert result is False + mock_credit_service.deduct_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, False, None + ) + + @pytest.mark.asyncio + async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service): + """Test decorator when action raises exception.""" + + @requires_credits( + CreditActionType.VLC_PLAY_SOUND, + credit_service_factory, + user_id_param="user_id" + ) + async def test_action(user_id: int) -> str: + raise ValueError("Test error") + + with pytest.raises(ValueError, match="Test error"): + await test_action(user_id=123) + + mock_credit_service.deduct_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, False, None + ) + + @pytest.mark.asyncio + async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service): + """Test decorator with insufficient credits.""" + mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0) + + @requires_credits( + CreditActionType.VLC_PLAY_SOUND, + credit_service_factory, + user_id_param="user_id" + ) + async def test_action(user_id: int) -> str: + return "Should not execute" + + with pytest.raises(InsufficientCreditsError): + await test_action(user_id=123) + + # Should not call deduct_credits since validation failed + mock_credit_service.deduct_credits.assert_not_called() + + @pytest.mark.asyncio + async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service): + """Test decorator extracting user_id from positional args.""" + + @requires_credits( + CreditActionType.VLC_PLAY_SOUND, + credit_service_factory, + user_id_param="user_id" + ) + async def test_action(user_id: int, message: str) -> str: + return message + + result = await test_action(123, "test") + + assert result == "test" + mock_credit_service.validate_and_reserve_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, None + ) + + @pytest.mark.asyncio + async def test_decorator_missing_user_id(self, credit_service_factory): + """Test decorator when user_id cannot be extracted.""" + + @requires_credits( + CreditActionType.VLC_PLAY_SOUND, + credit_service_factory, + user_id_param="user_id" + ) + async def test_action(other_param: str) -> str: + return other_param + + with pytest.raises(ValueError, match="Could not extract user_id"): + await test_action(other_param="test") + + +class TestValidateCreditsOnlyDecorator: + """Test validate_credits_only decorator.""" + + @pytest.fixture + def mock_credit_service(self): + """Create a mock credit service.""" + service = AsyncMock(spec=CreditService) + service.validate_and_reserve_credits = AsyncMock() + return service + + @pytest.fixture + def credit_service_factory(self, mock_credit_service): + """Create a credit service factory.""" + return lambda: mock_credit_service + + @pytest.mark.asyncio + async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service): + """Test validate_credits_only decorator.""" + + @validate_credits_only( + CreditActionType.VLC_PLAY_SOUND, + credit_service_factory, + user_id_param="user_id" + ) + async def test_action(user_id: int, message: str) -> str: + return f"Validated: {message}" + + result = await test_action(user_id=123, message="test") + + assert result == "Validated: test" + mock_credit_service.validate_and_reserve_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND + ) + # Should not deduct credits, only validate + mock_credit_service.deduct_credits.assert_not_called() + + +class TestCreditManager: + """Test CreditManager context manager.""" + + @pytest.fixture + def mock_credit_service(self): + """Create a mock credit service.""" + service = AsyncMock(spec=CreditService) + service.validate_and_reserve_credits = AsyncMock() + service.deduct_credits = AsyncMock() + return service + + @pytest.mark.asyncio + async def test_credit_manager_success(self, mock_credit_service): + """Test CreditManager with successful operation.""" + async with CreditManager( + mock_credit_service, + 123, + CreditActionType.VLC_PLAY_SOUND, + {"test": "data"} + ) as manager: + manager.mark_success() + + mock_credit_service.validate_and_reserve_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"} + ) + mock_credit_service.deduct_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"} + ) + + @pytest.mark.asyncio + async def test_credit_manager_failure(self, mock_credit_service): + """Test CreditManager with failed operation.""" + async with CreditManager( + mock_credit_service, + 123, + CreditActionType.VLC_PLAY_SOUND + ): + # Don't mark as success - should be considered failed + pass + + mock_credit_service.deduct_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, False, None + ) + + @pytest.mark.asyncio + async def test_credit_manager_exception(self, mock_credit_service): + """Test CreditManager when exception occurs.""" + with pytest.raises(ValueError, match="Test error"): + async with CreditManager( + mock_credit_service, + 123, + CreditActionType.VLC_PLAY_SOUND + ): + raise ValueError("Test error") + + mock_credit_service.deduct_credits.assert_called_once_with( + 123, CreditActionType.VLC_PLAY_SOUND, False, None + ) + + @pytest.mark.asyncio + async def test_credit_manager_validation_failure(self, mock_credit_service): + """Test CreditManager when validation fails.""" + mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0) + + with pytest.raises(InsufficientCreditsError): + async with CreditManager( + mock_credit_service, + 123, + CreditActionType.VLC_PLAY_SOUND + ): + pass + + # Should not call deduct_credits since validation failed + mock_credit_service.deduct_credits.assert_not_called() \ No newline at end of file