Compare commits
2 Commits
c13285ca4e
...
b8346ab667
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8346ab667 | ||
|
|
f24698e3ff |
@@ -20,6 +20,11 @@ from app.schemas.auth import (
|
||||
)
|
||||
from app.services.oauth import OAuthUserInfo
|
||||
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
|
||||
from app.utils.exceptions import (
|
||||
raise_bad_request,
|
||||
raise_not_found,
|
||||
raise_unauthorized,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -39,10 +44,7 @@ class AuthService:
|
||||
|
||||
# Check if email already exists
|
||||
if await self.user_repo.email_exists(request.email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email address is already registered",
|
||||
)
|
||||
raise_bad_request("Email address is already registered")
|
||||
|
||||
# Hash the password
|
||||
hashed_password = PasswordUtils.hash_password(request.password)
|
||||
@@ -75,27 +77,18 @@ class AuthService:
|
||||
# Get user by email
|
||||
user = await self.user_repo.get_by_email(request.email)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
raise_unauthorized("Invalid email or password")
|
||||
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
raise_unauthorized("Account is deactivated")
|
||||
|
||||
# Verify password
|
||||
if not user.password_hash or not PasswordUtils.verify_password(
|
||||
request.password,
|
||||
user.password_hash,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
raise_unauthorized("Invalid email or password")
|
||||
|
||||
# Generate access token
|
||||
token = self._create_access_token(user)
|
||||
@@ -110,16 +103,10 @@ class AuthService:
|
||||
"""Get the current authenticated user."""
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
raise_not_found("User")
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
raise_unauthorized("Account is deactivated")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -10,6 +10,11 @@ from app.models.playlist import Playlist
|
||||
from app.models.sound import Sound
|
||||
from app.repositories.playlist import PlaylistRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.utils.exceptions import (
|
||||
raise_bad_request,
|
||||
raise_internal_server_error,
|
||||
raise_not_found,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -27,10 +32,7 @@ class PlaylistService:
|
||||
"""Get a playlist by ID."""
|
||||
playlist = await self.playlist_repo.get_by_id(playlist_id)
|
||||
if not playlist:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Playlist not found",
|
||||
)
|
||||
raise_not_found("Playlist")
|
||||
|
||||
return playlist
|
||||
|
||||
@@ -47,9 +49,8 @@ class PlaylistService:
|
||||
main_playlist = await self.playlist_repo.get_main_playlist()
|
||||
|
||||
if not main_playlist:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Main playlist not found. Make sure to run database seeding."
|
||||
raise_internal_server_error(
|
||||
"Main playlist not found. Make sure to run database seeding."
|
||||
)
|
||||
|
||||
return main_playlist
|
||||
|
||||
166
app/utils/database.py
Normal file
166
app/utils/database.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Database utility functions for common operations."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar
|
||||
from sqlmodel import select, SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
T = TypeVar("T", bound=SQLModel)
|
||||
|
||||
|
||||
async def create_and_save(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
**kwargs: Any
|
||||
) -> T:
|
||||
"""Create, add, commit, and refresh a model instance.
|
||||
|
||||
This consolidates the common database pattern of:
|
||||
- instance = ModelClass(**kwargs)
|
||||
- session.add(instance)
|
||||
- await session.commit()
|
||||
- await session.refresh(instance)
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class to instantiate
|
||||
**kwargs: Arguments to pass to model constructor
|
||||
|
||||
Returns:
|
||||
Created and refreshed model instance
|
||||
"""
|
||||
instance = model_class(**kwargs)
|
||||
session.add(instance)
|
||||
await session.commit()
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
async def get_or_create(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
defaults: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> tuple[T, bool]:
|
||||
"""Get an existing instance or create a new one.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class
|
||||
defaults: Default values for creation (if not found)
|
||||
**kwargs: Filter criteria for lookup
|
||||
|
||||
Returns:
|
||||
Tuple of (instance, created) where created is True if instance was created
|
||||
"""
|
||||
# Build filter conditions
|
||||
filters = []
|
||||
for key, value in kwargs.items():
|
||||
filters.append(getattr(model_class, key) == value)
|
||||
|
||||
# Try to find existing instance
|
||||
statement = select(model_class).where(*filters)
|
||||
result = await session.exec(statement)
|
||||
instance = result.first()
|
||||
|
||||
if instance:
|
||||
return instance, False
|
||||
|
||||
# Create new instance
|
||||
create_kwargs = {**kwargs}
|
||||
if defaults:
|
||||
create_kwargs.update(defaults)
|
||||
|
||||
instance = await create_and_save(session, model_class, **create_kwargs)
|
||||
return instance, True
|
||||
|
||||
|
||||
async def update_and_save(
|
||||
session: AsyncSession,
|
||||
instance: T,
|
||||
**updates: Any
|
||||
) -> T:
|
||||
"""Update model instance fields and save to database.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
instance: Model instance to update
|
||||
**updates: Field updates to apply
|
||||
|
||||
Returns:
|
||||
Updated and refreshed model instance
|
||||
"""
|
||||
for field, value in updates.items():
|
||||
setattr(instance, field, value)
|
||||
|
||||
session.add(instance)
|
||||
await session.commit()
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
async def bulk_create(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
items: List[Dict[str, Any]]
|
||||
) -> List[T]:
|
||||
"""Create multiple model instances in bulk.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class to instantiate
|
||||
items: List of dictionaries with model data
|
||||
|
||||
Returns:
|
||||
List of created model instances
|
||||
"""
|
||||
instances = []
|
||||
for item_data in items:
|
||||
instance = model_class(**item_data)
|
||||
session.add(instance)
|
||||
instances.append(instance)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Refresh all instances
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
|
||||
return instances
|
||||
|
||||
|
||||
async def delete_and_commit(
|
||||
session: AsyncSession,
|
||||
instance: T
|
||||
) -> None:
|
||||
"""Delete an instance and commit the transaction.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
instance: Model instance to delete
|
||||
"""
|
||||
await session.delete(instance)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def exists(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
**kwargs: Any
|
||||
) -> bool:
|
||||
"""Check if a model instance exists with given criteria.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class
|
||||
**kwargs: Filter criteria
|
||||
|
||||
Returns:
|
||||
True if instance exists, False otherwise
|
||||
"""
|
||||
filters = []
|
||||
for key, value in kwargs.items():
|
||||
filters.append(getattr(model_class, key) == value)
|
||||
|
||||
statement = select(model_class).where(*filters)
|
||||
result = await session.exec(statement)
|
||||
return result.first() is not None
|
||||
121
app/utils/exceptions.py
Normal file
121
app/utils/exceptions.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Utility functions for common HTTP exception patterns."""
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def raise_not_found(resource: str, identifier: str = None) -> None:
|
||||
"""Raise a standardized 404 Not Found exception.
|
||||
|
||||
Args:
|
||||
resource: Name of the resource that wasn't found
|
||||
identifier: Optional identifier for the specific resource
|
||||
|
||||
Raises:
|
||||
HTTPException with 404 status code
|
||||
"""
|
||||
if identifier:
|
||||
detail = f"{resource} with ID {identifier} not found"
|
||||
else:
|
||||
detail = f"{resource} not found"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_unauthorized(detail: str = "Could not validate credentials") -> None:
|
||||
"""Raise a standardized 401 Unauthorized exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 401 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_bad_request(detail: str) -> None:
|
||||
"""Raise a standardized 400 Bad Request exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 400 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_internal_server_error(detail: str, cause: Exception = None) -> None:
|
||||
"""Raise a standardized 500 Internal Server Error exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
cause: Optional underlying exception
|
||||
|
||||
Raises:
|
||||
HTTPException with 500 status code
|
||||
"""
|
||||
if cause:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail,
|
||||
) from cause
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_payment_required(detail: str = "Insufficient credits") -> None:
|
||||
"""Raise a standardized 402 Payment Required exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 402 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_forbidden(detail: str = "Access forbidden") -> None:
|
||||
"""Raise a standardized 403 Forbidden exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 403 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def raise_conflict(detail: str) -> None:
|
||||
"""Raise a standardized 409 Conflict exception.
|
||||
|
||||
Args:
|
||||
detail: Error message detail
|
||||
|
||||
Raises:
|
||||
HTTPException with 409 status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=detail,
|
||||
)
|
||||
179
app/utils/test_helpers.py
Normal file
179
app/utils/test_helpers.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Test helper utilities for reducing code duplication."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Optional, Type, TypeVar
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from app.models.user import User
|
||||
from app.utils.auth import JWTUtils
|
||||
|
||||
T = TypeVar("T", bound=SQLModel)
|
||||
|
||||
|
||||
def create_jwt_token_data(user: User) -> Dict[str, str]:
|
||||
"""Create standardized JWT token data dictionary for a user.
|
||||
|
||||
Args:
|
||||
user: User object to create token data for
|
||||
|
||||
Returns:
|
||||
Dictionary with sub, email, and role fields
|
||||
"""
|
||||
return {
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
|
||||
def create_access_token_for_user(user: User) -> str:
|
||||
"""Create an access token for a user using standardized token data.
|
||||
|
||||
Args:
|
||||
user: User object to create token for
|
||||
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
token_data = create_jwt_token_data(user)
|
||||
return JWTUtils.create_access_token(token_data)
|
||||
|
||||
|
||||
async def create_and_save_model(
|
||||
session: AsyncSession,
|
||||
model_class: Type[T],
|
||||
**kwargs: Any
|
||||
) -> T:
|
||||
"""Create, save, and refresh a model instance.
|
||||
|
||||
This consolidates the common pattern of:
|
||||
- model = ModelClass(**kwargs)
|
||||
- session.add(model)
|
||||
- await session.commit()
|
||||
- await session.refresh(model)
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
model_class: SQLModel class to instantiate
|
||||
**kwargs: Arguments to pass to model constructor
|
||||
|
||||
Returns:
|
||||
Created and refreshed model instance
|
||||
"""
|
||||
instance = model_class(**kwargs)
|
||||
session.add(instance)
|
||||
await session.commit()
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def override_dependencies(
|
||||
app: FastAPI,
|
||||
overrides: Dict[Any, Any]
|
||||
):
|
||||
"""Context manager for FastAPI dependency overrides with automatic cleanup.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
overrides: Dictionary mapping dependency functions to mock implementations
|
||||
|
||||
Usage:
|
||||
async with override_dependencies(test_app, {
|
||||
get_service: lambda: mock_service,
|
||||
get_repo: lambda: mock_repo
|
||||
}):
|
||||
# Test code here
|
||||
pass
|
||||
# Dependencies automatically cleaned up
|
||||
"""
|
||||
# Apply overrides
|
||||
for dependency, override in overrides.items():
|
||||
app.dependency_overrides[dependency] = override
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Clean up overrides
|
||||
for dependency in overrides:
|
||||
app.dependency_overrides.pop(dependency, None)
|
||||
|
||||
|
||||
def create_mock_vlc_services() -> Dict[str, AsyncMock]:
|
||||
"""Create standard set of mocked VLC-related services.
|
||||
|
||||
Returns:
|
||||
Dictionary with mocked vlc_service, sound_repository, and credit_service
|
||||
"""
|
||||
return {
|
||||
"vlc_service": AsyncMock(),
|
||||
"sound_repository": AsyncMock(),
|
||||
"credit_service": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
def configure_mock_sound_play_success(
|
||||
mocks: Dict[str, AsyncMock],
|
||||
sound_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Configure mocks for successful sound playback scenario.
|
||||
|
||||
Args:
|
||||
mocks: Dictionary of mock services from create_mock_vlc_services()
|
||||
sound_data: Dictionary with sound properties (id, name, etc.)
|
||||
"""
|
||||
from app.models.sound import Sound
|
||||
|
||||
mock_sound = Sound(**sound_data)
|
||||
|
||||
# Configure repository mock
|
||||
mocks["sound_repository"].get_by_id.return_value = mock_sound
|
||||
|
||||
# Configure credit service mocks
|
||||
mocks["credit_service"].validate_and_reserve_credits.return_value = None
|
||||
mocks["credit_service"].deduct_credits.return_value = None
|
||||
|
||||
# Configure VLC service mock
|
||||
mocks["vlc_service"].play_sound.return_value = True
|
||||
|
||||
|
||||
def create_mock_vlc_stop_result(
|
||||
success: bool = True,
|
||||
processes_found: int = 3,
|
||||
processes_killed: int = 3,
|
||||
processes_remaining: int = 0,
|
||||
message: Optional[str] = None,
|
||||
error: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create standardized VLC stop operation result.
|
||||
|
||||
Args:
|
||||
success: Whether operation succeeded
|
||||
processes_found: Number of VLC processes found
|
||||
processes_killed: Number of processes successfully killed
|
||||
processes_remaining: Number of processes still running
|
||||
message: Success/status message
|
||||
error: Error message (for failed operations)
|
||||
|
||||
Returns:
|
||||
Dictionary with VLC stop operation result
|
||||
"""
|
||||
result = {
|
||||
"success": success,
|
||||
"processes_found": processes_found,
|
||||
"processes_killed": processes_killed,
|
||||
}
|
||||
|
||||
if not success:
|
||||
result["error"] = error or "Command failed"
|
||||
result["message"] = message or "Failed to stop VLC processes"
|
||||
else:
|
||||
# Always include processes_remaining for successful operations
|
||||
result["processes_remaining"] = processes_remaining
|
||||
result["message"] = message or f"Killed {processes_killed} VLC processes"
|
||||
|
||||
return result
|
||||
140
app/utils/validation.py
Normal file
140
app/utils/validation.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Common validation utility functions."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
# Password validation constants
|
||||
MIN_PASSWORD_LENGTH = 8
|
||||
|
||||
|
||||
def validate_email(email: str) -> bool:
|
||||
"""Validate email address format.
|
||||
|
||||
Args:
|
||||
email: Email address to validate
|
||||
|
||||
Returns:
|
||||
True if email format is valid, False otherwise
|
||||
"""
|
||||
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
return bool(re.match(pattern, email))
|
||||
|
||||
|
||||
def validate_password_strength(password: str) -> tuple[bool, str | None]:
|
||||
"""Validate password meets security requirements.
|
||||
|
||||
Args:
|
||||
password: Password to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
if len(password) < MIN_PASSWORD_LENGTH:
|
||||
msg = f"Password must be at least {MIN_PASSWORD_LENGTH} characters long"
|
||||
return False, msg
|
||||
|
||||
if not re.search(r"[A-Z]", password):
|
||||
return False, "Password must contain at least one uppercase letter"
|
||||
|
||||
if not re.search(r"[a-z]", password):
|
||||
return False, "Password must contain at least one lowercase letter"
|
||||
|
||||
if not re.search(r"\d", password):
|
||||
return False, "Password must contain at least one number"
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def validate_filename(
|
||||
filename: str, allowed_extensions: list[str] | None = None
|
||||
) -> bool:
|
||||
"""Validate filename format and extension.
|
||||
|
||||
Args:
|
||||
filename: Filename to validate
|
||||
allowed_extensions: List of allowed file extensions (with dots)
|
||||
|
||||
Returns:
|
||||
True if filename is valid, False otherwise
|
||||
"""
|
||||
if not filename or filename.startswith(".") or "/" in filename or "\\" in filename:
|
||||
return False
|
||||
|
||||
if allowed_extensions:
|
||||
file_path = Path(filename)
|
||||
return file_path.suffix.lower() in [ext.lower() for ext in allowed_extensions]
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_audio_filename(filename: str) -> bool:
|
||||
"""Validate audio filename has allowed extension.
|
||||
|
||||
Args:
|
||||
filename: Audio filename to validate
|
||||
|
||||
Returns:
|
||||
True if filename has valid audio extension, False otherwise
|
||||
"""
|
||||
audio_extensions = [".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"]
|
||||
return validate_filename(filename, audio_extensions)
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""Sanitize filename by removing/replacing invalid characters.
|
||||
|
||||
Args:
|
||||
filename: Filename to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized filename safe for filesystem
|
||||
"""
|
||||
# Remove or replace problematic characters
|
||||
sanitized = re.sub(r'[<>:"/\\|?*]', "_", filename)
|
||||
|
||||
# Remove leading/trailing whitespace and dots
|
||||
sanitized = sanitized.strip(" .")
|
||||
|
||||
# Ensure not empty
|
||||
if not sanitized:
|
||||
sanitized = "untitled"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def validate_url(url: str) -> bool:
|
||||
"""Validate URL format.
|
||||
|
||||
Args:
|
||||
url: URL to validate
|
||||
|
||||
Returns:
|
||||
True if URL format is valid, False otherwise
|
||||
"""
|
||||
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
|
||||
return bool(re.match(pattern, url))
|
||||
|
||||
|
||||
def validate_positive_integer(value: Any, field_name: str = "value") -> int:
|
||||
"""Validate and convert value to positive integer.
|
||||
|
||||
Args:
|
||||
value: Value to validate and convert
|
||||
field_name: Name of field for error messages
|
||||
|
||||
Returns:
|
||||
Validated positive integer
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a positive integer
|
||||
"""
|
||||
try:
|
||||
int_value = int(value)
|
||||
if int_value <= 0:
|
||||
msg = f"{field_name} must be a positive integer"
|
||||
raise ValueError(msg)
|
||||
return int_value
|
||||
except (TypeError, ValueError) as e:
|
||||
msg = f"{field_name} must be a positive integer"
|
||||
raise ValueError(msg) from e
|
||||
@@ -516,10 +516,9 @@ class TestPlayerEndpoints:
|
||||
"status": PlayerStatus.PLAYING.value,
|
||||
"mode": PlayerMode.CONTINUOUS.value,
|
||||
"volume": 50,
|
||||
"current_sound_id": 1,
|
||||
"current_sound_index": 0,
|
||||
"current_sound_position": 5000,
|
||||
"current_sound_duration": 30000,
|
||||
"position_ms": 5000,
|
||||
"duration_ms": 30000,
|
||||
"index": 0,
|
||||
"current_sound": {
|
||||
"id": 1,
|
||||
"name": "Test Song",
|
||||
@@ -530,11 +529,13 @@ class TestPlayerEndpoints:
|
||||
"thumbnail": None,
|
||||
"play_count": 0,
|
||||
},
|
||||
"playlist_id": 1,
|
||||
"playlist_name": "Test Playlist",
|
||||
"playlist_length": 1,
|
||||
"playlist_duration": 30000,
|
||||
"playlist_sounds": [],
|
||||
"playlist": {
|
||||
"id": 1,
|
||||
"name": "Test Playlist",
|
||||
"length": 1,
|
||||
"duration": 30000,
|
||||
"sounds": [],
|
||||
},
|
||||
}
|
||||
mock_player_service.get_state.return_value = mock_state
|
||||
|
||||
|
||||
@@ -1,12 +1,22 @@
|
||||
"""Tests for VLC player API endpoints."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.models.sound import Sound
|
||||
from app.models.user import User
|
||||
from app.api.v1.sounds import get_vlc_player, get_sound_repository, get_credit_service
|
||||
from app.utils.test_helpers import (
|
||||
override_dependencies,
|
||||
create_mock_vlc_services,
|
||||
configure_mock_sound_play_success,
|
||||
create_mock_vlc_stop_result,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
class TestVLCEndpoints:
|
||||
@@ -15,101 +25,158 @@ class TestVLCEndpoints:
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_success(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test successful sound playback via VLC."""
|
||||
# Mock the VLC player service and sound repository methods
|
||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
||||
mock_play_sound.return_value = True
|
||||
# Set up mocks using helper
|
||||
mocks = create_mock_vlc_services()
|
||||
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
)
|
||||
mock_get_by_id.return_value = mock_sound
|
||||
# Configure mocks for successful play
|
||||
sound_data = {
|
||||
"id": 1,
|
||||
"type": "SDB",
|
||||
"name": "Test Sound",
|
||||
"filename": "test.mp3",
|
||||
"duration": 5000,
|
||||
"size": 1024,
|
||||
"hash": "test_hash",
|
||||
}
|
||||
configure_mock_sound_play_success(mocks, sound_data)
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||
# Use dependency override helper
|
||||
overrides = {
|
||||
get_vlc_player: lambda: mocks["vlc_service"],
|
||||
get_sound_repository: lambda: mocks["sound_repository"],
|
||||
get_credit_service: lambda: mocks["credit_service"],
|
||||
}
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["sound_id"] == 1
|
||||
assert data["sound_name"] == "Test Sound"
|
||||
assert "Test Sound" in data["message"]
|
||||
async with override_dependencies(test_app, overrides):
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||
|
||||
# Verify service calls
|
||||
mock_get_by_id.assert_called_once_with(1)
|
||||
mock_play_sound.assert_called_once_with(mock_sound)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["sound_id"] == 1
|
||||
assert data["sound_name"] == "Test Sound"
|
||||
assert "Test Sound" in data["message"]
|
||||
|
||||
# Verify service calls
|
||||
mocks["sound_repository"].get_by_id.assert_called_once_with(1)
|
||||
mocks["vlc_service"].play_sound.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_sound_not_found(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test VLC playback when sound is not found."""
|
||||
# Mock the sound repository to return None
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_get_by_id.return_value = None
|
||||
# Set up mocks
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_repo = AsyncMock()
|
||||
mock_credit_service = AsyncMock()
|
||||
|
||||
# Configure mocks
|
||||
mock_repo.get_by_id.return_value = None
|
||||
|
||||
# Override dependencies
|
||||
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
|
||||
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
|
||||
|
||||
try:
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "Sound with ID 999 not found" in data["detail"]
|
||||
finally:
|
||||
# Clean up dependency overrides (except get_db which is needed for other tests)
|
||||
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||
test_app.dependency_overrides.pop(get_sound_repository, None)
|
||||
test_app.dependency_overrides.pop(get_credit_service, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_launch_failure(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test VLC playback when VLC launch fails."""
|
||||
# Mock the VLC player service to fail
|
||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
||||
mock_play_sound.return_value = False
|
||||
# Set up mocks
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_repo = AsyncMock()
|
||||
mock_credit_service = AsyncMock()
|
||||
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
)
|
||||
mock_get_by_id.return_value = mock_sound
|
||||
# Set up test data
|
||||
mock_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Test Sound",
|
||||
filename="test.mp3",
|
||||
duration=5000,
|
||||
size=1024,
|
||||
hash="test_hash",
|
||||
)
|
||||
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||
# Configure mocks
|
||||
mock_repo.get_by_id.return_value = mock_sound
|
||||
mock_credit_service.validate_and_reserve_credits.return_value = None
|
||||
mock_credit_service.deduct_credits.return_value = None
|
||||
mock_vlc_service.play_sound.return_value = False
|
||||
|
||||
assert response.status_code == 500
|
||||
data = response.json()
|
||||
assert "Failed to launch VLC for sound playback" in data["detail"]
|
||||
# Override dependencies
|
||||
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
|
||||
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
|
||||
|
||||
try:
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||
|
||||
assert response.status_code == 500
|
||||
data = response.json()
|
||||
assert "Failed to launch VLC for sound playback" in data["detail"]
|
||||
finally:
|
||||
# Clean up dependency overrides (except get_db which is needed for other tests)
|
||||
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||
test_app.dependency_overrides.pop(get_sound_repository, None)
|
||||
test_app.dependency_overrides.pop(get_credit_service, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_service_exception(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test VLC playback when service raises an exception."""
|
||||
# Mock the sound repository to raise an exception
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_get_by_id.side_effect = Exception("Database error")
|
||||
# Set up mocks
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_repo = AsyncMock()
|
||||
mock_credit_service = AsyncMock()
|
||||
|
||||
# Configure mocks
|
||||
mock_repo.get_by_id.side_effect = Exception("Database error")
|
||||
|
||||
# Override dependencies
|
||||
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
|
||||
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
|
||||
|
||||
try:
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||
|
||||
assert response.status_code == 500
|
||||
data = response.json()
|
||||
assert "Failed to play sound" in data["detail"]
|
||||
finally:
|
||||
# Clean up dependency overrides (except get_db which is needed for other tests)
|
||||
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||
test_app.dependency_overrides.pop(get_sound_repository, None)
|
||||
test_app.dependency_overrides.pop(get_credit_service, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_play_sound_with_vlc_unauthenticated(
|
||||
@@ -123,21 +190,25 @@ class TestVLCEndpoints:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_success(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test successful stopping of all VLC instances."""
|
||||
# Mock the VLC player service
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 3,
|
||||
"processes_killed": 3,
|
||||
"processes_remaining": 0,
|
||||
"message": "Killed 3 VLC processes",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
# Set up mock using helper
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_result = create_mock_vlc_stop_result(
|
||||
success=True,
|
||||
processes_found=3,
|
||||
processes_killed=3,
|
||||
processes_remaining=0
|
||||
)
|
||||
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
|
||||
|
||||
# Use dependency override helper
|
||||
overrides = {get_vlc_player: lambda: mock_vlc_service}
|
||||
|
||||
async with override_dependencies(test_app, overrides):
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -149,25 +220,30 @@ class TestVLCEndpoints:
|
||||
assert "Killed 3 VLC processes" in data["message"]
|
||||
|
||||
# Verify service call
|
||||
mock_stop_all.assert_called_once()
|
||||
mock_vlc_service.stop_all_vlc_instances.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_no_processes(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test stopping VLC instances when none are running."""
|
||||
# Mock the VLC player service
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 0,
|
||||
"processes_killed": 0,
|
||||
"message": "No VLC processes found",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
# Set up mock using helper
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_result = create_mock_vlc_stop_result(
|
||||
success=True,
|
||||
processes_found=0,
|
||||
processes_killed=0,
|
||||
message="No VLC processes found"
|
||||
)
|
||||
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
|
||||
|
||||
# Use dependency override helper
|
||||
overrides = {get_vlc_player: lambda: mock_vlc_service}
|
||||
|
||||
async with override_dependencies(test_app, overrides):
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -180,21 +256,26 @@ class TestVLCEndpoints:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_partial_success(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test stopping VLC instances with partial success."""
|
||||
# Mock the VLC player service
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 3,
|
||||
"processes_killed": 2,
|
||||
"processes_remaining": 1,
|
||||
"message": "Killed 2 VLC processes",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
# Set up mock
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 3,
|
||||
"processes_killed": 2,
|
||||
"processes_remaining": 1,
|
||||
"message": "Killed 2 VLC processes",
|
||||
}
|
||||
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
|
||||
|
||||
# Override dependency
|
||||
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||
|
||||
try:
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -203,25 +284,33 @@ class TestVLCEndpoints:
|
||||
assert data["processes_found"] == 3
|
||||
assert data["processes_killed"] == 2
|
||||
assert data["processes_remaining"] == 1
|
||||
finally:
|
||||
# Clean up dependency override
|
||||
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_failure(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test stopping VLC instances when service fails."""
|
||||
# Mock the VLC player service
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": False,
|
||||
"processes_found": 0,
|
||||
"processes_killed": 0,
|
||||
"error": "Command failed",
|
||||
"message": "Failed to stop VLC processes",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
# Set up mock
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_result = {
|
||||
"success": False,
|
||||
"processes_found": 0,
|
||||
"processes_killed": 0,
|
||||
"error": "Command failed",
|
||||
"message": "Failed to stop VLC processes",
|
||||
}
|
||||
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
|
||||
|
||||
# Override dependency
|
||||
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||
|
||||
try:
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -229,23 +318,34 @@ class TestVLCEndpoints:
|
||||
assert data["success"] is False
|
||||
assert data["error"] == "Command failed"
|
||||
assert data["message"] == "Failed to stop VLC processes"
|
||||
finally:
|
||||
# Clean up dependency override
|
||||
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_service_exception(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_client: AsyncClient,
|
||||
authenticated_user: User,
|
||||
):
|
||||
"""Test stopping VLC instances when service raises an exception."""
|
||||
# Mock the VLC player service to raise an exception
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_stop_all.side_effect = Exception("Service error")
|
||||
# Set up mock to raise an exception
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_vlc_service.stop_all_vlc_instances.side_effect = Exception("Service error")
|
||||
|
||||
# Override dependency
|
||||
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||
|
||||
try:
|
||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 500
|
||||
data = response.json()
|
||||
assert "Failed to stop VLC instances" in data["detail"]
|
||||
finally:
|
||||
# Clean up dependency override
|
||||
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all_vlc_instances_unauthenticated(
|
||||
@@ -259,47 +359,71 @@ class TestVLCEndpoints:
|
||||
@pytest.mark.asyncio
|
||||
async def test_vlc_endpoints_with_admin_user(
|
||||
self,
|
||||
test_app: FastAPI,
|
||||
authenticated_admin_client: AsyncClient,
|
||||
admin_user: User,
|
||||
):
|
||||
"""Test VLC endpoints work with admin user."""
|
||||
# Test play endpoint with admin
|
||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
||||
mock_play_sound.return_value = True
|
||||
# Set up mocks
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_repo = AsyncMock()
|
||||
mock_credit_service = AsyncMock()
|
||||
|
||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
||||
mock_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Admin Test Sound",
|
||||
filename="admin_test.mp3",
|
||||
duration=3000,
|
||||
size=512,
|
||||
hash="admin_hash",
|
||||
)
|
||||
mock_get_by_id.return_value = mock_sound
|
||||
# Set up test data
|
||||
mock_sound = Sound(
|
||||
id=1,
|
||||
type="SDB",
|
||||
name="Admin Test Sound",
|
||||
filename="admin_test.mp3",
|
||||
duration=3000,
|
||||
size=512,
|
||||
hash="admin_hash",
|
||||
)
|
||||
|
||||
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
|
||||
# Configure mocks
|
||||
mock_repo.get_by_id.return_value = mock_sound
|
||||
mock_credit_service.validate_and_reserve_credits.return_value = None
|
||||
mock_credit_service.deduct_credits.return_value = None
|
||||
mock_vlc_service.play_sound.return_value = True
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["sound_name"] == "Admin Test Sound"
|
||||
# Override dependencies
|
||||
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
|
||||
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
|
||||
|
||||
try:
|
||||
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["sound_name"] == "Admin Test Sound"
|
||||
finally:
|
||||
# Clean up dependency overrides (except get_db which is needed for other tests)
|
||||
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||
test_app.dependency_overrides.pop(get_sound_repository, None)
|
||||
test_app.dependency_overrides.pop(get_credit_service, None)
|
||||
|
||||
# Test stop-all endpoint with admin
|
||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 1,
|
||||
"processes_killed": 1,
|
||||
"processes_remaining": 0,
|
||||
"message": "Killed 1 VLC processes",
|
||||
}
|
||||
mock_stop_all.return_value = mock_result
|
||||
mock_vlc_service_2 = AsyncMock()
|
||||
mock_result = {
|
||||
"success": True,
|
||||
"processes_found": 1,
|
||||
"processes_killed": 1,
|
||||
"processes_remaining": 0,
|
||||
"message": "Killed 1 VLC processes",
|
||||
}
|
||||
mock_vlc_service_2.stop_all_vlc_instances.return_value = mock_result
|
||||
|
||||
# Override dependency for stop-all test
|
||||
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service_2
|
||||
|
||||
try:
|
||||
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["processes_killed"] == 1
|
||||
finally:
|
||||
# Clean up dependency override
|
||||
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||
@@ -18,6 +18,7 @@ from app.models.plan import Plan
|
||||
from app.models.user import User
|
||||
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
|
||||
from app.utils.auth import JWTUtils, PasswordUtils
|
||||
from app.utils.test_helpers import create_access_token_for_user
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -277,28 +278,14 @@ def test_login_data() -> dict[str, str]:
|
||||
@pytest_asyncio.fixture
|
||||
async def auth_headers(test_user: User) -> dict[str, str]:
|
||||
"""Create authentication headers with JWT token."""
|
||||
token_data = {
|
||||
"sub": str(test_user.id),
|
||||
"email": test_user.email,
|
||||
"role": test_user.role,
|
||||
}
|
||||
|
||||
access_token = JWTUtils.create_access_token(token_data)
|
||||
|
||||
access_token = create_access_token_for_user(test_user)
|
||||
return {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_headers(admin_user: User) -> dict[str, str]:
|
||||
"""Create admin authentication headers with JWT token."""
|
||||
token_data = {
|
||||
"sub": str(admin_user.id),
|
||||
"email": admin_user.email,
|
||||
"role": admin_user.role,
|
||||
}
|
||||
|
||||
access_token = JWTUtils.create_access_token(token_data)
|
||||
|
||||
access_token = create_access_token_for_user(admin_user)
|
||||
return {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
|
||||
@@ -317,26 +304,12 @@ def authenticated_user(test_user: User) -> User:
|
||||
@pytest_asyncio.fixture
|
||||
async def auth_cookies(test_user: User) -> dict[str, str]:
|
||||
"""Create authentication cookies with JWT token."""
|
||||
token_data = {
|
||||
"sub": str(test_user.id),
|
||||
"email": test_user.email,
|
||||
"role": test_user.role,
|
||||
}
|
||||
|
||||
access_token = JWTUtils.create_access_token(token_data)
|
||||
|
||||
access_token = create_access_token_for_user(test_user)
|
||||
return {"access_token": access_token}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_cookies(admin_user: User) -> dict[str, str]:
|
||||
"""Create admin authentication cookies with JWT token."""
|
||||
token_data = {
|
||||
"sub": str(admin_user.id),
|
||||
"email": admin_user.email,
|
||||
"role": admin_user.role,
|
||||
}
|
||||
|
||||
access_token = JWTUtils.create_access_token(token_data)
|
||||
|
||||
access_token = create_access_token_for_user(admin_user)
|
||||
return {"access_token": access_token}
|
||||
|
||||
@@ -11,6 +11,7 @@ from app.models.plan import Plan
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.utils.auth import PasswordUtils
|
||||
from app.utils.database import create_and_save
|
||||
|
||||
|
||||
class TestUserRepository:
|
||||
|
||||
Reference in New Issue
Block a user