Compare commits

..

2 Commits

10 changed files with 901 additions and 208 deletions

View File

@@ -20,6 +20,11 @@ from app.schemas.auth import (
) )
from app.services.oauth import OAuthUserInfo from app.services.oauth import OAuthUserInfo
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
from app.utils.exceptions import (
raise_bad_request,
raise_not_found,
raise_unauthorized,
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -39,10 +44,7 @@ class AuthService:
# Check if email already exists # Check if email already exists
if await self.user_repo.email_exists(request.email): if await self.user_repo.email_exists(request.email):
raise HTTPException( raise_bad_request("Email address is already registered")
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email address is already registered",
)
# Hash the password # Hash the password
hashed_password = PasswordUtils.hash_password(request.password) hashed_password = PasswordUtils.hash_password(request.password)
@@ -75,27 +77,18 @@ class AuthService:
# Get user by email # Get user by email
user = await self.user_repo.get_by_email(request.email) user = await self.user_repo.get_by_email(request.email)
if not user: if not user:
raise HTTPException( raise_unauthorized("Invalid email or password")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
# Check if user is active # Check if user is active
if not user.is_active: if not user.is_active:
raise HTTPException( raise_unauthorized("Account is deactivated")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
# Verify password # Verify password
if not user.password_hash or not PasswordUtils.verify_password( if not user.password_hash or not PasswordUtils.verify_password(
request.password, request.password,
user.password_hash, user.password_hash,
): ):
raise HTTPException( raise_unauthorized("Invalid email or password")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
# Generate access token # Generate access token
token = self._create_access_token(user) token = self._create_access_token(user)
@@ -110,16 +103,10 @@ class AuthService:
"""Get the current authenticated user.""" """Get the current authenticated user."""
user = await self.user_repo.get_by_id(user_id) user = await self.user_repo.get_by_id(user_id)
if not user: if not user:
raise HTTPException( raise_not_found("User")
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
if not user.is_active: if not user.is_active:
raise HTTPException( raise_unauthorized("Account is deactivated")
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return user return user

View File

@@ -10,6 +10,11 @@ from app.models.playlist import Playlist
from app.models.sound import Sound from app.models.sound import Sound
from app.repositories.playlist import PlaylistRepository from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
from app.utils.exceptions import (
raise_bad_request,
raise_internal_server_error,
raise_not_found,
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -27,10 +32,7 @@ class PlaylistService:
"""Get a playlist by ID.""" """Get a playlist by ID."""
playlist = await self.playlist_repo.get_by_id(playlist_id) playlist = await self.playlist_repo.get_by_id(playlist_id)
if not playlist: if not playlist:
raise HTTPException( raise_not_found("Playlist")
status_code=status.HTTP_404_NOT_FOUND,
detail="Playlist not found",
)
return playlist return playlist
@@ -47,9 +49,8 @@ class PlaylistService:
main_playlist = await self.playlist_repo.get_main_playlist() main_playlist = await self.playlist_repo.get_main_playlist()
if not main_playlist: if not main_playlist:
raise HTTPException( raise_internal_server_error(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, "Main playlist not found. Make sure to run database seeding."
detail="Main playlist not found. Make sure to run database seeding."
) )
return main_playlist return main_playlist

166
app/utils/database.py Normal file
View File

@@ -0,0 +1,166 @@
"""Database utility functions for common operations."""
from typing import Any, Dict, List, Optional, Type, TypeVar
from sqlmodel import select, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
T = TypeVar("T", bound=SQLModel)
async def create_and_save(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> T:
"""Create, add, commit, and refresh a model instance.
This consolidates the common database pattern of:
- instance = ModelClass(**kwargs)
- session.add(instance)
- await session.commit()
- await session.refresh(instance)
Args:
session: Database session
model_class: SQLModel class to instantiate
**kwargs: Arguments to pass to model constructor
Returns:
Created and refreshed model instance
"""
instance = model_class(**kwargs)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
async def get_or_create(
session: AsyncSession,
model_class: Type[T],
defaults: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> tuple[T, bool]:
"""Get an existing instance or create a new one.
Args:
session: Database session
model_class: SQLModel class
defaults: Default values for creation (if not found)
**kwargs: Filter criteria for lookup
Returns:
Tuple of (instance, created) where created is True if instance was created
"""
# Build filter conditions
filters = []
for key, value in kwargs.items():
filters.append(getattr(model_class, key) == value)
# Try to find existing instance
statement = select(model_class).where(*filters)
result = await session.exec(statement)
instance = result.first()
if instance:
return instance, False
# Create new instance
create_kwargs = {**kwargs}
if defaults:
create_kwargs.update(defaults)
instance = await create_and_save(session, model_class, **create_kwargs)
return instance, True
async def update_and_save(
session: AsyncSession,
instance: T,
**updates: Any
) -> T:
"""Update model instance fields and save to database.
Args:
session: Database session
instance: Model instance to update
**updates: Field updates to apply
Returns:
Updated and refreshed model instance
"""
for field, value in updates.items():
setattr(instance, field, value)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
async def bulk_create(
session: AsyncSession,
model_class: Type[T],
items: List[Dict[str, Any]]
) -> List[T]:
"""Create multiple model instances in bulk.
Args:
session: Database session
model_class: SQLModel class to instantiate
items: List of dictionaries with model data
Returns:
List of created model instances
"""
instances = []
for item_data in items:
instance = model_class(**item_data)
session.add(instance)
instances.append(instance)
await session.commit()
# Refresh all instances
for instance in instances:
await session.refresh(instance)
return instances
async def delete_and_commit(
session: AsyncSession,
instance: T
) -> None:
"""Delete an instance and commit the transaction.
Args:
session: Database session
instance: Model instance to delete
"""
await session.delete(instance)
await session.commit()
async def exists(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> bool:
"""Check if a model instance exists with given criteria.
Args:
session: Database session
model_class: SQLModel class
**kwargs: Filter criteria
Returns:
True if instance exists, False otherwise
"""
filters = []
for key, value in kwargs.items():
filters.append(getattr(model_class, key) == value)
statement = select(model_class).where(*filters)
result = await session.exec(statement)
return result.first() is not None

121
app/utils/exceptions.py Normal file
View File

@@ -0,0 +1,121 @@
"""Utility functions for common HTTP exception patterns."""
from fastapi import HTTPException, status
def raise_not_found(resource: str, identifier: str = None) -> None:
"""Raise a standardized 404 Not Found exception.
Args:
resource: Name of the resource that wasn't found
identifier: Optional identifier for the specific resource
Raises:
HTTPException with 404 status code
"""
if identifier:
detail = f"{resource} with ID {identifier} not found"
else:
detail = f"{resource} not found"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def raise_unauthorized(detail: str = "Could not validate credentials") -> None:
"""Raise a standardized 401 Unauthorized exception.
Args:
detail: Error message detail
Raises:
HTTPException with 401 status code
"""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)
def raise_bad_request(detail: str) -> None:
"""Raise a standardized 400 Bad Request exception.
Args:
detail: Error message detail
Raises:
HTTPException with 400 status code
"""
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=detail,
)
def raise_internal_server_error(detail: str, cause: Exception = None) -> None:
"""Raise a standardized 500 Internal Server Error exception.
Args:
detail: Error message detail
cause: Optional underlying exception
Raises:
HTTPException with 500 status code
"""
if cause:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail,
) from cause
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail,
)
def raise_payment_required(detail: str = "Insufficient credits") -> None:
"""Raise a standardized 402 Payment Required exception.
Args:
detail: Error message detail
Raises:
HTTPException with 402 status code
"""
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=detail,
)
def raise_forbidden(detail: str = "Access forbidden") -> None:
"""Raise a standardized 403 Forbidden exception.
Args:
detail: Error message detail
Raises:
HTTPException with 403 status code
"""
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
def raise_conflict(detail: str) -> None:
"""Raise a standardized 409 Conflict exception.
Args:
detail: Error message detail
Raises:
HTTPException with 409 status code
"""
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)

179
app/utils/test_helpers.py Normal file
View File

@@ -0,0 +1,179 @@
"""Test helper utilities for reducing code duplication."""
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Type, TypeVar
from unittest.mock import AsyncMock
from fastapi import FastAPI
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import SQLModel
from app.models.user import User
from app.utils.auth import JWTUtils
T = TypeVar("T", bound=SQLModel)
def create_jwt_token_data(user: User) -> Dict[str, str]:
"""Create standardized JWT token data dictionary for a user.
Args:
user: User object to create token data for
Returns:
Dictionary with sub, email, and role fields
"""
return {
"sub": str(user.id),
"email": user.email,
"role": user.role,
}
def create_access_token_for_user(user: User) -> str:
"""Create an access token for a user using standardized token data.
Args:
user: User object to create token for
Returns:
JWT access token string
"""
token_data = create_jwt_token_data(user)
return JWTUtils.create_access_token(token_data)
async def create_and_save_model(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> T:
"""Create, save, and refresh a model instance.
This consolidates the common pattern of:
- model = ModelClass(**kwargs)
- session.add(model)
- await session.commit()
- await session.refresh(model)
Args:
session: Database session
model_class: SQLModel class to instantiate
**kwargs: Arguments to pass to model constructor
Returns:
Created and refreshed model instance
"""
instance = model_class(**kwargs)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
@asynccontextmanager
async def override_dependencies(
app: FastAPI,
overrides: Dict[Any, Any]
):
"""Context manager for FastAPI dependency overrides with automatic cleanup.
Args:
app: FastAPI application instance
overrides: Dictionary mapping dependency functions to mock implementations
Usage:
async with override_dependencies(test_app, {
get_service: lambda: mock_service,
get_repo: lambda: mock_repo
}):
# Test code here
pass
# Dependencies automatically cleaned up
"""
# Apply overrides
for dependency, override in overrides.items():
app.dependency_overrides[dependency] = override
try:
yield
finally:
# Clean up overrides
for dependency in overrides:
app.dependency_overrides.pop(dependency, None)
def create_mock_vlc_services() -> Dict[str, AsyncMock]:
"""Create standard set of mocked VLC-related services.
Returns:
Dictionary with mocked vlc_service, sound_repository, and credit_service
"""
return {
"vlc_service": AsyncMock(),
"sound_repository": AsyncMock(),
"credit_service": AsyncMock(),
}
def configure_mock_sound_play_success(
mocks: Dict[str, AsyncMock],
sound_data: Dict[str, Any]
) -> None:
"""Configure mocks for successful sound playback scenario.
Args:
mocks: Dictionary of mock services from create_mock_vlc_services()
sound_data: Dictionary with sound properties (id, name, etc.)
"""
from app.models.sound import Sound
mock_sound = Sound(**sound_data)
# Configure repository mock
mocks["sound_repository"].get_by_id.return_value = mock_sound
# Configure credit service mocks
mocks["credit_service"].validate_and_reserve_credits.return_value = None
mocks["credit_service"].deduct_credits.return_value = None
# Configure VLC service mock
mocks["vlc_service"].play_sound.return_value = True
def create_mock_vlc_stop_result(
success: bool = True,
processes_found: int = 3,
processes_killed: int = 3,
processes_remaining: int = 0,
message: Optional[str] = None,
error: Optional[str] = None
) -> Dict[str, Any]:
"""Create standardized VLC stop operation result.
Args:
success: Whether operation succeeded
processes_found: Number of VLC processes found
processes_killed: Number of processes successfully killed
processes_remaining: Number of processes still running
message: Success/status message
error: Error message (for failed operations)
Returns:
Dictionary with VLC stop operation result
"""
result = {
"success": success,
"processes_found": processes_found,
"processes_killed": processes_killed,
}
if not success:
result["error"] = error or "Command failed"
result["message"] = message or "Failed to stop VLC processes"
else:
# Always include processes_remaining for successful operations
result["processes_remaining"] = processes_remaining
result["message"] = message or f"Killed {processes_killed} VLC processes"
return result

140
app/utils/validation.py Normal file
View File

@@ -0,0 +1,140 @@
"""Common validation utility functions."""
import re
from pathlib import Path
from typing import Any, Optional
# Password validation constants
MIN_PASSWORD_LENGTH = 8
def validate_email(email: str) -> bool:
"""Validate email address format.
Args:
email: Email address to validate
Returns:
True if email format is valid, False otherwise
"""
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
return bool(re.match(pattern, email))
def validate_password_strength(password: str) -> tuple[bool, str | None]:
"""Validate password meets security requirements.
Args:
password: Password to validate
Returns:
Tuple of (is_valid, error_message)
"""
if len(password) < MIN_PASSWORD_LENGTH:
msg = f"Password must be at least {MIN_PASSWORD_LENGTH} characters long"
return False, msg
if not re.search(r"[A-Z]", password):
return False, "Password must contain at least one uppercase letter"
if not re.search(r"[a-z]", password):
return False, "Password must contain at least one lowercase letter"
if not re.search(r"\d", password):
return False, "Password must contain at least one number"
return True, None
def validate_filename(
filename: str, allowed_extensions: list[str] | None = None
) -> bool:
"""Validate filename format and extension.
Args:
filename: Filename to validate
allowed_extensions: List of allowed file extensions (with dots)
Returns:
True if filename is valid, False otherwise
"""
if not filename or filename.startswith(".") or "/" in filename or "\\" in filename:
return False
if allowed_extensions:
file_path = Path(filename)
return file_path.suffix.lower() in [ext.lower() for ext in allowed_extensions]
return True
def validate_audio_filename(filename: str) -> bool:
"""Validate audio filename has allowed extension.
Args:
filename: Audio filename to validate
Returns:
True if filename has valid audio extension, False otherwise
"""
audio_extensions = [".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"]
return validate_filename(filename, audio_extensions)
def sanitize_filename(filename: str) -> str:
"""Sanitize filename by removing/replacing invalid characters.
Args:
filename: Filename to sanitize
Returns:
Sanitized filename safe for filesystem
"""
# Remove or replace problematic characters
sanitized = re.sub(r'[<>:"/\\|?*]', "_", filename)
# Remove leading/trailing whitespace and dots
sanitized = sanitized.strip(" .")
# Ensure not empty
if not sanitized:
sanitized = "untitled"
return sanitized
def validate_url(url: str) -> bool:
"""Validate URL format.
Args:
url: URL to validate
Returns:
True if URL format is valid, False otherwise
"""
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
return bool(re.match(pattern, url))
def validate_positive_integer(value: Any, field_name: str = "value") -> int:
"""Validate and convert value to positive integer.
Args:
value: Value to validate and convert
field_name: Name of field for error messages
Returns:
Validated positive integer
Raises:
ValueError: If value is not a positive integer
"""
try:
int_value = int(value)
if int_value <= 0:
msg = f"{field_name} must be a positive integer"
raise ValueError(msg)
return int_value
except (TypeError, ValueError) as e:
msg = f"{field_name} must be a positive integer"
raise ValueError(msg) from e

View File

@@ -516,10 +516,9 @@ class TestPlayerEndpoints:
"status": PlayerStatus.PLAYING.value, "status": PlayerStatus.PLAYING.value,
"mode": PlayerMode.CONTINUOUS.value, "mode": PlayerMode.CONTINUOUS.value,
"volume": 50, "volume": 50,
"current_sound_id": 1, "position_ms": 5000,
"current_sound_index": 0, "duration_ms": 30000,
"current_sound_position": 5000, "index": 0,
"current_sound_duration": 30000,
"current_sound": { "current_sound": {
"id": 1, "id": 1,
"name": "Test Song", "name": "Test Song",
@@ -530,11 +529,13 @@ class TestPlayerEndpoints:
"thumbnail": None, "thumbnail": None,
"play_count": 0, "play_count": 0,
}, },
"playlist_id": 1, "playlist": {
"playlist_name": "Test Playlist", "id": 1,
"playlist_length": 1, "name": "Test Playlist",
"playlist_duration": 30000, "length": 1,
"playlist_sounds": [], "duration": 30000,
"sounds": [],
},
} }
mock_player_service.get_state.return_value = mock_state mock_player_service.get_state.return_value = mock_state

View File

@@ -1,12 +1,22 @@
"""Tests for VLC player API endpoints.""" """Tests for VLC player API endpoints."""
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from fastapi import FastAPI
from app.models.sound import Sound from app.models.sound import Sound
from app.models.user import User from app.models.user import User
from app.api.v1.sounds import get_vlc_player, get_sound_repository, get_credit_service
from app.utils.test_helpers import (
override_dependencies,
create_mock_vlc_services,
configure_mock_sound_play_success,
create_mock_vlc_stop_result,
)
class TestVLCEndpoints: class TestVLCEndpoints:
@@ -15,68 +25,93 @@ class TestVLCEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_play_sound_with_vlc_success( async def test_play_sound_with_vlc_success(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test successful sound playback via VLC.""" """Test successful sound playback via VLC."""
# Mock the VLC player service and sound repository methods # Set up mocks using helper
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound: mocks = create_mock_vlc_services()
mock_play_sound.return_value = True
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id: # Configure mocks for successful play
mock_sound = Sound( sound_data = {
id=1, "id": 1,
type="SDB", "type": "SDB",
name="Test Sound", "name": "Test Sound",
filename="test.mp3", "filename": "test.mp3",
duration=5000, "duration": 5000,
size=1024, "size": 1024,
hash="test_hash", "hash": "test_hash",
) }
mock_get_by_id.return_value = mock_sound configure_mock_sound_play_success(mocks, sound_data)
# Use dependency override helper
overrides = {
get_vlc_player: lambda: mocks["vlc_service"],
get_sound_repository: lambda: mocks["sound_repository"],
get_credit_service: lambda: mocks["credit_service"],
}
async with override_dependencies(test_app, overrides):
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1") response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True
assert data["sound_id"] == 1 assert data["sound_id"] == 1
assert data["sound_name"] == "Test Sound" assert data["sound_name"] == "Test Sound"
assert "Test Sound" in data["message"] assert "Test Sound" in data["message"]
# Verify service calls # Verify service calls
mock_get_by_id.assert_called_once_with(1) mocks["sound_repository"].get_by_id.assert_called_once_with(1)
mock_play_sound.assert_called_once_with(mock_sound) mocks["vlc_service"].play_sound.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_play_sound_with_vlc_sound_not_found( async def test_play_sound_with_vlc_sound_not_found(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test VLC playback when sound is not found.""" """Test VLC playback when sound is not found."""
# Mock the sound repository to return None # Set up mocks
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id: mock_vlc_service = AsyncMock()
mock_get_by_id.return_value = None mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
# Configure mocks
mock_repo.get_by_id.return_value = None
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/999") response = await authenticated_client.post("/api/v1/sounds/vlc/play/999")
assert response.status_code == 404 assert response.status_code == 404
data = response.json() data = response.json()
assert "Sound with ID 999 not found" in data["detail"] assert "Sound with ID 999 not found" in data["detail"]
finally:
# Clean up dependency overrides (except get_db which is needed for other tests)
test_app.dependency_overrides.pop(get_vlc_player, None)
test_app.dependency_overrides.pop(get_sound_repository, None)
test_app.dependency_overrides.pop(get_credit_service, None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_play_sound_with_vlc_launch_failure( async def test_play_sound_with_vlc_launch_failure(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test VLC playback when VLC launch fails.""" """Test VLC playback when VLC launch fails."""
# Mock the VLC player service to fail # Set up mocks
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound: mock_vlc_service = AsyncMock()
mock_play_sound.return_value = False mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id: # Set up test data
mock_sound = Sound( mock_sound = Sound(
id=1, id=1,
type="SDB", type="SDB",
@@ -86,30 +121,62 @@ class TestVLCEndpoints:
size=1024, size=1024,
hash="test_hash", hash="test_hash",
) )
mock_get_by_id.return_value = mock_sound
# Configure mocks
mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = False
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1") response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 500 assert response.status_code == 500
data = response.json() data = response.json()
assert "Failed to launch VLC for sound playback" in data["detail"] assert "Failed to launch VLC for sound playback" in data["detail"]
finally:
# Clean up dependency overrides (except get_db which is needed for other tests)
test_app.dependency_overrides.pop(get_vlc_player, None)
test_app.dependency_overrides.pop(get_sound_repository, None)
test_app.dependency_overrides.pop(get_credit_service, None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_play_sound_with_vlc_service_exception( async def test_play_sound_with_vlc_service_exception(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test VLC playback when service raises an exception.""" """Test VLC playback when service raises an exception."""
# Mock the sound repository to raise an exception # Set up mocks
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id: mock_vlc_service = AsyncMock()
mock_get_by_id.side_effect = Exception("Database error") mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
# Configure mocks
mock_repo.get_by_id.side_effect = Exception("Database error")
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1") response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 500 assert response.status_code == 500
data = response.json() data = response.json()
assert "Failed to play sound" in data["detail"] assert "Failed to play sound" in data["detail"]
finally:
# Clean up dependency overrides (except get_db which is needed for other tests)
test_app.dependency_overrides.pop(get_vlc_player, None)
test_app.dependency_overrides.pop(get_sound_repository, None)
test_app.dependency_overrides.pop(get_credit_service, None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_play_sound_with_vlc_unauthenticated( async def test_play_sound_with_vlc_unauthenticated(
@@ -123,21 +190,25 @@ class TestVLCEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_all_vlc_instances_success( async def test_stop_all_vlc_instances_success(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test successful stopping of all VLC instances.""" """Test successful stopping of all VLC instances."""
# Mock the VLC player service # Set up mock using helper
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all: mock_vlc_service = AsyncMock()
mock_result = { mock_result = create_mock_vlc_stop_result(
"success": True, success=True,
"processes_found": 3, processes_found=3,
"processes_killed": 3, processes_killed=3,
"processes_remaining": 0, processes_remaining=0
"message": "Killed 3 VLC processes", )
} mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
mock_stop_all.return_value = mock_result
# Use dependency override helper
overrides = {get_vlc_player: lambda: mock_vlc_service}
async with override_dependencies(test_app, overrides):
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
@@ -149,25 +220,30 @@ class TestVLCEndpoints:
assert "Killed 3 VLC processes" in data["message"] assert "Killed 3 VLC processes" in data["message"]
# Verify service call # Verify service call
mock_stop_all.assert_called_once() mock_vlc_service.stop_all_vlc_instances.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_all_vlc_instances_no_processes( async def test_stop_all_vlc_instances_no_processes(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test stopping VLC instances when none are running.""" """Test stopping VLC instances when none are running."""
# Mock the VLC player service # Set up mock using helper
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all: mock_vlc_service = AsyncMock()
mock_result = { mock_result = create_mock_vlc_stop_result(
"success": True, success=True,
"processes_found": 0, processes_found=0,
"processes_killed": 0, processes_killed=0,
"message": "No VLC processes found", message="No VLC processes found"
} )
mock_stop_all.return_value = mock_result mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Use dependency override helper
overrides = {get_vlc_player: lambda: mock_vlc_service}
async with override_dependencies(test_app, overrides):
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
@@ -180,12 +256,13 @@ class TestVLCEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_all_vlc_instances_partial_success( async def test_stop_all_vlc_instances_partial_success(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test stopping VLC instances with partial success.""" """Test stopping VLC instances with partial success."""
# Mock the VLC player service # Set up mock
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all: mock_vlc_service = AsyncMock()
mock_result = { mock_result = {
"success": True, "success": True,
"processes_found": 3, "processes_found": 3,
@@ -193,8 +270,12 @@ class TestVLCEndpoints:
"processes_remaining": 1, "processes_remaining": 1,
"message": "Killed 2 VLC processes", "message": "Killed 2 VLC processes",
} }
mock_stop_all.return_value = mock_result mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
@@ -203,16 +284,20 @@ class TestVLCEndpoints:
assert data["processes_found"] == 3 assert data["processes_found"] == 3
assert data["processes_killed"] == 2 assert data["processes_killed"] == 2
assert data["processes_remaining"] == 1 assert data["processes_remaining"] == 1
finally:
# Clean up dependency override
test_app.dependency_overrides.pop(get_vlc_player, None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_all_vlc_instances_failure( async def test_stop_all_vlc_instances_failure(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test stopping VLC instances when service fails.""" """Test stopping VLC instances when service fails."""
# Mock the VLC player service # Set up mock
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all: mock_vlc_service = AsyncMock()
mock_result = { mock_result = {
"success": False, "success": False,
"processes_found": 0, "processes_found": 0,
@@ -220,8 +305,12 @@ class TestVLCEndpoints:
"error": "Command failed", "error": "Command failed",
"message": "Failed to stop VLC processes", "message": "Failed to stop VLC processes",
} }
mock_stop_all.return_value = mock_result mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
@@ -229,23 +318,34 @@ class TestVLCEndpoints:
assert data["success"] is False assert data["success"] is False
assert data["error"] == "Command failed" assert data["error"] == "Command failed"
assert data["message"] == "Failed to stop VLC processes" assert data["message"] == "Failed to stop VLC processes"
finally:
# Clean up dependency override
test_app.dependency_overrides.pop(get_vlc_player, None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_all_vlc_instances_service_exception( async def test_stop_all_vlc_instances_service_exception(
self, self,
test_app: FastAPI,
authenticated_client: AsyncClient, authenticated_client: AsyncClient,
authenticated_user: User, authenticated_user: User,
): ):
"""Test stopping VLC instances when service raises an exception.""" """Test stopping VLC instances when service raises an exception."""
# Mock the VLC player service to raise an exception # Set up mock to raise an exception
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all: mock_vlc_service = AsyncMock()
mock_stop_all.side_effect = Exception("Service error") mock_vlc_service.stop_all_vlc_instances.side_effect = Exception("Service error")
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
try:
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 500 assert response.status_code == 500
data = response.json() data = response.json()
assert "Failed to stop VLC instances" in data["detail"] assert "Failed to stop VLC instances" in data["detail"]
finally:
# Clean up dependency override
test_app.dependency_overrides.pop(get_vlc_player, None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_all_vlc_instances_unauthenticated( async def test_stop_all_vlc_instances_unauthenticated(
@@ -259,15 +359,17 @@ class TestVLCEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_vlc_endpoints_with_admin_user( async def test_vlc_endpoints_with_admin_user(
self, self,
test_app: FastAPI,
authenticated_admin_client: AsyncClient, authenticated_admin_client: AsyncClient,
admin_user: User, admin_user: User,
): ):
"""Test VLC endpoints work with admin user.""" """Test VLC endpoints work with admin user."""
# Test play endpoint with admin # Set up mocks
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound: mock_vlc_service = AsyncMock()
mock_play_sound.return_value = True mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id: # Set up test data
mock_sound = Sound( mock_sound = Sound(
id=1, id=1,
type="SDB", type="SDB",
@@ -277,17 +379,32 @@ class TestVLCEndpoints:
size=512, size=512,
hash="admin_hash", hash="admin_hash",
) )
mock_get_by_id.return_value = mock_sound
# Configure mocks
mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = True
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1") response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True
assert data["sound_name"] == "Admin Test Sound" assert data["sound_name"] == "Admin Test Sound"
finally:
# Clean up dependency overrides (except get_db which is needed for other tests)
test_app.dependency_overrides.pop(get_vlc_player, None)
test_app.dependency_overrides.pop(get_sound_repository, None)
test_app.dependency_overrides.pop(get_credit_service, None)
# Test stop-all endpoint with admin # Test stop-all endpoint with admin
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all: mock_vlc_service_2 = AsyncMock()
mock_result = { mock_result = {
"success": True, "success": True,
"processes_found": 1, "processes_found": 1,
@@ -295,11 +412,18 @@ class TestVLCEndpoints:
"processes_remaining": 0, "processes_remaining": 0,
"message": "Killed 1 VLC processes", "message": "Killed 1 VLC processes",
} }
mock_stop_all.return_value = mock_result mock_vlc_service_2.stop_all_vlc_instances.return_value = mock_result
# Override dependency for stop-all test
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service_2
try:
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all") response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["success"] is True assert data["success"] is True
assert data["processes_killed"] == 1 assert data["processes_killed"] == 1
finally:
# Clean up dependency override
test_app.dependency_overrides.pop(get_vlc_player, None)

View File

@@ -18,6 +18,7 @@ from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
from app.utils.auth import JWTUtils, PasswordUtils from app.utils.auth import JWTUtils, PasswordUtils
from app.utils.test_helpers import create_access_token_for_user
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@@ -277,28 +278,14 @@ def test_login_data() -> dict[str, str]:
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def auth_headers(test_user: User) -> dict[str, str]: async def auth_headers(test_user: User) -> dict[str, str]:
"""Create authentication headers with JWT token.""" """Create authentication headers with JWT token."""
token_data = { access_token = create_access_token_for_user(test_user)
"sub": str(test_user.id),
"email": test_user.email,
"role": test_user.role,
}
access_token = JWTUtils.create_access_token(token_data)
return {"Authorization": f"Bearer {access_token}"} return {"Authorization": f"Bearer {access_token}"}
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def admin_headers(admin_user: User) -> dict[str, str]: async def admin_headers(admin_user: User) -> dict[str, str]:
"""Create admin authentication headers with JWT token.""" """Create admin authentication headers with JWT token."""
token_data = { access_token = create_access_token_for_user(admin_user)
"sub": str(admin_user.id),
"email": admin_user.email,
"role": admin_user.role,
}
access_token = JWTUtils.create_access_token(token_data)
return {"Authorization": f"Bearer {access_token}"} return {"Authorization": f"Bearer {access_token}"}
@@ -317,26 +304,12 @@ def authenticated_user(test_user: User) -> User:
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def auth_cookies(test_user: User) -> dict[str, str]: async def auth_cookies(test_user: User) -> dict[str, str]:
"""Create authentication cookies with JWT token.""" """Create authentication cookies with JWT token."""
token_data = { access_token = create_access_token_for_user(test_user)
"sub": str(test_user.id),
"email": test_user.email,
"role": test_user.role,
}
access_token = JWTUtils.create_access_token(token_data)
return {"access_token": access_token} return {"access_token": access_token}
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def admin_cookies(admin_user: User) -> dict[str, str]: async def admin_cookies(admin_user: User) -> dict[str, str]:
"""Create admin authentication cookies with JWT token.""" """Create admin authentication cookies with JWT token."""
token_data = { access_token = create_access_token_for_user(admin_user)
"sub": str(admin_user.id),
"email": admin_user.email,
"role": admin_user.role,
}
access_token = JWTUtils.create_access_token(token_data)
return {"access_token": access_token} return {"access_token": access_token}

View File

@@ -11,6 +11,7 @@ from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.repositories.user import UserRepository from app.repositories.user import UserRepository
from app.utils.auth import PasswordUtils from app.utils.auth import PasswordUtils
from app.utils.database import create_and_save
class TestUserRepository: class TestUserRepository: