Refactor test files for improved readability and consistency

- Removed unnecessary blank lines and adjusted formatting in test files.
- Ensured consistent use of commas in function calls and assertions across various test cases.
- Updated import statements for better organization and clarity.
- Enhanced mock setups in tests for better isolation and reliability.
- Improved assertions to follow a consistent style for better readability.
This commit is contained in:
JSC
2025-07-31 21:37:04 +02:00
parent e69098d633
commit 8847131f24
42 changed files with 602 additions and 616 deletions

View File

@@ -1,6 +1,6 @@
"""Player API endpoints.""" """Player API endpoints."""
from typing import Annotated, Any from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
@@ -214,4 +214,4 @@ async def get_state(
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get player state", detail="Failed to get player state",
) from e ) from e

View File

@@ -10,12 +10,12 @@ from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType from app.models.credit_action import CreditActionType
from app.models.user import User from app.models.user import User
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.credit import CreditService, InsufficientCreditsError from app.services.credit import CreditService, InsufficientCreditsError
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.extraction_processor import extraction_processor from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService from app.services.sound_scanner import ScanResults, SoundScannerService
from app.services.vlc_player import get_vlc_player_service, VLCPlayerService from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
router = APIRouter(prefix="/sounds", tags=["sounds"]) router = APIRouter(prefix="/sounds", tags=["sounds"])
@@ -125,13 +125,13 @@ async def scan_custom_directory(
async def normalize_all_sounds( async def normalize_all_sounds(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service) SoundNormalizerService, Depends(get_sound_normalizer_service),
], ],
force: bool = Query( force: bool = Query(
False, description="Force normalization of already normalized sounds" False, description="Force normalization of already normalized sounds",
), ),
one_pass: bool | None = Query( one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)" None, description="Use one-pass normalization (overrides config)",
), ),
) -> dict[str, NormalizationResults | str]: ) -> dict[str, NormalizationResults | str]:
"""Normalize all unnormalized sounds.""" """Normalize all unnormalized sounds."""
@@ -163,13 +163,13 @@ async def normalize_sounds_by_type(
sound_type: str, sound_type: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service) SoundNormalizerService, Depends(get_sound_normalizer_service),
], ],
force: bool = Query( force: bool = Query(
False, description="Force normalization of already normalized sounds" False, description="Force normalization of already normalized sounds",
), ),
one_pass: bool | None = Query( one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)" None, description="Use one-pass normalization (overrides config)",
), ),
) -> dict[str, NormalizationResults | str]: ) -> dict[str, NormalizationResults | str]:
"""Normalize all sounds of a specific type (SDB, TTS, EXT).""" """Normalize all sounds of a specific type (SDB, TTS, EXT)."""
@@ -210,13 +210,13 @@ async def normalize_sound_by_id(
sound_id: int, sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service) SoundNormalizerService, Depends(get_sound_normalizer_service),
], ],
force: bool = Query( force: bool = Query(
False, description="Force normalization of already normalized sound" False, description="Force normalization of already normalized sound",
), ),
one_pass: bool | None = Query( one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)" None, description="Use one-pass normalization (overrides config)",
), ),
) -> dict[str, str]: ) -> dict[str, str]:
"""Normalize a specific sound by ID.""" """Normalize a specific sound by ID."""
@@ -283,7 +283,7 @@ async def create_extraction(
) )
extraction_info = await extraction_service.create_extraction( extraction_info = await extraction_service.create_extraction(
url, current_user.id url, current_user.id,
) )
# Queue the extraction for background processing # Queue the extraction for background processing
@@ -398,7 +398,7 @@ async def play_sound_with_vlc(
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
current_user.id, current_user.id,
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
{"sound_id": sound_id, "sound_name": sound.name} {"sound_id": sound_id, "sound_name": sound.name},
) )
except InsufficientCreditsError as e: except InsufficientCreditsError as e:
raise HTTPException( raise HTTPException(
@@ -408,7 +408,7 @@ async def play_sound_with_vlc(
# Play the sound using VLC # Play the sound using VLC
success = await vlc_player.play_sound(sound) success = await vlc_player.play_sound(sound)
# Deduct credits based on success # Deduct credits based on success
await credit_service.deduct_credits( await credit_service.deduct_credits(
current_user.id, current_user.id,
@@ -416,7 +416,7 @@ async def play_sound_with_vlc(
success, success,
{"sound_id": sound_id, "sound_name": sound.name}, {"sound_id": sound_id, "sound_name": sound.name},
) )
if not success: if not success:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View File

@@ -118,4 +118,4 @@ def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
Dictionary of all credit actions Dictionary of all credit actions
""" """
return CREDIT_ACTIONS.copy() return CREDIT_ACTIONS.copy()

View File

@@ -26,4 +26,4 @@ class CreditTransaction(BaseModel, table=True):
metadata_json: str | None = Field(default=None) metadata_json: str | None = Field(default=None)
# relationships # relationships
user: "User" = Relationship(back_populates="credit_transactions") user: "User" = Relationship(back_populates="credit_transactions")

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint from sqlmodel import Field, Relationship
from app.models.base import BaseModel from app.models.base import BaseModel

View File

@@ -39,7 +39,7 @@ __all__ = [
"UserResponse", "UserResponse",
# Common schemas # Common schemas
"HealthResponse", "HealthResponse",
"MessageResponse", "MessageResponse",
"StatusResponse", "StatusResponse",
# Player schemas # Player schemas
"PlayerModeRequest", "PlayerModeRequest",

View File

@@ -18,4 +18,4 @@ class StatusResponse(BaseModel):
class HealthResponse(BaseModel): class HealthResponse(BaseModel):
"""Health check response.""" """Health check response."""
status: str = Field(description="Health status") status: str = Field(description="Health status")

View File

@@ -30,10 +30,10 @@ class PlayerStateResponse(BaseModel):
status: str = Field(description="Player status (playing, paused, stopped)") status: str = Field(description="Player status (playing, paused, stopped)")
current_sound: dict[str, Any] | None = Field( current_sound: dict[str, Any] | None = Field(
None, description="Current sound information" None, description="Current sound information",
) )
playlist: dict[str, Any] | None = Field( playlist: dict[str, Any] | None = Field(
None, description="Current playlist information" None, description="Current playlist information",
) )
position: int = Field(description="Current position in milliseconds") position: int = Field(description="Current position in milliseconds")
duration: int | None = Field( duration: int | None = Field(

View File

@@ -1,6 +1,6 @@
"""Playlist schemas.""" """Playlist schemas."""
from pydantic import BaseModel, Field from pydantic import BaseModel
from app.models.playlist import Playlist from app.models.playlist import Playlist
from app.models.sound import Sound from app.models.sound import Sound

View File

@@ -30,7 +30,7 @@ class InsufficientCreditsError(Exception):
self.required = required self.required = required
self.available = available self.available = available
super().__init__( super().__init__(
f"Insufficient credits: {required} required, {available} available" f"Insufficient credits: {required} required, {available} available",
) )
@@ -138,10 +138,10 @@ class CreditService:
""" """
action = get_credit_action(action_type) action = get_credit_action(action_type)
# Only deduct if action requires success and was successful, or doesn't require success # Only deduct if action requires success and was successful, or doesn't require success
should_deduct = (action.requires_success and success) or not action.requires_success should_deduct = (action.requires_success and success) or not action.requires_success
if not should_deduct: if not should_deduct:
logger.info( logger.info(
"Skipping credit deduction for user %s: action %s failed and requires success", "Skipping credit deduction for user %s: action %s failed and requires success",
@@ -150,7 +150,7 @@ class CreditService:
) )
# Still create a transaction record for auditing # Still create a transaction record for auditing
return await self._create_transaction_record( return await self._create_transaction_record(
user_id, action, 0, success, metadata user_id, action, 0, success, metadata,
) )
session = self.db_session_factory() session = self.db_session_factory()
@@ -380,4 +380,4 @@ class CreditService:
raise ValueError(msg) raise ValueError(msg)
return user.credits return user.credits
finally: finally:
await session.close() await session.close()

View File

@@ -10,7 +10,6 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings from app.core.config import settings
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.extraction import Extraction
from app.models.sound import Sound from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
@@ -155,7 +154,7 @@ class ExtractionService:
# Check if extraction already exists for this service # Check if extraction already exists for this service
existing = await self.extraction_repo.get_by_service_and_id( existing = await self.extraction_repo.get_by_service_and_id(
service_info["service"], service_info["service_id"] service_info["service"], service_info["service_id"],
) )
if existing and existing.id != extraction_id: if existing and existing.id != extraction_id:
error_msg = ( error_msg = (
@@ -180,7 +179,7 @@ class ExtractionService:
# Extract audio and thumbnail # Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media( audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url extraction_id, extraction_url,
) )
# Move files to final locations # Move files to final locations
@@ -238,7 +237,7 @@ class ExtractionService:
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
logger.exception( logger.exception(
"Failed to process extraction %d: %s", extraction_id, error_msg "Failed to process extraction %d: %s", extraction_id, error_msg,
) )
# Update extraction with error # Update extraction with error
@@ -262,14 +261,14 @@ class ExtractionService:
} }
async def _extract_media( async def _extract_media(
self, extraction_id: int, extraction_url: str self, extraction_id: int, extraction_url: str,
) -> tuple[Path, Path | None]: ) -> tuple[Path, Path | None]:
"""Extract audio and thumbnail using yt-dlp.""" """Extract audio and thumbnail using yt-dlp."""
temp_dir = Path(settings.EXTRACTION_TEMP_DIR) temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
# Create unique filename based on extraction ID # Create unique filename based on extraction ID
output_template = str( output_template = str(
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s" temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s",
) )
# Configure yt-dlp options # Configure yt-dlp options
@@ -304,8 +303,8 @@ class ExtractionService:
# Find the extracted files # Find the extracted files
audio_files = list( audio_files = list(
temp_dir.glob( temp_dir.glob(
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}" f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}",
) ),
) )
thumbnail_files = ( thumbnail_files = (
list(temp_dir.glob(f"extraction_{extraction_id}_*.webp")) list(temp_dir.glob(f"extraction_{extraction_id}_*.webp"))
@@ -342,7 +341,7 @@ class ExtractionService:
"""Move extracted files to their final locations.""" """Move extracted files to their final locations."""
# Generate clean filename based on title and service # Generate clean filename based on title and service
safe_title = self._sanitize_filename( safe_title = self._sanitize_filename(
title or f"{service or 'unknown'}_{service_id or 'unknown'}" title or f"{service or 'unknown'}_{service_id or 'unknown'}",
) )
# Move audio file # Move audio file

View File

@@ -46,9 +46,9 @@ class ExtractionProcessor:
if self.processor_task and not self.processor_task.done(): if self.processor_task and not self.processor_task.done():
try: try:
await asyncio.wait_for(self.processor_task, timeout=30.0) await asyncio.wait_for(self.processor_task, timeout=30.0)
except asyncio.TimeoutError: except TimeoutError:
logger.warning( logger.warning(
"Extraction processor did not stop gracefully, cancelling..." "Extraction processor did not stop gracefully, cancelling...",
) )
self.processor_task.cancel() self.processor_task.cancel()
try: try:
@@ -66,7 +66,7 @@ class ExtractionProcessor:
# The processor will pick it up on the next cycle # The processor will pick it up on the next cycle
else: else:
logger.warning( logger.warning(
"Extraction %d is already being processed", extraction_id "Extraction %d is already being processed", extraction_id,
) )
async def _process_queue(self) -> None: async def _process_queue(self) -> None:
@@ -81,7 +81,7 @@ class ExtractionProcessor:
try: try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0) await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
break # Shutdown requested break # Shutdown requested
except asyncio.TimeoutError: except TimeoutError:
continue # Continue processing continue # Continue processing
except Exception as e: except Exception as e:
@@ -90,7 +90,7 @@ class ExtractionProcessor:
try: try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0) await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
break # Shutdown requested break # Shutdown requested
except asyncio.TimeoutError: except TimeoutError:
continue continue
logger.info("Extraction queue processor stopped") logger.info("Extraction queue processor stopped")
@@ -125,13 +125,13 @@ class ExtractionProcessor:
# Start processing this extraction in the background # Start processing this extraction in the background
task = asyncio.create_task( task = asyncio.create_task(
self._process_single_extraction(extraction_id) self._process_single_extraction(extraction_id),
) )
task.add_done_callback( task.add_done_callback(
lambda t, eid=extraction_id: self._on_extraction_completed( lambda t, eid=extraction_id: self._on_extraction_completed(
eid, eid,
t, t,
) ),
) )
logger.info( logger.info(

View File

@@ -49,7 +49,7 @@ class PlaylistService:
if not main_playlist: if not main_playlist:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Main playlist not found. Make sure to run database seeding." detail="Main playlist not found. Make sure to run database seeding.",
) )
return main_playlist return main_playlist
@@ -179,7 +179,7 @@ class PlaylistService:
return await self.playlist_repo.get_playlist_sounds(playlist_id) return await self.playlist_repo.get_playlist_sounds(playlist_id)
async def add_sound_to_playlist( async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None,
) -> None: ) -> None:
"""Add a sound to a playlist.""" """Add a sound to a playlist."""
# Verify playlist exists # Verify playlist exists
@@ -202,11 +202,11 @@ class PlaylistService:
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position) await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
logger.info( logger.info(
"Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id "Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id,
) )
async def remove_sound_from_playlist( async def remove_sound_from_playlist(
self, playlist_id: int, sound_id: int, user_id: int self, playlist_id: int, sound_id: int, user_id: int,
) -> None: ) -> None:
"""Remove a sound from a playlist.""" """Remove a sound from a playlist."""
# Verify playlist exists # Verify playlist exists
@@ -228,7 +228,7 @@ class PlaylistService:
) )
async def reorder_playlist_sounds( async def reorder_playlist_sounds(
self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]] self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]],
) -> None: ) -> None:
"""Reorder sounds in a playlist.""" """Reorder sounds in a playlist."""
# Verify playlist exists # Verify playlist exists
@@ -262,7 +262,7 @@ class PlaylistService:
await self._unset_current_playlist(user_id) await self._unset_current_playlist(user_id)
await self._set_main_as_current(user_id) await self._set_main_as_current(user_id)
logger.info( logger.info(
"Unset current playlist and set main as current for user %s", user_id "Unset current playlist and set main as current for user %s", user_id,
) )
async def get_playlist_stats(self, playlist_id: int) -> dict[str, Any]: async def get_playlist_stats(self, playlist_id: int) -> dict[str, Any]:
@@ -290,7 +290,7 @@ class PlaylistService:
# Check if sound is already in main playlist # Check if sound is already in main playlist
if not await self.playlist_repo.is_sound_in_playlist( if not await self.playlist_repo.is_sound_in_playlist(
main_playlist.id, sound_id main_playlist.id, sound_id,
): ):
await self.playlist_repo.add_sound_to_playlist(main_playlist.id, sound_id) await self.playlist_repo.add_sound_to_playlist(main_playlist.id, sound_id)
logger.info( logger.info(

View File

@@ -141,7 +141,7 @@ class SoundNormalizerService:
ffmpeg.run(stream, quiet=True, overwrite_output=True) ffmpeg.run(stream, quiet=True, overwrite_output=True)
logger.info("One-pass normalization completed: %s", output_path) logger.info("One-pass normalization completed: %s", output_path)
except Exception as e: except Exception:
logger.exception("One-pass normalization failed for %s", input_path) logger.exception("One-pass normalization failed for %s", input_path)
raise raise
@@ -153,7 +153,7 @@ class SoundNormalizerService:
"""Normalize audio using two-pass loudnorm for better quality.""" """Normalize audio using two-pass loudnorm for better quality."""
try: try:
logger.info( logger.info(
"Starting two-pass normalization: %s -> %s", input_path, output_path "Starting two-pass normalization: %s -> %s", input_path, output_path,
) )
# First pass: analyze # First pass: analyze
@@ -193,7 +193,7 @@ class SoundNormalizerService:
json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output) json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output)
if not json_match: if not json_match:
logger.error( logger.error(
"Could not find JSON in loudnorm output: %s", analysis_output "Could not find JSON in loudnorm output: %s", analysis_output,
) )
raise ValueError("Could not extract loudnorm analysis data") raise ValueError("Could not extract loudnorm analysis data")
@@ -260,7 +260,7 @@ class SoundNormalizerService:
) )
raise raise
except Exception as e: except Exception:
logger.exception("Two-pass normalization failed for %s", input_path) logger.exception("Two-pass normalization failed for %s", input_path)
raise raise
@@ -428,7 +428,7 @@ class SoundNormalizerService:
"type": sound.type, "type": sound.type,
"is_normalized": sound.is_normalized, "is_normalized": sound.is_normalized,
"name": sound.name, "name": sound.name,
} },
) )
# Process each sound using captured data # Process each sound using captured data
@@ -476,7 +476,7 @@ class SoundNormalizerService:
"normalized_hash": None, "normalized_hash": None,
"id": sound_id, "id": sound_id,
"error": str(e), "error": str(e),
} },
) )
logger.info("Normalization completed: %s", results) logger.info("Normalization completed: %s", results)
@@ -517,7 +517,7 @@ class SoundNormalizerService:
"type": sound.type, "type": sound.type,
"is_normalized": sound.is_normalized, "is_normalized": sound.is_normalized,
"name": sound.name, "name": sound.name,
} },
) )
# Process each sound using captured data # Process each sound using captured data
@@ -565,7 +565,7 @@ class SoundNormalizerService:
"normalized_hash": None, "normalized_hash": None,
"id": sound_id, "id": sound_id,
"error": str(e), "error": str(e),
} },
) )
logger.info("Type normalization completed: %s", results) logger.info("Type normalization completed: %s", results)

View File

@@ -132,7 +132,7 @@ class SoundScannerService:
"id": None, "id": None,
"error": str(e), "error": str(e),
"changes": None, "changes": None,
} },
) )
# Delete sounds that no longer exist in directory # Delete sounds that no longer exist in directory
@@ -153,7 +153,7 @@ class SoundScannerService:
"id": sound.id, "id": sound.id,
"error": None, "error": None,
"changes": None, "changes": None,
} },
) )
except Exception as e: except Exception as e:
logger.exception("Error deleting sound %s", filename) logger.exception("Error deleting sound %s", filename)
@@ -169,7 +169,7 @@ class SoundScannerService:
"id": sound.id, "id": sound.id,
"error": str(e), "error": str(e),
"changes": None, "changes": None,
} },
) )
logger.info("Sync completed: %s", results) logger.info("Sync completed: %s", results)
@@ -219,7 +219,7 @@ class SoundScannerService:
"id": sound.id, "id": sound.id,
"error": None, "error": None,
"changes": None, "changes": None,
} },
) )
elif existing_sound.hash != file_hash: elif existing_sound.hash != file_hash:
@@ -246,7 +246,7 @@ class SoundScannerService:
"id": existing_sound.id, "id": existing_sound.id,
"error": None, "error": None,
"changes": ["hash", "duration", "size", "name"], "changes": ["hash", "duration", "size", "name"],
} },
) )
else: else:
@@ -264,7 +264,7 @@ class SoundScannerService:
"id": existing_sound.id, "id": existing_sound.id,
"error": None, "error": None,
"changes": None, "changes": None,
} },
) )
async def scan_soundboard_directory(self) -> ScanResults: async def scan_soundboard_directory(self) -> ScanResults:

View File

@@ -6,7 +6,6 @@ from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger

View File

@@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable
from typing import Any, TypeVar from typing import Any, TypeVar
from app.models.credit_action import CreditActionType from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError from app.services.credit import CreditService
F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
@@ -69,7 +69,7 @@ def requires_credits(
# Validate credits before execution # Validate credits before execution
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
user_id, action_type, metadata user_id, action_type, metadata,
) )
# Execute the function # Execute the function
@@ -85,7 +85,7 @@ def requires_credits(
finally: finally:
# Deduct credits based on success # Deduct credits based on success
await credit_service.deduct_credits( await credit_service.deduct_credits(
user_id, action_type, success, metadata user_id, action_type, success, metadata,
) )
return wrapper # type: ignore[return-value] return wrapper # type: ignore[return-value]
@@ -173,7 +173,7 @@ class CreditManager:
async def __aenter__(self) -> "CreditManager": async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits.""" """Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits( await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.metadata self.user_id, self.action_type, self.metadata,
) )
self.validated = True self.validated = True
return self return self
@@ -184,9 +184,9 @@ class CreditManager:
# If no exception occurred, consider it successful # If no exception occurred, consider it successful
success = exc_type is None and self.success success = exc_type is None and self.success
await self.credit_service.deduct_credits( await self.credit_service.deduct_credits(
self.user_id, self.action_type, success, self.metadata self.user_id, self.action_type, success, self.metadata,
) )
def mark_success(self) -> None: def mark_success(self) -> None:
"""Mark the operation as successful.""" """Mark the operation as successful."""
self.success = True self.success = True

View File

@@ -177,7 +177,7 @@ class TestApiTokenEndpoints:
# Set a token on the user # Set a token on the user
authenticated_user.api_token = "expired_token" authenticated_user.api_token = "expired_token"
authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta( authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta(
days=1 days=1,
) )
response = await authenticated_client.get("/api/v1/auth/api-token/status") response = await authenticated_client.get("/api/v1/auth/api-token/status")
@@ -209,7 +209,7 @@ class TestApiTokenEndpoints:
# Verify token exists # Verify token exists
status_response = await authenticated_client.get( status_response = await authenticated_client.get(
"/api/v1/auth/api-token/status" "/api/v1/auth/api-token/status",
) )
assert status_response.json()["has_token"] is True assert status_response.json()["has_token"] is True
@@ -222,7 +222,7 @@ class TestApiTokenEndpoints:
# Verify token is gone # Verify token is gone
status_response = await authenticated_client.get( status_response = await authenticated_client.get(
"/api/v1/auth/api-token/status" "/api/v1/auth/api-token/status",
) )
assert status_response.json()["has_token"] is False assert status_response.json()["has_token"] is False

View File

@@ -1,20 +1,16 @@
"""Tests for extraction API endpoints.""" """Tests for extraction API endpoints."""
from unittest.mock import AsyncMock, Mock
import pytest import pytest
import pytest_asyncio
from httpx import AsyncClient from httpx import AsyncClient
from app.models.user import User
class TestExtractionEndpoints: class TestExtractionEndpoints:
"""Test extraction API endpoints.""" """Test extraction API endpoints."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_extraction_success( async def test_create_extraction_success(
self, test_client: AsyncClient, auth_cookies: dict[str, str] self, test_client: AsyncClient, auth_cookies: dict[str, str],
): ):
"""Test successful extraction creation.""" """Test successful extraction creation."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -50,7 +46,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_processor_status_admin( async def test_get_processor_status_admin(
self, test_client: AsyncClient, admin_cookies: dict[str, str] self, test_client: AsyncClient, admin_cookies: dict[str, str],
): ):
"""Test getting processor status as admin.""" """Test getting processor status as admin."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -66,7 +62,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_processor_status_non_admin( async def test_get_processor_status_non_admin(
self, test_client: AsyncClient, auth_cookies: dict[str, str] self, test_client: AsyncClient, auth_cookies: dict[str, str],
): ):
"""Test getting processor status as non-admin user.""" """Test getting processor status as non-admin user."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -80,7 +76,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_extractions( async def test_get_user_extractions(
self, test_client: AsyncClient, auth_cookies: dict[str, str] self, test_client: AsyncClient, auth_cookies: dict[str, str],
): ):
"""Test getting user extractions.""" """Test getting user extractions."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning

View File

@@ -656,4 +656,4 @@ class TestPlayerEndpoints:
json={"volume": 100}, json={"volume": 100},
) )
assert response.status_code == 200 assert response.status_code == 200
mock_player_service.set_volume.assert_called_with(100) mock_player_service.set_volume.assert_called_with(100)

View File

@@ -1,7 +1,5 @@
"""Tests for playlist API endpoints.""" """Tests for playlist API endpoints."""
import json
from typing import Any
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@@ -96,7 +94,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
name="Main Playlist", name="Main Playlist",
@@ -107,7 +105,7 @@ class TestPlaylistEndpoints:
) )
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
response = await authenticated_client.get("/api/v1/playlists/") response = await authenticated_client.get("/api/v1/playlists/")
assert response.status_code == 200 assert response.status_code == 200
@@ -146,11 +144,11 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(main_playlist) await test_session.refresh(main_playlist)
# Extract ID before HTTP request # Extract ID before HTTP request
main_playlist_id = main_playlist.id main_playlist_id = main_playlist.id
main_playlist_name = main_playlist.name main_playlist_name = main_playlist.name
response = await authenticated_client.get("/api/v1/playlists/main") response = await authenticated_client.get("/api/v1/playlists/main")
assert response.status_code == 200 assert response.status_code == 200
@@ -189,10 +187,10 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(main_playlist) await test_session.refresh(main_playlist)
# Extract ID before HTTP request # Extract ID before HTTP request
main_playlist_id = main_playlist.id main_playlist_id = main_playlist.id
response = await authenticated_client.get("/api/v1/playlists/current") response = await authenticated_client.get("/api/v1/playlists/current")
assert response.status_code == 200 assert response.status_code == 200
@@ -256,10 +254,10 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist) test_session.add(test_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
# Extract name before HTTP request # Extract name before HTTP request
playlist_name = test_playlist.name playlist_name = test_playlist.name
payload = { payload = {
"name": playlist_name, "name": playlist_name,
"description": "Duplicate name", "description": "Duplicate name",
@@ -292,13 +290,13 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist) test_session.add(test_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
# Extract values before HTTP request # Extract values before HTTP request
playlist_id = test_playlist.id playlist_id = test_playlist.id
playlist_name = test_playlist.name playlist_name = test_playlist.name
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}" f"/api/v1/playlists/{playlist_id}",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -339,10 +337,10 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist) test_session.add(test_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
# Extract ID before HTTP request # Extract ID before HTTP request
playlist_id = test_playlist.id playlist_id = test_playlist.id
payload = { payload = {
"name": "Updated Playlist", "name": "Updated Playlist",
"description": "Updated description", "description": "Updated description",
@@ -350,7 +348,7 @@ class TestPlaylistEndpoints:
} }
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}", json=payload f"/api/v1/playlists/{playlist_id}", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -379,7 +377,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
# Note: main_playlist doesn't need to be current=True for this test # Note: main_playlist doesn't need to be current=True for this test
# The service logic handles current playlist management # The service logic handles current playlist management
main_playlist = Playlist( main_playlist = Playlist(
@@ -393,14 +391,14 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
# Extract ID before HTTP request # Extract ID before HTTP request
playlist_id = test_playlist.id playlist_id = test_playlist.id
payload = {"is_current": True} payload = {"is_current": True}
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}", json=payload f"/api/v1/playlists/{playlist_id}", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -429,12 +427,12 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist) test_session.add(test_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
# Extract ID before HTTP requests # Extract ID before HTTP requests
playlist_id = test_playlist.id playlist_id = test_playlist.id
response = await authenticated_client.delete( response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}" f"/api/v1/playlists/{playlist_id}",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -442,7 +440,7 @@ class TestPlaylistEndpoints:
# Verify playlist is deleted # Verify playlist is deleted
get_response = await authenticated_client.get( get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}" f"/api/v1/playlists/{playlist_id}",
) )
assert get_response.status_code == 404 assert get_response.status_code == 404
@@ -465,12 +463,12 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(main_playlist) await test_session.refresh(main_playlist)
# Extract ID before HTTP request # Extract ID before HTTP request
main_playlist_id = main_playlist.id main_playlist_id = main_playlist.id
response = await authenticated_client.delete( response = await authenticated_client.delete(
f"/api/v1/playlists/{main_playlist_id}" f"/api/v1/playlists/{main_playlist_id}",
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -496,7 +494,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
name="Main Playlist", name="Main Playlist",
@@ -507,7 +505,7 @@ class TestPlaylistEndpoints:
) )
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
response = await authenticated_client.get("/api/v1/playlists/search/playlist") response = await authenticated_client.get("/api/v1/playlists/search/playlist")
assert response.status_code == 200 assert response.status_code == 200
@@ -541,7 +539,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
test_sound = Sound( test_sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -555,12 +553,12 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(test_sound) await test_session.refresh(test_sound)
# Extract IDs before creating playlist_sound # Extract IDs before creating playlist_sound
playlist_id = test_playlist.id playlist_id = test_playlist.id
sound_id = test_sound.id sound_id = test_sound.id
sound_name = test_sound.name sound_name = test_sound.name
# Add sound to playlist manually for testing # Add sound to playlist manually for testing
from app.models.playlist_sound import PlaylistSound from app.models.playlist_sound import PlaylistSound
@@ -573,7 +571,7 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds" f"/api/v1/playlists/{playlist_id}/sounds",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -602,7 +600,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
test_sound = Sound( test_sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -616,15 +614,15 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(test_sound) await test_session.refresh(test_sound)
# Extract IDs before HTTP requests # Extract IDs before HTTP requests
playlist_id = test_playlist.id playlist_id = test_playlist.id
sound_id = test_sound.id sound_id = test_sound.id
payload = {"sound_id": sound_id} payload = {"sound_id": sound_id}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -632,7 +630,7 @@ class TestPlaylistEndpoints:
# Verify sound was added # Verify sound was added
get_response = await authenticated_client.get( get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds" f"/api/v1/playlists/{playlist_id}/sounds",
) )
assert get_response.status_code == 200 assert get_response.status_code == 200
sounds = get_response.json() sounds = get_response.json()
@@ -659,7 +657,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
test_sound = Sound( test_sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -673,15 +671,15 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(test_sound) await test_session.refresh(test_sound)
# Extract IDs before HTTP request # Extract IDs before HTTP request
playlist_id = test_playlist.id playlist_id = test_playlist.id
sound_id = test_sound.id sound_id = test_sound.id
payload = {"sound_id": sound_id, "position": 5} payload = {"sound_id": sound_id, "position": 5}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -706,7 +704,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
test_sound = Sound( test_sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -720,22 +718,22 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(test_sound) await test_session.refresh(test_sound)
# Extract IDs before HTTP requests # Extract IDs before HTTP requests
playlist_id = test_playlist.id playlist_id = test_playlist.id
sound_id = test_sound.id sound_id = test_sound.id
payload = {"sound_id": sound_id} payload = {"sound_id": sound_id}
# Add sound first time # Add sound first time
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
# Try to add same sound again # Try to add same sound again
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 400 assert response.status_code == 400
assert "already in this playlist" in response.json()["detail"] assert "already in this playlist" in response.json()["detail"]
@@ -762,14 +760,14 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist) test_session.add(test_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
# Extract ID before HTTP request # Extract ID before HTTP request
playlist_id = test_playlist.id playlist_id = test_playlist.id
payload = {"sound_id": 99999} payload = {"sound_id": 99999}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -795,7 +793,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
test_sound = Sound( test_sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -809,20 +807,20 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(test_sound) await test_session.refresh(test_sound)
# Extract IDs before HTTP requests # Extract IDs before HTTP requests
playlist_id = test_playlist.id playlist_id = test_playlist.id
sound_id = test_sound.id sound_id = test_sound.id
# Add sound first # Add sound first
payload = {"sound_id": sound_id} payload = {"sound_id": sound_id}
await authenticated_client.post( await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
# Remove sound # Remove sound
response = await authenticated_client.delete( response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}" f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -830,7 +828,7 @@ class TestPlaylistEndpoints:
# Verify sound was removed # Verify sound was removed
get_response = await authenticated_client.get( get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds" f"/api/v1/playlists/{playlist_id}/sounds",
) )
sounds = get_response.json() sounds = get_response.json()
assert len(sounds) == 0 assert len(sounds) == 0
@@ -855,7 +853,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
test_sound = Sound( test_sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -869,13 +867,13 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(test_sound) await test_session.refresh(test_sound)
# Extract IDs before HTTP request # Extract IDs before HTTP request
playlist_id = test_playlist.id playlist_id = test_playlist.id
sound_id = test_sound.id sound_id = test_sound.id
response = await authenticated_client.delete( response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}" f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}",
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -901,7 +899,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
# Create multiple sounds # Create multiple sounds
sound1 = Sound(name="Sound 1", filename="sound1.mp3", type="SDB", hash="hash1") sound1 = Sound(name="Sound 1", filename="sound1.mp3", type="SDB", hash="hash1")
sound2 = Sound(name="Sound 2", filename="sound2.mp3", type="SDB", hash="hash2") sound2 = Sound(name="Sound 2", filename="sound2.mp3", type="SDB", hash="hash2")
@@ -910,7 +908,7 @@ class TestPlaylistEndpoints:
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(sound1) await test_session.refresh(sound1)
await test_session.refresh(sound2) await test_session.refresh(sound2)
# Extract IDs before HTTP requests # Extract IDs before HTTP requests
playlist_id = test_playlist.id playlist_id = test_playlist.id
sound1_id = sound1.id sound1_id = sound1.id
@@ -929,11 +927,11 @@ class TestPlaylistEndpoints:
# Reorder sounds - use positions that don't cause constraints # Reorder sounds - use positions that don't cause constraints
# When swapping, we need to be careful about unique constraints # When swapping, we need to be careful about unique constraints
payload = { payload = {
"sound_positions": [[sound1_id, 10], [sound2_id, 5]] # Use different positions to avoid constraints "sound_positions": [[sound1_id, 10], [sound2_id, 5]], # Use different positions to avoid constraints
} }
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}/sounds/reorder", json=payload f"/api/v1/playlists/{playlist_id}/sounds/reorder", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -959,7 +957,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
name="Main Playlist", name="Main Playlist",
@@ -971,12 +969,12 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
# Extract ID before HTTP request # Extract ID before HTTP request
playlist_id = test_playlist.id playlist_id = test_playlist.id
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}/set-current" f"/api/v1/playlists/{playlist_id}/set-current",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -1001,7 +999,7 @@ class TestPlaylistEndpoints:
is_deletable=False, is_deletable=False,
) )
test_session.add(main_playlist) test_session.add(main_playlist)
# Create a current playlist for the user # Create a current playlist for the user
user_id = test_user.id user_id = test_user.id
current_playlist = Playlist( current_playlist = Playlist(
@@ -1025,7 +1023,7 @@ class TestPlaylistEndpoints:
# but something else is causing validation to fail # but something else is causing validation to fail
assert response.status_code == 422 assert response.status_code == 422
return return
assert response.status_code == 200 assert response.status_code == 200
assert "unset successfully" in response.json()["message"] assert "unset successfully" in response.json()["message"]
@@ -1055,7 +1053,7 @@ class TestPlaylistEndpoints:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
test_sound = Sound( test_sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -1069,14 +1067,14 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(test_sound) await test_session.refresh(test_sound)
# Extract IDs before HTTP requests # Extract IDs before HTTP requests
playlist_id = test_playlist.id playlist_id = test_playlist.id
sound_id = test_sound.id sound_id = test_sound.id
# Initially empty # Initially empty
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/stats" f"/api/v1/playlists/{playlist_id}/stats",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -1093,7 +1091,7 @@ class TestPlaylistEndpoints:
# Check stats again # Check stats again
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/stats" f"/api/v1/playlists/{playlist_id}/stats",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -1110,8 +1108,8 @@ class TestPlaylistEndpoints:
test_session: AsyncSession, test_session: AsyncSession,
) -> None: ) -> None:
"""Test that users can only access their own playlists.""" """Test that users can only access their own playlists."""
from app.utils.auth import JWTUtils, PasswordUtils
from app.models.plan import Plan from app.models.plan import Plan
from app.utils.auth import PasswordUtils
# Create plan within this test to avoid session issues # Create plan within this test to avoid session issues
plan = Plan( plan = Plan(
@@ -1124,10 +1122,10 @@ class TestPlaylistEndpoints:
test_session.add(plan) test_session.add(plan)
await test_session.commit() await test_session.commit()
await test_session.refresh(plan) await test_session.refresh(plan)
# Extract plan ID immediately to avoid session issues # Extract plan ID immediately to avoid session issues
plan_id = plan.id plan_id = plan.id
# Create another user with their own playlist # Create another user with their own playlist
other_user = User( other_user = User(
email="other@example.com", email="other@example.com",
@@ -1144,7 +1142,7 @@ class TestPlaylistEndpoints:
# Extract other user ID before creating playlist # Extract other user ID before creating playlist
other_user_id = other_user.id other_user_id = other_user.id
other_playlist = Playlist( other_playlist = Playlist(
user_id=other_user_id, user_id=other_user_id,
name="Other User's Playlist", name="Other User's Playlist",
@@ -1153,13 +1151,13 @@ class TestPlaylistEndpoints:
test_session.add(other_playlist) test_session.add(other_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(other_playlist) await test_session.refresh(other_playlist)
# Extract playlist ID before HTTP requests # Extract playlist ID before HTTP requests
other_playlist_id = other_playlist.id other_playlist_id = other_playlist.id
# Try to access other user's playlist # Try to access other user's playlist
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{other_playlist_id}" f"/api/v1/playlists/{other_playlist_id}",
) )
# Currently the implementation allows access to all playlists # Currently the implementation allows access to all playlists

View File

@@ -158,7 +158,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_missing_parameters( async def test_send_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User self, authenticated_client: AsyncClient, authenticated_user: User,
): ):
"""Test sending message with missing parameters.""" """Test sending message with missing parameters."""
# Missing target_user_id # Missing target_user_id
@@ -177,7 +177,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_broadcast_message_missing_parameters( async def test_broadcast_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User self, authenticated_client: AsyncClient, authenticated_user: User,
): ):
"""Test broadcasting message with missing parameters.""" """Test broadcasting message with missing parameters."""
response = await authenticated_client.post("/api/v1/socket/broadcast") response = await authenticated_client.post("/api/v1/socket/broadcast")
@@ -185,7 +185,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_invalid_user_id( async def test_send_message_invalid_user_id(
self, authenticated_client: AsyncClient, authenticated_user: User self, authenticated_client: AsyncClient, authenticated_user: User,
): ):
"""Test sending message with invalid user ID.""" """Test sending message with invalid user ID."""
response = await authenticated_client.post( response = await authenticated_client.post(

View File

@@ -66,7 +66,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -167,7 +167,7 @@ class TestSoundEndpoints:
headers = {"API-TOKEN": "admin_api_token"} headers = {"API-TOKEN": "admin_api_token"}
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -192,7 +192,7 @@ class TestSoundEndpoints:
): ):
"""Test scanning sounds when service raises an error.""" """Test scanning sounds when service raises an error."""
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.side_effect = Exception("Directory not found") mock_scan.side_effect = Exception("Directory not found")
@@ -244,7 +244,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory" "app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -285,7 +285,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory" "app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -307,14 +307,14 @@ class TestSoundEndpoints:
): ):
"""Test custom directory scanning with invalid path.""" """Test custom directory scanning with invalid path."""
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory" "app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan: ) as mock_scan:
mock_scan.side_effect = ValueError( mock_scan.side_effect = ValueError(
"Directory does not exist: /invalid/path" "Directory does not exist: /invalid/path",
) )
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/invalid/path"} "/api/v1/sounds/scan/custom", params={"directory": "/invalid/path"},
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -325,7 +325,7 @@ class TestSoundEndpoints:
async def test_scan_custom_directory_unauthenticated(self, client: AsyncClient): async def test_scan_custom_directory_unauthenticated(self, client: AsyncClient):
"""Test custom directory scanning without authentication.""" """Test custom directory scanning without authentication."""
response = await client.post( response = await client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/some/path"} "/api/v1/sounds/scan/custom", params={"directory": "/some/path"},
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -377,12 +377,12 @@ class TestSoundEndpoints:
): ):
"""Test custom directory scanning when service raises an error.""" """Test custom directory scanning when service raises an error."""
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory" "app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan: ) as mock_scan:
mock_scan.side_effect = Exception("Permission denied") mock_scan.side_effect = Exception("Permission denied")
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/restricted/path"} "/api/v1/sounds/scan/custom", params={"directory": "/restricted/path"},
) )
assert response.status_code == 500 assert response.status_code == 500
@@ -442,7 +442,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -480,7 +480,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -570,12 +570,12 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds" "app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all" "/api/v1/sounds/normalize/all",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -608,12 +608,12 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds" "app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all", params={"force": True} "/api/v1/sounds/normalize/all", params={"force": True},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -637,12 +637,12 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds" "app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all", params={"one_pass": True} "/api/v1/sounds/normalize/all", params={"one_pass": True},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -684,7 +684,7 @@ class TestSoundEndpoints:
base_url="http://test", base_url="http://test",
) as client: ) as client:
response = await client.post( response = await client.post(
"/api/v1/sounds/normalize/all", headers=headers "/api/v1/sounds/normalize/all", headers=headers,
) )
assert response.status_code == 403 assert response.status_code == 403
@@ -702,12 +702,12 @@ class TestSoundEndpoints:
): ):
"""Test normalization when service raises an error.""" """Test normalization when service raises an error."""
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds" "app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize: ) as mock_normalize:
mock_normalize.side_effect = Exception("Normalization service failed") mock_normalize.side_effect = Exception("Normalization service failed")
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all" "/api/v1/sounds/normalize/all",
) )
assert response.status_code == 500 assert response.status_code == 500
@@ -758,12 +758,12 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type" "app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/type/SDB" "/api/v1/sounds/normalize/type/SDB",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -779,7 +779,7 @@ class TestSoundEndpoints:
# Verify the service was called with correct type # Verify the service was called with correct type
mock_normalize.assert_called_once_with( mock_normalize.assert_called_once_with(
sound_type="SDB", force=False, one_pass=None sound_type="SDB", force=False, one_pass=None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -790,7 +790,7 @@ class TestSoundEndpoints:
): ):
"""Test normalization with invalid sound type.""" """Test normalization with invalid sound type."""
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/type/INVALID" "/api/v1/sounds/normalize/type/INVALID",
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -814,7 +814,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type" "app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
@@ -827,7 +827,7 @@ class TestSoundEndpoints:
# Verify parameters were passed correctly # Verify parameters were passed correctly
mock_normalize.assert_called_once_with( mock_normalize.assert_called_once_with(
sound_type="TTS", force=True, one_pass=False sound_type="TTS", force=True, one_pass=False,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -866,7 +866,7 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound" "app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
@@ -874,7 +874,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42" "/api/v1/sounds/normalize/42",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -897,12 +897,12 @@ class TestSoundEndpoints:
): ):
"""Test normalization of non-existent sound.""" """Test normalization of non-existent sound."""
with patch( with patch(
"app.repositories.sound.SoundRepository.get_by_id" "app.repositories.sound.SoundRepository.get_by_id",
) as mock_get_sound: ) as mock_get_sound:
mock_get_sound.return_value = None mock_get_sound.return_value = None
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/999" "/api/v1/sounds/normalize/999",
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -945,7 +945,7 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound" "app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
@@ -953,7 +953,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42" "/api/v1/sounds/normalize/42",
) )
assert response.status_code == 500 assert response.status_code == 500
@@ -997,7 +997,7 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound" "app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
@@ -1052,7 +1052,7 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound" "app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
@@ -1060,7 +1060,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42" "/api/v1/sounds/normalize/42",
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -1,16 +1,14 @@
"""Tests for VLC player API endpoints.""" """Tests for VLC player API endpoints."""
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock
import pytest import pytest
from httpx import AsyncClient
from fastapi import FastAPI from fastapi import FastAPI
from httpx import AsyncClient
from app.api.v1.sounds import get_credit_service, get_sound_repository, get_vlc_player
from app.models.sound import Sound from app.models.sound import Sound
from app.models.user import User from app.models.user import User
from app.api.v1.sounds import get_vlc_player, get_sound_repository, get_credit_service
class TestVLCEndpoints: class TestVLCEndpoints:
@@ -28,7 +26,7 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_credit_service = AsyncMock() mock_credit_service = AsyncMock()
# Set up test data # Set up test data
mock_sound = Sound( mock_sound = Sound(
id=1, id=1,
@@ -39,27 +37,27 @@ class TestVLCEndpoints:
size=1024, size=1024,
hash="test_hash", hash="test_hash",
) )
# Configure mocks # Configure mocks
mock_repo.get_by_id.return_value = mock_sound mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = True mock_vlc_service.play_sound.return_value = True
# Override dependencies # Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1") response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["sound_id"] == 1 assert data["sound_id"] == 1
assert data["sound_name"] == "Test Sound" assert data["sound_name"] == "Test Sound"
assert "Test Sound" in data["message"] assert "Test Sound" in data["message"]
# Verify service calls # Verify service calls
mock_repo.get_by_id.assert_called_once_with(1) mock_repo.get_by_id.assert_called_once_with(1)
mock_vlc_service.play_sound.assert_called_once_with(mock_sound) mock_vlc_service.play_sound.assert_called_once_with(mock_sound)
@@ -81,18 +79,18 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_credit_service = AsyncMock() mock_credit_service = AsyncMock()
# Configure mocks # Configure mocks
mock_repo.get_by_id.return_value = None mock_repo.get_by_id.return_value = None
# Override dependencies # Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/999") response = await authenticated_client.post("/api/v1/sounds/vlc/play/999")
assert response.status_code == 404 assert response.status_code == 404
data = response.json() data = response.json()
assert "Sound with ID 999 not found" in data["detail"] assert "Sound with ID 999 not found" in data["detail"]
@@ -114,7 +112,7 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_credit_service = AsyncMock() mock_credit_service = AsyncMock()
# Set up test data # Set up test data
mock_sound = Sound( mock_sound = Sound(
id=1, id=1,
@@ -125,21 +123,21 @@ class TestVLCEndpoints:
size=1024, size=1024,
hash="test_hash", hash="test_hash",
) )
# Configure mocks # Configure mocks
mock_repo.get_by_id.return_value = mock_sound mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = False mock_vlc_service.play_sound.return_value = False
# Override dependencies # Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1") response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 500 assert response.status_code == 500
data = response.json() data = response.json()
assert "Failed to launch VLC for sound playback" in data["detail"] assert "Failed to launch VLC for sound playback" in data["detail"]
@@ -161,18 +159,18 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_credit_service = AsyncMock() mock_credit_service = AsyncMock()
# Configure mocks # Configure mocks
mock_repo.get_by_id.side_effect = Exception("Database error") mock_repo.get_by_id.side_effect = Exception("Database error")
# Override dependencies # Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1") response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 500 assert response.status_code == 500
data = response.json() data = response.json()
assert "Failed to play sound" in data["detail"] assert "Failed to play sound" in data["detail"]
@@ -209,13 +207,13 @@ class TestVLCEndpoints:
"message": "Killed 3 VLC processes", "message": "Killed 3 VLC processes",
} }
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency # Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -223,7 +221,7 @@ class TestVLCEndpoints:
assert data["processes_killed"] == 3 assert data["processes_killed"] == 3
assert data["processes_remaining"] == 0 assert data["processes_remaining"] == 0
assert "Killed 3 VLC processes" in data["message"] assert "Killed 3 VLC processes" in data["message"]
# Verify service call # Verify service call
mock_vlc_service.stop_all_vlc_instances.assert_called_once() mock_vlc_service.stop_all_vlc_instances.assert_called_once()
finally: finally:
@@ -247,13 +245,13 @@ class TestVLCEndpoints:
"message": "No VLC processes found", "message": "No VLC processes found",
} }
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency # Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -282,13 +280,13 @@ class TestVLCEndpoints:
"message": "Killed 2 VLC processes", "message": "Killed 2 VLC processes",
} }
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency # Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
@@ -317,13 +315,13 @@ class TestVLCEndpoints:
"message": "Failed to stop VLC processes", "message": "Failed to stop VLC processes",
} }
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency # Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is False assert data["success"] is False
@@ -344,13 +342,13 @@ class TestVLCEndpoints:
# Set up mock to raise an exception # Set up mock to raise an exception
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_vlc_service.stop_all_vlc_instances.side_effect = Exception("Service error") mock_vlc_service.stop_all_vlc_instances.side_effect = Exception("Service error")
# Override dependency # Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try: try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 500 assert response.status_code == 500
data = response.json() data = response.json()
assert "Failed to stop VLC instances" in data["detail"] assert "Failed to stop VLC instances" in data["detail"]
@@ -379,7 +377,7 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_credit_service = AsyncMock() mock_credit_service = AsyncMock()
# Set up test data # Set up test data
mock_sound = Sound( mock_sound = Sound(
id=1, id=1,
@@ -390,21 +388,21 @@ class TestVLCEndpoints:
size=512, size=512,
hash="admin_hash", hash="admin_hash",
) )
# Configure mocks # Configure mocks
mock_repo.get_by_id.return_value = mock_sound mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = True mock_vlc_service.play_sound.return_value = True
# Override dependencies # Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try: try:
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1") response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["sound_name"] == "Admin Test Sound" assert data["sound_name"] == "Admin Test Sound"
@@ -424,17 +422,17 @@ class TestVLCEndpoints:
"message": "Killed 1 VLC processes", "message": "Killed 1 VLC processes",
} }
mock_vlc_service_2.stop_all_vlc_instances.return_value = mock_result mock_vlc_service_2.stop_all_vlc_instances.return_value = mock_result
# Override dependency for stop-all test # Override dependency for stop-all test
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service_2 test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service_2
try: try:
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
assert data["processes_killed"] == 1 assert data["processes_killed"] == 1
finally: finally:
# Clean up dependency override # Clean up dependency override
test_app.dependency_overrides.pop(get_vlc_player, None) test_app.dependency_overrides.pop(get_vlc_player, None)

View File

@@ -13,10 +13,8 @@ from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
from app.utils.auth import JWTUtils, PasswordUtils from app.utils.auth import JWTUtils, PasswordUtils

View File

@@ -49,7 +49,7 @@ class TestApiTokenDependencies:
assert result == test_user assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with( mock_auth_service.get_user_by_api_token.assert_called_once_with(
"test_api_token_123" "test_api_token_123",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -135,11 +135,11 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_service_exception( async def test_get_current_user_api_token_service_exception(
self, mock_auth_service self, mock_auth_service,
): ):
"""Test API token authentication with service exception.""" """Test API token authentication with service exception."""
mock_auth_service.get_user_by_api_token.side_effect = Exception( mock_auth_service.get_user_by_api_token.side_effect = Exception(
"Database error" "Database error",
) )
api_token_header = "test_token" api_token_header = "test_token"
@@ -170,7 +170,7 @@ class TestApiTokenDependencies:
assert result == test_user assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with( mock_auth_service.get_user_by_api_token.assert_called_once_with(
"test_api_token_123" "test_api_token_123",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -184,7 +184,7 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_no_expiry_never_expires( async def test_api_token_no_expiry_never_expires(
self, mock_auth_service, test_user self, mock_auth_service, test_user,
): ):
"""Test API token with no expiry date never expires.""" """Test API token with no expiry date never expires."""
test_user.api_token_expires_at = None test_user.api_token_expires_at = None

View File

@@ -42,7 +42,7 @@ class TestCreditTransactionRepository:
"""Create test credit transactions.""" """Create test credit transactions."""
transactions = [] transactions = []
user_id = test_user_id user_id = test_user_id
# Create various types of transactions # Create various types of transactions
transaction_data = [ transaction_data = [
{ {
@@ -105,9 +105,8 @@ class TestCreditTransactionRepository:
ensure_plans: tuple[Any, ...], # noqa: ARG002 ensure_plans: tuple[Any, ...], # noqa: ARG002
) -> AsyncGenerator[CreditTransaction, None]: ) -> AsyncGenerator[CreditTransaction, None]:
"""Create a transaction for a different user.""" """Create a transaction for a different user."""
from app.models.plan import Plan
from app.repositories.user import UserRepository from app.repositories.user import UserRepository
# Create another user # Create another user
user_repo = UserRepository(test_session) user_repo = UserRepository(test_session)
other_user_data = { other_user_data = {
@@ -134,7 +133,7 @@ class TestCreditTransactionRepository:
test_session.add(transaction) test_session.add(transaction)
await test_session.commit() await test_session.commit()
await test_session.refresh(transaction) await test_session.refresh(transaction)
yield transaction yield transaction
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -178,7 +177,7 @@ class TestCreditTransactionRepository:
assert len(transactions) == 4 assert len(transactions) == 4
# Should be ordered by created_at desc (newest first) # Should be ordered by created_at desc (newest first)
assert all(t.user_id == test_user_id for t in transactions) assert all(t.user_id == test_user_id for t in transactions)
# Should not include other user's transaction # Should not include other user's transaction
other_user_ids = [t.user_id for t in transactions] other_user_ids = [t.user_id for t in transactions]
assert other_user_transaction.user_id not in other_user_ids assert other_user_transaction.user_id not in other_user_ids
@@ -193,13 +192,13 @@ class TestCreditTransactionRepository:
"""Test getting transactions by user ID with pagination.""" """Test getting transactions by user ID with pagination."""
# Get first 2 transactions # Get first 2 transactions
first_page = await credit_transaction_repository.get_by_user_id( first_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=0 test_user_id, limit=2, offset=0,
) )
assert len(first_page) == 2 assert len(first_page) == 2
# Get next 2 transactions # Get next 2 transactions
second_page = await credit_transaction_repository.get_by_user_id( second_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=2 test_user_id, limit=2, offset=2,
) )
assert len(second_page) == 2 assert len(second_page) == 2
@@ -216,17 +215,17 @@ class TestCreditTransactionRepository:
) -> None: ) -> None:
"""Test getting transactions by action type.""" """Test getting transactions by action type."""
vlc_transactions = await credit_transaction_repository.get_by_action_type( vlc_transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound" "vlc_play_sound",
) )
# Should return 2 VLC transactions (1 successful, 1 failed) # Should return 2 VLC transactions (1 successful, 1 failed)
assert len(vlc_transactions) >= 2 assert len(vlc_transactions) >= 2
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions) assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
extraction_transactions = await credit_transaction_repository.get_by_action_type( extraction_transactions = await credit_transaction_repository.get_by_action_type(
"audio_extraction" "audio_extraction",
) )
# Should return 1 extraction transaction # Should return 1 extraction transaction
assert len(extraction_transactions) >= 1 assert len(extraction_transactions) >= 1
assert all(t.action_type == "audio_extraction" for t in extraction_transactions) assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
@@ -240,14 +239,14 @@ class TestCreditTransactionRepository:
"""Test getting transactions by action type with pagination.""" """Test getting transactions by action type with pagination."""
# Test with limit # Test with limit
transactions = await credit_transaction_repository.get_by_action_type( transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1 "vlc_play_sound", limit=1,
) )
assert len(transactions) == 1 assert len(transactions) == 1
assert transactions[0].action_type == "vlc_play_sound" assert transactions[0].action_type == "vlc_play_sound"
# Test with offset # Test with offset
transactions = await credit_transaction_repository.get_by_action_type( transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, offset=1 "vlc_play_sound", limit=1, offset=1,
) )
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
@@ -275,7 +274,7 @@ class TestCreditTransactionRepository:
) -> None: ) -> None:
"""Test getting successful transactions filtered by user.""" """Test getting successful transactions filtered by user."""
successful_transactions = await credit_transaction_repository.get_successful_transactions( successful_transactions = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id user_id=test_user_id,
) )
# Should only return successful transactions for test_user # Should only return successful transactions for test_user
@@ -294,14 +293,14 @@ class TestCreditTransactionRepository:
"""Test getting successful transactions with pagination.""" """Test getting successful transactions with pagination."""
# Get first 2 successful transactions # Get first 2 successful transactions
first_page = await credit_transaction_repository.get_successful_transactions( first_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=0 user_id=test_user_id, limit=2, offset=0,
) )
assert len(first_page) == 2 assert len(first_page) == 2
assert all(t.success is True for t in first_page) assert all(t.success is True for t in first_page)
# Get next successful transaction # Get next successful transaction
second_page = await credit_transaction_repository.get_successful_transactions( second_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=2 user_id=test_user_id, limit=2, offset=2,
) )
assert len(second_page) == 1 # Should be 1 remaining assert len(second_page) == 1 # Should be 1 remaining
assert all(t.success is True for t in second_page) assert all(t.success is True for t in second_page)
@@ -363,7 +362,7 @@ class TestCreditTransactionRepository:
} }
updated_transaction = await credit_transaction_repository.update( updated_transaction = await credit_transaction_repository.update(
transaction, update_data transaction, update_data,
) )
assert updated_transaction.id == transaction.id assert updated_transaction.id == transaction.id
@@ -413,7 +412,7 @@ class TestCreditTransactionRepository:
) -> None: ) -> None:
"""Test that transactions are ordered by created_at desc.""" """Test that transactions are ordered by created_at desc."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id) transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should be ordered by created_at desc (newest first) # Should be ordered by created_at desc (newest first)
for i in range(len(transactions) - 1): for i in range(len(transactions) - 1):
assert transactions[i].created_at >= transactions[i + 1].created_at assert transactions[i].created_at >= transactions[i + 1].created_at

View File

@@ -52,7 +52,7 @@ class TestExtractionRepository:
assert result.service_id == extraction_data["service_id"] assert result.service_id == extraction_data["service_id"]
assert result.title == extraction_data["title"] assert result.title == extraction_data["title"]
assert result.status == extraction_data["status"] assert result.status == extraction_data["status"]
# Verify session methods were called # Verify session methods were called
extraction_repo.session.add.assert_called_once() extraction_repo.session.add.assert_called_once()
extraction_repo.session.commit.assert_called_once() extraction_repo.session.commit.assert_called_once()

View File

@@ -151,10 +151,10 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
# Extract user ID immediately after refresh # Extract user ID immediately after refresh
user_id = user.id user_id = user.id
# Create test playlist for this user # Create test playlist for this user
playlist = Playlist( playlist = Playlist(
user_id=user_id, user_id=user_id,
@@ -167,10 +167,10 @@ class TestPlaylistRepository:
) )
test_session.add(playlist) test_session.add(playlist)
await test_session.commit() await test_session.commit()
# Test the repository method # Test the repository method
playlists = await playlist_repository.get_by_user_id(user_id) playlists = await playlist_repository.get_by_user_id(user_id)
# Should only return user's playlists, not the main playlist (user_id=None) # Should only return user's playlists, not the main playlist (user_id=None)
assert len(playlists) == 1 assert len(playlists) == 1
assert playlists[0].name == "Test Playlist" assert playlists[0].name == "Test Playlist"
@@ -194,13 +194,13 @@ class TestPlaylistRepository:
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(main_playlist) await test_session.refresh(main_playlist)
# Extract ID before async call # Extract ID before async call
main_playlist_id = main_playlist.id main_playlist_id = main_playlist.id
# Test the repository method # Test the repository method
playlist = await playlist_repository.get_main_playlist() playlist = await playlist_repository.get_main_playlist()
assert playlist is not None assert playlist is not None
assert playlist.id == main_playlist_id assert playlist.id == main_playlist_id
assert playlist.is_main is True assert playlist.is_main is True
@@ -227,13 +227,13 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
# Extract user ID immediately after refresh # Extract user ID immediately after refresh
user_id = user.id user_id = user.id
# Test the repository method - should return None when no current playlist # Test the repository method - should return None when no current playlist
playlist = await playlist_repository.get_current_playlist(user_id) playlist = await playlist_repository.get_current_playlist(user_id)
# Should return None since no user playlist is marked as current # Should return None since no user playlist is marked as current
assert playlist is None assert playlist is None
@@ -319,10 +319,10 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
# Extract user ID immediately after refresh # Extract user ID immediately after refresh
user_id = user.id user_id = user.id
# Create test playlist # Create test playlist
test_playlist = Playlist( test_playlist = Playlist(
user_id=user_id, user_id=user_id,
@@ -334,7 +334,7 @@ class TestPlaylistRepository:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
# Create main playlist # Create main playlist
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
@@ -346,7 +346,7 @@ class TestPlaylistRepository:
) )
test_session.add(main_playlist) test_session.add(main_playlist)
await test_session.commit() await test_session.commit()
# Search for all playlists (no user filter) # Search for all playlists (no user filter)
all_results = await playlist_repository.search_by_name("playlist") all_results = await playlist_repository.search_by_name("playlist")
assert len(all_results) >= 2 # Should include both user and main playlists assert len(all_results) >= 2 # Should include both user and main playlists
@@ -382,7 +382,7 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
# Create test playlist # Create test playlist
playlist = Playlist( playlist = Playlist(
user_id=user.id, user_id=user.id,
@@ -394,7 +394,7 @@ class TestPlaylistRepository:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
# Create test sound # Create test sound
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
@@ -409,14 +409,14 @@ class TestPlaylistRepository:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async call # Extract IDs before async call
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
# Test the repository method # Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist( playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
assert playlist_sound.playlist_id == playlist_id assert playlist_sound.playlist_id == playlist_id
@@ -445,10 +445,10 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
# Extract user ID immediately after refresh # Extract user ID immediately after refresh
user_id = user.id user_id = user.id
# Create test playlist # Create test playlist
playlist = Playlist( playlist = Playlist(
user_id=user_id, user_id=user_id,
@@ -460,7 +460,7 @@ class TestPlaylistRepository:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
# Create test sound # Create test sound
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
@@ -475,14 +475,14 @@ class TestPlaylistRepository:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async call # Extract IDs before async call
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
# Test the repository method # Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist( playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id, position=5 playlist_id, sound_id, position=5,
) )
assert playlist_sound.position == 5 assert playlist_sound.position == 5
@@ -509,9 +509,9 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
user_id = user.id user_id = user.id
playlist = Playlist( playlist = Playlist(
user_id=user_id, user_id=user_id,
name="Test Playlist", name="Test Playlist",
@@ -522,7 +522,7 @@ class TestPlaylistRepository:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -536,7 +536,7 @@ class TestPlaylistRepository:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
@@ -546,17 +546,17 @@ class TestPlaylistRepository:
# Verify it was added # Verify it was added
assert await playlist_repository.is_sound_in_playlist( assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
# Remove the sound # Remove the sound
await playlist_repository.remove_sound_from_playlist( await playlist_repository.remove_sound_from_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
# Verify it was removed # Verify it was removed
assert not await playlist_repository.is_sound_in_playlist( assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -581,9 +581,9 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
user_id = user.id user_id = user.id
playlist = Playlist( playlist = Playlist(
user_id=user_id, user_id=user_id,
name="Test Playlist", name="Test Playlist",
@@ -594,7 +594,7 @@ class TestPlaylistRepository:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -608,7 +608,7 @@ class TestPlaylistRepository:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
@@ -647,9 +647,9 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
user_id = user.id user_id = user.id
playlist = Playlist( playlist = Playlist(
user_id=user_id, user_id=user_id,
name="Test Playlist", name="Test Playlist",
@@ -660,7 +660,7 @@ class TestPlaylistRepository:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -674,7 +674,7 @@ class TestPlaylistRepository:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
@@ -712,9 +712,9 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
user_id = user.id user_id = user.id
playlist = Playlist( playlist = Playlist(
user_id=user_id, user_id=user_id,
name="Test Playlist", name="Test Playlist",
@@ -725,7 +725,7 @@ class TestPlaylistRepository:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -739,14 +739,14 @@ class TestPlaylistRepository:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
# Initially not in playlist # Initially not in playlist
assert not await playlist_repository.is_sound_in_playlist( assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
# Add sound # Add sound
@@ -754,7 +754,7 @@ class TestPlaylistRepository:
# Now in playlist # Now in playlist
assert await playlist_repository.is_sound_in_playlist( assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -779,9 +779,9 @@ class TestPlaylistRepository:
test_session.add(user) test_session.add(user)
await test_session.commit() await test_session.commit()
await test_session.refresh(user) await test_session.refresh(user)
user_id = user.id user_id = user.id
playlist = Playlist( playlist = Playlist(
user_id=user_id, user_id=user_id,
name="Test Playlist", name="Test Playlist",
@@ -801,7 +801,7 @@ class TestPlaylistRepository:
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound1) await test_session.refresh(sound1)
await test_session.refresh(sound2) await test_session.refresh(sound2)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound1_id = sound1.id sound1_id = sound1.id
@@ -809,16 +809,16 @@ class TestPlaylistRepository:
# Add sounds to playlist # Add sounds to playlist
await playlist_repository.add_sound_to_playlist( await playlist_repository.add_sound_to_playlist(
playlist_id, sound1_id, position=0 playlist_id, sound1_id, position=0,
) )
await playlist_repository.add_sound_to_playlist( await playlist_repository.add_sound_to_playlist(
playlist_id, sound2_id, position=1 playlist_id, sound2_id, position=1,
) )
# Reorder sounds - use different positions to avoid constraint issues # Reorder sounds - use different positions to avoid constraint issues
sound_positions = [(sound1_id, 10), (sound2_id, 5)] sound_positions = [(sound1_id, 10), (sound2_id, 5)]
await playlist_repository.reorder_playlist_sounds( await playlist_repository.reorder_playlist_sounds(
playlist_id, sound_positions playlist_id, sound_positions,
) )
# Verify new order # Verify new order

View File

@@ -359,7 +359,7 @@ class TestSoundRepository:
"""Test creating sound with duplicate hash should fail.""" """Test creating sound with duplicate hash should fail."""
# Store the hash to avoid lazy loading issues # Store the hash to avoid lazy loading issues
original_hash = test_sound.hash original_hash = test_sound.hash
duplicate_sound_data = { duplicate_sound_data = {
"name": "Duplicate Hash Sound", "name": "Duplicate Hash Sound",
"filename": "duplicate.mp3", "filename": "duplicate.mp3",
@@ -373,4 +373,4 @@ class TestSoundRepository:
# Should fail due to unique constraint on hash # Should fail due to unique constraint on hash
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
await sound_repository.create(duplicate_sound_data) await sound_repository.create(duplicate_sound_data)

View File

@@ -60,7 +60,7 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by provider user ID when it exists.""" """Test getting OAuth by provider user ID when it exists."""
oauth = await user_oauth_repository.get_by_provider_user_id( oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "google_123456" "google", "google_123456",
) )
assert oauth is not None assert oauth is not None
@@ -76,7 +76,7 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by provider user ID when it doesn't exist.""" """Test getting OAuth by provider user ID when it doesn't exist."""
oauth = await user_oauth_repository.get_by_provider_user_id( oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "nonexistent_id" "google", "nonexistent_id",
) )
assert oauth is None assert oauth is None
@@ -90,7 +90,7 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by user ID and provider when it exists.""" """Test getting OAuth by user ID and provider when it exists."""
oauth = await user_oauth_repository.get_by_user_id_and_provider( oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google" test_user_id, "google",
) )
assert oauth is not None assert oauth is not None
@@ -106,7 +106,7 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by user ID and provider when it doesn't exist.""" """Test getting OAuth by user ID and provider when it doesn't exist."""
oauth = await user_oauth_repository.get_by_user_id_and_provider( oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github" test_user_id, "github",
) )
assert oauth is None assert oauth is None
@@ -183,7 +183,7 @@ class TestUserOauthRepository:
# Verify it's deleted by trying to find it # Verify it's deleted by trying to find it
deleted_oauth = await user_oauth_repository.get_by_provider_user_id( deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
"twitter", "twitter_456" "twitter", "twitter_456",
) )
assert deleted_oauth is None assert deleted_oauth is None
@@ -240,10 +240,10 @@ class TestUserOauthRepository:
# Verify both exist by querying back from database # Verify both exist by querying back from database
found_google = await user_oauth_repository.get_by_user_id_and_provider( found_google = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google" test_user_id, "google",
) )
found_github = await user_oauth_repository.get_by_user_id_and_provider( found_github = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github" test_user_id, "github",
) )
assert found_google is not None assert found_google is not None
@@ -257,13 +257,13 @@ class TestUserOauthRepository:
# Verify we can also find them by provider_user_id # Verify we can also find them by provider_user_id
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id( found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
"google", "google_user_1" "google", "google_user_1",
) )
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id( found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
"github", "github_user_1" "github", "github_user_1",
) )
assert found_google_by_provider is not None assert found_google_by_provider is not None
assert found_github_by_provider is not None assert found_github_by_provider is not None
assert found_google_by_provider.user_id == test_user_id assert found_google_by_provider.user_id == test_user_id
assert found_github_by_provider.user_id == test_user_id assert found_github_by_provider.user_id == test_user_id

View File

@@ -1,7 +1,7 @@
"""Tests for credit service.""" """Tests for credit service."""
import json import json
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -42,14 +42,14 @@ class TestCreditService:
async def test_check_credits_sufficient(self, credit_service, sample_user): async def test_check_credits_sufficient(self, credit_service, sample_user):
"""Test checking credits when user has sufficient credits.""" """Test checking credits when user has sufficient credits."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class: with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND) result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
assert result is True assert result is True
mock_repo.get_by_id.assert_called_once_with(1) mock_repo.get_by_id.assert_called_once_with(1)
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -66,14 +66,14 @@ class TestCreditService:
credits=0, # No credits credits=0, # No credits
plan_id=1, plan_id=1,
) )
with patch("app.services.credit.UserRepository") as mock_repo_class: with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user mock_repo.get_by_id.return_value = poor_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND) result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
assert result is False assert result is False
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -81,14 +81,14 @@ class TestCreditService:
async def test_check_credits_user_not_found(self, credit_service): async def test_check_credits_user_not_found(self, credit_service):
"""Test checking credits when user is not found.""" """Test checking credits when user is not found."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class: with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None mock_repo.get_by_id.return_value = None
result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND) result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND)
assert result is False assert result is False
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -96,16 +96,16 @@ class TestCreditService:
async def test_validate_and_reserve_credits_success(self, credit_service, sample_user): async def test_validate_and_reserve_credits_success(self, credit_service, sample_user):
"""Test successful credit validation and reservation.""" """Test successful credit validation and reservation."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class: with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
user, action = await credit_service.validate_and_reserve_credits( user, action = await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND 1, CreditActionType.VLC_PLAY_SOUND,
) )
assert user == sample_user assert user == sample_user
assert action.action_type == CreditActionType.VLC_PLAY_SOUND assert action.action_type == CreditActionType.VLC_PLAY_SOUND
assert action.cost == 1 assert action.cost == 1
@@ -123,17 +123,17 @@ class TestCreditService:
credits=0, credits=0,
plan_id=1, plan_id=1,
) )
with patch("app.services.credit.UserRepository") as mock_repo_class: with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user mock_repo.get_by_id.return_value = poor_user
with pytest.raises(InsufficientCreditsError) as exc_info: with pytest.raises(InsufficientCreditsError) as exc_info:
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND 1, CreditActionType.VLC_PLAY_SOUND,
) )
assert exc_info.value.required == 1 assert exc_info.value.required == 1
assert exc_info.value.available == 0 assert exc_info.value.available == 0
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -142,42 +142,42 @@ class TestCreditService:
async def test_validate_and_reserve_credits_user_not_found(self, credit_service): async def test_validate_and_reserve_credits_user_not_found(self, credit_service):
"""Test credit validation when user is not found.""" """Test credit validation when user is not found."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class: with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None mock_repo.get_by_id.return_value = None
with pytest.raises(ValueError, match="User 999 not found"): with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
999, CreditActionType.VLC_PLAY_SOUND 999, CreditActionType.VLC_PLAY_SOUND,
) )
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deduct_credits_success(self, credit_service, sample_user): async def test_deduct_credits_success(self, credit_service, sample_user):
"""Test successful credit deduction.""" """Test successful credit deduction."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \ with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager: patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits( transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"} 1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"},
) )
# Verify user credits were updated # Verify user credits were updated
mock_repo.update.assert_called_once_with(sample_user, {"credits": 9}) mock_repo.update.assert_called_once_with(sample_user, {"credits": 9})
# Verify transaction was created # Verify transaction was created
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
# Verify socket event was emitted # Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with( mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", { "1", "user_credits_changed", {
@@ -187,9 +187,9 @@ class TestCreditService:
"credits_deducted": 1, "credits_deducted": 1,
"action_type": "vlc_play_sound", "action_type": "vlc_play_sound",
"success": True, "success": True,
} },
) )
# Check transaction details # Check transaction details
added_transaction = mock_session.add.call_args[0][0] added_transaction = mock_session.add.call_args[0][0]
assert isinstance(added_transaction, CreditTransaction) assert isinstance(added_transaction, CreditTransaction)
@@ -205,28 +205,28 @@ class TestCreditService:
async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user): async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user):
"""Test credit deduction when action failed but requires success.""" """Test credit deduction when action failed but requires success."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \ with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager: patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits( transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, False # Action failed 1, CreditActionType.VLC_PLAY_SOUND, False, # Action failed
) )
# Verify user credits were NOT updated (action requires success) # Verify user credits were NOT updated (action requires success)
mock_repo.update.assert_not_called() mock_repo.update.assert_not_called()
# Verify transaction was still created for auditing # Verify transaction was still created for auditing
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
# Verify no socket event was emitted since no credits were actually deducted # Verify no socket event was emitted since no credits were actually deducted
mock_socket_manager.send_to_user.assert_not_called() mock_socket_manager.send_to_user.assert_not_called()
# Check transaction details # Check transaction details
added_transaction = mock_session.add.call_args[0][0] added_transaction = mock_session.add.call_args[0][0]
assert added_transaction.amount == 0 # No deduction for failed action assert added_transaction.amount == 0 # No deduction for failed action
@@ -246,22 +246,22 @@ class TestCreditService:
credits=0, credits=0,
plan_id=1, plan_id=1,
) )
with patch("app.services.credit.UserRepository") as mock_repo_class, \ with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager: patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user mock_repo.get_by_id.return_value = poor_user
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
with pytest.raises(InsufficientCreditsError): with pytest.raises(InsufficientCreditsError):
await credit_service.deduct_credits( await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True 1, CreditActionType.VLC_PLAY_SOUND, True,
) )
# Verify no socket event was emitted since credits could not be deducted # Verify no socket event was emitted since credits could not be deducted
mock_socket_manager.send_to_user.assert_not_called() mock_socket_manager.send_to_user.assert_not_called()
mock_session.rollback.assert_called_once() mock_session.rollback.assert_called_once()
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -269,25 +269,25 @@ class TestCreditService:
async def test_add_credits(self, credit_service, sample_user): async def test_add_credits(self, credit_service, sample_user):
"""Test adding credits to user account.""" """Test adding credits to user account."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \ with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager: patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.add_credits( transaction = await credit_service.add_credits(
1, 5, "Bonus credits", {"reason": "signup"} 1, 5, "Bonus credits", {"reason": "signup"},
) )
# Verify user credits were updated # Verify user credits were updated
mock_repo.update.assert_called_once_with(sample_user, {"credits": 15}) mock_repo.update.assert_called_once_with(sample_user, {"credits": 15})
# Verify transaction was created # Verify transaction was created
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
# Verify socket event was emitted # Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with( mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", { "1", "user_credits_changed", {
@@ -297,9 +297,9 @@ class TestCreditService:
"credits_added": 5, "credits_added": 5,
"description": "Bonus credits", "description": "Bonus credits",
"success": True, "success": True,
} },
) )
# Check transaction details # Check transaction details
added_transaction = mock_session.add.call_args[0][0] added_transaction = mock_session.add.call_args[0][0]
assert added_transaction.amount == 5 assert added_transaction.amount == 5
@@ -312,7 +312,7 @@ class TestCreditService:
"""Test adding invalid amount of credits.""" """Test adding invalid amount of credits."""
with pytest.raises(ValueError, match="Amount must be positive"): with pytest.raises(ValueError, match="Amount must be positive"):
await credit_service.add_credits(1, 0, "Invalid") await credit_service.add_credits(1, 0, "Invalid")
with pytest.raises(ValueError, match="Amount must be positive"): with pytest.raises(ValueError, match="Amount must be positive"):
await credit_service.add_credits(1, -5, "Invalid") await credit_service.add_credits(1, -5, "Invalid")
@@ -320,14 +320,14 @@ class TestCreditService:
async def test_get_user_balance(self, credit_service, sample_user): async def test_get_user_balance(self, credit_service, sample_user):
"""Test getting user credit balance.""" """Test getting user credit balance."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class: with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
balance = await credit_service.get_user_balance(1) balance = await credit_service.get_user_balance(1)
assert balance == 10 assert balance == 10
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -335,15 +335,15 @@ class TestCreditService:
async def test_get_user_balance_user_not_found(self, credit_service): async def test_get_user_balance_user_not_found(self, credit_service):
"""Test getting balance for non-existent user.""" """Test getting balance for non-existent user."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class: with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None mock_repo.get_by_id.return_value = None
with pytest.raises(ValueError, match="User 999 not found"): with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.get_user_balance(999) await credit_service.get_user_balance(999)
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -355,4 +355,4 @@ class TestInsufficientCreditsError:
error = InsufficientCreditsError(5, 2) error = InsufficientCreditsError(5, 2)
assert error.required == 5 assert error.required == 5
assert error.available == 2 assert error.available == 2
assert str(error) == "Insufficient credits: 5 required, 2 available" assert str(error) == "Insufficient credits: 5 required, 2 available"

View File

@@ -53,7 +53,7 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_service_info_youtube( async def test_detect_service_info_youtube(
self, mock_ydl_class, extraction_service self, mock_ydl_class, extraction_service,
): ):
"""Test service detection for YouTube.""" """Test service detection for YouTube."""
mock_ydl = Mock() mock_ydl = Mock()
@@ -67,7 +67,7 @@ class TestExtractionService:
} }
result = await extraction_service._detect_service_info( result = await extraction_service._detect_service_info(
"https://www.youtube.com/watch?v=test123" "https://www.youtube.com/watch?v=test123",
) )
assert result is not None assert result is not None
@@ -78,7 +78,7 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_service_info_failure( async def test_detect_service_info_failure(
self, mock_ydl_class, extraction_service self, mock_ydl_class, extraction_service,
): ):
"""Test service detection failure.""" """Test service detection failure."""
mock_ydl = Mock() mock_ydl = Mock()
@@ -106,7 +106,7 @@ class TestExtractionService:
status="pending", status="pending",
) )
extraction_service.extraction_repo.create = AsyncMock( extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction return_value=mock_extraction,
) )
result = await extraction_service.create_extraction(url, user_id) result = await extraction_service.create_extraction(url, user_id)
@@ -134,7 +134,7 @@ class TestExtractionService:
status="pending", status="pending",
) )
extraction_service.extraction_repo.create = AsyncMock( extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction return_value=mock_extraction,
) )
result = await extraction_service.create_extraction(url, user_id) result = await extraction_service.create_extraction(url, user_id)
@@ -160,7 +160,7 @@ class TestExtractionService:
status="pending", status="pending",
) )
extraction_service.extraction_repo.create = AsyncMock( extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction return_value=mock_extraction,
) )
result = await extraction_service.create_extraction(url, user_id) result = await extraction_service.create_extraction(url, user_id)
@@ -186,11 +186,11 @@ class TestExtractionService:
) )
extraction_service.extraction_repo.get_by_id = AsyncMock( extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=mock_extraction return_value=mock_extraction,
) )
extraction_service.extraction_repo.update = AsyncMock() extraction_service.extraction_repo.update = AsyncMock()
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock( extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
return_value=None return_value=None,
) )
# Mock service detection # Mock service detection
@@ -202,14 +202,14 @@ class TestExtractionService:
with ( with (
patch.object( patch.object(
extraction_service, "_detect_service_info", return_value=service_info extraction_service, "_detect_service_info", return_value=service_info,
), ),
patch.object(extraction_service, "_extract_media") as mock_extract, patch.object(extraction_service, "_extract_media") as mock_extract,
patch.object( patch.object(
extraction_service, "_move_files_to_final_location" extraction_service, "_move_files_to_final_location",
) as mock_move, ) as mock_move,
patch.object( patch.object(
extraction_service, "_create_sound_record" extraction_service, "_create_sound_record",
) as mock_create_sound, ) as mock_create_sound,
patch.object(extraction_service, "_normalize_sound") as mock_normalize, patch.object(extraction_service, "_normalize_sound") as mock_normalize,
patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist, patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist,
@@ -223,7 +223,7 @@ class TestExtractionService:
# Verify service detection was called # Verify service detection was called
extraction_service._detect_service_info.assert_called_once_with( extraction_service._detect_service_info.assert_called_once_with(
"https://www.youtube.com/watch?v=test123" "https://www.youtube.com/watch?v=test123",
) )
# Verify extraction was updated with service info # Verify extraction was updated with service info
@@ -289,15 +289,15 @@ class TestExtractionService:
with ( with (
patch( patch(
"app.services.extraction.get_audio_duration", return_value=240000 "app.services.extraction.get_audio_duration", return_value=240000,
), ),
patch("app.services.extraction.get_file_size", return_value=1024), patch("app.services.extraction.get_file_size", return_value=1024),
patch( patch(
"app.services.extraction.get_file_hash", return_value="test_hash" "app.services.extraction.get_file_hash", return_value="test_hash",
), ),
): ):
extraction_service.sound_repo.create = AsyncMock( extraction_service.sound_repo.create = AsyncMock(
return_value=mock_sound return_value=mock_sound,
) )
result = await extraction_service._create_sound_record( result = await extraction_service._create_sound_record(
@@ -336,7 +336,7 @@ class TestExtractionService:
mock_normalizer = Mock() mock_normalizer = Mock()
mock_normalizer.normalize_sound = AsyncMock( mock_normalizer.normalize_sound = AsyncMock(
return_value={"status": "normalized"} return_value={"status": "normalized"},
) )
with patch( with patch(
@@ -368,7 +368,7 @@ class TestExtractionService:
mock_normalizer = Mock() mock_normalizer = Mock()
mock_normalizer.normalize_sound = AsyncMock( mock_normalizer.normalize_sound = AsyncMock(
return_value={"status": "error", "error": "Test error"} return_value={"status": "error", "error": "Test error"},
) )
with patch( with patch(
@@ -395,7 +395,7 @@ class TestExtractionService:
) )
extraction_service.extraction_repo.get_by_id = AsyncMock( extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=extraction return_value=extraction,
) )
result = await extraction_service.get_extraction_by_id(1) result = await extraction_service.get_extraction_by_id(1)
@@ -443,7 +443,7 @@ class TestExtractionService:
] ]
extraction_service.extraction_repo.get_by_user = AsyncMock( extraction_service.extraction_repo.get_by_user = AsyncMock(
return_value=extractions return_value=extractions,
) )
result = await extraction_service.get_user_extractions(1) result = await extraction_service.get_user_extractions(1)
@@ -470,7 +470,7 @@ class TestExtractionService:
] ]
extraction_service.extraction_repo.get_pending_extractions = AsyncMock( extraction_service.extraction_repo.get_pending_extractions = AsyncMock(
return_value=pending_extractions return_value=pending_extractions,
) )
result = await extraction_service.get_pending_extractions() result = await extraction_service.get_pending_extractions()

View File

@@ -1,6 +1,5 @@
"""Tests for extraction background processor.""" """Tests for extraction background processor."""
import asyncio
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
@@ -30,7 +29,7 @@ class TestExtractionProcessor:
"""Test starting and stopping the processor.""" """Test starting and stopping the processor."""
# Mock the _process_queue method to avoid actual processing # Mock the _process_queue method to avoid actual processing
with patch.object( with patch.object(
processor, "_process_queue", new_callable=AsyncMock processor, "_process_queue", new_callable=AsyncMock,
) as mock_process: ) as mock_process:
# Start the processor # Start the processor
await processor.start() await processor.start()
@@ -138,12 +137,12 @@ class TestExtractionProcessor:
# Mock the extraction service # Mock the extraction service
mock_service = Mock() mock_service = Mock()
mock_service.process_extraction = AsyncMock( mock_service.process_extraction = AsyncMock(
return_value={"status": "completed", "id": extraction_id} return_value={"status": "completed", "id": extraction_id},
) )
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -168,7 +167,7 @@ class TestExtractionProcessor:
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -193,12 +192,12 @@ class TestExtractionProcessor:
# Mock extraction service # Mock extraction service
mock_service = Mock() mock_service = Mock()
mock_service.get_pending_extractions = AsyncMock( mock_service.get_pending_extractions = AsyncMock(
return_value=[{"id": 100, "status": "pending"}] return_value=[{"id": 100, "status": "pending"}],
) )
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -222,15 +221,15 @@ class TestExtractionProcessor:
return_value=[ return_value=[
{"id": 100, "status": "pending"}, {"id": 100, "status": "pending"},
{"id": 101, "status": "pending"}, {"id": 101, "status": "pending"},
] ],
) )
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch.object( patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock processor, "_process_single_extraction", new_callable=AsyncMock,
) as mock_process, ) as mock_process,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -267,15 +266,15 @@ class TestExtractionProcessor:
{"id": 100, "status": "pending"}, {"id": 100, "status": "pending"},
{"id": 101, "status": "pending"}, {"id": 101, "status": "pending"},
{"id": 102, "status": "pending"}, {"id": 102, "status": "pending"},
] ],
) )
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch.object( patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock processor, "_process_single_extraction", new_callable=AsyncMock,
) as mock_process, ) as mock_process,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",

View File

@@ -4,14 +4,12 @@ import asyncio
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.models.user import User
from app.services.player import ( from app.services.player import (
PlayerMode, PlayerMode,
PlayerService, PlayerService,
@@ -21,7 +19,6 @@ from app.services.player import (
initialize_player_service, initialize_player_service,
shutdown_player_service, shutdown_player_service,
) )
from app.utils.audio import get_sound_file_path
class TestPlayerState: class TestPlayerState:
@@ -200,7 +197,7 @@ class TestPlayerService:
mock_file_path = Mock(spec=Path) mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = True mock_file_path.exists.return_value = True
mock_path.return_value = mock_file_path mock_path.return_value = mock_file_path
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock): with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock):
mock_media = Mock() mock_media = Mock()
player_service._vlc_instance.media_new.return_value = mock_media player_service._vlc_instance.media_new.return_value = mock_media
@@ -738,7 +735,7 @@ class TestPlayerService:
# Verify sound play count was updated # Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
mock_sound, {"play_count": 6} mock_sound, {"play_count": 6},
) )
# Verify SoundPlayed record was created with None user_id for player # Verify SoundPlayed record was created with None user_id for player
@@ -768,7 +765,7 @@ class TestPlayerService:
mock_file_path = Mock(spec=Path) mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = False # File doesn't exist mock_file_path.exists.return_value = False # File doesn't exist
mock_path.return_value = mock_file_path mock_path.return_value = mock_file_path
# This should fail because file doesn't exist # This should fail because file doesn't exist
result = asyncio.run(player_service.play(0)) result = asyncio.run(player_service.play(0))
# Verify the utility was called # Verify the utility was called
@@ -817,4 +814,4 @@ class TestPlayerServiceGlobalFunctions:
"""Test getting player service when not initialized.""" """Test getting player service when not initialized."""
with patch("app.services.player.player_service", None): with patch("app.services.player.player_service", None):
with pytest.raises(RuntimeError, match="Player service not initialized"): with pytest.raises(RuntimeError, match="Player service not initialized"):
get_player_service() get_player_service()

View File

@@ -153,7 +153,6 @@ class TestPlaylistService:
test_user: User, test_user: User,
) -> None: ) -> None:
"""Test getting non-existent playlist.""" """Test getting non-existent playlist."""
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_playlist_by_id(99999) await playlist_service.get_playlist_by_id(99999)
@@ -168,7 +167,6 @@ class TestPlaylistService:
test_session: AsyncSession, test_session: AsyncSession,
) -> None: ) -> None:
"""Test getting existing main playlist.""" """Test getting existing main playlist."""
# Create main playlist manually # Create main playlist manually
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
@@ -193,7 +191,6 @@ class TestPlaylistService:
test_user: User, test_user: User,
) -> None: ) -> None:
"""Test that service fails if main playlist doesn't exist.""" """Test that service fails if main playlist doesn't exist."""
# Should raise an HTTPException if no main playlist exists # Should raise an HTTPException if no main playlist exists
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_main_playlist() await playlist_service.get_main_playlist()
@@ -207,7 +204,6 @@ class TestPlaylistService:
test_user: User, test_user: User,
) -> None: ) -> None:
"""Test creating a new playlist successfully.""" """Test creating a new playlist successfully."""
user_id = test_user.id # Extract user_id while session is available user_id = test_user.id # Extract user_id while session is available
playlist = await playlist_service.create_playlist( playlist = await playlist_service.create_playlist(
user_id=user_id, user_id=user_id,
@@ -246,7 +242,7 @@ class TestPlaylistService:
test_session.add(playlist) test_session.add(playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
# Extract name before async call # Extract name before async call
playlist_name = playlist.name playlist_name = playlist.name
@@ -280,10 +276,10 @@ class TestPlaylistService:
test_session.add(current_playlist) test_session.add(current_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(current_playlist) await test_session.refresh(current_playlist)
# Verify the existing current playlist # Verify the existing current playlist
assert current_playlist.is_current is True assert current_playlist.is_current is True
# Extract ID before async call # Extract ID before async call
current_playlist_id = current_playlist.id current_playlist_id = current_playlist.id
@@ -323,10 +319,10 @@ class TestPlaylistService:
test_session.add(playlist) test_session.add(playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
# Extract IDs before async call # Extract IDs before async call
playlist_id = playlist.id playlist_id = playlist.id
updated_playlist = await playlist_service.update_playlist( updated_playlist = await playlist_service.update_playlist(
playlist_id=playlist_id, playlist_id=playlist_id,
user_id=user_id, user_id=user_id,
@@ -359,7 +355,7 @@ class TestPlaylistService:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
current_playlist = Playlist( current_playlist = Playlist(
user_id=user_id, user_id=user_id,
name="Current Playlist", name="Current Playlist",
@@ -372,7 +368,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(current_playlist) await test_session.refresh(current_playlist)
# Extract IDs before async calls # Extract IDs before async calls
test_playlist_id = test_playlist.id test_playlist_id = test_playlist.id
current_playlist_id = current_playlist.id current_playlist_id = current_playlist.id
@@ -416,7 +412,7 @@ class TestPlaylistService:
test_session.add(playlist) test_session.add(playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
# Extract ID before async call # Extract ID before async call
playlist_id = playlist.id playlist_id = playlist.id
@@ -445,7 +441,7 @@ class TestPlaylistService:
is_deletable=False, is_deletable=False,
) )
test_session.add(main_playlist) test_session.add(main_playlist)
# Create current playlist within this test # Create current playlist within this test
user_id = test_user.id user_id = test_user.id
current_playlist = Playlist( current_playlist = Playlist(
@@ -459,7 +455,7 @@ class TestPlaylistService:
test_session.add(current_playlist) test_session.add(current_playlist)
await test_session.commit() await test_session.commit()
await test_session.refresh(current_playlist) await test_session.refresh(current_playlist)
# Extract ID before async call # Extract ID before async call
current_playlist_id = current_playlist.id current_playlist_id = current_playlist.id
@@ -481,7 +477,7 @@ class TestPlaylistService:
"""Test deleting non-deletable playlist fails.""" """Test deleting non-deletable playlist fails."""
# Extract user ID immediately # Extract user ID immediately
user_id = test_user.id user_id = test_user.id
# Create non-deletable playlist # Create non-deletable playlist
non_deletable = Playlist( non_deletable = Playlist(
user_id=user_id, user_id=user_id,
@@ -491,7 +487,7 @@ class TestPlaylistService:
test_session.add(non_deletable) test_session.add(non_deletable)
await test_session.commit() await test_session.commit()
await test_session.refresh(non_deletable) await test_session.refresh(non_deletable)
# Extract ID before async call # Extract ID before async call
non_deletable_id = non_deletable.id non_deletable_id = non_deletable.id
@@ -521,7 +517,7 @@ class TestPlaylistService:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -535,7 +531,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
@@ -571,7 +567,7 @@ class TestPlaylistService:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -585,7 +581,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
@@ -628,7 +624,7 @@ class TestPlaylistService:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -642,7 +638,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
@@ -685,7 +681,7 @@ class TestPlaylistService:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -699,7 +695,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
@@ -734,7 +730,7 @@ class TestPlaylistService:
is_deletable=True, is_deletable=True,
) )
test_session.add(test_playlist) test_session.add(test_playlist)
current_playlist = Playlist( current_playlist = Playlist(
user_id=user_id, user_id=user_id,
name="Current Playlist", name="Current Playlist",
@@ -747,7 +743,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(test_playlist) await test_session.refresh(test_playlist)
await test_session.refresh(current_playlist) await test_session.refresh(current_playlist)
# Extract IDs before async calls # Extract IDs before async calls
test_playlist_id = test_playlist.id test_playlist_id = test_playlist.id
current_playlist_id = current_playlist.id current_playlist_id = current_playlist.id
@@ -758,7 +754,7 @@ class TestPlaylistService:
# Set test_playlist as current # Set test_playlist as current
updated_playlist = await playlist_service.set_current_playlist( updated_playlist = await playlist_service.set_current_playlist(
test_playlist_id, user_id test_playlist_id, user_id,
) )
assert updated_playlist.is_current is True assert updated_playlist.is_current is True
@@ -786,7 +782,7 @@ class TestPlaylistService:
is_deletable=True, is_deletable=True,
) )
test_session.add(current_playlist) test_session.add(current_playlist)
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
name="Main Playlist", name="Main Playlist",
@@ -799,7 +795,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(current_playlist) await test_session.refresh(current_playlist)
await test_session.refresh(main_playlist) await test_session.refresh(main_playlist)
# Extract IDs before async calls # Extract IDs before async calls
current_playlist_id = current_playlist.id current_playlist_id = current_playlist.id
main_playlist_id = main_playlist.id main_playlist_id = main_playlist.id
@@ -839,7 +835,7 @@ class TestPlaylistService:
is_deletable=True, is_deletable=True,
) )
test_session.add(playlist) test_session.add(playlist)
sound = Sound( sound = Sound(
name="Test Sound", name="Test Sound",
filename="test.mp3", filename="test.mp3",
@@ -853,7 +849,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
await test_session.refresh(sound) await test_session.refresh(sound)
# Extract IDs before async calls # Extract IDs before async calls
playlist_id = playlist.id playlist_id = playlist.id
sound_id = sound.id sound_id = sound.id
@@ -897,7 +893,7 @@ class TestPlaylistService:
play_count=10, play_count=10,
) )
test_session.add(sound) test_session.add(sound)
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
name="Main Playlist", name="Main Playlist",
@@ -910,7 +906,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(sound) await test_session.refresh(sound)
await test_session.refresh(main_playlist) await test_session.refresh(main_playlist)
# Extract IDs before async calls # Extract IDs before async calls
sound_id = sound.id sound_id = sound.id
main_playlist_id = main_playlist.id main_playlist_id = main_playlist.id
@@ -943,7 +939,7 @@ class TestPlaylistService:
play_count=10, play_count=10,
) )
test_session.add(sound) test_session.add(sound)
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
name="Main Playlist", name="Main Playlist",
@@ -956,7 +952,7 @@ class TestPlaylistService:
await test_session.commit() await test_session.commit()
await test_session.refresh(sound) await test_session.refresh(sound)
await test_session.refresh(main_playlist) await test_session.refresh(main_playlist)
# Extract IDs before async calls # Extract IDs before async calls
sound_id = sound.id sound_id = sound.id
main_playlist_id = main_playlist.id main_playlist_id = main_playlist.id

View File

@@ -98,7 +98,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_success( async def test_connect_handler_success(
self, mock_decode, mock_extract_token, socket_manager, mock_sio self, mock_decode, mock_extract_token, socket_manager, mock_sio,
): ):
"""Test successful connection with valid token.""" """Test successful connection with valid token."""
# Setup mocks # Setup mocks
@@ -133,7 +133,7 @@ class TestSocketManager:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
async def test_connect_handler_no_token( async def test_connect_handler_no_token(
self, mock_extract_token, socket_manager, mock_sio self, mock_extract_token, socket_manager, mock_sio,
): ):
"""Test connection with no access token.""" """Test connection with no access token."""
# Setup mocks # Setup mocks
@@ -167,7 +167,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_invalid_token( async def test_connect_handler_invalid_token(
self, mock_decode, mock_extract_token, socket_manager, mock_sio self, mock_decode, mock_extract_token, socket_manager, mock_sio,
): ):
"""Test connection with invalid token.""" """Test connection with invalid token."""
# Setup mocks # Setup mocks
@@ -202,7 +202,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_missing_user_id( async def test_connect_handler_missing_user_id(
self, mock_decode, mock_extract_token, socket_manager, mock_sio self, mock_decode, mock_extract_token, socket_manager, mock_sio,
): ):
"""Test connection with token missing user ID.""" """Test connection with token missing user ID."""
# Setup mocks # Setup mocks

View File

@@ -55,7 +55,7 @@ class TestSoundNormalizerService:
normalized_path = normalizer_service._get_normalized_path(sound) normalized_path = normalizer_service._get_normalized_path(sound)
assert "sounds/normalized/soundboard" in str(normalized_path) assert "sounds/normalized/soundboard" in str(normalized_path)
assert "test_audio.mp3" == normalized_path.name assert normalized_path.name == "test_audio.mp3"
def test_get_original_path(self, normalizer_service): def test_get_original_path(self, normalizer_service):
"""Test original path generation.""" """Test original path generation."""
@@ -72,7 +72,7 @@ class TestSoundNormalizerService:
original_path = normalizer_service._get_original_path(sound) original_path = normalizer_service._get_original_path(sound)
assert "sounds/originals/soundboard" in str(original_path) assert "sounds/originals/soundboard" in str(original_path)
assert "test_audio.wav" == original_path.name assert original_path.name == "test_audio.wav"
def test_get_file_hash(self, normalizer_service): def test_get_file_hash(self, normalizer_service):
"""Test file hash calculation.""" """Test file hash calculation."""
@@ -172,14 +172,14 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path, patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path, patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
patch.object( patch.object(
normalizer_service, "_normalize_audio_two_pass" normalizer_service, "_normalize_audio_two_pass",
) as mock_normalize, ) as mock_normalize,
patch( patch(
"app.services.sound_normalizer.get_audio_duration", return_value=6000 "app.services.sound_normalizer.get_audio_duration", return_value=6000,
), ),
patch("app.services.sound_normalizer.get_file_size", return_value=2048), patch("app.services.sound_normalizer.get_file_size", return_value=2048),
patch( patch(
"app.services.sound_normalizer.get_file_hash", return_value="new_hash" "app.services.sound_normalizer.get_file_hash", return_value="new_hash",
), ),
): ):
# Setup path mocks # Setup path mocks
@@ -245,14 +245,14 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path, patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path, patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
patch.object( patch.object(
normalizer_service, "_normalize_audio_one_pass" normalizer_service, "_normalize_audio_one_pass",
) as mock_normalize, ) as mock_normalize,
patch( patch(
"app.services.sound_normalizer.get_audio_duration", return_value=5500 "app.services.sound_normalizer.get_audio_duration", return_value=5500,
), ),
patch("app.services.sound_normalizer.get_file_size", return_value=1500), patch("app.services.sound_normalizer.get_file_size", return_value=1500),
patch( patch(
"app.services.sound_normalizer.get_file_hash", return_value="norm_hash" "app.services.sound_normalizer.get_file_hash", return_value="norm_hash",
), ),
): ):
# Setup path mocks # Setup path mocks
@@ -300,7 +300,7 @@ class TestSoundNormalizerService:
with ( with (
patch("pathlib.Path.exists", return_value=True), patch("pathlib.Path.exists", return_value=True),
patch.object( patch.object(
normalizer_service, "_normalize_audio_two_pass" normalizer_service, "_normalize_audio_two_pass",
) as mock_normalize, ) as mock_normalize,
): ):
mock_normalize.side_effect = Exception("Normalization failed") mock_normalize.side_effect = Exception("Normalization failed")
@@ -339,7 +339,7 @@ class TestSoundNormalizerService:
# Mock repository calls # Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock( normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock(
return_value=sounds return_value=sounds,
) )
# Mock individual normalization # Mock individual normalization
@@ -399,7 +399,7 @@ class TestSoundNormalizerService:
# Mock repository calls # Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds_by_type = AsyncMock( normalizer_service.sound_repo.get_unnormalized_sounds_by_type = AsyncMock(
return_value=sdb_sounds return_value=sdb_sounds,
) )
# Mock individual normalization # Mock individual normalization
@@ -428,7 +428,7 @@ class TestSoundNormalizerService:
# Verify correct repository method was called # Verify correct repository method was called
normalizer_service.sound_repo.get_unnormalized_sounds_by_type.assert_called_once_with( normalizer_service.sound_repo.get_unnormalized_sounds_by_type.assert_called_once_with(
"SDB" "SDB",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -459,7 +459,7 @@ class TestSoundNormalizerService:
# Mock repository calls # Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock( normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock(
return_value=sounds return_value=sounds,
) )
# Mock individual normalization with one success and one error # Mock individual normalization with one success and one error
@@ -529,7 +529,7 @@ class TestSoundNormalizerService:
# Verify ffmpeg chain was called correctly # Verify ffmpeg chain was called correctly
mock_ffmpeg.input.assert_called_once_with(str(input_path)) mock_ffmpeg.input.assert_called_once_with(str(input_path))
mock_ffmpeg.filter.assert_called_once_with( mock_ffmpeg.filter.assert_called_once_with(
mock_stream, "loudnorm", I=-23, TP=-2, LRA=7 mock_stream, "loudnorm", I=-23, TP=-2, LRA=7,
) )
mock_ffmpeg.output.assert_called_once() mock_ffmpeg.output.assert_called_once()
mock_ffmpeg.run.assert_called_once() mock_ffmpeg.run.assert_called_once()

View File

@@ -153,7 +153,7 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "SDB", existing_sound, results temp_path, "SDB", existing_sound, results,
) )
assert results["skipped"] == 1 assert results["skipped"] == 1
@@ -257,7 +257,7 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "SDB", existing_sound, results temp_path, "SDB", existing_sound, results,
) )
assert results["updated"] == 1 assert results["updated"] == 1
@@ -296,7 +296,7 @@ class TestSoundScannerService:
# Mock file operations # Mock file operations
with ( with (
patch( patch(
"app.services.sound_scanner.get_file_hash", return_value="custom_hash" "app.services.sound_scanner.get_file_hash", return_value="custom_hash",
), ),
patch("app.services.sound_scanner.get_audio_duration", return_value=60000), patch("app.services.sound_scanner.get_audio_duration", return_value=60000),
patch("app.services.sound_scanner.get_file_size", return_value=2048), patch("app.services.sound_scanner.get_file_size", return_value=2048),
@@ -316,7 +316,7 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "CUSTOM", None, results temp_path, "CUSTOM", None, results,
) )
assert results["added"] == 1 assert results["added"] == 1

View File

@@ -7,10 +7,8 @@ from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from app.models.sound import Sound from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.models.user import User from app.models.user import User
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
from app.utils.audio import get_sound_file_path
class TestVLCPlayerService: class TestVLCPlayerService:
@@ -79,16 +77,16 @@ class TestVLCPlayerService:
def test_find_vlc_executable_found_by_path(self, mock_run): def test_find_vlc_executable_found_by_path(self, mock_run):
"""Test VLC executable detection when found by absolute path.""" """Test VLC executable detection when found by absolute path."""
mock_run.return_value.returncode = 1 # which command fails mock_run.return_value.returncode = 1 # which command fails
# Mock Path to return True for the first absolute path # Mock Path to return True for the first absolute path
with patch("app.services.vlc_player.Path") as mock_path: with patch("app.services.vlc_player.Path") as mock_path:
def path_side_effect(path_str): def path_side_effect(path_str):
mock_instance = Mock() mock_instance = Mock()
mock_instance.exists.return_value = str(path_str) == "/usr/bin/vlc" mock_instance.exists.return_value = str(path_str) == "/usr/bin/vlc"
return mock_instance return mock_instance
mock_path.side_effect = path_side_effect mock_path.side_effect = path_side_effect
service = VLCPlayerService() service = VLCPlayerService()
assert service.vlc_executable == "/usr/bin/vlc" assert service.vlc_executable == "/usr/bin/vlc"
@@ -100,10 +98,10 @@ class TestVLCPlayerService:
mock_path_instance = Mock() mock_path_instance = Mock()
mock_path_instance.exists.return_value = False mock_path_instance.exists.return_value = False
mock_path.return_value = mock_path_instance mock_path.return_value = mock_path_instance
# Mock which command as failing # Mock which command as failing
mock_run.return_value.returncode = 1 mock_run.return_value.returncode = 1
service = VLCPlayerService() service = VLCPlayerService()
assert service.vlc_executable == "vlc" assert service.vlc_executable == "vlc"
@@ -111,26 +109,26 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_success( async def test_play_sound_success(
self, mock_subprocess, vlc_service, sample_sound self, mock_subprocess, vlc_service, sample_sound,
): ):
"""Test successful sound playback.""" """Test successful sound playback."""
# Mock subprocess # Mock subprocess
mock_process = Mock() mock_process = Mock()
mock_process.pid = 12345 mock_process.pid = 12345
mock_subprocess.return_value = mock_process mock_subprocess.return_value = mock_process
# Mock the file path utility to avoid Path issues # Mock the file path utility to avoid Path issues
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path: with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
mock_path = Mock() mock_path = Mock()
mock_path.exists.return_value = True mock_path.exists.return_value = True
mock_get_path.return_value = mock_path mock_get_path.return_value = mock_path
result = await vlc_service.play_sound(sample_sound) result = await vlc_service.play_sound(sample_sound)
assert result is True assert result is True
mock_subprocess.assert_called_once() mock_subprocess.assert_called_once()
args = mock_subprocess.call_args args = mock_subprocess.call_args
# Check command arguments # Check command arguments
cmd_args = args[1] # keyword arguments cmd_args = args[1] # keyword arguments
assert "--play-and-exit" in args[0] assert "--play-and-exit" in args[0]
@@ -144,7 +142,7 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_play_sound_file_not_found( async def test_play_sound_file_not_found(
self, vlc_service, sample_sound self, vlc_service, sample_sound,
): ):
"""Test sound playback when file doesn't exist.""" """Test sound playback when file doesn't exist."""
# Mock the file path utility to return a non-existent path # Mock the file path utility to return a non-existent path
@@ -152,15 +150,15 @@ class TestVLCPlayerService:
mock_path = Mock() mock_path = Mock()
mock_path.exists.return_value = False mock_path.exists.return_value = False
mock_get_path.return_value = mock_path mock_get_path.return_value = mock_path
result = await vlc_service.play_sound(sample_sound) result = await vlc_service.play_sound(sample_sound)
assert result is False assert result is False
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_subprocess_error( async def test_play_sound_subprocess_error(
self, mock_subprocess, vlc_service, sample_sound self, mock_subprocess, vlc_service, sample_sound,
): ):
"""Test sound playback when subprocess fails.""" """Test sound playback when subprocess fails."""
# Mock the file path utility to return an existing path # Mock the file path utility to return an existing path
@@ -168,12 +166,12 @@ class TestVLCPlayerService:
mock_path = Mock() mock_path = Mock()
mock_path.exists.return_value = True mock_path.exists.return_value = True
mock_get_path.return_value = mock_path mock_get_path.return_value = mock_path
# Mock subprocess exception # Mock subprocess exception
mock_subprocess.side_effect = Exception("Subprocess failed") mock_subprocess.side_effect = Exception("Subprocess failed")
result = await vlc_service.play_sound(sample_sound) result = await vlc_service.play_sound(sample_sound)
assert result is False assert result is False
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -184,27 +182,27 @@ class TestVLCPlayerService:
mock_find_process = Mock() mock_find_process = Mock()
mock_find_process.returncode = 0 mock_find_process.returncode = 0
mock_find_process.communicate = AsyncMock( mock_find_process.communicate = AsyncMock(
return_value=(b"12345\n67890\n", b"") return_value=(b"12345\n67890\n", b""),
) )
# Mock pkill process (kill VLC processes) # Mock pkill process (kill VLC processes)
mock_kill_process = Mock() mock_kill_process = Mock()
mock_kill_process.communicate = AsyncMock(return_value=(b"", b"")) mock_kill_process.communicate = AsyncMock(return_value=(b"", b""))
# Mock verify process (check remaining processes) # Mock verify process (check remaining processes)
mock_verify_process = Mock() mock_verify_process = Mock()
mock_verify_process.returncode = 1 # No processes found mock_verify_process.returncode = 1 # No processes found
mock_verify_process.communicate = AsyncMock(return_value=(b"", b"")) mock_verify_process.communicate = AsyncMock(return_value=(b"", b""))
# Set up subprocess mock to return different processes for each call # Set up subprocess mock to return different processes for each call
mock_subprocess.side_effect = [ mock_subprocess.side_effect = [
mock_find_process, mock_find_process,
mock_kill_process, mock_kill_process,
mock_verify_process, mock_verify_process,
] ]
result = await vlc_service.stop_all_vlc_instances() result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True assert result["success"] is True
assert result["processes_found"] == 2 assert result["processes_found"] == 2
assert result["processes_killed"] == 2 assert result["processes_killed"] == 2
@@ -214,18 +212,18 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_no_processes( async def test_stop_all_vlc_instances_no_processes(
self, mock_subprocess, vlc_service self, mock_subprocess, vlc_service,
): ):
"""Test stopping VLC instances when none are running.""" """Test stopping VLC instances when none are running."""
# Mock pgrep process (no VLC processes found) # Mock pgrep process (no VLC processes found)
mock_find_process = Mock() mock_find_process = Mock()
mock_find_process.returncode = 1 # No processes found mock_find_process.returncode = 1 # No processes found
mock_find_process.communicate = AsyncMock(return_value=(b"", b"")) mock_find_process.communicate = AsyncMock(return_value=(b"", b""))
mock_subprocess.return_value = mock_find_process mock_subprocess.return_value = mock_find_process
result = await vlc_service.stop_all_vlc_instances() result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True assert result["success"] is True
assert result["processes_found"] == 0 assert result["processes_found"] == 0
assert result["processes_killed"] == 0 assert result["processes_killed"] == 0
@@ -234,33 +232,33 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_partial_kill( async def test_stop_all_vlc_instances_partial_kill(
self, mock_subprocess, vlc_service self, mock_subprocess, vlc_service,
): ):
"""Test stopping VLC instances when some processes remain.""" """Test stopping VLC instances when some processes remain."""
# Mock pgrep process (find VLC processes) # Mock pgrep process (find VLC processes)
mock_find_process = Mock() mock_find_process = Mock()
mock_find_process.returncode = 0 mock_find_process.returncode = 0
mock_find_process.communicate = AsyncMock( mock_find_process.communicate = AsyncMock(
return_value=(b"12345\n67890\n11111\n", b"") return_value=(b"12345\n67890\n11111\n", b""),
) )
# Mock pkill process (kill VLC processes) # Mock pkill process (kill VLC processes)
mock_kill_process = Mock() mock_kill_process = Mock()
mock_kill_process.communicate = AsyncMock(return_value=(b"", b"")) mock_kill_process.communicate = AsyncMock(return_value=(b"", b""))
# Mock verify process (one process remains) # Mock verify process (one process remains)
mock_verify_process = Mock() mock_verify_process = Mock()
mock_verify_process.returncode = 0 mock_verify_process.returncode = 0
mock_verify_process.communicate = AsyncMock(return_value=(b"11111\n", b"")) mock_verify_process.communicate = AsyncMock(return_value=(b"11111\n", b""))
mock_subprocess.side_effect = [ mock_subprocess.side_effect = [
mock_find_process, mock_find_process,
mock_kill_process, mock_kill_process,
mock_verify_process, mock_verify_process,
] ]
result = await vlc_service.stop_all_vlc_instances() result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True assert result["success"] is True
assert result["processes_found"] == 3 assert result["processes_found"] == 3
assert result["processes_killed"] == 2 assert result["processes_killed"] == 2
@@ -272,9 +270,9 @@ class TestVLCPlayerService:
"""Test stopping VLC instances when an error occurs.""" """Test stopping VLC instances when an error occurs."""
# Mock subprocess exception # Mock subprocess exception
mock_subprocess.side_effect = Exception("Command failed") mock_subprocess.side_effect = Exception("Command failed")
result = await vlc_service.stop_all_vlc_instances() result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is False assert result["success"] is False
assert result["processes_found"] == 0 assert result["processes_found"] == 0
assert result["processes_killed"] == 0 assert result["processes_killed"] == 0
@@ -286,16 +284,16 @@ class TestVLCPlayerService:
with patch("app.services.vlc_player.VLCPlayerService") as mock_service_class: with patch("app.services.vlc_player.VLCPlayerService") as mock_service_class:
mock_instance = Mock() mock_instance = Mock()
mock_service_class.return_value = mock_instance mock_service_class.return_value = mock_instance
# Clear the global instance # Clear the global instance
import app.services.vlc_player import app.services.vlc_player
app.services.vlc_player.vlc_player_service = None app.services.vlc_player.vlc_player_service = None
# First call should create new instance # First call should create new instance
service1 = get_vlc_player_service() service1 = get_vlc_player_service()
assert service1 == mock_instance assert service1 == mock_instance
mock_service_class.assert_called_once() mock_service_class.assert_called_once()
# Second call should return same instance # Second call should return same instance
service2 = get_vlc_player_service() service2 = get_vlc_player_service()
assert service2 == mock_instance assert service2 == mock_instance
@@ -306,22 +304,22 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_with_play_count_tracking( async def test_play_sound_with_play_count_tracking(
self, mock_subprocess, vlc_service_with_db, sample_sound self, mock_subprocess, vlc_service_with_db, sample_sound,
): ):
"""Test sound playback with play count tracking.""" """Test sound playback with play count tracking."""
# Mock subprocess # Mock subprocess
mock_process = Mock() mock_process = Mock()
mock_process.pid = 12345 mock_process.pid = 12345
mock_subprocess.return_value = mock_process mock_subprocess.return_value = mock_process
# Mock session and repositories # Mock session and repositories
mock_session = AsyncMock() mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session vlc_service_with_db.db_session_factory.return_value = mock_session
# Mock repositories # Mock repositories
mock_sound_repo = AsyncMock() mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock() mock_user_repo = AsyncMock()
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo): with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo): with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket: with patch("app.services.vlc_player.socket_manager") as mock_socket:
@@ -331,7 +329,7 @@ class TestVLCPlayerService:
mock_path = Mock() mock_path = Mock()
mock_path.exists.return_value = True mock_path.exists.return_value = True
mock_get_path.return_value = mock_path mock_get_path.return_value = mock_path
# Mock sound repository responses # Mock sound repository responses
updated_sound = Sound( updated_sound = Sound(
id=1, id=1,
@@ -345,7 +343,7 @@ class TestVLCPlayerService:
) )
mock_sound_repo.get_by_id.return_value = sample_sound mock_sound_repo.get_by_id.return_value = sample_sound
mock_sound_repo.update.return_value = updated_sound mock_sound_repo.update.return_value = updated_sound
# Mock admin user # Mock admin user
admin_user = User( admin_user = User(
id=1, id=1,
@@ -354,20 +352,20 @@ class TestVLCPlayerService:
role="admin", role="admin",
) )
mock_user_repo.get_by_id.return_value = admin_user mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast # Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock() mock_socket.broadcast_to_all = AsyncMock()
result = await vlc_service_with_db.play_sound(sample_sound) result = await vlc_service_with_db.play_sound(sample_sound)
# Wait a bit for the async task to complete # Wait a bit for the async task to complete
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
assert result is True assert result is True
# Verify subprocess was called # Verify subprocess was called
mock_subprocess.assert_called_once() mock_subprocess.assert_called_once()
# Note: The async task runs in the background, so we can't easily # Note: The async task runs in the background, so we can't easily
# verify the database operations in this test without more complex # verify the database operations in this test without more complex
# mocking or using a real async test framework setup # mocking or using a real async test framework setup
@@ -378,10 +376,10 @@ class TestVLCPlayerService:
# Mock session and repositories # Mock session and repositories
mock_session = AsyncMock() mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session vlc_service_with_db.db_session_factory.return_value = mock_session
mock_sound_repo = AsyncMock() mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock() mock_user_repo = AsyncMock()
# Create test sound and user # Create test sound and user
test_sound = Sound( test_sound = Sound(
id=1, id=1,
@@ -399,7 +397,7 @@ class TestVLCPlayerService:
name="Admin User", name="Admin User",
role="admin", role="admin",
) )
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo): with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo): with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket: with patch("app.services.vlc_player.socket_manager") as mock_socket:
@@ -407,26 +405,26 @@ class TestVLCPlayerService:
# Setup mocks # Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound mock_sound_repo.get_by_id.return_value = test_sound
mock_user_repo.get_by_id.return_value = admin_user mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast # Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock() mock_socket.broadcast_to_all = AsyncMock()
await vlc_service_with_db._record_play_count(1, "Test Sound") await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound repository calls # Verify sound repository calls
mock_sound_repo.get_by_id.assert_called_once_with(1) mock_sound_repo.get_by_id.assert_called_once_with(1)
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 1} test_sound, {"play_count": 1},
) )
# Verify user repository calls # Verify user repository calls
mock_user_repo.get_by_id.assert_called_once_with(1) mock_user_repo.get_by_id.assert_called_once_with(1)
# Verify session operations # Verify session operations
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
# Verify socket broadcast # Verify socket broadcast
mock_socket.broadcast_to_all.assert_called_once_with( mock_socket.broadcast_to_all.assert_called_once_with(
"sound_played", "sound_played",
@@ -451,10 +449,10 @@ class TestVLCPlayerService:
# Mock session and repositories # Mock session and repositories
mock_session = AsyncMock() mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session vlc_service_with_db.db_session_factory.return_value = mock_session
mock_sound_repo = AsyncMock() mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock() mock_user_repo = AsyncMock()
# Create test sound and user # Create test sound and user
test_sound = Sound( test_sound = Sound(
id=1, id=1,
@@ -472,27 +470,27 @@ class TestVLCPlayerService:
name="Admin User", name="Admin User",
role="admin", role="admin",
) )
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo): with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo): with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket: with patch("app.services.vlc_player.socket_manager") as mock_socket:
# Setup mocks # Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound mock_sound_repo.get_by_id.return_value = test_sound
mock_user_repo.get_by_id.return_value = admin_user mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast # Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock() mock_socket.broadcast_to_all = AsyncMock()
await vlc_service_with_db._record_play_count(1, "Test Sound") await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound play count was updated # Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 6} test_sound, {"play_count": 6},
) )
# Verify new SoundPlayed record was always added # Verify new SoundPlayed record was always added
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
# Verify commit happened # Verify commit happened
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
@@ -502,10 +500,10 @@ class TestVLCPlayerService:
mock_file_path = Mock(spec=Path) mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = False # File doesn't exist mock_file_path.exists.return_value = False # File doesn't exist
mock_path.return_value = mock_file_path mock_path.return_value = mock_file_path
# This should fail because file doesn't exist # This should fail because file doesn't exist
result = asyncio.run(vlc_service.play_sound(sample_sound)) result = asyncio.run(vlc_service.play_sound(sample_sound))
# Verify the utility was called and returned False # Verify the utility was called and returned False
mock_path.assert_called_once_with(sample_sound) mock_path.assert_called_once_with(sample_sound)
assert result is False assert result is False

View File

@@ -8,7 +8,12 @@ from unittest.mock import patch
import pytest import pytest
from app.models.sound import Sound from app.models.sound import Sound
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size, get_sound_file_path from app.utils.audio import (
get_audio_duration,
get_file_hash,
get_file_size,
get_sound_file_path,
)
class TestAudioUtils: class TestAudioUtils:
@@ -301,7 +306,7 @@ class TestAudioUtils:
type="SDB", type="SDB",
is_normalized=False, is_normalized=False,
) )
result = get_sound_file_path(sound) result = get_sound_file_path(sound)
expected = Path("sounds/originals/soundboard/test.mp3") expected = Path("sounds/originals/soundboard/test.mp3")
assert result == expected assert result == expected
@@ -310,13 +315,13 @@ class TestAudioUtils:
"""Test getting sound file path for SDB type normalized file.""" """Test getting sound file path for SDB type normalized file."""
sound = Sound( sound = Sound(
id=1, id=1,
name="Test Sound", name="Test Sound",
filename="original.mp3", filename="original.mp3",
normalized_filename="normalized.mp3", normalized_filename="normalized.mp3",
type="SDB", type="SDB",
is_normalized=True, is_normalized=True,
) )
result = get_sound_file_path(sound) result = get_sound_file_path(sound)
expected = Path("sounds/normalized/soundboard/normalized.mp3") expected = Path("sounds/normalized/soundboard/normalized.mp3")
assert result == expected assert result == expected
@@ -326,11 +331,11 @@ class TestAudioUtils:
sound = Sound( sound = Sound(
id=2, id=2,
name="TTS Sound", name="TTS Sound",
filename="tts_file.wav", filename="tts_file.wav",
type="TTS", type="TTS",
is_normalized=False, is_normalized=False,
) )
result = get_sound_file_path(sound) result = get_sound_file_path(sound)
expected = Path("sounds/originals/text_to_speech/tts_file.wav") expected = Path("sounds/originals/text_to_speech/tts_file.wav")
assert result == expected assert result == expected
@@ -342,10 +347,10 @@ class TestAudioUtils:
name="TTS Sound", name="TTS Sound",
filename="original.wav", filename="original.wav",
normalized_filename="normalized.mp3", normalized_filename="normalized.mp3",
type="TTS", type="TTS",
is_normalized=True, is_normalized=True,
) )
result = get_sound_file_path(sound) result = get_sound_file_path(sound)
expected = Path("sounds/normalized/text_to_speech/normalized.mp3") expected = Path("sounds/normalized/text_to_speech/normalized.mp3")
assert result == expected assert result == expected
@@ -359,7 +364,7 @@ class TestAudioUtils:
type="EXT", type="EXT",
is_normalized=False, is_normalized=False,
) )
result = get_sound_file_path(sound) result = get_sound_file_path(sound)
expected = Path("sounds/originals/extracted/extracted.mp3") expected = Path("sounds/originals/extracted/extracted.mp3")
assert result == expected assert result == expected
@@ -370,11 +375,11 @@ class TestAudioUtils:
id=3, id=3,
name="Extracted Sound", name="Extracted Sound",
filename="original.mp3", filename="original.mp3",
normalized_filename="normalized.mp3", normalized_filename="normalized.mp3",
type="EXT", type="EXT",
is_normalized=True, is_normalized=True,
) )
result = get_sound_file_path(sound) result = get_sound_file_path(sound)
expected = Path("sounds/normalized/extracted/normalized.mp3") expected = Path("sounds/normalized/extracted/normalized.mp3")
assert result == expected assert result == expected
@@ -388,7 +393,7 @@ class TestAudioUtils:
type="CUSTOM", type="CUSTOM",
is_normalized=False, is_normalized=False,
) )
result = get_sound_file_path(sound) result = get_sound_file_path(sound)
expected = Path("sounds/originals/custom/unknown.mp3") expected = Path("sounds/originals/custom/unknown.mp3")
assert result == expected assert result == expected
@@ -403,7 +408,7 @@ class TestAudioUtils:
type="SDB", type="SDB",
is_normalized=True, # True but no normalized_filename is_normalized=True, # True but no normalized_filename
) )
result = get_sound_file_path(sound) result = get_sound_file_path(sound)
# Should fall back to original file # Should fall back to original file
expected = Path("sounds/originals/soundboard/original.mp3") expected = Path("sounds/originals/soundboard/original.mp3")

View File

@@ -1,12 +1,16 @@
"""Tests for credit decorators.""" """Tests for credit decorators."""
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock
import pytest import pytest
from app.models.credit_action import CreditActionType from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError from app.services.credit import CreditService, InsufficientCreditsError
from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only from app.utils.credit_decorators import (
CreditManager,
requires_credits,
validate_credits_only,
)
class TestRequiresCreditsDecorator: class TestRequiresCreditsDecorator:
@@ -32,7 +36,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int, message: str) -> str: async def test_action(user_id: int, message: str) -> str:
return f"Success: {message}" return f"Success: {message}"
@@ -41,10 +45,10 @@ class TestRequiresCreditsDecorator:
assert result == "Success: test" assert result == "Success: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None 123, CreditActionType.VLC_PLAY_SOUND, None,
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, None 123, CreditActionType.VLC_PLAY_SOUND, True, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -58,7 +62,7 @@ class TestRequiresCreditsDecorator:
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id", user_id_param="user_id",
metadata_extractor=extract_metadata metadata_extractor=extract_metadata,
) )
async def test_action(user_id: int, sound_name: str) -> bool: async def test_action(user_id: int, sound_name: str) -> bool:
return True return True
@@ -66,10 +70,10 @@ class TestRequiresCreditsDecorator:
await test_action(user_id=123, sound_name="test.mp3") await test_action(user_id=123, sound_name="test.mp3")
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"} 123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"},
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"} 123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -79,7 +83,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int) -> bool: async def test_action(user_id: int) -> bool:
return False # Action fails return False # Action fails
@@ -88,7 +92,7 @@ class TestRequiresCreditsDecorator:
assert result is False assert result is False
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None 123, CreditActionType.VLC_PLAY_SOUND, False, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -98,7 +102,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int) -> str: async def test_action(user_id: int) -> str:
raise ValueError("Test error") raise ValueError("Test error")
@@ -107,7 +111,7 @@ class TestRequiresCreditsDecorator:
await test_action(user_id=123) await test_action(user_id=123)
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None 123, CreditActionType.VLC_PLAY_SOUND, False, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -118,7 +122,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int) -> str: async def test_action(user_id: int) -> str:
return "Should not execute" return "Should not execute"
@@ -136,7 +140,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int, message: str) -> str: async def test_action(user_id: int, message: str) -> str:
return message return message
@@ -145,7 +149,7 @@ class TestRequiresCreditsDecorator:
assert result == "test" assert result == "test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None 123, CreditActionType.VLC_PLAY_SOUND, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -155,7 +159,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(other_param: str) -> str: async def test_action(other_param: str) -> str:
return other_param return other_param
@@ -186,7 +190,7 @@ class TestValidateCreditsOnlyDecorator:
@validate_credits_only( @validate_credits_only(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int, message: str) -> str: async def test_action(user_id: int, message: str) -> str:
return f"Validated: {message}" return f"Validated: {message}"
@@ -195,7 +199,7 @@ class TestValidateCreditsOnlyDecorator:
assert result == "Validated: test" assert result == "Validated: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND 123, CreditActionType.VLC_PLAY_SOUND,
) )
# Should not deduct credits, only validate # Should not deduct credits, only validate
mock_credit_service.deduct_credits.assert_not_called() mock_credit_service.deduct_credits.assert_not_called()
@@ -219,15 +223,15 @@ class TestCreditManager:
mock_credit_service, mock_credit_service,
123, 123,
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
{"test": "data"} {"test": "data"},
) as manager: ) as manager:
manager.mark_success() manager.mark_success()
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"} 123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"},
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"} 123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -236,13 +240,13 @@ class TestCreditManager:
async with CreditManager( async with CreditManager(
mock_credit_service, mock_credit_service,
123, 123,
CreditActionType.VLC_PLAY_SOUND CreditActionType.VLC_PLAY_SOUND,
): ):
# Don't mark as success - should be considered failed # Don't mark as success - should be considered failed
pass pass
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None 123, CreditActionType.VLC_PLAY_SOUND, False, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -252,12 +256,12 @@ class TestCreditManager:
async with CreditManager( async with CreditManager(
mock_credit_service, mock_credit_service,
123, 123,
CreditActionType.VLC_PLAY_SOUND CreditActionType.VLC_PLAY_SOUND,
): ):
raise ValueError("Test error") raise ValueError("Test error")
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None 123, CreditActionType.VLC_PLAY_SOUND, False, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -269,9 +273,9 @@ class TestCreditManager:
async with CreditManager( async with CreditManager(
mock_credit_service, mock_credit_service,
123, 123,
CreditActionType.VLC_PLAY_SOUND CreditActionType.VLC_PLAY_SOUND,
): ):
pass pass
# Should not call deduct_credits since validation failed # Should not call deduct_credits since validation failed
mock_credit_service.deduct_credits.assert_not_called() mock_credit_service.deduct_credits.assert_not_called()