Add tests for sound repository, user OAuth repository, credit service, and credit decorators
- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities. - Create tests for UserOauthRepository to validate OAuth record management. - Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits. - Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
This commit is contained in:
@@ -7,9 +7,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.dependencies import get_current_active_user_flexible
|
from app.core.dependencies import get_current_active_user_flexible
|
||||||
|
from app.models.credit_action import CreditActionType
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.sound import SoundRepository
|
from app.repositories.sound import SoundRepository
|
||||||
from app.services.extraction import ExtractionInfo, ExtractionService
|
from app.services.extraction import ExtractionInfo, ExtractionService
|
||||||
|
from app.services.credit import CreditService, InsufficientCreditsError
|
||||||
from app.services.extraction_processor import extraction_processor
|
from app.services.extraction_processor import extraction_processor
|
||||||
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
|
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
|
||||||
from app.services.sound_scanner import ScanResults, SoundScannerService
|
from app.services.sound_scanner import ScanResults, SoundScannerService
|
||||||
@@ -45,6 +47,12 @@ def get_vlc_player() -> VLCPlayerService:
|
|||||||
return get_vlc_player_service(get_session_factory())
|
return get_vlc_player_service(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
|
def get_credit_service() -> CreditService:
|
||||||
|
"""Get the credit service."""
|
||||||
|
from app.core.database import get_session_factory
|
||||||
|
return CreditService(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
async def get_sound_repository(
|
async def get_sound_repository(
|
||||||
session: Annotated[AsyncSession, Depends(get_db)],
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
) -> SoundRepository:
|
) -> SoundRepository:
|
||||||
@@ -373,8 +381,9 @@ async def play_sound_with_vlc(
|
|||||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
|
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
|
||||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||||
|
credit_service: Annotated[CreditService, Depends(get_credit_service)],
|
||||||
) -> dict[str, str | int | bool]:
|
) -> dict[str, str | int | bool]:
|
||||||
"""Play a sound using VLC subprocess."""
|
"""Play a sound using VLC subprocess (requires 1 credit)."""
|
||||||
try:
|
try:
|
||||||
# Get the sound
|
# Get the sound
|
||||||
sound = await sound_repo.get_by_id(sound_id)
|
sound = await sound_repo.get_by_id(sound_id)
|
||||||
@@ -384,9 +393,30 @@ async def play_sound_with_vlc(
|
|||||||
detail=f"Sound with ID {sound_id} not found",
|
detail=f"Sound with ID {sound_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check and validate credits before playing
|
||||||
|
try:
|
||||||
|
await credit_service.validate_and_reserve_credits(
|
||||||
|
current_user.id,
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
{"sound_id": sound_id, "sound_name": sound.name}
|
||||||
|
)
|
||||||
|
except InsufficientCreditsError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Insufficient credits: {e.required} required, {e.available} available",
|
||||||
|
) from e
|
||||||
|
|
||||||
# Play the sound using VLC
|
# Play the sound using VLC
|
||||||
success = await vlc_player.play_sound(sound)
|
success = await vlc_player.play_sound(sound)
|
||||||
|
|
||||||
|
# Deduct credits based on success
|
||||||
|
await credit_service.deduct_credits(
|
||||||
|
current_user.id,
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
success,
|
||||||
|
{"sound_id": sound_id, "sound_name": sound.name},
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
@@ -398,6 +428,7 @@ async def play_sound_with_vlc(
|
|||||||
"sound_id": sound_id,
|
"sound_id": sound_id,
|
||||||
"sound_name": sound.name,
|
"sound_name": sound.name,
|
||||||
"success": True,
|
"success": True,
|
||||||
|
"credits_deducted": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
121
app/models/credit_action.py
Normal file
121
app/models/credit_action.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Credit action definitions for the credit system."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class CreditActionType(str, Enum):
|
||||||
|
"""Types of actions that consume credits."""
|
||||||
|
|
||||||
|
VLC_PLAY_SOUND = "vlc_play_sound"
|
||||||
|
AUDIO_EXTRACTION = "audio_extraction"
|
||||||
|
TEXT_TO_SPEECH = "text_to_speech"
|
||||||
|
SOUND_NORMALIZATION = "sound_normalization"
|
||||||
|
API_REQUEST = "api_request"
|
||||||
|
PLAYLIST_CREATION = "playlist_creation"
|
||||||
|
|
||||||
|
|
||||||
|
class CreditAction:
|
||||||
|
"""Definition of a credit-consuming action."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
action_type: CreditActionType,
|
||||||
|
cost: int,
|
||||||
|
description: str,
|
||||||
|
*,
|
||||||
|
requires_success: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a credit action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_type: The type of action
|
||||||
|
cost: Number of credits required
|
||||||
|
description: Human-readable description
|
||||||
|
requires_success: Whether credits are only deducted on successful completion
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.action_type = action_type
|
||||||
|
self.cost = cost
|
||||||
|
self.description = description
|
||||||
|
self.requires_success = requires_success
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
"""Return string representation of the action."""
|
||||||
|
return f"{self.action_type.value} ({self.cost} credits)"
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for serialization."""
|
||||||
|
return {
|
||||||
|
"action_type": self.action_type.value,
|
||||||
|
"cost": self.cost,
|
||||||
|
"description": self.description,
|
||||||
|
"requires_success": self.requires_success,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Predefined credit actions
|
||||||
|
CREDIT_ACTIONS = {
|
||||||
|
CreditActionType.VLC_PLAY_SOUND: CreditAction(
|
||||||
|
action_type=CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
cost=1,
|
||||||
|
description="Play a sound using VLC player",
|
||||||
|
requires_success=True,
|
||||||
|
),
|
||||||
|
CreditActionType.AUDIO_EXTRACTION: CreditAction(
|
||||||
|
action_type=CreditActionType.AUDIO_EXTRACTION,
|
||||||
|
cost=5,
|
||||||
|
description="Extract audio from external URL",
|
||||||
|
requires_success=True,
|
||||||
|
),
|
||||||
|
CreditActionType.TEXT_TO_SPEECH: CreditAction(
|
||||||
|
action_type=CreditActionType.TEXT_TO_SPEECH,
|
||||||
|
cost=2,
|
||||||
|
description="Generate speech from text",
|
||||||
|
requires_success=True,
|
||||||
|
),
|
||||||
|
CreditActionType.SOUND_NORMALIZATION: CreditAction(
|
||||||
|
action_type=CreditActionType.SOUND_NORMALIZATION,
|
||||||
|
cost=1,
|
||||||
|
description="Normalize audio levels",
|
||||||
|
requires_success=True,
|
||||||
|
),
|
||||||
|
CreditActionType.API_REQUEST: CreditAction(
|
||||||
|
action_type=CreditActionType.API_REQUEST,
|
||||||
|
cost=1,
|
||||||
|
description="API request (rate limiting)",
|
||||||
|
requires_success=False, # Charged even if request fails
|
||||||
|
),
|
||||||
|
CreditActionType.PLAYLIST_CREATION: CreditAction(
|
||||||
|
action_type=CreditActionType.PLAYLIST_CREATION,
|
||||||
|
cost=3,
|
||||||
|
description="Create a new playlist",
|
||||||
|
requires_success=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_credit_action(action_type: CreditActionType) -> CreditAction:
|
||||||
|
"""Get a credit action definition by type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_type: The action type to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The credit action definition
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If action type is not found
|
||||||
|
|
||||||
|
"""
|
||||||
|
return CREDIT_ACTIONS[action_type]
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
|
||||||
|
"""Get all available credit actions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of all credit actions
|
||||||
|
|
||||||
|
"""
|
||||||
|
return CREDIT_ACTIONS.copy()
|
||||||
29
app/models/credit_transaction.py
Normal file
29
app/models/credit_transaction.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""Credit transaction model for tracking credit usage."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlmodel import Field, Relationship
|
||||||
|
|
||||||
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
class CreditTransaction(BaseModel, table=True):
|
||||||
|
"""Database model for credit transactions."""
|
||||||
|
|
||||||
|
__tablename__ = "credit_transaction" # pyright: ignore[reportAssignmentType]
|
||||||
|
|
||||||
|
user_id: int = Field(foreign_key="user.id", nullable=False)
|
||||||
|
action_type: str = Field(nullable=False)
|
||||||
|
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
|
||||||
|
balance_before: int = Field(nullable=False)
|
||||||
|
balance_after: int = Field(nullable=False)
|
||||||
|
description: str = Field(nullable=False)
|
||||||
|
success: bool = Field(nullable=False, default=True)
|
||||||
|
# JSON string for additional data
|
||||||
|
metadata_json: str | None = Field(default=None)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
user: "User" = Relationship(back_populates="credit_transactions")
|
||||||
@@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship
|
|||||||
from app.models.base import BaseModel
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from app.models.credit_transaction import CreditTransaction
|
||||||
from app.models.extraction import Extraction
|
from app.models.extraction import Extraction
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
from app.models.playlist import Playlist
|
from app.models.playlist import Playlist
|
||||||
@@ -35,3 +36,4 @@ class User(BaseModel, table=True):
|
|||||||
playlists: list["Playlist"] = Relationship(back_populates="user")
|
playlists: list["Playlist"] = Relationship(back_populates="user")
|
||||||
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
|
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
|
||||||
extractions: list["Extraction"] = Relationship(back_populates="user")
|
extractions: list["Extraction"] = Relationship(back_populates="user")
|
||||||
|
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")
|
||||||
|
|||||||
132
app/repositories/base.py
Normal file
132
app/repositories/base.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Base repository with common CRUD operations."""
|
||||||
|
|
||||||
|
from typing import Any, Generic, TypeVar
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
# Type variable for the model
|
||||||
|
ModelType = TypeVar("ModelType")
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRepository(Generic[ModelType]):
|
||||||
|
"""Base repository with common CRUD operations."""
|
||||||
|
|
||||||
|
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
|
||||||
|
"""Initialize the repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The SQLModel class
|
||||||
|
session: Database session
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def get_by_id(self, entity_id: int) -> ModelType | None:
|
||||||
|
"""Get an entity by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_id: The entity ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The entity if found, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.first()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_all(
|
||||||
|
self,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ModelType]:
|
||||||
|
"""Get all entities with pagination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of entities to return
|
||||||
|
offset: Number of entities to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of entities
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = select(self.model).limit(limit).offset(offset)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get all %s", self.model.__name__)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
||||||
|
"""Create a new entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_data: Dictionary of entity data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created entity
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
entity = self.model(**entity_data)
|
||||||
|
self.session.add(entity)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(entity)
|
||||||
|
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||||
|
return entity
|
||||||
|
except Exception:
|
||||||
|
await self.session.rollback()
|
||||||
|
logger.exception("Failed to create %s", self.model.__name__)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
|
||||||
|
"""Update an entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: The entity to update
|
||||||
|
update_data: Dictionary of fields to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated entity
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(entity, field, value)
|
||||||
|
|
||||||
|
self.session.add(entity)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(entity)
|
||||||
|
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||||
|
return entity
|
||||||
|
except Exception:
|
||||||
|
await self.session.rollback()
|
||||||
|
logger.exception("Failed to update %s", self.model.__name__)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete(self, entity: ModelType) -> None:
|
||||||
|
"""Delete an entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: The entity to delete
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.session.delete(entity)
|
||||||
|
await self.session.commit()
|
||||||
|
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||||
|
except Exception:
|
||||||
|
await self.session.rollback()
|
||||||
|
logger.exception("Failed to delete %s", self.model.__name__)
|
||||||
|
raise
|
||||||
108
app/repositories/credit_transaction.py
Normal file
108
app/repositories/credit_transaction.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Repository for credit transaction database operations."""
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.credit_transaction import CreditTransaction
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
|
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
|
||||||
|
"""Repository for credit transaction operations."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize the repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(CreditTransaction, session)
|
||||||
|
|
||||||
|
async def get_by_user_id(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[CreditTransaction]:
|
||||||
|
"""Get credit transactions for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
limit: Maximum number of transactions to return
|
||||||
|
offset: Number of transactions to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of credit transactions ordered by creation date (newest first)
|
||||||
|
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
select(CreditTransaction)
|
||||||
|
.where(CreditTransaction.user_id == user_id)
|
||||||
|
.order_by(CreditTransaction.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_by_action_type(
|
||||||
|
self,
|
||||||
|
action_type: str,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[CreditTransaction]:
|
||||||
|
"""Get credit transactions by action type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_type: The action type to filter by
|
||||||
|
limit: Maximum number of transactions to return
|
||||||
|
offset: Number of transactions to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of credit transactions ordered by creation date (newest first)
|
||||||
|
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
select(CreditTransaction)
|
||||||
|
.where(CreditTransaction.action_type == action_type)
|
||||||
|
.order_by(CreditTransaction.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
)
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_successful_transactions(
|
||||||
|
self,
|
||||||
|
user_id: int | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[CreditTransaction]:
|
||||||
|
"""Get successful credit transactions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Optional user ID to filter by
|
||||||
|
limit: Maximum number of transactions to return
|
||||||
|
offset: Number of transactions to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of successful credit transactions
|
||||||
|
|
||||||
|
"""
|
||||||
|
stmt = (
|
||||||
|
select(CreditTransaction)
|
||||||
|
.where(CreditTransaction.success == True) # noqa: E712
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(CreditTransaction.user_id == user_id)
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
stmt.order_by(CreditTransaction.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self.session.exec(stmt)
|
||||||
|
return list(result.all())
|
||||||
383
app/services/credit.py
Normal file
383
app/services/credit.py
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
"""Credit management service for tracking and validating user credit usage."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
|
||||||
|
from app.models.credit_transaction import CreditTransaction
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.user import UserRepository
|
||||||
|
from app.services.socket import socket_manager
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InsufficientCreditsError(Exception):
|
||||||
|
"""Raised when user has insufficient credits for an action."""
|
||||||
|
|
||||||
|
def __init__(self, required: int, available: int) -> None:
|
||||||
|
"""Initialize the error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required: Number of credits required
|
||||||
|
available: Number of credits available
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.required = required
|
||||||
|
self.available = available
|
||||||
|
super().__init__(
|
||||||
|
f"Insufficient credits: {required} required, {available} available"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreditService:
|
||||||
|
"""Service for managing user credits and transactions."""
|
||||||
|
|
||||||
|
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
||||||
|
"""Initialize the credit service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_factory: Factory function to create database sessions
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.db_session_factory = db_session_factory
|
||||||
|
|
||||||
|
async def check_credits(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
action_type: CreditActionType,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if user has sufficient credits for an action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
action_type: The type of action to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user has sufficient credits, False otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
action = get_credit_action(action_type)
|
||||||
|
session = self.db_session_factory()
|
||||||
|
try:
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
return user.credits >= action.cost
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
async def validate_and_reserve_credits(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
action_type: CreditActionType,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[User, CreditAction]:
|
||||||
|
"""Validate user has sufficient credits and optionally reserve them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
action_type: The type of action
|
||||||
|
metadata: Optional metadata to store with transaction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (user, credit_action)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InsufficientCreditsError: If user has insufficient credits
|
||||||
|
|
||||||
|
"""
|
||||||
|
action = get_credit_action(action_type)
|
||||||
|
session = self.db_session_factory()
|
||||||
|
try:
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
msg = f"User {user_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if user.credits < action.cost:
|
||||||
|
raise InsufficientCreditsError(action.cost, user.credits)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Credits validated for user %s: %s credits available, %s required",
|
||||||
|
user_id,
|
||||||
|
user.credits,
|
||||||
|
action.cost,
|
||||||
|
)
|
||||||
|
return user, action
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
async def deduct_credits(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
action_type: CreditActionType,
|
||||||
|
success: bool = True,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> CreditTransaction:
|
||||||
|
"""Deduct credits from user account and record transaction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
action_type: The type of action
|
||||||
|
success: Whether the action was successful
|
||||||
|
metadata: Optional metadata to store with transaction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created credit transaction
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InsufficientCreditsError: If user has insufficient credits
|
||||||
|
ValueError: If user not found
|
||||||
|
|
||||||
|
"""
|
||||||
|
action = get_credit_action(action_type)
|
||||||
|
|
||||||
|
# Only deduct if action requires success and was successful, or doesn't require success
|
||||||
|
should_deduct = (action.requires_success and success) or not action.requires_success
|
||||||
|
|
||||||
|
if not should_deduct:
|
||||||
|
logger.info(
|
||||||
|
"Skipping credit deduction for user %s: action %s failed and requires success",
|
||||||
|
user_id,
|
||||||
|
action_type.value,
|
||||||
|
)
|
||||||
|
# Still create a transaction record for auditing
|
||||||
|
return await self._create_transaction_record(
|
||||||
|
user_id, action, 0, success, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
session = self.db_session_factory()
|
||||||
|
try:
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
msg = f"User {user_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
if user.credits < action.cost:
|
||||||
|
raise InsufficientCreditsError(action.cost, user.credits)
|
||||||
|
|
||||||
|
# Record transaction
|
||||||
|
balance_before = user.credits
|
||||||
|
balance_after = user.credits - action.cost
|
||||||
|
|
||||||
|
transaction = CreditTransaction(
|
||||||
|
user_id=user_id,
|
||||||
|
action_type=action_type.value,
|
||||||
|
amount=-action.cost,
|
||||||
|
balance_before=balance_before,
|
||||||
|
balance_after=balance_after,
|
||||||
|
description=action.description,
|
||||||
|
success=success,
|
||||||
|
metadata_json=json.dumps(metadata) if metadata else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update user credits
|
||||||
|
await user_repo.update(user, {"credits": balance_after})
|
||||||
|
|
||||||
|
# Save transaction
|
||||||
|
session.add(transaction)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
|
||||||
|
user_id,
|
||||||
|
action.cost,
|
||||||
|
action_type.value,
|
||||||
|
success,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit user_credits_changed event via WebSocket
|
||||||
|
try:
|
||||||
|
event_data = {
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"credits_before": balance_before,
|
||||||
|
"credits_after": balance_after,
|
||||||
|
"credits_deducted": action.cost,
|
||||||
|
"action_type": action_type.value,
|
||||||
|
"success": success,
|
||||||
|
}
|
||||||
|
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||||
|
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return transaction
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
async def add_credits(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
amount: int,
|
||||||
|
description: str,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> CreditTransaction:
|
||||||
|
"""Add credits to user account.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
amount: Number of credits to add
|
||||||
|
description: Description of the credit addition
|
||||||
|
metadata: Optional metadata to store with transaction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created credit transaction
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user not found or amount is negative
|
||||||
|
|
||||||
|
"""
|
||||||
|
if amount <= 0:
|
||||||
|
msg = "Amount must be positive"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
session = self.db_session_factory()
|
||||||
|
try:
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
msg = f"User {user_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Record transaction
|
||||||
|
balance_before = user.credits
|
||||||
|
balance_after = user.credits + amount
|
||||||
|
|
||||||
|
transaction = CreditTransaction(
|
||||||
|
user_id=user_id,
|
||||||
|
action_type="credit_addition",
|
||||||
|
amount=amount,
|
||||||
|
balance_before=balance_before,
|
||||||
|
balance_after=balance_after,
|
||||||
|
description=description,
|
||||||
|
success=True,
|
||||||
|
metadata_json=json.dumps(metadata) if metadata else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update user credits
|
||||||
|
await user_repo.update(user, {"credits": balance_after})
|
||||||
|
|
||||||
|
# Save transaction
|
||||||
|
session.add(transaction)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Credits added for user %s: %s credits (description: %s)",
|
||||||
|
user_id,
|
||||||
|
amount,
|
||||||
|
description,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit user_credits_changed event via WebSocket
|
||||||
|
try:
|
||||||
|
event_data = {
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"credits_before": balance_before,
|
||||||
|
"credits_after": balance_after,
|
||||||
|
"credits_added": amount,
|
||||||
|
"description": description,
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||||
|
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return transaction
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
async def _create_transaction_record(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
action: CreditAction,
|
||||||
|
amount: int,
|
||||||
|
success: bool,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> CreditTransaction:
|
||||||
|
"""Create a transaction record without modifying credits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
action: The credit action
|
||||||
|
amount: Amount to record (typically 0 for failed actions)
|
||||||
|
success: Whether the action was successful
|
||||||
|
metadata: Optional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created transaction record
|
||||||
|
|
||||||
|
"""
|
||||||
|
session = self.db_session_factory()
|
||||||
|
try:
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
msg = f"User {user_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
transaction = CreditTransaction(
|
||||||
|
user_id=user_id,
|
||||||
|
action_type=action.action_type.value,
|
||||||
|
amount=amount,
|
||||||
|
balance_before=user.credits,
|
||||||
|
balance_after=user.credits,
|
||||||
|
description=f"{action.description} (failed)" if not success else action.description,
|
||||||
|
success=success,
|
||||||
|
metadata_json=json.dumps(metadata) if metadata else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(transaction)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return transaction
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
async def get_user_balance(self, user_id: int) -> int:
|
||||||
|
"""Get current credit balance for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current credit balance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user not found
|
||||||
|
|
||||||
|
"""
|
||||||
|
session = self.db_session_factory()
|
||||||
|
try:
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
msg = f"User {user_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return user.credits
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
192
app/utils/credit_decorators.py
Normal file
192
app/utils/credit_decorators.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""Decorators for credit management and validation."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from app.models.credit_action import CreditActionType
|
||||||
|
from app.services.credit import CreditService, InsufficientCreditsError
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
|
||||||
|
|
||||||
|
|
||||||
|
def requires_credits(
|
||||||
|
action_type: CreditActionType,
|
||||||
|
credit_service_factory: Callable[[], CreditService],
|
||||||
|
user_id_param: str = "user_id",
|
||||||
|
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
|
||||||
|
) -> Callable[[F], F]:
|
||||||
|
"""Decorator to enforce credit requirements for actions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_type: The type of action that requires credits
|
||||||
|
credit_service_factory: Factory to create credit service instance
|
||||||
|
user_id_param: Name of the parameter containing user ID
|
||||||
|
metadata_extractor: Optional function to extract metadata from function args
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorated function that validates and deducts credits
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@requires_credits(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
lambda: get_credit_service(),
|
||||||
|
user_id_param="user_id"
|
||||||
|
)
|
||||||
|
async def play_sound_for_user(user_id: int, sound: Sound) -> bool:
|
||||||
|
# Implementation here
|
||||||
|
return True
|
||||||
|
|
||||||
|
"""
|
||||||
|
def decorator(func: F) -> F:
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
# Extract user ID from parameters
|
||||||
|
user_id = None
|
||||||
|
if user_id_param in kwargs:
|
||||||
|
user_id = kwargs[user_id_param]
|
||||||
|
else:
|
||||||
|
# Try to find user_id in function signature
|
||||||
|
import inspect
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
param_names = list(sig.parameters.keys())
|
||||||
|
if user_id_param in param_names:
|
||||||
|
param_index = param_names.index(user_id_param)
|
||||||
|
if param_index < len(args):
|
||||||
|
user_id = args[param_index]
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
msg = f"Could not extract user_id from parameter '{user_id_param}'"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Extract metadata if extractor provided
|
||||||
|
metadata = None
|
||||||
|
if metadata_extractor:
|
||||||
|
metadata = metadata_extractor(*args, **kwargs)
|
||||||
|
|
||||||
|
# Get credit service
|
||||||
|
credit_service = credit_service_factory()
|
||||||
|
|
||||||
|
# Validate credits before execution
|
||||||
|
await credit_service.validate_and_reserve_credits(
|
||||||
|
user_id, action_type, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the function
|
||||||
|
success = False
|
||||||
|
result = None
|
||||||
|
try:
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
success = bool(result) # Consider function result as success indicator
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
success = False
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Deduct credits based on success
|
||||||
|
await credit_service.deduct_credits(
|
||||||
|
user_id, action_type, success, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper # type: ignore[return-value]
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def validate_credits_only(
|
||||||
|
action_type: CreditActionType,
|
||||||
|
credit_service_factory: Callable[[], CreditService],
|
||||||
|
user_id_param: str = "user_id",
|
||||||
|
) -> Callable[[F], F]:
|
||||||
|
"""Decorator to only validate credits without deducting them.
|
||||||
|
|
||||||
|
Useful for checking if a user can perform an action before actual execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_type: The type of action that requires credits
|
||||||
|
credit_service_factory: Factory to create credit service instance
|
||||||
|
user_id_param: Name of the parameter containing user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorated function that validates credits only
|
||||||
|
|
||||||
|
"""
|
||||||
|
def decorator(func: F) -> F:
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
# Extract user ID from parameters
|
||||||
|
user_id = None
|
||||||
|
if user_id_param in kwargs:
|
||||||
|
user_id = kwargs[user_id_param]
|
||||||
|
else:
|
||||||
|
# Try to find user_id in function signature
|
||||||
|
import inspect
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
param_names = list(sig.parameters.keys())
|
||||||
|
if user_id_param in param_names:
|
||||||
|
param_index = param_names.index(user_id_param)
|
||||||
|
if param_index < len(args):
|
||||||
|
user_id = args[param_index]
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
msg = f"Could not extract user_id from parameter '{user_id_param}'"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Get credit service
|
||||||
|
credit_service = credit_service_factory()
|
||||||
|
|
||||||
|
# Validate credits only
|
||||||
|
await credit_service.validate_and_reserve_credits(user_id, action_type)
|
||||||
|
|
||||||
|
# Execute the function
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper # type: ignore[return-value]
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class CreditManager:
|
||||||
|
"""Context manager for credit operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
credit_service: CreditService,
|
||||||
|
user_id: int,
|
||||||
|
action_type: CreditActionType,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize credit manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
credit_service: Credit service instance
|
||||||
|
user_id: User ID
|
||||||
|
action_type: Action type
|
||||||
|
metadata: Optional metadata
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.credit_service = credit_service
|
||||||
|
self.user_id = user_id
|
||||||
|
self.action_type = action_type
|
||||||
|
self.metadata = metadata
|
||||||
|
self.validated = False
|
||||||
|
self.success = False
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "CreditManager":
|
||||||
|
"""Enter context manager - validate credits."""
|
||||||
|
await self.credit_service.validate_and_reserve_credits(
|
||||||
|
self.user_id, self.action_type, self.metadata
|
||||||
|
)
|
||||||
|
self.validated = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
|
||||||
|
"""Exit context manager - deduct credits based on success."""
|
||||||
|
if self.validated:
|
||||||
|
# If no exception occurred, consider it successful
|
||||||
|
success = exc_type is None and self.success
|
||||||
|
await self.credit_service.deduct_credits(
|
||||||
|
self.user_id, self.action_type, success, self.metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def mark_success(self) -> None:
|
||||||
|
"""Mark the operation as successful."""
|
||||||
|
self.success = True
|
||||||
@@ -13,8 +13,10 @@ from sqlmodel import SQLModel, select
|
|||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
|
from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
|
||||||
from app.utils.auth import JWTUtils, PasswordUtils
|
from app.utils.auth import JWTUtils, PasswordUtils
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
412
tests/repositories/test_credit_transaction.py
Normal file
412
tests/repositories/test_credit_transaction.py
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
"""Tests for credit transaction repository."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.credit_transaction import CreditTransaction
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.credit_transaction import CreditTransactionRepository
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreditTransactionRepository:
|
||||||
|
"""Test credit transaction repository operations."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def credit_transaction_repository(
|
||||||
|
self,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
) -> AsyncGenerator[CreditTransactionRepository, None]: # type: ignore[misc]
|
||||||
|
"""Create a credit transaction repository instance."""
|
||||||
|
yield CreditTransactionRepository(test_session)
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user_id(
|
||||||
|
self,
|
||||||
|
test_user: User,
|
||||||
|
) -> int:
|
||||||
|
"""Get test user ID to avoid lazy loading issues."""
|
||||||
|
return test_user.id
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_transactions(
|
||||||
|
self,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> AsyncGenerator[list[CreditTransaction], None]: # type: ignore[misc]
|
||||||
|
"""Create test credit transactions."""
|
||||||
|
transactions = []
|
||||||
|
user_id = test_user_id
|
||||||
|
|
||||||
|
# Create various types of transactions
|
||||||
|
transaction_data = [
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"action_type": "vlc_play_sound",
|
||||||
|
"amount": -1,
|
||||||
|
"balance_before": 10,
|
||||||
|
"balance_after": 9,
|
||||||
|
"description": "Play sound via VLC",
|
||||||
|
"success": True,
|
||||||
|
"metadata_json": json.dumps({"sound_id": 1, "sound_name": "test.mp3"}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"action_type": "audio_extraction",
|
||||||
|
"amount": -5,
|
||||||
|
"balance_before": 9,
|
||||||
|
"balance_after": 4,
|
||||||
|
"description": "Extract audio from URL",
|
||||||
|
"success": True,
|
||||||
|
"metadata_json": json.dumps({"url": "https://example.com/video"}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"action_type": "vlc_play_sound",
|
||||||
|
"amount": 0,
|
||||||
|
"balance_before": 4,
|
||||||
|
"balance_after": 4,
|
||||||
|
"description": "Play sound via VLC (failed)",
|
||||||
|
"success": False,
|
||||||
|
"metadata_json": json.dumps({"sound_id": 2, "error": "File not found"}),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"action_type": "credit_addition",
|
||||||
|
"amount": 50,
|
||||||
|
"balance_before": 4,
|
||||||
|
"balance_after": 54,
|
||||||
|
"description": "Bonus credits",
|
||||||
|
"success": True,
|
||||||
|
"metadata_json": json.dumps({"reason": "signup_bonus"}),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for data in transaction_data:
|
||||||
|
transaction = CreditTransaction(**data)
|
||||||
|
test_session.add(transaction)
|
||||||
|
transactions.append(transaction)
|
||||||
|
|
||||||
|
await test_session.commit()
|
||||||
|
for transaction in transactions:
|
||||||
|
await test_session.refresh(transaction)
|
||||||
|
|
||||||
|
yield transactions
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def other_user_transaction(
|
||||||
|
self,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
ensure_plans: tuple, # noqa: ARG002
|
||||||
|
) -> AsyncGenerator[CreditTransaction, None]: # type: ignore[misc]
|
||||||
|
"""Create a transaction for a different user."""
|
||||||
|
from app.models.plan import Plan
|
||||||
|
from app.repositories.user import UserRepository
|
||||||
|
|
||||||
|
# Create another user
|
||||||
|
user_repo = UserRepository(test_session)
|
||||||
|
other_user_data = {
|
||||||
|
"email": "other@example.com",
|
||||||
|
"name": "Other User",
|
||||||
|
"password_hash": "hashed_password",
|
||||||
|
"role": "user",
|
||||||
|
"is_active": True,
|
||||||
|
}
|
||||||
|
other_user = await user_repo.create(other_user_data)
|
||||||
|
|
||||||
|
# Create transaction for the other user
|
||||||
|
transaction_data = {
|
||||||
|
"user_id": other_user.id,
|
||||||
|
"action_type": "vlc_play_sound",
|
||||||
|
"amount": -1,
|
||||||
|
"balance_before": 100,
|
||||||
|
"balance_after": 99,
|
||||||
|
"description": "Other user play sound",
|
||||||
|
"success": True,
|
||||||
|
"metadata_json": None,
|
||||||
|
}
|
||||||
|
transaction = CreditTransaction(**transaction_data)
|
||||||
|
test_session.add(transaction)
|
||||||
|
await test_session.commit()
|
||||||
|
await test_session.refresh(transaction)
|
||||||
|
|
||||||
|
yield transaction
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_id_existing(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
) -> None:
|
||||||
|
"""Test getting transaction by ID when it exists."""
|
||||||
|
transaction = await credit_transaction_repository.get_by_id(test_transactions[0].id)
|
||||||
|
|
||||||
|
assert transaction is not None
|
||||||
|
assert transaction.id == test_transactions[0].id
|
||||||
|
assert transaction.action_type == "vlc_play_sound"
|
||||||
|
assert transaction.amount == -1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_id_nonexistent(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting transaction by ID when it doesn't exist."""
|
||||||
|
transaction = await credit_transaction_repository.get_by_id(99999)
|
||||||
|
|
||||||
|
assert transaction is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_user_id(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
other_user_transaction: CreditTransaction,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting transactions by user ID."""
|
||||||
|
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
|
||||||
|
|
||||||
|
# Should return all transactions for test_user
|
||||||
|
assert len(transactions) == 4
|
||||||
|
# Should be ordered by created_at desc (newest first)
|
||||||
|
assert all(t.user_id == test_user_id for t in transactions)
|
||||||
|
|
||||||
|
# Should not include other user's transaction
|
||||||
|
other_user_ids = [t.user_id for t in transactions]
|
||||||
|
assert other_user_transaction.user_id not in other_user_ids
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_user_id_with_pagination(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting transactions by user ID with pagination."""
|
||||||
|
# Get first 2 transactions
|
||||||
|
first_page = await credit_transaction_repository.get_by_user_id(
|
||||||
|
test_user_id, limit=2, offset=0
|
||||||
|
)
|
||||||
|
assert len(first_page) == 2
|
||||||
|
|
||||||
|
# Get next 2 transactions
|
||||||
|
second_page = await credit_transaction_repository.get_by_user_id(
|
||||||
|
test_user_id, limit=2, offset=2
|
||||||
|
)
|
||||||
|
assert len(second_page) == 2
|
||||||
|
|
||||||
|
# Should not overlap
|
||||||
|
first_page_ids = {t.id for t in first_page}
|
||||||
|
second_page_ids = {t.id for t in second_page}
|
||||||
|
assert first_page_ids.isdisjoint(second_page_ids)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_action_type(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
) -> None:
|
||||||
|
"""Test getting transactions by action type."""
|
||||||
|
vlc_transactions = await credit_transaction_repository.get_by_action_type(
|
||||||
|
"vlc_play_sound"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 2 VLC transactions (1 successful, 1 failed)
|
||||||
|
assert len(vlc_transactions) >= 2
|
||||||
|
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
|
||||||
|
|
||||||
|
extraction_transactions = await credit_transaction_repository.get_by_action_type(
|
||||||
|
"audio_extraction"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 1 extraction transaction
|
||||||
|
assert len(extraction_transactions) >= 1
|
||||||
|
assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_action_type_with_pagination(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
) -> None:
|
||||||
|
"""Test getting transactions by action type with pagination."""
|
||||||
|
# Test with limit
|
||||||
|
transactions = await credit_transaction_repository.get_by_action_type(
|
||||||
|
"vlc_play_sound", limit=1
|
||||||
|
)
|
||||||
|
assert len(transactions) == 1
|
||||||
|
assert transactions[0].action_type == "vlc_play_sound"
|
||||||
|
|
||||||
|
# Test with offset
|
||||||
|
transactions = await credit_transaction_repository.get_by_action_type(
|
||||||
|
"vlc_play_sound", limit=1, offset=1
|
||||||
|
)
|
||||||
|
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_successful_transactions(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
) -> None:
|
||||||
|
"""Test getting only successful transactions."""
|
||||||
|
successful_transactions = await credit_transaction_repository.get_successful_transactions()
|
||||||
|
|
||||||
|
# Should only return successful transactions
|
||||||
|
assert all(t.success is True for t in successful_transactions)
|
||||||
|
# Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition)
|
||||||
|
assert len(successful_transactions) >= 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_successful_transactions_by_user(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
other_user_transaction: CreditTransaction,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting successful transactions filtered by user."""
|
||||||
|
successful_transactions = await credit_transaction_repository.get_successful_transactions(
|
||||||
|
user_id=test_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should only return successful transactions for test_user
|
||||||
|
assert all(t.success is True for t in successful_transactions)
|
||||||
|
assert all(t.user_id == test_user_id for t in successful_transactions)
|
||||||
|
# Should be 3 successful transactions for test_user
|
||||||
|
assert len(successful_transactions) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_successful_transactions_with_pagination(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting successful transactions with pagination."""
|
||||||
|
# Get first 2 successful transactions
|
||||||
|
first_page = await credit_transaction_repository.get_successful_transactions(
|
||||||
|
user_id=test_user_id, limit=2, offset=0
|
||||||
|
)
|
||||||
|
assert len(first_page) == 2
|
||||||
|
assert all(t.success is True for t in first_page)
|
||||||
|
|
||||||
|
# Get next successful transaction
|
||||||
|
second_page = await credit_transaction_repository.get_successful_transactions(
|
||||||
|
user_id=test_user_id, limit=2, offset=2
|
||||||
|
)
|
||||||
|
assert len(second_page) == 1 # Should be 1 remaining
|
||||||
|
assert all(t.success is True for t in second_page)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all_transactions(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
other_user_transaction: CreditTransaction,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting all transactions."""
|
||||||
|
all_transactions = await credit_transaction_repository.get_all()
|
||||||
|
|
||||||
|
# Should return all transactions
|
||||||
|
assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_transaction(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a new transaction."""
|
||||||
|
transaction_data = {
|
||||||
|
"user_id": test_user_id,
|
||||||
|
"action_type": "test_action",
|
||||||
|
"amount": -10,
|
||||||
|
"balance_before": 100,
|
||||||
|
"balance_after": 90,
|
||||||
|
"description": "Test transaction",
|
||||||
|
"success": True,
|
||||||
|
"metadata_json": json.dumps({"test": "data"}),
|
||||||
|
}
|
||||||
|
|
||||||
|
transaction = await credit_transaction_repository.create(transaction_data)
|
||||||
|
|
||||||
|
assert transaction.id is not None
|
||||||
|
assert transaction.user_id == test_user_id
|
||||||
|
assert transaction.action_type == "test_action"
|
||||||
|
assert transaction.amount == -10
|
||||||
|
assert transaction.balance_before == 100
|
||||||
|
assert transaction.balance_after == 90
|
||||||
|
assert transaction.success is True
|
||||||
|
assert json.loads(transaction.metadata_json) == {"test": "data"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_transaction(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
) -> None:
|
||||||
|
"""Test updating a transaction."""
|
||||||
|
transaction = test_transactions[0]
|
||||||
|
update_data = {
|
||||||
|
"description": "Updated description",
|
||||||
|
"metadata_json": json.dumps({"updated": True}),
|
||||||
|
}
|
||||||
|
|
||||||
|
updated_transaction = await credit_transaction_repository.update(
|
||||||
|
transaction, update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated_transaction.id == transaction.id
|
||||||
|
assert updated_transaction.description == "Updated description"
|
||||||
|
assert json.loads(updated_transaction.metadata_json) == {"updated": True}
|
||||||
|
# Other fields should remain unchanged
|
||||||
|
assert updated_transaction.amount == transaction.amount
|
||||||
|
assert updated_transaction.action_type == transaction.action_type
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_transaction(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a transaction."""
|
||||||
|
# Create a transaction to delete
|
||||||
|
transaction_data = {
|
||||||
|
"user_id": test_user_id,
|
||||||
|
"action_type": "to_delete",
|
||||||
|
"amount": -1,
|
||||||
|
"balance_before": 10,
|
||||||
|
"balance_after": 9,
|
||||||
|
"description": "To be deleted",
|
||||||
|
"success": True,
|
||||||
|
"metadata_json": None,
|
||||||
|
}
|
||||||
|
transaction = await credit_transaction_repository.create(transaction_data)
|
||||||
|
transaction_id = transaction.id
|
||||||
|
|
||||||
|
# Delete the transaction
|
||||||
|
await credit_transaction_repository.delete(transaction)
|
||||||
|
|
||||||
|
# Verify transaction is deleted
|
||||||
|
deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id)
|
||||||
|
assert deleted_transaction is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transaction_ordering(
|
||||||
|
self,
|
||||||
|
credit_transaction_repository: CreditTransactionRepository,
|
||||||
|
test_transactions: list[CreditTransaction],
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test that transactions are ordered by created_at desc."""
|
||||||
|
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
|
||||||
|
|
||||||
|
# Should be ordered by created_at desc (newest first)
|
||||||
|
for i in range(len(transactions) - 1):
|
||||||
|
assert transactions[i].created_at >= transactions[i + 1].created_at
|
||||||
376
tests/repositories/test_sound.py
Normal file
376
tests/repositories/test_sound.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""Tests for sound repository."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.sound import Sound
|
||||||
|
from app.repositories.sound import SoundRepository
|
||||||
|
|
||||||
|
|
||||||
|
class TestSoundRepository:
|
||||||
|
"""Test sound repository operations."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def sound_repository(
|
||||||
|
self,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
) -> AsyncGenerator[SoundRepository, None]: # type: ignore[misc]
|
||||||
|
"""Create a sound repository instance."""
|
||||||
|
yield SoundRepository(test_session)
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_sound(
|
||||||
|
self,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
|
||||||
|
"""Create a test sound."""
|
||||||
|
sound_data = {
|
||||||
|
"name": "Test Sound",
|
||||||
|
"filename": "test_sound.mp3",
|
||||||
|
"type": "SDB",
|
||||||
|
"duration": 5000,
|
||||||
|
"size": 1024000,
|
||||||
|
"hash": "test_hash_123",
|
||||||
|
"play_count": 0,
|
||||||
|
"is_normalized": False,
|
||||||
|
}
|
||||||
|
sound = Sound(**sound_data)
|
||||||
|
test_session.add(sound)
|
||||||
|
await test_session.commit()
|
||||||
|
await test_session.refresh(sound)
|
||||||
|
yield sound
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def normalized_sound(
|
||||||
|
self,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
|
||||||
|
"""Create a normalized test sound."""
|
||||||
|
sound_data = {
|
||||||
|
"name": "Normalized Sound",
|
||||||
|
"filename": "normalized_sound.mp3",
|
||||||
|
"type": "TTS",
|
||||||
|
"duration": 3000,
|
||||||
|
"size": 512000,
|
||||||
|
"hash": "normalized_hash_456",
|
||||||
|
"play_count": 5,
|
||||||
|
"is_normalized": True,
|
||||||
|
"normalized_filename": "normalized_sound_norm.mp3",
|
||||||
|
"normalized_duration": 3000,
|
||||||
|
"normalized_size": 480000,
|
||||||
|
"normalized_hash": "normalized_hash_norm_456",
|
||||||
|
}
|
||||||
|
sound = Sound(**sound_data)
|
||||||
|
test_session.add(sound)
|
||||||
|
await test_session.commit()
|
||||||
|
await test_session.refresh(sound)
|
||||||
|
yield sound
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_id_existing(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting sound by ID when it exists."""
|
||||||
|
sound = await sound_repository.get_by_id(test_sound.id)
|
||||||
|
|
||||||
|
assert sound is not None
|
||||||
|
assert sound.id == test_sound.id
|
||||||
|
assert sound.name == test_sound.name
|
||||||
|
assert sound.filename == test_sound.filename
|
||||||
|
assert sound.type == test_sound.type
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_id_nonexistent(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting sound by ID when it doesn't exist."""
|
||||||
|
sound = await sound_repository.get_by_id(99999)
|
||||||
|
|
||||||
|
assert sound is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_filename_existing(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting sound by filename when it exists."""
|
||||||
|
sound = await sound_repository.get_by_filename(test_sound.filename)
|
||||||
|
|
||||||
|
assert sound is not None
|
||||||
|
assert sound.id == test_sound.id
|
||||||
|
assert sound.filename == test_sound.filename
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_filename_nonexistent(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting sound by filename when it doesn't exist."""
|
||||||
|
sound = await sound_repository.get_by_filename("nonexistent.mp3")
|
||||||
|
|
||||||
|
assert sound is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_hash_existing(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting sound by hash when it exists."""
|
||||||
|
sound = await sound_repository.get_by_hash(test_sound.hash)
|
||||||
|
|
||||||
|
assert sound is not None
|
||||||
|
assert sound.id == test_sound.id
|
||||||
|
assert sound.hash == test_sound.hash
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_hash_nonexistent(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting sound by hash when it doesn't exist."""
|
||||||
|
sound = await sound_repository.get_by_hash("nonexistent_hash")
|
||||||
|
|
||||||
|
assert sound is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_type(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
normalized_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting sounds by type."""
|
||||||
|
sdb_sounds = await sound_repository.get_by_type("SDB")
|
||||||
|
tts_sounds = await sound_repository.get_by_type("TTS")
|
||||||
|
ext_sounds = await sound_repository.get_by_type("EXT")
|
||||||
|
|
||||||
|
# Should find the SDB sound
|
||||||
|
assert len(sdb_sounds) >= 1
|
||||||
|
assert any(sound.id == test_sound.id for sound in sdb_sounds)
|
||||||
|
|
||||||
|
# Should find the TTS sound
|
||||||
|
assert len(tts_sounds) >= 1
|
||||||
|
assert any(sound.id == normalized_sound.id for sound in tts_sounds)
|
||||||
|
|
||||||
|
# Should not find any EXT sounds
|
||||||
|
assert len(ext_sounds) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sound(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a new sound."""
|
||||||
|
sound_data = {
|
||||||
|
"name": "New Sound",
|
||||||
|
"filename": "new_sound.wav",
|
||||||
|
"type": "EXT",
|
||||||
|
"duration": 7500,
|
||||||
|
"size": 2048000,
|
||||||
|
"hash": "new_hash_789",
|
||||||
|
"play_count": 0,
|
||||||
|
"is_normalized": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
sound = await sound_repository.create(sound_data)
|
||||||
|
|
||||||
|
assert sound.id is not None
|
||||||
|
assert sound.name == sound_data["name"]
|
||||||
|
assert sound.filename == sound_data["filename"]
|
||||||
|
assert sound.type == sound_data["type"]
|
||||||
|
assert sound.duration == sound_data["duration"]
|
||||||
|
assert sound.size == sound_data["size"]
|
||||||
|
assert sound.hash == sound_data["hash"]
|
||||||
|
assert sound.play_count == 0
|
||||||
|
assert sound.is_normalized is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_sound(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test updating a sound."""
|
||||||
|
update_data = {
|
||||||
|
"name": "Updated Sound Name",
|
||||||
|
"play_count": 10,
|
||||||
|
"is_normalized": True,
|
||||||
|
"normalized_filename": "updated_norm.mp3",
|
||||||
|
}
|
||||||
|
|
||||||
|
updated_sound = await sound_repository.update(test_sound, update_data)
|
||||||
|
|
||||||
|
assert updated_sound.id == test_sound.id
|
||||||
|
assert updated_sound.name == "Updated Sound Name"
|
||||||
|
assert updated_sound.play_count == 10
|
||||||
|
assert updated_sound.is_normalized is True
|
||||||
|
assert updated_sound.normalized_filename == "updated_norm.mp3"
|
||||||
|
assert updated_sound.filename == test_sound.filename # Unchanged
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_sound(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting a sound."""
|
||||||
|
# Create a sound to delete
|
||||||
|
sound_data = {
|
||||||
|
"name": "To Delete",
|
||||||
|
"filename": "to_delete.mp3",
|
||||||
|
"type": "SDB",
|
||||||
|
"duration": 1000,
|
||||||
|
"size": 256000,
|
||||||
|
"hash": "delete_hash",
|
||||||
|
"play_count": 0,
|
||||||
|
"is_normalized": False,
|
||||||
|
}
|
||||||
|
sound = await sound_repository.create(sound_data)
|
||||||
|
sound_id = sound.id
|
||||||
|
|
||||||
|
# Delete the sound
|
||||||
|
await sound_repository.delete(sound)
|
||||||
|
|
||||||
|
# Verify sound is deleted
|
||||||
|
deleted_sound = await sound_repository.get_by_id(sound_id)
|
||||||
|
assert deleted_sound is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_by_name(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
normalized_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test searching sounds by name."""
|
||||||
|
# Search for "test" should find test_sound
|
||||||
|
results = await sound_repository.search_by_name("test")
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert any(sound.id == test_sound.id for sound in results)
|
||||||
|
|
||||||
|
# Search for "normalized" should find normalized_sound
|
||||||
|
results = await sound_repository.search_by_name("normalized")
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert any(sound.id == normalized_sound.id for sound in results)
|
||||||
|
|
||||||
|
# Case insensitive search
|
||||||
|
results = await sound_repository.search_by_name("TEST")
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert any(sound.id == test_sound.id for sound in results)
|
||||||
|
|
||||||
|
# Partial match
|
||||||
|
results = await sound_repository.search_by_name("norm")
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert any(sound.id == normalized_sound.id for sound in results)
|
||||||
|
|
||||||
|
# No matches
|
||||||
|
results = await sound_repository.search_by_name("nonexistent")
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_popular_sounds(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
normalized_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting popular sounds."""
|
||||||
|
# Update play counts to test ordering
|
||||||
|
await sound_repository.update(test_sound, {"play_count": 15})
|
||||||
|
await sound_repository.update(normalized_sound, {"play_count": 5})
|
||||||
|
|
||||||
|
# Create another sound with higher play count
|
||||||
|
high_play_sound_data = {
|
||||||
|
"name": "Popular Sound",
|
||||||
|
"filename": "popular.mp3",
|
||||||
|
"type": "SDB",
|
||||||
|
"duration": 2000,
|
||||||
|
"size": 300000,
|
||||||
|
"hash": "popular_hash",
|
||||||
|
"play_count": 25,
|
||||||
|
"is_normalized": False,
|
||||||
|
}
|
||||||
|
high_play_sound = await sound_repository.create(high_play_sound_data)
|
||||||
|
|
||||||
|
# Get popular sounds
|
||||||
|
popular_sounds = await sound_repository.get_popular_sounds(limit=10)
|
||||||
|
|
||||||
|
assert len(popular_sounds) >= 3
|
||||||
|
# Should be ordered by play_count desc
|
||||||
|
assert popular_sounds[0].play_count >= popular_sounds[1].play_count
|
||||||
|
# The highest play count sound should be first
|
||||||
|
assert popular_sounds[0].id == high_play_sound.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_unnormalized_sounds(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
normalized_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting unnormalized sounds."""
|
||||||
|
unnormalized_sounds = await sound_repository.get_unnormalized_sounds()
|
||||||
|
|
||||||
|
# Should include test_sound (not normalized)
|
||||||
|
assert any(sound.id == test_sound.id for sound in unnormalized_sounds)
|
||||||
|
# Should not include normalized_sound (already normalized)
|
||||||
|
assert not any(sound.id == normalized_sound.id for sound in unnormalized_sounds)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_unnormalized_sounds_by_type(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
normalized_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting unnormalized sounds by type."""
|
||||||
|
# Get unnormalized SDB sounds
|
||||||
|
sdb_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("SDB")
|
||||||
|
# Should include test_sound (SDB, not normalized)
|
||||||
|
assert any(sound.id == test_sound.id for sound in sdb_unnormalized)
|
||||||
|
|
||||||
|
# Get unnormalized TTS sounds
|
||||||
|
tts_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("TTS")
|
||||||
|
# Should not include normalized_sound (TTS, but already normalized)
|
||||||
|
assert not any(sound.id == normalized_sound.id for sound in tts_unnormalized)
|
||||||
|
|
||||||
|
# Get unnormalized EXT sounds
|
||||||
|
ext_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("EXT")
|
||||||
|
# Should be empty
|
||||||
|
assert len(ext_unnormalized) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_duplicate_hash(
|
||||||
|
self,
|
||||||
|
sound_repository: SoundRepository,
|
||||||
|
test_sound: Sound,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating sound with duplicate hash is allowed."""
|
||||||
|
# Store the hash to avoid lazy loading issues
|
||||||
|
original_hash = test_sound.hash
|
||||||
|
|
||||||
|
duplicate_sound_data = {
|
||||||
|
"name": "Duplicate Hash Sound",
|
||||||
|
"filename": "duplicate.mp3",
|
||||||
|
"type": "SDB",
|
||||||
|
"duration": 1000,
|
||||||
|
"size": 100000,
|
||||||
|
"hash": original_hash, # Same hash as test_sound
|
||||||
|
"play_count": 0,
|
||||||
|
"is_normalized": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should succeed - duplicate hashes are allowed
|
||||||
|
duplicate_sound = await sound_repository.create(duplicate_sound_data)
|
||||||
|
|
||||||
|
assert duplicate_sound.id is not None
|
||||||
|
assert duplicate_sound.name == "Duplicate Hash Sound"
|
||||||
|
assert duplicate_sound.hash == original_hash # Same hash is allowed
|
||||||
268
tests/repositories/test_user_oauth.py
Normal file
268
tests/repositories/test_user_oauth.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""Tests for user OAuth repository."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.user_oauth import UserOauth
|
||||||
|
from app.repositories.user_oauth import UserOauthRepository
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserOauthRepository:
|
||||||
|
"""Test user OAuth repository operations."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def user_oauth_repository(
|
||||||
|
self,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
) -> AsyncGenerator[UserOauthRepository, None]: # type: ignore[misc]
|
||||||
|
"""Create a user OAuth repository instance."""
|
||||||
|
yield UserOauthRepository(test_session)
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user_id(
|
||||||
|
self,
|
||||||
|
test_user: User,
|
||||||
|
) -> int:
|
||||||
|
"""Get test user ID to avoid lazy loading issues."""
|
||||||
|
return test_user.id
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_oauth(
|
||||||
|
self,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> AsyncGenerator[UserOauth, None]: # type: ignore[misc]
|
||||||
|
"""Create a test OAuth record."""
|
||||||
|
oauth_data = {
|
||||||
|
"user_id": test_user_id,
|
||||||
|
"provider": "google",
|
||||||
|
"provider_user_id": "google_123456",
|
||||||
|
"email": "test@gmail.com",
|
||||||
|
"name": "Test User Google",
|
||||||
|
"picture": None,
|
||||||
|
}
|
||||||
|
oauth = UserOauth(**oauth_data)
|
||||||
|
test_session.add(oauth)
|
||||||
|
await test_session.commit()
|
||||||
|
await test_session.refresh(oauth)
|
||||||
|
yield oauth
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_provider_user_id_existing(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
test_oauth: UserOauth,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting OAuth by provider user ID when it exists."""
|
||||||
|
oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||||
|
"google", "google_123456"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert oauth is not None
|
||||||
|
assert oauth.id == test_oauth.id
|
||||||
|
assert oauth.provider == "google"
|
||||||
|
assert oauth.provider_user_id == "google_123456"
|
||||||
|
assert oauth.user_id == test_oauth.user_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_provider_user_id_nonexistent(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting OAuth by provider user ID when it doesn't exist."""
|
||||||
|
oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||||
|
"google", "nonexistent_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert oauth is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_user_id_and_provider_existing(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
test_oauth: UserOauth,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting OAuth by user ID and provider when it exists."""
|
||||||
|
oauth = await user_oauth_repository.get_by_user_id_and_provider(
|
||||||
|
test_user_id, "google"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert oauth is not None
|
||||||
|
assert oauth.id == test_oauth.id
|
||||||
|
assert oauth.provider == "google"
|
||||||
|
assert oauth.user_id == test_user_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_user_id_and_provider_nonexistent(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting OAuth by user ID and provider when it doesn't exist."""
|
||||||
|
oauth = await user_oauth_repository.get_by_user_id_and_provider(
|
||||||
|
test_user_id, "github"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert oauth is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_oauth(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a new OAuth record."""
|
||||||
|
oauth_data = {
|
||||||
|
"user_id": test_user_id,
|
||||||
|
"provider": "github",
|
||||||
|
"provider_user_id": "github_789",
|
||||||
|
"email": "test@github.com",
|
||||||
|
"name": "Test User GitHub",
|
||||||
|
"picture": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth = await user_oauth_repository.create(oauth_data)
|
||||||
|
|
||||||
|
assert oauth.id is not None
|
||||||
|
assert oauth.user_id == test_user_id
|
||||||
|
assert oauth.provider == "github"
|
||||||
|
assert oauth.provider_user_id == "github_789"
|
||||||
|
assert oauth.email == "test@github.com"
|
||||||
|
assert oauth.name == "Test User GitHub"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_oauth(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
test_oauth: UserOauth,
|
||||||
|
) -> None:
|
||||||
|
"""Test updating an OAuth record."""
|
||||||
|
update_data = {
|
||||||
|
"email": "updated@gmail.com",
|
||||||
|
"name": "Updated User Name",
|
||||||
|
"picture": "https://example.com/photo.jpg",
|
||||||
|
}
|
||||||
|
|
||||||
|
updated_oauth = await user_oauth_repository.update(test_oauth, update_data)
|
||||||
|
|
||||||
|
assert updated_oauth.id == test_oauth.id
|
||||||
|
assert updated_oauth.email == "updated@gmail.com"
|
||||||
|
assert updated_oauth.name == "Updated User Name"
|
||||||
|
assert updated_oauth.picture == "https://example.com/photo.jpg"
|
||||||
|
assert updated_oauth.provider == test_oauth.provider # Unchanged
|
||||||
|
assert updated_oauth.provider_user_id == test_oauth.provider_user_id # Unchanged
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_oauth(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
test_session: AsyncSession,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test deleting an OAuth record."""
|
||||||
|
# Create an OAuth record to delete
|
||||||
|
oauth_data = {
|
||||||
|
"user_id": test_user_id,
|
||||||
|
"provider": "twitter",
|
||||||
|
"provider_user_id": "twitter_456",
|
||||||
|
"email": "test@twitter.com",
|
||||||
|
"name": "Test User Twitter",
|
||||||
|
"picture": None,
|
||||||
|
}
|
||||||
|
oauth = await user_oauth_repository.create(oauth_data)
|
||||||
|
oauth_id = oauth.id
|
||||||
|
|
||||||
|
# Delete the OAuth record
|
||||||
|
await user_oauth_repository.delete(oauth)
|
||||||
|
|
||||||
|
# Verify it's deleted by trying to find it
|
||||||
|
deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||||
|
"twitter", "twitter_456"
|
||||||
|
)
|
||||||
|
assert deleted_oauth is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_duplicate_provider_user_id(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
test_oauth: UserOauth,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating OAuth with duplicate provider user ID should fail."""
|
||||||
|
# Try to create another OAuth with the same provider and provider_user_id
|
||||||
|
duplicate_oauth_data = {
|
||||||
|
"user_id": test_user_id,
|
||||||
|
"provider": "google",
|
||||||
|
"provider_user_id": "google_123456", # Same as test_oauth
|
||||||
|
"email": "another@gmail.com",
|
||||||
|
"name": "Another User",
|
||||||
|
"picture": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# This should fail due to unique constraint
|
||||||
|
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
|
||||||
|
await user_oauth_repository.create(duplicate_oauth_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_providers_same_user(
|
||||||
|
self,
|
||||||
|
user_oauth_repository: UserOauthRepository,
|
||||||
|
test_user_id: int,
|
||||||
|
) -> None:
|
||||||
|
"""Test that a user can have multiple OAuth providers."""
|
||||||
|
# Create Google OAuth
|
||||||
|
google_oauth_data = {
|
||||||
|
"user_id": test_user_id,
|
||||||
|
"provider": "google",
|
||||||
|
"provider_user_id": "google_user_1",
|
||||||
|
"email": "user@gmail.com",
|
||||||
|
"name": "Test User Google",
|
||||||
|
"picture": None,
|
||||||
|
}
|
||||||
|
google_oauth = await user_oauth_repository.create(google_oauth_data)
|
||||||
|
|
||||||
|
# Create GitHub OAuth for the same user
|
||||||
|
github_oauth_data = {
|
||||||
|
"user_id": test_user_id,
|
||||||
|
"provider": "github",
|
||||||
|
"provider_user_id": "github_user_1",
|
||||||
|
"email": "user@github.com",
|
||||||
|
"name": "Test User GitHub",
|
||||||
|
"picture": None,
|
||||||
|
}
|
||||||
|
github_oauth = await user_oauth_repository.create(github_oauth_data)
|
||||||
|
|
||||||
|
# Verify both exist by querying back from database
|
||||||
|
found_google = await user_oauth_repository.get_by_user_id_and_provider(
|
||||||
|
test_user_id, "google"
|
||||||
|
)
|
||||||
|
found_github = await user_oauth_repository.get_by_user_id_and_provider(
|
||||||
|
test_user_id, "github"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert found_google is not None
|
||||||
|
assert found_github is not None
|
||||||
|
assert found_google.provider == "google"
|
||||||
|
assert found_github.provider == "github"
|
||||||
|
assert found_google.user_id == test_user_id
|
||||||
|
assert found_github.user_id == test_user_id
|
||||||
|
assert found_google.provider_user_id == "google_user_1"
|
||||||
|
assert found_github.provider_user_id == "github_user_1"
|
||||||
|
|
||||||
|
# Verify we can also find them by provider_user_id
|
||||||
|
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
|
||||||
|
"google", "google_user_1"
|
||||||
|
)
|
||||||
|
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
|
||||||
|
"github", "github_user_1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert found_google_by_provider is not None
|
||||||
|
assert found_github_by_provider is not None
|
||||||
|
assert found_google_by_provider.user_id == test_user_id
|
||||||
|
assert found_github_by_provider.user_id == test_user_id
|
||||||
358
tests/services/test_credit.py
Normal file
358
tests/services/test_credit.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
"""Tests for credit service."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.credit_action import CreditActionType
|
||||||
|
from app.models.credit_transaction import CreditTransaction
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.credit import CreditService, InsufficientCreditsError
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreditService:
|
||||||
|
"""Test credit service functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_db_session_factory(self):
|
||||||
|
"""Create a mock database session factory."""
|
||||||
|
session = AsyncMock(spec=AsyncSession)
|
||||||
|
return lambda: session
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def credit_service(self, mock_db_session_factory):
|
||||||
|
"""Create a credit service instance for testing."""
|
||||||
|
return CreditService(mock_db_session_factory)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_user(self):
|
||||||
|
"""Create a sample user for testing."""
|
||||||
|
return User(
|
||||||
|
id=1,
|
||||||
|
name="Test User",
|
||||||
|
email="test@example.com",
|
||||||
|
role="user",
|
||||||
|
credits=10,
|
||||||
|
plan_id=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_credits_sufficient(self, credit_service, sample_user):
|
||||||
|
"""Test checking credits when user has sufficient credits."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = sample_user
|
||||||
|
|
||||||
|
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_repo.get_by_id.assert_called_once_with(1)
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_credits_insufficient(self, credit_service):
|
||||||
|
"""Test checking credits when user has insufficient credits."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
poor_user = User(
|
||||||
|
id=1,
|
||||||
|
name="Poor User",
|
||||||
|
email="poor@example.com",
|
||||||
|
role="user",
|
||||||
|
credits=0, # No credits
|
||||||
|
plan_id=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = poor_user
|
||||||
|
|
||||||
|
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_credits_user_not_found(self, credit_service):
|
||||||
|
"""Test checking credits when user is not found."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = None
|
||||||
|
|
||||||
|
result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_and_reserve_credits_success(self, credit_service, sample_user):
|
||||||
|
"""Test successful credit validation and reservation."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = sample_user
|
||||||
|
|
||||||
|
user, action = await credit_service.validate_and_reserve_credits(
|
||||||
|
1, CreditActionType.VLC_PLAY_SOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user == sample_user
|
||||||
|
assert action.action_type == CreditActionType.VLC_PLAY_SOUND
|
||||||
|
assert action.cost == 1
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_and_reserve_credits_insufficient(self, credit_service):
|
||||||
|
"""Test credit validation with insufficient credits."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
poor_user = User(
|
||||||
|
id=1,
|
||||||
|
name="Poor User",
|
||||||
|
email="poor@example.com",
|
||||||
|
role="user",
|
||||||
|
credits=0,
|
||||||
|
plan_id=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = poor_user
|
||||||
|
|
||||||
|
with pytest.raises(InsufficientCreditsError) as exc_info:
|
||||||
|
await credit_service.validate_and_reserve_credits(
|
||||||
|
1, CreditActionType.VLC_PLAY_SOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.required == 1
|
||||||
|
assert exc_info.value.available == 0
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_and_reserve_credits_user_not_found(self, credit_service):
|
||||||
|
"""Test credit validation when user is not found."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="User 999 not found"):
|
||||||
|
await credit_service.validate_and_reserve_credits(
|
||||||
|
999, CreditActionType.VLC_PLAY_SOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduct_credits_success(self, credit_service, sample_user):
|
||||||
|
"""Test successful credit deduction."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||||
|
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = sample_user
|
||||||
|
mock_socket_manager.send_to_user = AsyncMock()
|
||||||
|
|
||||||
|
transaction = await credit_service.deduct_credits(
|
||||||
|
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify user credits were updated
|
||||||
|
mock_repo.update.assert_called_once_with(sample_user, {"credits": 9})
|
||||||
|
|
||||||
|
# Verify transaction was created
|
||||||
|
mock_session.add.assert_called_once()
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify socket event was emitted
|
||||||
|
mock_socket_manager.send_to_user.assert_called_once_with(
|
||||||
|
"1", "user_credits_changed", {
|
||||||
|
"user_id": "1",
|
||||||
|
"credits_before": 10,
|
||||||
|
"credits_after": 9,
|
||||||
|
"credits_deducted": 1,
|
||||||
|
"action_type": "vlc_play_sound",
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check transaction details
|
||||||
|
added_transaction = mock_session.add.call_args[0][0]
|
||||||
|
assert isinstance(added_transaction, CreditTransaction)
|
||||||
|
assert added_transaction.user_id == 1
|
||||||
|
assert added_transaction.action_type == "vlc_play_sound"
|
||||||
|
assert added_transaction.amount == -1
|
||||||
|
assert added_transaction.balance_before == 10
|
||||||
|
assert added_transaction.balance_after == 9
|
||||||
|
assert added_transaction.success is True
|
||||||
|
assert json.loads(added_transaction.metadata_json) == {"test": "data"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user):
|
||||||
|
"""Test credit deduction when action failed but requires success."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||||
|
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = sample_user
|
||||||
|
mock_socket_manager.send_to_user = AsyncMock()
|
||||||
|
|
||||||
|
transaction = await credit_service.deduct_credits(
|
||||||
|
1, CreditActionType.VLC_PLAY_SOUND, False # Action failed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify user credits were NOT updated (action requires success)
|
||||||
|
mock_repo.update.assert_not_called()
|
||||||
|
|
||||||
|
# Verify transaction was still created for auditing
|
||||||
|
mock_session.add.assert_called_once()
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify no socket event was emitted since no credits were actually deducted
|
||||||
|
mock_socket_manager.send_to_user.assert_not_called()
|
||||||
|
|
||||||
|
# Check transaction details
|
||||||
|
added_transaction = mock_session.add.call_args[0][0]
|
||||||
|
assert added_transaction.amount == 0 # No deduction for failed action
|
||||||
|
assert added_transaction.balance_before == 10
|
||||||
|
assert added_transaction.balance_after == 10 # No change
|
||||||
|
assert added_transaction.success is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deduct_credits_insufficient(self, credit_service):
|
||||||
|
"""Test credit deduction with insufficient credits."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
poor_user = User(
|
||||||
|
id=1,
|
||||||
|
name="Poor User",
|
||||||
|
email="poor@example.com",
|
||||||
|
role="user",
|
||||||
|
credits=0,
|
||||||
|
plan_id=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||||
|
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = poor_user
|
||||||
|
mock_socket_manager.send_to_user = AsyncMock()
|
||||||
|
|
||||||
|
with pytest.raises(InsufficientCreditsError):
|
||||||
|
await credit_service.deduct_credits(
|
||||||
|
1, CreditActionType.VLC_PLAY_SOUND, True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify no socket event was emitted since credits could not be deducted
|
||||||
|
mock_socket_manager.send_to_user.assert_not_called()
|
||||||
|
|
||||||
|
mock_session.rollback.assert_called_once()
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_credits(self, credit_service, sample_user):
|
||||||
|
"""Test adding credits to user account."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||||
|
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = sample_user
|
||||||
|
mock_socket_manager.send_to_user = AsyncMock()
|
||||||
|
|
||||||
|
transaction = await credit_service.add_credits(
|
||||||
|
1, 5, "Bonus credits", {"reason": "signup"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify user credits were updated
|
||||||
|
mock_repo.update.assert_called_once_with(sample_user, {"credits": 15})
|
||||||
|
|
||||||
|
# Verify transaction was created
|
||||||
|
mock_session.add.assert_called_once()
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify socket event was emitted
|
||||||
|
mock_socket_manager.send_to_user.assert_called_once_with(
|
||||||
|
"1", "user_credits_changed", {
|
||||||
|
"user_id": "1",
|
||||||
|
"credits_before": 10,
|
||||||
|
"credits_after": 15,
|
||||||
|
"credits_added": 5,
|
||||||
|
"description": "Bonus credits",
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check transaction details
|
||||||
|
added_transaction = mock_session.add.call_args[0][0]
|
||||||
|
assert added_transaction.amount == 5
|
||||||
|
assert added_transaction.balance_before == 10
|
||||||
|
assert added_transaction.balance_after == 15
|
||||||
|
assert added_transaction.description == "Bonus credits"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_credits_invalid_amount(self, credit_service):
|
||||||
|
"""Test adding invalid amount of credits."""
|
||||||
|
with pytest.raises(ValueError, match="Amount must be positive"):
|
||||||
|
await credit_service.add_credits(1, 0, "Invalid")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Amount must be positive"):
|
||||||
|
await credit_service.add_credits(1, -5, "Invalid")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_balance(self, credit_service, sample_user):
|
||||||
|
"""Test getting user credit balance."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = sample_user
|
||||||
|
|
||||||
|
balance = await credit_service.get_user_balance(1)
|
||||||
|
|
||||||
|
assert balance == 10
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_balance_user_not_found(self, credit_service):
|
||||||
|
"""Test getting balance for non-existent user."""
|
||||||
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|
||||||
|
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
mock_repo.get_by_id.return_value = None
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="User 999 not found"):
|
||||||
|
await credit_service.get_user_balance(999)
|
||||||
|
|
||||||
|
mock_session.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestInsufficientCreditsError:
|
||||||
|
"""Test InsufficientCreditsError exception."""
|
||||||
|
|
||||||
|
def test_insufficient_credits_error_creation(self):
|
||||||
|
"""Test creating InsufficientCreditsError."""
|
||||||
|
error = InsufficientCreditsError(5, 2)
|
||||||
|
assert error.required == 5
|
||||||
|
assert error.available == 2
|
||||||
|
assert str(error) == "Insufficient credits: 5 required, 2 available"
|
||||||
277
tests/utils/test_credit_decorators.py
Normal file
277
tests/utils/test_credit_decorators.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""Tests for credit decorators."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.credit_action import CreditActionType
|
||||||
|
from app.services.credit import CreditService, InsufficientCreditsError
|
||||||
|
from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequiresCreditsDecorator:
|
||||||
|
"""Test requires_credits decorator."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_credit_service(self):
|
||||||
|
"""Create a mock credit service."""
|
||||||
|
service = AsyncMock(spec=CreditService)
|
||||||
|
service.validate_and_reserve_credits = AsyncMock()
|
||||||
|
service.deduct_credits = AsyncMock()
|
||||||
|
return service
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def credit_service_factory(self, mock_credit_service):
|
||||||
|
"""Create a credit service factory."""
|
||||||
|
return lambda: mock_credit_service
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decorator_success(self, credit_service_factory, mock_credit_service):
|
||||||
|
"""Test decorator with successful action."""
|
||||||
|
|
||||||
|
@requires_credits(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
credit_service_factory,
|
||||||
|
user_id_param="user_id"
|
||||||
|
)
|
||||||
|
async def test_action(user_id: int, message: str) -> str:
|
||||||
|
return f"Success: {message}"
|
||||||
|
|
||||||
|
result = await test_action(user_id=123, message="test")
|
||||||
|
|
||||||
|
assert result == "Success: test"
|
||||||
|
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, None
|
||||||
|
)
|
||||||
|
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, True, None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decorator_with_metadata(self, credit_service_factory, mock_credit_service):
|
||||||
|
"""Test decorator with metadata extraction."""
|
||||||
|
|
||||||
|
def extract_metadata(user_id: int, sound_name: str) -> dict:
|
||||||
|
return {"sound_name": sound_name}
|
||||||
|
|
||||||
|
@requires_credits(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
credit_service_factory,
|
||||||
|
user_id_param="user_id",
|
||||||
|
metadata_extractor=extract_metadata
|
||||||
|
)
|
||||||
|
async def test_action(user_id: int, sound_name: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
await test_action(user_id=123, sound_name="test.mp3")
|
||||||
|
|
||||||
|
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"}
|
||||||
|
)
|
||||||
|
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service):
|
||||||
|
"""Test decorator with failed action."""
|
||||||
|
|
||||||
|
@requires_credits(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
credit_service_factory,
|
||||||
|
user_id_param="user_id"
|
||||||
|
)
|
||||||
|
async def test_action(user_id: int) -> bool:
|
||||||
|
return False # Action fails
|
||||||
|
|
||||||
|
result = await test_action(user_id=123)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service):
|
||||||
|
"""Test decorator when action raises exception."""
|
||||||
|
|
||||||
|
@requires_credits(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
credit_service_factory,
|
||||||
|
user_id_param="user_id"
|
||||||
|
)
|
||||||
|
async def test_action(user_id: int) -> str:
|
||||||
|
raise ValueError("Test error")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Test error"):
|
||||||
|
await test_action(user_id=123)
|
||||||
|
|
||||||
|
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service):
|
||||||
|
"""Test decorator with insufficient credits."""
|
||||||
|
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
|
||||||
|
|
||||||
|
@requires_credits(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
credit_service_factory,
|
||||||
|
user_id_param="user_id"
|
||||||
|
)
|
||||||
|
async def test_action(user_id: int) -> str:
|
||||||
|
return "Should not execute"
|
||||||
|
|
||||||
|
with pytest.raises(InsufficientCreditsError):
|
||||||
|
await test_action(user_id=123)
|
||||||
|
|
||||||
|
# Should not call deduct_credits since validation failed
|
||||||
|
mock_credit_service.deduct_credits.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service):
|
||||||
|
"""Test decorator extracting user_id from positional args."""
|
||||||
|
|
||||||
|
@requires_credits(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
credit_service_factory,
|
||||||
|
user_id_param="user_id"
|
||||||
|
)
|
||||||
|
async def test_action(user_id: int, message: str) -> str:
|
||||||
|
return message
|
||||||
|
|
||||||
|
result = await test_action(123, "test")
|
||||||
|
|
||||||
|
assert result == "test"
|
||||||
|
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_decorator_missing_user_id(self, credit_service_factory):
|
||||||
|
"""Test decorator when user_id cannot be extracted."""
|
||||||
|
|
||||||
|
@requires_credits(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
credit_service_factory,
|
||||||
|
user_id_param="user_id"
|
||||||
|
)
|
||||||
|
async def test_action(other_param: str) -> str:
|
||||||
|
return other_param
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Could not extract user_id"):
|
||||||
|
await test_action(other_param="test")
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateCreditsOnlyDecorator:
|
||||||
|
"""Test validate_credits_only decorator."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_credit_service(self):
|
||||||
|
"""Create a mock credit service."""
|
||||||
|
service = AsyncMock(spec=CreditService)
|
||||||
|
service.validate_and_reserve_credits = AsyncMock()
|
||||||
|
return service
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def credit_service_factory(self, mock_credit_service):
|
||||||
|
"""Create a credit service factory."""
|
||||||
|
return lambda: mock_credit_service
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service):
|
||||||
|
"""Test validate_credits_only decorator."""
|
||||||
|
|
||||||
|
@validate_credits_only(
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
credit_service_factory,
|
||||||
|
user_id_param="user_id"
|
||||||
|
)
|
||||||
|
async def test_action(user_id: int, message: str) -> str:
|
||||||
|
return f"Validated: {message}"
|
||||||
|
|
||||||
|
result = await test_action(user_id=123, message="test")
|
||||||
|
|
||||||
|
assert result == "Validated: test"
|
||||||
|
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND
|
||||||
|
)
|
||||||
|
# Should not deduct credits, only validate
|
||||||
|
mock_credit_service.deduct_credits.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreditManager:
|
||||||
|
"""Test CreditManager context manager."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_credit_service(self):
|
||||||
|
"""Create a mock credit service."""
|
||||||
|
service = AsyncMock(spec=CreditService)
|
||||||
|
service.validate_and_reserve_credits = AsyncMock()
|
||||||
|
service.deduct_credits = AsyncMock()
|
||||||
|
return service
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credit_manager_success(self, mock_credit_service):
|
||||||
|
"""Test CreditManager with successful operation."""
|
||||||
|
async with CreditManager(
|
||||||
|
mock_credit_service,
|
||||||
|
123,
|
||||||
|
CreditActionType.VLC_PLAY_SOUND,
|
||||||
|
{"test": "data"}
|
||||||
|
) as manager:
|
||||||
|
manager.mark_success()
|
||||||
|
|
||||||
|
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"}
|
||||||
|
)
|
||||||
|
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credit_manager_failure(self, mock_credit_service):
|
||||||
|
"""Test CreditManager with failed operation."""
|
||||||
|
async with CreditManager(
|
||||||
|
mock_credit_service,
|
||||||
|
123,
|
||||||
|
CreditActionType.VLC_PLAY_SOUND
|
||||||
|
):
|
||||||
|
# Don't mark as success - should be considered failed
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credit_manager_exception(self, mock_credit_service):
|
||||||
|
"""Test CreditManager when exception occurs."""
|
||||||
|
with pytest.raises(ValueError, match="Test error"):
|
||||||
|
async with CreditManager(
|
||||||
|
mock_credit_service,
|
||||||
|
123,
|
||||||
|
CreditActionType.VLC_PLAY_SOUND
|
||||||
|
):
|
||||||
|
raise ValueError("Test error")
|
||||||
|
|
||||||
|
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||||
|
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credit_manager_validation_failure(self, mock_credit_service):
|
||||||
|
"""Test CreditManager when validation fails."""
|
||||||
|
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
|
||||||
|
|
||||||
|
with pytest.raises(InsufficientCreditsError):
|
||||||
|
async with CreditManager(
|
||||||
|
mock_credit_service,
|
||||||
|
123,
|
||||||
|
CreditActionType.VLC_PLAY_SOUND
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should not call deduct_credits since validation failed
|
||||||
|
mock_credit_service.deduct_credits.assert_not_called()
|
||||||
Reference in New Issue
Block a user