Compare commits
2 Commits
1b0d291ad3
...
e43650c26c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e43650c26c | ||
|
|
dd10ef5d41 |
@@ -7,11 +7,15 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.extraction import ExtractionInfo, ExtractionService
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
|
||||
from app.services.sound_scanner import ScanResults, SoundScannerService
|
||||
from app.services.vlc_player import get_vlc_player_service, VLCPlayerService
|
||||
|
||||
router = APIRouter(prefix="/sounds", tags=["sounds"])
|
||||
|
||||
@@ -37,6 +41,25 @@ async def get_extraction_service(
|
||||
return ExtractionService(session)
|
||||
|
||||
|
||||
def get_vlc_player() -> VLCPlayerService:
|
||||
"""Get the VLC player service."""
|
||||
from app.core.database import get_session_factory
|
||||
return get_vlc_player_service(get_session_factory())
|
||||
|
||||
|
||||
def get_credit_service() -> CreditService:
|
||||
"""Get the credit service."""
|
||||
from app.core.database import get_session_factory
|
||||
return CreditService(get_session_factory())
|
||||
|
||||
|
||||
async def get_sound_repository(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> SoundRepository:
|
||||
"""Get the sound repository."""
|
||||
return SoundRepository(session)
|
||||
|
||||
|
||||
# SCAN
|
||||
@router.post("/scan")
|
||||
async def scan_sounds(
|
||||
@@ -349,3 +372,87 @@ async def get_user_extractions(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extractions: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
# VLC PLAYER
|
||||
@router.post("/vlc/play/{sound_id}")
|
||||
async def play_sound_with_vlc(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
|
||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||
credit_service: Annotated[CreditService, Depends(get_credit_service)],
|
||||
) -> dict[str, str | int | bool]:
|
||||
"""Play a sound using VLC subprocess (requires 1 credit)."""
|
||||
try:
|
||||
# Get the sound
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Sound with ID {sound_id} not found",
|
||||
)
|
||||
|
||||
# Check and validate credits before playing
|
||||
try:
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
current_user.id,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
{"sound_id": sound_id, "sound_name": sound.name}
|
||||
)
|
||||
except InsufficientCreditsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Insufficient credits: {e.required} required, {e.available} available",
|
||||
) from e
|
||||
|
||||
# Play the sound using VLC
|
||||
success = await vlc_player.play_sound(sound)
|
||||
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
current_user.id,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
success,
|
||||
{"sound_id": sound_id, "sound_name": sound.name},
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to launch VLC for sound playback",
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Sound '{sound.name}' is now playing via VLC",
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound.name,
|
||||
"success": True,
|
||||
"credits_deducted": 1,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to play sound: {e!s}",
|
||||
) from e
|
||||
|
||||
|
||||
|
||||
@router.post("/vlc/stop-all")
|
||||
async def stop_all_vlc_instances(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
|
||||
) -> dict:
|
||||
"""Stop all running VLC instances."""
|
||||
try:
|
||||
result = await vlc_player.stop_all_vlc_instances()
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to stop VLC instances: {e!s}",
|
||||
) from e
|
||||
|
||||
121
app/models/credit_action.py
Normal file
121
app/models/credit_action.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Credit action definitions for the credit system."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CreditActionType(str, Enum):
|
||||
"""Types of actions that consume credits."""
|
||||
|
||||
VLC_PLAY_SOUND = "vlc_play_sound"
|
||||
AUDIO_EXTRACTION = "audio_extraction"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
SOUND_NORMALIZATION = "sound_normalization"
|
||||
API_REQUEST = "api_request"
|
||||
PLAYLIST_CREATION = "playlist_creation"
|
||||
|
||||
|
||||
class CreditAction:
|
||||
"""Definition of a credit-consuming action."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_type: CreditActionType,
|
||||
cost: int,
|
||||
description: str,
|
||||
*,
|
||||
requires_success: bool = True,
|
||||
) -> None:
|
||||
"""Initialize a credit action.
|
||||
|
||||
Args:
|
||||
action_type: The type of action
|
||||
cost: Number of credits required
|
||||
description: Human-readable description
|
||||
requires_success: Whether credits are only deducted on successful completion
|
||||
|
||||
"""
|
||||
self.action_type = action_type
|
||||
self.cost = cost
|
||||
self.description = description
|
||||
self.requires_success = requires_success
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string representation of the action."""
|
||||
return f"{self.action_type.value} ({self.cost} credits)"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"action_type": self.action_type.value,
|
||||
"cost": self.cost,
|
||||
"description": self.description,
|
||||
"requires_success": self.requires_success,
|
||||
}
|
||||
|
||||
|
||||
# Predefined credit actions
|
||||
CREDIT_ACTIONS = {
|
||||
CreditActionType.VLC_PLAY_SOUND: CreditAction(
|
||||
action_type=CreditActionType.VLC_PLAY_SOUND,
|
||||
cost=1,
|
||||
description="Play a sound using VLC player",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.AUDIO_EXTRACTION: CreditAction(
|
||||
action_type=CreditActionType.AUDIO_EXTRACTION,
|
||||
cost=5,
|
||||
description="Extract audio from external URL",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.TEXT_TO_SPEECH: CreditAction(
|
||||
action_type=CreditActionType.TEXT_TO_SPEECH,
|
||||
cost=2,
|
||||
description="Generate speech from text",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.SOUND_NORMALIZATION: CreditAction(
|
||||
action_type=CreditActionType.SOUND_NORMALIZATION,
|
||||
cost=1,
|
||||
description="Normalize audio levels",
|
||||
requires_success=True,
|
||||
),
|
||||
CreditActionType.API_REQUEST: CreditAction(
|
||||
action_type=CreditActionType.API_REQUEST,
|
||||
cost=1,
|
||||
description="API request (rate limiting)",
|
||||
requires_success=False, # Charged even if request fails
|
||||
),
|
||||
CreditActionType.PLAYLIST_CREATION: CreditAction(
|
||||
action_type=CreditActionType.PLAYLIST_CREATION,
|
||||
cost=3,
|
||||
description="Create a new playlist",
|
||||
requires_success=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_credit_action(action_type: CreditActionType) -> CreditAction:
|
||||
"""Get a credit action definition by type.
|
||||
|
||||
Args:
|
||||
action_type: The action type to look up
|
||||
|
||||
Returns:
|
||||
The credit action definition
|
||||
|
||||
Raises:
|
||||
KeyError: If action type is not found
|
||||
|
||||
"""
|
||||
return CREDIT_ACTIONS[action_type]
|
||||
|
||||
|
||||
def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
|
||||
"""Get all available credit actions.
|
||||
|
||||
Returns:
|
||||
Dictionary of all credit actions
|
||||
|
||||
"""
|
||||
return CREDIT_ACTIONS.copy()
|
||||
29
app/models/credit_transaction.py
Normal file
29
app/models/credit_transaction.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Credit transaction model for tracking credit usage."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class CreditTransaction(BaseModel, table=True):
|
||||
"""Database model for credit transactions."""
|
||||
|
||||
__tablename__ = "credit_transaction" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(foreign_key="user.id", nullable=False)
|
||||
action_type: str = Field(nullable=False)
|
||||
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
|
||||
balance_before: int = Field(nullable=False)
|
||||
balance_after: int = Field(nullable=False)
|
||||
description: str = Field(nullable=False)
|
||||
success: bool = Field(nullable=False, default=True)
|
||||
# JSON string for additional data
|
||||
metadata_json: str | None = Field(default=None)
|
||||
|
||||
# relationships
|
||||
user: "User" = Relationship(back_populates="credit_transactions")
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
@@ -14,18 +14,9 @@ class SoundPlayed(BaseModel, table=True):
|
||||
|
||||
__tablename__ = "sound_played" # pyright: ignore[reportAssignmentType]
|
||||
|
||||
user_id: int = Field(foreign_key="user.id", nullable=False)
|
||||
user_id: int | None = Field(foreign_key="user.id", nullable=True)
|
||||
sound_id: int = Field(foreign_key="sound.id", nullable=False)
|
||||
|
||||
# constraints
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"user_id",
|
||||
"sound_id",
|
||||
name="uq_sound_played_user_sound",
|
||||
),
|
||||
)
|
||||
|
||||
# relationships
|
||||
user: "User" = Relationship(back_populates="sounds_played")
|
||||
sound: "Sound" = Relationship(back_populates="play_history")
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship
|
||||
from app.models.base import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.plan import Plan
|
||||
from app.models.playlist import Playlist
|
||||
@@ -35,3 +36,4 @@ class User(BaseModel, table=True):
|
||||
playlists: list["Playlist"] = Relationship(back_populates="user")
|
||||
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
|
||||
extractions: list["Extraction"] = Relationship(back_populates="user")
|
||||
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")
|
||||
|
||||
132
app/repositories/base.py
Normal file
132
app/repositories/base.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
# Type variable for the model
|
||||
ModelType = TypeVar("ModelType")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelType]):
|
||||
"""Base repository with common CRUD operations."""
|
||||
|
||||
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
|
||||
"""Initialize the repository.
|
||||
|
||||
Args:
|
||||
model: The SQLModel class
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
self.model = model
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, entity_id: int) -> ModelType | None:
|
||||
"""Get an entity by ID.
|
||||
|
||||
Args:
|
||||
entity_id: The entity ID
|
||||
|
||||
Returns:
|
||||
The entity if found, None otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
|
||||
raise
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ModelType]:
|
||||
"""Get all entities with pagination.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entities to return
|
||||
offset: Number of entities to skip
|
||||
|
||||
Returns:
|
||||
List of entities
|
||||
|
||||
"""
|
||||
try:
|
||||
statement = select(self.model).limit(limit).offset(offset)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception("Failed to get all %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
||||
"""Create a new entity.
|
||||
|
||||
Args:
|
||||
entity_data: Dictionary of entity data
|
||||
|
||||
Returns:
|
||||
The created entity
|
||||
|
||||
"""
|
||||
try:
|
||||
entity = self.model(**entity_data)
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
|
||||
"""Update an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to update
|
||||
update_data: Dictionary of fields to update
|
||||
|
||||
Returns:
|
||||
The updated entity
|
||||
|
||||
"""
|
||||
try:
|
||||
for field, value in update_data.items():
|
||||
setattr(entity, field, value)
|
||||
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
return entity
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def delete(self, entity: ModelType) -> None:
|
||||
"""Delete an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to delete
|
||||
|
||||
"""
|
||||
try:
|
||||
await self.session.delete(entity)
|
||||
await self.session.commit()
|
||||
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete %s", self.model.__name__)
|
||||
raise
|
||||
108
app/repositories/credit_transaction.py
Normal file
108
app/repositories/credit_transaction.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Repository for credit transaction database operations."""
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
|
||||
"""Repository for credit transaction operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the repository.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
||||
"""
|
||||
super().__init__(CreditTransaction, session)
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: int,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get credit transactions for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of credit transactions ordered by creation date (newest first)
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.user_id == user_id)
|
||||
.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
|
||||
async def get_by_action_type(
|
||||
self,
|
||||
action_type: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get credit transactions by action type.
|
||||
|
||||
Args:
|
||||
action_type: The action type to filter by
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of credit transactions ordered by creation date (newest first)
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.action_type == action_type)
|
||||
.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
|
||||
async def get_successful_transactions(
|
||||
self,
|
||||
user_id: int | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[CreditTransaction]:
|
||||
"""Get successful credit transactions.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter by
|
||||
limit: Maximum number of transactions to return
|
||||
offset: Number of transactions to skip
|
||||
|
||||
Returns:
|
||||
List of successful credit transactions
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.success == True) # noqa: E712
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(CreditTransaction.user_id == user_id)
|
||||
|
||||
stmt = (
|
||||
stmt.order_by(CreditTransaction.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await self.session.exec(stmt)
|
||||
return list(result.all())
|
||||
383
app/services/credit.py
Normal file
383
app/services/credit.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""Credit management service for tracking and validating user credit usage."""
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.services.socket import socket_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class InsufficientCreditsError(Exception):
|
||||
"""Raised when user has insufficient credits for an action."""
|
||||
|
||||
def __init__(self, required: int, available: int) -> None:
|
||||
"""Initialize the error.
|
||||
|
||||
Args:
|
||||
required: Number of credits required
|
||||
available: Number of credits available
|
||||
|
||||
"""
|
||||
self.required = required
|
||||
self.available = available
|
||||
super().__init__(
|
||||
f"Insufficient credits: {required} required, {available} available"
|
||||
)
|
||||
|
||||
|
||||
class CreditService:
|
||||
"""Service for managing user credits and transactions."""
|
||||
|
||||
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
||||
"""Initialize the credit service.
|
||||
|
||||
Args:
|
||||
db_session_factory: Factory function to create database sessions
|
||||
|
||||
"""
|
||||
self.db_session_factory = db_session_factory
|
||||
|
||||
async def check_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
) -> bool:
|
||||
"""Check if user has sufficient credits for an action.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action to check
|
||||
|
||||
Returns:
|
||||
True if user has sufficient credits, False otherwise
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
return user.credits >= action.cost
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def validate_and_reserve_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> tuple[User, CreditAction]:
|
||||
"""Validate user has sufficient credits and optionally reserve them.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
Tuple of (user, credit_action)
|
||||
|
||||
Raises:
|
||||
InsufficientCreditsError: If user has insufficient credits
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if user.credits < action.cost:
|
||||
raise InsufficientCreditsError(action.cost, user.credits)
|
||||
|
||||
logger.info(
|
||||
"Credits validated for user %s: %s credits available, %s required",
|
||||
user_id,
|
||||
user.credits,
|
||||
action.cost,
|
||||
)
|
||||
return user, action
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def deduct_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
success: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Deduct credits from user account and record transaction.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action_type: The type of action
|
||||
success: Whether the action was successful
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
The created credit transaction
|
||||
|
||||
Raises:
|
||||
InsufficientCreditsError: If user has insufficient credits
|
||||
ValueError: If user not found
|
||||
|
||||
"""
|
||||
action = get_credit_action(action_type)
|
||||
|
||||
# Only deduct if action requires success and was successful, or doesn't require success
|
||||
should_deduct = (action.requires_success and success) or not action.requires_success
|
||||
|
||||
if not should_deduct:
|
||||
logger.info(
|
||||
"Skipping credit deduction for user %s: action %s failed and requires success",
|
||||
user_id,
|
||||
action_type.value,
|
||||
)
|
||||
# Still create a transaction record for auditing
|
||||
return await self._create_transaction_record(
|
||||
user_id, action, 0, success, metadata
|
||||
)
|
||||
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
if user.credits < action.cost:
|
||||
raise InsufficientCreditsError(action.cost, user.credits)
|
||||
|
||||
# Record transaction
|
||||
balance_before = user.credits
|
||||
balance_after = user.credits - action.cost
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type=action_type.value,
|
||||
amount=-action.cost,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance_after,
|
||||
description=action.description,
|
||||
success=success,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
# Update user credits
|
||||
await user_repo.update(user, {"credits": balance_after})
|
||||
|
||||
# Save transaction
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
|
||||
user_id,
|
||||
action.cost,
|
||||
action_type.value,
|
||||
success,
|
||||
)
|
||||
|
||||
# Emit user_credits_changed event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"user_id": str(user_id),
|
||||
"credits_before": balance_before,
|
||||
"credits_after": balance_after,
|
||||
"credits_deducted": action.cost,
|
||||
"action_type": action_type.value,
|
||||
"success": success,
|
||||
}
|
||||
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def add_credits(
|
||||
self,
|
||||
user_id: int,
|
||||
amount: int,
|
||||
description: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Add credits to user account.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
amount: Number of credits to add
|
||||
description: Description of the credit addition
|
||||
metadata: Optional metadata to store with transaction
|
||||
|
||||
Returns:
|
||||
The created credit transaction
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found or amount is negative
|
||||
|
||||
"""
|
||||
if amount <= 0:
|
||||
msg = "Amount must be positive"
|
||||
raise ValueError(msg)
|
||||
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Record transaction
|
||||
balance_before = user.credits
|
||||
balance_after = user.credits + amount
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type="credit_addition",
|
||||
amount=amount,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance_after,
|
||||
description=description,
|
||||
success=True,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
# Update user credits
|
||||
await user_repo.update(user, {"credits": balance_after})
|
||||
|
||||
# Save transaction
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Credits added for user %s: %s credits (description: %s)",
|
||||
user_id,
|
||||
amount,
|
||||
description,
|
||||
)
|
||||
|
||||
# Emit user_credits_changed event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"user_id": str(user_id),
|
||||
"credits_before": balance_before,
|
||||
"credits_after": balance_after,
|
||||
"credits_added": amount,
|
||||
"description": description,
|
||||
"success": True,
|
||||
}
|
||||
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
|
||||
logger.info("Emitted user_credits_changed event for user %s", user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to emit user_credits_changed event for user %s", user_id,
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def _create_transaction_record(
|
||||
self,
|
||||
user_id: int,
|
||||
action: CreditAction,
|
||||
amount: int,
|
||||
success: bool,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> CreditTransaction:
|
||||
"""Create a transaction record without modifying credits.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
action: The credit action
|
||||
amount: Amount to record (typically 0 for failed actions)
|
||||
success: Whether the action was successful
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
The created transaction record
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
transaction = CreditTransaction(
|
||||
user_id=user_id,
|
||||
action_type=action.action_type.value,
|
||||
amount=amount,
|
||||
balance_before=user.credits,
|
||||
balance_after=user.credits,
|
||||
description=f"{action.description} (failed)" if not success else action.description,
|
||||
success=success,
|
||||
metadata_json=json.dumps(metadata) if metadata else None,
|
||||
)
|
||||
|
||||
session.add(transaction)
|
||||
await session.commit()
|
||||
|
||||
return transaction
|
||||
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def get_user_balance(self, user_id: int) -> int:
|
||||
"""Get current credit balance for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
Current credit balance
|
||||
|
||||
Raises:
|
||||
ValueError: If user not found
|
||||
|
||||
"""
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
user_repo = UserRepository(session)
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
msg = f"User {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
return user.credits
|
||||
finally:
|
||||
await session.close()
|
||||
@@ -9,15 +9,14 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import vlc # type: ignore[import-untyped]
|
||||
from sqlmodel import select
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.sound import Sound
|
||||
from app.models.sound_played import SoundPlayed
|
||||
from app.repositories.playlist import PlaylistRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.user import UserRepository
|
||||
from app.services.socket import socket_manager
|
||||
from app.utils.audio import get_sound_file_path
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -198,7 +197,7 @@ class PlayerService:
|
||||
return
|
||||
|
||||
# Get sound file path
|
||||
sound_path = self._get_sound_file_path(self.state.current_sound)
|
||||
sound_path = get_sound_file_path(self.state.current_sound)
|
||||
if not sound_path.exists():
|
||||
logger.error("Sound file not found: %s", sound_path)
|
||||
return
|
||||
@@ -344,6 +343,12 @@ class PlayerService:
|
||||
if self.state.status != PlayerStatus.STOPPED:
|
||||
await self._stop_playback()
|
||||
|
||||
# Set first track as current if no current track and playlist has sounds
|
||||
if not self.state.current_sound_id and sounds:
|
||||
self.state.current_sound_index = 0
|
||||
self.state.current_sound = sounds[0]
|
||||
self.state.current_sound_id = sounds[0].id
|
||||
|
||||
logger.info(
|
||||
"Loaded playlist: %s (%s sounds)",
|
||||
current_playlist.name,
|
||||
@@ -360,21 +365,6 @@ class PlayerService:
|
||||
"""Get current player state."""
|
||||
return self.state.to_dict()
|
||||
|
||||
def _get_sound_file_path(self, sound: Sound) -> Path:
|
||||
"""Get the file path for a sound."""
|
||||
# Determine the correct subdirectory based on sound type
|
||||
subdir = "extracted" if sound.type.upper() == "EXT" else sound.type.lower()
|
||||
|
||||
# Use normalized file if available, otherwise original
|
||||
if sound.is_normalized and sound.normalized_filename:
|
||||
return (
|
||||
Path("sounds/normalized")
|
||||
/ subdir
|
||||
/ sound.normalized_filename
|
||||
)
|
||||
return (
|
||||
Path("sounds/originals") / subdir / sound.filename
|
||||
)
|
||||
|
||||
def _get_next_index(self, current_index: int) -> int | None:
|
||||
"""Get next track index based on current mode."""
|
||||
@@ -501,7 +491,6 @@ class PlayerService:
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
sound_repo = SoundRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# Update sound play count
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
@@ -519,37 +508,17 @@ class PlayerService:
|
||||
else:
|
||||
logger.warning("Sound %s not found for play count update", sound_id)
|
||||
|
||||
# Record play history for admin user (ID 1) as placeholder
|
||||
# This could be refined to track per-user play history
|
||||
admin_user = await user_repo.get_by_id(1)
|
||||
if admin_user:
|
||||
# Check if already recorded for this user using proper query
|
||||
stmt = select(SoundPlayed).where(
|
||||
SoundPlayed.user_id == admin_user.id,
|
||||
SoundPlayed.sound_id == sound_id,
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
existing = result.first()
|
||||
|
||||
if not existing:
|
||||
sound_played = SoundPlayed(
|
||||
user_id=admin_user.id,
|
||||
sound_id=sound_id,
|
||||
)
|
||||
session.add(sound_played)
|
||||
logger.info(
|
||||
"Created SoundPlayed record for user %s, sound %s",
|
||||
admin_user.id,
|
||||
sound_id,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"SoundPlayed record already exists for user %s, sound %s",
|
||||
admin_user.id,
|
||||
sound_id,
|
||||
)
|
||||
else:
|
||||
logger.warning("Admin user (ID 1) not found for play history")
|
||||
# Record play history without user_id for player-based plays
|
||||
# Always create a new SoundPlayed record for each play event
|
||||
sound_played = SoundPlayed(
|
||||
user_id=None, # No user_id for player-based plays
|
||||
sound_id=sound_id,
|
||||
)
|
||||
session.add(sound_played)
|
||||
logger.info(
|
||||
"Created SoundPlayed record for player play, sound %s",
|
||||
sound_id,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Successfully recorded play count for sound %s", sound_id)
|
||||
|
||||
313
app/services/vlc_player.py
Normal file
313
app/services/vlc_player.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""VLC subprocess-based player service for immediate sound playback."""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.sound import Sound
|
||||
from app.models.sound_played import SoundPlayed
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.repositories.user import UserRepository
|
||||
from app.services.socket import socket_manager
|
||||
from app.utils.audio import get_sound_file_path
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VLCPlayerService:
|
||||
"""Service for launching VLC instances via subprocess to play sounds."""
|
||||
|
||||
def __init__(
|
||||
self, db_session_factory: Callable[[], AsyncSession] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the VLC player service."""
|
||||
self.vlc_executable = self._find_vlc_executable()
|
||||
self.db_session_factory = db_session_factory
|
||||
logger.info(
|
||||
"VLC Player Service initialized with executable: %s",
|
||||
self.vlc_executable,
|
||||
)
|
||||
|
||||
def _find_vlc_executable(self) -> str:
|
||||
"""Find VLC executable path based on the operating system."""
|
||||
# Common VLC executable paths
|
||||
possible_paths = [
|
||||
"vlc", # Linux/Mac with VLC in PATH
|
||||
"/usr/bin/vlc", # Linux
|
||||
"/usr/local/bin/vlc", # Linux/Mac
|
||||
"/Applications/VLC.app/Contents/MacOS/VLC", # macOS
|
||||
"C:\\Program Files\\VideoLAN\\VLC\\vlc.exe", # Windows
|
||||
"C:\\Program Files (x86)\\VideoLAN\\VLC\\vlc.exe", # Windows 32-bit
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
try:
|
||||
if Path(path).exists():
|
||||
return path
|
||||
# For "vlc", try to find it in PATH
|
||||
if path == "vlc":
|
||||
result = subprocess.run(
|
||||
["which", "vlc"],
|
||||
capture_output=True,
|
||||
check=False,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return path
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
continue
|
||||
|
||||
# Default to 'vlc' and let the system handle it
|
||||
logger.warning(
|
||||
"VLC executable not found in common paths, using 'vlc' from PATH",
|
||||
)
|
||||
return "vlc"
|
||||
|
||||
async def play_sound(self, sound: Sound) -> bool:
|
||||
"""Play a sound using a new VLC subprocess instance.
|
||||
|
||||
Args:
|
||||
sound: The Sound object to play
|
||||
|
||||
Returns:
|
||||
bool: True if VLC process was launched successfully, False otherwise
|
||||
|
||||
"""
|
||||
try:
|
||||
sound_path = get_sound_file_path(sound)
|
||||
|
||||
if not sound_path.exists():
|
||||
logger.error("Sound file not found: %s", sound_path)
|
||||
return False
|
||||
|
||||
# VLC command arguments for immediate playback
|
||||
cmd = [
|
||||
self.vlc_executable,
|
||||
str(sound_path),
|
||||
"--play-and-exit", # Exit VLC when playback finishes
|
||||
"--intf",
|
||||
"dummy", # No interface
|
||||
"--no-video", # Audio only
|
||||
"--no-repeat", # Don't repeat
|
||||
"--no-loop", # Don't loop
|
||||
]
|
||||
|
||||
# Launch VLC process asynchronously without waiting
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Launched VLC process (PID: %s) for sound: %s",
|
||||
process.pid,
|
||||
sound.name,
|
||||
)
|
||||
|
||||
# Record play count and emit event
|
||||
if self.db_session_factory and sound.id:
|
||||
asyncio.create_task(self._record_play_count(sound.id, sound.name))
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to launch VLC for sound %s", sound.name)
|
||||
return False
|
||||
|
||||
async def stop_all_vlc_instances(self) -> dict[str, Any]:
|
||||
"""Stop all running VLC processes by killing them.
|
||||
|
||||
Returns:
|
||||
dict: Results of the stop operation including counts and any errors
|
||||
|
||||
"""
|
||||
try:
|
||||
# Find all VLC processes
|
||||
find_cmd = ["pgrep", "-f", "vlc"]
|
||||
find_process = await asyncio.create_subprocess_exec(
|
||||
*find_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await find_process.communicate()
|
||||
|
||||
if find_process.returncode != 0:
|
||||
# No VLC processes found
|
||||
logger.info("No VLC processes found to stop")
|
||||
return {
|
||||
"success": True,
|
||||
"processes_found": 0,
|
||||
"processes_killed": 0,
|
||||
"message": "No VLC processes found",
|
||||
}
|
||||
|
||||
# Parse PIDs from output
|
||||
pids = []
|
||||
if stdout:
|
||||
pids = [
|
||||
pid.strip()
|
||||
for pid in stdout.decode().strip().split("\n")
|
||||
if pid.strip()
|
||||
]
|
||||
|
||||
if not pids:
|
||||
logger.info("No VLC processes found to stop")
|
||||
return {
|
||||
"success": True,
|
||||
"processes_found": 0,
|
||||
"processes_killed": 0,
|
||||
"message": "No VLC processes found",
|
||||
}
|
||||
|
||||
logger.info("Found %s VLC processes: %s", len(pids), ", ".join(pids))
|
||||
|
||||
# Kill all VLC processes
|
||||
kill_cmd = ["pkill", "-f", "vlc"]
|
||||
kill_process = await asyncio.create_subprocess_exec(
|
||||
*kill_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
await kill_process.communicate()
|
||||
|
||||
# Verify processes were killed
|
||||
verify_process = await asyncio.create_subprocess_exec(
|
||||
*find_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout_verify, _ = await verify_process.communicate()
|
||||
|
||||
remaining_pids = []
|
||||
if verify_process.returncode == 0 and stdout_verify:
|
||||
remaining_pids = [
|
||||
pid.strip()
|
||||
for pid in stdout_verify.decode().strip().split("\n")
|
||||
if pid.strip()
|
||||
]
|
||||
|
||||
processes_killed = len(pids) - len(remaining_pids)
|
||||
|
||||
logger.info(
|
||||
"Kill operation completed. Found: %s, Killed: %s, Remaining: %s",
|
||||
len(pids),
|
||||
processes_killed,
|
||||
len(remaining_pids),
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"processes_found": len(pids),
|
||||
"processes_killed": processes_killed,
|
||||
"processes_remaining": len(remaining_pids),
|
||||
"message": f"Killed {processes_killed} VLC processes",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to stop VLC processes")
|
||||
return {
|
||||
"success": False,
|
||||
"processes_found": 0,
|
||||
"processes_killed": 0,
|
||||
"error": str(e),
|
||||
"message": "Failed to stop VLC processes",
|
||||
}
|
||||
|
||||
async def _record_play_count(self, sound_id: int, sound_name: str) -> None:
|
||||
"""Record a play count for a sound and emit sound_played event."""
|
||||
if not self.db_session_factory:
|
||||
logger.warning(
|
||||
"No database session factory available for play count recording",
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("Recording play count for sound %s", sound_id)
|
||||
session = self.db_session_factory()
|
||||
try:
|
||||
sound_repo = SoundRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# Update sound play count
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
old_count = 0
|
||||
if sound:
|
||||
old_count = sound.play_count
|
||||
await sound_repo.update(
|
||||
sound,
|
||||
{"play_count": sound.play_count + 1},
|
||||
)
|
||||
logger.info(
|
||||
"Updated sound %s play_count: %s -> %s",
|
||||
sound_id,
|
||||
old_count,
|
||||
old_count + 1,
|
||||
)
|
||||
else:
|
||||
logger.warning("Sound %s not found for play count update", sound_id)
|
||||
|
||||
# Record play history for admin user (ID 1) as placeholder
|
||||
# This could be refined to track per-user play history
|
||||
admin_user = await user_repo.get_by_id(1)
|
||||
admin_user_id = None
|
||||
if admin_user:
|
||||
admin_user_id = admin_user.id
|
||||
|
||||
# Always create a new SoundPlayed record for each play event
|
||||
sound_played = SoundPlayed(
|
||||
user_id=admin_user_id, # Can be None for player-based plays
|
||||
sound_id=sound_id,
|
||||
)
|
||||
session.add(sound_played)
|
||||
logger.info(
|
||||
"Created SoundPlayed record for user %s, sound %s",
|
||||
admin_user_id,
|
||||
sound_id,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Successfully recorded play count for sound %s", sound_id)
|
||||
|
||||
# Emit sound_played event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound_name,
|
||||
"user_id": admin_user_id,
|
||||
"play_count": (old_count + 1) if sound else None,
|
||||
}
|
||||
await socket_manager.broadcast_to_all("sound_played", event_data)
|
||||
logger.info("Broadcasted sound_played event for sound %s", sound_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to broadcast sound_played event for sound %s", sound_id,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error recording play count for sound %s", sound_id)
|
||||
await session.rollback()
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
|
||||
# Global VLC player service instance
|
||||
vlc_player_service: VLCPlayerService | None = None
|
||||
|
||||
|
||||
def get_vlc_player_service(
|
||||
db_session_factory: Callable[[], AsyncSession] | None = None,
|
||||
) -> VLCPlayerService:
|
||||
"""Get the global VLC player service instance."""
|
||||
global vlc_player_service # noqa: PLW0603
|
||||
if vlc_player_service is None:
|
||||
vlc_player_service = VLCPlayerService(db_session_factory)
|
||||
return vlc_player_service
|
||||
@@ -2,11 +2,15 @@
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import ffmpeg # type: ignore[import-untyped]
|
||||
|
||||
from app.core.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.sound import Sound
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -33,3 +37,30 @@ def get_audio_duration(file_path: Path) -> int:
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get duration for %s: %s", file_path, e)
|
||||
return 0
|
||||
|
||||
|
||||
def get_sound_file_path(sound: "Sound") -> Path:
|
||||
"""Get the file path for a sound based on its type and normalization status.
|
||||
|
||||
Args:
|
||||
sound: The Sound object to get the path for
|
||||
|
||||
Returns:
|
||||
Path: The full path to the sound file
|
||||
|
||||
"""
|
||||
# Determine the correct subdirectory based on sound type
|
||||
if sound.type.upper() == "EXT":
|
||||
subdir = "extracted"
|
||||
elif sound.type.upper() == "SDB":
|
||||
subdir = "soundboard"
|
||||
elif sound.type.upper() == "TTS":
|
||||
subdir = "text_to_speech"
|
||||
else:
|
||||
# Fallback to lowercase type
|
||||
subdir = sound.type.lower()
|
||||
|
||||
# Use normalized file if available, otherwise original
|
||||
if sound.is_normalized and sound.normalized_filename:
|
||||
return Path("sounds/normalized") / subdir / sound.normalized_filename
|
||||
return Path("sounds/originals") / subdir / sound.filename
|
||||
|
||||
192
app/utils/credit_decorators.py
Normal file
192
app/utils/credit_decorators.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Decorators for credit management and validation."""
|
||||
|
||||
import functools
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
|
||||
|
||||
|
||||
def requires_credits(
|
||||
action_type: CreditActionType,
|
||||
credit_service_factory: Callable[[], CreditService],
|
||||
user_id_param: str = "user_id",
|
||||
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to enforce credit requirements for actions.
|
||||
|
||||
Args:
|
||||
action_type: The type of action that requires credits
|
||||
credit_service_factory: Factory to create credit service instance
|
||||
user_id_param: Name of the parameter containing user ID
|
||||
metadata_extractor: Optional function to extract metadata from function args
|
||||
|
||||
Returns:
|
||||
Decorated function that validates and deducts credits
|
||||
|
||||
Example:
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
lambda: get_credit_service(),
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def play_sound_for_user(user_id: int, sound: Sound) -> bool:
|
||||
# Implementation here
|
||||
return True
|
||||
|
||||
"""
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Extract user ID from parameters
|
||||
user_id = None
|
||||
if user_id_param in kwargs:
|
||||
user_id = kwargs[user_id_param]
|
||||
else:
|
||||
# Try to find user_id in function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
if user_id_param in param_names:
|
||||
param_index = param_names.index(user_id_param)
|
||||
if param_index < len(args):
|
||||
user_id = args[param_index]
|
||||
|
||||
if user_id is None:
|
||||
msg = f"Could not extract user_id from parameter '{user_id_param}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Extract metadata if extractor provided
|
||||
metadata = None
|
||||
if metadata_extractor:
|
||||
metadata = metadata_extractor(*args, **kwargs)
|
||||
|
||||
# Get credit service
|
||||
credit_service = credit_service_factory()
|
||||
|
||||
# Validate credits before execution
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
user_id, action_type, metadata
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
success = False
|
||||
result = None
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
success = bool(result) # Consider function result as success indicator
|
||||
return result
|
||||
except Exception:
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
user_id, action_type, success, metadata
|
||||
)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_credits_only(
|
||||
action_type: CreditActionType,
|
||||
credit_service_factory: Callable[[], CreditService],
|
||||
user_id_param: str = "user_id",
|
||||
) -> Callable[[F], F]:
|
||||
"""Decorator to only validate credits without deducting them.
|
||||
|
||||
Useful for checking if a user can perform an action before actual execution.
|
||||
|
||||
Args:
|
||||
action_type: The type of action that requires credits
|
||||
credit_service_factory: Factory to create credit service instance
|
||||
user_id_param: Name of the parameter containing user ID
|
||||
|
||||
Returns:
|
||||
Decorated function that validates credits only
|
||||
|
||||
"""
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
# Extract user ID from parameters
|
||||
user_id = None
|
||||
if user_id_param in kwargs:
|
||||
user_id = kwargs[user_id_param]
|
||||
else:
|
||||
# Try to find user_id in function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
if user_id_param in param_names:
|
||||
param_index = param_names.index(user_id_param)
|
||||
if param_index < len(args):
|
||||
user_id = args[param_index]
|
||||
|
||||
if user_id is None:
|
||||
msg = f"Could not extract user_id from parameter '{user_id_param}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get credit service
|
||||
credit_service = credit_service_factory()
|
||||
|
||||
# Validate credits only
|
||||
await credit_service.validate_and_reserve_credits(user_id, action_type)
|
||||
|
||||
# Execute the function
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
return decorator
|
||||
|
||||
|
||||
class CreditManager:
|
||||
"""Context manager for credit operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
credit_service: CreditService,
|
||||
user_id: int,
|
||||
action_type: CreditActionType,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize credit manager.
|
||||
|
||||
Args:
|
||||
credit_service: Credit service instance
|
||||
user_id: User ID
|
||||
action_type: Action type
|
||||
metadata: Optional metadata
|
||||
|
||||
"""
|
||||
self.credit_service = credit_service
|
||||
self.user_id = user_id
|
||||
self.action_type = action_type
|
||||
self.metadata = metadata
|
||||
self.validated = False
|
||||
self.success = False
|
||||
|
||||
async def __aenter__(self) -> "CreditManager":
|
||||
"""Enter context manager - validate credits."""
|
||||
await self.credit_service.validate_and_reserve_credits(
|
||||
self.user_id, self.action_type, self.metadata
|
||||
)
|
||||
self.validated = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
|
||||
"""Exit context manager - deduct credits based on success."""
|
||||
if self.validated:
|
||||
# If no exception occurred, consider it successful
|
||||
success = exc_type is None and self.success
|
||||
await self.credit_service.deduct_credits(
|
||||
self.user_id, self.action_type, success, self.metadata
|
||||
)
|
||||
|
||||
def mark_success(self) -> None:
|
||||
"""Mark the operation as successful."""
|
||||
self.success = True
|
||||
305
tests/api/v1/test_vlc_endpoints.py
Normal file
305
tests/api/v1/test_vlc_endpoints.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""Tests for VLC player API endpoints."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models.sound import Sound
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestVLCEndpoints:
|
||||
"""Test VLC player API endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_success(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test successful sound playback via VLC."""
|
||||
# Mock the VLC player service and sound repository methods
|
||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
||||
mock_play_sound.return_value = True
|
||||
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
)
|
||||
mock_get_by_id.return_value = mock_sound
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["sound_id"] == 1
|
||||
assert data["sound_name"] == "Test Sound"
|
||||
assert "Test Sound" in data["message"]
|
||||
|
||||
# Verify service calls
|
||||
mock_get_by_id.assert_called_once_with(1)
|
||||
mock_play_sound.assert_called_once_with(mock_sound)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_sound_not_found(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test VLC playback when sound is not found."""
|
||||
# Mock the sound repository to return None
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_get_by_id.return_value = None
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "Sound with ID 999 not found" in data["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_launch_failure(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test VLC playback when VLC launch fails."""
|
||||
# Mock the VLC player service to fail
|
||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
||||
mock_play_sound.return_value = False
|
||||
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
)
|
||||
mock_get_by_id.return_value = mock_sound
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||
|
||||
assert response.status_code == 500
|
||||
data = response.json()
|
||||
assert "Failed to launch VLC for sound playback" in data["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_service_exception(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test VLC playback when service raises an exception."""
|
||||
# Mock the sound repository to raise an exception
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_get_by_id.side_effect = Exception("Database error")
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||
|
||||
assert response.status_code == 500
|
||||
data = response.json()
|
||||
assert "Failed to play sound" in data["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_unauthenticated(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
):
|
||||
"""Test VLC playback without authentication."""
|
||||
response = await client.post("/api/v1/sounds/vlc/play/1")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_success(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test successful stopping of all VLC instances."""
|
||||
# Mock the VLC player service
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 3,
|
||||
"processes_killed": 3,
|
||||
"processes_remaining": 0,
|
||||
"message": "Killed 3 VLC processes",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["processes_found"] == 3
|
||||
assert data["processes_killed"] == 3
|
||||
assert data["processes_remaining"] == 0
|
||||
assert "Killed 3 VLC processes" in data["message"]
|
||||
|
||||
# Verify service call
|
||||
mock_stop_all.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_no_processes(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test stopping VLC instances when none are running."""
|
||||
# Mock the VLC player service
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 0,
|
||||
"processes_killed": 0,
|
||||
"message": "No VLC processes found",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["processes_found"] == 0
|
||||
assert data["processes_killed"] == 0
|
||||
assert data["message"] == "No VLC processes found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_partial_success(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test stopping VLC instances with partial success."""
|
||||
# Mock the VLC player service
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 3,
|
||||
"processes_killed": 2,
|
||||
"processes_remaining": 1,
|
||||
"message": "Killed 2 VLC processes",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["processes_found"] == 3
|
||||
assert data["processes_killed"] == 2
|
||||
assert data["processes_remaining"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_failure(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test stopping VLC instances when service fails."""
|
||||
# Mock the VLC player service
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": False,
|
||||
"processes_found": 0,
|
||||
"processes_killed": 0,
|
||||
"error": "Command failed",
|
||||
"message": "Failed to stop VLC processes",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["error"] == "Command failed"
|
||||
assert data["message"] == "Failed to stop VLC processes"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_service_exception(
|
||||
self,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test stopping VLC instances when service raises an exception."""
|
||||
# Mock the VLC player service to raise an exception
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_stop_all.side_effect = Exception("Service error")
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 500
|
||||
data = response.json()
|
||||
assert "Failed to stop VLC instances" in data["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_unauthenticated(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
):
|
||||
"""Test stopping VLC instances without authentication."""
|
||||
response = await client.post("/api/v1/sounds/vlc/stop-all")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vlc_endpoints_with_admin_user(
|
||||
self,
|
||||
authenticated_admin_client: AsyncClient,
|
||||
admin_user: User,
|
||||
):
|
||||
"""Test VLC endpoints work with admin user."""
|
||||
# Test play endpoint with admin
|
||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
||||
mock_play_sound.return_value = True
|
||||
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Admin Test Sound",
|
||||
filename="admin_test.mp3",
|
||||
duration=3000,
|
||||
size=512,
|
||||
hash="admin_hash",
|
||||
)
|
||||
mock_get_by_id.return_value = mock_sound
|
||||
|
||||
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["sound_name"] == "Admin Test Sound"
|
||||
|
||||
# Test stop-all endpoint with admin
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 1,
|
||||
"processes_killed": 1,
|
||||
"processes_remaining": 0,
|
||||
"message": "Killed 1 VLC processes",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
|
||||
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["processes_killed"] == 1
|
||||
@@ -13,8 +13,10 @@ from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy
|
||||
from app.models.plan import Plan
|
||||
from app.models.user import User
|
||||
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
|
||||
from app.utils.auth import JWTUtils, PasswordUtils
|
||||
|
||||
|
||||
|
||||
412
tests/repositories/test_credit_transaction.py
Normal file
412
tests/repositories/test_credit_transaction.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""Tests for credit transaction repository."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.user import User
|
||||
from app.repositories.credit_transaction import CreditTransactionRepository
|
||||
|
||||
|
||||
class TestCreditTransactionRepository:
|
||||
"""Test credit transaction repository operations."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def credit_transaction_repository(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[CreditTransactionRepository, None]: # type: ignore[misc]
|
||||
"""Create a credit transaction repository instance."""
|
||||
yield CreditTransactionRepository(test_session)
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_id(
|
||||
self,
|
||||
test_user: User,
|
||||
) -> int:
|
||||
"""Get test user ID to avoid lazy loading issues."""
|
||||
return test_user.id
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_transactions(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
test_user_id: int,
|
||||
) -> AsyncGenerator[list[CreditTransaction], None]: # type: ignore[misc]
|
||||
"""Create test credit transactions."""
|
||||
transactions = []
|
||||
user_id = test_user_id
|
||||
|
||||
# Create various types of transactions
|
||||
transaction_data = [
|
||||
{
|
||||
"user_id": user_id,
|
||||
"action_type": "vlc_play_sound",
|
||||
"amount": -1,
|
||||
"balance_before": 10,
|
||||
"balance_after": 9,
|
||||
"description": "Play sound via VLC",
|
||||
"success": True,
|
||||
"metadata_json": json.dumps({"sound_id": 1, "sound_name": "test.mp3"}),
|
||||
},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"action_type": "audio_extraction",
|
||||
"amount": -5,
|
||||
"balance_before": 9,
|
||||
"balance_after": 4,
|
||||
"description": "Extract audio from URL",
|
||||
"success": True,
|
||||
"metadata_json": json.dumps({"url": "https://example.com/video"}),
|
||||
},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"action_type": "vlc_play_sound",
|
||||
"amount": 0,
|
||||
"balance_before": 4,
|
||||
"balance_after": 4,
|
||||
"description": "Play sound via VLC (failed)",
|
||||
"success": False,
|
||||
"metadata_json": json.dumps({"sound_id": 2, "error": "File not found"}),
|
||||
},
|
||||
{
|
||||
"user_id": user_id,
|
||||
"action_type": "credit_addition",
|
||||
"amount": 50,
|
||||
"balance_before": 4,
|
||||
"balance_after": 54,
|
||||
"description": "Bonus credits",
|
||||
"success": True,
|
||||
"metadata_json": json.dumps({"reason": "signup_bonus"}),
|
||||
},
|
||||
]
|
||||
|
||||
for data in transaction_data:
|
||||
transaction = CreditTransaction(**data)
|
||||
test_session.add(transaction)
|
||||
transactions.append(transaction)
|
||||
|
||||
await test_session.commit()
|
||||
for transaction in transactions:
|
||||
await test_session.refresh(transaction)
|
||||
|
||||
yield transactions
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def other_user_transaction(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
ensure_plans: tuple, # noqa: ARG002
|
||||
) -> AsyncGenerator[CreditTransaction, None]: # type: ignore[misc]
|
||||
"""Create a transaction for a different user."""
|
||||
from app.models.plan import Plan
|
||||
from app.repositories.user import UserRepository
|
||||
|
||||
# Create another user
|
||||
user_repo = UserRepository(test_session)
|
||||
other_user_data = {
|
||||
"email": "other@example.com",
|
||||
"name": "Other User",
|
||||
"password_hash": "hashed_password",
|
||||
"role": "user",
|
||||
"is_active": True,
|
||||
}
|
||||
other_user = await user_repo.create(other_user_data)
|
||||
|
||||
# Create transaction for the other user
|
||||
transaction_data = {
|
||||
"user_id": other_user.id,
|
||||
"action_type": "vlc_play_sound",
|
||||
"amount": -1,
|
||||
"balance_before": 100,
|
||||
"balance_after": 99,
|
||||
"description": "Other user play sound",
|
||||
"success": True,
|
||||
"metadata_json": None,
|
||||
}
|
||||
transaction = CreditTransaction(**transaction_data)
|
||||
test_session.add(transaction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(transaction)
|
||||
|
||||
yield transaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_existing(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test getting transaction by ID when it exists."""
|
||||
transaction = await credit_transaction_repository.get_by_id(test_transactions[0].id)
|
||||
|
||||
assert transaction is not None
|
||||
assert transaction.id == test_transactions[0].id
|
||||
assert transaction.action_type == "vlc_play_sound"
|
||||
assert transaction.amount == -1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_nonexistent(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
) -> None:
|
||||
"""Test getting transaction by ID when it doesn't exist."""
|
||||
transaction = await credit_transaction_repository.get_by_id(99999)
|
||||
|
||||
assert transaction is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_id(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
other_user_transaction: CreditTransaction,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting transactions by user ID."""
|
||||
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
|
||||
|
||||
# Should return all transactions for test_user
|
||||
assert len(transactions) == 4
|
||||
# Should be ordered by created_at desc (newest first)
|
||||
assert all(t.user_id == test_user_id for t in transactions)
|
||||
|
||||
# Should not include other user's transaction
|
||||
other_user_ids = [t.user_id for t in transactions]
|
||||
assert other_user_transaction.user_id not in other_user_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_id_with_pagination(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting transactions by user ID with pagination."""
|
||||
# Get first 2 transactions
|
||||
first_page = await credit_transaction_repository.get_by_user_id(
|
||||
test_user_id, limit=2, offset=0
|
||||
)
|
||||
assert len(first_page) == 2
|
||||
|
||||
# Get next 2 transactions
|
||||
second_page = await credit_transaction_repository.get_by_user_id(
|
||||
test_user_id, limit=2, offset=2
|
||||
)
|
||||
assert len(second_page) == 2
|
||||
|
||||
# Should not overlap
|
||||
first_page_ids = {t.id for t in first_page}
|
||||
second_page_ids = {t.id for t in second_page}
|
||||
assert first_page_ids.isdisjoint(second_page_ids)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_action_type(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test getting transactions by action type."""
|
||||
vlc_transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"vlc_play_sound"
|
||||
)
|
||||
|
||||
# Should return 2 VLC transactions (1 successful, 1 failed)
|
||||
assert len(vlc_transactions) >= 2
|
||||
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
|
||||
|
||||
extraction_transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"audio_extraction"
|
||||
)
|
||||
|
||||
# Should return 1 extraction transaction
|
||||
assert len(extraction_transactions) >= 1
|
||||
assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_action_type_with_pagination(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test getting transactions by action type with pagination."""
|
||||
# Test with limit
|
||||
transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"vlc_play_sound", limit=1
|
||||
)
|
||||
assert len(transactions) == 1
|
||||
assert transactions[0].action_type == "vlc_play_sound"
|
||||
|
||||
# Test with offset
|
||||
transactions = await credit_transaction_repository.get_by_action_type(
|
||||
"vlc_play_sound", limit=1, offset=1
|
||||
)
|
||||
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_successful_transactions(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test getting only successful transactions."""
|
||||
successful_transactions = await credit_transaction_repository.get_successful_transactions()
|
||||
|
||||
# Should only return successful transactions
|
||||
assert all(t.success is True for t in successful_transactions)
|
||||
# Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition)
|
||||
assert len(successful_transactions) >= 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_successful_transactions_by_user(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
other_user_transaction: CreditTransaction,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting successful transactions filtered by user."""
|
||||
successful_transactions = await credit_transaction_repository.get_successful_transactions(
|
||||
user_id=test_user_id
|
||||
)
|
||||
|
||||
# Should only return successful transactions for test_user
|
||||
assert all(t.success is True for t in successful_transactions)
|
||||
assert all(t.user_id == test_user_id for t in successful_transactions)
|
||||
# Should be 3 successful transactions for test_user
|
||||
assert len(successful_transactions) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_successful_transactions_with_pagination(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting successful transactions with pagination."""
|
||||
# Get first 2 successful transactions
|
||||
first_page = await credit_transaction_repository.get_successful_transactions(
|
||||
user_id=test_user_id, limit=2, offset=0
|
||||
)
|
||||
assert len(first_page) == 2
|
||||
assert all(t.success is True for t in first_page)
|
||||
|
||||
# Get next successful transaction
|
||||
second_page = await credit_transaction_repository.get_successful_transactions(
|
||||
user_id=test_user_id, limit=2, offset=2
|
||||
)
|
||||
assert len(second_page) == 1 # Should be 1 remaining
|
||||
assert all(t.success is True for t in second_page)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_transactions(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
other_user_transaction: CreditTransaction,
|
||||
) -> None:
|
||||
"""Test getting all transactions."""
|
||||
all_transactions = await credit_transaction_repository.get_all()
|
||||
|
||||
# Should return all transactions
|
||||
assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_transaction(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test creating a new transaction."""
|
||||
transaction_data = {
|
||||
"user_id": test_user_id,
|
||||
"action_type": "test_action",
|
||||
"amount": -10,
|
||||
"balance_before": 100,
|
||||
"balance_after": 90,
|
||||
"description": "Test transaction",
|
||||
"success": True,
|
||||
"metadata_json": json.dumps({"test": "data"}),
|
||||
}
|
||||
|
||||
transaction = await credit_transaction_repository.create(transaction_data)
|
||||
|
||||
assert transaction.id is not None
|
||||
assert transaction.user_id == test_user_id
|
||||
assert transaction.action_type == "test_action"
|
||||
assert transaction.amount == -10
|
||||
assert transaction.balance_before == 100
|
||||
assert transaction.balance_after == 90
|
||||
assert transaction.success is True
|
||||
assert json.loads(transaction.metadata_json) == {"test": "data"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_transaction(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
) -> None:
|
||||
"""Test updating a transaction."""
|
||||
transaction = test_transactions[0]
|
||||
update_data = {
|
||||
"description": "Updated description",
|
||||
"metadata_json": json.dumps({"updated": True}),
|
||||
}
|
||||
|
||||
updated_transaction = await credit_transaction_repository.update(
|
||||
transaction, update_data
|
||||
)
|
||||
|
||||
assert updated_transaction.id == transaction.id
|
||||
assert updated_transaction.description == "Updated description"
|
||||
assert json.loads(updated_transaction.metadata_json) == {"updated": True}
|
||||
# Other fields should remain unchanged
|
||||
assert updated_transaction.amount == transaction.amount
|
||||
assert updated_transaction.action_type == transaction.action_type
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_transaction(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_session: AsyncSession,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test deleting a transaction."""
|
||||
# Create a transaction to delete
|
||||
transaction_data = {
|
||||
"user_id": test_user_id,
|
||||
"action_type": "to_delete",
|
||||
"amount": -1,
|
||||
"balance_before": 10,
|
||||
"balance_after": 9,
|
||||
"description": "To be deleted",
|
||||
"success": True,
|
||||
"metadata_json": None,
|
||||
}
|
||||
transaction = await credit_transaction_repository.create(transaction_data)
|
||||
transaction_id = transaction.id
|
||||
|
||||
# Delete the transaction
|
||||
await credit_transaction_repository.delete(transaction)
|
||||
|
||||
# Verify transaction is deleted
|
||||
deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id)
|
||||
assert deleted_transaction is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_ordering(
|
||||
self,
|
||||
credit_transaction_repository: CreditTransactionRepository,
|
||||
test_transactions: list[CreditTransaction],
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test that transactions are ordered by created_at desc."""
|
||||
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
|
||||
|
||||
# Should be ordered by created_at desc (newest first)
|
||||
for i in range(len(transactions) - 1):
|
||||
assert transactions[i].created_at >= transactions[i + 1].created_at
|
||||
376
tests/repositories/test_sound.py
Normal file
376
tests/repositories/test_sound.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""Tests for sound repository."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.sound import Sound
|
||||
from app.repositories.sound import SoundRepository
|
||||
|
||||
|
||||
class TestSoundRepository:
|
||||
"""Test sound repository operations."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def sound_repository(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[SoundRepository, None]: # type: ignore[misc]
|
||||
"""Create a sound repository instance."""
|
||||
yield SoundRepository(test_session)
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_sound(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
|
||||
"""Create a test sound."""
|
||||
sound_data = {
|
||||
"name": "Test Sound",
|
||||
"filename": "test_sound.mp3",
|
||||
"type": "SDB",
|
||||
"duration": 5000,
|
||||
"size": 1024000,
|
||||
"hash": "test_hash_123",
|
||||
"play_count": 0,
|
||||
"is_normalized": False,
|
||||
}
|
||||
sound = Sound(**sound_data)
|
||||
test_session.add(sound)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(sound)
|
||||
yield sound
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def normalized_sound(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
|
||||
"""Create a normalized test sound."""
|
||||
sound_data = {
|
||||
"name": "Normalized Sound",
|
||||
"filename": "normalized_sound.mp3",
|
||||
"type": "TTS",
|
||||
"duration": 3000,
|
||||
"size": 512000,
|
||||
"hash": "normalized_hash_456",
|
||||
"play_count": 5,
|
||||
"is_normalized": True,
|
||||
"normalized_filename": "normalized_sound_norm.mp3",
|
||||
"normalized_duration": 3000,
|
||||
"normalized_size": 480000,
|
||||
"normalized_hash": "normalized_hash_norm_456",
|
||||
}
|
||||
sound = Sound(**sound_data)
|
||||
test_session.add(sound)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(sound)
|
||||
yield sound
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_existing(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting sound by ID when it exists."""
|
||||
sound = await sound_repository.get_by_id(test_sound.id)
|
||||
|
||||
assert sound is not None
|
||||
assert sound.id == test_sound.id
|
||||
assert sound.name == test_sound.name
|
||||
assert sound.filename == test_sound.filename
|
||||
assert sound.type == test_sound.type
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_nonexistent(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
) -> None:
|
||||
"""Test getting sound by ID when it doesn't exist."""
|
||||
sound = await sound_repository.get_by_id(99999)
|
||||
|
||||
assert sound is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_filename_existing(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting sound by filename when it exists."""
|
||||
sound = await sound_repository.get_by_filename(test_sound.filename)
|
||||
|
||||
assert sound is not None
|
||||
assert sound.id == test_sound.id
|
||||
assert sound.filename == test_sound.filename
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_filename_nonexistent(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
) -> None:
|
||||
"""Test getting sound by filename when it doesn't exist."""
|
||||
sound = await sound_repository.get_by_filename("nonexistent.mp3")
|
||||
|
||||
assert sound is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_hash_existing(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting sound by hash when it exists."""
|
||||
sound = await sound_repository.get_by_hash(test_sound.hash)
|
||||
|
||||
assert sound is not None
|
||||
assert sound.id == test_sound.id
|
||||
assert sound.hash == test_sound.hash
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_hash_nonexistent(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
) -> None:
|
||||
"""Test getting sound by hash when it doesn't exist."""
|
||||
sound = await sound_repository.get_by_hash("nonexistent_hash")
|
||||
|
||||
assert sound is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_type(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting sounds by type."""
|
||||
sdb_sounds = await sound_repository.get_by_type("SDB")
|
||||
tts_sounds = await sound_repository.get_by_type("TTS")
|
||||
ext_sounds = await sound_repository.get_by_type("EXT")
|
||||
|
||||
# Should find the SDB sound
|
||||
assert len(sdb_sounds) >= 1
|
||||
assert any(sound.id == test_sound.id for sound in sdb_sounds)
|
||||
|
||||
# Should find the TTS sound
|
||||
assert len(tts_sounds) >= 1
|
||||
assert any(sound.id == normalized_sound.id for sound in tts_sounds)
|
||||
|
||||
# Should not find any EXT sounds
|
||||
assert len(ext_sounds) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_sound(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
) -> None:
|
||||
"""Test creating a new sound."""
|
||||
sound_data = {
|
||||
"name": "New Sound",
|
||||
"filename": "new_sound.wav",
|
||||
"type": "EXT",
|
||||
"duration": 7500,
|
||||
"size": 2048000,
|
||||
"hash": "new_hash_789",
|
||||
"play_count": 0,
|
||||
"is_normalized": False,
|
||||
}
|
||||
|
||||
sound = await sound_repository.create(sound_data)
|
||||
|
||||
assert sound.id is not None
|
||||
assert sound.name == sound_data["name"]
|
||||
assert sound.filename == sound_data["filename"]
|
||||
assert sound.type == sound_data["type"]
|
||||
assert sound.duration == sound_data["duration"]
|
||||
assert sound.size == sound_data["size"]
|
||||
assert sound.hash == sound_data["hash"]
|
||||
assert sound.play_count == 0
|
||||
assert sound.is_normalized is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_sound(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test updating a sound."""
|
||||
update_data = {
|
||||
"name": "Updated Sound Name",
|
||||
"play_count": 10,
|
||||
"is_normalized": True,
|
||||
"normalized_filename": "updated_norm.mp3",
|
||||
}
|
||||
|
||||
updated_sound = await sound_repository.update(test_sound, update_data)
|
||||
|
||||
assert updated_sound.id == test_sound.id
|
||||
assert updated_sound.name == "Updated Sound Name"
|
||||
assert updated_sound.play_count == 10
|
||||
assert updated_sound.is_normalized is True
|
||||
assert updated_sound.normalized_filename == "updated_norm.mp3"
|
||||
assert updated_sound.filename == test_sound.filename # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_sound(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_session: AsyncSession,
|
||||
) -> None:
|
||||
"""Test deleting a sound."""
|
||||
# Create a sound to delete
|
||||
sound_data = {
|
||||
"name": "To Delete",
|
||||
"filename": "to_delete.mp3",
|
||||
"type": "SDB",
|
||||
"duration": 1000,
|
||||
"size": 256000,
|
||||
"hash": "delete_hash",
|
||||
"play_count": 0,
|
||||
"is_normalized": False,
|
||||
}
|
||||
sound = await sound_repository.create(sound_data)
|
||||
sound_id = sound.id
|
||||
|
||||
# Delete the sound
|
||||
await sound_repository.delete(sound)
|
||||
|
||||
# Verify sound is deleted
|
||||
deleted_sound = await sound_repository.get_by_id(sound_id)
|
||||
assert deleted_sound is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_name(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test searching sounds by name."""
|
||||
# Search for "test" should find test_sound
|
||||
results = await sound_repository.search_by_name("test")
|
||||
assert len(results) >= 1
|
||||
assert any(sound.id == test_sound.id for sound in results)
|
||||
|
||||
# Search for "normalized" should find normalized_sound
|
||||
results = await sound_repository.search_by_name("normalized")
|
||||
assert len(results) >= 1
|
||||
assert any(sound.id == normalized_sound.id for sound in results)
|
||||
|
||||
# Case insensitive search
|
||||
results = await sound_repository.search_by_name("TEST")
|
||||
assert len(results) >= 1
|
||||
assert any(sound.id == test_sound.id for sound in results)
|
||||
|
||||
# Partial match
|
||||
results = await sound_repository.search_by_name("norm")
|
||||
assert len(results) >= 1
|
||||
assert any(sound.id == normalized_sound.id for sound in results)
|
||||
|
||||
# No matches
|
||||
results = await sound_repository.search_by_name("nonexistent")
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_popular_sounds(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting popular sounds."""
|
||||
# Update play counts to test ordering
|
||||
await sound_repository.update(test_sound, {"play_count": 15})
|
||||
await sound_repository.update(normalized_sound, {"play_count": 5})
|
||||
|
||||
# Create another sound with higher play count
|
||||
high_play_sound_data = {
|
||||
"name": "Popular Sound",
|
||||
"filename": "popular.mp3",
|
||||
"type": "SDB",
|
||||
"duration": 2000,
|
||||
"size": 300000,
|
||||
"hash": "popular_hash",
|
||||
"play_count": 25,
|
||||
"is_normalized": False,
|
||||
}
|
||||
high_play_sound = await sound_repository.create(high_play_sound_data)
|
||||
|
||||
# Get popular sounds
|
||||
popular_sounds = await sound_repository.get_popular_sounds(limit=10)
|
||||
|
||||
assert len(popular_sounds) >= 3
|
||||
# Should be ordered by play_count desc
|
||||
assert popular_sounds[0].play_count >= popular_sounds[1].play_count
|
||||
# The highest play count sound should be first
|
||||
assert popular_sounds[0].id == high_play_sound.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unnormalized_sounds(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting unnormalized sounds."""
|
||||
unnormalized_sounds = await sound_repository.get_unnormalized_sounds()
|
||||
|
||||
# Should include test_sound (not normalized)
|
||||
assert any(sound.id == test_sound.id for sound in unnormalized_sounds)
|
||||
# Should not include normalized_sound (already normalized)
|
||||
assert not any(sound.id == normalized_sound.id for sound in unnormalized_sounds)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unnormalized_sounds_by_type(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
normalized_sound: Sound,
|
||||
) -> None:
|
||||
"""Test getting unnormalized sounds by type."""
|
||||
# Get unnormalized SDB sounds
|
||||
sdb_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("SDB")
|
||||
# Should include test_sound (SDB, not normalized)
|
||||
assert any(sound.id == test_sound.id for sound in sdb_unnormalized)
|
||||
|
||||
# Get unnormalized TTS sounds
|
||||
tts_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("TTS")
|
||||
# Should not include normalized_sound (TTS, but already normalized)
|
||||
assert not any(sound.id == normalized_sound.id for sound in tts_unnormalized)
|
||||
|
||||
# Get unnormalized EXT sounds
|
||||
ext_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("EXT")
|
||||
# Should be empty
|
||||
assert len(ext_unnormalized) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_hash(
|
||||
self,
|
||||
sound_repository: SoundRepository,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test creating sound with duplicate hash is allowed."""
|
||||
# Store the hash to avoid lazy loading issues
|
||||
original_hash = test_sound.hash
|
||||
|
||||
duplicate_sound_data = {
|
||||
"name": "Duplicate Hash Sound",
|
||||
"filename": "duplicate.mp3",
|
||||
"type": "SDB",
|
||||
"duration": 1000,
|
||||
"size": 100000,
|
||||
"hash": original_hash, # Same hash as test_sound
|
||||
"play_count": 0,
|
||||
"is_normalized": False,
|
||||
}
|
||||
|
||||
# Should succeed - duplicate hashes are allowed
|
||||
duplicate_sound = await sound_repository.create(duplicate_sound_data)
|
||||
|
||||
assert duplicate_sound.id is not None
|
||||
assert duplicate_sound.name == "Duplicate Hash Sound"
|
||||
assert duplicate_sound.hash == original_hash # Same hash is allowed
|
||||
268
tests/repositories/test_user_oauth.py
Normal file
268
tests/repositories/test_user_oauth.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Tests for user OAuth repository."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.user_oauth import UserOauth
|
||||
from app.repositories.user_oauth import UserOauthRepository
|
||||
|
||||
|
||||
class TestUserOauthRepository:
|
||||
"""Test user OAuth repository operations."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user_oauth_repository(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
) -> AsyncGenerator[UserOauthRepository, None]: # type: ignore[misc]
|
||||
"""Create a user OAuth repository instance."""
|
||||
yield UserOauthRepository(test_session)
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_id(
|
||||
self,
|
||||
test_user: User,
|
||||
) -> int:
|
||||
"""Get test user ID to avoid lazy loading issues."""
|
||||
return test_user.id
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_oauth(
|
||||
self,
|
||||
test_session: AsyncSession,
|
||||
test_user_id: int,
|
||||
) -> AsyncGenerator[UserOauth, None]: # type: ignore[misc]
|
||||
"""Create a test OAuth record."""
|
||||
oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "google",
|
||||
"provider_user_id": "google_123456",
|
||||
"email": "test@gmail.com",
|
||||
"name": "Test User Google",
|
||||
"picture": None,
|
||||
}
|
||||
oauth = UserOauth(**oauth_data)
|
||||
test_session.add(oauth)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(oauth)
|
||||
yield oauth
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_provider_user_id_existing(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_oauth: UserOauth,
|
||||
) -> None:
|
||||
"""Test getting OAuth by provider user ID when it exists."""
|
||||
oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||
"google", "google_123456"
|
||||
)
|
||||
|
||||
assert oauth is not None
|
||||
assert oauth.id == test_oauth.id
|
||||
assert oauth.provider == "google"
|
||||
assert oauth.provider_user_id == "google_123456"
|
||||
assert oauth.user_id == test_oauth.user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_provider_user_id_nonexistent(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
) -> None:
|
||||
"""Test getting OAuth by provider user ID when it doesn't exist."""
|
||||
oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||
"google", "nonexistent_id"
|
||||
)
|
||||
|
||||
assert oauth is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_id_and_provider_existing(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_oauth: UserOauth,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting OAuth by user ID and provider when it exists."""
|
||||
oauth = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
test_user_id, "google"
|
||||
)
|
||||
|
||||
assert oauth is not None
|
||||
assert oauth.id == test_oauth.id
|
||||
assert oauth.provider == "google"
|
||||
assert oauth.user_id == test_user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_user_id_and_provider_nonexistent(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test getting OAuth by user ID and provider when it doesn't exist."""
|
||||
oauth = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
test_user_id, "github"
|
||||
)
|
||||
|
||||
assert oauth is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_oauth(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test creating a new OAuth record."""
|
||||
oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "github",
|
||||
"provider_user_id": "github_789",
|
||||
"email": "test@github.com",
|
||||
"name": "Test User GitHub",
|
||||
"picture": None,
|
||||
}
|
||||
|
||||
oauth = await user_oauth_repository.create(oauth_data)
|
||||
|
||||
assert oauth.id is not None
|
||||
assert oauth.user_id == test_user_id
|
||||
assert oauth.provider == "github"
|
||||
assert oauth.provider_user_id == "github_789"
|
||||
assert oauth.email == "test@github.com"
|
||||
assert oauth.name == "Test User GitHub"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_oauth(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_oauth: UserOauth,
|
||||
) -> None:
|
||||
"""Test updating an OAuth record."""
|
||||
update_data = {
|
||||
"email": "updated@gmail.com",
|
||||
"name": "Updated User Name",
|
||||
"picture": "https://example.com/photo.jpg",
|
||||
}
|
||||
|
||||
updated_oauth = await user_oauth_repository.update(test_oauth, update_data)
|
||||
|
||||
assert updated_oauth.id == test_oauth.id
|
||||
assert updated_oauth.email == "updated@gmail.com"
|
||||
assert updated_oauth.name == "Updated User Name"
|
||||
assert updated_oauth.picture == "https://example.com/photo.jpg"
|
||||
assert updated_oauth.provider == test_oauth.provider # Unchanged
|
||||
assert updated_oauth.provider_user_id == test_oauth.provider_user_id # Unchanged
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_oauth(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_session: AsyncSession,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test deleting an OAuth record."""
|
||||
# Create an OAuth record to delete
|
||||
oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "twitter",
|
||||
"provider_user_id": "twitter_456",
|
||||
"email": "test@twitter.com",
|
||||
"name": "Test User Twitter",
|
||||
"picture": None,
|
||||
}
|
||||
oauth = await user_oauth_repository.create(oauth_data)
|
||||
oauth_id = oauth.id
|
||||
|
||||
# Delete the OAuth record
|
||||
await user_oauth_repository.delete(oauth)
|
||||
|
||||
# Verify it's deleted by trying to find it
|
||||
deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
|
||||
"twitter", "twitter_456"
|
||||
)
|
||||
assert deleted_oauth is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_provider_user_id(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_oauth: UserOauth,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test creating OAuth with duplicate provider user ID should fail."""
|
||||
# Try to create another OAuth with the same provider and provider_user_id
|
||||
duplicate_oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "google",
|
||||
"provider_user_id": "google_123456", # Same as test_oauth
|
||||
"email": "another@gmail.com",
|
||||
"name": "Another User",
|
||||
"picture": None,
|
||||
}
|
||||
|
||||
# This should fail due to unique constraint
|
||||
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
|
||||
await user_oauth_repository.create(duplicate_oauth_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_providers_same_user(
|
||||
self,
|
||||
user_oauth_repository: UserOauthRepository,
|
||||
test_user_id: int,
|
||||
) -> None:
|
||||
"""Test that a user can have multiple OAuth providers."""
|
||||
# Create Google OAuth
|
||||
google_oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "google",
|
||||
"provider_user_id": "google_user_1",
|
||||
"email": "user@gmail.com",
|
||||
"name": "Test User Google",
|
||||
"picture": None,
|
||||
}
|
||||
google_oauth = await user_oauth_repository.create(google_oauth_data)
|
||||
|
||||
# Create GitHub OAuth for the same user
|
||||
github_oauth_data = {
|
||||
"user_id": test_user_id,
|
||||
"provider": "github",
|
||||
"provider_user_id": "github_user_1",
|
||||
"email": "user@github.com",
|
||||
"name": "Test User GitHub",
|
||||
"picture": None,
|
||||
}
|
||||
github_oauth = await user_oauth_repository.create(github_oauth_data)
|
||||
|
||||
# Verify both exist by querying back from database
|
||||
found_google = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
test_user_id, "google"
|
||||
)
|
||||
found_github = await user_oauth_repository.get_by_user_id_and_provider(
|
||||
test_user_id, "github"
|
||||
)
|
||||
|
||||
assert found_google is not None
|
||||
assert found_github is not None
|
||||
assert found_google.provider == "google"
|
||||
assert found_github.provider == "github"
|
||||
assert found_google.user_id == test_user_id
|
||||
assert found_github.user_id == test_user_id
|
||||
assert found_google.provider_user_id == "google_user_1"
|
||||
assert found_github.provider_user_id == "github_user_1"
|
||||
|
||||
# Verify we can also find them by provider_user_id
|
||||
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
|
||||
"google", "google_user_1"
|
||||
)
|
||||
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
|
||||
"github", "github_user_1"
|
||||
)
|
||||
|
||||
assert found_google_by_provider is not None
|
||||
assert found_github_by_provider is not None
|
||||
assert found_google_by_provider.user_id == test_user_id
|
||||
assert found_github_by_provider.user_id == test_user_id
|
||||
358
tests/services/test_credit.py
Normal file
358
tests/services/test_credit.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Tests for credit service."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.models.credit_transaction import CreditTransaction
|
||||
from app.models.user import User
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
|
||||
|
||||
class TestCreditService:
|
||||
"""Test credit service functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session_factory(self):
|
||||
"""Create a mock database session factory."""
|
||||
session = AsyncMock(spec=AsyncSession)
|
||||
return lambda: session
|
||||
|
||||
@pytest.fixture
|
||||
def credit_service(self, mock_db_session_factory):
|
||||
"""Create a credit service instance for testing."""
|
||||
return CreditService(mock_db_session_factory)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_user(self):
|
||||
"""Create a sample user for testing."""
|
||||
return User(
|
||||
id=1,
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
role="user",
|
||||
credits=10,
|
||||
plan_id=1,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_credits_sufficient(self, credit_service, sample_user):
|
||||
"""Test checking credits when user has sufficient credits."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
|
||||
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
|
||||
|
||||
assert result is True
|
||||
mock_repo.get_by_id.assert_called_once_with(1)
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_credits_insufficient(self, credit_service):
|
||||
"""Test checking credits when user has insufficient credits."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
poor_user = User(
|
||||
id=1,
|
||||
name="Poor User",
|
||||
email="poor@example.com",
|
||||
role="user",
|
||||
credits=0, # No credits
|
||||
plan_id=1,
|
||||
)
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = poor_user
|
||||
|
||||
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
|
||||
|
||||
assert result is False
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_credits_user_not_found(self, credit_service):
|
||||
"""Test checking credits when user is not found."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = None
|
||||
|
||||
result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND)
|
||||
|
||||
assert result is False
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_reserve_credits_success(self, credit_service, sample_user):
|
||||
"""Test successful credit validation and reservation."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
|
||||
user, action = await credit_service.validate_and_reserve_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND
|
||||
)
|
||||
|
||||
assert user == sample_user
|
||||
assert action.action_type == CreditActionType.VLC_PLAY_SOUND
|
||||
assert action.cost == 1
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_reserve_credits_insufficient(self, credit_service):
|
||||
"""Test credit validation with insufficient credits."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
poor_user = User(
|
||||
id=1,
|
||||
name="Poor User",
|
||||
email="poor@example.com",
|
||||
role="user",
|
||||
credits=0,
|
||||
plan_id=1,
|
||||
)
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = poor_user
|
||||
|
||||
with pytest.raises(InsufficientCreditsError) as exc_info:
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND
|
||||
)
|
||||
|
||||
assert exc_info.value.required == 1
|
||||
assert exc_info.value.available == 0
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_and_reserve_credits_user_not_found(self, credit_service):
|
||||
"""Test credit validation when user is not found."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="User 999 not found"):
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
999, CreditActionType.VLC_PLAY_SOUND
|
||||
)
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduct_credits_success(self, credit_service, sample_user):
|
||||
"""Test successful credit deduction."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
mock_socket_manager.send_to_user = AsyncMock()
|
||||
|
||||
transaction = await credit_service.deduct_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
|
||||
)
|
||||
|
||||
# Verify user credits were updated
|
||||
mock_repo.update.assert_called_once_with(sample_user, {"credits": 9})
|
||||
|
||||
# Verify transaction was created
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify socket event was emitted
|
||||
mock_socket_manager.send_to_user.assert_called_once_with(
|
||||
"1", "user_credits_changed", {
|
||||
"user_id": "1",
|
||||
"credits_before": 10,
|
||||
"credits_after": 9,
|
||||
"credits_deducted": 1,
|
||||
"action_type": "vlc_play_sound",
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Check transaction details
|
||||
added_transaction = mock_session.add.call_args[0][0]
|
||||
assert isinstance(added_transaction, CreditTransaction)
|
||||
assert added_transaction.user_id == 1
|
||||
assert added_transaction.action_type == "vlc_play_sound"
|
||||
assert added_transaction.amount == -1
|
||||
assert added_transaction.balance_before == 10
|
||||
assert added_transaction.balance_after == 9
|
||||
assert added_transaction.success is True
|
||||
assert json.loads(added_transaction.metadata_json) == {"test": "data"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user):
|
||||
"""Test credit deduction when action failed but requires success."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
mock_socket_manager.send_to_user = AsyncMock()
|
||||
|
||||
transaction = await credit_service.deduct_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND, False # Action failed
|
||||
)
|
||||
|
||||
# Verify user credits were NOT updated (action requires success)
|
||||
mock_repo.update.assert_not_called()
|
||||
|
||||
# Verify transaction was still created for auditing
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify no socket event was emitted since no credits were actually deducted
|
||||
mock_socket_manager.send_to_user.assert_not_called()
|
||||
|
||||
# Check transaction details
|
||||
added_transaction = mock_session.add.call_args[0][0]
|
||||
assert added_transaction.amount == 0 # No deduction for failed action
|
||||
assert added_transaction.balance_before == 10
|
||||
assert added_transaction.balance_after == 10 # No change
|
||||
assert added_transaction.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduct_credits_insufficient(self, credit_service):
|
||||
"""Test credit deduction with insufficient credits."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
poor_user = User(
|
||||
id=1,
|
||||
name="Poor User",
|
||||
email="poor@example.com",
|
||||
role="user",
|
||||
credits=0,
|
||||
plan_id=1,
|
||||
)
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = poor_user
|
||||
mock_socket_manager.send_to_user = AsyncMock()
|
||||
|
||||
with pytest.raises(InsufficientCreditsError):
|
||||
await credit_service.deduct_credits(
|
||||
1, CreditActionType.VLC_PLAY_SOUND, True
|
||||
)
|
||||
|
||||
# Verify no socket event was emitted since credits could not be deducted
|
||||
mock_socket_manager.send_to_user.assert_not_called()
|
||||
|
||||
mock_session.rollback.assert_called_once()
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_credits(self, credit_service, sample_user):
|
||||
"""Test adding credits to user account."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class, \
|
||||
patch("app.services.credit.socket_manager") as mock_socket_manager:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
mock_socket_manager.send_to_user = AsyncMock()
|
||||
|
||||
transaction = await credit_service.add_credits(
|
||||
1, 5, "Bonus credits", {"reason": "signup"}
|
||||
)
|
||||
|
||||
# Verify user credits were updated
|
||||
mock_repo.update.assert_called_once_with(sample_user, {"credits": 15})
|
||||
|
||||
# Verify transaction was created
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Verify socket event was emitted
|
||||
mock_socket_manager.send_to_user.assert_called_once_with(
|
||||
"1", "user_credits_changed", {
|
||||
"user_id": "1",
|
||||
"credits_before": 10,
|
||||
"credits_after": 15,
|
||||
"credits_added": 5,
|
||||
"description": "Bonus credits",
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Check transaction details
|
||||
added_transaction = mock_session.add.call_args[0][0]
|
||||
assert added_transaction.amount == 5
|
||||
assert added_transaction.balance_before == 10
|
||||
assert added_transaction.balance_after == 15
|
||||
assert added_transaction.description == "Bonus credits"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_credits_invalid_amount(self, credit_service):
|
||||
"""Test adding invalid amount of credits."""
|
||||
with pytest.raises(ValueError, match="Amount must be positive"):
|
||||
await credit_service.add_credits(1, 0, "Invalid")
|
||||
|
||||
with pytest.raises(ValueError, match="Amount must be positive"):
|
||||
await credit_service.add_credits(1, -5, "Invalid")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_balance(self, credit_service, sample_user):
|
||||
"""Test getting user credit balance."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = sample_user
|
||||
|
||||
balance = await credit_service.get_user_balance(1)
|
||||
|
||||
assert balance == 10
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_balance_user_not_found(self, credit_service):
|
||||
"""Test getting balance for non-existent user."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
with patch("app.services.credit.UserRepository") as mock_repo_class:
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo_class.return_value = mock_repo
|
||||
mock_repo.get_by_id.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="User 999 not found"):
|
||||
await credit_service.get_user_balance(999)
|
||||
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
|
||||
class TestInsufficientCreditsError:
|
||||
"""Test InsufficientCreditsError exception."""
|
||||
|
||||
def test_insufficient_credits_error_creation(self):
|
||||
"""Test creating InsufficientCreditsError."""
|
||||
error = InsufficientCreditsError(5, 2)
|
||||
assert error.required == 5
|
||||
assert error.available == 2
|
||||
assert str(error) == "Insufficient credits: 5 required, 2 available"
|
||||
@@ -21,6 +21,7 @@ from app.services.player import (
|
||||
initialize_player_service,
|
||||
shutdown_player_service,
|
||||
)
|
||||
from app.utils.audio import get_sound_file_path
|
||||
|
||||
|
||||
class TestPlayerState:
|
||||
@@ -196,7 +197,7 @@ class TestPlayerService:
|
||||
)
|
||||
player_service.state.playlist_sounds = [sound]
|
||||
|
||||
with patch.object(player_service, "_get_sound_file_path") as mock_path:
|
||||
with patch("app.services.player.get_sound_file_path") as mock_path:
|
||||
mock_file_path = Mock(spec=Path)
|
||||
mock_file_path.exists.return_value = True
|
||||
mock_path.return_value = mock_file_path
|
||||
@@ -385,51 +386,6 @@ class TestPlayerService:
|
||||
assert player_service.state.playlist_length == 2
|
||||
assert player_service.state.playlist_duration == 75000
|
||||
|
||||
def test_get_sound_file_path_normalized(self, player_service):
|
||||
"""Test getting file path for normalized sound."""
|
||||
sound = Sound(
|
||||
id=1,
|
||||
name="Test Song",
|
||||
filename="original.mp3",
|
||||
normalized_filename="normalized.mp3",
|
||||
is_normalized=True,
|
||||
type="SDB",
|
||||
)
|
||||
|
||||
result = player_service._get_sound_file_path(sound)
|
||||
|
||||
expected = Path("sounds/normalized/sdb/normalized.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_original(self, player_service):
|
||||
"""Test getting file path for original sound."""
|
||||
sound = Sound(
|
||||
id=1,
|
||||
name="Test Song",
|
||||
filename="original.mp3",
|
||||
is_normalized=False,
|
||||
type="SDB",
|
||||
)
|
||||
|
||||
result = player_service._get_sound_file_path(sound)
|
||||
|
||||
expected = Path("sounds/originals/sdb/original.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_ext_type(self, player_service):
|
||||
"""Test getting file path for EXT type sound."""
|
||||
sound = Sound(
|
||||
id=1,
|
||||
name="Test Song",
|
||||
filename="extracted.mp3",
|
||||
is_normalized=False,
|
||||
type="EXT",
|
||||
)
|
||||
|
||||
result = player_service._get_sound_file_path(sound)
|
||||
|
||||
expected = Path("sounds/originals/extracted/extracted.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_next_index_continuous_mode(self, player_service):
|
||||
"""Test getting next index in continuous mode."""
|
||||
@@ -538,36 +494,24 @@ class TestPlayerService:
|
||||
|
||||
# Mock repositories
|
||||
with patch("app.services.player.SoundRepository") as mock_sound_repo_class:
|
||||
with patch("app.services.player.UserRepository") as mock_user_repo_class:
|
||||
mock_sound_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
mock_sound_repo_class.return_value = mock_sound_repo
|
||||
mock_user_repo_class.return_value = mock_user_repo
|
||||
mock_sound_repo = AsyncMock()
|
||||
mock_sound_repo_class.return_value = mock_sound_repo
|
||||
|
||||
# Mock sound and user
|
||||
mock_sound = Mock()
|
||||
mock_sound.play_count = 5
|
||||
mock_sound_repo.get_by_id.return_value = mock_sound
|
||||
# Mock sound
|
||||
mock_sound = Mock()
|
||||
mock_sound.play_count = 5
|
||||
mock_sound_repo.get_by_id.return_value = mock_sound
|
||||
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user_repo.get_by_id.return_value = mock_user
|
||||
await player_service._record_play_count(1)
|
||||
|
||||
# Mock no existing SoundPlayed record
|
||||
mock_result = Mock()
|
||||
mock_result.first.return_value = None
|
||||
mock_session.exec.return_value = mock_result
|
||||
# Verify sound play count was updated
|
||||
mock_sound_repo.update.assert_called_once_with(
|
||||
mock_sound, {"play_count": 6}
|
||||
)
|
||||
|
||||
await player_service._record_play_count(1)
|
||||
|
||||
# Verify sound play count was updated
|
||||
mock_sound_repo.update.assert_called_once_with(
|
||||
mock_sound, {"play_count": 6}
|
||||
)
|
||||
|
||||
# Verify SoundPlayed record was created
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
# Verify SoundPlayed record was created with None user_id for player
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_get_state(self, player_service):
|
||||
"""Test getting current player state."""
|
||||
@@ -577,6 +521,27 @@ class TestPlayerService:
|
||||
assert "mode" in result
|
||||
assert "volume" in result
|
||||
|
||||
def test_uses_shared_sound_path_utility(self, player_service):
|
||||
"""Test that player service uses the shared sound path utility."""
|
||||
sound = Sound(
|
||||
id=1,
|
||||
name="Test Song",
|
||||
filename="test.mp3",
|
||||
type="SDB",
|
||||
is_normalized=False,
|
||||
)
|
||||
player_service.state.playlist_sounds = [sound]
|
||||
|
||||
with patch("app.services.player.get_sound_file_path") as mock_path:
|
||||
mock_file_path = Mock(spec=Path)
|
||||
mock_file_path.exists.return_value = False # File doesn't exist
|
||||
mock_path.return_value = mock_file_path
|
||||
|
||||
# This should fail because file doesn't exist
|
||||
result = asyncio.run(player_service.play(0))
|
||||
# Verify the utility was called
|
||||
mock_path.assert_called_once_with(sound)
|
||||
|
||||
|
||||
class TestPlayerServiceGlobalFunctions:
|
||||
"""Test global player service functions."""
|
||||
|
||||
511
tests/services/test_vlc_player.py
Normal file
511
tests/services/test_vlc_player.py
Normal file
@@ -0,0 +1,511 @@
|
||||
"""Tests for VLC player service."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.sound import Sound
|
||||
from app.models.sound_played import SoundPlayed
|
||||
from app.models.user import User
|
||||
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
||||
from app.utils.audio import get_sound_file_path
|
||||
|
||||
|
||||
class TestVLCPlayerService:
|
||||
"""Test VLC player service."""
|
||||
|
||||
@pytest.fixture
|
||||
def vlc_service(self):
|
||||
"""Create a VLC service instance."""
|
||||
with patch("app.services.vlc_player.subprocess.run") as mock_run:
|
||||
# Mock VLC executable detection
|
||||
mock_run.return_value.returncode = 0
|
||||
return VLCPlayerService()
|
||||
|
||||
@pytest.fixture
|
||||
def vlc_service_with_db(self):
|
||||
"""Create a VLC service instance with database session factory."""
|
||||
with patch("app.services.vlc_player.subprocess.run") as mock_run:
|
||||
# Mock VLC executable detection
|
||||
mock_run.return_value.returncode = 0
|
||||
mock_session_factory = Mock()
|
||||
return VLCPlayerService(mock_session_factory)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sound(self):
|
||||
"""Create a sample sound for testing."""
|
||||
return Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test_audio.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
is_normalized=False,
|
||||
normalized_filename=None,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def normalized_sound(self):
|
||||
"""Create a normalized sound for testing."""
|
||||
return Sound(
|
||||
id=2,
|
||||
type="TTS",
|
||||
name="Normalized Sound",
|
||||
filename="original.wav",
|
||||
duration=7500,
|
||||
size=2048,
|
||||
hash="normalized_hash",
|
||||
is_normalized=True,
|
||||
normalized_filename="normalized.mp3",
|
||||
)
|
||||
|
||||
def test_init(self, vlc_service):
|
||||
"""Test VLC service initialization."""
|
||||
assert vlc_service.vlc_executable is not None
|
||||
assert isinstance(vlc_service.vlc_executable, str)
|
||||
|
||||
@patch("app.services.vlc_player.subprocess.run")
|
||||
def test_find_vlc_executable_found_in_path(self, mock_run):
|
||||
"""Test VLC executable detection when found in PATH."""
|
||||
mock_run.return_value.returncode = 0
|
||||
service = VLCPlayerService()
|
||||
assert service.vlc_executable == "vlc"
|
||||
|
||||
@patch("app.services.vlc_player.subprocess.run")
|
||||
def test_find_vlc_executable_found_by_path(self, mock_run):
|
||||
"""Test VLC executable detection when found by absolute path."""
|
||||
mock_run.return_value.returncode = 1 # which command fails
|
||||
|
||||
# Mock Path to return True for the first absolute path
|
||||
with patch("app.services.vlc_player.Path") as mock_path:
|
||||
def path_side_effect(path_str):
|
||||
mock_instance = Mock()
|
||||
mock_instance.exists.return_value = str(path_str) == "/usr/bin/vlc"
|
||||
return mock_instance
|
||||
|
||||
mock_path.side_effect = path_side_effect
|
||||
|
||||
service = VLCPlayerService()
|
||||
assert service.vlc_executable == "/usr/bin/vlc"
|
||||
|
||||
@patch("app.services.vlc_player.subprocess.run")
|
||||
@patch("app.services.vlc_player.Path")
|
||||
def test_find_vlc_executable_fallback(self, mock_path, mock_run):
|
||||
"""Test VLC executable detection fallback to default."""
|
||||
# Mock all paths as non-existent
|
||||
mock_path_instance = Mock()
|
||||
mock_path_instance.exists.return_value = False
|
||||
mock_path.return_value = mock_path_instance
|
||||
|
||||
# Mock which command as failing
|
||||
mock_run.return_value.returncode = 1
|
||||
|
||||
service = VLCPlayerService()
|
||||
assert service.vlc_executable == "vlc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
|
||||
async def test_play_sound_success(
|
||||
self, mock_subprocess, vlc_service, sample_sound
|
||||
):
|
||||
"""Test successful sound playback."""
|
||||
# Mock subprocess
|
||||
mock_process = Mock()
|
||||
mock_process.pid = 12345
|
||||
mock_subprocess.return_value = mock_process
|
||||
|
||||
# Mock the file path utility to avoid Path issues
|
||||
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
|
||||
mock_path = Mock()
|
||||
mock_path.exists.return_value = True
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
result = await vlc_service.play_sound(sample_sound)
|
||||
|
||||
assert result is True
|
||||
mock_subprocess.assert_called_once()
|
||||
args = mock_subprocess.call_args
|
||||
|
||||
# Check command arguments
|
||||
cmd_args = args[1] # keyword arguments
|
||||
assert "--play-and-exit" in args[0]
|
||||
assert "--intf" in args[0]
|
||||
assert "dummy" in args[0]
|
||||
assert "--no-video" in args[0]
|
||||
assert "--no-repeat" in args[0]
|
||||
assert "--no-loop" in args[0]
|
||||
assert cmd_args["stdout"] == asyncio.subprocess.DEVNULL
|
||||
assert cmd_args["stderr"] == asyncio.subprocess.DEVNULL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_file_not_found(
|
||||
self, vlc_service, sample_sound
|
||||
):
|
||||
"""Test sound playback when file doesn't exist."""
|
||||
# Mock the file path utility to return a non-existent path
|
||||
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
|
||||
mock_path = Mock()
|
||||
mock_path.exists.return_value = False
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
result = await vlc_service.play_sound(sample_sound)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
|
||||
async def test_play_sound_subprocess_error(
|
||||
self, mock_subprocess, vlc_service, sample_sound
|
||||
):
|
||||
"""Test sound playback when subprocess fails."""
|
||||
# Mock the file path utility to return an existing path
|
||||
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
|
||||
mock_path = Mock()
|
||||
mock_path.exists.return_value = True
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
# Mock subprocess exception
|
||||
mock_subprocess.side_effect = Exception("Subprocess failed")
|
||||
|
||||
result = await vlc_service.play_sound(sample_sound)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
|
||||
async def test_stop_all_vlc_instances_success(self, mock_subprocess, vlc_service):
|
||||
"""Test successful stopping of all VLC instances."""
|
||||
# Mock pgrep process (find VLC processes)
|
||||
mock_find_process = Mock()
|
||||
mock_find_process.returncode = 0
|
||||
mock_find_process.communicate = AsyncMock(
|
||||
return_value=(b"12345\n67890\n", b"")
|
||||
)
|
||||
|
||||
# Mock pkill process (kill VLC processes)
|
||||
mock_kill_process = Mock()
|
||||
mock_kill_process.communicate = AsyncMock(return_value=(b"", b""))
|
||||
|
||||
# Mock verify process (check remaining processes)
|
||||
mock_verify_process = Mock()
|
||||
mock_verify_process.returncode = 1 # No processes found
|
||||
mock_verify_process.communicate = AsyncMock(return_value=(b"", b""))
|
||||
|
||||
# Set up subprocess mock to return different processes for each call
|
||||
mock_subprocess.side_effect = [
|
||||
mock_find_process,
|
||||
mock_kill_process,
|
||||
mock_verify_process,
|
||||
]
|
||||
|
||||
result = await vlc_service.stop_all_vlc_instances()
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["processes_found"] == 2
|
||||
assert result["processes_killed"] == 2
|
||||
assert result["processes_remaining"] == 0
|
||||
assert "Killed 2 VLC processes" in result["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
|
||||
async def test_stop_all_vlc_instances_no_processes(
|
||||
self, mock_subprocess, vlc_service
|
||||
):
|
||||
"""Test stopping VLC instances when none are running."""
|
||||
# Mock pgrep process (no VLC processes found)
|
||||
mock_find_process = Mock()
|
||||
mock_find_process.returncode = 1 # No processes found
|
||||
mock_find_process.communicate = AsyncMock(return_value=(b"", b""))
|
||||
|
||||
mock_subprocess.return_value = mock_find_process
|
||||
|
||||
result = await vlc_service.stop_all_vlc_instances()
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["processes_found"] == 0
|
||||
assert result["processes_killed"] == 0
|
||||
assert result["message"] == "No VLC processes found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
|
||||
async def test_stop_all_vlc_instances_partial_kill(
|
||||
self, mock_subprocess, vlc_service
|
||||
):
|
||||
"""Test stopping VLC instances when some processes remain."""
|
||||
# Mock pgrep process (find VLC processes)
|
||||
mock_find_process = Mock()
|
||||
mock_find_process.returncode = 0
|
||||
mock_find_process.communicate = AsyncMock(
|
||||
return_value=(b"12345\n67890\n11111\n", b"")
|
||||
)
|
||||
|
||||
# Mock pkill process (kill VLC processes)
|
||||
mock_kill_process = Mock()
|
||||
mock_kill_process.communicate = AsyncMock(return_value=(b"", b""))
|
||||
|
||||
# Mock verify process (one process remains)
|
||||
mock_verify_process = Mock()
|
||||
mock_verify_process.returncode = 0
|
||||
mock_verify_process.communicate = AsyncMock(return_value=(b"11111\n", b""))
|
||||
|
||||
mock_subprocess.side_effect = [
|
||||
mock_find_process,
|
||||
mock_kill_process,
|
||||
mock_verify_process,
|
||||
]
|
||||
|
||||
result = await vlc_service.stop_all_vlc_instances()
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["processes_found"] == 3
|
||||
assert result["processes_killed"] == 2
|
||||
assert result["processes_remaining"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
|
||||
async def test_stop_all_vlc_instances_error(self, mock_subprocess, vlc_service):
|
||||
"""Test stopping VLC instances when an error occurs."""
|
||||
# Mock subprocess exception
|
||||
mock_subprocess.side_effect = Exception("Command failed")
|
||||
|
||||
result = await vlc_service.stop_all_vlc_instances()
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["processes_found"] == 0
|
||||
assert result["processes_killed"] == 0
|
||||
assert "error" in result
|
||||
assert result["message"] == "Failed to stop VLC processes"
|
||||
|
||||
def test_get_vlc_player_service_singleton(self):
|
||||
"""Test that get_vlc_player_service returns the same instance."""
|
||||
with patch("app.services.vlc_player.VLCPlayerService") as mock_service_class:
|
||||
mock_instance = Mock()
|
||||
mock_service_class.return_value = mock_instance
|
||||
|
||||
# Clear the global instance
|
||||
import app.services.vlc_player
|
||||
app.services.vlc_player.vlc_player_service = None
|
||||
|
||||
# First call should create new instance
|
||||
service1 = get_vlc_player_service()
|
||||
assert service1 == mock_instance
|
||||
mock_service_class.assert_called_once()
|
||||
|
||||
# Second call should return same instance
|
||||
service2 = get_vlc_player_service()
|
||||
assert service2 == mock_instance
|
||||
assert service1 is service2
|
||||
# Constructor should not be called again
|
||||
mock_service_class.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
|
||||
async def test_play_sound_with_play_count_tracking(
|
||||
self, mock_subprocess, vlc_service_with_db, sample_sound
|
||||
):
|
||||
"""Test sound playback with play count tracking."""
|
||||
# Mock subprocess
|
||||
mock_process = Mock()
|
||||
mock_process.pid = 12345
|
||||
mock_subprocess.return_value = mock_process
|
||||
|
||||
# Mock session and repositories
|
||||
mock_session = AsyncMock()
|
||||
vlc_service_with_db.db_session_factory.return_value = mock_session
|
||||
|
||||
# Mock repositories
|
||||
mock_sound_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
|
||||
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
|
||||
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
|
||||
with patch("app.services.vlc_player.socket_manager") as mock_socket:
|
||||
with patch("app.services.vlc_player.select") as mock_select:
|
||||
# Mock the file path utility
|
||||
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
|
||||
mock_path = Mock()
|
||||
mock_path.exists.return_value = True
|
||||
mock_get_path.return_value = mock_path
|
||||
|
||||
# Mock sound repository responses
|
||||
updated_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
play_count=1, # Updated count
|
||||
)
|
||||
mock_sound_repo.get_by_id.return_value = sample_sound
|
||||
mock_sound_repo.update.return_value = updated_sound
|
||||
|
||||
# Mock admin user
|
||||
admin_user = User(
|
||||
id=1,
|
||||
email="admin@test.com",
|
||||
name="Admin User",
|
||||
role="admin",
|
||||
)
|
||||
mock_user_repo.get_by_id.return_value = admin_user
|
||||
|
||||
# Mock socket broadcast
|
||||
mock_socket.broadcast_to_all = AsyncMock()
|
||||
|
||||
result = await vlc_service_with_db.play_sound(sample_sound)
|
||||
|
||||
# Wait a bit for the async task to complete
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify subprocess was called
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Note: The async task runs in the background, so we can't easily
|
||||
# verify the database operations in this test without more complex
|
||||
# mocking or using a real async test framework setup
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_play_count_success(self, vlc_service_with_db):
|
||||
"""Test successful play count recording."""
|
||||
# Mock session and repositories
|
||||
mock_session = AsyncMock()
|
||||
vlc_service_with_db.db_session_factory.return_value = mock_session
|
||||
|
||||
mock_sound_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
|
||||
# Create test sound and user
|
||||
test_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
play_count=0,
|
||||
)
|
||||
admin_user = User(
|
||||
id=1,
|
||||
email="admin@test.com",
|
||||
name="Admin User",
|
||||
role="admin",
|
||||
)
|
||||
|
||||
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
|
||||
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
|
||||
with patch("app.services.vlc_player.socket_manager") as mock_socket:
|
||||
with patch("app.services.vlc_player.select") as mock_select:
|
||||
# Setup mocks
|
||||
mock_sound_repo.get_by_id.return_value = test_sound
|
||||
mock_user_repo.get_by_id.return_value = admin_user
|
||||
|
||||
# Mock socket broadcast
|
||||
mock_socket.broadcast_to_all = AsyncMock()
|
||||
|
||||
await vlc_service_with_db._record_play_count(1, "Test Sound")
|
||||
|
||||
# Verify sound repository calls
|
||||
mock_sound_repo.get_by_id.assert_called_once_with(1)
|
||||
mock_sound_repo.update.assert_called_once_with(
|
||||
test_sound, {"play_count": 1}
|
||||
)
|
||||
|
||||
# Verify user repository calls
|
||||
mock_user_repo.get_by_id.assert_called_once_with(1)
|
||||
|
||||
# Verify session operations
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
# Verify socket broadcast
|
||||
mock_socket.broadcast_to_all.assert_called_once_with(
|
||||
"sound_played",
|
||||
{
|
||||
"sound_id": 1,
|
||||
"sound_name": "Test Sound",
|
||||
"user_id": 1,
|
||||
"play_count": 1,
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_play_count_no_session_factory(self, vlc_service):
|
||||
"""Test play count recording when no session factory is available."""
|
||||
# This should not raise an error and should log a warning
|
||||
await vlc_service._record_play_count(1, "Test Sound")
|
||||
# The method should return early without doing anything
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_play_count_always_creates_record(self, vlc_service_with_db):
|
||||
"""Test play count recording always creates a new SoundPlayed record."""
|
||||
# Mock session and repositories
|
||||
mock_session = AsyncMock()
|
||||
vlc_service_with_db.db_session_factory.return_value = mock_session
|
||||
|
||||
mock_sound_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
|
||||
# Create test sound and user
|
||||
test_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
play_count=5,
|
||||
)
|
||||
admin_user = User(
|
||||
id=1,
|
||||
email="admin@test.com",
|
||||
name="Admin User",
|
||||
role="admin",
|
||||
)
|
||||
|
||||
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
|
||||
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
|
||||
with patch("app.services.vlc_player.socket_manager") as mock_socket:
|
||||
# Setup mocks
|
||||
mock_sound_repo.get_by_id.return_value = test_sound
|
||||
mock_user_repo.get_by_id.return_value = admin_user
|
||||
|
||||
# Mock socket broadcast
|
||||
mock_socket.broadcast_to_all = AsyncMock()
|
||||
|
||||
await vlc_service_with_db._record_play_count(1, "Test Sound")
|
||||
|
||||
# Verify sound play count was updated
|
||||
mock_sound_repo.update.assert_called_once_with(
|
||||
test_sound, {"play_count": 6}
|
||||
)
|
||||
|
||||
# Verify new SoundPlayed record was always added
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
# Verify commit happened
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_uses_shared_sound_path_utility(self, vlc_service, sample_sound):
|
||||
"""Test that VLC service uses the shared sound path utility."""
|
||||
with patch("app.services.vlc_player.get_sound_file_path") as mock_path:
|
||||
mock_file_path = Mock(spec=Path)
|
||||
mock_file_path.exists.return_value = False # File doesn't exist
|
||||
mock_path.return_value = mock_file_path
|
||||
|
||||
# This should fail because file doesn't exist
|
||||
result = asyncio.run(vlc_service.play_sound(sample_sound))
|
||||
|
||||
# Verify the utility was called and returned False
|
||||
mock_path.assert_called_once_with(sample_sound)
|
||||
assert result is False
|
||||
@@ -7,7 +7,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
from app.models.sound import Sound
|
||||
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size, get_sound_file_path
|
||||
|
||||
|
||||
class TestAudioUtils:
|
||||
@@ -290,3 +291,120 @@ class TestAudioUtils:
|
||||
# Should raise FileNotFoundError for nonexistent file
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_file_size(nonexistent_path)
|
||||
|
||||
def test_get_sound_file_path_sdb_original(self):
|
||||
"""Test getting sound file path for SDB type original file."""
|
||||
sound = Sound(
|
||||
id=1,
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
type="SDB",
|
||||
is_normalized=False,
|
||||
)
|
||||
|
||||
result = get_sound_file_path(sound)
|
||||
expected = Path("sounds/originals/soundboard/test.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_sdb_normalized(self):
|
||||
"""Test getting sound file path for SDB type normalized file."""
|
||||
sound = Sound(
|
||||
id=1,
|
||||
name="Test Sound",
|
||||
filename="original.mp3",
|
||||
normalized_filename="normalized.mp3",
|
||||
type="SDB",
|
||||
is_normalized=True,
|
||||
)
|
||||
|
||||
result = get_sound_file_path(sound)
|
||||
expected = Path("sounds/normalized/soundboard/normalized.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_tts_original(self):
|
||||
"""Test getting sound file path for TTS type original file."""
|
||||
sound = Sound(
|
||||
id=2,
|
||||
name="TTS Sound",
|
||||
filename="tts_file.wav",
|
||||
type="TTS",
|
||||
is_normalized=False,
|
||||
)
|
||||
|
||||
result = get_sound_file_path(sound)
|
||||
expected = Path("sounds/originals/text_to_speech/tts_file.wav")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_tts_normalized(self):
|
||||
"""Test getting sound file path for TTS type normalized file."""
|
||||
sound = Sound(
|
||||
id=2,
|
||||
name="TTS Sound",
|
||||
filename="original.wav",
|
||||
normalized_filename="normalized.mp3",
|
||||
type="TTS",
|
||||
is_normalized=True,
|
||||
)
|
||||
|
||||
result = get_sound_file_path(sound)
|
||||
expected = Path("sounds/normalized/text_to_speech/normalized.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_ext_original(self):
|
||||
"""Test getting sound file path for EXT type original file."""
|
||||
sound = Sound(
|
||||
id=3,
|
||||
name="Extracted Sound",
|
||||
filename="extracted.mp3",
|
||||
type="EXT",
|
||||
is_normalized=False,
|
||||
)
|
||||
|
||||
result = get_sound_file_path(sound)
|
||||
expected = Path("sounds/originals/extracted/extracted.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_ext_normalized(self):
|
||||
"""Test getting sound file path for EXT type normalized file."""
|
||||
sound = Sound(
|
||||
id=3,
|
||||
name="Extracted Sound",
|
||||
filename="original.mp3",
|
||||
normalized_filename="normalized.mp3",
|
||||
type="EXT",
|
||||
is_normalized=True,
|
||||
)
|
||||
|
||||
result = get_sound_file_path(sound)
|
||||
expected = Path("sounds/normalized/extracted/normalized.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_unknown_type_fallback(self):
|
||||
"""Test getting sound file path for unknown type falls back to lowercase."""
|
||||
sound = Sound(
|
||||
id=4,
|
||||
name="Unknown Type Sound",
|
||||
filename="unknown.mp3",
|
||||
type="CUSTOM",
|
||||
is_normalized=False,
|
||||
)
|
||||
|
||||
result = get_sound_file_path(sound)
|
||||
expected = Path("sounds/originals/custom/unknown.mp3")
|
||||
assert result == expected
|
||||
|
||||
def test_get_sound_file_path_normalized_without_filename(self):
|
||||
"""Test getting sound file path when normalized but no normalized_filename."""
|
||||
sound = Sound(
|
||||
id=5,
|
||||
name="Test Sound",
|
||||
filename="original.mp3",
|
||||
normalized_filename=None,
|
||||
type="SDB",
|
||||
is_normalized=True, # True but no normalized_filename
|
||||
)
|
||||
|
||||
result = get_sound_file_path(sound)
|
||||
# Should fall back to original file
|
||||
expected = Path("sounds/originals/soundboard/original.mp3")
|
||||
assert result == expected
|
||||
|
||||
277
tests/utils/test_credit_decorators.py
Normal file
277
tests/utils/test_credit_decorators.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Tests for credit decorators."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only
|
||||
|
||||
|
||||
class TestRequiresCreditsDecorator:
|
||||
"""Test requires_credits decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
service.deduct_credits = AsyncMock()
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def credit_service_factory(self, mock_credit_service):
|
||||
"""Create a credit service factory."""
|
||||
return lambda: mock_credit_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_success(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator with successful action."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int, message: str) -> str:
|
||||
return f"Success: {message}"
|
||||
|
||||
result = await test_action(user_id=123, message="test")
|
||||
|
||||
assert result == "Success: test"
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, None
|
||||
)
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, True, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_with_metadata(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator with metadata extraction."""
|
||||
|
||||
def extract_metadata(user_id: int, sound_name: str) -> dict:
|
||||
return {"sound_name": sound_name}
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id",
|
||||
metadata_extractor=extract_metadata
|
||||
)
|
||||
async def test_action(user_id: int, sound_name: str) -> bool:
|
||||
return True
|
||||
|
||||
await test_action(user_id=123, sound_name="test.mp3")
|
||||
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"}
|
||||
)
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator with failed action."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int) -> bool:
|
||||
return False # Action fails
|
||||
|
||||
result = await test_action(user_id=123)
|
||||
|
||||
assert result is False
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator when action raises exception."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int) -> str:
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await test_action(user_id=123)
|
||||
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator with insufficient credits."""
|
||||
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int) -> str:
|
||||
return "Should not execute"
|
||||
|
||||
with pytest.raises(InsufficientCreditsError):
|
||||
await test_action(user_id=123)
|
||||
|
||||
# Should not call deduct_credits since validation failed
|
||||
mock_credit_service.deduct_credits.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service):
|
||||
"""Test decorator extracting user_id from positional args."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int, message: str) -> str:
|
||||
return message
|
||||
|
||||
result = await test_action(123, "test")
|
||||
|
||||
assert result == "test"
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorator_missing_user_id(self, credit_service_factory):
|
||||
"""Test decorator when user_id cannot be extracted."""
|
||||
|
||||
@requires_credits(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(other_param: str) -> str:
|
||||
return other_param
|
||||
|
||||
with pytest.raises(ValueError, match="Could not extract user_id"):
|
||||
await test_action(other_param="test")
|
||||
|
||||
|
||||
class TestValidateCreditsOnlyDecorator:
|
||||
"""Test validate_credits_only decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def credit_service_factory(self, mock_credit_service):
|
||||
"""Create a credit service factory."""
|
||||
return lambda: mock_credit_service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service):
|
||||
"""Test validate_credits_only decorator."""
|
||||
|
||||
@validate_credits_only(
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
credit_service_factory,
|
||||
user_id_param="user_id"
|
||||
)
|
||||
async def test_action(user_id: int, message: str) -> str:
|
||||
return f"Validated: {message}"
|
||||
|
||||
result = await test_action(user_id=123, message="test")
|
||||
|
||||
assert result == "Validated: test"
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND
|
||||
)
|
||||
# Should not deduct credits, only validate
|
||||
mock_credit_service.deduct_credits.assert_not_called()
|
||||
|
||||
|
||||
class TestCreditManager:
|
||||
"""Test CreditManager context manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
"""Create a mock credit service."""
|
||||
service = AsyncMock(spec=CreditService)
|
||||
service.validate_and_reserve_credits = AsyncMock()
|
||||
service.deduct_credits = AsyncMock()
|
||||
return service
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_success(self, mock_credit_service):
|
||||
"""Test CreditManager with successful operation."""
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND,
|
||||
{"test": "data"}
|
||||
) as manager:
|
||||
manager.mark_success()
|
||||
|
||||
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"}
|
||||
)
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_failure(self, mock_credit_service):
|
||||
"""Test CreditManager with failed operation."""
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND
|
||||
):
|
||||
# Don't mark as success - should be considered failed
|
||||
pass
|
||||
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_exception(self, mock_credit_service):
|
||||
"""Test CreditManager when exception occurs."""
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND
|
||||
):
|
||||
raise ValueError("Test error")
|
||||
|
||||
mock_credit_service.deduct_credits.assert_called_once_with(
|
||||
123, CreditActionType.VLC_PLAY_SOUND, False, None
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_manager_validation_failure(self, mock_credit_service):
|
||||
"""Test CreditManager when validation fails."""
|
||||
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
|
||||
|
||||
with pytest.raises(InsufficientCreditsError):
|
||||
async with CreditManager(
|
||||
mock_credit_service,
|
||||
123,
|
||||
CreditActionType.VLC_PLAY_SOUND
|
||||
):
|
||||
pass
|
||||
|
||||
# Should not call deduct_credits since validation failed
|
||||
mock_credit_service.deduct_credits.assert_not_called()
|
||||
Reference in New Issue
Block a user