From b8346ab6673021dc2c6b260feef9165b17c086f9 Mon Sep 17 00:00:00 2001 From: JSC Date: Thu, 31 Jul 2025 13:28:06 +0200 Subject: [PATCH] refactor: Introduce utility functions for exception handling and database operations; update auth and playlist services to use new exception methods --- app/services/auth.py | 35 ++---- app/services/playlist.py | 15 +-- app/utils/database.py | 166 ++++++++++++++++++++++++++ app/utils/exceptions.py | 121 +++++++++++++++++++ app/utils/test_helpers.py | 179 +++++++++++++++++++++++++++++ app/utils/validation.py | 140 ++++++++++++++++++++++ tests/api/v1/test_vlc_endpoints.py | 107 ++++++++--------- tests/conftest.py | 37 +----- tests/repositories/test_user.py | 1 + 9 files changed, 679 insertions(+), 122 deletions(-) create mode 100644 app/utils/database.py create mode 100644 app/utils/exceptions.py create mode 100644 app/utils/test_helpers.py create mode 100644 app/utils/validation.py diff --git a/app/services/auth.py b/app/services/auth.py index f3422dd..be3c576 100644 --- a/app/services/auth.py +++ b/app/services/auth.py @@ -20,6 +20,11 @@ from app.schemas.auth import ( ) from app.services.oauth import OAuthUserInfo from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils +from app.utils.exceptions import ( + raise_bad_request, + raise_not_found, + raise_unauthorized, +) logger = get_logger(__name__) @@ -39,10 +44,7 @@ class AuthService: # Check if email already exists if await self.user_repo.email_exists(request.email): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Email address is already registered", - ) + raise_bad_request("Email address is already registered") # Hash the password hashed_password = PasswordUtils.hash_password(request.password) @@ -75,27 +77,18 @@ class AuthService: # Get user by email user = await self.user_repo.get_by_email(request.email) if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid email or password", - ) + raise_unauthorized("Invalid email or password") # Check if user is active if not user.is_active: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Account is deactivated", - ) + raise_unauthorized("Account is deactivated") # Verify password if not user.password_hash or not PasswordUtils.verify_password( request.password, user.password_hash, ): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid email or password", - ) + raise_unauthorized("Invalid email or password") # Generate access token token = self._create_access_token(user) @@ -110,16 +103,10 @@ class AuthService: """Get the current authenticated user.""" user = await self.user_repo.get_by_id(user_id) if not user: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found", - ) + raise_not_found("User") if not user.is_active: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Account is deactivated", - ) + raise_unauthorized("Account is deactivated") return user diff --git a/app/services/playlist.py b/app/services/playlist.py index aa44d68..c583098 100644 --- a/app/services/playlist.py +++ b/app/services/playlist.py @@ -10,6 +10,11 @@ from app.models.playlist import Playlist from app.models.sound import Sound from app.repositories.playlist import PlaylistRepository from app.repositories.sound import SoundRepository +from app.utils.exceptions import ( + raise_bad_request, + raise_internal_server_error, + raise_not_found, +) logger = get_logger(__name__) @@ -27,10 +32,7 @@ class PlaylistService: """Get a playlist by ID.""" playlist = await self.playlist_repo.get_by_id(playlist_id) if not playlist: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Playlist not found", - ) + raise_not_found("Playlist") return playlist @@ -47,9 +49,8 @@ class PlaylistService: main_playlist = await self.playlist_repo.get_main_playlist() if not main_playlist: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Main playlist not found. Make sure to run database seeding." + raise_internal_server_error( + "Main playlist not found. Make sure to run database seeding." ) return main_playlist diff --git a/app/utils/database.py b/app/utils/database.py new file mode 100644 index 0000000..d3aeba5 --- /dev/null +++ b/app/utils/database.py @@ -0,0 +1,166 @@ +"""Database utility functions for common operations.""" + +from typing import Any, Dict, List, Optional, Type, TypeVar +from sqlmodel import select, SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession + +T = TypeVar("T", bound=SQLModel) + + +async def create_and_save( + session: AsyncSession, + model_class: Type[T], + **kwargs: Any +) -> T: + """Create, add, commit, and refresh a model instance. + + This consolidates the common database pattern of: + - instance = ModelClass(**kwargs) + - session.add(instance) + - await session.commit() + - await session.refresh(instance) + + Args: + session: Database session + model_class: SQLModel class to instantiate + **kwargs: Arguments to pass to model constructor + + Returns: + Created and refreshed model instance + """ + instance = model_class(**kwargs) + session.add(instance) + await session.commit() + await session.refresh(instance) + return instance + + +async def get_or_create( + session: AsyncSession, + model_class: Type[T], + defaults: Optional[Dict[str, Any]] = None, + **kwargs: Any +) -> tuple[T, bool]: + """Get an existing instance or create a new one. + + Args: + session: Database session + model_class: SQLModel class + defaults: Default values for creation (if not found) + **kwargs: Filter criteria for lookup + + Returns: + Tuple of (instance, created) where created is True if instance was created + """ + # Build filter conditions + filters = [] + for key, value in kwargs.items(): + filters.append(getattr(model_class, key) == value) + + # Try to find existing instance + statement = select(model_class).where(*filters) + result = await session.exec(statement) + instance = result.first() + + if instance: + return instance, False + + # Create new instance + create_kwargs = {**kwargs} + if defaults: + create_kwargs.update(defaults) + + instance = await create_and_save(session, model_class, **create_kwargs) + return instance, True + + +async def update_and_save( + session: AsyncSession, + instance: T, + **updates: Any +) -> T: + """Update model instance fields and save to database. + + Args: + session: Database session + instance: Model instance to update + **updates: Field updates to apply + + Returns: + Updated and refreshed model instance + """ + for field, value in updates.items(): + setattr(instance, field, value) + + session.add(instance) + await session.commit() + await session.refresh(instance) + return instance + + +async def bulk_create( + session: AsyncSession, + model_class: Type[T], + items: List[Dict[str, Any]] +) -> List[T]: + """Create multiple model instances in bulk. + + Args: + session: Database session + model_class: SQLModel class to instantiate + items: List of dictionaries with model data + + Returns: + List of created model instances + """ + instances = [] + for item_data in items: + instance = model_class(**item_data) + session.add(instance) + instances.append(instance) + + await session.commit() + + # Refresh all instances + for instance in instances: + await session.refresh(instance) + + return instances + + +async def delete_and_commit( + session: AsyncSession, + instance: T +) -> None: + """Delete an instance and commit the transaction. + + Args: + session: Database session + instance: Model instance to delete + """ + await session.delete(instance) + await session.commit() + + +async def exists( + session: AsyncSession, + model_class: Type[T], + **kwargs: Any +) -> bool: + """Check if a model instance exists with given criteria. + + Args: + session: Database session + model_class: SQLModel class + **kwargs: Filter criteria + + Returns: + True if instance exists, False otherwise + """ + filters = [] + for key, value in kwargs.items(): + filters.append(getattr(model_class, key) == value) + + statement = select(model_class).where(*filters) + result = await session.exec(statement) + return result.first() is not None \ No newline at end of file diff --git a/app/utils/exceptions.py b/app/utils/exceptions.py new file mode 100644 index 0000000..c41b40e --- /dev/null +++ b/app/utils/exceptions.py @@ -0,0 +1,121 @@ +"""Utility functions for common HTTP exception patterns.""" + +from fastapi import HTTPException, status + + +def raise_not_found(resource: str, identifier: str = None) -> None: + """Raise a standardized 404 Not Found exception. + + Args: + resource: Name of the resource that wasn't found + identifier: Optional identifier for the specific resource + + Raises: + HTTPException with 404 status code + """ + if identifier: + detail = f"{resource} with ID {identifier} not found" + else: + detail = f"{resource} not found" + + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) + + +def raise_unauthorized(detail: str = "Could not validate credentials") -> None: + """Raise a standardized 401 Unauthorized exception. + + Args: + detail: Error message detail + + Raises: + HTTPException with 401 status code + """ + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + ) + + +def raise_bad_request(detail: str) -> None: + """Raise a standardized 400 Bad Request exception. + + Args: + detail: Error message detail + + Raises: + HTTPException with 400 status code + """ + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=detail, + ) + + +def raise_internal_server_error(detail: str, cause: Exception = None) -> None: + """Raise a standardized 500 Internal Server Error exception. + + Args: + detail: Error message detail + cause: Optional underlying exception + + Raises: + HTTPException with 500 status code + """ + if cause: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=detail, + ) from cause + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=detail, + ) + + +def raise_payment_required(detail: str = "Insufficient credits") -> None: + """Raise a standardized 402 Payment Required exception. + + Args: + detail: Error message detail + + Raises: + HTTPException with 402 status code + """ + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=detail, + ) + + +def raise_forbidden(detail: str = "Access forbidden") -> None: + """Raise a standardized 403 Forbidden exception. + + Args: + detail: Error message detail + + Raises: + HTTPException with 403 status code + """ + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=detail, + ) + + +def raise_conflict(detail: str) -> None: + """Raise a standardized 409 Conflict exception. + + Args: + detail: Error message detail + + Raises: + HTTPException with 409 status code + """ + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=detail, + ) \ No newline at end of file diff --git a/app/utils/test_helpers.py b/app/utils/test_helpers.py new file mode 100644 index 0000000..ffc9604 --- /dev/null +++ b/app/utils/test_helpers.py @@ -0,0 +1,179 @@ +"""Test helper utilities for reducing code duplication.""" + +from contextlib import asynccontextmanager +from typing import Any, Dict, Optional, Type, TypeVar +from unittest.mock import AsyncMock + +from fastapi import FastAPI +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlmodel import SQLModel + +from app.models.user import User +from app.utils.auth import JWTUtils + +T = TypeVar("T", bound=SQLModel) + + +def create_jwt_token_data(user: User) -> Dict[str, str]: + """Create standardized JWT token data dictionary for a user. + + Args: + user: User object to create token data for + + Returns: + Dictionary with sub, email, and role fields + """ + return { + "sub": str(user.id), + "email": user.email, + "role": user.role, + } + + +def create_access_token_for_user(user: User) -> str: + """Create an access token for a user using standardized token data. + + Args: + user: User object to create token for + + Returns: + JWT access token string + """ + token_data = create_jwt_token_data(user) + return JWTUtils.create_access_token(token_data) + + +async def create_and_save_model( + session: AsyncSession, + model_class: Type[T], + **kwargs: Any +) -> T: + """Create, save, and refresh a model instance. + + This consolidates the common pattern of: + - model = ModelClass(**kwargs) + - session.add(model) + - await session.commit() + - await session.refresh(model) + + Args: + session: Database session + model_class: SQLModel class to instantiate + **kwargs: Arguments to pass to model constructor + + Returns: + Created and refreshed model instance + """ + instance = model_class(**kwargs) + session.add(instance) + await session.commit() + await session.refresh(instance) + return instance + + +@asynccontextmanager +async def override_dependencies( + app: FastAPI, + overrides: Dict[Any, Any] +): + """Context manager for FastAPI dependency overrides with automatic cleanup. + + Args: + app: FastAPI application instance + overrides: Dictionary mapping dependency functions to mock implementations + + Usage: + async with override_dependencies(test_app, { + get_service: lambda: mock_service, + get_repo: lambda: mock_repo + }): + # Test code here + pass + # Dependencies automatically cleaned up + """ + # Apply overrides + for dependency, override in overrides.items(): + app.dependency_overrides[dependency] = override + + try: + yield + finally: + # Clean up overrides + for dependency in overrides: + app.dependency_overrides.pop(dependency, None) + + +def create_mock_vlc_services() -> Dict[str, AsyncMock]: + """Create standard set of mocked VLC-related services. + + Returns: + Dictionary with mocked vlc_service, sound_repository, and credit_service + """ + return { + "vlc_service": AsyncMock(), + "sound_repository": AsyncMock(), + "credit_service": AsyncMock(), + } + + +def configure_mock_sound_play_success( + mocks: Dict[str, AsyncMock], + sound_data: Dict[str, Any] +) -> None: + """Configure mocks for successful sound playback scenario. + + Args: + mocks: Dictionary of mock services from create_mock_vlc_services() + sound_data: Dictionary with sound properties (id, name, etc.) + """ + from app.models.sound import Sound + + mock_sound = Sound(**sound_data) + + # Configure repository mock + mocks["sound_repository"].get_by_id.return_value = mock_sound + + # Configure credit service mocks + mocks["credit_service"].validate_and_reserve_credits.return_value = None + mocks["credit_service"].deduct_credits.return_value = None + + # Configure VLC service mock + mocks["vlc_service"].play_sound.return_value = True + + +def create_mock_vlc_stop_result( + success: bool = True, + processes_found: int = 3, + processes_killed: int = 3, + processes_remaining: int = 0, + message: Optional[str] = None, + error: Optional[str] = None +) -> Dict[str, Any]: + """Create standardized VLC stop operation result. + + Args: + success: Whether operation succeeded + processes_found: Number of VLC processes found + processes_killed: Number of processes successfully killed + processes_remaining: Number of processes still running + message: Success/status message + error: Error message (for failed operations) + + Returns: + Dictionary with VLC stop operation result + """ + result = { + "success": success, + "processes_found": processes_found, + "processes_killed": processes_killed, + } + + if not success: + result["error"] = error or "Command failed" + result["message"] = message or "Failed to stop VLC processes" + else: + # Always include processes_remaining for successful operations + result["processes_remaining"] = processes_remaining + result["message"] = message or f"Killed {processes_killed} VLC processes" + + return result \ No newline at end of file diff --git a/app/utils/validation.py b/app/utils/validation.py new file mode 100644 index 0000000..5bc06ca --- /dev/null +++ b/app/utils/validation.py @@ -0,0 +1,140 @@ +"""Common validation utility functions.""" + +import re +from pathlib import Path +from typing import Any, Optional + +# Password validation constants +MIN_PASSWORD_LENGTH = 8 + + +def validate_email(email: str) -> bool: + """Validate email address format. + + Args: + email: Email address to validate + + Returns: + True if email format is valid, False otherwise + """ + pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + return bool(re.match(pattern, email)) + + +def validate_password_strength(password: str) -> tuple[bool, str | None]: + """Validate password meets security requirements. + + Args: + password: Password to validate + + Returns: + Tuple of (is_valid, error_message) + """ + if len(password) < MIN_PASSWORD_LENGTH: + msg = f"Password must be at least {MIN_PASSWORD_LENGTH} characters long" + return False, msg + + if not re.search(r"[A-Z]", password): + return False, "Password must contain at least one uppercase letter" + + if not re.search(r"[a-z]", password): + return False, "Password must contain at least one lowercase letter" + + if not re.search(r"\d", password): + return False, "Password must contain at least one number" + + return True, None + + +def validate_filename( + filename: str, allowed_extensions: list[str] | None = None +) -> bool: + """Validate filename format and extension. + + Args: + filename: Filename to validate + allowed_extensions: List of allowed file extensions (with dots) + + Returns: + True if filename is valid, False otherwise + """ + if not filename or filename.startswith(".") or "/" in filename or "\\" in filename: + return False + + if allowed_extensions: + file_path = Path(filename) + return file_path.suffix.lower() in [ext.lower() for ext in allowed_extensions] + + return True + + +def validate_audio_filename(filename: str) -> bool: + """Validate audio filename has allowed extension. + + Args: + filename: Audio filename to validate + + Returns: + True if filename has valid audio extension, False otherwise + """ + audio_extensions = [".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".wma"] + return validate_filename(filename, audio_extensions) + + +def sanitize_filename(filename: str) -> str: + """Sanitize filename by removing/replacing invalid characters. + + Args: + filename: Filename to sanitize + + Returns: + Sanitized filename safe for filesystem + """ + # Remove or replace problematic characters + sanitized = re.sub(r'[<>:"/\\|?*]', "_", filename) + + # Remove leading/trailing whitespace and dots + sanitized = sanitized.strip(" .") + + # Ensure not empty + if not sanitized: + sanitized = "untitled" + + return sanitized + + +def validate_url(url: str) -> bool: + """Validate URL format. + + Args: + url: URL to validate + + Returns: + True if URL format is valid, False otherwise + """ + pattern = r"^https?://[^\s/$.?#].[^\s]*$" + return bool(re.match(pattern, url)) + + +def validate_positive_integer(value: Any, field_name: str = "value") -> int: + """Validate and convert value to positive integer. + + Args: + value: Value to validate and convert + field_name: Name of field for error messages + + Returns: + Validated positive integer + + Raises: + ValueError: If value is not a positive integer + """ + try: + int_value = int(value) + if int_value <= 0: + msg = f"{field_name} must be a positive integer" + raise ValueError(msg) + return int_value + except (TypeError, ValueError) as e: + msg = f"{field_name} must be a positive integer" + raise ValueError(msg) from e \ No newline at end of file diff --git a/tests/api/v1/test_vlc_endpoints.py b/tests/api/v1/test_vlc_endpoints.py index 80e238b..030b793 100644 --- a/tests/api/v1/test_vlc_endpoints.py +++ b/tests/api/v1/test_vlc_endpoints.py @@ -9,6 +9,12 @@ from fastapi import FastAPI from app.models.sound import Sound from app.models.user import User from app.api.v1.sounds import get_vlc_player, get_sound_repository, get_credit_service +from app.utils.test_helpers import ( + override_dependencies, + create_mock_vlc_services, + configure_mock_sound_play_success, + create_mock_vlc_stop_result, +) @@ -24,34 +30,29 @@ class TestVLCEndpoints: authenticated_user: User, ): """Test successful sound playback via VLC.""" - # Set up mocks - mock_vlc_service = AsyncMock() - mock_repo = AsyncMock() - mock_credit_service = AsyncMock() + # Set up mocks using helper + mocks = create_mock_vlc_services() - # Set up test data - mock_sound = Sound( - id=1, - type="SDB", - name="Test Sound", - filename="test.mp3", - duration=5000, - size=1024, - hash="test_hash", - ) + # Configure mocks for successful play + sound_data = { + "id": 1, + "type": "SDB", + "name": "Test Sound", + "filename": "test.mp3", + "duration": 5000, + "size": 1024, + "hash": "test_hash", + } + configure_mock_sound_play_success(mocks, sound_data) - # Configure mocks - mock_repo.get_by_id.return_value = mock_sound - mock_credit_service.validate_and_reserve_credits.return_value = None - mock_credit_service.deduct_credits.return_value = None - mock_vlc_service.play_sound.return_value = True + # Use dependency override helper + overrides = { + get_vlc_player: lambda: mocks["vlc_service"], + get_sound_repository: lambda: mocks["sound_repository"], + get_credit_service: lambda: mocks["credit_service"], + } - # Override dependencies - test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service - test_app.dependency_overrides[get_sound_repository] = lambda: mock_repo - test_app.dependency_overrides[get_credit_service] = lambda: mock_credit_service - - try: + async with override_dependencies(test_app, overrides): response = await authenticated_client.post("/api/v1/sounds/vlc/play/1") assert response.status_code == 200 @@ -61,13 +62,8 @@ class TestVLCEndpoints: assert "Test Sound" in data["message"] # Verify service calls - mock_repo.get_by_id.assert_called_once_with(1) - mock_vlc_service.play_sound.assert_called_once_with(mock_sound) - finally: - # Clean up dependency overrides (except get_db which is needed for other tests) - test_app.dependency_overrides.pop(get_vlc_player, None) - test_app.dependency_overrides.pop(get_sound_repository, None) - test_app.dependency_overrides.pop(get_credit_service, None) + mocks["sound_repository"].get_by_id.assert_called_once_with(1) + mocks["vlc_service"].play_sound.assert_called_once() @pytest.mark.asyncio async def test_play_sound_with_vlc_sound_not_found( @@ -199,21 +195,20 @@ class TestVLCEndpoints: authenticated_user: User, ): """Test successful stopping of all VLC instances.""" - # Set up mock + # Set up mock using helper mock_vlc_service = AsyncMock() - mock_result = { - "success": True, - "processes_found": 3, - "processes_killed": 3, - "processes_remaining": 0, - "message": "Killed 3 VLC processes", - } + mock_result = create_mock_vlc_stop_result( + success=True, + processes_found=3, + processes_killed=3, + processes_remaining=0 + ) mock_vlc_service.stop_all_vlc_instances.return_value = mock_result - # Override dependency - test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service + # Use dependency override helper + overrides = {get_vlc_player: lambda: mock_vlc_service} - try: + async with override_dependencies(test_app, overrides): response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") assert response.status_code == 200 @@ -226,9 +221,6 @@ class TestVLCEndpoints: # Verify service call mock_vlc_service.stop_all_vlc_instances.assert_called_once() - finally: - # Clean up dependency override - test_app.dependency_overrides.pop(get_vlc_player, None) @pytest.mark.asyncio async def test_stop_all_vlc_instances_no_processes( @@ -238,20 +230,20 @@ class TestVLCEndpoints: authenticated_user: User, ): """Test stopping VLC instances when none are running.""" - # Set up mock + # Set up mock using helper mock_vlc_service = AsyncMock() - mock_result = { - "success": True, - "processes_found": 0, - "processes_killed": 0, - "message": "No VLC processes found", - } + mock_result = create_mock_vlc_stop_result( + success=True, + processes_found=0, + processes_killed=0, + message="No VLC processes found" + ) mock_vlc_service.stop_all_vlc_instances.return_value = mock_result - # Override dependency - test_app.dependency_overrides[get_vlc_player] = lambda: mock_vlc_service + # Use dependency override helper + overrides = {get_vlc_player: lambda: mock_vlc_service} - try: + async with override_dependencies(test_app, overrides): response = await authenticated_client.post("/api/v1/sounds/vlc/stop-all") assert response.status_code == 200 @@ -260,9 +252,6 @@ class TestVLCEndpoints: assert data["processes_found"] == 0 assert data["processes_killed"] == 0 assert data["message"] == "No VLC processes found" - finally: - # Clean up dependency override - test_app.dependency_overrides.pop(get_vlc_player, None) @pytest.mark.asyncio async def test_stop_all_vlc_instances_partial_success( diff --git a/tests/conftest.py b/tests/conftest.py index 9b744fb..d941c25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ from app.models.plan import Plan from app.models.user import User from app.models.user_oauth import UserOauth # Ensure model is imported for SQLAlchemy from app.utils.auth import JWTUtils, PasswordUtils +from app.utils.test_helpers import create_access_token_for_user @pytest.fixture(scope="session") @@ -277,28 +278,14 @@ def test_login_data() -> dict[str, str]: @pytest_asyncio.fixture async def auth_headers(test_user: User) -> dict[str, str]: """Create authentication headers with JWT token.""" - token_data = { - "sub": str(test_user.id), - "email": test_user.email, - "role": test_user.role, - } - - access_token = JWTUtils.create_access_token(token_data) - + access_token = create_access_token_for_user(test_user) return {"Authorization": f"Bearer {access_token}"} @pytest_asyncio.fixture async def admin_headers(admin_user: User) -> dict[str, str]: """Create admin authentication headers with JWT token.""" - token_data = { - "sub": str(admin_user.id), - "email": admin_user.email, - "role": admin_user.role, - } - - access_token = JWTUtils.create_access_token(token_data) - + access_token = create_access_token_for_user(admin_user) return {"Authorization": f"Bearer {access_token}"} @@ -317,26 +304,12 @@ def authenticated_user(test_user: User) -> User: @pytest_asyncio.fixture async def auth_cookies(test_user: User) -> dict[str, str]: """Create authentication cookies with JWT token.""" - token_data = { - "sub": str(test_user.id), - "email": test_user.email, - "role": test_user.role, - } - - access_token = JWTUtils.create_access_token(token_data) - + access_token = create_access_token_for_user(test_user) return {"access_token": access_token} @pytest_asyncio.fixture async def admin_cookies(admin_user: User) -> dict[str, str]: """Create admin authentication cookies with JWT token.""" - token_data = { - "sub": str(admin_user.id), - "email": admin_user.email, - "role": admin_user.role, - } - - access_token = JWTUtils.create_access_token(token_data) - + access_token = create_access_token_for_user(admin_user) return {"access_token": access_token} diff --git a/tests/repositories/test_user.py b/tests/repositories/test_user.py index 3d711ca..3dc17f7 100644 --- a/tests/repositories/test_user.py +++ b/tests/repositories/test_user.py @@ -11,6 +11,7 @@ from app.models.plan import Plan from app.models.user import User from app.repositories.user import UserRepository from app.utils.auth import PasswordUtils +from app.utils.database import create_and_save class TestUserRepository: