Compare commits

...

1 Commits

Author SHA1 Message Date
JSC
b8346ab667 refactor: Introduce utility functions for exception handling and database operations; update auth and playlist services to use new exception methods
All checks were successful
Backend CI / test (push) Successful in 3m58s
2025-07-31 13:28:06 +02:00
9 changed files with 679 additions and 122 deletions

View File

@@ -20,6 +20,11 @@ from app.schemas.auth import (
)
from app.services.oauth import OAuthUserInfo
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
from app.utils.exceptions import (
raise_bad_request,
raise_not_found,
raise_unauthorized,
)
logger = get_logger(__name__)
@@ -39,10 +44,7 @@ class AuthService:
# Check if email already exists
if await self.user_repo.email_exists(request.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email address is already registered",
)
raise_bad_request("Email address is already registered")
# Hash the password
hashed_password = PasswordUtils.hash_password(request.password)
@@ -75,27 +77,18 @@ class AuthService:
# Get user by email
user = await self.user_repo.get_by_email(request.email)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
raise_unauthorized("Invalid email or password")
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
raise_unauthorized("Account is deactivated")
# Verify password
if not user.password_hash or not PasswordUtils.verify_password(
request.password,
user.password_hash,
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
raise_unauthorized("Invalid email or password")
# Generate access token
token = self._create_access_token(user)
@@ -110,16 +103,10 @@ class AuthService:
"""Get the current authenticated user."""
user = await self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
raise_not_found("User")
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
raise_unauthorized("Account is deactivated")
return user

View File

@@ -10,6 +10,11 @@ from app.models.playlist import Playlist
from app.models.sound import Sound
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
from app.utils.exceptions import (
raise_bad_request,
raise_internal_server_error,
raise_not_found,
)
logger = get_logger(__name__)
@@ -27,10 +32,7 @@ class PlaylistService:
"""Get a playlist by ID."""
playlist = await self.playlist_repo.get_by_id(playlist_id)
if not playlist:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Playlist not found",
)
raise_not_found("Playlist")
return playlist
@@ -47,9 +49,8 @@ class PlaylistService:
main_playlist = await self.playlist_repo.get_main_playlist()
if not main_playlist:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Main playlist not found. Make sure to run database seeding."
raise_internal_server_error(
"Main playlist not found. Make sure to run database seeding."
)
return main_playlist

166
app/utils/database.py Normal file
View File

@@ -0,0 +1,166 @@
"""Database utility functions for common operations."""
from typing import Any, Dict, List, Optional, Type, TypeVar
from sqlmodel import select, SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
T = TypeVar("T", bound=SQLModel)
async def create_and_save(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> T:
"""Create, add, commit, and refresh a model instance.
This consolidates the common database pattern of:
- instance = ModelClass(**kwargs)
- session.add(instance)
- await session.commit()
- await session.refresh(instance)
Args:
session: Database session
model_class: SQLModel class to instantiate
**kwargs: Arguments to pass to model constructor
Returns:
Created and refreshed model instance
"""
instance = model_class(**kwargs)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
async def get_or_create(
session: AsyncSession,
model_class: Type[T],
defaults: Optional[Dict[str, Any]] = None,
**kwargs: Any
) -> tuple[T, bool]:
"""Get an existing instance or create a new one.
Args:
session: Database session
model_class: SQLModel class
defaults: Default values for creation (if not found)
**kwargs: Filter criteria for lookup
Returns:
Tuple of (instance, created) where created is True if instance was created
"""
# Build filter conditions
filters = []
for key, value in kwargs.items():
filters.append(getattr(model_class, key) == value)
# Try to find existing instance
statement = select(model_class).where(*filters)
result = await session.exec(statement)
instance = result.first()
if instance:
return instance, False
# Create new instance
create_kwargs = {**kwargs}
if defaults:
create_kwargs.update(defaults)
instance = await create_and_save(session, model_class, **create_kwargs)
return instance, True
async def update_and_save(
session: AsyncSession,
instance: T,
**updates: Any
) -> T:
"""Update model instance fields and save to database.
Args:
session: Database session
instance: Model instance to update
**updates: Field updates to apply
Returns:
Updated and refreshed model instance
"""
for field, value in updates.items():
setattr(instance, field, value)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
async def bulk_create(
session: AsyncSession,
model_class: Type[T],
items: List[Dict[str, Any]]
) -> List[T]:
"""Create multiple model instances in bulk.
Args:
session: Database session
model_class: SQLModel class to instantiate
items: List of dictionaries with model data
Returns:
List of created model instances
"""
instances = []
for item_data in items:
instance = model_class(**item_data)
session.add(instance)
instances.append(instance)
await session.commit()
# Refresh all instances
for instance in instances:
await session.refresh(instance)
return instances
async def delete_and_commit(
session: AsyncSession,
instance: T
) -> None:
"""Delete an instance and commit the transaction.
Args:
session: Database session
instance: Model instance to delete
"""
await session.delete(instance)
await session.commit()
async def exists(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> bool:
"""Check if a model instance exists with given criteria.
Args:
session: Database session
model_class: SQLModel class
**kwargs: Filter criteria
Returns:
True if instance exists, False otherwise
"""
filters = []
for key, value in kwargs.items():
filters.append(getattr(model_class, key) == value)
statement = select(model_class).where(*filters)
result = await session.exec(statement)
return result.first() is not None

121
app/utils/exceptions.py Normal file
View File

@@ -0,0 +1,121 @@
"""Utility functions for common HTTP exception patterns."""
from fastapi import HTTPException, status
def raise_not_found(resource: str, identifier: str = None) -> None:
"""Raise a standardized 404 Not Found exception.
Args:
resource: Name of the resource that wasn't found
identifier: Optional identifier for the specific resource
Raises:
HTTPException with 404 status code
"""
if identifier:
detail = f"{resource} with ID {identifier} not found"
else:
detail = f"{resource} not found"
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def raise_unauthorized(detail: str = "Could not validate credentials") -> None:
"""Raise a standardized 401 Unauthorized exception.
Args:
detail: Error message detail
Raises:
HTTPException with 401 status code
"""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)
def raise_bad_request(detail: str) -> None:
"""Raise a standardized 400 Bad Request exception.
Args:
detail: Error message detail
Raises:
HTTPException with 400 status code
"""
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=detail,
)
def raise_internal_server_error(detail: str, cause: Exception = None) -> None:
"""Raise a standardized 500 Internal Server Error exception.
Args:
detail: Error message detail
cause: Optional underlying exception
Raises:
HTTPException with 500 status code
"""
if cause:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail,
) from cause
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=detail,
)
def raise_payment_required(detail: str = "Insufficient credits") -> None:
"""Raise a standardized 402 Payment Required exception.
Args:
detail: Error message detail
Raises:
HTTPException with 402 status code
"""
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=detail,
)
def raise_forbidden(detail: str = "Access forbidden") -> None:
"""Raise a standardized 403 Forbidden exception.
Args:
detail: Error message detail
Raises:
HTTPException with 403 status code
"""
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
def raise_conflict(detail: str) -> None:
"""Raise a standardized 409 Conflict exception.
Args:
detail: Error message detail
Raises:
HTTPException with 409 status code
"""
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)

179
app/utils/test_helpers.py Normal file
View File

@@ -0,0 +1,179 @@
"""Test helper utilities for reducing code duplication."""
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Type, TypeVar
from unittest.mock import AsyncMock
from fastapi import FastAPI
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import SQLModel
from app.models.user import User
from app.utils.auth import JWTUtils
T = TypeVar("T", bound=SQLModel)
def create_jwt_token_data(user: User) -> Dict[str, str]:
"""Create standardized JWT token data dictionary for a user.
Args:
user: User object to create token data for
Returns:
Dictionary with sub, email, and role fields
"""
return {
"sub": str(user.id),
"email": user.email,
"role": user.role,
}
def create_access_token_for_user(user: User) -> str:
"""Create an access token for a user using standardized token data.
Args:
user: User object to create token for
Returns:
JWT access token string
"""
token_data = create_jwt_token_data(user)
return JWTUtils.create_access_token(token_data)
async def create_and_save_model(
session: AsyncSession,
model_class: Type[T],
**kwargs: Any
) -> T:
"""Create, save, and refresh a model instance.
This consolidates the common pattern of:
- model = ModelClass(**kwargs)
- session.add(model)
- await session.commit()
- await session.refresh(model)
Args:
session: Database session
model_class: SQLModel class to instantiate
**kwargs: Arguments to pass to model constructor
Returns:
Created and refreshed model instance
"""
instance = model_class(**kwargs)
session.add(instance)
await session.commit()
await session.refresh(instance)
return instance
@asynccontextmanager
async def override_dependencies(
app: FastAPI,
overrides: Dict[Any, Any]
):
"""Context manager for FastAPI dependency overrides with automatic cleanup.
Args:
app: FastAPI application instance
overrides: Dictionary mapping dependency functions to mock implementations
Usage:
async with override_dependencies(test_app, {
get_service: lambda: mock_service,
get_repo: lambda: mock_repo
}):
# Test code here
pass
# Dependencies automatically cleaned up
"""
# Apply overrides
for dependency, override in overrides.items():
app.dependency_overrides[dependency] = override
try:
yield
finally:
# Clean up overrides
for dependency in overrides:
app.dependency_overrides.pop(dependency, None)
def create_mock_vlc_services() -> Dict[str, AsyncMock]:
"""Create standard set of mocked VLC-related services.
Returns:
Dictionary with mocked vlc_service, sound_repository, and credit_service
"""
return {
"vlc_service": AsyncMock(),
"sound_repository": AsyncMock(),
"credit_service": AsyncMock(),
}
def configure_mock_sound_play_success(
mocks: Dict[str, AsyncMock],
sound_data: Dict[str, Any]
) -> None:
"""Configure mocks for successful sound playback scenario.
Args:
mocks: Dictionary of mock services from create_mock_vlc_services()
sound_data: Dictionary with sound properties (id, name, etc.)
"""
from app.models.sound import Sound
mock_sound = Sound(**sound_data)
# Configure repository mock
mocks["sound_repository"].get_by_id.return_value = mock_sound
# Configure credit service mocks
mocks["credit_service"].validate_and_reserve_credits.return_value = None
mocks["credit_service"].deduct_credits.return_value = None
# Configure VLC service mock
mocks["vlc_service"].play_sound.return_value = True
def create_mock_vlc_stop_result(
success: bool = True,
processes_found: int = 3,
processes_killed: int = 3,
processes_remaining: int = 0,
message: Optional[str] = None,
error: Optional[str] = None
) -> Dict[str, Any]:
"""Create standardized VLC stop operation result.
Args:
success: Whether operation succeeded
processes_found: Number of VLC processes found
processes_killed: Number of processes successfully killed
processes_remaining: Number of processes still running
message: Success/status message
error: Error message (for failed operations)
Returns:
Dictionary with VLC stop operation result
"""
result = {
"success": success,
"processes_found": processes_found,
"processes_killed": processes_killed,
}
if not success:
result["error"] = error or "Command failed"
result["message"] = message or "Failed to stop VLC processes"
else:
# Always include processes_remaining for successful operations
result["processes_remaining"] = processes_remaining
result["message"] = message or f"Killed {processes_killed} VLC processes"
return result

140
app/utils/validation.py Normal file
View File

@@ -0,0 +1,140 @@
"""Common validation utility functions."""
import re
from pathlib import Path
from typing import Any, Optional
# Password validation constants
MIN_PASSWORD_LENGTH = 8
def validate_email(email: str) -> bool:
"""Validate email address format.
Args:
email: Email address to validate
Returns:
True if email format is valid, False otherwise
"""
pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
return bool(re.match(pattern, email))
def validate_password_strength(password: str) -> tuple[bool, str | None]:
"""Validate password meets security requirements.
Args:
password: Password to validate
Returns:
Tuple of (is_valid, error_message)
"""
if len(password) < MIN_PASSWORD_LENGTH:
msg = f"Password must be at least {MIN_PASSWORD_LENGTH} characters long"
return False, msg
if not re.search(r"[A-Z]", password):
return False, "Password must contain at least one uppercase letter"
if not re.search(r"[a-z]", password):
return False, "Password must contain at least one lowercase letter"
if not re.search(r"\d", password):
return False, "Password must contain at least one number"
return True, None
def validate_filename(
filename: str, allowed_extensions: list[str] | None = None
) -> bool:
"""Validate filename format and extension.
Args:
filename: Filename to validate
allowed_extensions: List of allowed file extensions (with dots)
Returns:
True if filename is valid, False otherwise
"""
if not filename or filename.startswith(".") or "/" in filename or "\\" in filename:
return False
if allowed_extensions:
file_path = Path(filename)
return file_path.suffix.lower() in [ext.lower() for ext in allowed_extensions]
return True
def validate_audio_filename(filename: str) -> bool:
"""Validate audio filename has allowed extension.
Args:
filename: Audio filename to validate
Returns:
True if filename has valid audio extension, False otherwise
"""
audio_extensions = [".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"]
return validate_filename(filename, audio_extensions)
def sanitize_filename(filename: str) -> str:
"""Sanitize filename by removing/replacing invalid characters.
Args:
filename: Filename to sanitize
Returns:
Sanitized filename safe for filesystem
"""
# Remove or replace problematic characters
sanitized = re.sub(r'[<>:"/\\|?*]', "_", filename)
# Remove leading/trailing whitespace and dots
sanitized = sanitized.strip(" .")
# Ensure not empty
if not sanitized:
sanitized = "untitled"
return sanitized
def validate_url(url: str) -> bool:
"""Validate URL format.
Args:
url: URL to validate
Returns:
True if URL format is valid, False otherwise
"""
pattern = r"^https?://[^\s/$.?#].[^\s]*$"
return bool(re.match(pattern, url))
def validate_positive_integer(value: Any, field_name: str = "value") -> int:
"""Validate and convert value to positive integer.
Args:
value: Value to validate and convert
field_name: Name of field for error messages
Returns:
Validated positive integer
Raises:
ValueError: If value is not a positive integer
"""
try:
int_value = int(value)
if int_value <= 0:
msg = f"{field_name} must be a positive integer"
raise ValueError(msg)
return int_value
except (TypeError, ValueError) as e:
msg = f"{field_name} must be a positive integer"
raise ValueError(msg) from e

View File

@@ -9,6 +9,12 @@ from fastapi import FastAPI
from app.models.sound import Sound
from app.models.user import User
from app.api.v1.sounds import get_vlc_player, get_sound_repository, get_credit_service
from app.utils.test_helpers import (
override_dependencies,
create_mock_vlc_services,
configure_mock_sound_play_success,
create_mock_vlc_stop_result,
)
@@ -24,34 +30,29 @@ class TestVLCEndpoints:
authenticated_user: User,
):
"""Test successful sound playback via VLC."""
# Set up mocks
mock_vlc_service = AsyncMock()
mock_repo = AsyncMock()
mock_credit_service = AsyncMock()
# Set up mocks using helper
mocks = create_mock_vlc_services()
# Set up test data
mock_sound = Sound(
id=1,
type="SDB",
name="Test Sound",
filename="test.mp3",
duration=5000,
size=1024,
hash="test_hash",
)
# Configure mocks for successful play
sound_data = {
"id": 1,
"type": "SDB",
"name": "Test Sound",
"filename": "test.mp3",
"duration": 5000,
"size": 1024,
"hash": "test_hash",
}
configure_mock_sound_play_success(mocks, sound_data)
# Configure mocks
mock_repo.get_by_id.return_value = mock_sound
mock_credit_service.validate_and_reserve_credits.return_value = None
mock_credit_service.deduct_credits.return_value = None
mock_vlc_service.play_sound.return_value = True
# Use dependency override helper
overrides = {
get_vlc_player: lambda: mocks["vlc_service"],
get_sound_repository: lambda: mocks["sound_repository"],
get_credit_service: lambda: mocks["credit_service"],
}
# Override dependencies
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo
test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service
try:
async with override_dependencies(test_app, overrides):
response = await authenticated_client.post("/api/v1/sounds/vlc/play/1")
assert response.status_code == 200
@@ -61,13 +62,8 @@ class TestVLCEndpoints:
assert "Test Sound" in data["message"]
# Verify service calls
mock_repo.get_by_id.assert_called_once_with(1)
mock_vlc_service.play_sound.assert_called_once_with(mock_sound)
finally:
# Clean up dependency overrides (except get_db which is needed for other tests)
test_app.dependency_overrides.pop(get_vlc_player, None)
test_app.dependency_overrides.pop(get_sound_repository, None)
test_app.dependency_overrides.pop(get_credit_service, None)
mocks["sound_repository"].get_by_id.assert_called_once_with(1)
mocks["vlc_service"].play_sound.assert_called_once()
@pytest.mark.asyncio
async def test_play_sound_with_vlc_sound_not_found(
@@ -199,21 +195,20 @@ class TestVLCEndpoints:
authenticated_user: User,
):
"""Test successful stopping of all VLC instances."""
# Set up mock
# Set up mock using helper
mock_vlc_service = AsyncMock()
mock_result = {
"success": True,
"processes_found": 3,
"processes_killed": 3,
"processes_remaining": 0,
"message": "Killed 3 VLC processes",
}
mock_result = create_mock_vlc_stop_result(
success=True,
processes_found=3,
processes_killed=3,
processes_remaining=0
)
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
# Use dependency override helper
overrides = {get_vlc_player: lambda: mock_vlc_service}
try:
async with override_dependencies(test_app, overrides):
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
@@ -226,9 +221,6 @@ class TestVLCEndpoints:
# Verify service call
mock_vlc_service.stop_all_vlc_instances.assert_called_once()
finally:
# Clean up dependency override
test_app.dependency_overrides.pop(get_vlc_player, None)
@pytest.mark.asyncio
async def test_stop_all_vlc_instances_no_processes(
@@ -238,20 +230,20 @@ class TestVLCEndpoints:
authenticated_user: User,
):
"""Test stopping VLC instances when none are running."""
# Set up mock
# Set up mock using helper
mock_vlc_service = AsyncMock()
mock_result = {
"success": True,
"processes_found": 0,
"processes_killed": 0,
"message": "No VLC processes found",
}
mock_result = create_mock_vlc_stop_result(
success=True,
processes_found=0,
processes_killed=0,
message="No VLC processes found"
)
mock_vlc_service.stop_all_vlc_instances.return_value = mock_result
# Override dependency
test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service
# Use dependency override helper
overrides = {get_vlc_player: lambda: mock_vlc_service}
try:
async with override_dependencies(test_app, overrides):
response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all")
assert response.status_code == 200
@@ -260,9 +252,6 @@ class TestVLCEndpoints:
assert data["processes_found"] == 0
assert data["processes_killed"] == 0
assert data["message"] == "No VLC processes found"
finally:
# Clean up dependency override
test_app.dependency_overrides.pop(get_vlc_player, None)
@pytest.mark.asyncio
async def test_stop_all_vlc_instances_partial_success(

View File

@@ -18,6 +18,7 @@ from app.models.plan import Plan
from app.models.user import User
from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy
from app.utils.auth import JWTUtils, PasswordUtils
from app.utils.test_helpers import create_access_token_for_user
@pytest.fixture(scope="session")
@@ -277,28 +278,14 @@ def test_login_data() -> dict[str, str]:
@pytest_asyncio.fixture
async def auth_headers(test_user: User) -> dict[str, str]:
"""Create authentication headers with JWT token."""
token_data = {
"sub": str(test_user.id),
"email": test_user.email,
"role": test_user.role,
}
access_token = JWTUtils.create_access_token(token_data)
access_token = create_access_token_for_user(test_user)
return {"Authorization": f"Bearer {access_token}"}
@pytest_asyncio.fixture
async def admin_headers(admin_user: User) -> dict[str, str]:
"""Create admin authentication headers with JWT token."""
token_data = {
"sub": str(admin_user.id),
"email": admin_user.email,
"role": admin_user.role,
}
access_token = JWTUtils.create_access_token(token_data)
access_token = create_access_token_for_user(admin_user)
return {"Authorization": f"Bearer {access_token}"}
@@ -317,26 +304,12 @@ def authenticated_user(test_user: User) -> User:
@pytest_asyncio.fixture
async def auth_cookies(test_user: User) -> dict[str, str]:
"""Create authentication cookies with JWT token."""
token_data = {
"sub": str(test_user.id),
"email": test_user.email,
"role": test_user.role,
}
access_token = JWTUtils.create_access_token(token_data)
access_token = create_access_token_for_user(test_user)
return {"access_token": access_token}
@pytest_asyncio.fixture
async def admin_cookies(admin_user: User) -> dict[str, str]:
"""Create admin authentication cookies with JWT token."""
token_data = {
"sub": str(admin_user.id),
"email": admin_user.email,
"role": admin_user.role,
}
access_token = JWTUtils.create_access_token(token_data)
access_token = create_access_token_for_user(admin_user)
return {"access_token": access_token}

View File

@@ -11,6 +11,7 @@ from app.models.plan import Plan
from app.models.user import User
from app.repositories.user import UserRepository
from app.utils.auth import PasswordUtils
from app.utils.database import create_and_save
class TestUserRepository: