Refactor test files for improved readability and consistency

- Removed unnecessary blank lines and adjusted formatting in test files.
- Ensured consistent use of commas in function calls and assertions across various test cases.
- Updated import statements for better organization and clarity.
- Enhanced mock setups in tests for better isolation and reliability.
- Improved assertions to follow a consistent style for better readability.
This commit is contained in:
JSC
2025-07-31 21:37:04 +02:00
parent e69098d633
commit 8847131f24
42 changed files with 602 additions and 616 deletions

View File

@@ -1,6 +1,6 @@
"""Player API endpoints.""" """Player API endpoints."""
from typing import Annotated, Any from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status

View File

@@ -10,12 +10,12 @@ from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType from app.models.credit_action import CreditActionType
from app.models.user import User from app.models.user import User
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.credit import CreditService, InsufficientCreditsError from app.services.credit import CreditService, InsufficientCreditsError
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.extraction_processor import extraction_processor from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService from app.services.sound_scanner import ScanResults, SoundScannerService
from app.services.vlc_player import get_vlc_player_service, VLCPlayerService from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
router = APIRouter(prefix="/sounds", tags=["sounds"]) router = APIRouter(prefix="/sounds", tags=["sounds"])
@@ -125,13 +125,13 @@ async def scan_custom_directory(
async def normalize_all_sounds( async def normalize_all_sounds(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service) SoundNormalizerService, Depends(get_sound_normalizer_service),
], ],
force: bool = Query( force: bool = Query(
False, description="Force normalization of already normalized sounds" False, description="Force normalization of already normalized sounds",
), ),
one_pass: bool | None = Query( one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)" None, description="Use one-pass normalization (overrides config)",
), ),
) -> dict[str, NormalizationResults | str]: ) -> dict[str, NormalizationResults | str]:
"""Normalize all unnormalized sounds.""" """Normalize all unnormalized sounds."""
@@ -163,13 +163,13 @@ async def normalize_sounds_by_type(
sound_type: str, sound_type: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service) SoundNormalizerService, Depends(get_sound_normalizer_service),
], ],
force: bool = Query( force: bool = Query(
False, description="Force normalization of already normalized sounds" False, description="Force normalization of already normalized sounds",
), ),
one_pass: bool | None = Query( one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)" None, description="Use one-pass normalization (overrides config)",
), ),
) -> dict[str, NormalizationResults | str]: ) -> dict[str, NormalizationResults | str]:
"""Normalize all sounds of a specific type (SDB, TTS, EXT).""" """Normalize all sounds of a specific type (SDB, TTS, EXT)."""
@@ -210,13 +210,13 @@ async def normalize_sound_by_id(
sound_id: int, sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service) SoundNormalizerService, Depends(get_sound_normalizer_service),
], ],
force: bool = Query( force: bool = Query(
False, description="Force normalization of already normalized sound" False, description="Force normalization of already normalized sound",
), ),
one_pass: bool | None = Query( one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)" None, description="Use one-pass normalization (overrides config)",
), ),
) -> dict[str, str]: ) -> dict[str, str]:
"""Normalize a specific sound by ID.""" """Normalize a specific sound by ID."""
@@ -283,7 +283,7 @@ async def create_extraction(
) )
extraction_info = await extraction_service.create_extraction( extraction_info = await extraction_service.create_extraction(
url, current_user.id url, current_user.id,
) )
# Queue the extraction for background processing # Queue the extraction for background processing
@@ -398,7 +398,7 @@ async def play_sound_with_vlc(
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
current_user.id, current_user.id,
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
{"sound_id": sound_id, "sound_name": sound.name} {"sound_id": sound_id, "sound_name": sound.name},
) )
except InsufficientCreditsError as e: except InsufficientCreditsError as e:
raise HTTPException( raise HTTPException(

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint from sqlmodel import Field, Relationship
from app.models.base import BaseModel from app.models.base import BaseModel

View File

@@ -30,10 +30,10 @@ class PlayerStateResponse(BaseModel):
status: str = Field(description="Player status (playing, paused, stopped)") status: str = Field(description="Player status (playing, paused, stopped)")
current_sound: dict[str, Any] | None = Field( current_sound: dict[str, Any] | None = Field(
None, description="Current sound information" None, description="Current sound information",
) )
playlist: dict[str, Any] | None = Field( playlist: dict[str, Any] | None = Field(
None, description="Current playlist information" None, description="Current playlist information",
) )
position: int = Field(description="Current position in milliseconds") position: int = Field(description="Current position in milliseconds")
duration: int | None = Field( duration: int | None = Field(

View File

@@ -1,6 +1,6 @@
"""Playlist schemas.""" """Playlist schemas."""
from pydantic import BaseModel, Field from pydantic import BaseModel
from app.models.playlist import Playlist from app.models.playlist import Playlist
from app.models.sound import Sound from app.models.sound import Sound

View File

@@ -30,7 +30,7 @@ class InsufficientCreditsError(Exception):
self.required = required self.required = required
self.available = available self.available = available
super().__init__( super().__init__(
f"Insufficient credits: {required} required, {available} available" f"Insufficient credits: {required} required, {available} available",
) )
@@ -150,7 +150,7 @@ class CreditService:
) )
# Still create a transaction record for auditing # Still create a transaction record for auditing
return await self._create_transaction_record( return await self._create_transaction_record(
user_id, action, 0, success, metadata user_id, action, 0, success, metadata,
) )
session = self.db_session_factory() session = self.db_session_factory()

View File

@@ -10,7 +10,6 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings from app.core.config import settings
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.extraction import Extraction
from app.models.sound import Sound from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
@@ -155,7 +154,7 @@ class ExtractionService:
# Check if extraction already exists for this service # Check if extraction already exists for this service
existing = await self.extraction_repo.get_by_service_and_id( existing = await self.extraction_repo.get_by_service_and_id(
service_info["service"], service_info["service_id"] service_info["service"], service_info["service_id"],
) )
if existing and existing.id != extraction_id: if existing and existing.id != extraction_id:
error_msg = ( error_msg = (
@@ -180,7 +179,7 @@ class ExtractionService:
# Extract audio and thumbnail # Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media( audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url extraction_id, extraction_url,
) )
# Move files to final locations # Move files to final locations
@@ -238,7 +237,7 @@ class ExtractionService:
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
logger.exception( logger.exception(
"Failed to process extraction %d: %s", extraction_id, error_msg "Failed to process extraction %d: %s", extraction_id, error_msg,
) )
# Update extraction with error # Update extraction with error
@@ -262,14 +261,14 @@ class ExtractionService:
} }
async def _extract_media( async def _extract_media(
self, extraction_id: int, extraction_url: str self, extraction_id: int, extraction_url: str,
) -> tuple[Path, Path | None]: ) -> tuple[Path, Path | None]:
"""Extract audio and thumbnail using yt-dlp.""" """Extract audio and thumbnail using yt-dlp."""
temp_dir = Path(settings.EXTRACTION_TEMP_DIR) temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
# Create unique filename based on extraction ID # Create unique filename based on extraction ID
output_template = str( output_template = str(
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s" temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s",
) )
# Configure yt-dlp options # Configure yt-dlp options
@@ -304,8 +303,8 @@ class ExtractionService:
# Find the extracted files # Find the extracted files
audio_files = list( audio_files = list(
temp_dir.glob( temp_dir.glob(
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}" f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}",
) ),
) )
thumbnail_files = ( thumbnail_files = (
list(temp_dir.glob(f"extraction_{extraction_id}_*.webp")) list(temp_dir.glob(f"extraction_{extraction_id}_*.webp"))
@@ -342,7 +341,7 @@ class ExtractionService:
"""Move extracted files to their final locations.""" """Move extracted files to their final locations."""
# Generate clean filename based on title and service # Generate clean filename based on title and service
safe_title = self._sanitize_filename( safe_title = self._sanitize_filename(
title or f"{service or 'unknown'}_{service_id or 'unknown'}" title or f"{service or 'unknown'}_{service_id or 'unknown'}",
) )
# Move audio file # Move audio file

View File

@@ -46,9 +46,9 @@ class ExtractionProcessor:
if self.processor_task and not self.processor_task.done(): if self.processor_task and not self.processor_task.done():
try: try:
await asyncio.wait_for(self.processor_task, timeout=30.0) await asyncio.wait_for(self.processor_task, timeout=30.0)
except asyncio.TimeoutError: except TimeoutError:
logger.warning( logger.warning(
"Extraction processor did not stop gracefully, cancelling..." "Extraction processor did not stop gracefully, cancelling...",
) )
self.processor_task.cancel() self.processor_task.cancel()
try: try:
@@ -66,7 +66,7 @@ class ExtractionProcessor:
# The processor will pick it up on the next cycle # The processor will pick it up on the next cycle
else: else:
logger.warning( logger.warning(
"Extraction %d is already being processed", extraction_id "Extraction %d is already being processed", extraction_id,
) )
async def _process_queue(self) -> None: async def _process_queue(self) -> None:
@@ -81,7 +81,7 @@ class ExtractionProcessor:
try: try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0) await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
break # Shutdown requested break # Shutdown requested
except asyncio.TimeoutError: except TimeoutError:
continue # Continue processing continue # Continue processing
except Exception as e: except Exception as e:
@@ -90,7 +90,7 @@ class ExtractionProcessor:
try: try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0) await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
break # Shutdown requested break # Shutdown requested
except asyncio.TimeoutError: except TimeoutError:
continue continue
logger.info("Extraction queue processor stopped") logger.info("Extraction queue processor stopped")
@@ -125,13 +125,13 @@ class ExtractionProcessor:
# Start processing this extraction in the background # Start processing this extraction in the background
task = asyncio.create_task( task = asyncio.create_task(
self._process_single_extraction(extraction_id) self._process_single_extraction(extraction_id),
) )
task.add_done_callback( task.add_done_callback(
lambda t, eid=extraction_id: self._on_extraction_completed( lambda t, eid=extraction_id: self._on_extraction_completed(
eid, eid,
t, t,
) ),
) )
logger.info( logger.info(

View File

@@ -49,7 +49,7 @@ class PlaylistService:
if not main_playlist: if not main_playlist:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Main playlist not found. Make sure to run database seeding." detail="Main playlist not found. Make sure to run database seeding.",
) )
return main_playlist return main_playlist
@@ -179,7 +179,7 @@ class PlaylistService:
return await self.playlist_repo.get_playlist_sounds(playlist_id) return await self.playlist_repo.get_playlist_sounds(playlist_id)
async def add_sound_to_playlist( async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None,
) -> None: ) -> None:
"""Add a sound to a playlist.""" """Add a sound to a playlist."""
# Verify playlist exists # Verify playlist exists
@@ -202,11 +202,11 @@ class PlaylistService:
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position) await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
logger.info( logger.info(
"Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id "Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id,
) )
async def remove_sound_from_playlist( async def remove_sound_from_playlist(
self, playlist_id: int, sound_id: int, user_id: int self, playlist_id: int, sound_id: int, user_id: int,
) -> None: ) -> None:
"""Remove a sound from a playlist.""" """Remove a sound from a playlist."""
# Verify playlist exists # Verify playlist exists
@@ -228,7 +228,7 @@ class PlaylistService:
) )
async def reorder_playlist_sounds( async def reorder_playlist_sounds(
self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]] self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]],
) -> None: ) -> None:
"""Reorder sounds in a playlist.""" """Reorder sounds in a playlist."""
# Verify playlist exists # Verify playlist exists
@@ -262,7 +262,7 @@ class PlaylistService:
await self._unset_current_playlist(user_id) await self._unset_current_playlist(user_id)
await self._set_main_as_current(user_id) await self._set_main_as_current(user_id)
logger.info( logger.info(
"Unset current playlist and set main as current for user %s", user_id "Unset current playlist and set main as current for user %s", user_id,
) )
async def get_playlist_stats(self, playlist_id: int) -> dict[str, Any]: async def get_playlist_stats(self, playlist_id: int) -> dict[str, Any]:
@@ -290,7 +290,7 @@ class PlaylistService:
# Check if sound is already in main playlist # Check if sound is already in main playlist
if not await self.playlist_repo.is_sound_in_playlist( if not await self.playlist_repo.is_sound_in_playlist(
main_playlist.id, sound_id main_playlist.id, sound_id,
): ):
await self.playlist_repo.add_sound_to_playlist(main_playlist.id, sound_id) await self.playlist_repo.add_sound_to_playlist(main_playlist.id, sound_id)
logger.info( logger.info(

View File

@@ -141,7 +141,7 @@ class SoundNormalizerService:
ffmpeg.run(stream, quiet=True, overwrite_output=True) ffmpeg.run(stream, quiet=True, overwrite_output=True)
logger.info("One-pass normalization completed: %s", output_path) logger.info("One-pass normalization completed: %s", output_path)
except Exception as e: except Exception:
logger.exception("One-pass normalization failed for %s", input_path) logger.exception("One-pass normalization failed for %s", input_path)
raise raise
@@ -153,7 +153,7 @@ class SoundNormalizerService:
"""Normalize audio using two-pass loudnorm for better quality.""" """Normalize audio using two-pass loudnorm for better quality."""
try: try:
logger.info( logger.info(
"Starting two-pass normalization: %s -> %s", input_path, output_path "Starting two-pass normalization: %s -> %s", input_path, output_path,
) )
# First pass: analyze # First pass: analyze
@@ -193,7 +193,7 @@ class SoundNormalizerService:
json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output) json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output)
if not json_match: if not json_match:
logger.error( logger.error(
"Could not find JSON in loudnorm output: %s", analysis_output "Could not find JSON in loudnorm output: %s", analysis_output,
) )
raise ValueError("Could not extract loudnorm analysis data") raise ValueError("Could not extract loudnorm analysis data")
@@ -260,7 +260,7 @@ class SoundNormalizerService:
) )
raise raise
except Exception as e: except Exception:
logger.exception("Two-pass normalization failed for %s", input_path) logger.exception("Two-pass normalization failed for %s", input_path)
raise raise
@@ -428,7 +428,7 @@ class SoundNormalizerService:
"type": sound.type, "type": sound.type,
"is_normalized": sound.is_normalized, "is_normalized": sound.is_normalized,
"name": sound.name, "name": sound.name,
} },
) )
# Process each sound using captured data # Process each sound using captured data
@@ -476,7 +476,7 @@ class SoundNormalizerService:
"normalized_hash": None, "normalized_hash": None,
"id": sound_id, "id": sound_id,
"error": str(e), "error": str(e),
} },
) )
logger.info("Normalization completed: %s", results) logger.info("Normalization completed: %s", results)
@@ -517,7 +517,7 @@ class SoundNormalizerService:
"type": sound.type, "type": sound.type,
"is_normalized": sound.is_normalized, "is_normalized": sound.is_normalized,
"name": sound.name, "name": sound.name,
} },
) )
# Process each sound using captured data # Process each sound using captured data
@@ -565,7 +565,7 @@ class SoundNormalizerService:
"normalized_hash": None, "normalized_hash": None,
"id": sound_id, "id": sound_id,
"error": str(e), "error": str(e),
} },
) )
logger.info("Type normalization completed: %s", results) logger.info("Type normalization completed: %s", results)

View File

@@ -132,7 +132,7 @@ class SoundScannerService:
"id": None, "id": None,
"error": str(e), "error": str(e),
"changes": None, "changes": None,
} },
) )
# Delete sounds that no longer exist in directory # Delete sounds that no longer exist in directory
@@ -153,7 +153,7 @@ class SoundScannerService:
"id": sound.id, "id": sound.id,
"error": None, "error": None,
"changes": None, "changes": None,
} },
) )
except Exception as e: except Exception as e:
logger.exception("Error deleting sound %s", filename) logger.exception("Error deleting sound %s", filename)
@@ -169,7 +169,7 @@ class SoundScannerService:
"id": sound.id, "id": sound.id,
"error": str(e), "error": str(e),
"changes": None, "changes": None,
} },
) )
logger.info("Sync completed: %s", results) logger.info("Sync completed: %s", results)
@@ -219,7 +219,7 @@ class SoundScannerService:
"id": sound.id, "id": sound.id,
"error": None, "error": None,
"changes": None, "changes": None,
} },
) )
elif existing_sound.hash != file_hash: elif existing_sound.hash != file_hash:
@@ -246,7 +246,7 @@ class SoundScannerService:
"id": existing_sound.id, "id": existing_sound.id,
"error": None, "error": None,
"changes": ["hash", "duration", "size", "name"], "changes": ["hash", "duration", "size", "name"],
} },
) )
else: else:
@@ -264,7 +264,7 @@ class SoundScannerService:
"id": existing_sound.id, "id": existing_sound.id,
"error": None, "error": None,
"changes": None, "changes": None,
} },
) )
async def scan_soundboard_directory(self) -> ScanResults: async def scan_soundboard_directory(self) -> ScanResults:

View File

@@ -6,7 +6,6 @@ from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger

View File

@@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable
from typing import Any, TypeVar from typing import Any, TypeVar
from app.models.credit_action import CreditActionType from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError from app.services.credit import CreditService
F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
@@ -69,7 +69,7 @@ def requires_credits(
# Validate credits before execution # Validate credits before execution
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
user_id, action_type, metadata user_id, action_type, metadata,
) )
# Execute the function # Execute the function
@@ -85,7 +85,7 @@ def requires_credits(
finally: finally:
# Deduct credits based on success # Deduct credits based on success
await credit_service.deduct_credits( await credit_service.deduct_credits(
user_id, action_type, success, metadata user_id, action_type, success, metadata,
) )
return wrapper # type: ignore[return-value] return wrapper # type: ignore[return-value]
@@ -173,7 +173,7 @@ class CreditManager:
async def __aenter__(self) -> "CreditManager": async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits.""" """Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits( await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.metadata self.user_id, self.action_type, self.metadata,
) )
self.validated = True self.validated = True
return self return self
@@ -184,7 +184,7 @@ class CreditManager:
# If no exception occurred, consider it successful # If no exception occurred, consider it successful
success = exc_type is None and self.success success = exc_type is None and self.success
await self.credit_service.deduct_credits( await self.credit_service.deduct_credits(
self.user_id, self.action_type, success, self.metadata self.user_id, self.action_type, success, self.metadata,
) )
def mark_success(self) -> None: def mark_success(self) -> None:

View File

@@ -177,7 +177,7 @@ class TestApiTokenEndpoints:
# Set a token on the user # Set a token on the user
authenticated_user.api_token = "expired_token" authenticated_user.api_token = "expired_token"
authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta( authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta(
days=1 days=1,
) )
response = await authenticated_client.get("/api/v1/auth/api-token/status") response = await authenticated_client.get("/api/v1/auth/api-token/status")
@@ -209,7 +209,7 @@ class TestApiTokenEndpoints:
# Verify token exists # Verify token exists
status_response = await authenticated_client.get( status_response = await authenticated_client.get(
"/api/v1/auth/api-token/status" "/api/v1/auth/api-token/status",
) )
assert status_response.json()["has_token"] is True assert status_response.json()["has_token"] is True
@@ -222,7 +222,7 @@ class TestApiTokenEndpoints:
# Verify token is gone # Verify token is gone
status_response = await authenticated_client.get( status_response = await authenticated_client.get(
"/api/v1/auth/api-token/status" "/api/v1/auth/api-token/status",
) )
assert status_response.json()["has_token"] is False assert status_response.json()["has_token"] is False

View File

@@ -1,20 +1,16 @@
"""Tests for extraction API endpoints.""" """Tests for extraction API endpoints."""
from unittest.mock import AsyncMock, Mock
import pytest import pytest
import pytest_asyncio
from httpx import AsyncClient from httpx import AsyncClient
from app.models.user import User
class TestExtractionEndpoints: class TestExtractionEndpoints:
"""Test extraction API endpoints.""" """Test extraction API endpoints."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_extraction_success( async def test_create_extraction_success(
self, test_client: AsyncClient, auth_cookies: dict[str, str] self, test_client: AsyncClient, auth_cookies: dict[str, str],
): ):
"""Test successful extraction creation.""" """Test successful extraction creation."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -50,7 +46,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_processor_status_admin( async def test_get_processor_status_admin(
self, test_client: AsyncClient, admin_cookies: dict[str, str] self, test_client: AsyncClient, admin_cookies: dict[str, str],
): ):
"""Test getting processor status as admin.""" """Test getting processor status as admin."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -66,7 +62,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_processor_status_non_admin( async def test_get_processor_status_non_admin(
self, test_client: AsyncClient, auth_cookies: dict[str, str] self, test_client: AsyncClient, auth_cookies: dict[str, str],
): ):
"""Test getting processor status as non-admin user.""" """Test getting processor status as non-admin user."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -80,7 +76,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_extractions( async def test_get_user_extractions(
self, test_client: AsyncClient, auth_cookies: dict[str, str] self, test_client: AsyncClient, auth_cookies: dict[str, str],
): ):
"""Test getting user extractions.""" """Test getting user extractions."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning

View File

@@ -1,7 +1,5 @@
"""Tests for playlist API endpoints.""" """Tests for playlist API endpoints."""
import json
from typing import Any
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@@ -298,7 +296,7 @@ class TestPlaylistEndpoints:
playlist_name = test_playlist.name playlist_name = test_playlist.name
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}" f"/api/v1/playlists/{playlist_id}",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -350,7 +348,7 @@ class TestPlaylistEndpoints:
} }
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}", json=payload f"/api/v1/playlists/{playlist_id}", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -400,7 +398,7 @@ class TestPlaylistEndpoints:
payload = {"is_current": True} payload = {"is_current": True}
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}", json=payload f"/api/v1/playlists/{playlist_id}", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -434,7 +432,7 @@ class TestPlaylistEndpoints:
playlist_id = test_playlist.id playlist_id = test_playlist.id
response = await authenticated_client.delete( response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}" f"/api/v1/playlists/{playlist_id}",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -442,7 +440,7 @@ class TestPlaylistEndpoints:
# Verify playlist is deleted # Verify playlist is deleted
get_response = await authenticated_client.get( get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}" f"/api/v1/playlists/{playlist_id}",
) )
assert get_response.status_code == 404 assert get_response.status_code == 404
@@ -470,7 +468,7 @@ class TestPlaylistEndpoints:
main_playlist_id = main_playlist.id main_playlist_id = main_playlist.id
response = await authenticated_client.delete( response = await authenticated_client.delete(
f"/api/v1/playlists/{main_playlist_id}" f"/api/v1/playlists/{main_playlist_id}",
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -573,7 +571,7 @@ class TestPlaylistEndpoints:
await test_session.commit() await test_session.commit()
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds" f"/api/v1/playlists/{playlist_id}/sounds",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -624,7 +622,7 @@ class TestPlaylistEndpoints:
payload = {"sound_id": sound_id} payload = {"sound_id": sound_id}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -632,7 +630,7 @@ class TestPlaylistEndpoints:
# Verify sound was added # Verify sound was added
get_response = await authenticated_client.get( get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds" f"/api/v1/playlists/{playlist_id}/sounds",
) )
assert get_response.status_code == 200 assert get_response.status_code == 200
sounds = get_response.json() sounds = get_response.json()
@@ -681,7 +679,7 @@ class TestPlaylistEndpoints:
payload = {"sound_id": sound_id, "position": 5} payload = {"sound_id": sound_id, "position": 5}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -729,13 +727,13 @@ class TestPlaylistEndpoints:
# Add sound first time # Add sound first time
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
# Try to add same sound again # Try to add same sound again
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 400 assert response.status_code == 400
assert "already in this playlist" in response.json()["detail"] assert "already in this playlist" in response.json()["detail"]
@@ -769,7 +767,7 @@ class TestPlaylistEndpoints:
payload = {"sound_id": 99999} payload = {"sound_id": 99999}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -817,12 +815,12 @@ class TestPlaylistEndpoints:
# Add sound first # Add sound first
payload = {"sound_id": sound_id} payload = {"sound_id": sound_id}
await authenticated_client.post( await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
) )
# Remove sound # Remove sound
response = await authenticated_client.delete( response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}" f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -830,7 +828,7 @@ class TestPlaylistEndpoints:
# Verify sound was removed # Verify sound was removed
get_response = await authenticated_client.get( get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds" f"/api/v1/playlists/{playlist_id}/sounds",
) )
sounds = get_response.json() sounds = get_response.json()
assert len(sounds) == 0 assert len(sounds) == 0
@@ -875,7 +873,7 @@ class TestPlaylistEndpoints:
sound_id = test_sound.id sound_id = test_sound.id
response = await authenticated_client.delete( response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}" f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}",
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -929,11 +927,11 @@ class TestPlaylistEndpoints:
# Reorder sounds - use positions that don't cause constraints # Reorder sounds - use positions that don't cause constraints
# When swapping, we need to be careful about unique constraints # When swapping, we need to be careful about unique constraints
payload = { payload = {
"sound_positions": [[sound1_id, 10], [sound2_id, 5]] # Use different positions to avoid constraints "sound_positions": [[sound1_id, 10], [sound2_id, 5]], # Use different positions to avoid constraints
} }
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}/sounds/reorder", json=payload f"/api/v1/playlists/{playlist_id}/sounds/reorder", json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -976,7 +974,7 @@ class TestPlaylistEndpoints:
playlist_id = test_playlist.id playlist_id = test_playlist.id
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}/set-current" f"/api/v1/playlists/{playlist_id}/set-current",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -1076,7 +1074,7 @@ class TestPlaylistEndpoints:
# Initially empty # Initially empty
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/stats" f"/api/v1/playlists/{playlist_id}/stats",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -1093,7 +1091,7 @@ class TestPlaylistEndpoints:
# Check stats again # Check stats again
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/stats" f"/api/v1/playlists/{playlist_id}/stats",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -1110,8 +1108,8 @@ class TestPlaylistEndpoints:
test_session: AsyncSession, test_session: AsyncSession,
) -> None: ) -> None:
"""Test that users can only access their own playlists.""" """Test that users can only access their own playlists."""
from app.utils.auth import JWTUtils, PasswordUtils
from app.models.plan import Plan from app.models.plan import Plan
from app.utils.auth import PasswordUtils
# Create plan within this test to avoid session issues # Create plan within this test to avoid session issues
plan = Plan( plan = Plan(
@@ -1159,7 +1157,7 @@ class TestPlaylistEndpoints:
# Try to access other user's playlist # Try to access other user's playlist
response = await authenticated_client.get( response = await authenticated_client.get(
f"/api/v1/playlists/{other_playlist_id}" f"/api/v1/playlists/{other_playlist_id}",
) )
# Currently the implementation allows access to all playlists # Currently the implementation allows access to all playlists

View File

@@ -158,7 +158,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_missing_parameters( async def test_send_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User self, authenticated_client: AsyncClient, authenticated_user: User,
): ):
"""Test sending message with missing parameters.""" """Test sending message with missing parameters."""
# Missing target_user_id # Missing target_user_id
@@ -177,7 +177,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_broadcast_message_missing_parameters( async def test_broadcast_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User self, authenticated_client: AsyncClient, authenticated_user: User,
): ):
"""Test broadcasting message with missing parameters.""" """Test broadcasting message with missing parameters."""
response = await authenticated_client.post("/api/v1/socket/broadcast") response = await authenticated_client.post("/api/v1/socket/broadcast")
@@ -185,7 +185,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_invalid_user_id( async def test_send_message_invalid_user_id(
self, authenticated_client: AsyncClient, authenticated_user: User self, authenticated_client: AsyncClient, authenticated_user: User,
): ):
"""Test sending message with invalid user ID.""" """Test sending message with invalid user ID."""
response = await authenticated_client.post( response = await authenticated_client.post(

View File

@@ -66,7 +66,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -167,7 +167,7 @@ class TestSoundEndpoints:
headers = {"API-TOKEN": "admin_api_token"} headers = {"API-TOKEN": "admin_api_token"}
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -192,7 +192,7 @@ class TestSoundEndpoints:
): ):
"""Test scanning sounds when service raises an error.""" """Test scanning sounds when service raises an error."""
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.side_effect = Exception("Directory not found") mock_scan.side_effect = Exception("Directory not found")
@@ -244,7 +244,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory" "app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -285,7 +285,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory" "app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -307,14 +307,14 @@ class TestSoundEndpoints:
): ):
"""Test custom directory scanning with invalid path.""" """Test custom directory scanning with invalid path."""
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory" "app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan: ) as mock_scan:
mock_scan.side_effect = ValueError( mock_scan.side_effect = ValueError(
"Directory does not exist: /invalid/path" "Directory does not exist: /invalid/path",
) )
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/invalid/path"} "/api/v1/sounds/scan/custom", params={"directory": "/invalid/path"},
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -325,7 +325,7 @@ class TestSoundEndpoints:
async def test_scan_custom_directory_unauthenticated(self, client: AsyncClient): async def test_scan_custom_directory_unauthenticated(self, client: AsyncClient):
"""Test custom directory scanning without authentication.""" """Test custom directory scanning without authentication."""
response = await client.post( response = await client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/some/path"} "/api/v1/sounds/scan/custom", params={"directory": "/some/path"},
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -377,12 +377,12 @@ class TestSoundEndpoints:
): ):
"""Test custom directory scanning when service raises an error.""" """Test custom directory scanning when service raises an error."""
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory" "app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan: ) as mock_scan:
mock_scan.side_effect = Exception("Permission denied") mock_scan.side_effect = Exception("Permission denied")
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/restricted/path"} "/api/v1/sounds/scan/custom", params={"directory": "/restricted/path"},
) )
assert response.status_code == 500 assert response.status_code == 500
@@ -442,7 +442,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -480,7 +480,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory" "app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
@@ -570,12 +570,12 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds" "app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all" "/api/v1/sounds/normalize/all",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -608,12 +608,12 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds" "app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all", params={"force": True} "/api/v1/sounds/normalize/all", params={"force": True},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -637,12 +637,12 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds" "app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all", params={"one_pass": True} "/api/v1/sounds/normalize/all", params={"one_pass": True},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -684,7 +684,7 @@ class TestSoundEndpoints:
base_url="http://test", base_url="http://test",
) as client: ) as client:
response = await client.post( response = await client.post(
"/api/v1/sounds/normalize/all", headers=headers "/api/v1/sounds/normalize/all", headers=headers,
) )
assert response.status_code == 403 assert response.status_code == 403
@@ -702,12 +702,12 @@ class TestSoundEndpoints:
): ):
"""Test normalization when service raises an error.""" """Test normalization when service raises an error."""
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds" "app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize: ) as mock_normalize:
mock_normalize.side_effect = Exception("Normalization service failed") mock_normalize.side_effect = Exception("Normalization service failed")
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all" "/api/v1/sounds/normalize/all",
) )
assert response.status_code == 500 assert response.status_code == 500
@@ -758,12 +758,12 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type" "app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/type/SDB" "/api/v1/sounds/normalize/type/SDB",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -779,7 +779,7 @@ class TestSoundEndpoints:
# Verify the service was called with correct type # Verify the service was called with correct type
mock_normalize.assert_called_once_with( mock_normalize.assert_called_once_with(
sound_type="SDB", force=False, one_pass=None sound_type="SDB", force=False, one_pass=None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -790,7 +790,7 @@ class TestSoundEndpoints:
): ):
"""Test normalization with invalid sound type.""" """Test normalization with invalid sound type."""
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/type/INVALID" "/api/v1/sounds/normalize/type/INVALID",
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -814,7 +814,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type" "app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type",
) as mock_normalize: ) as mock_normalize:
mock_normalize.return_value = mock_results mock_normalize.return_value = mock_results
@@ -827,7 +827,7 @@ class TestSoundEndpoints:
# Verify parameters were passed correctly # Verify parameters were passed correctly
mock_normalize.assert_called_once_with( mock_normalize.assert_called_once_with(
sound_type="TTS", force=True, one_pass=False sound_type="TTS", force=True, one_pass=False,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -866,7 +866,7 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound" "app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
@@ -874,7 +874,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42" "/api/v1/sounds/normalize/42",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -897,12 +897,12 @@ class TestSoundEndpoints:
): ):
"""Test normalization of non-existent sound.""" """Test normalization of non-existent sound."""
with patch( with patch(
"app.repositories.sound.SoundRepository.get_by_id" "app.repositories.sound.SoundRepository.get_by_id",
) as mock_get_sound: ) as mock_get_sound:
mock_get_sound.return_value = None mock_get_sound.return_value = None
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/999" "/api/v1/sounds/normalize/999",
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -945,7 +945,7 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound" "app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
@@ -953,7 +953,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42" "/api/v1/sounds/normalize/42",
) )
assert response.status_code == 500 assert response.status_code == 500
@@ -997,7 +997,7 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound" "app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
@@ -1052,7 +1052,7 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound" "app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
@@ -1060,7 +1060,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post( response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42" "/api/v1/sounds/normalize/42",
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -1,16 +1,14 @@
"""Tests for VLC player API endpoints.""" """Tests for VLC player API endpoints."""
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock
import pytest import pytest
from httpx import AsyncClient
from fastapi import FastAPI from fastapi import FastAPI
from httpx import AsyncClient
from app.api.v1.sounds import get_credit_service, get_sound_repository, get_vlc_player
from app.models.sound import Sound from app.models.sound import Sound
from app.models.user import User from app.models.user import User
from app.api.v1.sounds import get_vlc_player, get_sound_repository, get_credit_service
class TestVLCEndpoints: class TestVLCEndpoints:

View File

@@ -13,10 +13,8 @@ from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
from app.utils.auth import JWTUtils, PasswordUtils from app.utils.auth import JWTUtils, PasswordUtils

View File

@@ -49,7 +49,7 @@ class TestApiTokenDependencies:
assert result == test_user assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with( mock_auth_service.get_user_by_api_token.assert_called_once_with(
"test_api_token_123" "test_api_token_123",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -135,11 +135,11 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_service_exception( async def test_get_current_user_api_token_service_exception(
self, mock_auth_service self, mock_auth_service,
): ):
"""Test API token authentication with service exception.""" """Test API token authentication with service exception."""
mock_auth_service.get_user_by_api_token.side_effect = Exception( mock_auth_service.get_user_by_api_token.side_effect = Exception(
"Database error" "Database error",
) )
api_token_header = "test_token" api_token_header = "test_token"
@@ -170,7 +170,7 @@ class TestApiTokenDependencies:
assert result == test_user assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with( mock_auth_service.get_user_by_api_token.assert_called_once_with(
"test_api_token_123" "test_api_token_123",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -184,7 +184,7 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_no_expiry_never_expires( async def test_api_token_no_expiry_never_expires(
self, mock_auth_service, test_user self, mock_auth_service, test_user,
): ):
"""Test API token with no expiry date never expires.""" """Test API token with no expiry date never expires."""
test_user.api_token_expires_at = None test_user.api_token_expires_at = None

View File

@@ -105,7 +105,6 @@ class TestCreditTransactionRepository:
ensure_plans: tuple[Any, ...], # noqa: ARG002 ensure_plans: tuple[Any, ...], # noqa: ARG002
) -> AsyncGenerator[CreditTransaction, None]: ) -> AsyncGenerator[CreditTransaction, None]:
"""Create a transaction for a different user.""" """Create a transaction for a different user."""
from app.models.plan import Plan
from app.repositories.user import UserRepository from app.repositories.user import UserRepository
# Create another user # Create another user
@@ -193,13 +192,13 @@ class TestCreditTransactionRepository:
"""Test getting transactions by user ID with pagination.""" """Test getting transactions by user ID with pagination."""
# Get first 2 transactions # Get first 2 transactions
first_page = await credit_transaction_repository.get_by_user_id( first_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=0 test_user_id, limit=2, offset=0,
) )
assert len(first_page) == 2 assert len(first_page) == 2
# Get next 2 transactions # Get next 2 transactions
second_page = await credit_transaction_repository.get_by_user_id( second_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=2 test_user_id, limit=2, offset=2,
) )
assert len(second_page) == 2 assert len(second_page) == 2
@@ -216,7 +215,7 @@ class TestCreditTransactionRepository:
) -> None: ) -> None:
"""Test getting transactions by action type.""" """Test getting transactions by action type."""
vlc_transactions = await credit_transaction_repository.get_by_action_type( vlc_transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound" "vlc_play_sound",
) )
# Should return 2 VLC transactions (1 successful, 1 failed) # Should return 2 VLC transactions (1 successful, 1 failed)
@@ -224,7 +223,7 @@ class TestCreditTransactionRepository:
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions) assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
extraction_transactions = await credit_transaction_repository.get_by_action_type( extraction_transactions = await credit_transaction_repository.get_by_action_type(
"audio_extraction" "audio_extraction",
) )
# Should return 1 extraction transaction # Should return 1 extraction transaction
@@ -240,14 +239,14 @@ class TestCreditTransactionRepository:
"""Test getting transactions by action type with pagination.""" """Test getting transactions by action type with pagination."""
# Test with limit # Test with limit
transactions = await credit_transaction_repository.get_by_action_type( transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1 "vlc_play_sound", limit=1,
) )
assert len(transactions) == 1 assert len(transactions) == 1
assert transactions[0].action_type == "vlc_play_sound" assert transactions[0].action_type == "vlc_play_sound"
# Test with offset # Test with offset
transactions = await credit_transaction_repository.get_by_action_type( transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, offset=1 "vlc_play_sound", limit=1, offset=1,
) )
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
@@ -275,7 +274,7 @@ class TestCreditTransactionRepository:
) -> None: ) -> None:
"""Test getting successful transactions filtered by user.""" """Test getting successful transactions filtered by user."""
successful_transactions = await credit_transaction_repository.get_successful_transactions( successful_transactions = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id user_id=test_user_id,
) )
# Should only return successful transactions for test_user # Should only return successful transactions for test_user
@@ -294,14 +293,14 @@ class TestCreditTransactionRepository:
"""Test getting successful transactions with pagination.""" """Test getting successful transactions with pagination."""
# Get first 2 successful transactions # Get first 2 successful transactions
first_page = await credit_transaction_repository.get_successful_transactions( first_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=0 user_id=test_user_id, limit=2, offset=0,
) )
assert len(first_page) == 2 assert len(first_page) == 2
assert all(t.success is True for t in first_page) assert all(t.success is True for t in first_page)
# Get next successful transaction # Get next successful transaction
second_page = await credit_transaction_repository.get_successful_transactions( second_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=2 user_id=test_user_id, limit=2, offset=2,
) )
assert len(second_page) == 1 # Should be 1 remaining assert len(second_page) == 1 # Should be 1 remaining
assert all(t.success is True for t in second_page) assert all(t.success is True for t in second_page)
@@ -363,7 +362,7 @@ class TestCreditTransactionRepository:
} }
updated_transaction = await credit_transaction_repository.update( updated_transaction = await credit_transaction_repository.update(
transaction, update_data transaction, update_data,
) )
assert updated_transaction.id == transaction.id assert updated_transaction.id == transaction.id

View File

@@ -416,7 +416,7 @@ class TestPlaylistRepository:
# Test the repository method # Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist( playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
assert playlist_sound.playlist_id == playlist_id assert playlist_sound.playlist_id == playlist_id
@@ -482,7 +482,7 @@ class TestPlaylistRepository:
# Test the repository method # Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist( playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id, position=5 playlist_id, sound_id, position=5,
) )
assert playlist_sound.position == 5 assert playlist_sound.position == 5
@@ -546,17 +546,17 @@ class TestPlaylistRepository:
# Verify it was added # Verify it was added
assert await playlist_repository.is_sound_in_playlist( assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
# Remove the sound # Remove the sound
await playlist_repository.remove_sound_from_playlist( await playlist_repository.remove_sound_from_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
# Verify it was removed # Verify it was removed
assert not await playlist_repository.is_sound_in_playlist( assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -746,7 +746,7 @@ class TestPlaylistRepository:
# Initially not in playlist # Initially not in playlist
assert not await playlist_repository.is_sound_in_playlist( assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
# Add sound # Add sound
@@ -754,7 +754,7 @@ class TestPlaylistRepository:
# Now in playlist # Now in playlist
assert await playlist_repository.is_sound_in_playlist( assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id playlist_id, sound_id,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -809,16 +809,16 @@ class TestPlaylistRepository:
# Add sounds to playlist # Add sounds to playlist
await playlist_repository.add_sound_to_playlist( await playlist_repository.add_sound_to_playlist(
playlist_id, sound1_id, position=0 playlist_id, sound1_id, position=0,
) )
await playlist_repository.add_sound_to_playlist( await playlist_repository.add_sound_to_playlist(
playlist_id, sound2_id, position=1 playlist_id, sound2_id, position=1,
) )
# Reorder sounds - use different positions to avoid constraint issues # Reorder sounds - use different positions to avoid constraint issues
sound_positions = [(sound1_id, 10), (sound2_id, 5)] sound_positions = [(sound1_id, 10), (sound2_id, 5)]
await playlist_repository.reorder_playlist_sounds( await playlist_repository.reorder_playlist_sounds(
playlist_id, sound_positions playlist_id, sound_positions,
) )
# Verify new order # Verify new order

View File

@@ -60,7 +60,7 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by provider user ID when it exists.""" """Test getting OAuth by provider user ID when it exists."""
oauth = await user_oauth_repository.get_by_provider_user_id( oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "google_123456" "google", "google_123456",
) )
assert oauth is not None assert oauth is not None
@@ -76,7 +76,7 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by provider user ID when it doesn't exist.""" """Test getting OAuth by provider user ID when it doesn't exist."""
oauth = await user_oauth_repository.get_by_provider_user_id( oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "nonexistent_id" "google", "nonexistent_id",
) )
assert oauth is None assert oauth is None
@@ -90,7 +90,7 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by user ID and provider when it exists.""" """Test getting OAuth by user ID and provider when it exists."""
oauth = await user_oauth_repository.get_by_user_id_and_provider( oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google" test_user_id, "google",
) )
assert oauth is not None assert oauth is not None
@@ -106,7 +106,7 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by user ID and provider when it doesn't exist.""" """Test getting OAuth by user ID and provider when it doesn't exist."""
oauth = await user_oauth_repository.get_by_user_id_and_provider( oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github" test_user_id, "github",
) )
assert oauth is None assert oauth is None
@@ -183,7 +183,7 @@ class TestUserOauthRepository:
# Verify it's deleted by trying to find it # Verify it's deleted by trying to find it
deleted_oauth = await user_oauth_repository.get_by_provider_user_id( deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
"twitter", "twitter_456" "twitter", "twitter_456",
) )
assert deleted_oauth is None assert deleted_oauth is None
@@ -240,10 +240,10 @@ class TestUserOauthRepository:
# Verify both exist by querying back from database # Verify both exist by querying back from database
found_google = await user_oauth_repository.get_by_user_id_and_provider( found_google = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google" test_user_id, "google",
) )
found_github = await user_oauth_repository.get_by_user_id_and_provider( found_github = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github" test_user_id, "github",
) )
assert found_google is not None assert found_google is not None
@@ -257,10 +257,10 @@ class TestUserOauthRepository:
# Verify we can also find them by provider_user_id # Verify we can also find them by provider_user_id
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id( found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
"google", "google_user_1" "google", "google_user_1",
) )
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id( found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
"github", "github_user_1" "github", "github_user_1",
) )
assert found_google_by_provider is not None assert found_google_by_provider is not None

View File

@@ -1,7 +1,7 @@
"""Tests for credit service.""" """Tests for credit service."""
import json import json
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -103,7 +103,7 @@ class TestCreditService:
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
user, action = await credit_service.validate_and_reserve_credits( user, action = await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND 1, CreditActionType.VLC_PLAY_SOUND,
) )
assert user == sample_user assert user == sample_user
@@ -131,7 +131,7 @@ class TestCreditService:
with pytest.raises(InsufficientCreditsError) as exc_info: with pytest.raises(InsufficientCreditsError) as exc_info:
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND 1, CreditActionType.VLC_PLAY_SOUND,
) )
assert exc_info.value.required == 1 assert exc_info.value.required == 1
@@ -150,7 +150,7 @@ class TestCreditService:
with pytest.raises(ValueError, match="User 999 not found"): with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
999, CreditActionType.VLC_PLAY_SOUND 999, CreditActionType.VLC_PLAY_SOUND,
) )
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -168,7 +168,7 @@ class TestCreditService:
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits( transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"} 1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"},
) )
# Verify user credits were updated # Verify user credits were updated
@@ -187,7 +187,7 @@ class TestCreditService:
"credits_deducted": 1, "credits_deducted": 1,
"action_type": "vlc_play_sound", "action_type": "vlc_play_sound",
"success": True, "success": True,
} },
) )
# Check transaction details # Check transaction details
@@ -214,7 +214,7 @@ class TestCreditService:
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits( transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, False # Action failed 1, CreditActionType.VLC_PLAY_SOUND, False, # Action failed
) )
# Verify user credits were NOT updated (action requires success) # Verify user credits were NOT updated (action requires success)
@@ -256,7 +256,7 @@ class TestCreditService:
with pytest.raises(InsufficientCreditsError): with pytest.raises(InsufficientCreditsError):
await credit_service.deduct_credits( await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True 1, CreditActionType.VLC_PLAY_SOUND, True,
) )
# Verify no socket event was emitted since credits could not be deducted # Verify no socket event was emitted since credits could not be deducted
@@ -278,7 +278,7 @@ class TestCreditService:
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.add_credits( transaction = await credit_service.add_credits(
1, 5, "Bonus credits", {"reason": "signup"} 1, 5, "Bonus credits", {"reason": "signup"},
) )
# Verify user credits were updated # Verify user credits were updated
@@ -297,7 +297,7 @@ class TestCreditService:
"credits_added": 5, "credits_added": 5,
"description": "Bonus credits", "description": "Bonus credits",
"success": True, "success": True,
} },
) )
# Check transaction details # Check transaction details

View File

@@ -53,7 +53,7 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_service_info_youtube( async def test_detect_service_info_youtube(
self, mock_ydl_class, extraction_service self, mock_ydl_class, extraction_service,
): ):
"""Test service detection for YouTube.""" """Test service detection for YouTube."""
mock_ydl = Mock() mock_ydl = Mock()
@@ -67,7 +67,7 @@ class TestExtractionService:
} }
result = await extraction_service._detect_service_info( result = await extraction_service._detect_service_info(
"https://www.youtube.com/watch?v=test123" "https://www.youtube.com/watch?v=test123",
) )
assert result is not None assert result is not None
@@ -78,7 +78,7 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_service_info_failure( async def test_detect_service_info_failure(
self, mock_ydl_class, extraction_service self, mock_ydl_class, extraction_service,
): ):
"""Test service detection failure.""" """Test service detection failure."""
mock_ydl = Mock() mock_ydl = Mock()
@@ -106,7 +106,7 @@ class TestExtractionService:
status="pending", status="pending",
) )
extraction_service.extraction_repo.create = AsyncMock( extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction return_value=mock_extraction,
) )
result = await extraction_service.create_extraction(url, user_id) result = await extraction_service.create_extraction(url, user_id)
@@ -134,7 +134,7 @@ class TestExtractionService:
status="pending", status="pending",
) )
extraction_service.extraction_repo.create = AsyncMock( extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction return_value=mock_extraction,
) )
result = await extraction_service.create_extraction(url, user_id) result = await extraction_service.create_extraction(url, user_id)
@@ -160,7 +160,7 @@ class TestExtractionService:
status="pending", status="pending",
) )
extraction_service.extraction_repo.create = AsyncMock( extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction return_value=mock_extraction,
) )
result = await extraction_service.create_extraction(url, user_id) result = await extraction_service.create_extraction(url, user_id)
@@ -186,11 +186,11 @@ class TestExtractionService:
) )
extraction_service.extraction_repo.get_by_id = AsyncMock( extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=mock_extraction return_value=mock_extraction,
) )
extraction_service.extraction_repo.update = AsyncMock() extraction_service.extraction_repo.update = AsyncMock()
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock( extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
return_value=None return_value=None,
) )
# Mock service detection # Mock service detection
@@ -202,14 +202,14 @@ class TestExtractionService:
with ( with (
patch.object( patch.object(
extraction_service, "_detect_service_info", return_value=service_info extraction_service, "_detect_service_info", return_value=service_info,
), ),
patch.object(extraction_service, "_extract_media") as mock_extract, patch.object(extraction_service, "_extract_media") as mock_extract,
patch.object( patch.object(
extraction_service, "_move_files_to_final_location" extraction_service, "_move_files_to_final_location",
) as mock_move, ) as mock_move,
patch.object( patch.object(
extraction_service, "_create_sound_record" extraction_service, "_create_sound_record",
) as mock_create_sound, ) as mock_create_sound,
patch.object(extraction_service, "_normalize_sound") as mock_normalize, patch.object(extraction_service, "_normalize_sound") as mock_normalize,
patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist, patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist,
@@ -223,7 +223,7 @@ class TestExtractionService:
# Verify service detection was called # Verify service detection was called
extraction_service._detect_service_info.assert_called_once_with( extraction_service._detect_service_info.assert_called_once_with(
"https://www.youtube.com/watch?v=test123" "https://www.youtube.com/watch?v=test123",
) )
# Verify extraction was updated with service info # Verify extraction was updated with service info
@@ -289,15 +289,15 @@ class TestExtractionService:
with ( with (
patch( patch(
"app.services.extraction.get_audio_duration", return_value=240000 "app.services.extraction.get_audio_duration", return_value=240000,
), ),
patch("app.services.extraction.get_file_size", return_value=1024), patch("app.services.extraction.get_file_size", return_value=1024),
patch( patch(
"app.services.extraction.get_file_hash", return_value="test_hash" "app.services.extraction.get_file_hash", return_value="test_hash",
), ),
): ):
extraction_service.sound_repo.create = AsyncMock( extraction_service.sound_repo.create = AsyncMock(
return_value=mock_sound return_value=mock_sound,
) )
result = await extraction_service._create_sound_record( result = await extraction_service._create_sound_record(
@@ -336,7 +336,7 @@ class TestExtractionService:
mock_normalizer = Mock() mock_normalizer = Mock()
mock_normalizer.normalize_sound = AsyncMock( mock_normalizer.normalize_sound = AsyncMock(
return_value={"status": "normalized"} return_value={"status": "normalized"},
) )
with patch( with patch(
@@ -368,7 +368,7 @@ class TestExtractionService:
mock_normalizer = Mock() mock_normalizer = Mock()
mock_normalizer.normalize_sound = AsyncMock( mock_normalizer.normalize_sound = AsyncMock(
return_value={"status": "error", "error": "Test error"} return_value={"status": "error", "error": "Test error"},
) )
with patch( with patch(
@@ -395,7 +395,7 @@ class TestExtractionService:
) )
extraction_service.extraction_repo.get_by_id = AsyncMock( extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=extraction return_value=extraction,
) )
result = await extraction_service.get_extraction_by_id(1) result = await extraction_service.get_extraction_by_id(1)
@@ -443,7 +443,7 @@ class TestExtractionService:
] ]
extraction_service.extraction_repo.get_by_user = AsyncMock( extraction_service.extraction_repo.get_by_user = AsyncMock(
return_value=extractions return_value=extractions,
) )
result = await extraction_service.get_user_extractions(1) result = await extraction_service.get_user_extractions(1)
@@ -470,7 +470,7 @@ class TestExtractionService:
] ]
extraction_service.extraction_repo.get_pending_extractions = AsyncMock( extraction_service.extraction_repo.get_pending_extractions = AsyncMock(
return_value=pending_extractions return_value=pending_extractions,
) )
result = await extraction_service.get_pending_extractions() result = await extraction_service.get_pending_extractions()

View File

@@ -1,6 +1,5 @@
"""Tests for extraction background processor.""" """Tests for extraction background processor."""
import asyncio
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
@@ -30,7 +29,7 @@ class TestExtractionProcessor:
"""Test starting and stopping the processor.""" """Test starting and stopping the processor."""
# Mock the _process_queue method to avoid actual processing # Mock the _process_queue method to avoid actual processing
with patch.object( with patch.object(
processor, "_process_queue", new_callable=AsyncMock processor, "_process_queue", new_callable=AsyncMock,
) as mock_process: ) as mock_process:
# Start the processor # Start the processor
await processor.start() await processor.start()
@@ -138,12 +137,12 @@ class TestExtractionProcessor:
# Mock the extraction service # Mock the extraction service
mock_service = Mock() mock_service = Mock()
mock_service.process_extraction = AsyncMock( mock_service.process_extraction = AsyncMock(
return_value={"status": "completed", "id": extraction_id} return_value={"status": "completed", "id": extraction_id},
) )
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -168,7 +167,7 @@ class TestExtractionProcessor:
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -193,12 +192,12 @@ class TestExtractionProcessor:
# Mock extraction service # Mock extraction service
mock_service = Mock() mock_service = Mock()
mock_service.get_pending_extractions = AsyncMock( mock_service.get_pending_extractions = AsyncMock(
return_value=[{"id": 100, "status": "pending"}] return_value=[{"id": 100, "status": "pending"}],
) )
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -222,15 +221,15 @@ class TestExtractionProcessor:
return_value=[ return_value=[
{"id": 100, "status": "pending"}, {"id": 100, "status": "pending"},
{"id": 101, "status": "pending"}, {"id": 101, "status": "pending"},
] ],
) )
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch.object( patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock processor, "_process_single_extraction", new_callable=AsyncMock,
) as mock_process, ) as mock_process,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -267,15 +266,15 @@ class TestExtractionProcessor:
{"id": 100, "status": "pending"}, {"id": 100, "status": "pending"},
{"id": 101, "status": "pending"}, {"id": 101, "status": "pending"},
{"id": 102, "status": "pending"}, {"id": 102, "status": "pending"},
] ],
) )
with ( with (
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch.object( patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock processor, "_process_single_extraction", new_callable=AsyncMock,
) as mock_process, ) as mock_process,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",

View File

@@ -4,14 +4,12 @@ import asyncio
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.models.user import User
from app.services.player import ( from app.services.player import (
PlayerMode, PlayerMode,
PlayerService, PlayerService,
@@ -21,7 +19,6 @@ from app.services.player import (
initialize_player_service, initialize_player_service,
shutdown_player_service, shutdown_player_service,
) )
from app.utils.audio import get_sound_file_path
class TestPlayerState: class TestPlayerState:
@@ -738,7 +735,7 @@ class TestPlayerService:
# Verify sound play count was updated # Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
mock_sound, {"play_count": 6} mock_sound, {"play_count": 6},
) )
# Verify SoundPlayed record was created with None user_id for player # Verify SoundPlayed record was created with None user_id for player

View File

@@ -153,7 +153,6 @@ class TestPlaylistService:
test_user: User, test_user: User,
) -> None: ) -> None:
"""Test getting non-existent playlist.""" """Test getting non-existent playlist."""
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_playlist_by_id(99999) await playlist_service.get_playlist_by_id(99999)
@@ -168,7 +167,6 @@ class TestPlaylistService:
test_session: AsyncSession, test_session: AsyncSession,
) -> None: ) -> None:
"""Test getting existing main playlist.""" """Test getting existing main playlist."""
# Create main playlist manually # Create main playlist manually
main_playlist = Playlist( main_playlist = Playlist(
user_id=None, user_id=None,
@@ -193,7 +191,6 @@ class TestPlaylistService:
test_user: User, test_user: User,
) -> None: ) -> None:
"""Test that service fails if main playlist doesn't exist.""" """Test that service fails if main playlist doesn't exist."""
# Should raise an HTTPException if no main playlist exists # Should raise an HTTPException if no main playlist exists
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_main_playlist() await playlist_service.get_main_playlist()
@@ -207,7 +204,6 @@ class TestPlaylistService:
test_user: User, test_user: User,
) -> None: ) -> None:
"""Test creating a new playlist successfully.""" """Test creating a new playlist successfully."""
user_id = test_user.id # Extract user_id while session is available user_id = test_user.id # Extract user_id while session is available
playlist = await playlist_service.create_playlist( playlist = await playlist_service.create_playlist(
user_id=user_id, user_id=user_id,
@@ -758,7 +754,7 @@ class TestPlaylistService:
# Set test_playlist as current # Set test_playlist as current
updated_playlist = await playlist_service.set_current_playlist( updated_playlist = await playlist_service.set_current_playlist(
test_playlist_id, user_id test_playlist_id, user_id,
) )
assert updated_playlist.is_current is True assert updated_playlist.is_current is True

View File

@@ -98,7 +98,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_success( async def test_connect_handler_success(
self, mock_decode, mock_extract_token, socket_manager, mock_sio self, mock_decode, mock_extract_token, socket_manager, mock_sio,
): ):
"""Test successful connection with valid token.""" """Test successful connection with valid token."""
# Setup mocks # Setup mocks
@@ -133,7 +133,7 @@ class TestSocketManager:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
async def test_connect_handler_no_token( async def test_connect_handler_no_token(
self, mock_extract_token, socket_manager, mock_sio self, mock_extract_token, socket_manager, mock_sio,
): ):
"""Test connection with no access token.""" """Test connection with no access token."""
# Setup mocks # Setup mocks
@@ -167,7 +167,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_invalid_token( async def test_connect_handler_invalid_token(
self, mock_decode, mock_extract_token, socket_manager, mock_sio self, mock_decode, mock_extract_token, socket_manager, mock_sio,
): ):
"""Test connection with invalid token.""" """Test connection with invalid token."""
# Setup mocks # Setup mocks
@@ -202,7 +202,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_missing_user_id( async def test_connect_handler_missing_user_id(
self, mock_decode, mock_extract_token, socket_manager, mock_sio self, mock_decode, mock_extract_token, socket_manager, mock_sio,
): ):
"""Test connection with token missing user ID.""" """Test connection with token missing user ID."""
# Setup mocks # Setup mocks

View File

@@ -55,7 +55,7 @@ class TestSoundNormalizerService:
normalized_path = normalizer_service._get_normalized_path(sound) normalized_path = normalizer_service._get_normalized_path(sound)
assert "sounds/normalized/soundboard" in str(normalized_path) assert "sounds/normalized/soundboard" in str(normalized_path)
assert "test_audio.mp3" == normalized_path.name assert normalized_path.name == "test_audio.mp3"
def test_get_original_path(self, normalizer_service): def test_get_original_path(self, normalizer_service):
"""Test original path generation.""" """Test original path generation."""
@@ -72,7 +72,7 @@ class TestSoundNormalizerService:
original_path = normalizer_service._get_original_path(sound) original_path = normalizer_service._get_original_path(sound)
assert "sounds/originals/soundboard" in str(original_path) assert "sounds/originals/soundboard" in str(original_path)
assert "test_audio.wav" == original_path.name assert original_path.name == "test_audio.wav"
def test_get_file_hash(self, normalizer_service): def test_get_file_hash(self, normalizer_service):
"""Test file hash calculation.""" """Test file hash calculation."""
@@ -172,14 +172,14 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path, patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path, patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
patch.object( patch.object(
normalizer_service, "_normalize_audio_two_pass" normalizer_service, "_normalize_audio_two_pass",
) as mock_normalize, ) as mock_normalize,
patch( patch(
"app.services.sound_normalizer.get_audio_duration", return_value=6000 "app.services.sound_normalizer.get_audio_duration", return_value=6000,
), ),
patch("app.services.sound_normalizer.get_file_size", return_value=2048), patch("app.services.sound_normalizer.get_file_size", return_value=2048),
patch( patch(
"app.services.sound_normalizer.get_file_hash", return_value="new_hash" "app.services.sound_normalizer.get_file_hash", return_value="new_hash",
), ),
): ):
# Setup path mocks # Setup path mocks
@@ -245,14 +245,14 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path, patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path, patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
patch.object( patch.object(
normalizer_service, "_normalize_audio_one_pass" normalizer_service, "_normalize_audio_one_pass",
) as mock_normalize, ) as mock_normalize,
patch( patch(
"app.services.sound_normalizer.get_audio_duration", return_value=5500 "app.services.sound_normalizer.get_audio_duration", return_value=5500,
), ),
patch("app.services.sound_normalizer.get_file_size", return_value=1500), patch("app.services.sound_normalizer.get_file_size", return_value=1500),
patch( patch(
"app.services.sound_normalizer.get_file_hash", return_value="norm_hash" "app.services.sound_normalizer.get_file_hash", return_value="norm_hash",
), ),
): ):
# Setup path mocks # Setup path mocks
@@ -300,7 +300,7 @@ class TestSoundNormalizerService:
with ( with (
patch("pathlib.Path.exists", return_value=True), patch("pathlib.Path.exists", return_value=True),
patch.object( patch.object(
normalizer_service, "_normalize_audio_two_pass" normalizer_service, "_normalize_audio_two_pass",
) as mock_normalize, ) as mock_normalize,
): ):
mock_normalize.side_effect = Exception("Normalization failed") mock_normalize.side_effect = Exception("Normalization failed")
@@ -339,7 +339,7 @@ class TestSoundNormalizerService:
# Mock repository calls # Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock( normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock(
return_value=sounds return_value=sounds,
) )
# Mock individual normalization # Mock individual normalization
@@ -399,7 +399,7 @@ class TestSoundNormalizerService:
# Mock repository calls # Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds_by_type = AsyncMock( normalizer_service.sound_repo.get_unnormalized_sounds_by_type = AsyncMock(
return_value=sdb_sounds return_value=sdb_sounds,
) )
# Mock individual normalization # Mock individual normalization
@@ -428,7 +428,7 @@ class TestSoundNormalizerService:
# Verify correct repository method was called # Verify correct repository method was called
normalizer_service.sound_repo.get_unnormalized_sounds_by_type.assert_called_once_with( normalizer_service.sound_repo.get_unnormalized_sounds_by_type.assert_called_once_with(
"SDB" "SDB",
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -459,7 +459,7 @@ class TestSoundNormalizerService:
# Mock repository calls # Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock( normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock(
return_value=sounds return_value=sounds,
) )
# Mock individual normalization with one success and one error # Mock individual normalization with one success and one error
@@ -529,7 +529,7 @@ class TestSoundNormalizerService:
# Verify ffmpeg chain was called correctly # Verify ffmpeg chain was called correctly
mock_ffmpeg.input.assert_called_once_with(str(input_path)) mock_ffmpeg.input.assert_called_once_with(str(input_path))
mock_ffmpeg.filter.assert_called_once_with( mock_ffmpeg.filter.assert_called_once_with(
mock_stream, "loudnorm", I=-23, TP=-2, LRA=7 mock_stream, "loudnorm", I=-23, TP=-2, LRA=7,
) )
mock_ffmpeg.output.assert_called_once() mock_ffmpeg.output.assert_called_once()
mock_ffmpeg.run.assert_called_once() mock_ffmpeg.run.assert_called_once()

View File

@@ -153,7 +153,7 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "SDB", existing_sound, results temp_path, "SDB", existing_sound, results,
) )
assert results["skipped"] == 1 assert results["skipped"] == 1
@@ -257,7 +257,7 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "SDB", existing_sound, results temp_path, "SDB", existing_sound, results,
) )
assert results["updated"] == 1 assert results["updated"] == 1
@@ -296,7 +296,7 @@ class TestSoundScannerService:
# Mock file operations # Mock file operations
with ( with (
patch( patch(
"app.services.sound_scanner.get_file_hash", return_value="custom_hash" "app.services.sound_scanner.get_file_hash", return_value="custom_hash",
), ),
patch("app.services.sound_scanner.get_audio_duration", return_value=60000), patch("app.services.sound_scanner.get_audio_duration", return_value=60000),
patch("app.services.sound_scanner.get_file_size", return_value=2048), patch("app.services.sound_scanner.get_file_size", return_value=2048),
@@ -316,7 +316,7 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "CUSTOM", None, results temp_path, "CUSTOM", None, results,
) )
assert results["added"] == 1 assert results["added"] == 1

View File

@@ -7,10 +7,8 @@ from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from app.models.sound import Sound from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.models.user import User from app.models.user import User
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
from app.utils.audio import get_sound_file_path
class TestVLCPlayerService: class TestVLCPlayerService:
@@ -111,7 +109,7 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_success( async def test_play_sound_success(
self, mock_subprocess, vlc_service, sample_sound self, mock_subprocess, vlc_service, sample_sound,
): ):
"""Test successful sound playback.""" """Test successful sound playback."""
# Mock subprocess # Mock subprocess
@@ -144,7 +142,7 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_play_sound_file_not_found( async def test_play_sound_file_not_found(
self, vlc_service, sample_sound self, vlc_service, sample_sound,
): ):
"""Test sound playback when file doesn't exist.""" """Test sound playback when file doesn't exist."""
# Mock the file path utility to return a non-existent path # Mock the file path utility to return a non-existent path
@@ -160,7 +158,7 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_subprocess_error( async def test_play_sound_subprocess_error(
self, mock_subprocess, vlc_service, sample_sound self, mock_subprocess, vlc_service, sample_sound,
): ):
"""Test sound playback when subprocess fails.""" """Test sound playback when subprocess fails."""
# Mock the file path utility to return an existing path # Mock the file path utility to return an existing path
@@ -184,7 +182,7 @@ class TestVLCPlayerService:
mock_find_process = Mock() mock_find_process = Mock()
mock_find_process.returncode = 0 mock_find_process.returncode = 0
mock_find_process.communicate = AsyncMock( mock_find_process.communicate = AsyncMock(
return_value=(b"12345\n67890\n", b"") return_value=(b"12345\n67890\n", b""),
) )
# Mock pkill process (kill VLC processes) # Mock pkill process (kill VLC processes)
@@ -214,7 +212,7 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_no_processes( async def test_stop_all_vlc_instances_no_processes(
self, mock_subprocess, vlc_service self, mock_subprocess, vlc_service,
): ):
"""Test stopping VLC instances when none are running.""" """Test stopping VLC instances when none are running."""
# Mock pgrep process (no VLC processes found) # Mock pgrep process (no VLC processes found)
@@ -234,14 +232,14 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_partial_kill( async def test_stop_all_vlc_instances_partial_kill(
self, mock_subprocess, vlc_service self, mock_subprocess, vlc_service,
): ):
"""Test stopping VLC instances when some processes remain.""" """Test stopping VLC instances when some processes remain."""
# Mock pgrep process (find VLC processes) # Mock pgrep process (find VLC processes)
mock_find_process = Mock() mock_find_process = Mock()
mock_find_process.returncode = 0 mock_find_process.returncode = 0
mock_find_process.communicate = AsyncMock( mock_find_process.communicate = AsyncMock(
return_value=(b"12345\n67890\n11111\n", b"") return_value=(b"12345\n67890\n11111\n", b""),
) )
# Mock pkill process (kill VLC processes) # Mock pkill process (kill VLC processes)
@@ -306,7 +304,7 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_with_play_count_tracking( async def test_play_sound_with_play_count_tracking(
self, mock_subprocess, vlc_service_with_db, sample_sound self, mock_subprocess, vlc_service_with_db, sample_sound,
): ):
"""Test sound playback with play count tracking.""" """Test sound playback with play count tracking."""
# Mock subprocess # Mock subprocess
@@ -416,7 +414,7 @@ class TestVLCPlayerService:
# Verify sound repository calls # Verify sound repository calls
mock_sound_repo.get_by_id.assert_called_once_with(1) mock_sound_repo.get_by_id.assert_called_once_with(1)
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 1} test_sound, {"play_count": 1},
) )
# Verify user repository calls # Verify user repository calls
@@ -487,7 +485,7 @@ class TestVLCPlayerService:
# Verify sound play count was updated # Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 6} test_sound, {"play_count": 6},
) )
# Verify new SoundPlayed record was always added # Verify new SoundPlayed record was always added

View File

@@ -8,7 +8,12 @@ from unittest.mock import patch
import pytest import pytest
from app.models.sound import Sound from app.models.sound import Sound
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size, get_sound_file_path from app.utils.audio import (
get_audio_duration,
get_file_hash,
get_file_size,
get_sound_file_path,
)
class TestAudioUtils: class TestAudioUtils:

View File

@@ -1,12 +1,16 @@
"""Tests for credit decorators.""" """Tests for credit decorators."""
from unittest.mock import AsyncMock, Mock from unittest.mock import AsyncMock
import pytest import pytest
from app.models.credit_action import CreditActionType from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError from app.services.credit import CreditService, InsufficientCreditsError
from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only from app.utils.credit_decorators import (
CreditManager,
requires_credits,
validate_credits_only,
)
class TestRequiresCreditsDecorator: class TestRequiresCreditsDecorator:
@@ -32,7 +36,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int, message: str) -> str: async def test_action(user_id: int, message: str) -> str:
return f"Success: {message}" return f"Success: {message}"
@@ -41,10 +45,10 @@ class TestRequiresCreditsDecorator:
assert result == "Success: test" assert result == "Success: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None 123, CreditActionType.VLC_PLAY_SOUND, None,
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, None 123, CreditActionType.VLC_PLAY_SOUND, True, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -58,7 +62,7 @@ class TestRequiresCreditsDecorator:
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id", user_id_param="user_id",
metadata_extractor=extract_metadata metadata_extractor=extract_metadata,
) )
async def test_action(user_id: int, sound_name: str) -> bool: async def test_action(user_id: int, sound_name: str) -> bool:
return True return True
@@ -66,10 +70,10 @@ class TestRequiresCreditsDecorator:
await test_action(user_id=123, sound_name="test.mp3") await test_action(user_id=123, sound_name="test.mp3")
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"} 123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"},
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"} 123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -79,7 +83,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int) -> bool: async def test_action(user_id: int) -> bool:
return False # Action fails return False # Action fails
@@ -88,7 +92,7 @@ class TestRequiresCreditsDecorator:
assert result is False assert result is False
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None 123, CreditActionType.VLC_PLAY_SOUND, False, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -98,7 +102,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int) -> str: async def test_action(user_id: int) -> str:
raise ValueError("Test error") raise ValueError("Test error")
@@ -107,7 +111,7 @@ class TestRequiresCreditsDecorator:
await test_action(user_id=123) await test_action(user_id=123)
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None 123, CreditActionType.VLC_PLAY_SOUND, False, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -118,7 +122,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int) -> str: async def test_action(user_id: int) -> str:
return "Should not execute" return "Should not execute"
@@ -136,7 +140,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int, message: str) -> str: async def test_action(user_id: int, message: str) -> str:
return message return message
@@ -145,7 +149,7 @@ class TestRequiresCreditsDecorator:
assert result == "test" assert result == "test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None 123, CreditActionType.VLC_PLAY_SOUND, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -155,7 +159,7 @@ class TestRequiresCreditsDecorator:
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(other_param: str) -> str: async def test_action(other_param: str) -> str:
return other_param return other_param
@@ -186,7 +190,7 @@ class TestValidateCreditsOnlyDecorator:
@validate_credits_only( @validate_credits_only(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
credit_service_factory, credit_service_factory,
user_id_param="user_id" user_id_param="user_id",
) )
async def test_action(user_id: int, message: str) -> str: async def test_action(user_id: int, message: str) -> str:
return f"Validated: {message}" return f"Validated: {message}"
@@ -195,7 +199,7 @@ class TestValidateCreditsOnlyDecorator:
assert result == "Validated: test" assert result == "Validated: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND 123, CreditActionType.VLC_PLAY_SOUND,
) )
# Should not deduct credits, only validate # Should not deduct credits, only validate
mock_credit_service.deduct_credits.assert_not_called() mock_credit_service.deduct_credits.assert_not_called()
@@ -219,15 +223,15 @@ class TestCreditManager:
mock_credit_service, mock_credit_service,
123, 123,
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
{"test": "data"} {"test": "data"},
) as manager: ) as manager:
manager.mark_success() manager.mark_success()
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"} 123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"},
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"} 123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -236,13 +240,13 @@ class TestCreditManager:
async with CreditManager( async with CreditManager(
mock_credit_service, mock_credit_service,
123, 123,
CreditActionType.VLC_PLAY_SOUND CreditActionType.VLC_PLAY_SOUND,
): ):
# Don't mark as success - should be considered failed # Don't mark as success - should be considered failed
pass pass
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None 123, CreditActionType.VLC_PLAY_SOUND, False, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -252,12 +256,12 @@ class TestCreditManager:
async with CreditManager( async with CreditManager(
mock_credit_service, mock_credit_service,
123, 123,
CreditActionType.VLC_PLAY_SOUND CreditActionType.VLC_PLAY_SOUND,
): ):
raise ValueError("Test error") raise ValueError("Test error")
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None 123, CreditActionType.VLC_PLAY_SOUND, False, None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -269,7 +273,7 @@ class TestCreditManager:
async with CreditManager( async with CreditManager(
mock_credit_service, mock_credit_service,
123, 123,
CreditActionType.VLC_PLAY_SOUND CreditActionType.VLC_PLAY_SOUND,
): ):
pass pass