Compare commits

...

2 Commits

Author SHA1 Message Date
JSC
e43650c26c Add tests for sound repository, user OAuth repository, credit service, and credit decorators
- Implement comprehensive tests for SoundRepository covering CRUD operations and search functionalities.
- Create tests for UserOauthRepository to validate OAuth record management.
- Develop tests for CreditService to ensure proper credit management, including validation, deduction, and addition of credits.
- Add tests for credit-related decorators to verify correct behavior in credit management scenarios.
2025-07-30 21:33:55 +02:00
JSC
dd10ef5d41 feat: Add VLC player API endpoints and associated tests
- Implemented VLC player API endpoints for playing and stopping sounds.
- Added tests for successful playback, error handling, and authentication scenarios.
- Created utility function to get sound file paths based on sound properties.
- Refactored player service to utilize shared sound path utility.
- Enhanced test coverage for sound file path utility with various sound types.
- Introduced tests for VLC player service, including subprocess handling and play count tracking.
2025-07-30 20:46:49 +02:00
22 changed files with 4104 additions and 134 deletions

View File

@@ -7,11 +7,15 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType
from app.models.user import User
from app.repositories.sound import SoundRepository
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService
from app.services.vlc_player import get_vlc_player_service, VLCPlayerService
router = APIRouter(prefix="/sounds", tags=["sounds"])
@@ -37,6 +41,25 @@ async def get_extraction_service(
return ExtractionService(session)
def get_vlc_player() -> VLCPlayerService:
"""Get the VLC player service."""
from app.core.database import get_session_factory
return get_vlc_player_service(get_session_factory())
def get_credit_service() -> CreditService:
"""Get the credit service."""
from app.core.database import get_session_factory
return CreditService(get_session_factory())
async def get_sound_repository(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundRepository:
"""Get the sound repository."""
return SoundRepository(session)
# SCAN
@router.post("/scan")
async def scan_sounds(
@@ -349,3 +372,87 @@ async def get_user_extractions(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
# VLC PLAYER
@router.post("/vlc/play/{sound_id}")
async def play_sound_with_vlc(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
credit_service: Annotated[CreditService, Depends(get_credit_service)],
) -> dict[str, str | int | bool]:
"""Play a sound using VLC subprocess (requires 1 credit)."""
try:
# Get the sound
sound = await sound_repo.get_by_id(sound_id)
if not sound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Sound with ID {sound_id} not found",
)
# Check and validate credits before playing
try:
await credit_service.validate_and_reserve_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
{"sound_id": sound_id, "sound_name": sound.name}
)
except InsufficientCreditsError as e:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Insufficient credits: {e.required} required, {e.available} available",
) from e
# Play the sound using VLC
success = await vlc_player.play_sound(sound)
# Deduct credits based on success
await credit_service.deduct_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
success,
{"sound_id": sound_id, "sound_name": sound.name},
)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to launch VLC for sound playback",
)
return {
"message": f"Sound '{sound.name}' is now playing via VLC",
"sound_id": sound_id,
"sound_name": sound.name,
"success": True,
"credits_deducted": 1,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to play sound: {e!s}",
) from e
@router.post("/vlc/stop-all")
async def stop_all_vlc_instances(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
vlc_player: Annotated[VLCPlayerService, Depends(get_vlc_player)],
) -> dict:
"""Stop all running VLC instances."""
try:
result = await vlc_player.stop_all_vlc_instances()
return result
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to stop VLC instances: {e!s}",
) from e

121
app/models/credit_action.py Normal file
View File

@@ -0,0 +1,121 @@
"""Credit action definitions for the credit system."""
from enum import Enum
from typing import Any
class CreditActionType(str, Enum):
"""Types of actions that consume credits."""
VLC_PLAY_SOUND = "vlc_play_sound"
AUDIO_EXTRACTION = "audio_extraction"
TEXT_TO_SPEECH = "text_to_speech"
SOUND_NORMALIZATION = "sound_normalization"
API_REQUEST = "api_request"
PLAYLIST_CREATION = "playlist_creation"
class CreditAction:
"""Definition of a credit-consuming action."""
def __init__(
self,
action_type: CreditActionType,
cost: int,
description: str,
*,
requires_success: bool = True,
) -> None:
"""Initialize a credit action.
Args:
action_type: The type of action
cost: Number of credits required
description: Human-readable description
requires_success: Whether credits are only deducted on successful completion
"""
self.action_type = action_type
self.cost = cost
self.description = description
self.requires_success = requires_success
def __str__(self) -> str:
"""Return string representation of the action."""
return f"{self.action_type.value} ({self.cost} credits)"
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"action_type": self.action_type.value,
"cost": self.cost,
"description": self.description,
"requires_success": self.requires_success,
}
# Predefined credit actions
CREDIT_ACTIONS = {
CreditActionType.VLC_PLAY_SOUND: CreditAction(
action_type=CreditActionType.VLC_PLAY_SOUND,
cost=1,
description="Play a sound using VLC player",
requires_success=True,
),
CreditActionType.AUDIO_EXTRACTION: CreditAction(
action_type=CreditActionType.AUDIO_EXTRACTION,
cost=5,
description="Extract audio from external URL",
requires_success=True,
),
CreditActionType.TEXT_TO_SPEECH: CreditAction(
action_type=CreditActionType.TEXT_TO_SPEECH,
cost=2,
description="Generate speech from text",
requires_success=True,
),
CreditActionType.SOUND_NORMALIZATION: CreditAction(
action_type=CreditActionType.SOUND_NORMALIZATION,
cost=1,
description="Normalize audio levels",
requires_success=True,
),
CreditActionType.API_REQUEST: CreditAction(
action_type=CreditActionType.API_REQUEST,
cost=1,
description="API request (rate limiting)",
requires_success=False, # Charged even if request fails
),
CreditActionType.PLAYLIST_CREATION: CreditAction(
action_type=CreditActionType.PLAYLIST_CREATION,
cost=3,
description="Create a new playlist",
requires_success=True,
),
}
def get_credit_action(action_type: CreditActionType) -> CreditAction:
"""Get a credit action definition by type.
Args:
action_type: The action type to look up
Returns:
The credit action definition
Raises:
KeyError: If action type is not found
"""
return CREDIT_ACTIONS[action_type]
def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
"""Get all available credit actions.
Returns:
Dictionary of all credit actions
"""
return CREDIT_ACTIONS.copy()

View File

@@ -0,0 +1,29 @@
"""Credit transaction model for tracking credit usage."""
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.user import User
class CreditTransaction(BaseModel, table=True):
"""Database model for credit transactions."""
__tablename__ = "credit_transaction" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="user.id", nullable=False)
action_type: str = Field(nullable=False)
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
balance_before: int = Field(nullable=False)
balance_after: int = Field(nullable=False)
description: str = Field(nullable=False)
success: bool = Field(nullable=False, default=True)
# JSON string for additional data
metadata_json: str | None = Field(default=None)
# relationships
user: "User" = Relationship(back_populates="credit_transactions")

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel import Field, Relationship
from app.models.base import BaseModel
@@ -14,18 +14,9 @@ class SoundPlayed(BaseModel, table=True):
__tablename__ = "sound_played" # pyright: ignore[reportAssignmentType]
user_id: int = Field(foreign_key="user.id", nullable=False)
user_id: int | None = Field(foreign_key="user.id", nullable=True)
sound_id: int = Field(foreign_key="sound.id", nullable=False)
# constraints
__table_args__ = (
UniqueConstraint(
"user_id",
"sound_id",
name="uq_sound_played_user_sound",
),
)
# relationships
user: "User" = Relationship(back_populates="sounds_played")
sound: "Sound" = Relationship(back_populates="play_history")

View File

@@ -6,6 +6,7 @@ from sqlmodel import Field, Relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.credit_transaction import CreditTransaction
from app.models.extraction import Extraction
from app.models.plan import Plan
from app.models.playlist import Playlist
@@ -35,3 +36,4 @@ class User(BaseModel, table=True):
playlists: list["Playlist"] = Relationship(back_populates="user")
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
extractions: list["Extraction"] = Relationship(back_populates="user")
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")

132
app/repositories/base.py Normal file
View File

@@ -0,0 +1,132 @@
"""Base repository with common CRUD operations."""
from typing import Any, Generic, TypeVar
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
# Type variable for the model
ModelType = TypeVar("ModelType")
logger = get_logger(__name__)
class BaseRepository(Generic[ModelType]):
"""Base repository with common CRUD operations."""
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
"""Initialize the repository.
Args:
model: The SQLModel class
session: Database session
"""
self.model = model
self.session = session
async def get_by_id(self, entity_id: int) -> ModelType | None:
"""Get an entity by ID.
Args:
entity_id: The entity ID
Returns:
The entity if found, None otherwise
"""
try:
statement = select(self.model).where(getattr(self.model, "id") == entity_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id)
raise
async def get_all(
self,
limit: int = 100,
offset: int = 0,
) -> list[ModelType]:
"""Get all entities with pagination.
Args:
limit: Maximum number of entities to return
offset: Number of entities to skip
Returns:
List of entities
"""
try:
statement = select(self.model).limit(limit).offset(offset)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all %s", self.model.__name__)
raise
async def create(self, entity_data: dict[str, Any]) -> ModelType:
"""Create a new entity.
Args:
entity_data: Dictionary of entity data
Returns:
The created entity
"""
try:
entity = self.model(**entity_data)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to create %s", self.model.__name__)
raise
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
"""Update an entity.
Args:
entity: The entity to update
update_data: Dictionary of fields to update
Returns:
The updated entity
"""
try:
for field, value in update_data.items():
setattr(entity, field, value)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
return entity
except Exception:
await self.session.rollback()
logger.exception("Failed to update %s", self.model.__name__)
raise
async def delete(self, entity: ModelType) -> None:
"""Delete an entity.
Args:
entity: The entity to delete
"""
try:
await self.session.delete(entity)
await self.session.commit()
logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown"))
except Exception:
await self.session.rollback()
logger.exception("Failed to delete %s", self.model.__name__)
raise

View File

@@ -0,0 +1,108 @@
"""Repository for credit transaction database operations."""
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_transaction import CreditTransaction
from app.repositories.base import BaseRepository
class CreditTransactionRepository(BaseRepository[CreditTransaction]):
"""Repository for credit transaction operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the repository.
Args:
session: Database session
"""
super().__init__(CreditTransaction, session)
async def get_by_user_id(
self,
user_id: int,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get credit transactions for a user.
Args:
user_id: The user ID
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of credit transactions ordered by creation date (newest first)
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.user_id == user_id)
.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
async def get_by_action_type(
self,
action_type: str,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get credit transactions by action type.
Args:
action_type: The action type to filter by
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of credit transactions ordered by creation date (newest first)
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.action_type == action_type)
.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())
async def get_successful_transactions(
self,
user_id: int | None = None,
limit: int = 50,
offset: int = 0,
) -> list[CreditTransaction]:
"""Get successful credit transactions.
Args:
user_id: Optional user ID to filter by
limit: Maximum number of transactions to return
offset: Number of transactions to skip
Returns:
List of successful credit transactions
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.success == True) # noqa: E712
)
if user_id is not None:
stmt = stmt.where(CreditTransaction.user_id == user_id)
stmt = (
stmt.order_by(CreditTransaction.created_at.desc())
.limit(limit)
.offset(offset)
)
result = await self.session.exec(stmt)
return list(result.all())

383
app/services/credit.py Normal file
View File

@@ -0,0 +1,383 @@
"""Credit management service for tracking and validating user credit usage."""
import json
from collections.abc import Callable
from typing import Any
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.credit_action import CreditAction, CreditActionType, get_credit_action
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
logger = get_logger(__name__)
class InsufficientCreditsError(Exception):
"""Raised when user has insufficient credits for an action."""
def __init__(self, required: int, available: int) -> None:
"""Initialize the error.
Args:
required: Number of credits required
available: Number of credits available
"""
self.required = required
self.available = available
super().__init__(
f"Insufficient credits: {required} required, {available} available"
)
class CreditService:
"""Service for managing user credits and transactions."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
"""Initialize the credit service.
Args:
db_session_factory: Factory function to create database sessions
"""
self.db_session_factory = db_session_factory
async def check_credits(
self,
user_id: int,
action_type: CreditActionType,
) -> bool:
"""Check if user has sufficient credits for an action.
Args:
user_id: The user ID
action_type: The type of action to check
Returns:
True if user has sufficient credits, False otherwise
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
return False
return user.credits >= action.cost
finally:
await session.close()
async def validate_and_reserve_credits(
self,
user_id: int,
action_type: CreditActionType,
metadata: dict[str, Any] | None = None,
) -> tuple[User, CreditAction]:
"""Validate user has sufficient credits and optionally reserve them.
Args:
user_id: The user ID
action_type: The type of action
metadata: Optional metadata to store with transaction
Returns:
Tuple of (user, credit_action)
Raises:
InsufficientCreditsError: If user has insufficient credits
"""
action = get_credit_action(action_type)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
logger.info(
"Credits validated for user %s: %s credits available, %s required",
user_id,
user.credits,
action.cost,
)
return user, action
finally:
await session.close()
async def deduct_credits(
self,
user_id: int,
action_type: CreditActionType,
success: bool = True,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Deduct credits from user account and record transaction.
Args:
user_id: The user ID
action_type: The type of action
success: Whether the action was successful
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
InsufficientCreditsError: If user has insufficient credits
ValueError: If user not found
"""
action = get_credit_action(action_type)
# Only deduct if action requires success and was successful, or doesn't require success
should_deduct = (action.requires_success and success) or not action.requires_success
if not should_deduct:
logger.info(
"Skipping credit deduction for user %s: action %s failed and requires success",
user_id,
action_type.value,
)
# Still create a transaction record for auditing
return await self._create_transaction_record(
user_id, action, 0, success, metadata
)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
if user.credits < action.cost:
raise InsufficientCreditsError(action.cost, user.credits)
# Record transaction
balance_before = user.credits
balance_after = user.credits - action.cost
transaction = CreditTransaction(
user_id=user_id,
action_type=action_type.value,
amount=-action.cost,
balance_before=balance_before,
balance_after=balance_after,
description=action.description,
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits deducted for user %s: %s credits (action: %s, success: %s)",
user_id,
action.cost,
action_type.value,
success,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_deducted": action.cost,
"action_type": action_type.value,
"success": success,
}
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s", user_id,
)
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def add_credits(
self,
user_id: int,
amount: int,
description: str,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Add credits to user account.
Args:
user_id: The user ID
amount: Number of credits to add
description: Description of the credit addition
metadata: Optional metadata to store with transaction
Returns:
The created credit transaction
Raises:
ValueError: If user not found or amount is negative
"""
if amount <= 0:
msg = "Amount must be positive"
raise ValueError(msg)
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Record transaction
balance_before = user.credits
balance_after = user.credits + amount
transaction = CreditTransaction(
user_id=user_id,
action_type="credit_addition",
amount=amount,
balance_before=balance_before,
balance_after=balance_after,
description=description,
success=True,
metadata_json=json.dumps(metadata) if metadata else None,
)
# Update user credits
await user_repo.update(user, {"credits": balance_after})
# Save transaction
session.add(transaction)
await session.commit()
logger.info(
"Credits added for user %s: %s credits (description: %s)",
user_id,
amount,
description,
)
# Emit user_credits_changed event via WebSocket
try:
event_data = {
"user_id": str(user_id),
"credits_before": balance_before,
"credits_after": balance_after,
"credits_added": amount,
"description": description,
"success": True,
}
await socket_manager.send_to_user(str(user_id), "user_credits_changed", event_data)
logger.info("Emitted user_credits_changed event for user %s", user_id)
except Exception:
logger.exception(
"Failed to emit user_credits_changed event for user %s", user_id,
)
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def _create_transaction_record(
self,
user_id: int,
action: CreditAction,
amount: int,
success: bool,
metadata: dict[str, Any] | None = None,
) -> CreditTransaction:
"""Create a transaction record without modifying credits.
Args:
user_id: The user ID
action: The credit action
amount: Amount to record (typically 0 for failed actions)
success: Whether the action was successful
metadata: Optional metadata
Returns:
The created transaction record
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
transaction = CreditTransaction(
user_id=user_id,
action_type=action.action_type.value,
amount=amount,
balance_before=user.credits,
balance_after=user.credits,
description=f"{action.description} (failed)" if not success else action.description,
success=success,
metadata_json=json.dumps(metadata) if metadata else None,
)
session.add(transaction)
await session.commit()
return transaction
except Exception:
await session.rollback()
raise
finally:
await session.close()
async def get_user_balance(self, user_id: int) -> int:
"""Get current credit balance for a user.
Args:
user_id: The user ID
Returns:
Current credit balance
Raises:
ValueError: If user not found
"""
session = self.db_session_factory()
try:
user_repo = UserRepository(session)
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
return user.credits
finally:
await session.close()

View File

@@ -9,15 +9,14 @@ from pathlib import Path
from typing import Any
import vlc # type: ignore[import-untyped]
from sqlmodel import select
from app.core.logging import get_logger
from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
from app.utils.audio import get_sound_file_path
logger = get_logger(__name__)
@@ -198,7 +197,7 @@ class PlayerService:
return
# Get sound file path
sound_path = self._get_sound_file_path(self.state.current_sound)
sound_path = get_sound_file_path(self.state.current_sound)
if not sound_path.exists():
logger.error("Sound file not found: %s", sound_path)
return
@@ -344,6 +343,12 @@ class PlayerService:
if self.state.status != PlayerStatus.STOPPED:
await self._stop_playback()
# Set first track as current if no current track and playlist has sounds
if not self.state.current_sound_id and sounds:
self.state.current_sound_index = 0
self.state.current_sound = sounds[0]
self.state.current_sound_id = sounds[0].id
logger.info(
"Loaded playlist: %s (%s sounds)",
current_playlist.name,
@@ -360,21 +365,6 @@ class PlayerService:
"""Get current player state."""
return self.state.to_dict()
def _get_sound_file_path(self, sound: Sound) -> Path:
"""Get the file path for a sound."""
# Determine the correct subdirectory based on sound type
subdir = "extracted" if sound.type.upper() == "EXT" else sound.type.lower()
# Use normalized file if available, otherwise original
if sound.is_normalized and sound.normalized_filename:
return (
Path("sounds/normalized")
/ subdir
/ sound.normalized_filename
)
return (
Path("sounds/originals") / subdir / sound.filename
)
def _get_next_index(self, current_index: int) -> int | None:
"""Get next track index based on current mode."""
@@ -501,7 +491,6 @@ class PlayerService:
session = self.db_session_factory()
try:
sound_repo = SoundRepository(session)
user_repo = UserRepository(session)
# Update sound play count
sound = await sound_repo.get_by_id(sound_id)
@@ -519,37 +508,17 @@ class PlayerService:
else:
logger.warning("Sound %s not found for play count update", sound_id)
# Record play history for admin user (ID 1) as placeholder
# This could be refined to track per-user play history
admin_user = await user_repo.get_by_id(1)
if admin_user:
# Check if already recorded for this user using proper query
stmt = select(SoundPlayed).where(
SoundPlayed.user_id == admin_user.id,
SoundPlayed.sound_id == sound_id,
)
result = await session.exec(stmt)
existing = result.first()
if not existing:
sound_played = SoundPlayed(
user_id=admin_user.id,
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Created SoundPlayed record for user %s, sound %s",
admin_user.id,
sound_id,
)
else:
logger.info(
"SoundPlayed record already exists for user %s, sound %s",
admin_user.id,
sound_id,
)
else:
logger.warning("Admin user (ID 1) not found for play history")
# Record play history without user_id for player-based plays
# Always create a new SoundPlayed record for each play event
sound_played = SoundPlayed(
user_id=None, # No user_id for player-based plays
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Created SoundPlayed record for player play, sound %s",
sound_id,
)
await session.commit()
logger.info("Successfully recorded play count for sound %s", sound_id)

313
app/services/vlc_player.py Normal file
View File

@@ -0,0 +1,313 @@
"""VLC subprocess-based player service for immediate sound playback."""
import asyncio
import subprocess
from collections.abc import Callable
from pathlib import Path
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.repositories.sound import SoundRepository
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
from app.utils.audio import get_sound_file_path
logger = get_logger(__name__)
class VLCPlayerService:
"""Service for launching VLC instances via subprocess to play sounds."""
def __init__(
self, db_session_factory: Callable[[], AsyncSession] | None = None,
) -> None:
"""Initialize the VLC player service."""
self.vlc_executable = self._find_vlc_executable()
self.db_session_factory = db_session_factory
logger.info(
"VLC Player Service initialized with executable: %s",
self.vlc_executable,
)
def _find_vlc_executable(self) -> str:
"""Find VLC executable path based on the operating system."""
# Common VLC executable paths
possible_paths = [
"vlc", # Linux/Mac with VLC in PATH
"/usr/bin/vlc", # Linux
"/usr/local/bin/vlc", # Linux/Mac
"/Applications/VLC.app/Contents/MacOS/VLC", # macOS
"C:\\Program Files\\VideoLAN\\VLC\\vlc.exe", # Windows
"C:\\Program Files (x86)\\VideoLAN\\VLC\\vlc.exe", # Windows 32-bit
]
for path in possible_paths:
try:
if Path(path).exists():
return path
# For "vlc", try to find it in PATH
if path == "vlc":
result = subprocess.run(
["which", "vlc"],
capture_output=True,
check=False,
text=True,
)
if result.returncode == 0:
return path
except (OSError, subprocess.SubprocessError):
continue
# Default to 'vlc' and let the system handle it
logger.warning(
"VLC executable not found in common paths, using 'vlc' from PATH",
)
return "vlc"
async def play_sound(self, sound: Sound) -> bool:
"""Play a sound using a new VLC subprocess instance.
Args:
sound: The Sound object to play
Returns:
bool: True if VLC process was launched successfully, False otherwise
"""
try:
sound_path = get_sound_file_path(sound)
if not sound_path.exists():
logger.error("Sound file not found: %s", sound_path)
return False
# VLC command arguments for immediate playback
cmd = [
self.vlc_executable,
str(sound_path),
"--play-and-exit", # Exit VLC when playback finishes
"--intf",
"dummy", # No interface
"--no-video", # Audio only
"--no-repeat", # Don't repeat
"--no-loop", # Don't loop
]
# Launch VLC process asynchronously without waiting
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
logger.info(
"Launched VLC process (PID: %s) for sound: %s",
process.pid,
sound.name,
)
# Record play count and emit event
if self.db_session_factory and sound.id:
asyncio.create_task(self._record_play_count(sound.id, sound.name))
return True
except Exception:
logger.exception("Failed to launch VLC for sound %s", sound.name)
return False
async def stop_all_vlc_instances(self) -> dict[str, Any]:
"""Stop all running VLC processes by killing them.
Returns:
dict: Results of the stop operation including counts and any errors
"""
try:
# Find all VLC processes
find_cmd = ["pgrep", "-f", "vlc"]
find_process = await asyncio.create_subprocess_exec(
*find_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await find_process.communicate()
if find_process.returncode != 0:
# No VLC processes found
logger.info("No VLC processes found to stop")
return {
"success": True,
"processes_found": 0,
"processes_killed": 0,
"message": "No VLC processes found",
}
# Parse PIDs from output
pids = []
if stdout:
pids = [
pid.strip()
for pid in stdout.decode().strip().split("\n")
if pid.strip()
]
if not pids:
logger.info("No VLC processes found to stop")
return {
"success": True,
"processes_found": 0,
"processes_killed": 0,
"message": "No VLC processes found",
}
logger.info("Found %s VLC processes: %s", len(pids), ", ".join(pids))
# Kill all VLC processes
kill_cmd = ["pkill", "-f", "vlc"]
kill_process = await asyncio.create_subprocess_exec(
*kill_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await kill_process.communicate()
# Verify processes were killed
verify_process = await asyncio.create_subprocess_exec(
*find_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout_verify, _ = await verify_process.communicate()
remaining_pids = []
if verify_process.returncode == 0 and stdout_verify:
remaining_pids = [
pid.strip()
for pid in stdout_verify.decode().strip().split("\n")
if pid.strip()
]
processes_killed = len(pids) - len(remaining_pids)
logger.info(
"Kill operation completed. Found: %s, Killed: %s, Remaining: %s",
len(pids),
processes_killed,
len(remaining_pids),
)
return {
"success": True,
"processes_found": len(pids),
"processes_killed": processes_killed,
"processes_remaining": len(remaining_pids),
"message": f"Killed {processes_killed} VLC processes",
}
except Exception as e:
logger.exception("Failed to stop VLC processes")
return {
"success": False,
"processes_found": 0,
"processes_killed": 0,
"error": str(e),
"message": "Failed to stop VLC processes",
}
async def _record_play_count(self, sound_id: int, sound_name: str) -> None:
"""Record a play count for a sound and emit sound_played event."""
if not self.db_session_factory:
logger.warning(
"No database session factory available for play count recording",
)
return
logger.info("Recording play count for sound %s", sound_id)
session = self.db_session_factory()
try:
sound_repo = SoundRepository(session)
user_repo = UserRepository(session)
# Update sound play count
sound = await sound_repo.get_by_id(sound_id)
old_count = 0
if sound:
old_count = sound.play_count
await sound_repo.update(
sound,
{"play_count": sound.play_count + 1},
)
logger.info(
"Updated sound %s play_count: %s -> %s",
sound_id,
old_count,
old_count + 1,
)
else:
logger.warning("Sound %s not found for play count update", sound_id)
# Record play history for admin user (ID 1) as placeholder
# This could be refined to track per-user play history
admin_user = await user_repo.get_by_id(1)
admin_user_id = None
if admin_user:
admin_user_id = admin_user.id
# Always create a new SoundPlayed record for each play event
sound_played = SoundPlayed(
user_id=admin_user_id, # Can be None for player-based plays
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Created SoundPlayed record for user %s, sound %s",
admin_user_id,
sound_id,
)
await session.commit()
logger.info("Successfully recorded play count for sound %s", sound_id)
# Emit sound_played event via WebSocket
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": admin_user_id,
"play_count": (old_count + 1) if sound else None,
}
await socket_manager.broadcast_to_all("sound_played", event_data)
logger.info("Broadcasted sound_played event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_played event for sound %s", sound_id,
)
except Exception:
logger.exception("Error recording play count for sound %s", sound_id)
await session.rollback()
finally:
await session.close()
# Global VLC player service instance
vlc_player_service: VLCPlayerService | None = None
def get_vlc_player_service(
db_session_factory: Callable[[], AsyncSession] | None = None,
) -> VLCPlayerService:
"""Get the global VLC player service instance."""
global vlc_player_service # noqa: PLW0603
if vlc_player_service is None:
vlc_player_service = VLCPlayerService(db_session_factory)
return vlc_player_service

View File

@@ -2,11 +2,15 @@
import hashlib
from pathlib import Path
from typing import TYPE_CHECKING
import ffmpeg # type: ignore[import-untyped]
from app.core.logging import get_logger
if TYPE_CHECKING:
from app.models.sound import Sound
logger = get_logger(__name__)
@@ -33,3 +37,30 @@ def get_audio_duration(file_path: Path) -> int:
except Exception as e:
logger.warning("Failed to get duration for %s: %s", file_path, e)
return 0
def get_sound_file_path(sound: "Sound") -> Path:
"""Get the file path for a sound based on its type and normalization status.
Args:
sound: The Sound object to get the path for
Returns:
Path: The full path to the sound file
"""
# Determine the correct subdirectory based on sound type
if sound.type.upper() == "EXT":
subdir = "extracted"
elif sound.type.upper() == "SDB":
subdir = "soundboard"
elif sound.type.upper() == "TTS":
subdir = "text_to_speech"
else:
# Fallback to lowercase type
subdir = sound.type.lower()
# Use normalized file if available, otherwise original
if sound.is_normalized and sound.normalized_filename:
return Path("sounds/normalized") / subdir / sound.normalized_filename
return Path("sounds/originals") / subdir / sound.filename

View File

@@ -0,0 +1,192 @@
"""Decorators for credit management and validation."""
import functools
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
def requires_credits(
action_type: CreditActionType,
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
) -> Callable[[F], F]:
"""Decorator to enforce credit requirements for actions.
Args:
action_type: The type of action that requires credits
credit_service_factory: Factory to create credit service instance
user_id_param: Name of the parameter containing user ID
metadata_extractor: Optional function to extract metadata from function args
Returns:
Decorated function that validates and deducts credits
Example:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
lambda: get_credit_service(),
user_id_param="user_id"
)
async def play_sound_for_user(user_id: int, sound: Sound) -> bool:
# Implementation here
return True
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
param_index = param_names.index(user_id_param)
if param_index < len(args):
user_id = args[param_index]
if user_id is None:
msg = f"Could not extract user_id from parameter '{user_id_param}'"
raise ValueError(msg)
# Extract metadata if extractor provided
metadata = None
if metadata_extractor:
metadata = metadata_extractor(*args, **kwargs)
# Get credit service
credit_service = credit_service_factory()
# Validate credits before execution
await credit_service.validate_and_reserve_credits(
user_id, action_type, metadata
)
# Execute the function
success = False
result = None
try:
result = await func(*args, **kwargs)
success = bool(result) # Consider function result as success indicator
return result
except Exception:
success = False
raise
finally:
# Deduct credits based on success
await credit_service.deduct_credits(
user_id, action_type, success, metadata
)
return wrapper # type: ignore[return-value]
return decorator
def validate_credits_only(
action_type: CreditActionType,
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
) -> Callable[[F], F]:
"""Decorator to only validate credits without deducting them.
Useful for checking if a user can perform an action before actual execution.
Args:
action_type: The type of action that requires credits
credit_service_factory: Factory to create credit service instance
user_id_param: Name of the parameter containing user ID
Returns:
Decorated function that validates credits only
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
param_index = param_names.index(user_id_param)
if param_index < len(args):
user_id = args[param_index]
if user_id is None:
msg = f"Could not extract user_id from parameter '{user_id_param}'"
raise ValueError(msg)
# Get credit service
credit_service = credit_service_factory()
# Validate credits only
await credit_service.validate_and_reserve_credits(user_id, action_type)
# Execute the function
return await func(*args, **kwargs)
return wrapper # type: ignore[return-value]
return decorator
class CreditManager:
"""Context manager for credit operations."""
def __init__(
self,
credit_service: CreditService,
user_id: int,
action_type: CreditActionType,
metadata: dict[str, Any] | None = None,
) -> None:
"""Initialize credit manager.
Args:
credit_service: Credit service instance
user_id: User ID
action_type: Action type
metadata: Optional metadata
"""
self.credit_service = credit_service
self.user_id = user_id
self.action_type = action_type
self.metadata = metadata
self.validated = False
self.success = False
async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.metadata
)
self.validated = True
return self
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
"""Exit context manager - deduct credits based on success."""
if self.validated:
# If no exception occurred, consider it successful
success = exc_type is None and self.success
await self.credit_service.deduct_credits(
self.user_id, self.action_type, success, self.metadata
)
def mark_success(self) -> None:
"""Mark the operation as successful."""
self.success = True

View File

@@ -0,0 +1,305 @@
"""Tests for VLC player API endpoints."""
from unittest.mock import AsyncMock, patch
import pytest
from httpx import AsyncClient
from app.models.sound import Sound
from app.models.user import User
class TestVLCEndpoints:
"""Test VLC player API endpoints."""
@pytest.mark.asyncio
async def test_play_sound_with_vlc_success(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test successful sound playback via VLC."""
# Mock the VLC player service and sound repository methods
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
mock_play_sound.return_value = True
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
mock_sound = Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test.mp3",
duration=5000,
size=1024,
hash="test_hash",
)
mock_get_by_id.return_value = mock_sound
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["sound_id"] == 1
assert data["sound_name"] == "Test Sound"
assert "Test Sound" in data["message"]
# Verify service calls
mock_get_by_id.assert_called_once_with(1)
mock_play_sound.assert_called_once_with(mock_sound)
@pytest.mark.asyncio
async def test_play_sound_with_vlc_sound_not_found(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test VLC playback when sound is not found."""
# Mock the sound repository to return None
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
mock_get_by_id.return_value = None
response = await authenticated_client.post("/api/v1/sounds/vlc/play/999")
assert response.status_code == 404
data = response.json()
assert "Sound with ID 999 not found" in data["detail"]
@pytest.mark.asyncio
async def test_play_sound_with_vlc_launch_failure(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test VLC playback when VLC launch fails."""
# Mock the VLC player service to fail
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
mock_play_sound.return_value = False
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
mock_sound = Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test.mp3",
duration=5000,
size=1024,
hash="test_hash",
)
mock_get_by_id.return_value = mock_sound
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 500
data = response.json()
assert "Failed to launch VLC for sound playback" in data["detail"]
@pytest.mark.asyncio
async def test_play_sound_with_vlc_service_exception(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test VLC playback when service raises an exception."""
# Mock the sound repository to raise an exception
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
mock_get_by_id.side_effect = Exception("Database error")
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 500
data = response.json()
assert "Failed to play sound" in data["detail"]
@pytest.mark.asyncio
async def test_play_sound_with_vlc_unauthenticated(
self,
client: AsyncClient,
):
"""Test VLC playback without authentication."""
response = await client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_stop_all_vlc_instances_success(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test successful stopping of all VLC instances."""
# Mock the VLC player service
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
mock_result = {
"success": True,
"processes_found": 3,
"processes_killed": 3,
"processes_remaining": 0,
"message": "Killed 3 VLC processes",
}
mock_stop_all.return_value = mock_result
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["processes_found"] == 3
assert data["processes_killed"] == 3
assert data["processes_remaining"] == 0
assert "Killed 3 VLC processes" in data["message"]
# Verify service call
mock_stop_all.assert_called_once()
@pytest.mark.asyncio
async def test_stop_all_vlc_instances_no_processes(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test stopping VLC instances when none are running."""
# Mock the VLC player service
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
mock_result = {
"success": True,
"processes_found": 0,
"processes_killed": 0,
"message": "No VLC processes found",
}
mock_stop_all.return_value = mock_result
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["processes_found"] == 0
assert data["processes_killed"] == 0
assert data["message"] == "No VLC processes found"
@pytest.mark.asyncio
async def test_stop_all_vlc_instances_partial_success(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test stopping VLC instances with partial success."""
# Mock the VLC player service
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
mock_result = {
"success": True,
"processes_found": 3,
"processes_killed": 2,
"processes_remaining": 1,
"message": "Killed 2 VLC processes",
}
mock_stop_all.return_value = mock_result
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["processes_found"] == 3
assert data["processes_killed"] == 2
assert data["processes_remaining"] == 1
@pytest.mark.asyncio
async def test_stop_all_vlc_instances_failure(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test stopping VLC instances when service fails."""
# Mock the VLC player service
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
mock_result = {
"success": False,
"processes_found": 0,
"processes_killed": 0,
"error": "Command failed",
"message": "Failed to stop VLC processes",
}
mock_stop_all.return_value = mock_result
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is False
assert data["error"] == "Command failed"
assert data["message"] == "Failed to stop VLC processes"
@pytest.mark.asyncio
async def test_stop_all_vlc_instances_service_exception(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
):
"""Test stopping VLC instances when service raises an exception."""
# Mock the VLC player service to raise an exception
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
mock_stop_all.side_effect = Exception("Service error")
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 500
data = response.json()
assert "Failed to stop VLC instances" in data["detail"]
@pytest.mark.asyncio
async def test_stop_all_vlc_instances_unauthenticated(
self,
client: AsyncClient,
):
"""Test stopping VLC instances without authentication."""
response = await client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_vlc_endpoints_with_admin_user(
self,
authenticated_admin_client: AsyncClient,
admin_user: User,
):
"""Test VLC endpoints work with admin user."""
# Test play endpoint with admin
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
mock_play_sound.return_value = True
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
mock_sound = Sound(
id=1,
type="SDB",
name="Admin Test Sound",
filename="admin_test.mp3",
duration=3000,
size=512,
hash="admin_hash",
)
mock_get_by_id.return_value = mock_sound
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["sound_name"] == "Admin Test Sound"
# Test stop-all endpoint with admin
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
mock_result = {
"success": True,
"processes_found": 1,
"processes_killed": 1,
"processes_remaining": 0,
"message": "Killed 1 VLC processes",
}
mock_stop_all.return_value = mock_result
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["processes_killed"] == 1

View File

@@ -13,8 +13,10 @@ from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy
from app.models.plan import Plan
from app.models.user import User
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
from app.utils.auth import JWTUtils, PasswordUtils

View File

@@ -0,0 +1,412 @@
"""Tests for credit transaction repository."""
import json
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.repositories.credit_transaction import CreditTransactionRepository
class TestCreditTransactionRepository:
"""Test credit transaction repository operations."""
@pytest_asyncio.fixture
async def credit_transaction_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[CreditTransactionRepository, None]: # type: ignore[misc]
"""Create a credit transaction repository instance."""
yield CreditTransactionRepository(test_session)
@pytest_asyncio.fixture
async def test_user_id(
self,
test_user: User,
) -> int:
"""Get test user ID to avoid lazy loading issues."""
return test_user.id
@pytest_asyncio.fixture
async def test_transactions(
self,
test_session: AsyncSession,
test_user_id: int,
) -> AsyncGenerator[list[CreditTransaction], None]: # type: ignore[misc]
"""Create test credit transactions."""
transactions = []
user_id = test_user_id
# Create various types of transactions
transaction_data = [
{
"user_id": user_id,
"action_type": "vlc_play_sound",
"amount": -1,
"balance_before": 10,
"balance_after": 9,
"description": "Play sound via VLC",
"success": True,
"metadata_json": json.dumps({"sound_id": 1, "sound_name": "test.mp3"}),
},
{
"user_id": user_id,
"action_type": "audio_extraction",
"amount": -5,
"balance_before": 9,
"balance_after": 4,
"description": "Extract audio from URL",
"success": True,
"metadata_json": json.dumps({"url": "https://example.com/video"}),
},
{
"user_id": user_id,
"action_type": "vlc_play_sound",
"amount": 0,
"balance_before": 4,
"balance_after": 4,
"description": "Play sound via VLC (failed)",
"success": False,
"metadata_json": json.dumps({"sound_id": 2, "error": "File not found"}),
},
{
"user_id": user_id,
"action_type": "credit_addition",
"amount": 50,
"balance_before": 4,
"balance_after": 54,
"description": "Bonus credits",
"success": True,
"metadata_json": json.dumps({"reason": "signup_bonus"}),
},
]
for data in transaction_data:
transaction = CreditTransaction(**data)
test_session.add(transaction)
transactions.append(transaction)
await test_session.commit()
for transaction in transactions:
await test_session.refresh(transaction)
yield transactions
@pytest_asyncio.fixture
async def other_user_transaction(
self,
test_session: AsyncSession,
ensure_plans: tuple, # noqa: ARG002
) -> AsyncGenerator[CreditTransaction, None]: # type: ignore[misc]
"""Create a transaction for a different user."""
from app.models.plan import Plan
from app.repositories.user import UserRepository
# Create another user
user_repo = UserRepository(test_session)
other_user_data = {
"email": "other@example.com",
"name": "Other User",
"password_hash": "hashed_password",
"role": "user",
"is_active": True,
}
other_user = await user_repo.create(other_user_data)
# Create transaction for the other user
transaction_data = {
"user_id": other_user.id,
"action_type": "vlc_play_sound",
"amount": -1,
"balance_before": 100,
"balance_after": 99,
"description": "Other user play sound",
"success": True,
"metadata_json": None,
}
transaction = CreditTransaction(**transaction_data)
test_session.add(transaction)
await test_session.commit()
await test_session.refresh(transaction)
yield transaction
@pytest.mark.asyncio
async def test_get_by_id_existing(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transaction by ID when it exists."""
transaction = await credit_transaction_repository.get_by_id(test_transactions[0].id)
assert transaction is not None
assert transaction.id == test_transactions[0].id
assert transaction.action_type == "vlc_play_sound"
assert transaction.amount == -1
@pytest.mark.asyncio
async def test_get_by_id_nonexistent(
self,
credit_transaction_repository: CreditTransactionRepository,
) -> None:
"""Test getting transaction by ID when it doesn't exist."""
transaction = await credit_transaction_repository.get_by_id(99999)
assert transaction is None
@pytest.mark.asyncio
async def test_get_by_user_id(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
test_user_id: int,
) -> None:
"""Test getting transactions by user ID."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should return all transactions for test_user
assert len(transactions) == 4
# Should be ordered by created_at desc (newest first)
assert all(t.user_id == test_user_id for t in transactions)
# Should not include other user's transaction
other_user_ids = [t.user_id for t in transactions]
assert other_user_transaction.user_id not in other_user_ids
@pytest.mark.asyncio
async def test_get_by_user_id_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test getting transactions by user ID with pagination."""
# Get first 2 transactions
first_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=0
)
assert len(first_page) == 2
# Get next 2 transactions
second_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=2
)
assert len(second_page) == 2
# Should not overlap
first_page_ids = {t.id for t in first_page}
second_page_ids = {t.id for t in second_page}
assert first_page_ids.isdisjoint(second_page_ids)
@pytest.mark.asyncio
async def test_get_by_action_type(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transactions by action type."""
vlc_transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound"
)
# Should return 2 VLC transactions (1 successful, 1 failed)
assert len(vlc_transactions) >= 2
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
extraction_transactions = await credit_transaction_repository.get_by_action_type(
"audio_extraction"
)
# Should return 1 extraction transaction
assert len(extraction_transactions) >= 1
assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
@pytest.mark.asyncio
async def test_get_by_action_type_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting transactions by action type with pagination."""
# Test with limit
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1
)
assert len(transactions) == 1
assert transactions[0].action_type == "vlc_play_sound"
# Test with offset
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, offset=1
)
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
@pytest.mark.asyncio
async def test_get_successful_transactions(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test getting only successful transactions."""
successful_transactions = await credit_transaction_repository.get_successful_transactions()
# Should only return successful transactions
assert all(t.success is True for t in successful_transactions)
# Should be at least 3 (vlc_play_sound, audio_extraction, credit_addition)
assert len(successful_transactions) >= 3
@pytest.mark.asyncio
async def test_get_successful_transactions_by_user(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
test_user_id: int,
) -> None:
"""Test getting successful transactions filtered by user."""
successful_transactions = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id
)
# Should only return successful transactions for test_user
assert all(t.success is True for t in successful_transactions)
assert all(t.user_id == test_user_id for t in successful_transactions)
# Should be 3 successful transactions for test_user
assert len(successful_transactions) == 3
@pytest.mark.asyncio
async def test_get_successful_transactions_with_pagination(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test getting successful transactions with pagination."""
# Get first 2 successful transactions
first_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=0
)
assert len(first_page) == 2
assert all(t.success is True for t in first_page)
# Get next successful transaction
second_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=2
)
assert len(second_page) == 1 # Should be 1 remaining
assert all(t.success is True for t in second_page)
@pytest.mark.asyncio
async def test_get_all_transactions(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
other_user_transaction: CreditTransaction,
) -> None:
"""Test getting all transactions."""
all_transactions = await credit_transaction_repository.get_all()
# Should return all transactions
assert len(all_transactions) >= 5 # 4 from test_transactions + 1 other_user_transaction
@pytest.mark.asyncio
async def test_create_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_user_id: int,
) -> None:
"""Test creating a new transaction."""
transaction_data = {
"user_id": test_user_id,
"action_type": "test_action",
"amount": -10,
"balance_before": 100,
"balance_after": 90,
"description": "Test transaction",
"success": True,
"metadata_json": json.dumps({"test": "data"}),
}
transaction = await credit_transaction_repository.create(transaction_data)
assert transaction.id is not None
assert transaction.user_id == test_user_id
assert transaction.action_type == "test_action"
assert transaction.amount == -10
assert transaction.balance_before == 100
assert transaction.balance_after == 90
assert transaction.success is True
assert json.loads(transaction.metadata_json) == {"test": "data"}
@pytest.mark.asyncio
async def test_update_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
) -> None:
"""Test updating a transaction."""
transaction = test_transactions[0]
update_data = {
"description": "Updated description",
"metadata_json": json.dumps({"updated": True}),
}
updated_transaction = await credit_transaction_repository.update(
transaction, update_data
)
assert updated_transaction.id == transaction.id
assert updated_transaction.description == "Updated description"
assert json.loads(updated_transaction.metadata_json) == {"updated": True}
# Other fields should remain unchanged
assert updated_transaction.amount == transaction.amount
assert updated_transaction.action_type == transaction.action_type
@pytest.mark.asyncio
async def test_delete_transaction(
self,
credit_transaction_repository: CreditTransactionRepository,
test_session: AsyncSession,
test_user_id: int,
) -> None:
"""Test deleting a transaction."""
# Create a transaction to delete
transaction_data = {
"user_id": test_user_id,
"action_type": "to_delete",
"amount": -1,
"balance_before": 10,
"balance_after": 9,
"description": "To be deleted",
"success": True,
"metadata_json": None,
}
transaction = await credit_transaction_repository.create(transaction_data)
transaction_id = transaction.id
# Delete the transaction
await credit_transaction_repository.delete(transaction)
# Verify transaction is deleted
deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id)
assert deleted_transaction is None
@pytest.mark.asyncio
async def test_transaction_ordering(
self,
credit_transaction_repository: CreditTransactionRepository,
test_transactions: list[CreditTransaction],
test_user_id: int,
) -> None:
"""Test that transactions are ordered by created_at desc."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should be ordered by created_at desc (newest first)
for i in range(len(transactions) - 1):
assert transactions[i].created_at >= transactions[i + 1].created_at

View File

@@ -0,0 +1,376 @@
"""Tests for sound repository."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound
from app.repositories.sound import SoundRepository
class TestSoundRepository:
"""Test sound repository operations."""
@pytest_asyncio.fixture
async def sound_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[SoundRepository, None]: # type: ignore[misc]
"""Create a sound repository instance."""
yield SoundRepository(test_session)
@pytest_asyncio.fixture
async def test_sound(
self,
test_session: AsyncSession,
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
"""Create a test sound."""
sound_data = {
"name": "Test Sound",
"filename": "test_sound.mp3",
"type": "SDB",
"duration": 5000,
"size": 1024000,
"hash": "test_hash_123",
"play_count": 0,
"is_normalized": False,
}
sound = Sound(**sound_data)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(sound)
yield sound
@pytest_asyncio.fixture
async def normalized_sound(
self,
test_session: AsyncSession,
) -> AsyncGenerator[Sound, None]: # type: ignore[misc]
"""Create a normalized test sound."""
sound_data = {
"name": "Normalized Sound",
"filename": "normalized_sound.mp3",
"type": "TTS",
"duration": 3000,
"size": 512000,
"hash": "normalized_hash_456",
"play_count": 5,
"is_normalized": True,
"normalized_filename": "normalized_sound_norm.mp3",
"normalized_duration": 3000,
"normalized_size": 480000,
"normalized_hash": "normalized_hash_norm_456",
}
sound = Sound(**sound_data)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(sound)
yield sound
@pytest.mark.asyncio
async def test_get_by_id_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by ID when it exists."""
sound = await sound_repository.get_by_id(test_sound.id)
assert sound is not None
assert sound.id == test_sound.id
assert sound.name == test_sound.name
assert sound.filename == test_sound.filename
assert sound.type == test_sound.type
@pytest.mark.asyncio
async def test_get_by_id_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by ID when it doesn't exist."""
sound = await sound_repository.get_by_id(99999)
assert sound is None
@pytest.mark.asyncio
async def test_get_by_filename_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by filename when it exists."""
sound = await sound_repository.get_by_filename(test_sound.filename)
assert sound is not None
assert sound.id == test_sound.id
assert sound.filename == test_sound.filename
@pytest.mark.asyncio
async def test_get_by_filename_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by filename when it doesn't exist."""
sound = await sound_repository.get_by_filename("nonexistent.mp3")
assert sound is None
@pytest.mark.asyncio
async def test_get_by_hash_existing(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test getting sound by hash when it exists."""
sound = await sound_repository.get_by_hash(test_sound.hash)
assert sound is not None
assert sound.id == test_sound.id
assert sound.hash == test_sound.hash
@pytest.mark.asyncio
async def test_get_by_hash_nonexistent(
self,
sound_repository: SoundRepository,
) -> None:
"""Test getting sound by hash when it doesn't exist."""
sound = await sound_repository.get_by_hash("nonexistent_hash")
assert sound is None
@pytest.mark.asyncio
async def test_get_by_type(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting sounds by type."""
sdb_sounds = await sound_repository.get_by_type("SDB")
tts_sounds = await sound_repository.get_by_type("TTS")
ext_sounds = await sound_repository.get_by_type("EXT")
# Should find the SDB sound
assert len(sdb_sounds) >= 1
assert any(sound.id == test_sound.id for sound in sdb_sounds)
# Should find the TTS sound
assert len(tts_sounds) >= 1
assert any(sound.id == normalized_sound.id for sound in tts_sounds)
# Should not find any EXT sounds
assert len(ext_sounds) == 0
@pytest.mark.asyncio
async def test_create_sound(
self,
sound_repository: SoundRepository,
) -> None:
"""Test creating a new sound."""
sound_data = {
"name": "New Sound",
"filename": "new_sound.wav",
"type": "EXT",
"duration": 7500,
"size": 2048000,
"hash": "new_hash_789",
"play_count": 0,
"is_normalized": False,
}
sound = await sound_repository.create(sound_data)
assert sound.id is not None
assert sound.name == sound_data["name"]
assert sound.filename == sound_data["filename"]
assert sound.type == sound_data["type"]
assert sound.duration == sound_data["duration"]
assert sound.size == sound_data["size"]
assert sound.hash == sound_data["hash"]
assert sound.play_count == 0
assert sound.is_normalized is False
@pytest.mark.asyncio
async def test_update_sound(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test updating a sound."""
update_data = {
"name": "Updated Sound Name",
"play_count": 10,
"is_normalized": True,
"normalized_filename": "updated_norm.mp3",
}
updated_sound = await sound_repository.update(test_sound, update_data)
assert updated_sound.id == test_sound.id
assert updated_sound.name == "Updated Sound Name"
assert updated_sound.play_count == 10
assert updated_sound.is_normalized is True
assert updated_sound.normalized_filename == "updated_norm.mp3"
assert updated_sound.filename == test_sound.filename # Unchanged
@pytest.mark.asyncio
async def test_delete_sound(
self,
sound_repository: SoundRepository,
test_session: AsyncSession,
) -> None:
"""Test deleting a sound."""
# Create a sound to delete
sound_data = {
"name": "To Delete",
"filename": "to_delete.mp3",
"type": "SDB",
"duration": 1000,
"size": 256000,
"hash": "delete_hash",
"play_count": 0,
"is_normalized": False,
}
sound = await sound_repository.create(sound_data)
sound_id = sound.id
# Delete the sound
await sound_repository.delete(sound)
# Verify sound is deleted
deleted_sound = await sound_repository.get_by_id(sound_id)
assert deleted_sound is None
@pytest.mark.asyncio
async def test_search_by_name(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test searching sounds by name."""
# Search for "test" should find test_sound
results = await sound_repository.search_by_name("test")
assert len(results) >= 1
assert any(sound.id == test_sound.id for sound in results)
# Search for "normalized" should find normalized_sound
results = await sound_repository.search_by_name("normalized")
assert len(results) >= 1
assert any(sound.id == normalized_sound.id for sound in results)
# Case insensitive search
results = await sound_repository.search_by_name("TEST")
assert len(results) >= 1
assert any(sound.id == test_sound.id for sound in results)
# Partial match
results = await sound_repository.search_by_name("norm")
assert len(results) >= 1
assert any(sound.id == normalized_sound.id for sound in results)
# No matches
results = await sound_repository.search_by_name("nonexistent")
assert len(results) == 0
@pytest.mark.asyncio
async def test_get_popular_sounds(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting popular sounds."""
# Update play counts to test ordering
await sound_repository.update(test_sound, {"play_count": 15})
await sound_repository.update(normalized_sound, {"play_count": 5})
# Create another sound with higher play count
high_play_sound_data = {
"name": "Popular Sound",
"filename": "popular.mp3",
"type": "SDB",
"duration": 2000,
"size": 300000,
"hash": "popular_hash",
"play_count": 25,
"is_normalized": False,
}
high_play_sound = await sound_repository.create(high_play_sound_data)
# Get popular sounds
popular_sounds = await sound_repository.get_popular_sounds(limit=10)
assert len(popular_sounds) >= 3
# Should be ordered by play_count desc
assert popular_sounds[0].play_count >= popular_sounds[1].play_count
# The highest play count sound should be first
assert popular_sounds[0].id == high_play_sound.id
@pytest.mark.asyncio
async def test_get_unnormalized_sounds(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting unnormalized sounds."""
unnormalized_sounds = await sound_repository.get_unnormalized_sounds()
# Should include test_sound (not normalized)
assert any(sound.id == test_sound.id for sound in unnormalized_sounds)
# Should not include normalized_sound (already normalized)
assert not any(sound.id == normalized_sound.id for sound in unnormalized_sounds)
@pytest.mark.asyncio
async def test_get_unnormalized_sounds_by_type(
self,
sound_repository: SoundRepository,
test_sound: Sound,
normalized_sound: Sound,
) -> None:
"""Test getting unnormalized sounds by type."""
# Get unnormalized SDB sounds
sdb_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("SDB")
# Should include test_sound (SDB, not normalized)
assert any(sound.id == test_sound.id for sound in sdb_unnormalized)
# Get unnormalized TTS sounds
tts_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("TTS")
# Should not include normalized_sound (TTS, but already normalized)
assert not any(sound.id == normalized_sound.id for sound in tts_unnormalized)
# Get unnormalized EXT sounds
ext_unnormalized = await sound_repository.get_unnormalized_sounds_by_type("EXT")
# Should be empty
assert len(ext_unnormalized) == 0
@pytest.mark.asyncio
async def test_create_duplicate_hash(
self,
sound_repository: SoundRepository,
test_sound: Sound,
) -> None:
"""Test creating sound with duplicate hash is allowed."""
# Store the hash to avoid lazy loading issues
original_hash = test_sound.hash
duplicate_sound_data = {
"name": "Duplicate Hash Sound",
"filename": "duplicate.mp3",
"type": "SDB",
"duration": 1000,
"size": 100000,
"hash": original_hash, # Same hash as test_sound
"play_count": 0,
"is_normalized": False,
}
# Should succeed - duplicate hashes are allowed
duplicate_sound = await sound_repository.create(duplicate_sound_data)
assert duplicate_sound.id is not None
assert duplicate_sound.name == "Duplicate Hash Sound"
assert duplicate_sound.hash == original_hash # Same hash is allowed

View File

@@ -0,0 +1,268 @@
"""Tests for user OAuth repository."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.user import User
from app.models.user_oauth import UserOauth
from app.repositories.user_oauth import UserOauthRepository
class TestUserOauthRepository:
"""Test user OAuth repository operations."""
@pytest_asyncio.fixture
async def user_oauth_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[UserOauthRepository, None]: # type: ignore[misc]
"""Create a user OAuth repository instance."""
yield UserOauthRepository(test_session)
@pytest_asyncio.fixture
async def test_user_id(
self,
test_user: User,
) -> int:
"""Get test user ID to avoid lazy loading issues."""
return test_user.id
@pytest_asyncio.fixture
async def test_oauth(
self,
test_session: AsyncSession,
test_user_id: int,
) -> AsyncGenerator[UserOauth, None]: # type: ignore[misc]
"""Create a test OAuth record."""
oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_123456",
"email": "test@gmail.com",
"name": "Test User Google",
"picture": None,
}
oauth = UserOauth(**oauth_data)
test_session.add(oauth)
await test_session.commit()
await test_session.refresh(oauth)
yield oauth
@pytest.mark.asyncio
async def test_get_by_provider_user_id_existing(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
) -> None:
"""Test getting OAuth by provider user ID when it exists."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "google_123456"
)
assert oauth is not None
assert oauth.id == test_oauth.id
assert oauth.provider == "google"
assert oauth.provider_user_id == "google_123456"
assert oauth.user_id == test_oauth.user_id
@pytest.mark.asyncio
async def test_get_by_provider_user_id_nonexistent(
self,
user_oauth_repository: UserOauthRepository,
) -> None:
"""Test getting OAuth by provider user ID when it doesn't exist."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "nonexistent_id"
)
assert oauth is None
@pytest.mark.asyncio
async def test_get_by_user_id_and_provider_existing(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
test_user_id: int,
) -> None:
"""Test getting OAuth by user ID and provider when it exists."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
)
assert oauth is not None
assert oauth.id == test_oauth.id
assert oauth.provider == "google"
assert oauth.user_id == test_user_id
@pytest.mark.asyncio
async def test_get_by_user_id_and_provider_nonexistent(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test getting OAuth by user ID and provider when it doesn't exist."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
)
assert oauth is None
@pytest.mark.asyncio
async def test_create_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test creating a new OAuth record."""
oauth_data = {
"user_id": test_user_id,
"provider": "github",
"provider_user_id": "github_789",
"email": "test@github.com",
"name": "Test User GitHub",
"picture": None,
}
oauth = await user_oauth_repository.create(oauth_data)
assert oauth.id is not None
assert oauth.user_id == test_user_id
assert oauth.provider == "github"
assert oauth.provider_user_id == "github_789"
assert oauth.email == "test@github.com"
assert oauth.name == "Test User GitHub"
@pytest.mark.asyncio
async def test_update_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
) -> None:
"""Test updating an OAuth record."""
update_data = {
"email": "updated@gmail.com",
"name": "Updated User Name",
"picture": "https://example.com/photo.jpg",
}
updated_oauth = await user_oauth_repository.update(test_oauth, update_data)
assert updated_oauth.id == test_oauth.id
assert updated_oauth.email == "updated@gmail.com"
assert updated_oauth.name == "Updated User Name"
assert updated_oauth.picture == "https://example.com/photo.jpg"
assert updated_oauth.provider == test_oauth.provider # Unchanged
assert updated_oauth.provider_user_id == test_oauth.provider_user_id # Unchanged
@pytest.mark.asyncio
async def test_delete_oauth(
self,
user_oauth_repository: UserOauthRepository,
test_session: AsyncSession,
test_user_id: int,
) -> None:
"""Test deleting an OAuth record."""
# Create an OAuth record to delete
oauth_data = {
"user_id": test_user_id,
"provider": "twitter",
"provider_user_id": "twitter_456",
"email": "test@twitter.com",
"name": "Test User Twitter",
"picture": None,
}
oauth = await user_oauth_repository.create(oauth_data)
oauth_id = oauth.id
# Delete the OAuth record
await user_oauth_repository.delete(oauth)
# Verify it's deleted by trying to find it
deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
"twitter", "twitter_456"
)
assert deleted_oauth is None
@pytest.mark.asyncio
async def test_create_duplicate_provider_user_id(
self,
user_oauth_repository: UserOauthRepository,
test_oauth: UserOauth,
test_user_id: int,
) -> None:
"""Test creating OAuth with duplicate provider user ID should fail."""
# Try to create another OAuth with the same provider and provider_user_id
duplicate_oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_123456", # Same as test_oauth
"email": "another@gmail.com",
"name": "Another User",
"picture": None,
}
# This should fail due to unique constraint
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
await user_oauth_repository.create(duplicate_oauth_data)
@pytest.mark.asyncio
async def test_multiple_providers_same_user(
self,
user_oauth_repository: UserOauthRepository,
test_user_id: int,
) -> None:
"""Test that a user can have multiple OAuth providers."""
# Create Google OAuth
google_oauth_data = {
"user_id": test_user_id,
"provider": "google",
"provider_user_id": "google_user_1",
"email": "user@gmail.com",
"name": "Test User Google",
"picture": None,
}
google_oauth = await user_oauth_repository.create(google_oauth_data)
# Create GitHub OAuth for the same user
github_oauth_data = {
"user_id": test_user_id,
"provider": "github",
"provider_user_id": "github_user_1",
"email": "user@github.com",
"name": "Test User GitHub",
"picture": None,
}
github_oauth = await user_oauth_repository.create(github_oauth_data)
# Verify both exist by querying back from database
found_google = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
)
found_github = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
)
assert found_google is not None
assert found_github is not None
assert found_google.provider == "google"
assert found_github.provider == "github"
assert found_google.user_id == test_user_id
assert found_github.user_id == test_user_id
assert found_google.provider_user_id == "google_user_1"
assert found_github.provider_user_id == "github_user_1"
# Verify we can also find them by provider_user_id
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
"google", "google_user_1"
)
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
"github", "github_user_1"
)
assert found_google_by_provider is not None
assert found_github_by_provider is not None
assert found_google_by_provider.user_id == test_user_id
assert found_github_by_provider.user_id == test_user_id

View File

@@ -0,0 +1,358 @@
"""Tests for credit service."""
import json
from unittest.mock import AsyncMock, Mock, patch
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.credit_action import CreditActionType
from app.models.credit_transaction import CreditTransaction
from app.models.user import User
from app.services.credit import CreditService, InsufficientCreditsError
class TestCreditService:
"""Test credit service functionality."""
@pytest.fixture
def mock_db_session_factory(self):
"""Create a mock database session factory."""
session = AsyncMock(spec=AsyncSession)
return lambda: session
@pytest.fixture
def credit_service(self, mock_db_session_factory):
"""Create a credit service instance for testing."""
return CreditService(mock_db_session_factory)
@pytest.fixture
def sample_user(self):
"""Create a sample user for testing."""
return User(
id=1,
name="Test User",
email="test@example.com",
role="user",
credits=10,
plan_id=1,
)
@pytest.mark.asyncio
async def test_check_credits_sufficient(self, credit_service, sample_user):
"""Test checking credits when user has sufficient credits."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
assert result is True
mock_repo.get_by_id.assert_called_once_with(1)
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_check_credits_insufficient(self, credit_service):
"""Test checking credits when user has insufficient credits."""
mock_session = credit_service.db_session_factory()
poor_user = User(
id=1,
name="Poor User",
email="poor@example.com",
role="user",
credits=0, # No credits
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
assert result is False
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_check_credits_user_not_found(self, credit_service):
"""Test checking credits when user is not found."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND)
assert result is False
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_validate_and_reserve_credits_success(self, credit_service, sample_user):
"""Test successful credit validation and reservation."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
user, action = await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND
)
assert user == sample_user
assert action.action_type == CreditActionType.VLC_PLAY_SOUND
assert action.cost == 1
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_validate_and_reserve_credits_insufficient(self, credit_service):
"""Test credit validation with insufficient credits."""
mock_session = credit_service.db_session_factory()
poor_user = User(
id=1,
name="Poor User",
email="poor@example.com",
role="user",
credits=0,
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
with pytest.raises(InsufficientCreditsError) as exc_info:
await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND
)
assert exc_info.value.required == 1
assert exc_info.value.available == 0
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_validate_and_reserve_credits_user_not_found(self, credit_service):
"""Test credit validation when user is not found."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.validate_and_reserve_credits(
999, CreditActionType.VLC_PLAY_SOUND
)
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_deduct_credits_success(self, credit_service, sample_user):
"""Test successful credit deduction."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
)
# Verify user credits were updated
mock_repo.update.assert_called_once_with(sample_user, {"credits": 9})
# Verify transaction was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", {
"user_id": "1",
"credits_before": 10,
"credits_after": 9,
"credits_deducted": 1,
"action_type": "vlc_play_sound",
"success": True,
}
)
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert isinstance(added_transaction, CreditTransaction)
assert added_transaction.user_id == 1
assert added_transaction.action_type == "vlc_play_sound"
assert added_transaction.amount == -1
assert added_transaction.balance_before == 10
assert added_transaction.balance_after == 9
assert added_transaction.success is True
assert json.loads(added_transaction.metadata_json) == {"test": "data"}
@pytest.mark.asyncio
async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user):
"""Test credit deduction when action failed but requires success."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, False # Action failed
)
# Verify user credits were NOT updated (action requires success)
mock_repo.update.assert_not_called()
# Verify transaction was still created for auditing
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify no socket event was emitted since no credits were actually deducted
mock_socket_manager.send_to_user.assert_not_called()
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert added_transaction.amount == 0 # No deduction for failed action
assert added_transaction.balance_before == 10
assert added_transaction.balance_after == 10 # No change
assert added_transaction.success is False
@pytest.mark.asyncio
async def test_deduct_credits_insufficient(self, credit_service):
"""Test credit deduction with insufficient credits."""
mock_session = credit_service.db_session_factory()
poor_user = User(
id=1,
name="Poor User",
email="poor@example.com",
role="user",
credits=0,
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
mock_socket_manager.send_to_user = AsyncMock()
with pytest.raises(InsufficientCreditsError):
await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True
)
# Verify no socket event was emitted since credits could not be deducted
mock_socket_manager.send_to_user.assert_not_called()
mock_session.rollback.assert_called_once()
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_add_credits(self, credit_service, sample_user):
"""Test adding credits to user account."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.add_credits(
1, 5, "Bonus credits", {"reason": "signup"}
)
# Verify user credits were updated
mock_repo.update.assert_called_once_with(sample_user, {"credits": 15})
# Verify transaction was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", {
"user_id": "1",
"credits_before": 10,
"credits_after": 15,
"credits_added": 5,
"description": "Bonus credits",
"success": True,
}
)
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert added_transaction.amount == 5
assert added_transaction.balance_before == 10
assert added_transaction.balance_after == 15
assert added_transaction.description == "Bonus credits"
@pytest.mark.asyncio
async def test_add_credits_invalid_amount(self, credit_service):
"""Test adding invalid amount of credits."""
with pytest.raises(ValueError, match="Amount must be positive"):
await credit_service.add_credits(1, 0, "Invalid")
with pytest.raises(ValueError, match="Amount must be positive"):
await credit_service.add_credits(1, -5, "Invalid")
@pytest.mark.asyncio
async def test_get_user_balance(self, credit_service, sample_user):
"""Test getting user credit balance."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
balance = await credit_service.get_user_balance(1)
assert balance == 10
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_get_user_balance_user_not_found(self, credit_service):
"""Test getting balance for non-existent user."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.get_user_balance(999)
mock_session.close.assert_called_once()
class TestInsufficientCreditsError:
"""Test InsufficientCreditsError exception."""
def test_insufficient_credits_error_creation(self):
"""Test creating InsufficientCreditsError."""
error = InsufficientCreditsError(5, 2)
assert error.required == 5
assert error.available == 2
assert str(error) == "Insufficient credits: 5 required, 2 available"

View File

@@ -21,6 +21,7 @@ from app.services.player import (
initialize_player_service,
shutdown_player_service,
)
from app.utils.audio import get_sound_file_path
class TestPlayerState:
@@ -196,7 +197,7 @@ class TestPlayerService:
)
player_service.state.playlist_sounds = [sound]
with patch.object(player_service, "_get_sound_file_path") as mock_path:
with patch("app.services.player.get_sound_file_path") as mock_path:
mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = True
mock_path.return_value = mock_file_path
@@ -385,51 +386,6 @@ class TestPlayerService:
assert player_service.state.playlist_length == 2
assert player_service.state.playlist_duration == 75000
def test_get_sound_file_path_normalized(self, player_service):
"""Test getting file path for normalized sound."""
sound = Sound(
id=1,
name="Test Song",
filename="original.mp3",
normalized_filename="normalized.mp3",
is_normalized=True,
type="SDB",
)
result = player_service._get_sound_file_path(sound)
expected = Path("sounds/normalized/sdb/normalized.mp3")
assert result == expected
def test_get_sound_file_path_original(self, player_service):
"""Test getting file path for original sound."""
sound = Sound(
id=1,
name="Test Song",
filename="original.mp3",
is_normalized=False,
type="SDB",
)
result = player_service._get_sound_file_path(sound)
expected = Path("sounds/originals/sdb/original.mp3")
assert result == expected
def test_get_sound_file_path_ext_type(self, player_service):
"""Test getting file path for EXT type sound."""
sound = Sound(
id=1,
name="Test Song",
filename="extracted.mp3",
is_normalized=False,
type="EXT",
)
result = player_service._get_sound_file_path(sound)
expected = Path("sounds/originals/extracted/extracted.mp3")
assert result == expected
def test_get_next_index_continuous_mode(self, player_service):
"""Test getting next index in continuous mode."""
@@ -538,36 +494,24 @@ class TestPlayerService:
# Mock repositories
with patch("app.services.player.SoundRepository") as mock_sound_repo_class:
with patch("app.services.player.UserRepository") as mock_user_repo_class:
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
mock_sound_repo_class.return_value = mock_sound_repo
mock_user_repo_class.return_value = mock_user_repo
mock_sound_repo = AsyncMock()
mock_sound_repo_class.return_value = mock_sound_repo
# Mock sound and user
mock_sound = Mock()
mock_sound.play_count = 5
mock_sound_repo.get_by_id.return_value = mock_sound
# Mock sound
mock_sound = Mock()
mock_sound.play_count = 5
mock_sound_repo.get_by_id.return_value = mock_sound
mock_user = Mock()
mock_user.id = 1
mock_user_repo.get_by_id.return_value = mock_user
await player_service._record_play_count(1)
# Mock no existing SoundPlayed record
mock_result = Mock()
mock_result.first.return_value = None
mock_session.exec.return_value = mock_result
# Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with(
mock_sound, {"play_count": 6}
)
await player_service._record_play_count(1)
# Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with(
mock_sound, {"play_count": 6}
)
# Verify SoundPlayed record was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify SoundPlayed record was created with None user_id for player
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_get_state(self, player_service):
"""Test getting current player state."""
@@ -577,6 +521,27 @@ class TestPlayerService:
assert "mode" in result
assert "volume" in result
def test_uses_shared_sound_path_utility(self, player_service):
"""Test that player service uses the shared sound path utility."""
sound = Sound(
id=1,
name="Test Song",
filename="test.mp3",
type="SDB",
is_normalized=False,
)
player_service.state.playlist_sounds = [sound]
with patch("app.services.player.get_sound_file_path") as mock_path:
mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = False # File doesn't exist
mock_path.return_value = mock_file_path
# This should fail because file doesn't exist
result = asyncio.run(player_service.play(0))
# Verify the utility was called
mock_path.assert_called_once_with(sound)
class TestPlayerServiceGlobalFunctions:
"""Test global player service functions."""

View File

@@ -0,0 +1,511 @@
"""Tests for VLC player service."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
import pytest
from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.models.user import User
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
from app.utils.audio import get_sound_file_path
class TestVLCPlayerService:
"""Test VLC player service."""
@pytest.fixture
def vlc_service(self):
"""Create a VLC service instance."""
with patch("app.services.vlc_player.subprocess.run") as mock_run:
# Mock VLC executable detection
mock_run.return_value.returncode = 0
return VLCPlayerService()
@pytest.fixture
def vlc_service_with_db(self):
"""Create a VLC service instance with database session factory."""
with patch("app.services.vlc_player.subprocess.run") as mock_run:
# Mock VLC executable detection
mock_run.return_value.returncode = 0
mock_session_factory = Mock()
return VLCPlayerService(mock_session_factory)
@pytest.fixture
def sample_sound(self):
"""Create a sample sound for testing."""
return Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test_audio.mp3",
duration=5000,
size=1024,
hash="test_hash",
is_normalized=False,
normalized_filename=None,
)
@pytest.fixture
def normalized_sound(self):
"""Create a normalized sound for testing."""
return Sound(
id=2,
type="TTS",
name="Normalized Sound",
filename="original.wav",
duration=7500,
size=2048,
hash="normalized_hash",
is_normalized=True,
normalized_filename="normalized.mp3",
)
def test_init(self, vlc_service):
"""Test VLC service initialization."""
assert vlc_service.vlc_executable is not None
assert isinstance(vlc_service.vlc_executable, str)
@patch("app.services.vlc_player.subprocess.run")
def test_find_vlc_executable_found_in_path(self, mock_run):
"""Test VLC executable detection when found in PATH."""
mock_run.return_value.returncode = 0
service = VLCPlayerService()
assert service.vlc_executable == "vlc"
@patch("app.services.vlc_player.subprocess.run")
def test_find_vlc_executable_found_by_path(self, mock_run):
"""Test VLC executable detection when found by absolute path."""
mock_run.return_value.returncode = 1 # which command fails
# Mock Path to return True for the first absolute path
with patch("app.services.vlc_player.Path") as mock_path:
def path_side_effect(path_str):
mock_instance = Mock()
mock_instance.exists.return_value = str(path_str) == "/usr/bin/vlc"
return mock_instance
mock_path.side_effect = path_side_effect
service = VLCPlayerService()
assert service.vlc_executable == "/usr/bin/vlc"
@patch("app.services.vlc_player.subprocess.run")
@patch("app.services.vlc_player.Path")
def test_find_vlc_executable_fallback(self, mock_path, mock_run):
"""Test VLC executable detection fallback to default."""
# Mock all paths as non-existent
mock_path_instance = Mock()
mock_path_instance.exists.return_value = False
mock_path.return_value = mock_path_instance
# Mock which command as failing
mock_run.return_value.returncode = 1
service = VLCPlayerService()
assert service.vlc_executable == "vlc"
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_success(
self, mock_subprocess, vlc_service, sample_sound
):
"""Test successful sound playback."""
# Mock subprocess
mock_process = Mock()
mock_process.pid = 12345
mock_subprocess.return_value = mock_process
# Mock the file path utility to avoid Path issues
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
mock_path = Mock()
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
result = await vlc_service.play_sound(sample_sound)
assert result is True
mock_subprocess.assert_called_once()
args = mock_subprocess.call_args
# Check command arguments
cmd_args = args[1] # keyword arguments
assert "--play-and-exit" in args[0]
assert "--intf" in args[0]
assert "dummy" in args[0]
assert "--no-video" in args[0]
assert "--no-repeat" in args[0]
assert "--no-loop" in args[0]
assert cmd_args["stdout"] == asyncio.subprocess.DEVNULL
assert cmd_args["stderr"] == asyncio.subprocess.DEVNULL
@pytest.mark.asyncio
async def test_play_sound_file_not_found(
self, vlc_service, sample_sound
):
"""Test sound playback when file doesn't exist."""
# Mock the file path utility to return a non-existent path
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
mock_path = Mock()
mock_path.exists.return_value = False
mock_get_path.return_value = mock_path
result = await vlc_service.play_sound(sample_sound)
assert result is False
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_subprocess_error(
self, mock_subprocess, vlc_service, sample_sound
):
"""Test sound playback when subprocess fails."""
# Mock the file path utility to return an existing path
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
mock_path = Mock()
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
# Mock subprocess exception
mock_subprocess.side_effect = Exception("Subprocess failed")
result = await vlc_service.play_sound(sample_sound)
assert result is False
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_success(self, mock_subprocess, vlc_service):
"""Test successful stopping of all VLC instances."""
# Mock pgrep process (find VLC processes)
mock_find_process = Mock()
mock_find_process.returncode = 0
mock_find_process.communicate = AsyncMock(
return_value=(b"12345\n67890\n", b"")
)
# Mock pkill process (kill VLC processes)
mock_kill_process = Mock()
mock_kill_process.communicate = AsyncMock(return_value=(b"", b""))
# Mock verify process (check remaining processes)
mock_verify_process = Mock()
mock_verify_process.returncode = 1 # No processes found
mock_verify_process.communicate = AsyncMock(return_value=(b"", b""))
# Set up subprocess mock to return different processes for each call
mock_subprocess.side_effect = [
mock_find_process,
mock_kill_process,
mock_verify_process,
]
result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True
assert result["processes_found"] == 2
assert result["processes_killed"] == 2
assert result["processes_remaining"] == 0
assert "Killed 2 VLC processes" in result["message"]
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_no_processes(
self, mock_subprocess, vlc_service
):
"""Test stopping VLC instances when none are running."""
# Mock pgrep process (no VLC processes found)
mock_find_process = Mock()
mock_find_process.returncode = 1 # No processes found
mock_find_process.communicate = AsyncMock(return_value=(b"", b""))
mock_subprocess.return_value = mock_find_process
result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True
assert result["processes_found"] == 0
assert result["processes_killed"] == 0
assert result["message"] == "No VLC processes found"
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_partial_kill(
self, mock_subprocess, vlc_service
):
"""Test stopping VLC instances when some processes remain."""
# Mock pgrep process (find VLC processes)
mock_find_process = Mock()
mock_find_process.returncode = 0
mock_find_process.communicate = AsyncMock(
return_value=(b"12345\n67890\n11111\n", b"")
)
# Mock pkill process (kill VLC processes)
mock_kill_process = Mock()
mock_kill_process.communicate = AsyncMock(return_value=(b"", b""))
# Mock verify process (one process remains)
mock_verify_process = Mock()
mock_verify_process.returncode = 0
mock_verify_process.communicate = AsyncMock(return_value=(b"11111\n", b""))
mock_subprocess.side_effect = [
mock_find_process,
mock_kill_process,
mock_verify_process,
]
result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True
assert result["processes_found"] == 3
assert result["processes_killed"] == 2
assert result["processes_remaining"] == 1
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_error(self, mock_subprocess, vlc_service):
"""Test stopping VLC instances when an error occurs."""
# Mock subprocess exception
mock_subprocess.side_effect = Exception("Command failed")
result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is False
assert result["processes_found"] == 0
assert result["processes_killed"] == 0
assert "error" in result
assert result["message"] == "Failed to stop VLC processes"
def test_get_vlc_player_service_singleton(self):
"""Test that get_vlc_player_service returns the same instance."""
with patch("app.services.vlc_player.VLCPlayerService") as mock_service_class:
mock_instance = Mock()
mock_service_class.return_value = mock_instance
# Clear the global instance
import app.services.vlc_player
app.services.vlc_player.vlc_player_service = None
# First call should create new instance
service1 = get_vlc_player_service()
assert service1 == mock_instance
mock_service_class.assert_called_once()
# Second call should return same instance
service2 = get_vlc_player_service()
assert service2 == mock_instance
assert service1 is service2
# Constructor should not be called again
mock_service_class.assert_called_once()
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_with_play_count_tracking(
self, mock_subprocess, vlc_service_with_db, sample_sound
):
"""Test sound playback with play count tracking."""
# Mock subprocess
mock_process = Mock()
mock_process.pid = 12345
mock_subprocess.return_value = mock_process
# Mock session and repositories
mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session
# Mock repositories
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket:
with patch("app.services.vlc_player.select") as mock_select:
# Mock the file path utility
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
mock_path = Mock()
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
# Mock sound repository responses
updated_sound = Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test.mp3",
duration=5000,
size=1024,
hash="test_hash",
play_count=1, # Updated count
)
mock_sound_repo.get_by_id.return_value = sample_sound
mock_sound_repo.update.return_value = updated_sound
# Mock admin user
admin_user = User(
id=1,
email="admin@test.com",
name="Admin User",
role="admin",
)
mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock()
result = await vlc_service_with_db.play_sound(sample_sound)
# Wait a bit for the async task to complete
await asyncio.sleep(0.1)
assert result is True
# Verify subprocess was called
mock_subprocess.assert_called_once()
# Note: The async task runs in the background, so we can't easily
# verify the database operations in this test without more complex
# mocking or using a real async test framework setup
@pytest.mark.asyncio
async def test_record_play_count_success(self, vlc_service_with_db):
"""Test successful play count recording."""
# Mock session and repositories
mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
# Create test sound and user
test_sound = Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test.mp3",
duration=5000,
size=1024,
hash="test_hash",
play_count=0,
)
admin_user = User(
id=1,
email="admin@test.com",
name="Admin User",
role="admin",
)
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket:
with patch("app.services.vlc_player.select") as mock_select:
# Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound
mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock()
await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound repository calls
mock_sound_repo.get_by_id.assert_called_once_with(1)
mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 1}
)
# Verify user repository calls
mock_user_repo.get_by_id.assert_called_once_with(1)
# Verify session operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.close.assert_called_once()
# Verify socket broadcast
mock_socket.broadcast_to_all.assert_called_once_with(
"sound_played",
{
"sound_id": 1,
"sound_name": "Test Sound",
"user_id": 1,
"play_count": 1,
},
)
@pytest.mark.asyncio
async def test_record_play_count_no_session_factory(self, vlc_service):
"""Test play count recording when no session factory is available."""
# This should not raise an error and should log a warning
await vlc_service._record_play_count(1, "Test Sound")
# The method should return early without doing anything
@pytest.mark.asyncio
async def test_record_play_count_always_creates_record(self, vlc_service_with_db):
"""Test play count recording always creates a new SoundPlayed record."""
# Mock session and repositories
mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
# Create test sound and user
test_sound = Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test.mp3",
duration=5000,
size=1024,
hash="test_hash",
play_count=5,
)
admin_user = User(
id=1,
email="admin@test.com",
name="Admin User",
role="admin",
)
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket:
# Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound
mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock()
await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 6}
)
# Verify new SoundPlayed record was always added
mock_session.add.assert_called_once()
# Verify commit happened
mock_session.commit.assert_called_once()
def test_uses_shared_sound_path_utility(self, vlc_service, sample_sound):
"""Test that VLC service uses the shared sound path utility."""
with patch("app.services.vlc_player.get_sound_file_path") as mock_path:
mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = False # File doesn't exist
mock_path.return_value = mock_file_path
# This should fail because file doesn't exist
result = asyncio.run(vlc_service.play_sound(sample_sound))
# Verify the utility was called and returned False
mock_path.assert_called_once_with(sample_sound)
assert result is False

View File

@@ -7,7 +7,8 @@ from unittest.mock import patch
import pytest
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
from app.models.sound import Sound
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size, get_sound_file_path
class TestAudioUtils:
@@ -290,3 +291,120 @@ class TestAudioUtils:
# Should raise FileNotFoundError for nonexistent file
with pytest.raises(FileNotFoundError):
get_file_size(nonexistent_path)
def test_get_sound_file_path_sdb_original(self):
"""Test getting sound file path for SDB type original file."""
sound = Sound(
id=1,
name="Test Sound",
filename="test.mp3",
type="SDB",
is_normalized=False,
)
result = get_sound_file_path(sound)
expected = Path("sounds/originals/soundboard/test.mp3")
assert result == expected
def test_get_sound_file_path_sdb_normalized(self):
"""Test getting sound file path for SDB type normalized file."""
sound = Sound(
id=1,
name="Test Sound",
filename="original.mp3",
normalized_filename="normalized.mp3",
type="SDB",
is_normalized=True,
)
result = get_sound_file_path(sound)
expected = Path("sounds/normalized/soundboard/normalized.mp3")
assert result == expected
def test_get_sound_file_path_tts_original(self):
"""Test getting sound file path for TTS type original file."""
sound = Sound(
id=2,
name="TTS Sound",
filename="tts_file.wav",
type="TTS",
is_normalized=False,
)
result = get_sound_file_path(sound)
expected = Path("sounds/originals/text_to_speech/tts_file.wav")
assert result == expected
def test_get_sound_file_path_tts_normalized(self):
"""Test getting sound file path for TTS type normalized file."""
sound = Sound(
id=2,
name="TTS Sound",
filename="original.wav",
normalized_filename="normalized.mp3",
type="TTS",
is_normalized=True,
)
result = get_sound_file_path(sound)
expected = Path("sounds/normalized/text_to_speech/normalized.mp3")
assert result == expected
def test_get_sound_file_path_ext_original(self):
"""Test getting sound file path for EXT type original file."""
sound = Sound(
id=3,
name="Extracted Sound",
filename="extracted.mp3",
type="EXT",
is_normalized=False,
)
result = get_sound_file_path(sound)
expected = Path("sounds/originals/extracted/extracted.mp3")
assert result == expected
def test_get_sound_file_path_ext_normalized(self):
"""Test getting sound file path for EXT type normalized file."""
sound = Sound(
id=3,
name="Extracted Sound",
filename="original.mp3",
normalized_filename="normalized.mp3",
type="EXT",
is_normalized=True,
)
result = get_sound_file_path(sound)
expected = Path("sounds/normalized/extracted/normalized.mp3")
assert result == expected
def test_get_sound_file_path_unknown_type_fallback(self):
"""Test getting sound file path for unknown type falls back to lowercase."""
sound = Sound(
id=4,
name="Unknown Type Sound",
filename="unknown.mp3",
type="CUSTOM",
is_normalized=False,
)
result = get_sound_file_path(sound)
expected = Path("sounds/originals/custom/unknown.mp3")
assert result == expected
def test_get_sound_file_path_normalized_without_filename(self):
"""Test getting sound file path when normalized but no normalized_filename."""
sound = Sound(
id=5,
name="Test Sound",
filename="original.mp3",
normalized_filename=None,
type="SDB",
is_normalized=True, # True but no normalized_filename
)
result = get_sound_file_path(sound)
# Should fall back to original file
expected = Path("sounds/originals/soundboard/original.mp3")
assert result == expected

View File

@@ -0,0 +1,277 @@
"""Tests for credit decorators."""
from unittest.mock import AsyncMock, Mock
import pytest
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError
from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only
class TestRequiresCreditsDecorator:
"""Test requires_credits decorator."""
@pytest.fixture
def mock_credit_service(self):
"""Create a mock credit service."""
service = AsyncMock(spec=CreditService)
service.validate_and_reserve_credits = AsyncMock()
service.deduct_credits = AsyncMock()
return service
@pytest.fixture
def credit_service_factory(self, mock_credit_service):
"""Create a credit service factory."""
return lambda: mock_credit_service
@pytest.mark.asyncio
async def test_decorator_success(self, credit_service_factory, mock_credit_service):
"""Test decorator with successful action."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int, message: str) -> str:
return f"Success: {message}"
result = await test_action(user_id=123, message="test")
assert result == "Success: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, None
)
@pytest.mark.asyncio
async def test_decorator_with_metadata(self, credit_service_factory, mock_credit_service):
"""Test decorator with metadata extraction."""
def extract_metadata(user_id: int, sound_name: str) -> dict:
return {"sound_name": sound_name}
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id",
metadata_extractor=extract_metadata
)
async def test_action(user_id: int, sound_name: str) -> bool:
return True
await test_action(user_id=123, sound_name="test.mp3")
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"}
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"}
)
@pytest.mark.asyncio
async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service):
"""Test decorator with failed action."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int) -> bool:
return False # Action fails
result = await test_action(user_id=123)
assert result is False
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
)
@pytest.mark.asyncio
async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service):
"""Test decorator when action raises exception."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int) -> str:
raise ValueError("Test error")
with pytest.raises(ValueError, match="Test error"):
await test_action(user_id=123)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
)
@pytest.mark.asyncio
async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service):
"""Test decorator with insufficient credits."""
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int) -> str:
return "Should not execute"
with pytest.raises(InsufficientCreditsError):
await test_action(user_id=123)
# Should not call deduct_credits since validation failed
mock_credit_service.deduct_credits.assert_not_called()
@pytest.mark.asyncio
async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service):
"""Test decorator extracting user_id from positional args."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int, message: str) -> str:
return message
result = await test_action(123, "test")
assert result == "test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None
)
@pytest.mark.asyncio
async def test_decorator_missing_user_id(self, credit_service_factory):
"""Test decorator when user_id cannot be extracted."""
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(other_param: str) -> str:
return other_param
with pytest.raises(ValueError, match="Could not extract user_id"):
await test_action(other_param="test")
class TestValidateCreditsOnlyDecorator:
"""Test validate_credits_only decorator."""
@pytest.fixture
def mock_credit_service(self):
"""Create a mock credit service."""
service = AsyncMock(spec=CreditService)
service.validate_and_reserve_credits = AsyncMock()
return service
@pytest.fixture
def credit_service_factory(self, mock_credit_service):
"""Create a credit service factory."""
return lambda: mock_credit_service
@pytest.mark.asyncio
async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service):
"""Test validate_credits_only decorator."""
@validate_credits_only(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
)
async def test_action(user_id: int, message: str) -> str:
return f"Validated: {message}"
result = await test_action(user_id=123, message="test")
assert result == "Validated: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND
)
# Should not deduct credits, only validate
mock_credit_service.deduct_credits.assert_not_called()
class TestCreditManager:
"""Test CreditManager context manager."""
@pytest.fixture
def mock_credit_service(self):
"""Create a mock credit service."""
service = AsyncMock(spec=CreditService)
service.validate_and_reserve_credits = AsyncMock()
service.deduct_credits = AsyncMock()
return service
@pytest.mark.asyncio
async def test_credit_manager_success(self, mock_credit_service):
"""Test CreditManager with successful operation."""
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND,
{"test": "data"}
) as manager:
manager.mark_success()
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"}
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
)
@pytest.mark.asyncio
async def test_credit_manager_failure(self, mock_credit_service):
"""Test CreditManager with failed operation."""
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
):
# Don't mark as success - should be considered failed
pass
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
)
@pytest.mark.asyncio
async def test_credit_manager_exception(self, mock_credit_service):
"""Test CreditManager when exception occurs."""
with pytest.raises(ValueError, match="Test error"):
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
):
raise ValueError("Test error")
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
)
@pytest.mark.asyncio
async def test_credit_manager_validation_failure(self, mock_credit_service):
"""Test CreditManager when validation fails."""
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0)
with pytest.raises(InsufficientCreditsError):
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
):
pass
# Should not call deduct_credits since validation failed
mock_credit_service.deduct_credits.assert_not_called()