Add tests for sound repository, user OAuth repository, credit service, and credit decorators

- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities.
- Create tests for UserOauthRepository to validate OAuth record management.
- Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits.
- Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
JSC
2025-07-30 21:33:55 +02:00
parent dd10ef5d41
commit e43650c26c
14 changed files with 2692 additions and 1 deletions

View File

@@ -7,9 +7,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType
from app.models.user import User from app.models.user import User
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
from app.services.extraction import ExtractionInfo, ExtractionService from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.extraction_processor import extraction_processor from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService from app.services.sound_scanner import ScanResults, SoundScannerService
@@ -45,6 +47,12 @@ def get_vlc_player() -> VLCPlayerService:
return get_vlc_player_service(get_session_factory()) return get_vlc_player_service(get_session_factory())
def get_credit_service() -> CreditService:
"""Get the credit service."""
from app.core.database import get_session_factory
return CreditService(get_session_factory())
async def get_sound_repository( async def get_sound_repository(
session: Annotated[AsyncSession, Depends(get_db)], session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundRepository: ) -> SoundRepository:
@@ -373,8 +381,9 @@ async def play_sound_with_vlc(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)], vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)], sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
credit_service: Annotated[CreditService, Depends(get_credit_service)],
) -> dict[str, str | int | bool]: ) -> dict[str, str | int | bool]:
"""Play a sound using VLC subprocess.""" """Play a sound using VLC subprocess (requires 1 credit)."""
try: try:
# Get the sound # Get the sound
sound = await sound_repo.get_by_id(sound_id) sound = await sound_repo.get_by_id(sound_id)
@@ -384,9 +393,30 @@ async def play_sound_with_vlc(
detail=f"Sound with ID {sound_id} not found", detail=f"Sound with ID {sound_id} not found",
) )
# Check and validate credits before playing
try:
await credit_service.validate_and_reserve_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
{"sound_id": sound_id, "sound_name": sound.name}
)
except InsufficientCreditsError as e:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Insufficient credits: {e.required} required, {e.available} available",
) from e
# Play the sound using VLC # Play the sound using VLC
success = await vlc_player.play_sound(sound) success = await vlc_player.play_sound(sound)
# Deduct credits based on success
await credit_service.deduct_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
success,
{"sound_id": sound_id, "sound_name": sound.name},
)
if not success: if not success:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -398,6 +428,7 @@ async def play_sound_with_vlc(
"sound_id": sound_id, "sound_id": sound_id,
"sound_name": sound.name, "sound_name": sound.name,
"success": True, "success": True,
"credits_deducted": 1,
} }
except HTTPException: except HTTPException:

121
app/models/credit_action.py Normal file
View File

@@ -0,0 +1,121 @@
"""Credit action definitions for the credit system."""
from enum import Enum
from typing import Any
class CreditActionType(str, Enum):
"""Types of actions that consume credits."""
VLC_PLAY_SOUND = "vlc_play_sound"
AUDIO_EXTRACTION = "audio_extraction"
TEXT_TO_SPEECH = "text_to_speech"
SOUND_NORMALIZATION = "sound_normalization"
API_REQUEST = "api_request"
PLAYLIST_CREATION = "playlist_creation"
class CreditAction:
"""Definition of a credit-consuming action."""
def __init__(
self,
action_type: CreditActionType,
cost: int,
description: str,
*,
requires_success: bool = True,
) -> None:
"""Initialize a credit action.
Args:
action_type: The type of action
cost: Number of credits required
description: Human-readable description
requires_success: Whether credits are only deducted on successful completion
"""
self.action_type = action_type
self.cost = cost
self.description = description
self.requires_success = requires_success
def __str__(self) -> str:
"""Return string representation of the action."""
return f"{self.action_type.value} ({self.cost} credits)"
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"action_type": self.action_type.value,
"cost": self.cost,
"description": self.description,
"requires_success": self.requires_success,
}
# Predefined credit actions
CREDIT_ACTIONS = {
CreditActionType.VLC_PLAY_SOUND: CreditAction(
action_type=CreditActionType.VLC_PLAY_SOUND,
cost=1,
description="Play a sound using VLC player",
requires_success=True,
),
CreditActionType.AUDIO_EXTRACTION: CreditAction(
action_type=CreditActionType.AUDIO_EXTRACTION,
cost=5,
description="Extract audio from external URL",
requires_success=True,
),
CreditActionType.TEXT_TO_SPEECH: CreditAction(
action_type=CreditActionType.TEXT_TO_SPEECH,
cost=2,
description="Generate speech from text",
requires_success=True,
),
CreditActionType.SOUND_NORMALIZATION: CreditAction(
action_type=CreditActionType.SOUND_NORMALIZATION,
cost=1,
description="Normalize audio levels",
requires_success=True,
),
CreditActionType.API_REQUEST: CreditAction(
action_type=CreditActionType.API_REQUEST,
cost=1,
description="API request (rate limiting)",
requires_success=False, # Charged even if request fails
),
CreditActionType.PLAYLIST_CREATION: CreditAction(
action_type=CreditActionType.PLAYLIST_CREATION,
cost=3,
description="Create a new playlist",
requires_success=True,
),
}
def get_credit_action(action_type: CreditActionType) -> CreditAction:
"""Get a credit action definition by type.
Args:
action_type: The action type to look up
Returns:
The credit action definition
Raises:
KeyError: If action type is not found
"""
return CREDIT_ACTIONS[action_type]
def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
"""Get all available credit actions.
Returns:
Dictionary of all credit actions
"""
return CREDIT_ACTIONS.copy()

View File

@@ -0,0 +1,29 @@
"""Credit transaction model for tracking credit usage."""
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.user import User
class CreditTransaction(BaseModel, table=True):
"""Database model for credit transactions."""
__tablename__ = "credit_transaction" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="user.id", nullable=False)
action_type: str = Field(nullable=False)
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
balance_before: int = Field(nullable=False)
balance_after: int = Field(nullable=False)
description: str = Field(nullable=False)
success: bool = Field(nullable=False, default=True)
# JSON string for additional data
metadata_json: str | None = Field(default=None)
# relationships
user: "User" = Relationship(back_populates="credit_transactions")

View File

@@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship
from app.models.base import BaseModel from app.models.base import BaseModel
if TYPE_CHECKING: if TYPE_CHECKING:
from app.models.credit_transaction import CreditTransaction
from app.models.extraction import Extraction from app.models.extraction import Extraction
from app.models.plan import Plan from app.models.plan import Plan
from app.models.playlist import Playlist from app.models.playlist import Playlist
@@ -35,3 +36,4 @@ class User(BaseModel, table=True):
playlists: list["Playlist"] = Relationship(back_populates="user") playlists: list["Playlist"] = Relationship(back_populates="user")
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user") sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
extractions: list["Extraction"] = Relationship(back_populates="user") extractions: list["Extraction"] = Relationship(back_populates="user")
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")

132
app/repositories/base.py Normal file
View File

@@ -0,0 +1,132 @@
"""Base repository with common CRUD operations."""
from typing import Any, Generic, TypeVar
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
# Type variable for the model
ModelType = TypeVar("ModelType")
logger = get_logger(__name__)
class BaseRepository(Generic[ModelType]):
"""Base repository with common CRUD operations."""
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
"""Initialize the repository.
Args:
model: The SQLModel class
session: Database session
"""
self.model = model
self.session = session
async def get_by_id(self, entity_id: int) -> ModelType | None:
"""Get an entity by ID.
Args:
entity_id: The entity ID
Returns:
The entity if found, None otherwise
"""
try:
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
raise
async def get_all(
self,
limit: int = 100,
offset: int = 0,
) -> list[ModelType]:
"""Get all entities with pagination.
Args:
limit: Maximum number of entities to return
offset: Number of entities to skip
Returns:
List of entities
"""
try:
statement = select(self.model).limit(limit).offset(offset)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all %s", self.model.__name__)
raise
async def create(self, entity_data: dict[str, Any]) -> ModelType:
"""Create a new entity.
Args:
entity_data: Dictionary of entity data
Returns:
The created entity
"""
try:
entity = self.model(**entity_data)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to create %s", self.model.__name__)
raise
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
"""Update an entity.
Args:
entity: The entity to update
update_data: Dictionary of fields to update
Returns:
The updated entity
"""
try:
for field, value in update_data.items():
setattr(entity, field, value)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to update %s", self.model.__name__)
raise
async def delete(self, entity: ModelType) -> None:
"""Delete an entity.
Args:
entity: The entity to delete
"""
try:
await self.session.delete(entity)
await self.session.commit()
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
except Exception:
await self.session.rollback()
logger.exception("Failed to delete %s", self.model.__name__)
raise

View File

@@ -0,0 +1,108 @@
"""Repository for credit transaction database operations."""
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_transaction import CreditTransaction
from app.repositories.base import BaseRepository
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
"""Repository for credit transaction operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the repository.
Args:
session: Database session
"""
super().__init__(CreditTransaction, session)
async def get_by_user_id(
self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get credit transactions for a user.
Args:
user_id: The user ID
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of credit transactions ordered by creation date (newest first)
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.user_id == user_id)
.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
async def get_by_action_type(
self,
action_type: str,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get credit transactions by action type.
Args:
action_type: The action type to filter by
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of credit transactions ordered by creation date (newest first)
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.action_type == action_type)
.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
async def get_successful_transactions(
self,
user_id: int | None = None,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get successful credit transactions.
Args:
user_id: Optional user ID to filter by
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of successful credit transactions
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.success == True) # noqa: E712
)
if user_id is not None:
stmt = stmt.where(CreditTransaction.user_id == user_id)
stmt = (
stmt.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())

383
app/services/credit.py Normal file
View File

@@ -0,0 +1,383 @@
"""Credit management service for tracking and validating user credit usage."""
import json
from collections.abc import Callable
from typing import Any
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
logger = get_logger(__name__)
class InsufficientCreditsError(Exception):
"""Raised when user has insufficient credits for an action."""
def __init__(self, required: int, available: int) -> None:
"""Initialize the error.
Args:
required: Number of credits required
available: Number of credits available
"""
self.required = required
self.available = available
super().__init__(
f"Insufficient credits: {required} required, {available} available"
)
class CreditService:
"""Service for managing user credits and transactions."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
"""Initialize the credit service.
Args:
db_session_factory: Factory function to create database sessions
"""
self.db_session_factory = db_session_factory
async def check_credits(
self,
user_id: int,
action_type: CreditActionType,
) -> bool:
"""Check if user has sufficient credits for an action.
Args:
user_id: The user ID
action_type: The type of action to check
Returns:
True if user has sufficient credits, False otherwise
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
return False
return user.credits >= action.cost
finally:
await session.close()
async def validate_and_reserve_credits(
self,
user_id: int,
action_type: CreditActionType,
metadata: dict[str, Any] | None = None,
) -> tuple[User, CreditAction]:
"""Validate user has sufficient credits and optionally reserve them.
Args:
user_id: The user ID
action_type: The type of action
metadata: Optional metadata to store with transaction
Returns:
Tuple of (user, credit_action)
Raises:
InsufficientCreditsError: If user has insufficient credits
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
logger.info(
"Credits validated for user %s: %s credits available, %s required",
user_id,
user.credits,
action.cost,
)
return user, action
finally:
await session.close()
async def deduct_credits(
self,
user_id: int,
action_type: CreditActionType,
success: bool = True,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Deduct credits from user account and record transaction.
Args:
user_id: The user ID
action_type: The type of action
success: Whether the action was successful
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
InsufficientCreditsError: If user has insufficient credits
ValueError: If user not found
"""
action = get_credit_action(action_type)
# Only deduct if action requires success and was successful, or doesn't require success
should_deduct = (action.requires_success and success) or not action.requires_success
if not should_deduct:
logger.info(
"Skipping credit deduction for user %s: action %s failed and requires success",
user_id,
action_type.value,
)
# Still create a transaction record for auditing
return await self._create_transaction_record(
user_id, action, 0, success, metadata
)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
# Record transaction
balance_before = user.credits
balance_after = user.credits - action.cost
transaction = CreditTransaction(
user_id=user_id,
action_type=action_type.value,
amount=-action.cost,
balance_before=balance_before,
balance_after=balance_after,
description=action.description,
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
user_id,
action.cost,
action_type.value,
success,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_deducted": action.cost,
"action_type": action_type.value,
"success": success,
}
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s", user_id,
)
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def add_credits(
self,
user_id: int,
amount: int,
description: str,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Add credits to user account.
Args:
user_id: The user ID
amount: Number of credits to add
description: Description of the credit addition
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
ValueError: If user not found or amount is negative
"""
if amount <= 0:
msg = "Amount must be positive"
raise ValueError(msg)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Record transaction
balance_before = user.credits
balance_after = user.credits + amount
transaction = CreditTransaction(
user_id=user_id,
action_type="credit_addition",
amount=amount,
balance_before=balance_before,
balance_after=balance_after,
description=description,
success=True,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits added for user %s: %s credits (description: %s)",
user_id,
amount,
description,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_added": amount,
"description": description,
"success": True,
}
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s", user_id,
)
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def _create_transaction_record(
self,
user_id: int,
action: CreditAction,
amount: int,
success: bool,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Create a transaction record without modifying credits.
Args:
user_id: The user ID
action: The credit action
amount: Amount to record (typically 0 for failed actions)
success: Whether the action was successful
metadata: Optional metadata
Returns:
The created transaction record
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
transaction = CreditTransaction(
user_id=user_id,
action_type=action.action_type.value,
amount=amount,
balance_before=user.credits,
balance_after=user.credits,
description=f"{action.description} (failed)" if not success else action.description,
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(transaction)
await session.commit()
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def get_user_balance(self, user_id: int) -> int:
"""Get current credit balance for a user.
Args:
user_id: The user ID
Returns:
Current credit balance
Raises:
ValueError: If user not found
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
return user.credits
finally:
await session.close()

View File

@@ -0,0 +1,192 @@
"""Decorators for credit management and validation."""
import functools
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
def requires_credits(
action_type: CreditActionType,
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
) -> Callable[[F], F]:
"""Decorator to enforce credit requirements for actions.
Args:
action_type: The type of action that requires credits
credit_service_factory: Factory to create credit service instance
user_id_param: Name of the parameter containing user ID
metadata_extractor: Optional function to extract metadata from function args
Returns:
Decorated function that validates and deducts credits
Example:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
lambda: get_credit_service(),
user_id_param="user_id"
)
async def play_sound_for_user(user_id: int, sound: Sound) -> bool:
# Implementation here
return True
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
param_index = param_names.index(user_id_param)
if param_index < len(args):
user_id = args[param_index]
if user_id is None:
msg = f"Could not extract user_id from parameter '{user_id_param}'"
raise ValueError(msg)
# Extract metadata if extractor provided
metadata = None
if metadata_extractor:
metadata = metadata_extractor(*args, **kwargs)
# Get credit service
credit_service = credit_service_factory()
# Validate credits before execution
await credit_service.validate_and_reserve_credits(
user_id, action_type, metadata
)
# Execute the function
success = False
result = None
try:
result = await func(*args, **kwargs)
success = bool(result) # Consider function result as success indicator
return result
except Exception:
success = False
raise
finally:
# Deduct credits based on success
await credit_service.deduct_credits(
user_id, action_type, success, metadata
)
return wrapper # type: ignore[return-value]
return decorator
def validate_credits_only(
action_type: CreditActionType,
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
) -> Callable[[F], F]:
"""Decorator to only validate credits without deducting them.
Useful for checking if a user can perform an action before actual execution.
Args:
action_type: The type of action that requires credits
credit_service_factory: Factory to create credit service instance
user_id_param: Name of the parameter containing user ID
Returns:
Decorated function that validates credits only
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
param_index = param_names.index(user_id_param)
if param_index < len(args):
user_id = args[param_index]
if user_id is None:
msg = f"Could not extract user_id from parameter '{user_id_param}'"
raise ValueError(msg)
# Get credit service
credit_service = credit_service_factory()
# Validate credits only
await credit_service.validate_and_reserve_credits(user_id, action_type)
# Execute the function
return await func(*args, **kwargs)
return wrapper # type: ignore[return-value]
return decorator
class CreditManager:
"""Context manager for credit operations."""
def __init__(
self,
credit_service: CreditService,
user_id: int,
action_type: CreditActionType,
metadata: dict[str, Any] | None = None,
) -> None:
"""Initialize credit manager.
Args:
credit_service: Credit service instance
user_id: User ID
action_type: Action type
metadata: Optional metadata
"""
self.credit_service = credit_service
self.user_id = user_id
self.action_type = action_type
self.metadata = metadata
self.validated = False
self.success = False
async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.metadata
)
self.validated = True
return self
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
"""Exit context manager - deduct credits based on success."""
if self.validated:
# If no exception occurred, consider it successful
success = exc_type is None and self.success
await self.credit_service.deduct_credits(
self.user_id, self.action_type, success, self.metadata
)
def mark_success(self) -> None:
"""Mark the operation as successful."""
self.success = True

View File

@@ -13,8 +13,10 @@ from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
from app.utils.auth import JWTUtils, PasswordUtils from app.utils.auth import JWTUtils, PasswordUtils

View File

@@ -0,0 +1,412 @@
"""Tests for credit transaction repository."""
import json
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.repositories.credit_transaction import CreditTransactionRepository
class TestCreditTransactionRepository:
"""Test credit transaction repository operations."""
@pytest_asyncio.fixture
async def credit_transaction_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[CreditTransactionRepository, None]: # type: ignore[misc]
"""Create a credit transaction repository instance."""
yield CreditTransactionRepository(test_session)
@pytest_asyncio.fixture
async def test_user_id(
self,
test_user: User,
) -> int:
"""Get test user ID to avoid lazy loading issues."""
return test_user.id
@pytest_asyncio.fixture
async def test_transactions(
self,
test_session: AsyncSession,
test_user_id: int,
) -> AsyncGenerator[list[CreditTransaction], None]: # type: ignore[misc]
"""Create test credit transactions."""
transactions = []
user_id = test_user_id
# Create various types of transactions
transaction_data = [
{
"user_id": user_id,
"action_type": "vlc_play_sound",
"amount": -1,
"balance_before": 10,
"balance_after": 9,
"description": "Play sound via VLC",
"success": True,
"metadata_json": json.dumps({"sound_id": 1, "sound_name": "test.mp3"}),
},
{
"user_id": user_id,
"action_type": "audio_extraction",
"amount": -5,
"balance_before": 9,
"balance_after": 4,
"description": "Extract audio from URL",
"success": True,
"metadata_json": json.dumps({"url": "https://example.com/video"}),
},
{
"user_id": user_id,
"action_type": "vlc_play_sound",
"amount": 0,
"balance_before": 4,
"balance_after": 4,
"description": "Play sound via VLC (failed)",
"success": False,
"metadata_json": json.dumps({"sound_id": 2, "error": "File not found"}),
},
{
"user_id": user_id,
"action_type": "credit_addition",
"amount": 50,
"balance_before": 4,
"balance_after": 54,
"description": "Bonus credits",
"success": True,
"metadata_json": json.dumps({"reason": "signup_bonus"}),
},
]
for data in transaction_data:
transaction = CreditTransaction(**data)
test_session.add(transaction)
transactions.append(transaction)
await test_session.commit()
for transaction in transactions:
await test_session.refresh(transaction)
yield transactions
@pytest_asyncio.fixture
async def other_user_transaction(
self,
test_session: AsyncSession,
ensure_plans: tuple, # noqa: ARG002
) -> AsyncGenerator[CreditTransaction, None]: # type: ignore[misc]
"""Create a transaction for a different user."""
from app.models.plan import Plan
from app.repositories.user import UserRepository
# Create another user
user_repo = UserRepository(test_session)
other_user_data = {
"email": "other@example.com",
"name": "Other User",
"password_hash": "hashed_password",
"role": "user",
"is_active": True,
}
other_user = await user_repo.create(other_user_data)
# Create transaction for the other user
transaction_data = {
"user_id": other_user.id,
"action_type": "vlc_play_sound",
"amount": -1,
"balance_before": 100,
"balance_after": 99,
"description": "Other user play sound",
"success": True,
"metadata_json": None,
}
transaction = CreditTransaction(**transaction_data)
test_session.add(transaction)
await test_session.commit()
await test_session.refresh(transaction)
yield transaction
@pytest.mark.asyncio
async def test_get_by_id_existing(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transaction by ID when it exists."""
transaction = await credit_transaction_repository.get_by_id(test_transactions[0].id)
assert transaction is not None
assert transaction.id == test_transactions[0].id
assert transaction.action_type == "vlc_play_sound"
assert transaction.amount == -1
@pytest.mark.asyncio
async def test_get_by_id_nonexistent(
self,
credit_transaction_repository: CreditTransactionRepository,
) -> None:
"""Test getting transaction by ID when it doesn't exist."""
transaction = await credit_transaction_repository.get_by_id(99999)
assert transaction is None
@pytest.mark.asyncio
async def test_get_by_user_id(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
test_user_id: int,
) -> None:
"""Test getting transactions by user ID."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should return all transactions for test_user
assert len(transactions) == 4
# Should be ordered by created_at desc (newest first)
assert all(t.user_id == test_user_id for t in transactions)
# Should not include other user's transaction
other_user_ids = [t.user_id for t in transactions]
assert other_user_transaction.user_id not in other_user_ids
@pytest.mark.asyncio
async def test_get_by_user_id_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test getting transactions by user ID with pagination."""
# Get first 2 transactions
first_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=0
)
assert len(first_page) == 2
# Get next 2 transactions
second_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=2
)
assert len(second_page) == 2
# Should not overlap
first_page_ids = {t.id for t in first_page}
second_page_ids = {t.id for t in second_page}
assert first_page_ids.isdisjoint(second_page_ids)
@pytest.mark.asyncio
async def test_get_by_action_type(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transactions by action type."""
vlc_transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound"
)
# Should return 2 VLC transactions (1 successful, 1 failed)
assert len(vlc_transactions) >= 2
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
extraction_transactions = await credit_transaction_repository.get_by_action_type(
"audio_extraction"
)
# Should return 1 extraction transaction
assert len(extraction_transactions) >= 1
assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
@pytest.mark.asyncio
async def test_get_by_action_type_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transactions by action type with pagination."""
# Test with limit
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1
)
assert len(transactions) == 1
assert transactions[0].action_type == "vlc_play_sound"
# Test with offset
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, offset=1
)
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
@pytest.mark.asyncio
async def test_get_successful_transactions(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting only successful transactions."""
successful_transactions = await credit_transaction_repository.get_successful_transactions()
# Should only return successful transactions
assert all(t.success is True for t in successful_transactions)
# Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition)
assert len(successful_transactions) >= 3
@pytest.mark.asyncio
async def test_get_successful_transactions_by_user(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
test_user_id: int,
) -> None:
"""Test getting successful transactions filtered by user."""
successful_transactions = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id
)
# Should only return successful transactions for test_user
assert all(t.success is True for t in successful_transactions)
assert all(t.user_id == test_user_id for t in successful_transactions)
# Should be 3 successful transactions for test_user
assert len(successful_transactions) == 3
@pytest.mark.asyncio
async def test_get_successful_transactions_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test getting successful transactions with pagination."""
# Get first 2 successful transactions
first_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=0
)
assert len(first_page) == 2
assert all(t.success is True for t in first_page)
# Get next successful transaction
second_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=2
)
assert len(second_page) == 1 # Should be 1 remaining
assert all(t.success is True for t in second_page)
@pytest.mark.asyncio
async def test_get_all_transactions(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
) -> None:
"""Test getting all transactions."""
all_transactions = await credit_transaction_repository.get_all()
# Should return all transactions
assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction
@pytest.mark.asyncio
async def test_create_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_user_id: int,
) -> None:
"""Test creating a new transaction."""
transaction_data = {
"user_id": test_user_id,
"action_type": "test_action",
"amount": -10,
"balance_before": 100,
"balance_after": 90,
"description": "Test transaction",
"success": True,
"metadata_json": json.dumps({"test": "data"}),
}
transaction = await credit_transaction_repository.create(transaction_data)
assert transaction.id is not None
assert transaction.user_id == test_user_id
assert transaction.action_type == "test_action"
assert transaction.amount == -10
assert transaction.balance_before == 100
assert transaction.balance_after == 90
assert transaction.success is True
assert json.loads(transaction.metadata_json) == {"test": "data"}
@pytest.mark.asyncio
async def test_update_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test updating a transaction."""
transaction = test_transactions[0]
update_data = {
"description": "Updated description",
"metadata_json": json.dumps({"updated": True}),
}
updated_transaction = await credit_transaction_repository.update(
transaction, update_data
)
assert updated_transaction.id == transaction.id
assert updated_transaction.description == "Updated description"
assert json.loads(updated_transaction.metadata_json) == {"updated": True}
# Other fields should remain unchanged
assert updated_transaction.amount == transaction.amount
assert updated_transaction.action_type == transaction.action_type
@pytest.mark.asyncio
async def test_delete_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_session: AsyncSession,
test_user_id: int,
) -> None:
"""Test deleting a transaction."""
# Create a transaction to delete
transaction_data = {
"user_id": test_user_id,
"action_type": "to_delete",
"amount": -1,
"balance_before": 10,
"balance_after": 9,
"description": "To be deleted",
"success": True,
"metadata_json": None,
}
transaction = await credit_transaction_repository.create(transaction_data)
transaction_id = transaction.id
# Delete the transaction
await credit_transaction_repository.delete(transaction)
# Verify transaction is deleted
deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id)
assert deleted_transaction is None
@pytest.mark.asyncio
async def test_transaction_ordering(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test that transactions are ordered by created_at desc."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should be ordered by created_at desc (newest first)
for i in range(len(transactions) - 1):
assert transactions[i].created_at >= transactions[i + 1].created_at

View File

@@ -0,0 +1,376 @@
"""Tests for sound repository."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound
from app.repositories.sound import SoundRepository
class TestSoundRepository:
"""Test sound repository operations."""
@pytest_asyncio.fixture
async def sound_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[SoundRepository, None]: # type: ignore[misc]
"""Create a sound repository instance."""
yield SoundRepository(test_session)
@pytest_asyncio.fixture
async def test_sound(
self,
test_session: AsyncSession,
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
"""Create a test sound."""
sound_data = {
"name": "Test Sound",
"filename": "test_sound.mp3",
"type": "SDB",
"duration": 5000,
"size": 1024000,
"hash": "test_hash_123",
"play_count": 0,
"is_normalized": False,
}
sound = Sound(**sound_data)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(sound)
yield sound
@pytest_asyncio.fixture
async def normalized_sound(
self,
test_session: AsyncSession,
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
"""Create a normalized test sound."""
sound_data = {
"name": "Normalized Sound",
"filename": "normalized_sound.mp3",
"type": "TTS",
"duration": 3000,
"size": 512000,
"hash": "normalized_hash_456",
"play_count": 5,
"is_normalized": True,
"normalized_filename": "normalized_sound_norm.mp3",
"normalized_duration": 3000,
"normalized_size": 480000,
"normalized_hash": "normalized_hash_norm_456",
}
sound = Sound(**sound_data)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(sound)
yield sound
@pytest.mark.asyncio
async def test_get_by_id_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by ID when it exists."""
sound = await sound_repository.get_by_id(test_sound.id)
assert sound is not None
assert sound.id == test_sound.id
assert sound.name == test_sound.name
assert sound.filename == test_sound.filename
assert sound.type == test_sound.type
@pytest.mark.asyncio
async def test_get_by_id_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by ID when it doesn't exist."""
sound = await sound_repository.get_by_id(99999)
assert sound is None
@pytest.mark.asyncio
async def test_get_by_filename_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by filename when it exists."""
sound = await sound_repository.get_by_filename(test_sound.filename)
assert sound is not None
assert sound.id == test_sound.id
assert sound.filename == test_sound.filename
@pytest.mark.asyncio
async def test_get_by_filename_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by filename when it doesn't exist."""
sound = await sound_repository.get_by_filename("nonexistent.mp3")
assert sound is None
@pytest.mark.asyncio
async def test_get_by_hash_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by hash when it exists."""
sound = await sound_repository.get_by_hash(test_sound.hash)
assert sound is not None
assert sound.id == test_sound.id
assert sound.hash == test_sound.hash
@pytest.mark.asyncio
async def test_get_by_hash_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by hash when it doesn't exist."""
sound = await sound_repository.get_by_hash("nonexistent_hash")
assert sound is None
@pytest.mark.asyncio
async def test_get_by_type(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting sounds by type."""
sdb_sounds = await sound_repository.get_by_type("SDB")
tts_sounds = await sound_repository.get_by_type("TTS")
ext_sounds = await sound_repository.get_by_type("EXT")
# Should find the SDB sound
assert len(sdb_sounds) >= 1
assert any(sound.id == test_sound.id for sound in sdb_sounds)
# Should find the TTS sound
assert len(tts_sounds) >= 1
assert any(sound.id == normalized_sound.id for sound in tts_sounds)
# Should not find any EXT sounds
assert len(ext_sounds) == 0
@pytest.mark.asyncio
async def test_create_sound(
self,
sound_repository: SoundRepository,
) -> None:
"""Test creating a new sound."""
sound_data = {
"name": "New Sound",
"filename": "new_sound.wav",
"type": "EXT",
"duration": 7500,
"size": 2048000,
"hash": "new_hash_789",
"play_count": 0,
"is_normalized": False,
}
sound = await sound_repository.create(sound_data)
assert sound.id is not None
assert sound.name == sound_data["name"]
assert sound.filename == sound_data["filename"]
assert sound.type == sound_data["type"]
assert sound.duration == sound_data["duration"]
assert sound.size == sound_data["size"]
assert sound.hash == sound_data["hash"]
assert sound.play_count == 0
assert sound.is_normalized is False
@pytest.mark.asyncio
async def test_update_sound(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test updating a sound."""
update_data = {
"name": "Updated Sound Name",
"play_count": 10,
"is_normalized": True,
"normalized_filename": "updated_norm.mp3",
}
updated_sound = await sound_repository.update(test_sound, update_data)
assert updated_sound.id == test_sound.id
assert updated_sound.name == "Updated Sound Name"
assert updated_sound.play_count == 10
assert updated_sound.is_normalized is True
assert updated_sound.normalized_filename == "updated_norm.mp3"
assert updated_sound.filename == test_sound.filename # Unchanged
@pytest.mark.asyncio
async def test_delete_sound(
self,
sound_repository: SoundRepository,
test_session: AsyncSession,
) -> None:
"""Test deleting a sound."""
# Create a sound to delete
sound_data = {
"name": "To Delete",
"filename": "to_delete.mp3",
"type": "SDB",
"duration": 1000,
"size": 256000,
"hash": "delete_hash",
"play_count": 0,
"is_normalized": False,
}
sound = await sound_repository.create(sound_data)
sound_id = sound.id
# Delete the sound
await sound_repository.delete(sound)
# Verify sound is deleted
deleted_sound = await sound_repository.get_by_id(sound_id)
assert deleted_sound is None
@pytest.mark.asyncio
async def test_search_by_name(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test searching sounds by name."""
# Search for "test" should find test_sound
results = await sound_repository.search_by_name("test")
assert len(results) >= 1
assert any(sound.id == test_sound.id for sound in results)
# Search for "normalized" should find normalized_sound
results = await sound_repository.search_by_name("normalized")
assert len(results) >= 1
assert any(sound.id == normalized_sound.id for sound in results)
# Case insensitive search
results = await sound_repository.search_by_name("TEST")
assert len(results) >= 1
assert any(sound.id == test_sound.id for sound in results)
# Partial match
results = await sound_repository.search_by_name("norm")
assert len(results) >= 1
assert any(sound.id == normalized_sound.id for sound in results)
# No matches
results = await sound_repository.search_by_name("nonexistent")
assert len(results) == 0
@pytest.mark.asyncio
async def test_get_popular_sounds(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting popular sounds."""
# Update play counts to test ordering
await sound_repository.update(test_sound, {"play_count": 15})
await sound_repository.update(normalized_sound, {"play_count": 5})
# Create another sound with higher play count
high_play_sound_data = {
"name": "Popular Sound",
"filename": "popular.mp3",
"type": "SDB",
"duration": 2000,
"size": 300000,
"hash": "popular_hash",
"play_count": 25,
"is_normalized": False,
}
high_play_sound = await sound_repository.create(high_play_sound_data)
# Get popular sounds
popular_sounds = await sound_repository.get_popular_sounds(limit=10)
assert len(popular_sounds) >= 3
# Should be ordered by play_count desc
assert popular_sounds[0].play_count >= popular_sounds[1].play_count
# The highest play count sound should be first
assert popular_sounds[0].id == high_play_sound.id
@pytest.mark.asyncio
async def test_get_unnormalized_sounds(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting unnormalized sounds."""
unnormalized_sounds = await sound_repository.get_unnormalized_sounds()
# Should include test_sound (not normalized)
assert any(sound.id == test_sound.id for sound in unnormalized_sounds)
# Should not include normalized_sound (already normalized)
assert not any(sound.id == normalized_sound.id for sound in unnormalized_sounds)
@pytest.mark.asyncio
async def test_get_unnormalized_sounds_by_type(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting unnormalized sounds by type."""
# Get unnormalized SDB sounds
sdb_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("SDB")
# Should include test_sound (SDB, not normalized)
assert any(sound.id == test_sound.id for sound in sdb_unnormalized)
# Get unnormalized TTS sounds
tts_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("TTS")
# Should not include normalized_sound (TTS, but already normalized)
assert not any(sound.id == normalized_sound.id for sound in tts_unnormalized)
# Get unnormalized EXT sounds
ext_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("EXT")
# Should be empty
assert len(ext_unnormalized) == 0
@pytest.mark.asyncio
async def test_create_duplicate_hash(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test creating sound with duplicate hash is allowed."""
# Store the hash to avoid lazy loading issues
original_hash = test_sound.hash
duplicate_sound_data = {
"name": "Duplicate Hash Sound",
"filename": "duplicate.mp3",
"type": "SDB",
"duration": 1000,
"size": 100000,
"hash": original_hash, # Same hash as test_sound
"play_count": 0,
"is_normalized": False,
}
# Should succeed - duplicate hashes are allowed
duplicate_sound = await sound_repository.create(duplicate_sound_data)
assert duplicate_sound.id is not None
assert duplicate_sound.name == "Duplicate Hash Sound"
assert duplicate_sound.hash == original_hash # Same hash is allowed

View File

@@ -0,0 +1,268 @@
"""Tests for user OAuth repository."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.user import User
from app.models.user_oauth import UserOauth
from app.repositories.user_oauth import UserOauthRepository
class TestUserOauthRepository:
"""Test user OAuth repository operations."""
@pytest_asyncio.fixture
async def user_oauth_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[UserOauthRepository, None]: # type: ignore[misc]
"""Create a user OAuth repository instance."""
yield UserOauthRepository(test_session)
@pytest_asyncio.fixture
async def test_user_id(
self,
test_user: User,
) -> int:
"""Get test user ID to avoid lazy loading issues."""
return test_user.id
@pytest_asyncio.fixture
async def test_oauth(
self,
test_session: AsyncSession,
test_user_id: int,
) -> AsyncGenerator[UserOauth, None]: # type: ignore[misc]
"""Create a test OAuth record."""
oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_123456",
"email": "test@gmail.com",
"name": "Test User Google",
"picture": None,
}
oauth = UserOauth(**oauth_data)
test_session.add(oauth)
await test_session.commit()
await test_session.refresh(oauth)
yield oauth
@pytest.mark.asyncio
async def test_get_by_provider_user_id_existing(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
) -> None:
"""Test getting OAuth by provider user ID when it exists."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "google_123456"
)
assert oauth is not None
assert oauth.id == test_oauth.id
assert oauth.provider == "google"
assert oauth.provider_user_id == "google_123456"
assert oauth.user_id == test_oauth.user_id
@pytest.mark.asyncio
async def test_get_by_provider_user_id_nonexistent(
self,
user_oauth_repository: UserOauthRepository,
) -> None:
"""Test getting OAuth by provider user ID when it doesn't exist."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "nonexistent_id"
)
assert oauth is None
@pytest.mark.asyncio
async def test_get_by_user_id_and_provider_existing(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
test_user_id: int,
) -> None:
"""Test getting OAuth by user ID and provider when it exists."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
)
assert oauth is not None
assert oauth.id == test_oauth.id
assert oauth.provider == "google"
assert oauth.user_id == test_user_id
@pytest.mark.asyncio
async def test_get_by_user_id_and_provider_nonexistent(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test getting OAuth by user ID and provider when it doesn't exist."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
)
assert oauth is None
@pytest.mark.asyncio
async def test_create_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test creating a new OAuth record."""
oauth_data = {
"user_id": test_user_id,
"provider": "github",
"provider_user_id": "github_789",
"email": "test@github.com",
"name": "Test User GitHub",
"picture": None,
}
oauth = await user_oauth_repository.create(oauth_data)
assert oauth.id is not None
assert oauth.user_id == test_user_id
assert oauth.provider == "github"
assert oauth.provider_user_id == "github_789"
assert oauth.email == "test@github.com"
assert oauth.name == "Test User GitHub"
@pytest.mark.asyncio
async def test_update_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
) -> None:
"""Test updating an OAuth record."""
update_data = {
"email": "updated@gmail.com",
"name": "Updated User Name",
"picture": "https://example.com/photo.jpg",
}
updated_oauth = await user_oauth_repository.update(test_oauth, update_data)
assert updated_oauth.id == test_oauth.id
assert updated_oauth.email == "updated@gmail.com"
assert updated_oauth.name == "Updated User Name"
assert updated_oauth.picture == "https://example.com/photo.jpg"
assert updated_oauth.provider == test_oauth.provider # Unchanged
assert updated_oauth.provider_user_id == test_oauth.provider_user_id # Unchanged
@pytest.mark.asyncio
async def test_delete_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_session: AsyncSession,
test_user_id: int,
) -> None:
"""Test deleting an OAuth record."""
# Create an OAuth record to delete
oauth_data = {
"user_id": test_user_id,
"provider": "twitter",
"provider_user_id": "twitter_456",
"email": "test@twitter.com",
"name": "Test User Twitter",
"picture": None,
}
oauth = await user_oauth_repository.create(oauth_data)
oauth_id = oauth.id
# Delete the OAuth record
await user_oauth_repository.delete(oauth)
# Verify it's deleted by trying to find it
deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
"twitter", "twitter_456"
)
assert deleted_oauth is None
@pytest.mark.asyncio
async def test_create_duplicate_provider_user_id(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
test_user_id: int,
) -> None:
"""Test creating OAuth with duplicate provider user ID should fail."""
# Try to create another OAuth with the same provider and provider_user_id
duplicate_oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_123456", # Same as test_oauth
"email": "another@gmail.com",
"name": "Another User",
"picture": None,
}
# This should fail due to unique constraint
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
await user_oauth_repository.create(duplicate_oauth_data)
@pytest.mark.asyncio
async def test_multiple_providers_same_user(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test that a user can have multiple OAuth providers."""
# Create Google OAuth
google_oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_user_1",
"email": "user@gmail.com",
"name": "Test User Google",
"picture": None,
}
google_oauth = await user_oauth_repository.create(google_oauth_data)
# Create GitHub OAuth for the same user
github_oauth_data = {
"user_id": test_user_id,
"provider": "github",
"provider_user_id": "github_user_1",
"email": "user@github.com",
"name": "Test User GitHub",
"picture": None,
}
github_oauth = await user_oauth_repository.create(github_oauth_data)
# Verify both exist by querying back from database
found_google = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
)
found_github = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
)
assert found_google is not None
assert found_github is not None
assert found_google.provider == "google"
assert found_github.provider == "github"
assert found_google.user_id == test_user_id
assert found_github.user_id == test_user_id
assert found_google.provider_user_id == "google_user_1"
assert found_github.provider_user_id == "github_user_1"
# Verify we can also find them by provider_user_id
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
"google", "google_user_1"
)
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
"github", "github_user_1"
)
assert found_google_by_provider is not None
assert found_github_by_provider is not None
assert found_google_by_provider.user_id == test_user_id
assert found_github_by_provider.user_id == test_user_id

View File

@@ -0,0 +1,358 @@
"""Tests for credit service."""
import json
from unittest.mock import AsyncMock, Mock, patch
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_action import CreditActionType
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.services.credit import CreditService, InsufficientCreditsError
class TestCreditService:
"""Test credit service functionality."""
@pytest.fixture
def mock_db_session_factory(self):
"""Create a mock database session factory."""
session = AsyncMock(spec=AsyncSession)
return lambda: session
@pytest.fixture
def credit_service(self, mock_db_session_factory):
"""Create a credit service instance for testing."""
return CreditService(mock_db_session_factory)
@pytest.fixture
def sample_user(self):
"""Create a sample user for testing."""
return User(
id=1,
name="Test User",
email="test@example.com",
role="user",
credits=10,
plan_id=1,
)
@pytest.mark.asyncio
async def test_check_credits_sufficient(self, credit_service, sample_user):
"""Test checking credits when user has sufficient credits."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
assert result is True
mock_repo.get_by_id.assert_called_once_with(1)
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_check_credits_insufficient(self, credit_service):
"""Test checking credits when user has insufficient credits."""
mock_session = credit_service.db_session_factory()
poor_user = User(
id=1,
name="Poor User",
email="poor@example.com",
role="user",
credits=0, # No credits
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
assert result is False
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_check_credits_user_not_found(self, credit_service):
"""Test checking credits when user is not found."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND)
assert result is False
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_validate_and_reserve_credits_success(self, credit_service, sample_user):
"""Test successful credit validation and reservation."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
user, action = await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND
)
assert user == sample_user
assert action.action_type == CreditActionType.VLC_PLAY_SOUND
assert action.cost == 1
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_validate_and_reserve_credits_insufficient(self, credit_service):
"""Test credit validation with insufficient credits."""
mock_session = credit_service.db_session_factory()
poor_user = User(
id=1,
name="Poor User",
email="poor@example.com",
role="user",
credits=0,
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
with pytest.raises(InsufficientCreditsError) as exc_info:
await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND
)
assert exc_info.value.required == 1
assert exc_info.value.available == 0
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_validate_and_reserve_credits_user_not_found(self, credit_service):
"""Test credit validation when user is not found."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.validate_and_reserve_credits(
999, CreditActionType.VLC_PLAY_SOUND
)
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_deduct_credits_success(self, credit_service, sample_user):
"""Test successful credit deduction."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
)
# Verify user credits were updated
mock_repo.update.assert_called_once_with(sample_user, {"credits": 9})
# Verify transaction was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", {
"user_id": "1",
"credits_before": 10,
"credits_after": 9,
"credits_deducted": 1,
"action_type": "vlc_play_sound",
"success": True,
}
)
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert isinstance(added_transaction, CreditTransaction)
assert added_transaction.user_id == 1
assert added_transaction.action_type == "vlc_play_sound"
assert added_transaction.amount == -1
assert added_transaction.balance_before == 10
assert added_transaction.balance_after == 9
assert added_transaction.success is True
assert json.loads(added_transaction.metadata_json) == {"test": "data"}
@pytest.mark.asyncio
async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user):
"""Test credit deduction when action failed but requires success."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, False # Action failed
)
# Verify user credits were NOT updated (action requires success)
mock_repo.update.assert_not_called()
# Verify transaction was still created for auditing
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify no socket event was emitted since no credits were actually deducted
mock_socket_manager.send_to_user.assert_not_called()
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert added_transaction.amount == 0 # No deduction for failed action
assert added_transaction.balance_before == 10
assert added_transaction.balance_after == 10 # No change
assert added_transaction.success is False
@pytest.mark.asyncio
async def test_deduct_credits_insufficient(self, credit_service):
"""Test credit deduction with insufficient credits."""
mock_session = credit_service.db_session_factory()
poor_user = User(
id=1,
name="Poor User",
email="poor@example.com",
role="user",
credits=0,
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
mock_socket_manager.send_to_user = AsyncMock()
with pytest.raises(InsufficientCreditsError):
await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True
)
# Verify no socket event was emitted since credits could not be deducted
mock_socket_manager.send_to_user.assert_not_called()
mock_session.rollback.assert_called_once()
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_add_credits(self, credit_service, sample_user):
"""Test adding credits to user account."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.add_credits(
1, 5, "Bonus credits", {"reason": "signup"}
)
# Verify user credits were updated
mock_repo.update.assert_called_once_with(sample_user, {"credits": 15})
# Verify transaction was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", {
"user_id": "1",
"credits_before": 10,
"credits_after": 15,
"credits_added": 5,
"description": "Bonus credits",
"success": True,
}
)
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert added_transaction.amount == 5
assert added_transaction.balance_before == 10
assert added_transaction.balance_after == 15
assert added_transaction.description == "Bonus credits"
@pytest.mark.asyncio
async def test_add_credits_invalid_amount(self, credit_service):
"""Test adding invalid amount of credits."""
with pytest.raises(ValueError, match="Amount must be positive"):
await credit_service.add_credits(1, 0, "Invalid")
with pytest.raises(ValueError, match="Amount must be positive"):
await credit_service.add_credits(1, -5, "Invalid")
@pytest.mark.asyncio
async def test_get_user_balance(self, credit_service, sample_user):
"""Test getting user credit balance."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
balance = await credit_service.get_user_balance(1)
assert balance == 10
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_balance_user_not_found(self, credit_service):
"""Test getting balance for non-existent user."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.get_user_balance(999)
mock_session.close.assert_called_once()
class TestInsufficientCreditsError:
"""Test InsufficientCreditsError exception."""
def test_insufficient_credits_error_creation(self):
"""Test creating InsufficientCreditsError."""
error = InsufficientCreditsError(5, 2)
assert error.required == 5
assert error.available == 2
assert str(error) == "Insufficient credits: 5 required, 2 available"

View File

@@ -0,0 +1,277 @@
"""Tests for credit decorators."""
from unittest.mock import AsyncMock, Mock
import pytest
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError
from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only
class TestRequiresCreditsDecorator:
"""Test requires_credits decorator."""
@pytest.fixture
def mock_credit_service(self):
"""Create a mock credit service."""
service = AsyncMock(spec=CreditService)
service.validate_and_reserve_credits = AsyncMock()
service.deduct_credits = AsyncMock()
return service
@pytest.fixture
def credit_service_factory(self, mock_credit_service):
"""Create a credit service factory."""
return lambda: mock_credit_service
@pytest.mark.asyncio
async def test_decorator_success(self, credit_service_factory, mock_credit_service):
"""Test decorator with successful action."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int, message: str) -> str:
return f"Success: {message}"
result = await test_action(user_id=123, message="test")
assert result == "Success: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, None
)
@pytest.mark.asyncio
async def test_decorator_with_metadata(self, credit_service_factory, mock_credit_service):
"""Test decorator with metadata extraction."""
def extract_metadata(user_id: int, sound_name: str) -> dict:
return {"sound_name": sound_name}
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id",
metadata_extractor=extract_metadata
)
async def test_action(user_id: int, sound_name: str) -> bool:
return True
await test_action(user_id=123, sound_name="test.mp3")
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"}
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"}
)
@pytest.mark.asyncio
async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service):
"""Test decorator with failed action."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int) -> bool:
return False # Action fails
result = await test_action(user_id=123)
assert result is False
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
)
@pytest.mark.asyncio
async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service):
"""Test decorator when action raises exception."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int) -> str:
raise ValueError("Test error")
with pytest.raises(ValueError, match="Test error"):
await test_action(user_id=123)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
)
@pytest.mark.asyncio
async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service):
"""Test decorator with insufficient credits."""
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int) -> str:
return "Should not execute"
with pytest.raises(InsufficientCreditsError):
await test_action(user_id=123)
# Should not call deduct_credits since validation failed
mock_credit_service.deduct_credits.assert_not_called()
@pytest.mark.asyncio
async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service):
"""Test decorator extracting user_id from positional args."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int, message: str) -> str:
return message
result = await test_action(123, "test")
assert result == "test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None
)
@pytest.mark.asyncio
async def test_decorator_missing_user_id(self, credit_service_factory):
"""Test decorator when user_id cannot be extracted."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(other_param: str) -> str:
return other_param
with pytest.raises(ValueError, match="Could not extract user_id"):
await test_action(other_param="test")
class TestValidateCreditsOnlyDecorator:
"""Test validate_credits_only decorator."""
@pytest.fixture
def mock_credit_service(self):
"""Create a mock credit service."""
service = AsyncMock(spec=CreditService)
service.validate_and_reserve_credits = AsyncMock()
return service
@pytest.fixture
def credit_service_factory(self, mock_credit_service):
"""Create a credit service factory."""
return lambda: mock_credit_service
@pytest.mark.asyncio
async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service):
"""Test validate_credits_only decorator."""
@validate_credits_only(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int, message: str) -> str:
return f"Validated: {message}"
result = await test_action(user_id=123, message="test")
assert result == "Validated: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND
)
# Should not deduct credits, only validate
mock_credit_service.deduct_credits.assert_not_called()
class TestCreditManager:
"""Test CreditManager context manager."""
@pytest.fixture
def mock_credit_service(self):
"""Create a mock credit service."""
service = AsyncMock(spec=CreditService)
service.validate_and_reserve_credits = AsyncMock()
service.deduct_credits = AsyncMock()
return service
@pytest.mark.asyncio
async def test_credit_manager_success(self, mock_credit_service):
"""Test CreditManager with successful operation."""
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND,
{"test": "data"}
) as manager:
manager.mark_success()
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"}
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
)
@pytest.mark.asyncio
async def test_credit_manager_failure(self, mock_credit_service):
"""Test CreditManager with failed operation."""
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
):
# Don't mark as success - should be considered failed
pass
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
)
@pytest.mark.asyncio
async def test_credit_manager_exception(self, mock_credit_service):
"""Test CreditManager when exception occurs."""
with pytest.raises(ValueError, match="Test error"):
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
):
raise ValueError("Test error")
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
)
@pytest.mark.asyncio
async def test_credit_manager_validation_failure(self, mock_credit_service):
"""Test CreditManager when validation fails."""
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
with pytest.raises(InsufficientCreditsError):
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
):
pass
# Should not call deduct_credits since validation failed
mock_credit_service.deduct_credits.assert_not_called()