Compare commits

...

3 Commits

Author SHA1 Message Date
JSC
7ba52ad6fc fix: Lint fixes of core, models and schemas
All checks were successful
Backend CI / test (push) Successful in 4m5s
2025-07-31 22:06:31 +02:00
JSC
01bb48c206 fix: Utils lint fixes 2025-07-31 21:56:03 +02:00
JSC
8847131f24 Refactor test files for improved readability and consistency
- Removed unnecessary blank lines and adjusted formatting in test files.
- Ensured consistent use of commas in function calls and assertions across various test cases.
- Updated import statements for better organization and clarity.
- Enhanced mock setups in tests for better isolation and reliability.
- Improved assertions to follow a consistent style for better readability.
2025-07-31 21:37:04 +02:00
46 changed files with 704 additions and 714 deletions

View File

@@ -1,6 +1,6 @@
"""Player API endpoints."""
from typing import Annotated, Any
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
@@ -214,4 +214,4 @@ async def get_state(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get player state",
) from e
) from e

View File

@@ -10,12 +10,12 @@ from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType
from app.models.user import User
from app.repositories.sound import SoundRepository
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService
from app.services.vlc_player import get_vlc_player_service, VLCPlayerService
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
router = APIRouter(prefix="/sounds", tags=["sounds"])
@@ -125,13 +125,13 @@ async def scan_custom_directory(
async def normalize_all_sounds(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service)
SoundNormalizerService, Depends(get_sound_normalizer_service),
],
force: bool = Query(
False, description="Force normalization of already normalized sounds"
False, description="Force normalization of already normalized sounds",
),
one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)"
None, description="Use one-pass normalization (overrides config)",
),
) -> dict[str, NormalizationResults | str]:
"""Normalize all unnormalized sounds."""
@@ -163,13 +163,13 @@ async def normalize_sounds_by_type(
sound_type: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service)
SoundNormalizerService, Depends(get_sound_normalizer_service),
],
force: bool = Query(
False, description="Force normalization of already normalized sounds"
False, description="Force normalization of already normalized sounds",
),
one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)"
None, description="Use one-pass normalization (overrides config)",
),
) -> dict[str, NormalizationResults | str]:
"""Normalize all sounds of a specific type (SDB, TTS, EXT)."""
@@ -210,13 +210,13 @@ async def normalize_sound_by_id(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
normalizer_service: Annotated[
SoundNormalizerService, Depends(get_sound_normalizer_service)
SoundNormalizerService, Depends(get_sound_normalizer_service),
],
force: bool = Query(
False, description="Force normalization of already normalized sound"
False, description="Force normalization of already normalized sound",
),
one_pass: bool | None = Query(
None, description="Use one-pass normalization (overrides config)"
None, description="Use one-pass normalization (overrides config)",
),
) -> dict[str, str]:
"""Normalize a specific sound by ID."""
@@ -283,7 +283,7 @@ async def create_extraction(
)
extraction_info = await extraction_service.create_extraction(
url, current_user.id
url, current_user.id,
)
# Queue the extraction for background processing
@@ -398,7 +398,7 @@ async def play_sound_with_vlc(
await credit_service.validate_and_reserve_credits(
current_user.id,
CreditActionType.VLC_PLAY_SOUND,
{"sound_id": sound_id, "sound_name": sound.name}
{"sound_id": sound_id, "sound_name": sound.name},
)
except InsufficientCreditsError as e:
raise HTTPException(
@@ -408,7 +408,7 @@ async def play_sound_with_vlc(
# Play the sound using VLC
success = await vlc_player.play_sound(sound)
# Deduct credits based on success
await credit_service.deduct_credits(
current_user.id,
@@ -416,7 +416,7 @@ async def play_sound_with_vlc(
success,
{"sound_id": sound_id, "sound_name": sound.name},
)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View File

@@ -1,4 +1,4 @@
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import SQLModel
@@ -38,9 +38,9 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
await session.close()
def get_session_factory():
def get_session_factory() -> Callable[[], AsyncSession]:
"""Get a session factory function for services."""
def session_factory():
def session_factory() -> AsyncSession:
return AsyncSession(engine)
return session_factory

View File

@@ -135,8 +135,6 @@ async def get_current_user_api_token(
detail="Account is deactivated",
)
return user
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
@@ -146,6 +144,8 @@ async def get_current_user_api_token(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate API token",
) from e
else:
return user
async def get_current_user_flexible(

View File

@@ -118,4 +118,4 @@ def get_all_credit_actions() -> dict[CreditActionType, CreditAction]:
Dictionary of all credit actions
"""
return CREDIT_ACTIONS.copy()
return CREDIT_ACTIONS.copy()

View File

@@ -17,7 +17,8 @@ class CreditTransaction(BaseModel, table=True):
user_id: int = Field(foreign_key="user.id", nullable=False)
action_type: str = Field(nullable=False)
amount: int = Field(nullable=False) # Negative for deductions, positive for additions
# Negative for deductions, positive for additions
amount: int = Field(nullable=False)
balance_before: int = Field(nullable=False)
balance_after: int = Field(nullable=False)
description: str = Field(nullable=False)
@@ -26,4 +27,4 @@ class CreditTransaction(BaseModel, table=True):
metadata_json: str | None = Field(default=None)
# relationships
user: "User" = Relationship(back_populates="credit_transactions")
user: "User" = Relationship(back_populates="credit_transactions")

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel import Field, Relationship
from app.models.base import BaseModel
@@ -25,7 +25,8 @@ class Extraction(BaseModel, table=True):
status: str = Field(nullable=False, default="pending")
error: str | None = Field(default=None)
# constraints - only enforce uniqueness when both service and service_id are not null
# constraints - only enforce uniqueness when both service and service_id
# are not null
__table_args__ = ()
# relationships

View File

@@ -28,25 +28,16 @@ from .playlist import (
)
__all__ = [
# Auth schemas
"ApiTokenRequest",
"ApiTokenResponse",
"ApiTokenStatusResponse",
"AuthResponse",
"TokenResponse",
"UserLoginRequest",
"UserRegisterRequest",
"UserResponse",
# Common schemas
"HealthResponse",
"MessageResponse",
"StatusResponse",
# Player schemas
"MessageResponse",
"PlayerModeRequest",
"PlayerSeekRequest",
"PlayerStateResponse",
"PlayerVolumeRequest",
# Playlist schemas
"PlaylistAddSoundRequest",
"PlaylistCreateRequest",
"PlaylistReorderRequest",
@@ -54,4 +45,9 @@ __all__ = [
"PlaylistSoundResponse",
"PlaylistStatsResponse",
"PlaylistUpdateRequest",
"StatusResponse",
"TokenResponse",
"UserLoginRequest",
"UserRegisterRequest",
"UserResponse",
]

View File

@@ -18,4 +18,4 @@ class StatusResponse(BaseModel):
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(description="Health status")
status: str = Field(description="Health status")

View File

@@ -30,10 +30,10 @@ class PlayerStateResponse(BaseModel):
status: str = Field(description="Player status (playing, paused, stopped)")
current_sound: dict[str, Any] | None = Field(
None, description="Current sound information"
None, description="Current sound information",
)
playlist: dict[str, Any] | None = Field(
None, description="Current playlist information"
None, description="Current playlist information",
)
position: int = Field(description="Current position in milliseconds")
duration: int | None = Field(

View File

@@ -1,6 +1,6 @@
"""Playlist schemas."""
from pydantic import BaseModel, Field
from pydantic import BaseModel
from app.models.playlist import Playlist
from app.models.sound import Sound
@@ -40,7 +40,8 @@ class PlaylistResponse(BaseModel):
def from_playlist(cls, playlist: Playlist) -> "PlaylistResponse":
"""Create response from playlist model."""
if playlist.id is None:
raise ValueError("Playlist ID cannot be None")
msg = "Playlist ID cannot be None"
raise ValueError(msg)
return cls(
id=playlist.id,
name=playlist.name,
@@ -70,7 +71,8 @@ class PlaylistSoundResponse(BaseModel):
def from_sound(cls, sound: Sound) -> "PlaylistSoundResponse":
"""Create response from sound model."""
if sound.id is None:
raise ValueError("Sound ID cannot be None")
msg = "Sound ID cannot be None"
raise ValueError(msg)
return cls(
id=sound.id,
name=sound.name,

View File

@@ -30,7 +30,7 @@ class InsufficientCreditsError(Exception):
self.required = required
self.available = available
super().__init__(
f"Insufficient credits: {required} required, {available} available"
f"Insufficient credits: {required} required, {available} available",
)
@@ -138,10 +138,10 @@ class CreditService:
"""
action = get_credit_action(action_type)
# Only deduct if action requires success and was successful, or doesn't require success
should_deduct = (action.requires_success and success) or not action.requires_success
if not should_deduct:
logger.info(
"Skipping credit deduction for user %s: action %s failed and requires success",
@@ -150,7 +150,7 @@ class CreditService:
)
# Still create a transaction record for auditing
return await self._create_transaction_record(
user_id, action, 0, success, metadata
user_id, action, 0, success, metadata,
)
session = self.db_session_factory()
@@ -380,4 +380,4 @@ class CreditService:
raise ValueError(msg)
return user.credits
finally:
await session.close()
await session.close()

View File

@@ -10,7 +10,6 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.models.extraction import Extraction
from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository
@@ -155,7 +154,7 @@ class ExtractionService:
# Check if extraction already exists for this service
existing = await self.extraction_repo.get_by_service_and_id(
service_info["service"], service_info["service_id"]
service_info["service"], service_info["service_id"],
)
if existing and existing.id != extraction_id:
error_msg = (
@@ -180,7 +179,7 @@ class ExtractionService:
# Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url
extraction_id, extraction_url,
)
# Move files to final locations
@@ -238,7 +237,7 @@ class ExtractionService:
except Exception as e:
error_msg = str(e)
logger.exception(
"Failed to process extraction %d: %s", extraction_id, error_msg
"Failed to process extraction %d: %s", extraction_id, error_msg,
)
# Update extraction with error
@@ -262,14 +261,14 @@ class ExtractionService:
}
async def _extract_media(
self, extraction_id: int, extraction_url: str
self, extraction_id: int, extraction_url: str,
) -> tuple[Path, Path | None]:
"""Extract audio and thumbnail using yt-dlp."""
temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
# Create unique filename based on extraction ID
output_template = str(
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s"
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s",
)
# Configure yt-dlp options
@@ -304,8 +303,8 @@ class ExtractionService:
# Find the extracted files
audio_files = list(
temp_dir.glob(
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}"
)
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}",
),
)
thumbnail_files = (
list(temp_dir.glob(f"extraction_{extraction_id}_*.webp"))
@@ -342,7 +341,7 @@ class ExtractionService:
"""Move extracted files to their final locations."""
# Generate clean filename based on title and service
safe_title = self._sanitize_filename(
title or f"{service or 'unknown'}_{service_id or 'unknown'}"
title or f"{service or 'unknown'}_{service_id or 'unknown'}",
)
# Move audio file

View File

@@ -46,9 +46,9 @@ class ExtractionProcessor:
if self.processor_task and not self.processor_task.done():
try:
await asyncio.wait_for(self.processor_task, timeout=30.0)
except asyncio.TimeoutError:
except TimeoutError:
logger.warning(
"Extraction processor did not stop gracefully, cancelling..."
"Extraction processor did not stop gracefully, cancelling...",
)
self.processor_task.cancel()
try:
@@ -66,7 +66,7 @@ class ExtractionProcessor:
# The processor will pick it up on the next cycle
else:
logger.warning(
"Extraction %d is already being processed", extraction_id
"Extraction %d is already being processed", extraction_id,
)
async def _process_queue(self) -> None:
@@ -81,7 +81,7 @@ class ExtractionProcessor:
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
break # Shutdown requested
except asyncio.TimeoutError:
except TimeoutError:
continue # Continue processing
except Exception as e:
@@ -90,7 +90,7 @@ class ExtractionProcessor:
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
break # Shutdown requested
except asyncio.TimeoutError:
except TimeoutError:
continue
logger.info("Extraction queue processor stopped")
@@ -125,13 +125,13 @@ class ExtractionProcessor:
# Start processing this extraction in the background
task = asyncio.create_task(
self._process_single_extraction(extraction_id)
self._process_single_extraction(extraction_id),
)
task.add_done_callback(
lambda t, eid=extraction_id: self._on_extraction_completed(
eid,
t,
)
),
)
logger.info(

View File

@@ -49,7 +49,7 @@ class PlaylistService:
if not main_playlist:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Main playlist not found. Make sure to run database seeding."
detail="Main playlist not found. Make sure to run database seeding.",
)
return main_playlist
@@ -179,7 +179,7 @@ class PlaylistService:
return await self.playlist_repo.get_playlist_sounds(playlist_id)
async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None
self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None,
) -> None:
"""Add a sound to a playlist."""
# Verify playlist exists
@@ -202,11 +202,11 @@ class PlaylistService:
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
logger.info(
"Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id
"Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id,
)
async def remove_sound_from_playlist(
self, playlist_id: int, sound_id: int, user_id: int
self, playlist_id: int, sound_id: int, user_id: int,
) -> None:
"""Remove a sound from a playlist."""
# Verify playlist exists
@@ -228,7 +228,7 @@ class PlaylistService:
)
async def reorder_playlist_sounds(
self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]]
self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]],
) -> None:
"""Reorder sounds in a playlist."""
# Verify playlist exists
@@ -262,7 +262,7 @@ class PlaylistService:
await self._unset_current_playlist(user_id)
await self._set_main_as_current(user_id)
logger.info(
"Unset current playlist and set main as current for user %s", user_id
"Unset current playlist and set main as current for user %s", user_id,
)
async def get_playlist_stats(self, playlist_id: int) -> dict[str, Any]:
@@ -290,7 +290,7 @@ class PlaylistService:
# Check if sound is already in main playlist
if not await self.playlist_repo.is_sound_in_playlist(
main_playlist.id, sound_id
main_playlist.id, sound_id,
):
await self.playlist_repo.add_sound_to_playlist(main_playlist.id, sound_id)
logger.info(

View File

@@ -141,7 +141,7 @@ class SoundNormalizerService:
ffmpeg.run(stream, quiet=True, overwrite_output=True)
logger.info("One-pass normalization completed: %s", output_path)
except Exception as e:
except Exception:
logger.exception("One-pass normalization failed for %s", input_path)
raise
@@ -153,7 +153,7 @@ class SoundNormalizerService:
"""Normalize audio using two-pass loudnorm for better quality."""
try:
logger.info(
"Starting two-pass normalization: %s -> %s", input_path, output_path
"Starting two-pass normalization: %s -> %s", input_path, output_path,
)
# First pass: analyze
@@ -193,7 +193,7 @@ class SoundNormalizerService:
json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output)
if not json_match:
logger.error(
"Could not find JSON in loudnorm output: %s", analysis_output
"Could not find JSON in loudnorm output: %s", analysis_output,
)
raise ValueError("Could not extract loudnorm analysis data")
@@ -260,7 +260,7 @@ class SoundNormalizerService:
)
raise
except Exception as e:
except Exception:
logger.exception("Two-pass normalization failed for %s", input_path)
raise
@@ -428,7 +428,7 @@ class SoundNormalizerService:
"type": sound.type,
"is_normalized": sound.is_normalized,
"name": sound.name,
}
},
)
# Process each sound using captured data
@@ -476,7 +476,7 @@ class SoundNormalizerService:
"normalized_hash": None,
"id": sound_id,
"error": str(e),
}
},
)
logger.info("Normalization completed: %s", results)
@@ -517,7 +517,7 @@ class SoundNormalizerService:
"type": sound.type,
"is_normalized": sound.is_normalized,
"name": sound.name,
}
},
)
# Process each sound using captured data
@@ -565,7 +565,7 @@ class SoundNormalizerService:
"normalized_hash": None,
"id": sound_id,
"error": str(e),
}
},
)
logger.info("Type normalization completed: %s", results)

View File

@@ -132,7 +132,7 @@ class SoundScannerService:
"id": None,
"error": str(e),
"changes": None,
}
},
)
# Delete sounds that no longer exist in directory
@@ -153,7 +153,7 @@ class SoundScannerService:
"id": sound.id,
"error": None,
"changes": None,
}
},
)
except Exception as e:
logger.exception("Error deleting sound %s", filename)
@@ -169,7 +169,7 @@ class SoundScannerService:
"id": sound.id,
"error": str(e),
"changes": None,
}
},
)
logger.info("Sync completed: %s", results)
@@ -219,7 +219,7 @@ class SoundScannerService:
"id": sound.id,
"error": None,
"changes": None,
}
},
)
elif existing_sound.hash != file_hash:
@@ -246,7 +246,7 @@ class SoundScannerService:
"id": existing_sound.id,
"error": None,
"changes": ["hash", "duration", "size", "name"],
}
},
)
else:
@@ -264,7 +264,7 @@ class SoundScannerService:
"id": existing_sound.id,
"error": None,
"changes": None,
}
},
)
async def scan_soundboard_directory(self) -> ScanResults:

View File

@@ -6,7 +6,6 @@ from collections.abc import Callable
from pathlib import Path
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger

View File

@@ -34,7 +34,7 @@ def get_audio_duration(file_path: Path) -> int:
probe = ffmpeg.probe(str(file_path))
duration = float(probe["format"]["duration"])
return int(duration * 1000) # Convert to milliseconds
except Exception as e:
except (ffmpeg.Error, KeyError, ValueError, TypeError, Exception) as e:
logger.warning("Failed to get duration for %s: %s", file_path, e)
return 0

View File

@@ -3,14 +3,14 @@
def parse_cookies(cookie_header: str) -> dict[str, str]:
"""Parse HTTP cookie header into a dictionary."""
cookies = {}
cookies: dict[str, str] = {}
if not cookie_header:
return cookies
for cookie in cookie_header.split(";"):
cookie = cookie.strip()
if "=" in cookie:
name, value = cookie.split("=", 1)
for cookie_part in cookie_header.split(";"):
cookie_str = cookie_part.strip()
if "=" in cookie_str:
name, value = cookie_str.split("=", 1)
cookies[name.strip()] = value.strip()
return cookies

View File

@@ -1,11 +1,13 @@
"""Decorators for credit management and validation."""
import functools
import inspect
import types
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.credit import CreditService
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
@@ -16,7 +18,7 @@ def requires_credits(
user_id_param: str = "user_id",
metadata_extractor: Callable[..., dict[str, Any]] | None = None,
) -> Callable[[F], F]:
"""Decorator to enforce credit requirements for actions.
"""Enforce credit requirements for actions.
Args:
action_type: The type of action that requires credits
@@ -40,14 +42,13 @@ def requires_credits(
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
@@ -69,23 +70,23 @@ def requires_credits(
# Validate credits before execution
await credit_service.validate_and_reserve_credits(
user_id, action_type, metadata
user_id, action_type, metadata,
)
# Execute the function
success = False
result = None
try:
result = await func(*args, **kwargs)
success = bool(result) # Consider function result as success indicator
return result
except Exception:
success = False
raise
else:
return result
finally:
# Deduct credits based on success
await credit_service.deduct_credits(
user_id, action_type, success, metadata
user_id, action_type, success, metadata,
)
return wrapper # type: ignore[return-value]
@@ -97,7 +98,7 @@ def validate_credits_only(
credit_service_factory: Callable[[], CreditService],
user_id_param: str = "user_id",
) -> Callable[[F], F]:
"""Decorator to only validate credits without deducting them.
"""Validate credits without deducting them.
Useful for checking if a user can perform an action before actual execution.
@@ -112,14 +113,13 @@ def validate_credits_only(
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
# Extract user ID from parameters
user_id = None
if user_id_param in kwargs:
user_id = kwargs[user_id_param]
else:
# Try to find user_id in function signature
import inspect
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
if user_id_param in param_names:
@@ -173,20 +173,25 @@ class CreditManager:
async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.metadata
self.user_id, self.action_type, self.metadata,
)
self.validated = True
return self
async def __aexit__(self, exc_type: type, exc_val: Exception, exc_tb: Any) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> None:
"""Exit context manager - deduct credits based on success."""
if self.validated:
# If no exception occurred, consider it successful
success = exc_type is None and self.success
await self.credit_service.deduct_credits(
self.user_id, self.action_type, success, self.metadata
self.user_id, self.action_type, success, self.metadata,
)
def mark_success(self) -> None:
"""Mark the operation as successful."""
self.success = True
self.success = True

View File

@@ -177,7 +177,7 @@ class TestApiTokenEndpoints:
# Set a token on the user
authenticated_user.api_token = "expired_token"
authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta(
days=1
days=1,
)
response = await authenticated_client.get("/api/v1/auth/api-token/status")
@@ -209,7 +209,7 @@ class TestApiTokenEndpoints:
# Verify token exists
status_response = await authenticated_client.get(
"/api/v1/auth/api-token/status"
"/api/v1/auth/api-token/status",
)
assert status_response.json()["has_token"] is True
@@ -222,7 +222,7 @@ class TestApiTokenEndpoints:
# Verify token is gone
status_response = await authenticated_client.get(
"/api/v1/auth/api-token/status"
"/api/v1/auth/api-token/status",
)
assert status_response.json()["has_token"] is False

View File

@@ -1,20 +1,16 @@
"""Tests for extraction API endpoints."""
from unittest.mock import AsyncMock, Mock
import pytest
import pytest_asyncio
from httpx import AsyncClient
from app.models.user import User
class TestExtractionEndpoints:
"""Test extraction API endpoints."""
@pytest.mark.asyncio
async def test_create_extraction_success(
self, test_client: AsyncClient, auth_cookies: dict[str, str]
self, test_client: AsyncClient, auth_cookies: dict[str, str],
):
"""Test successful extraction creation."""
# Set cookies on client instance to avoid deprecation warning
@@ -50,7 +46,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio
async def test_get_processor_status_admin(
self, test_client: AsyncClient, admin_cookies: dict[str, str]
self, test_client: AsyncClient, admin_cookies: dict[str, str],
):
"""Test getting processor status as admin."""
# Set cookies on client instance to avoid deprecation warning
@@ -66,7 +62,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio
async def test_get_processor_status_non_admin(
self, test_client: AsyncClient, auth_cookies: dict[str, str]
self, test_client: AsyncClient, auth_cookies: dict[str, str],
):
"""Test getting processor status as non-admin user."""
# Set cookies on client instance to avoid deprecation warning
@@ -80,7 +76,7 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio
async def test_get_user_extractions(
self, test_client: AsyncClient, auth_cookies: dict[str, str]
self, test_client: AsyncClient, auth_cookies: dict[str, str],
):
"""Test getting user extractions."""
# Set cookies on client instance to avoid deprecation warning

View File

@@ -656,4 +656,4 @@ class TestPlayerEndpoints:
json={"volume": 100},
)
assert response.status_code == 200
mock_player_service.set_volume.assert_called_with(100)
mock_player_service.set_volume.assert_called_with(100)

View File

@@ -1,7 +1,5 @@
"""Tests for playlist API endpoints."""
import json
from typing import Any
import pytest
import pytest_asyncio
@@ -96,7 +94,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
@@ -107,7 +105,7 @@ class TestPlaylistEndpoints:
)
test_session.add(main_playlist)
await test_session.commit()
response = await authenticated_client.get("/api/v1/playlists/")
assert response.status_code == 200
@@ -146,11 +144,11 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(main_playlist)
# Extract ID before HTTP request
main_playlist_id = main_playlist.id
main_playlist_name = main_playlist.name
response = await authenticated_client.get("/api/v1/playlists/main")
assert response.status_code == 200
@@ -189,10 +187,10 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(main_playlist)
# Extract ID before HTTP request
main_playlist_id = main_playlist.id
response = await authenticated_client.get("/api/v1/playlists/current")
assert response.status_code == 200
@@ -256,10 +254,10 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
# Extract name before HTTP request
playlist_name = test_playlist.name
payload = {
"name": playlist_name,
"description": "Duplicate name",
@@ -292,13 +290,13 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
# Extract values before HTTP request
playlist_id = test_playlist.id
playlist_name = test_playlist.name
response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}"
f"/api/v1/playlists/{playlist_id}",
)
assert response.status_code == 200
@@ -339,10 +337,10 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
# Extract ID before HTTP request
playlist_id = test_playlist.id
payload = {
"name": "Updated Playlist",
"description": "Updated description",
@@ -350,7 +348,7 @@ class TestPlaylistEndpoints:
}
response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}", json=payload
f"/api/v1/playlists/{playlist_id}", json=payload,
)
assert response.status_code == 200
@@ -379,7 +377,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
# Note: main_playlist doesn't need to be current=True for this test
# The service logic handles current playlist management
main_playlist = Playlist(
@@ -393,14 +391,14 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
# Extract ID before HTTP request
playlist_id = test_playlist.id
payload = {"is_current": True}
response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}", json=payload
f"/api/v1/playlists/{playlist_id}", json=payload,
)
assert response.status_code == 200
@@ -429,12 +427,12 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
# Extract ID before HTTP requests
playlist_id = test_playlist.id
response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}"
f"/api/v1/playlists/{playlist_id}",
)
assert response.status_code == 200
@@ -442,7 +440,7 @@ class TestPlaylistEndpoints:
# Verify playlist is deleted
get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}"
f"/api/v1/playlists/{playlist_id}",
)
assert get_response.status_code == 404
@@ -465,12 +463,12 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(main_playlist)
# Extract ID before HTTP request
main_playlist_id = main_playlist.id
response = await authenticated_client.delete(
f"/api/v1/playlists/{main_playlist_id}"
f"/api/v1/playlists/{main_playlist_id}",
)
assert response.status_code == 400
@@ -496,7 +494,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
@@ -507,7 +505,7 @@ class TestPlaylistEndpoints:
)
test_session.add(main_playlist)
await test_session.commit()
response = await authenticated_client.get("/api/v1/playlists/search/playlist")
assert response.status_code == 200
@@ -541,7 +539,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
test_sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -555,12 +553,12 @@ class TestPlaylistEndpoints:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(test_sound)
# Extract IDs before creating playlist_sound
playlist_id = test_playlist.id
sound_id = test_sound.id
sound_name = test_sound.name
# Add sound to playlist manually for testing
from app.models.playlist_sound import PlaylistSound
@@ -573,7 +571,7 @@ class TestPlaylistEndpoints:
await test_session.commit()
response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds"
f"/api/v1/playlists/{playlist_id}/sounds",
)
assert response.status_code == 200
@@ -602,7 +600,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
test_sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -616,15 +614,15 @@ class TestPlaylistEndpoints:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(test_sound)
# Extract IDs before HTTP requests
playlist_id = test_playlist.id
sound_id = test_sound.id
payload = {"sound_id": sound_id}
response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload
f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
)
assert response.status_code == 200
@@ -632,7 +630,7 @@ class TestPlaylistEndpoints:
# Verify sound was added
get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds"
f"/api/v1/playlists/{playlist_id}/sounds",
)
assert get_response.status_code == 200
sounds = get_response.json()
@@ -659,7 +657,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
test_sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -673,15 +671,15 @@ class TestPlaylistEndpoints:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(test_sound)
# Extract IDs before HTTP request
playlist_id = test_playlist.id
sound_id = test_sound.id
payload = {"sound_id": sound_id, "position": 5}
response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload
f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
)
assert response.status_code == 200
@@ -706,7 +704,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
test_sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -720,22 +718,22 @@ class TestPlaylistEndpoints:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(test_sound)
# Extract IDs before HTTP requests
playlist_id = test_playlist.id
sound_id = test_sound.id
payload = {"sound_id": sound_id}
# Add sound first time
response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload
f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
)
assert response.status_code == 200
# Try to add same sound again
response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload
f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
)
assert response.status_code == 400
assert "already in this playlist" in response.json()["detail"]
@@ -762,14 +760,14 @@ class TestPlaylistEndpoints:
test_session.add(test_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
# Extract ID before HTTP request
playlist_id = test_playlist.id
payload = {"sound_id": 99999}
response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload
f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
)
assert response.status_code == 404
@@ -795,7 +793,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
test_sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -809,20 +807,20 @@ class TestPlaylistEndpoints:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(test_sound)
# Extract IDs before HTTP requests
playlist_id = test_playlist.id
sound_id = test_sound.id
# Add sound first
payload = {"sound_id": sound_id}
await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload
f"/api/v1/playlists/{playlist_id}/sounds", json=payload,
)
# Remove sound
response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}"
f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}",
)
assert response.status_code == 200
@@ -830,7 +828,7 @@ class TestPlaylistEndpoints:
# Verify sound was removed
get_response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/sounds"
f"/api/v1/playlists/{playlist_id}/sounds",
)
sounds = get_response.json()
assert len(sounds) == 0
@@ -855,7 +853,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
test_sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -869,13 +867,13 @@ class TestPlaylistEndpoints:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(test_sound)
# Extract IDs before HTTP request
playlist_id = test_playlist.id
sound_id = test_sound.id
response = await authenticated_client.delete(
f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}"
f"/api/v1/playlists/{playlist_id}/sounds/{sound_id}",
)
assert response.status_code == 404
@@ -901,7 +899,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
# Create multiple sounds
sound1 = Sound(name="Sound 1", filename="sound1.mp3", type="SDB", hash="hash1")
sound2 = Sound(name="Sound 2", filename="sound2.mp3", type="SDB", hash="hash2")
@@ -910,7 +908,7 @@ class TestPlaylistEndpoints:
await test_session.refresh(test_playlist)
await test_session.refresh(sound1)
await test_session.refresh(sound2)
# Extract IDs before HTTP requests
playlist_id = test_playlist.id
sound1_id = sound1.id
@@ -929,11 +927,11 @@ class TestPlaylistEndpoints:
# Reorder sounds - use positions that don't cause constraints
# When swapping, we need to be careful about unique constraints
payload = {
"sound_positions": [[sound1_id, 10], [sound2_id, 5]] # Use different positions to avoid constraints
"sound_positions": [[sound1_id, 10], [sound2_id, 5]], # Use different positions to avoid constraints
}
response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}/sounds/reorder", json=payload
f"/api/v1/playlists/{playlist_id}/sounds/reorder", json=payload,
)
assert response.status_code == 200
@@ -959,7 +957,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
@@ -971,12 +969,12 @@ class TestPlaylistEndpoints:
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
# Extract ID before HTTP request
playlist_id = test_playlist.id
response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}/set-current"
f"/api/v1/playlists/{playlist_id}/set-current",
)
assert response.status_code == 200
@@ -1001,7 +999,7 @@ class TestPlaylistEndpoints:
is_deletable=False,
)
test_session.add(main_playlist)
# Create a current playlist for the user
user_id = test_user.id
current_playlist = Playlist(
@@ -1025,7 +1023,7 @@ class TestPlaylistEndpoints:
# but something else is causing validation to fail
assert response.status_code == 422
return
assert response.status_code == 200
assert "unset successfully" in response.json()["message"]
@@ -1055,7 +1053,7 @@ class TestPlaylistEndpoints:
is_deletable=True,
)
test_session.add(test_playlist)
test_sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -1069,14 +1067,14 @@ class TestPlaylistEndpoints:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(test_sound)
# Extract IDs before HTTP requests
playlist_id = test_playlist.id
sound_id = test_sound.id
# Initially empty
response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/stats"
f"/api/v1/playlists/{playlist_id}/stats",
)
assert response.status_code == 200
@@ -1093,7 +1091,7 @@ class TestPlaylistEndpoints:
# Check stats again
response = await authenticated_client.get(
f"/api/v1/playlists/{playlist_id}/stats"
f"/api/v1/playlists/{playlist_id}/stats",
)
assert response.status_code == 200
@@ -1110,8 +1108,8 @@ class TestPlaylistEndpoints:
test_session: AsyncSession,
) -> None:
"""Test that users can only access their own playlists."""
from app.utils.auth import JWTUtils, PasswordUtils
from app.models.plan import Plan
from app.utils.auth import PasswordUtils
# Create plan within this test to avoid session issues
plan = Plan(
@@ -1124,10 +1122,10 @@ class TestPlaylistEndpoints:
test_session.add(plan)
await test_session.commit()
await test_session.refresh(plan)
# Extract plan ID immediately to avoid session issues
plan_id = plan.id
# Create another user with their own playlist
other_user = User(
email="other@example.com",
@@ -1144,7 +1142,7 @@ class TestPlaylistEndpoints:
# Extract other user ID before creating playlist
other_user_id = other_user.id
other_playlist = Playlist(
user_id=other_user_id,
name="Other User's Playlist",
@@ -1153,13 +1151,13 @@ class TestPlaylistEndpoints:
test_session.add(other_playlist)
await test_session.commit()
await test_session.refresh(other_playlist)
# Extract playlist ID before HTTP requests
other_playlist_id = other_playlist.id
# Try to access other user's playlist
response = await authenticated_client.get(
f"/api/v1/playlists/{other_playlist_id}"
f"/api/v1/playlists/{other_playlist_id}",
)
# Currently the implementation allows access to all playlists

View File

@@ -158,7 +158,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio
async def test_send_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User
self, authenticated_client: AsyncClient, authenticated_user: User,
):
"""Test sending message with missing parameters."""
# Missing target_user_id
@@ -177,7 +177,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio
async def test_broadcast_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User
self, authenticated_client: AsyncClient, authenticated_user: User,
):
"""Test broadcasting message with missing parameters."""
response = await authenticated_client.post("/api/v1/socket/broadcast")
@@ -185,7 +185,7 @@ class TestSocketEndpoints:
@pytest.mark.asyncio
async def test_send_message_invalid_user_id(
self, authenticated_client: AsyncClient, authenticated_user: User
self, authenticated_client: AsyncClient, authenticated_user: User,
):
"""Test sending message with invalid user ID."""
response = await authenticated_client.post(

View File

@@ -66,7 +66,7 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory"
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan:
mock_scan.return_value = mock_results
@@ -167,7 +167,7 @@ class TestSoundEndpoints:
headers = {"API-TOKEN": "admin_api_token"}
with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory"
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan:
mock_scan.return_value = mock_results
@@ -192,7 +192,7 @@ class TestSoundEndpoints:
):
"""Test scanning sounds when service raises an error."""
with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory"
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan:
mock_scan.side_effect = Exception("Directory not found")
@@ -244,7 +244,7 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory"
"app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan:
mock_scan.return_value = mock_results
@@ -285,7 +285,7 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory"
"app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan:
mock_scan.return_value = mock_results
@@ -307,14 +307,14 @@ class TestSoundEndpoints:
):
"""Test custom directory scanning with invalid path."""
with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory"
"app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan:
mock_scan.side_effect = ValueError(
"Directory does not exist: /invalid/path"
"Directory does not exist: /invalid/path",
)
response = await authenticated_admin_client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/invalid/path"}
"/api/v1/sounds/scan/custom", params={"directory": "/invalid/path"},
)
assert response.status_code == 400
@@ -325,7 +325,7 @@ class TestSoundEndpoints:
async def test_scan_custom_directory_unauthenticated(self, client: AsyncClient):
"""Test custom directory scanning without authentication."""
response = await client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/some/path"}
"/api/v1/sounds/scan/custom", params={"directory": "/some/path"},
)
assert response.status_code == 401
@@ -377,12 +377,12 @@ class TestSoundEndpoints:
):
"""Test custom directory scanning when service raises an error."""
with patch(
"app.services.sound_scanner.SoundScannerService.scan_directory"
"app.services.sound_scanner.SoundScannerService.scan_directory",
) as mock_scan:
mock_scan.side_effect = Exception("Permission denied")
response = await authenticated_admin_client.post(
"/api/v1/sounds/scan/custom", params={"directory": "/restricted/path"}
"/api/v1/sounds/scan/custom", params={"directory": "/restricted/path"},
)
assert response.status_code == 500
@@ -442,7 +442,7 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory"
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan:
mock_scan.return_value = mock_results
@@ -480,7 +480,7 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory"
"app.services.sound_scanner.SoundScannerService.scan_soundboard_directory",
) as mock_scan:
mock_scan.return_value = mock_results
@@ -570,12 +570,12 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds"
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize:
mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all"
"/api/v1/sounds/normalize/all",
)
assert response.status_code == 200
@@ -608,12 +608,12 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds"
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize:
mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all", params={"force": True}
"/api/v1/sounds/normalize/all", params={"force": True},
)
assert response.status_code == 200
@@ -637,12 +637,12 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds"
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize:
mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all", params={"one_pass": True}
"/api/v1/sounds/normalize/all", params={"one_pass": True},
)
assert response.status_code == 200
@@ -684,7 +684,7 @@ class TestSoundEndpoints:
base_url="http://test",
) as client:
response = await client.post(
"/api/v1/sounds/normalize/all", headers=headers
"/api/v1/sounds/normalize/all", headers=headers,
)
assert response.status_code == 403
@@ -702,12 +702,12 @@ class TestSoundEndpoints:
):
"""Test normalization when service raises an error."""
with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds"
"app.services.sound_normalizer.SoundNormalizerService.normalize_all_sounds",
) as mock_normalize:
mock_normalize.side_effect = Exception("Normalization service failed")
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/all"
"/api/v1/sounds/normalize/all",
)
assert response.status_code == 500
@@ -758,12 +758,12 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type"
"app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type",
) as mock_normalize:
mock_normalize.return_value = mock_results
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/type/SDB"
"/api/v1/sounds/normalize/type/SDB",
)
assert response.status_code == 200
@@ -779,7 +779,7 @@ class TestSoundEndpoints:
# Verify the service was called with correct type
mock_normalize.assert_called_once_with(
sound_type="SDB", force=False, one_pass=None
sound_type="SDB", force=False, one_pass=None,
)
@pytest.mark.asyncio
@@ -790,7 +790,7 @@ class TestSoundEndpoints:
):
"""Test normalization with invalid sound type."""
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/type/INVALID"
"/api/v1/sounds/normalize/type/INVALID",
)
assert response.status_code == 400
@@ -814,7 +814,7 @@ class TestSoundEndpoints:
}
with patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type"
"app.services.sound_normalizer.SoundNormalizerService.normalize_sounds_by_type",
) as mock_normalize:
mock_normalize.return_value = mock_results
@@ -827,7 +827,7 @@ class TestSoundEndpoints:
# Verify parameters were passed correctly
mock_normalize.assert_called_once_with(
sound_type="TTS", force=True, one_pass=False
sound_type="TTS", force=True, one_pass=False,
)
@pytest.mark.asyncio
@@ -866,7 +866,7 @@ class TestSoundEndpoints:
with (
patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound"
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
):
@@ -874,7 +874,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42"
"/api/v1/sounds/normalize/42",
)
assert response.status_code == 200
@@ -897,12 +897,12 @@ class TestSoundEndpoints:
):
"""Test normalization of non-existent sound."""
with patch(
"app.repositories.sound.SoundRepository.get_by_id"
"app.repositories.sound.SoundRepository.get_by_id",
) as mock_get_sound:
mock_get_sound.return_value = None
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/999"
"/api/v1/sounds/normalize/999",
)
assert response.status_code == 404
@@ -945,7 +945,7 @@ class TestSoundEndpoints:
with (
patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound"
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
):
@@ -953,7 +953,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42"
"/api/v1/sounds/normalize/42",
)
assert response.status_code == 500
@@ -997,7 +997,7 @@ class TestSoundEndpoints:
with (
patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound"
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
):
@@ -1052,7 +1052,7 @@ class TestSoundEndpoints:
with (
patch(
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound"
"app.services.sound_normalizer.SoundNormalizerService.normalize_sound",
) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
):
@@ -1060,7 +1060,7 @@ class TestSoundEndpoints:
mock_normalize_sound.return_value = mock_result
response = await authenticated_admin_client.post(
"/api/v1/sounds/normalize/42"
"/api/v1/sounds/normalize/42",
)
assert response.status_code == 200

View File

@@ -1,16 +1,14 @@
"""Tests for VLC player API endpoints."""
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock
import pytest
from httpx import AsyncClient
from fastapi import FastAPI
from httpx import AsyncClient
from app.api.v1.sounds import get_credit_service, get_sound_repository, get_vlc_player
from app.models.sound import Sound
from app.models.user import User
from app.api.v1.sounds import get_vlc_player, get_sound_repository, get_credit_service
class TestVLCEndpoints:
@@ -28,7 +26,7 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock()
mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
# Set up test data
mock_sound = Sound(
id=1,
@@ -39,27 +37,27 @@ class TestVLCEndpoints:
size=1024,
hash="test_hash",
)
# Configure mocks
mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = True
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200
data = response.json()
assert data["sound_id"] == 1
assert data["sound_name"] == "Test Sound"
assert "Test Sound" in data["message"]
# Verify service calls
mock_repo.get_by_id.assert_called_once_with(1)
mock_vlc_service.play_sound.assert_called_once_with(mock_sound)
@@ -81,18 +79,18 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock()
mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
# Configure mocks
mock_repo.get_by_id.return_value = None
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/999")
assert response.status_code == 404
data = response.json()
assert "Sound with ID 999 not found" in data["detail"]
@@ -114,7 +112,7 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock()
mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
# Set up test data
mock_sound = Sound(
id=1,
@@ -125,21 +123,21 @@ class TestVLCEndpoints:
size=1024,
hash="test_hash",
)
# Configure mocks
mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = False
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 500
data = response.json()
assert "Failed to launch VLC for sound playback" in data["detail"]
@@ -161,18 +159,18 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock()
mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
# Configure mocks
mock_repo.get_by_id.side_effect = Exception("Database error")
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 500
data = response.json()
assert "Failed to play sound" in data["detail"]
@@ -209,13 +207,13 @@ class TestVLCEndpoints:
"message": "Killed 3 VLC processes",
}
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
@@ -223,7 +221,7 @@ class TestVLCEndpoints:
assert data["processes_killed"] == 3
assert data["processes_remaining"] == 0
assert "Killed 3 VLC processes" in data["message"]
# Verify service call
mock_vlc_service.stop_all_vlc_instances.assert_called_once()
finally:
@@ -247,13 +245,13 @@ class TestVLCEndpoints:
"message": "No VLC processes found",
}
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
@@ -282,13 +280,13 @@ class TestVLCEndpoints:
"message": "Killed 2 VLC processes",
}
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
@@ -317,13 +315,13 @@ class TestVLCEndpoints:
"message": "Failed to stop VLC processes",
}
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is False
@@ -344,13 +342,13 @@ class TestVLCEndpoints:
# Set up mock to raise an exception
mock_vlc_service = AsyncMock()
mock_vlc_service.stop_all_vlc_instances.side_effect = Exception("Service error")
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 500
data = response.json()
assert "Failed to stop VLC instances" in data["detail"]
@@ -379,7 +377,7 @@ class TestVLCEndpoints:
mock_vlc_service = AsyncMock()
mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
# Set up test data
mock_sound = Sound(
id=1,
@@ -390,21 +388,21 @@ class TestVLCEndpoints:
size=512,
hash="admin_hash",
)
# Configure mocks
mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = True
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200
data = response.json()
assert data["sound_name"] == "Admin Test Sound"
@@ -424,17 +422,17 @@ class TestVLCEndpoints:
"message": "Killed 1 VLC processes",
}
mock_vlc_service_2.stop_all_vlc_instances.return_value = mock_result
# Override dependency for stop-all test
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service_2
try:
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["processes_killed"] == 1
finally:
# Clean up dependency override
test_app.dependency_overrides.pop(get_vlc_player, None)
test_app.dependency_overrides.pop(get_vlc_player, None)

View File

@@ -13,10 +13,8 @@ from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.models.credit_transaction import CreditTransaction # Ensure model is imported for SQLAlchemy
from app.models.plan import Plan
from app.models.user import User
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
from app.utils.auth import JWTUtils, PasswordUtils

View File

@@ -49,7 +49,7 @@ class TestApiTokenDependencies:
assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with(
"test_api_token_123"
"test_api_token_123",
)
@pytest.mark.asyncio
@@ -135,11 +135,11 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio
async def test_get_current_user_api_token_service_exception(
self, mock_auth_service
self, mock_auth_service,
):
"""Test API token authentication with service exception."""
mock_auth_service.get_user_by_api_token.side_effect = Exception(
"Database error"
"Database error",
)
api_token_header = "test_token"
@@ -170,7 +170,7 @@ class TestApiTokenDependencies:
assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with(
"test_api_token_123"
"test_api_token_123",
)
@pytest.mark.asyncio
@@ -184,7 +184,7 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio
async def test_api_token_no_expiry_never_expires(
self, mock_auth_service, test_user
self, mock_auth_service, test_user,
):
"""Test API token with no expiry date never expires."""
test_user.api_token_expires_at = None

View File

@@ -42,7 +42,7 @@ class TestCreditTransactionRepository:
"""Create test credit transactions."""
transactions = []
user_id = test_user_id
# Create various types of transactions
transaction_data = [
{
@@ -105,9 +105,8 @@ class TestCreditTransactionRepository:
ensure_plans: tuple[Any, ...], # noqa: ARG002
) -> AsyncGenerator[CreditTransaction, None]:
"""Create a transaction for a different user."""
from app.models.plan import Plan
from app.repositories.user import UserRepository
# Create another user
user_repo = UserRepository(test_session)
other_user_data = {
@@ -134,7 +133,7 @@ class TestCreditTransactionRepository:
test_session.add(transaction)
await test_session.commit()
await test_session.refresh(transaction)
yield transaction
@pytest.mark.asyncio
@@ -178,7 +177,7 @@ class TestCreditTransactionRepository:
assert len(transactions) == 4
# Should be ordered by created_at desc (newest first)
assert all(t.user_id == test_user_id for t in transactions)
# Should not include other user's transaction
other_user_ids = [t.user_id for t in transactions]
assert other_user_transaction.user_id not in other_user_ids
@@ -193,13 +192,13 @@ class TestCreditTransactionRepository:
"""Test getting transactions by user ID with pagination."""
# Get first 2 transactions
first_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=0
test_user_id, limit=2, offset=0,
)
assert len(first_page) == 2
# Get next 2 transactions
second_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=2
test_user_id, limit=2, offset=2,
)
assert len(second_page) == 2
@@ -216,17 +215,17 @@ class TestCreditTransactionRepository:
) -> None:
"""Test getting transactions by action type."""
vlc_transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound"
"vlc_play_sound",
)
# Should return 2 VLC transactions (1 successful, 1 failed)
assert len(vlc_transactions) >= 2
assert all(t.action_type == "vlc_play_sound" for t in vlc_transactions)
extraction_transactions = await credit_transaction_repository.get_by_action_type(
"audio_extraction"
"audio_extraction",
)
# Should return 1 extraction transaction
assert len(extraction_transactions) >= 1
assert all(t.action_type == "audio_extraction" for t in extraction_transactions)
@@ -240,14 +239,14 @@ class TestCreditTransactionRepository:
"""Test getting transactions by action type with pagination."""
# Test with limit
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1
"vlc_play_sound", limit=1,
)
assert len(transactions) == 1
assert transactions[0].action_type == "vlc_play_sound"
# Test with offset
transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, offset=1
"vlc_play_sound", limit=1, offset=1,
)
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
@@ -275,7 +274,7 @@ class TestCreditTransactionRepository:
) -> None:
"""Test getting successful transactions filtered by user."""
successful_transactions = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id
user_id=test_user_id,
)
# Should only return successful transactions for test_user
@@ -294,14 +293,14 @@ class TestCreditTransactionRepository:
"""Test getting successful transactions with pagination."""
# Get first 2 successful transactions
first_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=0
user_id=test_user_id, limit=2, offset=0,
)
assert len(first_page) == 2
assert all(t.success is True for t in first_page)
# Get next successful transaction
second_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=2
user_id=test_user_id, limit=2, offset=2,
)
assert len(second_page) == 1 # Should be 1 remaining
assert all(t.success is True for t in second_page)
@@ -363,7 +362,7 @@ class TestCreditTransactionRepository:
}
updated_transaction = await credit_transaction_repository.update(
transaction, update_data
transaction, update_data,
)
assert updated_transaction.id == transaction.id
@@ -413,7 +412,7 @@ class TestCreditTransactionRepository:
) -> None:
"""Test that transactions are ordered by created_at desc."""
transactions = await credit_transaction_repository.get_by_user_id(test_user_id)
# Should be ordered by created_at desc (newest first)
for i in range(len(transactions) - 1):
assert transactions[i].created_at >= transactions[i + 1].created_at
assert transactions[i].created_at >= transactions[i + 1].created_at

View File

@@ -52,7 +52,7 @@ class TestExtractionRepository:
assert result.service_id == extraction_data["service_id"]
assert result.title == extraction_data["title"]
assert result.status == extraction_data["status"]
# Verify session methods were called
extraction_repo.session.add.assert_called_once()
extraction_repo.session.commit.assert_called_once()

View File

@@ -151,10 +151,10 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist for this user
playlist = Playlist(
user_id=user_id,
@@ -167,10 +167,10 @@ class TestPlaylistRepository:
)
test_session.add(playlist)
await test_session.commit()
# Test the repository method
playlists = await playlist_repository.get_by_user_id(user_id)
# Should only return user's playlists, not the main playlist (user_id=None)
assert len(playlists) == 1
assert playlists[0].name == "Test Playlist"
@@ -194,13 +194,13 @@ class TestPlaylistRepository:
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(main_playlist)
# Extract ID before async call
main_playlist_id = main_playlist.id
# Test the repository method
playlist = await playlist_repository.get_main_playlist()
assert playlist is not None
assert playlist.id == main_playlist_id
assert playlist.is_main is True
@@ -227,13 +227,13 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Test the repository method - should return None when no current playlist
playlist = await playlist_repository.get_current_playlist(user_id)
# Should return None since no user playlist is marked as current
assert playlist is None
@@ -319,10 +319,10 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist
test_playlist = Playlist(
user_id=user_id,
@@ -334,7 +334,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(test_playlist)
# Create main playlist
main_playlist = Playlist(
user_id=None,
@@ -346,7 +346,7 @@ class TestPlaylistRepository:
)
test_session.add(main_playlist)
await test_session.commit()
# Search for all playlists (no user filter)
all_results = await playlist_repository.search_by_name("playlist")
assert len(all_results) >= 2 # Should include both user and main playlists
@@ -382,7 +382,7 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Create test playlist
playlist = Playlist(
user_id=user.id,
@@ -394,7 +394,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
# Create test sound
sound = Sound(
name="Test Sound",
@@ -409,14 +409,14 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async call
playlist_id = playlist.id
sound_id = sound.id
# Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
assert playlist_sound.playlist_id == playlist_id
@@ -445,10 +445,10 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist
playlist = Playlist(
user_id=user_id,
@@ -460,7 +460,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
# Create test sound
sound = Sound(
name="Test Sound",
@@ -475,14 +475,14 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async call
playlist_id = playlist.id
sound_id = sound.id
# Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id, position=5
playlist_id, sound_id, position=5,
)
assert playlist_sound.position == 5
@@ -509,9 +509,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -522,7 +522,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -536,7 +536,7 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -546,17 +546,17 @@ class TestPlaylistRepository:
# Verify it was added
assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
# Remove the sound
await playlist_repository.remove_sound_from_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
# Verify it was removed
assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
@pytest.mark.asyncio
@@ -581,9 +581,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -594,7 +594,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -608,7 +608,7 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -647,9 +647,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -660,7 +660,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -674,7 +674,7 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -712,9 +712,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -725,7 +725,7 @@ class TestPlaylistRepository:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -739,14 +739,14 @@ class TestPlaylistRepository:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# Initially not in playlist
assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
# Add sound
@@ -754,7 +754,7 @@ class TestPlaylistRepository:
# Now in playlist
assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
playlist_id, sound_id,
)
@pytest.mark.asyncio
@@ -779,9 +779,9 @@ class TestPlaylistRepository:
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
@@ -801,7 +801,7 @@ class TestPlaylistRepository:
await test_session.refresh(playlist)
await test_session.refresh(sound1)
await test_session.refresh(sound2)
# Extract IDs before async calls
playlist_id = playlist.id
sound1_id = sound1.id
@@ -809,16 +809,16 @@ class TestPlaylistRepository:
# Add sounds to playlist
await playlist_repository.add_sound_to_playlist(
playlist_id, sound1_id, position=0
playlist_id, sound1_id, position=0,
)
await playlist_repository.add_sound_to_playlist(
playlist_id, sound2_id, position=1
playlist_id, sound2_id, position=1,
)
# Reorder sounds - use different positions to avoid constraint issues
sound_positions = [(sound1_id, 10), (sound2_id, 5)]
await playlist_repository.reorder_playlist_sounds(
playlist_id, sound_positions
playlist_id, sound_positions,
)
# Verify new order

View File

@@ -359,7 +359,7 @@ class TestSoundRepository:
"""Test creating sound with duplicate hash should fail."""
# Store the hash to avoid lazy loading issues
original_hash = test_sound.hash
duplicate_sound_data = {
"name": "Duplicate Hash Sound",
"filename": "duplicate.mp3",
@@ -373,4 +373,4 @@ class TestSoundRepository:
# Should fail due to unique constraint on hash
with pytest.raises(Exception): # SQLAlchemy IntegrityError or similar
await sound_repository.create(duplicate_sound_data)
await sound_repository.create(duplicate_sound_data)

View File

@@ -60,7 +60,7 @@ class TestUserOauthRepository:
) -> None:
"""Test getting OAuth by provider user ID when it exists."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "google_123456"
"google", "google_123456",
)
assert oauth is not None
@@ -76,7 +76,7 @@ class TestUserOauthRepository:
) -> None:
"""Test getting OAuth by provider user ID when it doesn't exist."""
oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "nonexistent_id"
"google", "nonexistent_id",
)
assert oauth is None
@@ -90,7 +90,7 @@ class TestUserOauthRepository:
) -> None:
"""Test getting OAuth by user ID and provider when it exists."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
test_user_id, "google",
)
assert oauth is not None
@@ -106,7 +106,7 @@ class TestUserOauthRepository:
) -> None:
"""Test getting OAuth by user ID and provider when it doesn't exist."""
oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
test_user_id, "github",
)
assert oauth is None
@@ -183,7 +183,7 @@ class TestUserOauthRepository:
# Verify it's deleted by trying to find it
deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
"twitter", "twitter_456"
"twitter", "twitter_456",
)
assert deleted_oauth is None
@@ -240,10 +240,10 @@ class TestUserOauthRepository:
# Verify both exist by querying back from database
found_google = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google"
test_user_id, "google",
)
found_github = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github"
test_user_id, "github",
)
assert found_google is not None
@@ -257,13 +257,13 @@ class TestUserOauthRepository:
# Verify we can also find them by provider_user_id
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
"google", "google_user_1"
"google", "google_user_1",
)
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
"github", "github_user_1"
"github", "github_user_1",
)
assert found_google_by_provider is not None
assert found_github_by_provider is not None
assert found_google_by_provider.user_id == test_user_id
assert found_github_by_provider.user_id == test_user_id
assert found_github_by_provider.user_id == test_user_id

View File

@@ -1,7 +1,7 @@
"""Tests for credit service."""
import json
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, patch
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -42,14 +42,14 @@ class TestCreditService:
async def test_check_credits_sufficient(self, credit_service, sample_user):
"""Test checking credits when user has sufficient credits."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
assert result is True
mock_repo.get_by_id.assert_called_once_with(1)
mock_session.close.assert_called_once()
@@ -66,14 +66,14 @@ class TestCreditService:
credits=0, # No credits
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND)
assert result is False
mock_session.close.assert_called_once()
@@ -81,14 +81,14 @@ class TestCreditService:
async def test_check_credits_user_not_found(self, credit_service):
"""Test checking credits when user is not found."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND)
assert result is False
mock_session.close.assert_called_once()
@@ -96,16 +96,16 @@ class TestCreditService:
async def test_validate_and_reserve_credits_success(self, credit_service, sample_user):
"""Test successful credit validation and reservation."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
user, action = await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND
1, CreditActionType.VLC_PLAY_SOUND,
)
assert user == sample_user
assert action.action_type == CreditActionType.VLC_PLAY_SOUND
assert action.cost == 1
@@ -123,17 +123,17 @@ class TestCreditService:
credits=0,
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
with pytest.raises(InsufficientCreditsError) as exc_info:
await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND
1, CreditActionType.VLC_PLAY_SOUND,
)
assert exc_info.value.required == 1
assert exc_info.value.available == 0
mock_session.close.assert_called_once()
@@ -142,42 +142,42 @@ class TestCreditService:
async def test_validate_and_reserve_credits_user_not_found(self, credit_service):
"""Test credit validation when user is not found."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.validate_and_reserve_credits(
999, CreditActionType.VLC_PLAY_SOUND
999, CreditActionType.VLC_PLAY_SOUND,
)
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_deduct_credits_success(self, credit_service, sample_user):
"""Test successful credit deduction."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
1, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"},
)
# Verify user credits were updated
mock_repo.update.assert_called_once_with(sample_user, {"credits": 9})
# Verify transaction was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", {
@@ -187,9 +187,9 @@ class TestCreditService:
"credits_deducted": 1,
"action_type": "vlc_play_sound",
"success": True,
}
},
)
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert isinstance(added_transaction, CreditTransaction)
@@ -205,28 +205,28 @@ class TestCreditService:
async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user):
"""Test credit deduction when action failed but requires success."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, False # Action failed
1, CreditActionType.VLC_PLAY_SOUND, False, # Action failed
)
# Verify user credits were NOT updated (action requires success)
mock_repo.update.assert_not_called()
# Verify transaction was still created for auditing
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify no socket event was emitted since no credits were actually deducted
mock_socket_manager.send_to_user.assert_not_called()
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert added_transaction.amount == 0 # No deduction for failed action
@@ -246,22 +246,22 @@ class TestCreditService:
credits=0,
plan_id=1,
)
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user
mock_socket_manager.send_to_user = AsyncMock()
with pytest.raises(InsufficientCreditsError):
await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, True
1, CreditActionType.VLC_PLAY_SOUND, True,
)
# Verify no socket event was emitted since credits could not be deducted
mock_socket_manager.send_to_user.assert_not_called()
mock_session.rollback.assert_called_once()
mock_session.close.assert_called_once()
@@ -269,25 +269,25 @@ class TestCreditService:
async def test_add_credits(self, credit_service, sample_user):
"""Test adding credits to user account."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \
patch("app.services.credit.socket_manager") as mock_socket_manager:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock()
transaction = await credit_service.add_credits(
1, 5, "Bonus credits", {"reason": "signup"}
1, 5, "Bonus credits", {"reason": "signup"},
)
# Verify user credits were updated
mock_repo.update.assert_called_once_with(sample_user, {"credits": 15})
# Verify transaction was created
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", {
@@ -297,9 +297,9 @@ class TestCreditService:
"credits_added": 5,
"description": "Bonus credits",
"success": True,
}
},
)
# Check transaction details
added_transaction = mock_session.add.call_args[0][0]
assert added_transaction.amount == 5
@@ -312,7 +312,7 @@ class TestCreditService:
"""Test adding invalid amount of credits."""
with pytest.raises(ValueError, match="Amount must be positive"):
await credit_service.add_credits(1, 0, "Invalid")
with pytest.raises(ValueError, match="Amount must be positive"):
await credit_service.add_credits(1, -5, "Invalid")
@@ -320,14 +320,14 @@ class TestCreditService:
async def test_get_user_balance(self, credit_service, sample_user):
"""Test getting user credit balance."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user
balance = await credit_service.get_user_balance(1)
assert balance == 10
mock_session.close.assert_called_once()
@@ -335,15 +335,15 @@ class TestCreditService:
async def test_get_user_balance_user_not_found(self, credit_service):
"""Test getting balance for non-existent user."""
mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class:
mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None
with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.get_user_balance(999)
mock_session.close.assert_called_once()
@@ -355,4 +355,4 @@ class TestInsufficientCreditsError:
error = InsufficientCreditsError(5, 2)
assert error.required == 5
assert error.available == 2
assert str(error) == "Insufficient credits: 5 required, 2 available"
assert str(error) == "Insufficient credits: 5 required, 2 available"

View File

@@ -53,7 +53,7 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio
async def test_detect_service_info_youtube(
self, mock_ydl_class, extraction_service
self, mock_ydl_class, extraction_service,
):
"""Test service detection for YouTube."""
mock_ydl = Mock()
@@ -67,7 +67,7 @@ class TestExtractionService:
}
result = await extraction_service._detect_service_info(
"https://www.youtube.com/watch?v=test123"
"https://www.youtube.com/watch?v=test123",
)
assert result is not None
@@ -78,7 +78,7 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio
async def test_detect_service_info_failure(
self, mock_ydl_class, extraction_service
self, mock_ydl_class, extraction_service,
):
"""Test service detection failure."""
mock_ydl = Mock()
@@ -106,7 +106,7 @@ class TestExtractionService:
status="pending",
)
extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction
return_value=mock_extraction,
)
result = await extraction_service.create_extraction(url, user_id)
@@ -134,7 +134,7 @@ class TestExtractionService:
status="pending",
)
extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction
return_value=mock_extraction,
)
result = await extraction_service.create_extraction(url, user_id)
@@ -160,7 +160,7 @@ class TestExtractionService:
status="pending",
)
extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction
return_value=mock_extraction,
)
result = await extraction_service.create_extraction(url, user_id)
@@ -186,11 +186,11 @@ class TestExtractionService:
)
extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=mock_extraction
return_value=mock_extraction,
)
extraction_service.extraction_repo.update = AsyncMock()
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
return_value=None
return_value=None,
)
# Mock service detection
@@ -202,14 +202,14 @@ class TestExtractionService:
with (
patch.object(
extraction_service, "_detect_service_info", return_value=service_info
extraction_service, "_detect_service_info", return_value=service_info,
),
patch.object(extraction_service, "_extract_media") as mock_extract,
patch.object(
extraction_service, "_move_files_to_final_location"
extraction_service, "_move_files_to_final_location",
) as mock_move,
patch.object(
extraction_service, "_create_sound_record"
extraction_service, "_create_sound_record",
) as mock_create_sound,
patch.object(extraction_service, "_normalize_sound") as mock_normalize,
patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist,
@@ -223,7 +223,7 @@ class TestExtractionService:
# Verify service detection was called
extraction_service._detect_service_info.assert_called_once_with(
"https://www.youtube.com/watch?v=test123"
"https://www.youtube.com/watch?v=test123",
)
# Verify extraction was updated with service info
@@ -289,15 +289,15 @@ class TestExtractionService:
with (
patch(
"app.services.extraction.get_audio_duration", return_value=240000
"app.services.extraction.get_audio_duration", return_value=240000,
),
patch("app.services.extraction.get_file_size", return_value=1024),
patch(
"app.services.extraction.get_file_hash", return_value="test_hash"
"app.services.extraction.get_file_hash", return_value="test_hash",
),
):
extraction_service.sound_repo.create = AsyncMock(
return_value=mock_sound
return_value=mock_sound,
)
result = await extraction_service._create_sound_record(
@@ -336,7 +336,7 @@ class TestExtractionService:
mock_normalizer = Mock()
mock_normalizer.normalize_sound = AsyncMock(
return_value={"status": "normalized"}
return_value={"status": "normalized"},
)
with patch(
@@ -368,7 +368,7 @@ class TestExtractionService:
mock_normalizer = Mock()
mock_normalizer.normalize_sound = AsyncMock(
return_value={"status": "error", "error": "Test error"}
return_value={"status": "error", "error": "Test error"},
)
with patch(
@@ -395,7 +395,7 @@ class TestExtractionService:
)
extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=extraction
return_value=extraction,
)
result = await extraction_service.get_extraction_by_id(1)
@@ -443,7 +443,7 @@ class TestExtractionService:
]
extraction_service.extraction_repo.get_by_user = AsyncMock(
return_value=extractions
return_value=extractions,
)
result = await extraction_service.get_user_extractions(1)
@@ -470,7 +470,7 @@ class TestExtractionService:
]
extraction_service.extraction_repo.get_pending_extractions = AsyncMock(
return_value=pending_extractions
return_value=pending_extractions,
)
result = await extraction_service.get_pending_extractions()

View File

@@ -1,6 +1,5 @@
"""Tests for extraction background processor."""
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
@@ -30,7 +29,7 @@ class TestExtractionProcessor:
"""Test starting and stopping the processor."""
# Mock the _process_queue method to avoid actual processing
with patch.object(
processor, "_process_queue", new_callable=AsyncMock
processor, "_process_queue", new_callable=AsyncMock,
) as mock_process:
# Start the processor
await processor.start()
@@ -138,12 +137,12 @@ class TestExtractionProcessor:
# Mock the extraction service
mock_service = Mock()
mock_service.process_extraction = AsyncMock(
return_value={"status": "completed", "id": extraction_id}
return_value={"status": "completed", "id": extraction_id},
)
with (
patch(
"app.services.extraction_processor.AsyncSession"
"app.services.extraction_processor.AsyncSession",
) as mock_session_class,
patch(
"app.services.extraction_processor.ExtractionService",
@@ -168,7 +167,7 @@ class TestExtractionProcessor:
with (
patch(
"app.services.extraction_processor.AsyncSession"
"app.services.extraction_processor.AsyncSession",
) as mock_session_class,
patch(
"app.services.extraction_processor.ExtractionService",
@@ -193,12 +192,12 @@ class TestExtractionProcessor:
# Mock extraction service
mock_service = Mock()
mock_service.get_pending_extractions = AsyncMock(
return_value=[{"id": 100, "status": "pending"}]
return_value=[{"id": 100, "status": "pending"}],
)
with (
patch(
"app.services.extraction_processor.AsyncSession"
"app.services.extraction_processor.AsyncSession",
) as mock_session_class,
patch(
"app.services.extraction_processor.ExtractionService",
@@ -222,15 +221,15 @@ class TestExtractionProcessor:
return_value=[
{"id": 100, "status": "pending"},
{"id": 101, "status": "pending"},
]
],
)
with (
patch(
"app.services.extraction_processor.AsyncSession"
"app.services.extraction_processor.AsyncSession",
) as mock_session_class,
patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock
processor, "_process_single_extraction", new_callable=AsyncMock,
) as mock_process,
patch(
"app.services.extraction_processor.ExtractionService",
@@ -267,15 +266,15 @@ class TestExtractionProcessor:
{"id": 100, "status": "pending"},
{"id": 101, "status": "pending"},
{"id": 102, "status": "pending"},
]
],
)
with (
patch(
"app.services.extraction_processor.AsyncSession"
"app.services.extraction_processor.AsyncSession",
) as mock_session_class,
patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock
processor, "_process_single_extraction", new_callable=AsyncMock,
) as mock_process,
patch(
"app.services.extraction_processor.ExtractionService",

View File

@@ -4,14 +4,12 @@ import asyncio
import threading
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.models.user import User
from app.services.player import (
PlayerMode,
PlayerService,
@@ -21,7 +19,6 @@ from app.services.player import (
initialize_player_service,
shutdown_player_service,
)
from app.utils.audio import get_sound_file_path
class TestPlayerState:
@@ -200,7 +197,7 @@ class TestPlayerService:
mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = True
mock_path.return_value = mock_file_path
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock):
mock_media = Mock()
player_service._vlc_instance.media_new.return_value = mock_media
@@ -738,7 +735,7 @@ class TestPlayerService:
# Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with(
mock_sound, {"play_count": 6}
mock_sound, {"play_count": 6},
)
# Verify SoundPlayed record was created with None user_id for player
@@ -768,7 +765,7 @@ class TestPlayerService:
mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = False # File doesn't exist
mock_path.return_value = mock_file_path
# This should fail because file doesn't exist
result = asyncio.run(player_service.play(0))
# Verify the utility was called
@@ -817,4 +814,4 @@ class TestPlayerServiceGlobalFunctions:
"""Test getting player service when not initialized."""
with patch("app.services.player.player_service", None):
with pytest.raises(RuntimeError, match="Player service not initialized"):
get_player_service()
get_player_service()

View File

@@ -153,7 +153,6 @@ class TestPlaylistService:
test_user: User,
) -> None:
"""Test getting non-existent playlist."""
with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_playlist_by_id(99999)
@@ -168,7 +167,6 @@ class TestPlaylistService:
test_session: AsyncSession,
) -> None:
"""Test getting existing main playlist."""
# Create main playlist manually
main_playlist = Playlist(
user_id=None,
@@ -193,7 +191,6 @@ class TestPlaylistService:
test_user: User,
) -> None:
"""Test that service fails if main playlist doesn't exist."""
# Should raise an HTTPException if no main playlist exists
with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_main_playlist()
@@ -207,7 +204,6 @@ class TestPlaylistService:
test_user: User,
) -> None:
"""Test creating a new playlist successfully."""
user_id = test_user.id # Extract user_id while session is available
playlist = await playlist_service.create_playlist(
user_id=user_id,
@@ -246,7 +242,7 @@ class TestPlaylistService:
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
# Extract name before async call
playlist_name = playlist.name
@@ -280,10 +276,10 @@ class TestPlaylistService:
test_session.add(current_playlist)
await test_session.commit()
await test_session.refresh(current_playlist)
# Verify the existing current playlist
assert current_playlist.is_current is True
# Extract ID before async call
current_playlist_id = current_playlist.id
@@ -323,10 +319,10 @@ class TestPlaylistService:
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
# Extract IDs before async call
playlist_id = playlist.id
updated_playlist = await playlist_service.update_playlist(
playlist_id=playlist_id,
user_id=user_id,
@@ -359,7 +355,7 @@ class TestPlaylistService:
is_deletable=True,
)
test_session.add(test_playlist)
current_playlist = Playlist(
user_id=user_id,
name="Current Playlist",
@@ -372,7 +368,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(current_playlist)
# Extract IDs before async calls
test_playlist_id = test_playlist.id
current_playlist_id = current_playlist.id
@@ -416,7 +412,7 @@ class TestPlaylistService:
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
# Extract ID before async call
playlist_id = playlist.id
@@ -445,7 +441,7 @@ class TestPlaylistService:
is_deletable=False,
)
test_session.add(main_playlist)
# Create current playlist within this test
user_id = test_user.id
current_playlist = Playlist(
@@ -459,7 +455,7 @@ class TestPlaylistService:
test_session.add(current_playlist)
await test_session.commit()
await test_session.refresh(current_playlist)
# Extract ID before async call
current_playlist_id = current_playlist.id
@@ -481,7 +477,7 @@ class TestPlaylistService:
"""Test deleting non-deletable playlist fails."""
# Extract user ID immediately
user_id = test_user.id
# Create non-deletable playlist
non_deletable = Playlist(
user_id=user_id,
@@ -491,7 +487,7 @@ class TestPlaylistService:
test_session.add(non_deletable)
await test_session.commit()
await test_session.refresh(non_deletable)
# Extract ID before async call
non_deletable_id = non_deletable.id
@@ -521,7 +517,7 @@ class TestPlaylistService:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -535,7 +531,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -571,7 +567,7 @@ class TestPlaylistService:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -585,7 +581,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -628,7 +624,7 @@ class TestPlaylistService:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -642,7 +638,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -685,7 +681,7 @@ class TestPlaylistService:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -699,7 +695,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -734,7 +730,7 @@ class TestPlaylistService:
is_deletable=True,
)
test_session.add(test_playlist)
current_playlist = Playlist(
user_id=user_id,
name="Current Playlist",
@@ -747,7 +743,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(current_playlist)
# Extract IDs before async calls
test_playlist_id = test_playlist.id
current_playlist_id = current_playlist.id
@@ -758,7 +754,7 @@ class TestPlaylistService:
# Set test_playlist as current
updated_playlist = await playlist_service.set_current_playlist(
test_playlist_id, user_id
test_playlist_id, user_id,
)
assert updated_playlist.is_current is True
@@ -786,7 +782,7 @@ class TestPlaylistService:
is_deletable=True,
)
test_session.add(current_playlist)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
@@ -799,7 +795,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(current_playlist)
await test_session.refresh(main_playlist)
# Extract IDs before async calls
current_playlist_id = current_playlist.id
main_playlist_id = main_playlist.id
@@ -839,7 +835,7 @@ class TestPlaylistService:
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
@@ -853,7 +849,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
@@ -897,7 +893,7 @@ class TestPlaylistService:
play_count=10,
)
test_session.add(sound)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
@@ -910,7 +906,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(sound)
await test_session.refresh(main_playlist)
# Extract IDs before async calls
sound_id = sound.id
main_playlist_id = main_playlist.id
@@ -943,7 +939,7 @@ class TestPlaylistService:
play_count=10,
)
test_session.add(sound)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
@@ -956,7 +952,7 @@ class TestPlaylistService:
await test_session.commit()
await test_session.refresh(sound)
await test_session.refresh(main_playlist)
# Extract IDs before async calls
sound_id = sound.id
main_playlist_id = main_playlist.id

View File

@@ -98,7 +98,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_success(
self, mock_decode, mock_extract_token, socket_manager, mock_sio
self, mock_decode, mock_extract_token, socket_manager, mock_sio,
):
"""Test successful connection with valid token."""
# Setup mocks
@@ -133,7 +133,7 @@ class TestSocketManager:
@pytest.mark.asyncio
@patch("app.services.socket.extract_access_token_from_cookies")
async def test_connect_handler_no_token(
self, mock_extract_token, socket_manager, mock_sio
self, mock_extract_token, socket_manager, mock_sio,
):
"""Test connection with no access token."""
# Setup mocks
@@ -167,7 +167,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_invalid_token(
self, mock_decode, mock_extract_token, socket_manager, mock_sio
self, mock_decode, mock_extract_token, socket_manager, mock_sio,
):
"""Test connection with invalid token."""
# Setup mocks
@@ -202,7 +202,7 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_missing_user_id(
self, mock_decode, mock_extract_token, socket_manager, mock_sio
self, mock_decode, mock_extract_token, socket_manager, mock_sio,
):
"""Test connection with token missing user ID."""
# Setup mocks

View File

@@ -55,7 +55,7 @@ class TestSoundNormalizerService:
normalized_path = normalizer_service._get_normalized_path(sound)
assert "sounds/normalized/soundboard" in str(normalized_path)
assert "test_audio.mp3" == normalized_path.name
assert normalized_path.name == "test_audio.mp3"
def test_get_original_path(self, normalizer_service):
"""Test original path generation."""
@@ -72,7 +72,7 @@ class TestSoundNormalizerService:
original_path = normalizer_service._get_original_path(sound)
assert "sounds/originals/soundboard" in str(original_path)
assert "test_audio.wav" == original_path.name
assert original_path.name == "test_audio.wav"
def test_get_file_hash(self, normalizer_service):
"""Test file hash calculation."""
@@ -172,14 +172,14 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
patch.object(
normalizer_service, "_normalize_audio_two_pass"
normalizer_service, "_normalize_audio_two_pass",
) as mock_normalize,
patch(
"app.services.sound_normalizer.get_audio_duration", return_value=6000
"app.services.sound_normalizer.get_audio_duration", return_value=6000,
),
patch("app.services.sound_normalizer.get_file_size", return_value=2048),
patch(
"app.services.sound_normalizer.get_file_hash", return_value="new_hash"
"app.services.sound_normalizer.get_file_hash", return_value="new_hash",
),
):
# Setup path mocks
@@ -245,14 +245,14 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
patch.object(
normalizer_service, "_normalize_audio_one_pass"
normalizer_service, "_normalize_audio_one_pass",
) as mock_normalize,
patch(
"app.services.sound_normalizer.get_audio_duration", return_value=5500
"app.services.sound_normalizer.get_audio_duration", return_value=5500,
),
patch("app.services.sound_normalizer.get_file_size", return_value=1500),
patch(
"app.services.sound_normalizer.get_file_hash", return_value="norm_hash"
"app.services.sound_normalizer.get_file_hash", return_value="norm_hash",
),
):
# Setup path mocks
@@ -300,7 +300,7 @@ class TestSoundNormalizerService:
with (
patch("pathlib.Path.exists", return_value=True),
patch.object(
normalizer_service, "_normalize_audio_two_pass"
normalizer_service, "_normalize_audio_two_pass",
) as mock_normalize,
):
mock_normalize.side_effect = Exception("Normalization failed")
@@ -339,7 +339,7 @@ class TestSoundNormalizerService:
# Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock(
return_value=sounds
return_value=sounds,
)
# Mock individual normalization
@@ -399,7 +399,7 @@ class TestSoundNormalizerService:
# Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds_by_type = AsyncMock(
return_value=sdb_sounds
return_value=sdb_sounds,
)
# Mock individual normalization
@@ -428,7 +428,7 @@ class TestSoundNormalizerService:
# Verify correct repository method was called
normalizer_service.sound_repo.get_unnormalized_sounds_by_type.assert_called_once_with(
"SDB"
"SDB",
)
@pytest.mark.asyncio
@@ -459,7 +459,7 @@ class TestSoundNormalizerService:
# Mock repository calls
normalizer_service.sound_repo.get_unnormalized_sounds = AsyncMock(
return_value=sounds
return_value=sounds,
)
# Mock individual normalization with one success and one error
@@ -529,7 +529,7 @@ class TestSoundNormalizerService:
# Verify ffmpeg chain was called correctly
mock_ffmpeg.input.assert_called_once_with(str(input_path))
mock_ffmpeg.filter.assert_called_once_with(
mock_stream, "loudnorm", I=-23, TP=-2, LRA=7
mock_stream, "loudnorm", I=-23, TP=-2, LRA=7,
)
mock_ffmpeg.output.assert_called_once()
mock_ffmpeg.run.assert_called_once()

View File

@@ -153,7 +153,7 @@ class TestSoundScannerService:
"files": [],
}
await scanner_service._sync_audio_file(
temp_path, "SDB", existing_sound, results
temp_path, "SDB", existing_sound, results,
)
assert results["skipped"] == 1
@@ -257,7 +257,7 @@ class TestSoundScannerService:
"files": [],
}
await scanner_service._sync_audio_file(
temp_path, "SDB", existing_sound, results
temp_path, "SDB", existing_sound, results,
)
assert results["updated"] == 1
@@ -296,7 +296,7 @@ class TestSoundScannerService:
# Mock file operations
with (
patch(
"app.services.sound_scanner.get_file_hash", return_value="custom_hash"
"app.services.sound_scanner.get_file_hash", return_value="custom_hash",
),
patch("app.services.sound_scanner.get_audio_duration", return_value=60000),
patch("app.services.sound_scanner.get_file_size", return_value=2048),
@@ -316,7 +316,7 @@ class TestSoundScannerService:
"files": [],
}
await scanner_service._sync_audio_file(
temp_path, "CUSTOM", None, results
temp_path, "CUSTOM", None, results,
)
assert results["added"] == 1

View File

@@ -6,11 +6,10 @@ from unittest.mock import AsyncMock, Mock, patch
import pytest
from app.models.credit_transaction import CreditTransaction
from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.models.user import User
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
from app.utils.audio import get_sound_file_path
class TestVLCPlayerService:
@@ -79,16 +78,16 @@ class TestVLCPlayerService:
def test_find_vlc_executable_found_by_path(self, mock_run):
"""Test VLC executable detection when found by absolute path."""
mock_run.return_value.returncode = 1 # which command fails
# Mock Path to return True for the first absolute path
with patch("app.services.vlc_player.Path") as mock_path:
def path_side_effect(path_str):
mock_instance = Mock()
mock_instance.exists.return_value = str(path_str) == "/usr/bin/vlc"
return mock_instance
mock_path.side_effect = path_side_effect
service = VLCPlayerService()
assert service.vlc_executable == "/usr/bin/vlc"
@@ -100,10 +99,10 @@ class TestVLCPlayerService:
mock_path_instance = Mock()
mock_path_instance.exists.return_value = False
mock_path.return_value = mock_path_instance
# Mock which command as failing
mock_run.return_value.returncode = 1
service = VLCPlayerService()
assert service.vlc_executable == "vlc"
@@ -111,26 +110,26 @@ class TestVLCPlayerService:
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_success(
self, mock_subprocess, vlc_service, sample_sound
self, mock_subprocess, vlc_service, sample_sound,
):
"""Test successful sound playback."""
# Mock subprocess
mock_process = Mock()
mock_process.pid = 12345
mock_subprocess.return_value = mock_process
# Mock the file path utility to avoid Path issues
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
mock_path = Mock()
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
result = await vlc_service.play_sound(sample_sound)
assert result is True
mock_subprocess.assert_called_once()
args = mock_subprocess.call_args
# Check command arguments
cmd_args = args[1] # keyword arguments
assert "--play-and-exit" in args[0]
@@ -144,7 +143,7 @@ class TestVLCPlayerService:
@pytest.mark.asyncio
async def test_play_sound_file_not_found(
self, vlc_service, sample_sound
self, vlc_service, sample_sound,
):
"""Test sound playback when file doesn't exist."""
# Mock the file path utility to return a non-existent path
@@ -152,15 +151,15 @@ class TestVLCPlayerService:
mock_path = Mock()
mock_path.exists.return_value = False
mock_get_path.return_value = mock_path
result = await vlc_service.play_sound(sample_sound)
assert result is False
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_subprocess_error(
self, mock_subprocess, vlc_service, sample_sound
self, mock_subprocess, vlc_service, sample_sound,
):
"""Test sound playback when subprocess fails."""
# Mock the file path utility to return an existing path
@@ -168,12 +167,12 @@ class TestVLCPlayerService:
mock_path = Mock()
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
# Mock subprocess exception
mock_subprocess.side_effect = Exception("Subprocess failed")
result = await vlc_service.play_sound(sample_sound)
assert result is False
@pytest.mark.asyncio
@@ -184,27 +183,27 @@ class TestVLCPlayerService:
mock_find_process = Mock()
mock_find_process.returncode = 0
mock_find_process.communicate = AsyncMock(
return_value=(b"12345\n67890\n", b"")
return_value=(b"12345\n67890\n", b""),
)
# Mock pkill process (kill VLC processes)
mock_kill_process = Mock()
mock_kill_process.communicate = AsyncMock(return_value=(b"", b""))
# Mock verify process (check remaining processes)
mock_verify_process = Mock()
mock_verify_process.returncode = 1 # No processes found
mock_verify_process.communicate = AsyncMock(return_value=(b"", b""))
# Set up subprocess mock to return different processes for each call
mock_subprocess.side_effect = [
mock_find_process,
mock_kill_process,
mock_verify_process,
]
result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True
assert result["processes_found"] == 2
assert result["processes_killed"] == 2
@@ -214,18 +213,18 @@ class TestVLCPlayerService:
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_no_processes(
self, mock_subprocess, vlc_service
self, mock_subprocess, vlc_service,
):
"""Test stopping VLC instances when none are running."""
# Mock pgrep process (no VLC processes found)
mock_find_process = Mock()
mock_find_process.returncode = 1 # No processes found
mock_find_process.communicate = AsyncMock(return_value=(b"", b""))
mock_subprocess.return_value = mock_find_process
result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True
assert result["processes_found"] == 0
assert result["processes_killed"] == 0
@@ -234,33 +233,33 @@ class TestVLCPlayerService:
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_partial_kill(
self, mock_subprocess, vlc_service
self, mock_subprocess, vlc_service,
):
"""Test stopping VLC instances when some processes remain."""
# Mock pgrep process (find VLC processes)
mock_find_process = Mock()
mock_find_process.returncode = 0
mock_find_process.communicate = AsyncMock(
return_value=(b"12345\n67890\n11111\n", b"")
return_value=(b"12345\n67890\n11111\n", b""),
)
# Mock pkill process (kill VLC processes)
mock_kill_process = Mock()
mock_kill_process.communicate = AsyncMock(return_value=(b"", b""))
# Mock verify process (one process remains)
mock_verify_process = Mock()
mock_verify_process.returncode = 0
mock_verify_process.communicate = AsyncMock(return_value=(b"11111\n", b""))
mock_subprocess.side_effect = [
mock_find_process,
mock_kill_process,
mock_verify_process,
]
result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is True
assert result["processes_found"] == 3
assert result["processes_killed"] == 2
@@ -272,9 +271,9 @@ class TestVLCPlayerService:
"""Test stopping VLC instances when an error occurs."""
# Mock subprocess exception
mock_subprocess.side_effect = Exception("Command failed")
result = await vlc_service.stop_all_vlc_instances()
assert result["success"] is False
assert result["processes_found"] == 0
assert result["processes_killed"] == 0
@@ -286,16 +285,16 @@ class TestVLCPlayerService:
with patch("app.services.vlc_player.VLCPlayerService") as mock_service_class:
mock_instance = Mock()
mock_service_class.return_value = mock_instance
# Clear the global instance
import app.services.vlc_player
app.services.vlc_player.vlc_player_service = None
# First call should create new instance
service1 = get_vlc_player_service()
assert service1 == mock_instance
mock_service_class.assert_called_once()
# Second call should return same instance
service2 = get_vlc_player_service()
assert service2 == mock_instance
@@ -306,71 +305,70 @@ class TestVLCPlayerService:
@pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_with_play_count_tracking(
self, mock_subprocess, vlc_service_with_db, sample_sound
self, mock_subprocess, vlc_service_with_db, sample_sound,
):
"""Test sound playback with play count tracking."""
# Mock subprocess
mock_process = Mock()
mock_process.pid = 12345
mock_subprocess.return_value = mock_process
# Mock session and repositories
mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session
# Mock repositories
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket:
with patch("app.services.vlc_player.select") as mock_select:
# Mock the file path utility
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
mock_path = Mock()
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
# Mock sound repository responses
updated_sound = Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test.mp3",
duration=5000,
size=1024,
hash="test_hash",
play_count=1, # Updated count
)
mock_sound_repo.get_by_id.return_value = sample_sound
mock_sound_repo.update.return_value = updated_sound
# Mock admin user
admin_user = User(
id=1,
email="admin@test.com",
name="Admin User",
role="admin",
)
mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock()
result = await vlc_service_with_db.play_sound(sample_sound)
# Wait a bit for the async task to complete
await asyncio.sleep(0.1)
assert result is True
# Verify subprocess was called
mock_subprocess.assert_called_once()
# Note: The async task runs in the background, so we can't easily
# verify the database operations in this test without more complex
# mocking or using a real async test framework setup
# Mock the file path utility
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path:
mock_path = Mock()
mock_path.exists.return_value = True
mock_get_path.return_value = mock_path
# Mock sound repository responses
updated_sound = Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test.mp3",
duration=5000,
size=1024,
hash="test_hash",
play_count=1, # Updated count
)
mock_sound_repo.get_by_id.return_value = sample_sound
mock_sound_repo.update.return_value = updated_sound
# Mock admin user
admin_user = User(
id=1,
email="admin@test.com",
name="Admin User",
role="admin",
)
mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock()
result = await vlc_service_with_db.play_sound(sample_sound)
# Wait a bit for the async task to complete
await asyncio.sleep(0.1)
assert result is True
# Verify subprocess was called
mock_subprocess.assert_called_once()
# Note: The async task runs in the background, so we can't easily
# verify the database operations in this test without more complex
# mocking or using a real async test framework setup
@pytest.mark.asyncio
async def test_record_play_count_success(self, vlc_service_with_db):
@@ -378,10 +376,10 @@ class TestVLCPlayerService:
# Mock session and repositories
mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
# Create test sound and user
test_sound = Sound(
id=1,
@@ -399,44 +397,43 @@ class TestVLCPlayerService:
name="Admin User",
role="admin",
)
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket:
with patch("app.services.vlc_player.select") as mock_select:
# Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound
mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock()
await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound repository calls
mock_sound_repo.get_by_id.assert_called_once_with(1)
mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 1}
)
# Verify user repository calls
mock_user_repo.get_by_id.assert_called_once_with(1)
# Verify session operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.close.assert_called_once()
# Verify socket broadcast
mock_socket.broadcast_to_all.assert_called_once_with(
"sound_played",
{
"sound_id": 1,
"sound_name": "Test Sound",
"user_id": 1,
"play_count": 1,
},
)
# Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound
mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock()
await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound repository calls
mock_sound_repo.get_by_id.assert_called_once_with(1)
mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 1},
)
# Verify user repository calls
mock_user_repo.get_by_id.assert_called_once_with(1)
# Verify session operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.close.assert_called_once()
# Verify socket broadcast
mock_socket.broadcast_to_all.assert_called_once_with(
"sound_played",
{
"sound_id": 1,
"sound_name": "Test Sound",
"user_id": 1,
"play_count": 1,
},
)
@pytest.mark.asyncio
async def test_record_play_count_no_session_factory(self, vlc_service):
@@ -451,10 +448,10 @@ class TestVLCPlayerService:
# Mock session and repositories
mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
# Create test sound and user
test_sound = Sound(
id=1,
@@ -472,27 +469,27 @@ class TestVLCPlayerService:
name="Admin User",
role="admin",
)
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo):
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo):
with patch("app.services.vlc_player.socket_manager") as mock_socket:
# Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound
mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock()
await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 6}
test_sound, {"play_count": 6},
)
# Verify new SoundPlayed record was always added
mock_session.add.assert_called_once()
# Verify commit happened
mock_session.commit.assert_called_once()
@@ -502,10 +499,10 @@ class TestVLCPlayerService:
mock_file_path = Mock(spec=Path)
mock_file_path.exists.return_value = False # File doesn't exist
mock_path.return_value = mock_file_path
# This should fail because file doesn't exist
result = asyncio.run(vlc_service.play_sound(sample_sound))
# Verify the utility was called and returned False
mock_path.assert_called_once_with(sample_sound)
assert result is False
assert result is False

View File

@@ -8,7 +8,12 @@ from unittest.mock import patch
import pytest
from app.models.sound import Sound
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size, get_sound_file_path
from app.utils.audio import (
get_audio_duration,
get_file_hash,
get_file_size,
get_sound_file_path,
)
class TestAudioUtils:
@@ -301,7 +306,7 @@ class TestAudioUtils:
type="SDB",
is_normalized=False,
)
result = get_sound_file_path(sound)
expected = Path("sounds/originals/soundboard/test.mp3")
assert result == expected
@@ -310,13 +315,13 @@ class TestAudioUtils:
"""Test getting sound file path for SDB type normalized file."""
sound = Sound(
id=1,
name="Test Sound",
name="Test Sound",
filename="original.mp3",
normalized_filename="normalized.mp3",
type="SDB",
is_normalized=True,
)
result = get_sound_file_path(sound)
expected = Path("sounds/normalized/soundboard/normalized.mp3")
assert result == expected
@@ -326,11 +331,11 @@ class TestAudioUtils:
sound = Sound(
id=2,
name="TTS Sound",
filename="tts_file.wav",
filename="tts_file.wav",
type="TTS",
is_normalized=False,
)
result = get_sound_file_path(sound)
expected = Path("sounds/originals/text_to_speech/tts_file.wav")
assert result == expected
@@ -342,10 +347,10 @@ class TestAudioUtils:
name="TTS Sound",
filename="original.wav",
normalized_filename="normalized.mp3",
type="TTS",
type="TTS",
is_normalized=True,
)
result = get_sound_file_path(sound)
expected = Path("sounds/normalized/text_to_speech/normalized.mp3")
assert result == expected
@@ -359,7 +364,7 @@ class TestAudioUtils:
type="EXT",
is_normalized=False,
)
result = get_sound_file_path(sound)
expected = Path("sounds/originals/extracted/extracted.mp3")
assert result == expected
@@ -370,11 +375,11 @@ class TestAudioUtils:
id=3,
name="Extracted Sound",
filename="original.mp3",
normalized_filename="normalized.mp3",
normalized_filename="normalized.mp3",
type="EXT",
is_normalized=True,
)
result = get_sound_file_path(sound)
expected = Path("sounds/normalized/extracted/normalized.mp3")
assert result == expected
@@ -388,7 +393,7 @@ class TestAudioUtils:
type="CUSTOM",
is_normalized=False,
)
result = get_sound_file_path(sound)
expected = Path("sounds/originals/custom/unknown.mp3")
assert result == expected
@@ -403,7 +408,7 @@ class TestAudioUtils:
type="SDB",
is_normalized=True, # True but no normalized_filename
)
result = get_sound_file_path(sound)
# Should fall back to original file
expected = Path("sounds/originals/soundboard/original.mp3")

View File

@@ -1,12 +1,16 @@
"""Tests for credit decorators."""
from unittest.mock import AsyncMock, Mock
from unittest.mock import AsyncMock
import pytest
from app.models.credit_action import CreditActionType
from app.services.credit import CreditService, InsufficientCreditsError
from app.utils.credit_decorators import CreditManager, requires_credits, validate_credits_only
from app.utils.credit_decorators import (
CreditManager,
requires_credits,
validate_credits_only,
)
class TestRequiresCreditsDecorator:
@@ -32,7 +36,7 @@ class TestRequiresCreditsDecorator:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
user_id_param="user_id",
)
async def test_action(user_id: int, message: str) -> str:
return f"Success: {message}"
@@ -41,10 +45,10 @@ class TestRequiresCreditsDecorator:
assert result == "Success: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None
123, CreditActionType.VLC_PLAY_SOUND, None,
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, None
123, CreditActionType.VLC_PLAY_SOUND, True, None,
)
@pytest.mark.asyncio
@@ -58,7 +62,7 @@ class TestRequiresCreditsDecorator:
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id",
metadata_extractor=extract_metadata
metadata_extractor=extract_metadata,
)
async def test_action(user_id: int, sound_name: str) -> bool:
return True
@@ -66,10 +70,10 @@ class TestRequiresCreditsDecorator:
await test_action(user_id=123, sound_name="test.mp3")
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"}
123, CreditActionType.VLC_PLAY_SOUND, {"sound_name": "test.mp3"},
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"}
123, CreditActionType.VLC_PLAY_SOUND, True, {"sound_name": "test.mp3"},
)
@pytest.mark.asyncio
@@ -79,7 +83,7 @@ class TestRequiresCreditsDecorator:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
user_id_param="user_id",
)
async def test_action(user_id: int) -> bool:
return False # Action fails
@@ -88,7 +92,7 @@ class TestRequiresCreditsDecorator:
assert result is False
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
123, CreditActionType.VLC_PLAY_SOUND, False, None,
)
@pytest.mark.asyncio
@@ -98,7 +102,7 @@ class TestRequiresCreditsDecorator:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
user_id_param="user_id",
)
async def test_action(user_id: int) -> str:
raise ValueError("Test error")
@@ -107,7 +111,7 @@ class TestRequiresCreditsDecorator:
await test_action(user_id=123)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
123, CreditActionType.VLC_PLAY_SOUND, False, None,
)
@pytest.mark.asyncio
@@ -118,7 +122,7 @@ class TestRequiresCreditsDecorator:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
user_id_param="user_id",
)
async def test_action(user_id: int) -> str:
return "Should not execute"
@@ -136,7 +140,7 @@ class TestRequiresCreditsDecorator:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
user_id_param="user_id",
)
async def test_action(user_id: int, message: str) -> str:
return message
@@ -145,7 +149,7 @@ class TestRequiresCreditsDecorator:
assert result == "test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, None
123, CreditActionType.VLC_PLAY_SOUND, None,
)
@pytest.mark.asyncio
@@ -155,7 +159,7 @@ class TestRequiresCreditsDecorator:
@requires_credits(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
user_id_param="user_id",
)
async def test_action(other_param: str) -> str:
return other_param
@@ -186,7 +190,7 @@ class TestValidateCreditsOnlyDecorator:
@validate_credits_only(
CreditActionType.VLC_PLAY_SOUND,
credit_service_factory,
user_id_param="user_id"
user_id_param="user_id",
)
async def test_action(user_id: int, message: str) -> str:
return f"Validated: {message}"
@@ -195,7 +199,7 @@ class TestValidateCreditsOnlyDecorator:
assert result == "Validated: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND
123, CreditActionType.VLC_PLAY_SOUND,
)
# Should not deduct credits, only validate
mock_credit_service.deduct_credits.assert_not_called()
@@ -219,15 +223,15 @@ class TestCreditManager:
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND,
{"test": "data"}
{"test": "data"},
) as manager:
manager.mark_success()
mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"}
123, CreditActionType.VLC_PLAY_SOUND, {"test": "data"},
)
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"}
123, CreditActionType.VLC_PLAY_SOUND, True, {"test": "data"},
)
@pytest.mark.asyncio
@@ -236,13 +240,13 @@ class TestCreditManager:
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
CreditActionType.VLC_PLAY_SOUND,
):
# Don't mark as success - should be considered failed
pass
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
123, CreditActionType.VLC_PLAY_SOUND, False, None,
)
@pytest.mark.asyncio
@@ -252,12 +256,12 @@ class TestCreditManager:
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
CreditActionType.VLC_PLAY_SOUND,
):
raise ValueError("Test error")
mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, False, None
123, CreditActionType.VLC_PLAY_SOUND, False, None,
)
@pytest.mark.asyncio
@@ -269,9 +273,9 @@ class TestCreditManager:
async with CreditManager(
mock_credit_service,
123,
CreditActionType.VLC_PLAY_SOUND
CreditActionType.VLC_PLAY_SOUND,
):
pass
# Should not call deduct_credits since validation failed
mock_credit_service.deduct_credits.assert_not_called()
mock_credit_service.deduct_credits.assert_not_called()