Compare commits
2 Commits
c13285ca4e
...
b8346ab667
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8346ab667 | ||
|
|
f24698e3ff |
@@ -20,6 +20,11 @@ from app.schemas.auth import (
|
|||||||
)
|
)
|
||||||
from app.services.oauth import OAuthUserInfo
|
from app.services.oauth import OAuthUserInfo
|
||||||
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
|
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
|
||||||
|
from app.utils.exceptions import (
|
||||||
|
raise_bad_request,
|
||||||
|
raise_not_found,
|
||||||
|
raise_unauthorized,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -39,10 +44,7 @@ class AuthService:
|
|||||||
|
|
||||||
# Check if email already exists
|
# Check if email already exists
|
||||||
if await self.user_repo.email_exists(request.email):
|
if await self.user_repo.email_exists(request.email):
|
||||||
raise HTTPException(
|
raise_bad_request("Email address is already registered")
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Email address is already registered",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hash the password
|
# Hash the password
|
||||||
hashed_password = PasswordUtils.hash_password(request.password)
|
hashed_password = PasswordUtils.hash_password(request.password)
|
||||||
@@ -75,27 +77,18 @@ class AuthService:
|
|||||||
# Get user by email
|
# Get user by email
|
||||||
user = await self.user_repo.get_by_email(request.email)
|
user = await self.user_repo.get_by_email(request.email)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise_unauthorized("Invalid email or password")
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid email or password",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if user is active
|
# Check if user is active
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
raise HTTPException(
|
raise_unauthorized("Account is deactivated")
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Account is deactivated",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify password
|
# Verify password
|
||||||
if not user.password_hash or not PasswordUtils.verify_password(
|
if not user.password_hash or not PasswordUtils.verify_password(
|
||||||
request.password,
|
request.password,
|
||||||
user.password_hash,
|
user.password_hash,
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise_unauthorized("Invalid email or password")
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid email or password",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate access token
|
# Generate access token
|
||||||
token = self._create_access_token(user)
|
token = self._create_access_token(user)
|
||||||
@@ -110,16 +103,10 @@ class AuthService:
|
|||||||
"""Get the current authenticated user."""
|
"""Get the current authenticated user."""
|
||||||
user = await self.user_repo.get_by_id(user_id)
|
user = await self.user_repo.get_by_id(user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise_not_found("User")
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
raise HTTPException(
|
raise_unauthorized("Account is deactivated")
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Account is deactivated",
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,11 @@ from app.models.playlist import Playlist
|
|||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
from app.repositories.playlist import PlaylistRepository
|
from app.repositories.playlist import PlaylistRepository
|
||||||
from app.repositories.sound import SoundRepository
|
from app.repositories.sound import SoundRepository
|
||||||
|
from app.utils.exceptions import (
|
||||||
|
raise_bad_request,
|
||||||
|
raise_internal_server_error,
|
||||||
|
raise_not_found,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -27,10 +32,7 @@ class PlaylistService:
|
|||||||
"""Get a playlist by ID."""
|
"""Get a playlist by ID."""
|
||||||
playlist = await self.playlist_repo.get_by_id(playlist_id)
|
playlist = await self.playlist_repo.get_by_id(playlist_id)
|
||||||
if not playlist:
|
if not playlist:
|
||||||
raise HTTPException(
|
raise_not_found("Playlist")
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Playlist not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
return playlist
|
return playlist
|
||||||
|
|
||||||
@@ -47,9 +49,8 @@ class PlaylistService:
|
|||||||
main_playlist = await self.playlist_repo.get_main_playlist()
|
main_playlist = await self.playlist_repo.get_main_playlist()
|
||||||
|
|
||||||
if not main_playlist:
|
if not main_playlist:
|
||||||
raise HTTPException(
|
raise_internal_server_error(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
"Main playlist not found. Make sure to run database seeding."
|
||||||
detail="Main playlist not found. Make sure to run database seeding."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return main_playlist
|
return main_playlist
|
||||||
|
|||||||
166
app/utils/database.py
Normal file
166
app/utils/database.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Database utility functions for common operations."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Type, TypeVar
|
||||||
|
from sqlmodel import select, SQLModel
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=SQLModel)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_and_save(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_class: Type[T],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> T:
|
||||||
|
"""Create, add, commit, and refresh a model instance.
|
||||||
|
|
||||||
|
This consolidates the common database pattern of:
|
||||||
|
- instance = ModelClass(**kwargs)
|
||||||
|
- session.add(instance)
|
||||||
|
- await session.commit()
|
||||||
|
- await session.refresh(instance)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
model_class: SQLModel class to instantiate
|
||||||
|
**kwargs: Arguments to pass to model constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created and refreshed model instance
|
||||||
|
"""
|
||||||
|
instance = model_class(**kwargs)
|
||||||
|
session.add(instance)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(instance)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_class: Type[T],
|
||||||
|
defaults: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> tuple[T, bool]:
|
||||||
|
"""Get an existing instance or create a new one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
model_class: SQLModel class
|
||||||
|
defaults: Default values for creation (if not found)
|
||||||
|
**kwargs: Filter criteria for lookup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (instance, created) where created is True if instance was created
|
||||||
|
"""
|
||||||
|
# Build filter conditions
|
||||||
|
filters = []
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
filters.append(getattr(model_class, key) == value)
|
||||||
|
|
||||||
|
# Try to find existing instance
|
||||||
|
statement = select(model_class).where(*filters)
|
||||||
|
result = await session.exec(statement)
|
||||||
|
instance = result.first()
|
||||||
|
|
||||||
|
if instance:
|
||||||
|
return instance, False
|
||||||
|
|
||||||
|
# Create new instance
|
||||||
|
create_kwargs = {**kwargs}
|
||||||
|
if defaults:
|
||||||
|
create_kwargs.update(defaults)
|
||||||
|
|
||||||
|
instance = await create_and_save(session, model_class, **create_kwargs)
|
||||||
|
return instance, True
|
||||||
|
|
||||||
|
|
||||||
|
async def update_and_save(
|
||||||
|
session: AsyncSession,
|
||||||
|
instance: T,
|
||||||
|
**updates: Any
|
||||||
|
) -> T:
|
||||||
|
"""Update model instance fields and save to database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
instance: Model instance to update
|
||||||
|
**updates: Field updates to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated and refreshed model instance
|
||||||
|
"""
|
||||||
|
for field, value in updates.items():
|
||||||
|
setattr(instance, field, value)
|
||||||
|
|
||||||
|
session.add(instance)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(instance)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
async def bulk_create(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_class: Type[T],
|
||||||
|
items: List[Dict[str, Any]]
|
||||||
|
) -> List[T]:
|
||||||
|
"""Create multiple model instances in bulk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
model_class: SQLModel class to instantiate
|
||||||
|
items: List of dictionaries with model data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of created model instances
|
||||||
|
"""
|
||||||
|
instances = []
|
||||||
|
for item_data in items:
|
||||||
|
instance = model_class(**item_data)
|
||||||
|
session.add(instance)
|
||||||
|
instances.append(instance)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# Refresh all instances
|
||||||
|
for instance in instances:
|
||||||
|
await session.refresh(instance)
|
||||||
|
|
||||||
|
return instances
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_and_commit(
|
||||||
|
session: AsyncSession,
|
||||||
|
instance: T
|
||||||
|
) -> None:
|
||||||
|
"""Delete an instance and commit the transaction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
instance: Model instance to delete
|
||||||
|
"""
|
||||||
|
await session.delete(instance)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def exists(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_class: Type[T],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a model instance exists with given criteria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
model_class: SQLModel class
|
||||||
|
**kwargs: Filter criteria
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if instance exists, False otherwise
|
||||||
|
"""
|
||||||
|
filters = []
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
filters.append(getattr(model_class, key) == value)
|
||||||
|
|
||||||
|
statement = select(model_class).where(*filters)
|
||||||
|
result = await session.exec(statement)
|
||||||
|
return result.first() is not None
|
||||||
121
app/utils/exceptions.py
Normal file
121
app/utils/exceptions.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Utility functions for common HTTP exception patterns."""
|
||||||
|
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
|
def raise_not_found(resource: str, identifier: str = None) -> None:
|
||||||
|
"""Raise a standardized 404 Not Found exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource: Name of the resource that wasn't found
|
||||||
|
identifier: Optional identifier for the specific resource
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException with 404 status code
|
||||||
|
"""
|
||||||
|
if identifier:
|
||||||
|
detail = f"{resource} with ID {identifier} not found"
|
||||||
|
else:
|
||||||
|
detail = f"{resource} not found"
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_unauthorized(detail: str = "Could not validate credentials") -> None:
|
||||||
|
"""Raise a standardized 401 Unauthorized exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detail: Error message detail
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException with 401 status code
|
||||||
|
"""
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_bad_request(detail: str) -> None:
|
||||||
|
"""Raise a standardized 400 Bad Request exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detail: Error message detail
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException with 400 status code
|
||||||
|
"""
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_internal_server_error(detail: str, cause: Exception = None) -> None:
|
||||||
|
"""Raise a standardized 500 Internal Server Error exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detail: Error message detail
|
||||||
|
cause: Optional underlying exception
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException with 500 status code
|
||||||
|
"""
|
||||||
|
if cause:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail,
|
||||||
|
) from cause
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_payment_required(detail: str = "Insufficient credits") -> None:
|
||||||
|
"""Raise a standardized 402 Payment Required exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detail: Error message detail
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException with 402 status code
|
||||||
|
"""
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_forbidden(detail: str = "Access forbidden") -> None:
|
||||||
|
"""Raise a standardized 403 Forbidden exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detail: Error message detail
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException with 403 status code
|
||||||
|
"""
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_conflict(detail: str) -> None:
|
||||||
|
"""Raise a standardized 409 Conflict exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detail: Error message detail
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException with 409 status code
|
||||||
|
"""
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=detail,
|
||||||
|
)
|
||||||
179
app/utils/test_helpers.py
Normal file
179
app/utils/test_helpers.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""Test helper utilities for reducing code duplication."""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, Dict, Optional, Type, TypeVar
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
from app.models.user import User
|
||||||
|
from app.utils.auth import JWTUtils
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=SQLModel)
|
||||||
|
|
||||||
|
|
||||||
|
def create_jwt_token_data(user: User) -> Dict[str, str]:
|
||||||
|
"""Create standardized JWT token data dictionary for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User object to create token data for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with sub, email, and role fields
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"sub": str(user.id),
|
||||||
|
"email": user.email,
|
||||||
|
"role": user.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token_for_user(user: User) -> str:
|
||||||
|
"""Create an access token for a user using standardized token data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User object to create token for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT access token string
|
||||||
|
"""
|
||||||
|
token_data = create_jwt_token_data(user)
|
||||||
|
return JWTUtils.create_access_token(token_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_and_save_model(
|
||||||
|
session: AsyncSession,
|
||||||
|
model_class: Type[T],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> T:
|
||||||
|
"""Create, save, and refresh a model instance.
|
||||||
|
|
||||||
|
This consolidates the common pattern of:
|
||||||
|
- model = ModelClass(**kwargs)
|
||||||
|
- session.add(model)
|
||||||
|
- await session.commit()
|
||||||
|
- await session.refresh(model)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
model_class: SQLModel class to instantiate
|
||||||
|
**kwargs: Arguments to pass to model constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created and refreshed model instance
|
||||||
|
"""
|
||||||
|
instance = model_class(**kwargs)
|
||||||
|
session.add(instance)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(instance)
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def override_dependencies(
|
||||||
|
app: FastAPI,
|
||||||
|
overrides: Dict[Any, Any]
|
||||||
|
):
|
||||||
|
"""Context manager for FastAPI dependency overrides with automatic cleanup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance
|
||||||
|
overrides: Dictionary mapping dependency functions to mock implementations
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with override_dependencies(test_app, {
|
||||||
|
get_service: lambda: mock_service,
|
||||||
|
get_repo: lambda: mock_repo
|
||||||
|
}):
|
||||||
|
# Test code here
|
||||||
|
pass
|
||||||
|
# Dependencies automatically cleaned up
|
||||||
|
"""
|
||||||
|
# Apply overrides
|
||||||
|
for dependency, override in overrides.items():
|
||||||
|
app.dependency_overrides[dependency] = override
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Clean up overrides
|
||||||
|
for dependency in overrides:
|
||||||
|
app.dependency_overrides.pop(dependency, None)
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_vlc_services() -> Dict[str, AsyncMock]:
|
||||||
|
"""Create standard set of mocked VLC-related services.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with mocked vlc_service, sound_repository, and credit_service
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"vlc_service": AsyncMock(),
|
||||||
|
"sound_repository": AsyncMock(),
|
||||||
|
"credit_service": AsyncMock(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def configure_mock_sound_play_success(
|
||||||
|
mocks: Dict[str, AsyncMock],
|
||||||
|
sound_data: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Configure mocks for successful sound playback scenario.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mocks: Dictionary of mock services from create_mock_vlc_services()
|
||||||
|
sound_data: Dictionary with sound properties (id, name, etc.)
|
||||||
|
"""
|
||||||
|
from app.models.sound import Sound
|
||||||
|
|
||||||
|
mock_sound = Sound(**sound_data)
|
||||||
|
|
||||||
|
# Configure repository mock
|
||||||
|
mocks["sound_repository"].get_by_id.return_value = mock_sound
|
||||||
|
|
||||||
|
# Configure credit service mocks
|
||||||
|
mocks["credit_service"].validate_and_reserve_credits.return_value = None
|
||||||
|
mocks["credit_service"].deduct_credits.return_value = None
|
||||||
|
|
||||||
|
# Configure VLC service mock
|
||||||
|
mocks["vlc_service"].play_sound.return_value = True
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_vlc_stop_result(
|
||||||
|
success: bool = True,
|
||||||
|
processes_found: int = 3,
|
||||||
|
processes_killed: int = 3,
|
||||||
|
processes_remaining: int = 0,
|
||||||
|
message: Optional[str] = None,
|
||||||
|
error: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create standardized VLC stop operation result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
success: Whether operation succeeded
|
||||||
|
processes_found: Number of VLC processes found
|
||||||
|
processes_killed: Number of processes successfully killed
|
||||||
|
processes_remaining: Number of processes still running
|
||||||
|
message: Success/status message
|
||||||
|
error: Error message (for failed operations)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with VLC stop operation result
|
||||||
|
"""
|
||||||
|
result = {
|
||||||
|
"success": success,
|
||||||
|
"processes_found": processes_found,
|
||||||
|
"processes_killed": processes_killed,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
result["error"] = error or "Command failed"
|
||||||
|
result["message"] = message or "Failed to stop VLC processes"
|
||||||
|
else:
|
||||||
|
# Always include processes_remaining for successful operations
|
||||||
|
result["processes_remaining"] = processes_remaining
|
||||||
|
result["message"] = message or f"Killed {processes_killed} VLC processes"
|
||||||
|
|
||||||
|
return result
|
||||||
140
app/utils/validation.py
Normal file
140
app/utils/validation.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Common validation utility functions."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Password validation constants
|
||||||
|
MIN_PASSWORD_LENGTH = 8
|
||||||
|
|
||||||
|
|
||||||
|
def validate_email(email: str) -> bool:
|
||||||
|
"""Validate email address format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: Email address to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if email format is valid, False otherwise
|
||||||
|
"""
|
||||||
|
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||||
|
return bool(re.match(pattern, email))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_password_strength(password: str) -> tuple[bool, str | None]:
|
||||||
|
"""Validate password meets security requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
password: Password to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, error_message)
|
||||||
|
"""
|
||||||
|
if len(password) < MIN_PASSWORD_LENGTH:
|
||||||
|
msg = f"Password must be at least {MIN_PASSWORD_LENGTH} characters long"
|
||||||
|
return False, msg
|
||||||
|
|
||||||
|
if not re.search(r"[A-Z]", password):
|
||||||
|
return False, "Password must contain at least one uppercase letter"
|
||||||
|
|
||||||
|
if not re.search(r"[a-z]", password):
|
||||||
|
return False, "Password must contain at least one lowercase letter"
|
||||||
|
|
||||||
|
if not re.search(r"\d", password):
|
||||||
|
return False, "Password must contain at least one number"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_filename(
|
||||||
|
filename: str, allowed_extensions: list[str] | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""Validate filename format and extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to validate
|
||||||
|
allowed_extensions: List of allowed file extensions (with dots)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if filename is valid, False otherwise
|
||||||
|
"""
|
||||||
|
if not filename or filename.startswith(".") or "/" in filename or "\\" in filename:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if allowed_extensions:
|
||||||
|
file_path = Path(filename)
|
||||||
|
return file_path.suffix.lower() in [ext.lower() for ext in allowed_extensions]
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def validate_audio_filename(filename: str) -> bool:
|
||||||
|
"""Validate audio filename has allowed extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Audio filename to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if filename has valid audio extension, False otherwise
|
||||||
|
"""
|
||||||
|
audio_extensions = [".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"]
|
||||||
|
return validate_filename(filename, audio_extensions)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_filename(filename: str) -> str:
|
||||||
|
"""Sanitize filename by removing/replacing invalid characters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Filename to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized filename safe for filesystem
|
||||||
|
"""
|
||||||
|
# Remove or replace problematic characters
|
||||||
|
sanitized = re.sub(r'[<>:"/\\|?*]', "_", filename)
|
||||||
|
|
||||||
|
# Remove leading/trailing whitespace and dots
|
||||||
|
sanitized = sanitized.strip(" .")
|
||||||
|
|
||||||
|
# Ensure not empty
|
||||||
|
if not sanitized:
|
||||||
|
sanitized = "untitled"
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(url: str) -> bool:
|
||||||
|
"""Validate URL format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if URL format is valid, False otherwise
|
||||||
|
"""
|
||||||
|
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
|
||||||
|
return bool(re.match(pattern, url))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_positive_integer(value: Any, field_name: str = "value") -> int:
|
||||||
|
"""Validate and convert value to positive integer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to validate and convert
|
||||||
|
field_name: Name of field for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated positive integer
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If value is not a positive integer
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
int_value = int(value)
|
||||||
|
if int_value <= 0:
|
||||||
|
msg = f"{field_name} must be a positive integer"
|
||||||
|
raise ValueError(msg)
|
||||||
|
return int_value
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
msg = f"{field_name} must be a positive integer"
|
||||||
|
raise ValueError(msg) from e
|
||||||
@@ -516,10 +516,9 @@ class TestPlayerEndpoints:
|
|||||||
"status": PlayerStatus.PLAYING.value,
|
"status": PlayerStatus.PLAYING.value,
|
||||||
"mode": PlayerMode.CONTINUOUS.value,
|
"mode": PlayerMode.CONTINUOUS.value,
|
||||||
"volume": 50,
|
"volume": 50,
|
||||||
"current_sound_id": 1,
|
"position_ms": 5000,
|
||||||
"current_sound_index": 0,
|
"duration_ms": 30000,
|
||||||
"current_sound_position": 5000,
|
"index": 0,
|
||||||
"current_sound_duration": 30000,
|
|
||||||
"current_sound": {
|
"current_sound": {
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"name": "Test Song",
|
"name": "Test Song",
|
||||||
@@ -530,11 +529,13 @@ class TestPlayerEndpoints:
|
|||||||
"thumbnail": None,
|
"thumbnail": None,
|
||||||
"play_count": 0,
|
"play_count": 0,
|
||||||
},
|
},
|
||||||
"playlist_id": 1,
|
"playlist": {
|
||||||
"playlist_name": "Test Playlist",
|
"id": 1,
|
||||||
"playlist_length": 1,
|
"name": "Test Playlist",
|
||||||
"playlist_duration": 30000,
|
"length": 1,
|
||||||
"playlist_sounds": [],
|
"duration": 30000,
|
||||||
|
"sounds": [],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
mock_player_service.get_state.return_value = mock_state
|
mock_player_service.get_state.return_value = mock_state
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,22 @@
|
|||||||
"""Tests for VLC player API endpoints."""
|
"""Tests for VLC player API endpoints."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.api.v1.sounds import get_vlc_player, get_sound_repository, get_credit_service
|
||||||
|
from app.utils.test_helpers import (
|
||||||
|
override_dependencies,
|
||||||
|
create_mock_vlc_services,
|
||||||
|
configure_mock_sound_play_success,
|
||||||
|
create_mock_vlc_stop_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestVLCEndpoints:
|
class TestVLCEndpoints:
|
||||||
@@ -15,101 +25,158 @@ class TestVLCEndpoints:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_play_sound_with_vlc_success(
|
async def test_play_sound_with_vlc_success(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test successful sound playback via VLC."""
|
"""Test successful sound playback via VLC."""
|
||||||
# Mock the VLC player service and sound repository methods
|
# Set up mocks using helper
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
mocks = create_mock_vlc_services()
|
||||||
mock_play_sound.return_value = True
|
|
||||||
|
# Configure mocks for successful play
|
||||||
|
sound_data = {
|
||||||
|
"id": 1,
|
||||||
|
"type": "SDB",
|
||||||
|
"name": "Test Sound",
|
||||||
|
"filename": "test.mp3",
|
||||||
|
"duration": 5000,
|
||||||
|
"size": 1024,
|
||||||
|
"hash": "test_hash",
|
||||||
|
}
|
||||||
|
configure_mock_sound_play_success(mocks, sound_data)
|
||||||
|
|
||||||
|
# Use dependency override helper
|
||||||
|
overrides = {
|
||||||
|
get_vlc_player: lambda: mocks["vlc_service"],
|
||||||
|
get_sound_repository: lambda: mocks["sound_repository"],
|
||||||
|
get_credit_service: lambda: mocks["credit_service"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with override_dependencies(test_app, overrides):
|
||||||
|
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||||
|
|
||||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
assert response.status_code == 200
|
||||||
mock_sound = Sound(
|
data = response.json()
|
||||||
id=1,
|
assert data["sound_id"] == 1
|
||||||
type="SDB",
|
assert data["sound_name"] == "Test Sound"
|
||||||
name="Test Sound",
|
assert "Test Sound" in data["message"]
|
||||||
filename="test.mp3",
|
|
||||||
duration=5000,
|
# Verify service calls
|
||||||
size=1024,
|
mocks["sound_repository"].get_by_id.assert_called_once_with(1)
|
||||||
hash="test_hash",
|
mocks["vlc_service"].play_sound.assert_called_once()
|
||||||
)
|
|
||||||
mock_get_by_id.return_value = mock_sound
|
|
||||||
|
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["sound_id"] == 1
|
|
||||||
assert data["sound_name"] == "Test Sound"
|
|
||||||
assert "Test Sound" in data["message"]
|
|
||||||
|
|
||||||
# Verify service calls
|
|
||||||
mock_get_by_id.assert_called_once_with(1)
|
|
||||||
mock_play_sound.assert_called_once_with(mock_sound)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_play_sound_with_vlc_sound_not_found(
|
async def test_play_sound_with_vlc_sound_not_found(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test VLC playback when sound is not found."""
|
"""Test VLC playback when sound is not found."""
|
||||||
# Mock the sound repository to return None
|
# Set up mocks
|
||||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_get_by_id.return_value = None
|
mock_repo = AsyncMock()
|
||||||
|
mock_credit_service = AsyncMock()
|
||||||
|
|
||||||
|
# Configure mocks
|
||||||
|
mock_repo.get_by_id.return_value = None
|
||||||
|
|
||||||
|
# Override dependencies
|
||||||
|
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||||
|
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
|
||||||
|
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
|
||||||
|
|
||||||
|
try:
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/999")
|
response = await authenticated_client.post("/api/v1/sounds/vlc/play/999")
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "Sound with ID 999 not found" in data["detail"]
|
assert "Sound with ID 999 not found" in data["detail"]
|
||||||
|
finally:
|
||||||
|
# Clean up dependency overrides (except get_db which is needed for other tests)
|
||||||
|
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||||
|
test_app.dependency_overrides.pop(get_sound_repository, None)
|
||||||
|
test_app.dependency_overrides.pop(get_credit_service, None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_play_sound_with_vlc_launch_failure(
|
async def test_play_sound_with_vlc_launch_failure(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test VLC playback when VLC launch fails."""
|
"""Test VLC playback when VLC launch fails."""
|
||||||
# Mock the VLC player service to fail
|
# Set up mocks
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_play_sound.return_value = False
|
mock_repo = AsyncMock()
|
||||||
|
mock_credit_service = AsyncMock()
|
||||||
|
|
||||||
|
# Set up test data
|
||||||
|
mock_sound = Sound(
|
||||||
|
id=1,
|
||||||
|
type="SDB",
|
||||||
|
name="Test Sound",
|
||||||
|
filename="test.mp3",
|
||||||
|
duration=5000,
|
||||||
|
size=1024,
|
||||||
|
hash="test_hash",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure mocks
|
||||||
|
mock_repo.get_by_id.return_value = mock_sound
|
||||||
|
mock_credit_service.validate_and_reserve_credits.return_value = None
|
||||||
|
mock_credit_service.deduct_credits.return_value = None
|
||||||
|
mock_vlc_service.play_sound.return_value = False
|
||||||
|
|
||||||
|
# Override dependencies
|
||||||
|
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||||
|
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
|
||||||
|
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||||
|
|
||||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
assert response.status_code == 500
|
||||||
mock_sound = Sound(
|
data = response.json()
|
||||||
id=1,
|
assert "Failed to launch VLC for sound playback" in data["detail"]
|
||||||
type="SDB",
|
finally:
|
||||||
name="Test Sound",
|
# Clean up dependency overrides (except get_db which is needed for other tests)
|
||||||
filename="test.mp3",
|
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||||
duration=5000,
|
test_app.dependency_overrides.pop(get_sound_repository, None)
|
||||||
size=1024,
|
test_app.dependency_overrides.pop(get_credit_service, None)
|
||||||
hash="test_hash",
|
|
||||||
)
|
|
||||||
mock_get_by_id.return_value = mock_sound
|
|
||||||
|
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
|
||||||
|
|
||||||
assert response.status_code == 500
|
|
||||||
data = response.json()
|
|
||||||
assert "Failed to launch VLC for sound playback" in data["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_play_sound_with_vlc_service_exception(
|
async def test_play_sound_with_vlc_service_exception(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test VLC playback when service raises an exception."""
|
"""Test VLC playback when service raises an exception."""
|
||||||
# Mock the sound repository to raise an exception
|
# Set up mocks
|
||||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_get_by_id.side_effect = Exception("Database error")
|
mock_repo = AsyncMock()
|
||||||
|
mock_credit_service = AsyncMock()
|
||||||
|
|
||||||
|
# Configure mocks
|
||||||
|
mock_repo.get_by_id.side_effect = Exception("Database error")
|
||||||
|
|
||||||
|
# Override dependencies
|
||||||
|
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||||
|
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
|
||||||
|
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
|
||||||
|
|
||||||
|
try:
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
|
||||||
|
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "Failed to play sound" in data["detail"]
|
assert "Failed to play sound" in data["detail"]
|
||||||
|
finally:
|
||||||
|
# Clean up dependency overrides (except get_db which is needed for other tests)
|
||||||
|
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||||
|
test_app.dependency_overrides.pop(get_sound_repository, None)
|
||||||
|
test_app.dependency_overrides.pop(get_credit_service, None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_play_sound_with_vlc_unauthenticated(
|
async def test_play_sound_with_vlc_unauthenticated(
|
||||||
@@ -123,21 +190,25 @@ class TestVLCEndpoints:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_all_vlc_instances_success(
|
async def test_stop_all_vlc_instances_success(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test successful stopping of all VLC instances."""
|
"""Test successful stopping of all VLC instances."""
|
||||||
# Mock the VLC player service
|
# Set up mock using helper
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_result = {
|
mock_result = create_mock_vlc_stop_result(
|
||||||
"success": True,
|
success=True,
|
||||||
"processes_found": 3,
|
processes_found=3,
|
||||||
"processes_killed": 3,
|
processes_killed=3,
|
||||||
"processes_remaining": 0,
|
processes_remaining=0
|
||||||
"message": "Killed 3 VLC processes",
|
)
|
||||||
}
|
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
|
||||||
mock_stop_all.return_value = mock_result
|
|
||||||
|
# Use dependency override helper
|
||||||
|
overrides = {get_vlc_player: lambda: mock_vlc_service}
|
||||||
|
|
||||||
|
async with override_dependencies(test_app, overrides):
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -149,25 +220,30 @@ class TestVLCEndpoints:
|
|||||||
assert "Killed 3 VLC processes" in data["message"]
|
assert "Killed 3 VLC processes" in data["message"]
|
||||||
|
|
||||||
# Verify service call
|
# Verify service call
|
||||||
mock_stop_all.assert_called_once()
|
mock_vlc_service.stop_all_vlc_instances.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_all_vlc_instances_no_processes(
|
async def test_stop_all_vlc_instances_no_processes(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test stopping VLC instances when none are running."""
|
"""Test stopping VLC instances when none are running."""
|
||||||
# Mock the VLC player service
|
# Set up mock using helper
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_result = {
|
mock_result = create_mock_vlc_stop_result(
|
||||||
"success": True,
|
success=True,
|
||||||
"processes_found": 0,
|
processes_found=0,
|
||||||
"processes_killed": 0,
|
processes_killed=0,
|
||||||
"message": "No VLC processes found",
|
message="No VLC processes found"
|
||||||
}
|
)
|
||||||
mock_stop_all.return_value = mock_result
|
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
|
||||||
|
|
||||||
|
# Use dependency override helper
|
||||||
|
overrides = {get_vlc_player: lambda: mock_vlc_service}
|
||||||
|
|
||||||
|
async with override_dependencies(test_app, overrides):
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -180,21 +256,26 @@ class TestVLCEndpoints:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_all_vlc_instances_partial_success(
|
async def test_stop_all_vlc_instances_partial_success(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test stopping VLC instances with partial success."""
|
"""Test stopping VLC instances with partial success."""
|
||||||
# Mock the VLC player service
|
# Set up mock
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_result = {
|
mock_result = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"processes_found": 3,
|
"processes_found": 3,
|
||||||
"processes_killed": 2,
|
"processes_killed": 2,
|
||||||
"processes_remaining": 1,
|
"processes_remaining": 1,
|
||||||
"message": "Killed 2 VLC processes",
|
"message": "Killed 2 VLC processes",
|
||||||
}
|
}
|
||||||
mock_stop_all.return_value = mock_result
|
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
|
||||||
|
|
||||||
|
# Override dependency
|
||||||
|
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||||
|
|
||||||
|
try:
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -203,25 +284,33 @@ class TestVLCEndpoints:
|
|||||||
assert data["processes_found"] == 3
|
assert data["processes_found"] == 3
|
||||||
assert data["processes_killed"] == 2
|
assert data["processes_killed"] == 2
|
||||||
assert data["processes_remaining"] == 1
|
assert data["processes_remaining"] == 1
|
||||||
|
finally:
|
||||||
|
# Clean up dependency override
|
||||||
|
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_all_vlc_instances_failure(
|
async def test_stop_all_vlc_instances_failure(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test stopping VLC instances when service fails."""
|
"""Test stopping VLC instances when service fails."""
|
||||||
# Mock the VLC player service
|
# Set up mock
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_result = {
|
mock_result = {
|
||||||
"success": False,
|
"success": False,
|
||||||
"processes_found": 0,
|
"processes_found": 0,
|
||||||
"processes_killed": 0,
|
"processes_killed": 0,
|
||||||
"error": "Command failed",
|
"error": "Command failed",
|
||||||
"message": "Failed to stop VLC processes",
|
"message": "Failed to stop VLC processes",
|
||||||
}
|
}
|
||||||
mock_stop_all.return_value = mock_result
|
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
|
||||||
|
|
||||||
|
# Override dependency
|
||||||
|
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||||
|
|
||||||
|
try:
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -229,23 +318,34 @@ class TestVLCEndpoints:
|
|||||||
assert data["success"] is False
|
assert data["success"] is False
|
||||||
assert data["error"] == "Command failed"
|
assert data["error"] == "Command failed"
|
||||||
assert data["message"] == "Failed to stop VLC processes"
|
assert data["message"] == "Failed to stop VLC processes"
|
||||||
|
finally:
|
||||||
|
# Clean up dependency override
|
||||||
|
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_all_vlc_instances_service_exception(
|
async def test_stop_all_vlc_instances_service_exception(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_client: AsyncClient,
|
authenticated_client: AsyncClient,
|
||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
):
|
):
|
||||||
"""Test stopping VLC instances when service raises an exception."""
|
"""Test stopping VLC instances when service raises an exception."""
|
||||||
# Mock the VLC player service to raise an exception
|
# Set up mock to raise an exception
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_stop_all.side_effect = Exception("Service error")
|
mock_vlc_service.stop_all_vlc_instances.side_effect = Exception("Service error")
|
||||||
|
|
||||||
|
# Override dependency
|
||||||
|
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||||
|
|
||||||
|
try:
|
||||||
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
|
||||||
|
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "Failed to stop VLC instances" in data["detail"]
|
assert "Failed to stop VLC instances" in data["detail"]
|
||||||
|
finally:
|
||||||
|
# Clean up dependency override
|
||||||
|
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_all_vlc_instances_unauthenticated(
|
async def test_stop_all_vlc_instances_unauthenticated(
|
||||||
@@ -259,47 +359,71 @@ class TestVLCEndpoints:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vlc_endpoints_with_admin_user(
|
async def test_vlc_endpoints_with_admin_user(
|
||||||
self,
|
self,
|
||||||
|
test_app: FastAPI,
|
||||||
authenticated_admin_client: AsyncClient,
|
authenticated_admin_client: AsyncClient,
|
||||||
admin_user: User,
|
admin_user: User,
|
||||||
):
|
):
|
||||||
"""Test VLC endpoints work with admin user."""
|
"""Test VLC endpoints work with admin user."""
|
||||||
# Test play endpoint with admin
|
# Set up mocks
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play_sound:
|
mock_vlc_service = AsyncMock()
|
||||||
mock_play_sound.return_value = True
|
mock_repo = AsyncMock()
|
||||||
|
mock_credit_service = AsyncMock()
|
||||||
|
|
||||||
|
# Set up test data
|
||||||
|
mock_sound = Sound(
|
||||||
|
id=1,
|
||||||
|
type="SDB",
|
||||||
|
name="Admin Test Sound",
|
||||||
|
filename="admin_test.mp3",
|
||||||
|
duration=3000,
|
||||||
|
size=512,
|
||||||
|
hash="admin_hash",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure mocks
|
||||||
|
mock_repo.get_by_id.return_value = mock_sound
|
||||||
|
mock_credit_service.validate_and_reserve_credits.return_value = None
|
||||||
|
mock_credit_service.deduct_credits.return_value = None
|
||||||
|
mock_vlc_service.play_sound.return_value = True
|
||||||
|
|
||||||
|
# Override dependencies
|
||||||
|
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
|
||||||
|
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
|
||||||
|
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
|
||||||
|
|
||||||
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_by_id:
|
assert response.status_code == 200
|
||||||
mock_sound = Sound(
|
data = response.json()
|
||||||
id=1,
|
assert data["sound_name"] == "Admin Test Sound"
|
||||||
type="SDB",
|
finally:
|
||||||
name="Admin Test Sound",
|
# Clean up dependency overrides (except get_db which is needed for other tests)
|
||||||
filename="admin_test.mp3",
|
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||||
duration=3000,
|
test_app.dependency_overrides.pop(get_sound_repository, None)
|
||||||
size=512,
|
test_app.dependency_overrides.pop(get_credit_service, None)
|
||||||
hash="admin_hash",
|
|
||||||
)
|
|
||||||
mock_get_by_id.return_value = mock_sound
|
|
||||||
|
|
||||||
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/play/1")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["success"] is True
|
|
||||||
assert data["sound_name"] == "Admin Test Sound"
|
|
||||||
|
|
||||||
# Test stop-all endpoint with admin
|
# Test stop-all endpoint with admin
|
||||||
with patch("app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances") as mock_stop_all:
|
mock_vlc_service_2 = AsyncMock()
|
||||||
mock_result = {
|
mock_result = {
|
||||||
"success": True,
|
"success": True,
|
||||||
"processes_found": 1,
|
"processes_found": 1,
|
||||||
"processes_killed": 1,
|
"processes_killed": 1,
|
||||||
"processes_remaining": 0,
|
"processes_remaining": 0,
|
||||||
"message": "Killed 1 VLC processes",
|
"message": "Killed 1 VLC processes",
|
||||||
}
|
}
|
||||||
mock_stop_all.return_value = mock_result
|
mock_vlc_service_2.stop_all_vlc_instances.return_value = mock_result
|
||||||
|
|
||||||
|
# Override dependency for stop-all test
|
||||||
|
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service_2
|
||||||
|
|
||||||
|
try:
|
||||||
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all")
|
response = await authenticated_admin_client.post("/api/v1/sounds/vlc/stop-all")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["success"] is True
|
assert data["success"] is True
|
||||||
assert data["processes_killed"] == 1
|
assert data["processes_killed"] == 1
|
||||||
|
finally:
|
||||||
|
# Clean up dependency override
|
||||||
|
test_app.dependency_overrides.pop(get_vlc_player, None)
|
||||||
@@ -18,6 +18,7 @@ from app.models.plan import Plan
|
|||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
|
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
|
||||||
from app.utils.auth import JWTUtils, PasswordUtils
|
from app.utils.auth import JWTUtils, PasswordUtils
|
||||||
|
from app.utils.test_helpers import create_access_token_for_user
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@@ -277,28 +278,14 @@ def test_login_data() -> dict[str, str]:
|
|||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def auth_headers(test_user: User) -> dict[str, str]:
|
async def auth_headers(test_user: User) -> dict[str, str]:
|
||||||
"""Create authentication headers with JWT token."""
|
"""Create authentication headers with JWT token."""
|
||||||
token_data = {
|
access_token = create_access_token_for_user(test_user)
|
||||||
"sub": str(test_user.id),
|
|
||||||
"email": test_user.email,
|
|
||||||
"role": test_user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
access_token = JWTUtils.create_access_token(token_data)
|
|
||||||
|
|
||||||
return {"Authorization": f"Bearer {access_token}"}
|
return {"Authorization": f"Bearer {access_token}"}
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def admin_headers(admin_user: User) -> dict[str, str]:
|
async def admin_headers(admin_user: User) -> dict[str, str]:
|
||||||
"""Create admin authentication headers with JWT token."""
|
"""Create admin authentication headers with JWT token."""
|
||||||
token_data = {
|
access_token = create_access_token_for_user(admin_user)
|
||||||
"sub": str(admin_user.id),
|
|
||||||
"email": admin_user.email,
|
|
||||||
"role": admin_user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
access_token = JWTUtils.create_access_token(token_data)
|
|
||||||
|
|
||||||
return {"Authorization": f"Bearer {access_token}"}
|
return {"Authorization": f"Bearer {access_token}"}
|
||||||
|
|
||||||
|
|
||||||
@@ -317,26 +304,12 @@ def authenticated_user(test_user: User) -> User:
|
|||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def auth_cookies(test_user: User) -> dict[str, str]:
|
async def auth_cookies(test_user: User) -> dict[str, str]:
|
||||||
"""Create authentication cookies with JWT token."""
|
"""Create authentication cookies with JWT token."""
|
||||||
token_data = {
|
access_token = create_access_token_for_user(test_user)
|
||||||
"sub": str(test_user.id),
|
|
||||||
"email": test_user.email,
|
|
||||||
"role": test_user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
access_token = JWTUtils.create_access_token(token_data)
|
|
||||||
|
|
||||||
return {"access_token": access_token}
|
return {"access_token": access_token}
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
async def admin_cookies(admin_user: User) -> dict[str, str]:
|
async def admin_cookies(admin_user: User) -> dict[str, str]:
|
||||||
"""Create admin authentication cookies with JWT token."""
|
"""Create admin authentication cookies with JWT token."""
|
||||||
token_data = {
|
access_token = create_access_token_for_user(admin_user)
|
||||||
"sub": str(admin_user.id),
|
|
||||||
"email": admin_user.email,
|
|
||||||
"role": admin_user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
access_token = JWTUtils.create_access_token(token_data)
|
|
||||||
|
|
||||||
return {"access_token": access_token}
|
return {"access_token": access_token}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from app.models.plan import Plan
|
|||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.user import UserRepository
|
from app.repositories.user import UserRepository
|
||||||
from app.utils.auth import PasswordUtils
|
from app.utils.auth import PasswordUtils
|
||||||
|
from app.utils.database import create_and_save
|
||||||
|
|
||||||
|
|
||||||
class TestUserRepository:
|
class TestUserRepository:
|
||||||
|
|||||||
Reference in New Issue
Block a user