From 6b55ff0e81e75c5ab7f78f327a40717bfb7823e2 Mon Sep 17 00:00:00 2001 From: JSC Date: Sun, 17 Aug 2025 12:36:52 +0200 Subject: [PATCH] Refactor user endpoint tests to include pagination and response structure validation - Updated tests for listing users to validate pagination and response format. - Changed mock return values to include total count and pagination details. - Refactored user creation mocks for clarity and consistency. - Enhanced assertions to check for presence of pagination fields in responses. - Adjusted test cases for user retrieval and updates to ensure proper handling of user data. - Improved readability by restructuring mock definitions and assertions across various test files. --- app/api/v1/admin/users.py | 20 +- app/api/v1/auth.py | 31 +- app/api/v1/dashboard.py | 6 +- app/api/v1/extractions.py | 122 +++--- app/api/v1/favorites.py | 37 +- app/api/v1/playlists.py | 46 +- app/api/v1/sounds.py | 12 +- app/core/config.py | 2 +- app/models/base.py | 4 +- app/models/playlist_sound.py | 2 - app/repositories/extraction.py | 12 +- app/repositories/favorite.py | 12 +- app/repositories/playlist.py | 25 +- app/repositories/sound.py | 8 +- app/repositories/user.py | 77 ++-- app/schemas/auth.py | 8 +- app/schemas/favorite.py | 6 +- app/schemas/playlist.py | 11 +- app/schemas/sound.py | 26 +- app/schemas/user.py | 5 +- app/services/auth.py | 12 +- app/services/dashboard.py | 4 +- app/services/extraction.py | 22 +- app/services/favorite.py | 33 +- app/services/playlist.py | 15 +- tests/api/v1/admin/test_users_endpoints.py | 488 +++++++++++++-------- tests/api/v1/test_auth_endpoints.py | 42 +- tests/api/v1/test_playlist_endpoints.py | 10 +- tests/api/v1/test_sound_endpoints.py | 78 ++-- tests/conftest.py | 2 - tests/repositories/test_playlist.py | 56 ++- tests/services/test_credit.py | 4 +- tests/services/test_dashboard.py | 4 +- tests/services/test_extraction.py | 109 +++-- tests/services/test_scheduler.py | 15 +- 35 files changed, 863 insertions(+), 503 deletions(-) diff --git a/app/api/v1/admin/users.py b/app/api/v1/admin/users.py index 5e969ab..91da88f 100644 --- a/app/api/v1/admin/users.py +++ b/app/api/v1/admin/users.py @@ -10,7 +10,7 @@ from app.core.dependencies import get_admin_user from app.models.plan import Plan from app.models.user import User from app.repositories.plan import PlanRepository -from app.repositories.user import UserRepository, UserSortField, SortOrder, UserStatus +from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus from app.schemas.auth import UserResponse from app.schemas.user import UserUpdate @@ -36,21 +36,27 @@ def _user_to_response(user: User) -> UserResponse: "name": user.plan.name, "max_credits": user.plan.max_credits, "features": [], # Add features if needed - } if user.plan else {}, + } + if user.plan + else {}, created_at=user.created_at, updated_at=user.updated_at, ) @router.get("/") -async def list_users( +async def list_users( # noqa: PLR0913 session: Annotated[AsyncSession, Depends(get_db)], page: Annotated[int, Query(description="Page number", ge=1)] = 1, limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50, search: Annotated[str | None, Query(description="Search in name or email")] = None, - sort_by: Annotated[UserSortField, Query(description="Sort by field")] = UserSortField.NAME, + sort_by: Annotated[ + UserSortField, Query(description="Sort by field"), + ] = UserSortField.NAME, sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC, - status_filter: Annotated[UserStatus, Query(description="Filter by status")] = UserStatus.ALL, + status_filter: Annotated[ + UserStatus, Query(description="Filter by status"), + ] = UserStatus.ALL, ) -> dict[str, Any]: """Get all users with pagination, search, and filters (admin only).""" user_repo = UserRepository(session) @@ -62,9 +68,9 @@ async def list_users( sort_order=sort_order, status_filter=status_filter, ) - + total_pages = (total_count + limit - 1) // limit # Ceiling division - + return { "users": [_user_to_response(user) for user in users], "total": total_count, diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index 906bdd4..5b5e29a 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -464,7 +464,8 @@ async def update_profile( """Update the current user's profile.""" try: updated_user = await auth_service.update_user_profile( - current_user, request.model_dump(exclude_unset=True), + current_user, + request.model_dump(exclude_unset=True), ) return await auth_service.user_to_response(updated_user) except Exception as e: @@ -486,7 +487,9 @@ async def change_password( user_email = current_user.email try: await auth_service.change_user_password( - current_user, request.current_password, request.new_password, + current_user, + request.current_password, + request.new_password, ) except ValueError as e: raise HTTPException( @@ -513,11 +516,13 @@ async def get_user_providers( # Add password provider if user has password if current_user.password_hash: - providers.append({ - "provider": "password", - "display_name": "Password", - "connected_at": current_user.created_at.isoformat(), - }) + providers.append( + { + "provider": "password", + "display_name": "Password", + "connected_at": current_user.created_at.isoformat(), + }, + ) # Get OAuth providers from the database oauth_providers = await auth_service.get_user_oauth_providers(current_user) @@ -528,10 +533,12 @@ async def get_user_providers( elif oauth.provider == "google": display_name = "Google" - providers.append({ - "provider": oauth.provider, - "display_name": display_name, - "connected_at": oauth.created_at.isoformat(), - }) + providers.append( + { + "provider": oauth.provider, + "display_name": display_name, + "connected_at": oauth.created_at.isoformat(), + }, + ) return providers diff --git a/app/api/v1/dashboard.py b/app/api/v1/dashboard.py index 2920367..73690c6 100644 --- a/app/api/v1/dashboard.py +++ b/app/api/v1/dashboard.py @@ -34,7 +34,8 @@ async def get_top_sounds( _current_user: Annotated[User, Depends(get_current_user)], dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)], sound_type: Annotated[ - str, Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"), + str, + Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"), ], period: Annotated[ str, @@ -43,7 +44,8 @@ async def get_top_sounds( ), ] = "all_time", limit: Annotated[ - int, Query(description="Number of top sounds to return", ge=1, le=100), + int, + Query(description="Number of top sounds to return", ge=1, le=100), ] = 10, ) -> list[dict[str, Any]]: """Get top sounds by play count for a specific period.""" diff --git a/app/api/v1/extractions.py b/app/api/v1/extractions.py index b139cc2..9c535d3 100644 --- a/app/api/v1/extractions.py +++ b/app/api/v1/extractions.py @@ -60,68 +60,13 @@ async def create_extraction( } -@router.get("/{extraction_id}") -async def get_extraction( - extraction_id: int, - current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001 - extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)], -) -> ExtractionInfo: - """Get extraction information by ID.""" - try: - extraction_info = await extraction_service.get_extraction_by_id(extraction_id) - - if not extraction_info: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Extraction {extraction_id} not found", - ) - - except HTTPException: - raise - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to get extraction: {e!s}", - ) from e - else: - return extraction_info - - -@router.get("/") -async def get_all_extractions( - extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)], - search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None, - sort_by: Annotated[str, Query(description="Sort by field")] = "created_at", - sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc", - status_filter: Annotated[str | None, Query(description="Filter by status")] = None, - page: Annotated[int, Query(description="Page number", ge=1)] = 1, - limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50, -) -> dict: - """Get all extractions with optional filtering, search, and sorting.""" - try: - result = await extraction_service.get_all_extractions( - search=search, - sort_by=sort_by, - sort_order=sort_order, - status_filter=status_filter, - page=page, - limit=limit, - ) - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to get extractions: {e!s}", - ) from e - else: - return result - - @router.get("/user") -async def get_user_extractions( +async def get_user_extractions( # noqa: PLR0913 current_user: Annotated[User, Depends(get_current_active_user_flexible)], extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)], - search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None, + search: Annotated[ + str | None, Query(description="Search in title, URL, or service"), + ] = None, sort_by: Annotated[str, Query(description="Sort by field")] = "created_at", sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc", status_filter: Annotated[str | None, Query(description="Filter by status")] = None, @@ -153,3 +98,62 @@ async def get_user_extractions( ) from e else: return result + + +@router.get("/{extraction_id}") +async def get_extraction( + extraction_id: int, + current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001 + extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)], +) -> ExtractionInfo: + """Get extraction information by ID.""" + try: + extraction_info = await extraction_service.get_extraction_by_id(extraction_id) + + if not extraction_info: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Extraction {extraction_id} not found", + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get extraction: {e!s}", + ) from e + else: + return extraction_info + + +@router.get("/") +async def get_all_extractions( # noqa: PLR0913 + extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)], + search: Annotated[ + str | None, Query(description="Search in title, URL, or service"), + ] = None, + sort_by: Annotated[str, Query(description="Sort by field")] = "created_at", + sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc", + status_filter: Annotated[str | None, Query(description="Filter by status")] = None, + page: Annotated[int, Query(description="Page number", ge=1)] = 1, + limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50, +) -> dict: + """Get all extractions with optional filtering, search, and sorting.""" + try: + result = await extraction_service.get_all_extractions( + search=search, + sort_by=sort_by, + sort_order=sort_order, + status_filter=status_filter, + page=page, + limit=limit, + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get extractions: {e!s}", + ) from e + else: + return result diff --git a/app/api/v1/favorites.py b/app/api/v1/favorites.py index fd66382..eb13219 100644 --- a/app/api/v1/favorites.py +++ b/app/api/v1/favorites.py @@ -4,6 +4,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, Query, status +from app.core.database import get_session_factory from app.core.dependencies import get_current_active_user from app.models.user import User from app.schemas.common import MessageResponse @@ -19,12 +20,10 @@ router = APIRouter(prefix="/favorites", tags=["favorites"]) def get_favorite_service() -> FavoriteService: """Get the favorite service.""" - from app.core.database import get_session_factory - return FavoriteService(get_session_factory()) -@router.get("/", response_model=FavoritesListResponse) +@router.get("/") async def get_user_favorites( current_user: Annotated[User, Depends(get_current_active_user)], favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)], @@ -33,12 +32,14 @@ async def get_user_favorites( ) -> FavoritesListResponse: """Get all favorites for the current user.""" favorites = await favorite_service.get_user_favorites( - current_user.id, limit, offset, + current_user.id, + limit, + offset, ) return FavoritesListResponse(favorites=favorites) -@router.get("/sounds", response_model=FavoritesListResponse) +@router.get("/sounds") async def get_user_sound_favorites( current_user: Annotated[User, Depends(get_current_active_user)], favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)], @@ -47,12 +48,14 @@ async def get_user_sound_favorites( ) -> FavoritesListResponse: """Get sound favorites for the current user.""" favorites = await favorite_service.get_user_sound_favorites( - current_user.id, limit, offset, + current_user.id, + limit, + offset, ) return FavoritesListResponse(favorites=favorites) -@router.get("/playlists", response_model=FavoritesListResponse) +@router.get("/playlists") async def get_user_playlist_favorites( current_user: Annotated[User, Depends(get_current_active_user)], favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)], @@ -61,12 +64,14 @@ async def get_user_playlist_favorites( ) -> FavoritesListResponse: """Get playlist favorites for the current user.""" favorites = await favorite_service.get_user_playlist_favorites( - current_user.id, limit, offset, + current_user.id, + limit, + offset, ) return FavoritesListResponse(favorites=favorites) -@router.get("/counts", response_model=FavoriteCountsResponse) +@router.get("/counts") async def get_favorite_counts( current_user: Annotated[User, Depends(get_current_active_user)], favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)], @@ -76,7 +81,7 @@ async def get_favorite_counts( return FavoriteCountsResponse(**counts) -@router.post("/sounds/{sound_id}", response_model=FavoriteResponse) +@router.post("/sounds/{sound_id}") async def add_sound_favorite( sound_id: int, current_user: Annotated[User, Depends(get_current_active_user)], @@ -103,7 +108,7 @@ async def add_sound_favorite( ) from e -@router.post("/playlists/{playlist_id}", response_model=FavoriteResponse) +@router.post("/playlists/{playlist_id}") async def add_playlist_favorite( playlist_id: int, current_user: Annotated[User, Depends(get_current_active_user)], @@ -112,7 +117,8 @@ async def add_playlist_favorite( """Add a playlist to favorites.""" try: favorite = await favorite_service.add_playlist_favorite( - current_user.id, playlist_id, + current_user.id, + playlist_id, ) return FavoriteResponse.model_validate(favorite) except ValueError as e: @@ -132,7 +138,7 @@ async def add_playlist_favorite( ) from e -@router.delete("/sounds/{sound_id}", response_model=MessageResponse) +@router.delete("/sounds/{sound_id}") async def remove_sound_favorite( sound_id: int, current_user: Annotated[User, Depends(get_current_active_user)], @@ -149,7 +155,7 @@ async def remove_sound_favorite( ) from e -@router.delete("/playlists/{playlist_id}", response_model=MessageResponse) +@router.delete("/playlists/{playlist_id}") async def remove_playlist_favorite( playlist_id: int, current_user: Annotated[User, Depends(get_current_active_user)], @@ -185,6 +191,7 @@ async def check_playlist_favorited( ) -> dict[str, bool]: """Check if a playlist is favorited by the current user.""" is_favorited = await favorite_service.is_playlist_favorited( - current_user.id, playlist_id, + current_user.id, + playlist_id, ) return {"is_favorited": is_favorited} diff --git a/app/api/v1/playlists.py b/app/api/v1/playlists.py index 76cac08..2efcfba 100644 --- a/app/api/v1/playlists.py +++ b/app/api/v1/playlists.py @@ -5,7 +5,7 @@ from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlmodel.ext.asyncio.session import AsyncSession -from app.core.database import get_db +from app.core.database import get_db, get_session_factory from app.core.dependencies import get_current_active_user_flexible from app.models.user import User from app.repositories.playlist import PlaylistSortField, SortOrder @@ -34,7 +34,6 @@ async def get_playlist_service( def get_favorite_service() -> FavoriteService: """Get the favorite service.""" - from app.core.database import get_session_factory return FavoriteService(get_session_factory()) @@ -57,7 +56,7 @@ async def get_all_playlists( # noqa: PLR0913 ] = SortOrder.ASC, page: Annotated[int, Query(description="Page number", ge=1)] = 1, limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50, - favorites_only: Annotated[ + favorites_only: Annotated[ # noqa: FBT002 bool, Query(description="Show only favorited playlists"), ] = False, @@ -78,15 +77,26 @@ async def get_all_playlists( # noqa: PLR0913 # Convert to PlaylistResponse with favorite indicators playlist_responses = [] for playlist_dict in result["playlists"]: - # The playlist service returns dict, need to create playlist object-like structure - is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist_dict["id"]) - favorite_count = await favorite_service.get_playlist_favorite_count(playlist_dict["id"]) + # The playlist service returns dict, need to create playlist object structure + playlist_id = playlist_dict["id"] + is_favorited = await favorite_service.is_playlist_favorited( + current_user.id, playlist_id, + ) + favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id) # Create a PlaylistResponse-like dict with proper datetime conversion playlist_response = { **playlist_dict, - "created_at": playlist_dict["created_at"].isoformat() if playlist_dict["created_at"] else None, - "updated_at": playlist_dict["updated_at"].isoformat() if playlist_dict["updated_at"] else None, + "created_at": ( + playlist_dict["created_at"].isoformat() + if playlist_dict["created_at"] + else None + ), + "updated_at": ( + playlist_dict["updated_at"].isoformat() + if playlist_dict["updated_at"] + else None + ), "is_favorited": is_favorited, "favorite_count": favorite_count, } @@ -113,9 +123,13 @@ async def get_user_playlists( # Add favorite indicators for each playlist playlist_responses = [] for playlist in playlists: - is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id) + is_favorited = await favorite_service.is_playlist_favorited( + current_user.id, playlist.id, + ) favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id) - playlist_response = PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count) + playlist_response = PlaylistResponse.from_playlist( + playlist, is_favorited, favorite_count, + ) playlist_responses.append(playlist_response) return playlist_responses @@ -129,7 +143,9 @@ async def get_main_playlist( ) -> PlaylistResponse: """Get the global main playlist.""" playlist = await playlist_service.get_main_playlist() - is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id) + is_favorited = await favorite_service.is_playlist_favorited( + current_user.id, playlist.id, + ) favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id) return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count) @@ -142,7 +158,9 @@ async def get_current_playlist( ) -> PlaylistResponse: """Get the global current playlist (falls back to main playlist).""" playlist = await playlist_service.get_current_playlist() - is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id) + is_favorited = await favorite_service.is_playlist_favorited( + current_user.id, playlist.id, + ) favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id) return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count) @@ -172,7 +190,9 @@ async def get_playlist( ) -> PlaylistResponse: """Get a specific playlist.""" playlist = await playlist_service.get_playlist_by_id(playlist_id) - is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id) + is_favorited = await favorite_service.is_playlist_favorited( + current_user.id, playlist.id, + ) favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id) return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count) diff --git a/app/api/v1/sounds.py b/app/api/v1/sounds.py index 597e3d9..74c44ab 100644 --- a/app/api/v1/sounds.py +++ b/app/api/v1/sounds.py @@ -40,7 +40,7 @@ async def get_sound_repository( return SoundRepository(session) -@router.get("/", response_model=SoundsListResponse) +@router.get("/") async def get_sounds( # noqa: PLR0913 current_user: Annotated[User, Depends(get_current_active_user_flexible)], sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)], @@ -69,7 +69,7 @@ async def get_sounds( # noqa: PLR0913 int, Query(description="Number of results to skip", ge=0), ] = 0, - favorites_only: Annotated[ + favorites_only: Annotated[ # noqa: FBT002 bool, Query(description="Show only favorited sounds"), ] = False, @@ -90,9 +90,13 @@ async def get_sounds( # noqa: PLR0913 # Add favorite indicators for each sound sound_responses = [] for sound in sounds: - is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound.id) + is_favorited = await favorite_service.is_sound_favorited( + current_user.id, sound.id, + ) favorite_count = await favorite_service.get_sound_favorite_count(sound.id) - sound_response = SoundResponse.from_sound(sound, is_favorited, favorite_count) + sound_response = SoundResponse.from_sound( + sound, is_favorited, favorite_count, + ) sound_responses.append(sound_response) except Exception as e: diff --git a/app/core/config.py b/app/core/config.py index fc11bf8..ca890dd 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -20,7 +20,7 @@ class Settings(BaseSettings): # Production URLs (for reverse proxy deployment) FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production - BACKEND_URL: str = "http://localhost:8000" # Backend base URL + BACKEND_URL: str = "http://localhost:8000" # Backend base URL # CORS Configuration CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS diff --git a/app/models/base.py b/app/models/base.py index 4562d65..c4fb6b0 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -20,7 +20,9 @@ class BaseModel(SQLModel): # SQLAlchemy event listener to automatically update updated_at timestamp @event.listens_for(BaseModel, "before_update", propagate=True) def update_timestamp( - mapper: Mapper[Any], connection: Connection, target: BaseModel, # noqa: ARG001 + mapper: Mapper[Any], # noqa: ARG001 + connection: Connection, # noqa: ARG001 + target: BaseModel, ) -> None: """Automatically set updated_at timestamp before update operations.""" target.updated_at = datetime.now(UTC) diff --git a/app/models/playlist_sound.py b/app/models/playlist_sound.py index 0a0c11d..8d56763 100644 --- a/app/models/playlist_sound.py +++ b/app/models/playlist_sound.py @@ -35,5 +35,3 @@ class PlaylistSound(BaseModel, table=True): # relationships playlist: "Playlist" = Relationship(back_populates="playlist_sounds") sound: "Sound" = Relationship(back_populates="playlist_sounds") - - diff --git a/app/repositories/extraction.py b/app/repositories/extraction.py index ab6fe1d..2b791e2 100644 --- a/app/repositories/extraction.py +++ b/app/repositories/extraction.py @@ -58,7 +58,7 @@ class ExtractionRepository(BaseRepository[Extraction]): ) return list(result.all()) - async def get_user_extractions_filtered( + async def get_user_extractions_filtered( # noqa: PLR0913 self, user_id: int, search: str | None = None, @@ -92,7 +92,7 @@ class ExtractionRepository(BaseRepository[Extraction]): # Get total count before pagination count_query = select(func.count()).select_from( - base_query.subquery() + base_query.subquery(), ) count_result = await self.session.exec(count_query) total_count = count_result.one() @@ -106,10 +106,10 @@ class ExtractionRepository(BaseRepository[Extraction]): paginated_query = base_query.limit(limit).offset(offset) result = await self.session.exec(paginated_query) - + return list(result.all()), total_count - async def get_all_extractions_filtered( + async def get_all_extractions_filtered( # noqa: PLR0913 self, search: str | None = None, sort_by: str = "created_at", @@ -138,7 +138,7 @@ class ExtractionRepository(BaseRepository[Extraction]): # Get total count before pagination count_query = select(func.count()).select_from( - base_query.subquery() + base_query.subquery(), ) count_result = await self.session.exec(count_query) total_count = count_result.one() @@ -152,5 +152,5 @@ class ExtractionRepository(BaseRepository[Extraction]): paginated_query = base_query.limit(limit).offset(offset) result = await self.session.exec(paginated_query) - + return list(result.all()), total_count diff --git a/app/repositories/favorite.py b/app/repositories/favorite.py index 853c5ac..5ffa977 100644 --- a/app/repositories/favorite.py +++ b/app/repositories/favorite.py @@ -118,7 +118,9 @@ class FavoriteRepository(BaseRepository[Favorite]): raise async def get_by_user_and_sound( - self, user_id: int, sound_id: int, + self, + user_id: int, + sound_id: int, ) -> Favorite | None: """Get a favorite by user and sound. @@ -138,12 +140,16 @@ class FavoriteRepository(BaseRepository[Favorite]): return result.first() except Exception: logger.exception( - "Failed to get favorite for user %s and sound %s", user_id, sound_id, + "Failed to get favorite for user %s and sound %s", + user_id, + sound_id, ) raise async def get_by_user_and_playlist( - self, user_id: int, playlist_id: int, + self, + user_id: int, + playlist_id: int, ) -> Favorite | None: """Get a favorite by user and playlist. diff --git a/app/repositories/playlist.py b/app/repositories/playlist.py index 768b87c..63bbdce 100644 --- a/app/repositories/playlist.py +++ b/app/repositories/playlist.py @@ -57,7 +57,8 @@ class PlaylistRepository(BaseRepository[Playlist]): # management except Exception: logger.exception( - "Failed to update playlist timestamp for playlist: %s", playlist_id, + "Failed to update playlist timestamp for playlist: %s", + playlist_id, ) raise @@ -341,7 +342,7 @@ class PlaylistRepository(BaseRepository[Playlist]): include_stats: bool = False, # noqa: FBT001, FBT002 limit: int | None = None, offset: int = 0, - favorites_only: bool = False, + favorites_only: bool = False, # noqa: FBT001, FBT002 current_user_id: int | None = None, *, return_count: bool = False, @@ -395,9 +396,13 @@ class PlaylistRepository(BaseRepository[Playlist]): # Apply favorites filter if favorites_only and current_user_id is not None: # Use EXISTS subquery to avoid JOIN conflicts with GROUP BY - favorites_subquery = select(1).select_from(Favorite).where( - Favorite.user_id == current_user_id, - Favorite.playlist_id == Playlist.id, + favorites_subquery = ( + select(1) + .select_from(Favorite) + .where( + Favorite.user_id == current_user_id, + Favorite.playlist_id == Playlist.id, + ) ) subquery = subquery.where(favorites_subquery.exists()) @@ -466,9 +471,13 @@ class PlaylistRepository(BaseRepository[Playlist]): # Apply favorites filter if favorites_only and current_user_id is not None: # Use EXISTS subquery to avoid JOIN conflicts with GROUP BY - favorites_subquery = select(1).select_from(Favorite).where( - Favorite.user_id == current_user_id, - Favorite.playlist_id == Playlist.id, + favorites_subquery = ( + select(1) + .select_from(Favorite) + .where( + Favorite.user_id == current_user_id, + Favorite.playlist_id == Playlist.id, + ) ) subquery = subquery.where(favorites_subquery.exists()) diff --git a/app/repositories/sound.py b/app/repositories/sound.py index 8c8f76e..9209fc4 100644 --- a/app/repositories/sound.py +++ b/app/repositories/sound.py @@ -141,7 +141,7 @@ class SoundRepository(BaseRepository[Sound]): sort_order: SortOrder = SortOrder.ASC, limit: int | None = None, offset: int = 0, - favorites_only: bool = False, + favorites_only: bool = False, # noqa: FBT001, FBT002 user_id: int | None = None, ) -> list[Sound]: """Search and sort sounds with optional filtering.""" @@ -189,7 +189,8 @@ class SoundRepository(BaseRepository[Sound]): logger.exception( ( "Failed to search and sort sounds: " - "query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, user_id=%s" + "query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, " + "user_id=%s" ), search_query, sound_types, @@ -288,8 +289,7 @@ class SoundRepository(BaseRepository[Sound]): # Group by sound and order by play count descending statement = ( - statement - .group_by( + statement.group_by( Sound.id, Sound.name, Sound.type, diff --git a/app/repositories/user.py b/app/repositories/user.py index e58a966..c7292a6 100644 --- a/app/repositories/user.py +++ b/app/repositories/user.py @@ -1,7 +1,7 @@ """User repository.""" -from typing import Any from enum import Enum +from typing import Any from sqlalchemy import func from sqlalchemy.orm import selectinload @@ -18,6 +18,7 @@ logger = get_logger(__name__) class UserSortField(str, Enum): """User sort fields.""" + NAME = "name" EMAIL = "email" ROLE = "role" @@ -27,12 +28,14 @@ class UserSortField(str, Enum): class SortOrder(str, Enum): """Sort order.""" + ASC = "asc" DESC = "desc" class UserStatus(str, Enum): """User status filter.""" + ALL = "all" ACTIVE = "active" INACTIVE = "inactive" @@ -64,7 +67,7 @@ class UserRepository(BaseRepository[User]): logger.exception("Failed to get all users with plan") raise - async def get_all_with_plan_paginated( + async def get_all_with_plan_paginated( # noqa: PLR0913 self, page: int = 1, limit: int = 50, @@ -77,21 +80,20 @@ class UserRepository(BaseRepository[User]): try: # Calculate offset offset = (page - 1) * limit - + # Build base query base_query = select(User).options(selectinload(User.plan)) count_query = select(func.count(User.id)) - + # Apply search filter if search and search.strip(): search_pattern = f"%{search.strip().lower()}%" - search_condition = ( - func.lower(User.name).like(search_pattern) | - func.lower(User.email).like(search_pattern) - ) + search_condition = func.lower(User.name).like( + search_pattern, + ) | func.lower(User.email).like(search_pattern) base_query = base_query.where(search_condition) count_query = count_query.where(search_condition) - + # Apply status filter if status_filter == UserStatus.ACTIVE: base_query = base_query.where(User.is_active == True) # noqa: E712 @@ -99,47 +101,34 @@ class UserRepository(BaseRepository[User]): elif status_filter == UserStatus.INACTIVE: base_query = base_query.where(User.is_active == False) # noqa: E712 count_query = count_query.where(User.is_active == False) # noqa: E712 - + # Apply sorting - if sort_by == UserSortField.EMAIL: - if sort_order == SortOrder.DESC: - base_query = base_query.order_by(User.email.desc()) - else: - base_query = base_query.order_by(User.email.asc()) - elif sort_by == UserSortField.ROLE: - if sort_order == SortOrder.DESC: - base_query = base_query.order_by(User.role.desc()) - else: - base_query = base_query.order_by(User.role.asc()) - elif sort_by == UserSortField.CREDITS: - if sort_order == SortOrder.DESC: - base_query = base_query.order_by(User.credits.desc()) - else: - base_query = base_query.order_by(User.credits.asc()) - elif sort_by == UserSortField.CREATED_AT: - if sort_order == SortOrder.DESC: - base_query = base_query.order_by(User.created_at.desc()) - else: - base_query = base_query.order_by(User.created_at.asc()) - else: # Default to name - if sort_order == SortOrder.DESC: - base_query = base_query.order_by(User.name.desc()) - else: - base_query = base_query.order_by(User.name.asc()) - + sort_column = { + UserSortField.NAME: User.name, + UserSortField.EMAIL: User.email, + UserSortField.ROLE: User.role, + UserSortField.CREDITS: User.credits, + UserSortField.CREATED_AT: User.created_at, + }.get(sort_by, User.name) + + if sort_order == SortOrder.DESC: + base_query = base_query.order_by(sort_column.desc()) + else: + base_query = base_query.order_by(sort_column.asc()) + # Get total count count_result = await self.session.exec(count_query) total_count = count_result.one() - + # Apply pagination and get results paginated_query = base_query.limit(limit).offset(offset) result = await self.session.exec(paginated_query) users = list(result.all()) - - return users, total_count except Exception: logger.exception("Failed to get paginated users with plan") raise + else: + return users, total_count async def get_by_id_with_plan(self, entity_id: int) -> User | None: """Get a user by ID with plan relationship loaded.""" @@ -178,7 +167,7 @@ class UserRepository(BaseRepository[User]): logger.exception("Failed to get user by API token") raise - async def create(self, user_data: dict[str, Any]) -> User: + async def create(self, entity_data: dict[str, Any]) -> User: """Create a new user with plan assignment and first user admin logic.""" def _raise_plan_not_found() -> None: @@ -194,7 +183,7 @@ class UserRepository(BaseRepository[User]): if is_first_user: # First user gets admin role and pro plan plan_statement = select(Plan).where(Plan.code == "pro") - user_data["role"] = "admin" + entity_data["role"] = "admin" logger.info("Creating first user with admin role and pro plan") else: # Regular users get free plan @@ -210,11 +199,11 @@ class UserRepository(BaseRepository[User]): assert default_plan is not None # noqa: S101 # Set plan_id and default credits - user_data["plan_id"] = default_plan.id - user_data["credits"] = default_plan.credits + entity_data["plan_id"] = default_plan.id + entity_data["credits"] = default_plan.credits # Use BaseRepository's create method - return await super().create(user_data) + return await super().create(entity_data) except Exception: logger.exception("Failed to create user") raise diff --git a/app/schemas/auth.py b/app/schemas/auth.py index 60614c7..c1db2e8 100644 --- a/app/schemas/auth.py +++ b/app/schemas/auth.py @@ -85,7 +85,8 @@ class ChangePasswordRequest(BaseModel): """Schema for password change request.""" current_password: str | None = Field( - None, description="Current password (required if user has existing password)", + None, + description="Current password (required if user has existing password)", ) new_password: str = Field( ..., @@ -98,5 +99,8 @@ class UpdateProfileRequest(BaseModel): """Schema for profile update request.""" name: str | None = Field( - None, min_length=1, max_length=100, description="User display name", + None, + min_length=1, + max_length=100, + description="User display name", ) diff --git a/app/schemas/favorite.py b/app/schemas/favorite.py index 1e6bd67..2e0dcb3 100644 --- a/app/schemas/favorite.py +++ b/app/schemas/favorite.py @@ -11,10 +11,12 @@ class FavoriteResponse(BaseModel): id: int = Field(description="Favorite ID") user_id: int = Field(description="User ID") sound_id: int | None = Field( - description="Sound ID if this is a sound favorite", default=None, + description="Sound ID if this is a sound favorite", + default=None, ) playlist_id: int | None = Field( - description="Playlist ID if this is a playlist favorite", default=None, + description="Playlist ID if this is a playlist favorite", + default=None, ) created_at: datetime = Field(description="Creation timestamp") updated_at: datetime = Field(description="Last update timestamp") diff --git a/app/schemas/playlist.py b/app/schemas/playlist.py index 1d66d06..a0cc349 100644 --- a/app/schemas/playlist.py +++ b/app/schemas/playlist.py @@ -39,14 +39,19 @@ class PlaylistResponse(BaseModel): updated_at: str | None @classmethod - def from_playlist(cls, playlist: Playlist, is_favorited: bool = False, favorite_count: int = 0) -> "PlaylistResponse": + def from_playlist( + cls, + playlist: Playlist, + is_favorited: bool = False, # noqa: FBT001, FBT002 + favorite_count: int = 0, + ) -> "PlaylistResponse": """Create response from playlist model. - + Args: playlist: The Playlist model is_favorited: Whether the playlist is favorited by the current user favorite_count: Number of users who favorited this playlist - + Returns: PlaylistResponse instance diff --git a/app/schemas/sound.py b/app/schemas/sound.py index 3618701..ebe9fb9 100644 --- a/app/schemas/sound.py +++ b/app/schemas/sound.py @@ -18,16 +18,20 @@ class SoundResponse(BaseModel): size: int = Field(description="File size in bytes") hash: str = Field(description="File hash") normalized_filename: str | None = Field( - description="Normalized filename", default=None, + description="Normalized filename", + default=None, ) normalized_duration: int | None = Field( - description="Normalized duration in milliseconds", default=None, + description="Normalized duration in milliseconds", + default=None, ) normalized_size: int | None = Field( - description="Normalized file size in bytes", default=None, + description="Normalized file size in bytes", + default=None, ) normalized_hash: str | None = Field( - description="Normalized file hash", default=None, + description="Normalized file hash", + default=None, ) thumbnail: str | None = Field(description="Thumbnail filename", default=None) play_count: int = Field(description="Number of times played") @@ -35,10 +39,12 @@ class SoundResponse(BaseModel): is_music: bool = Field(description="Whether the sound is music") is_deletable: bool = Field(description="Whether the sound can be deleted") is_favorited: bool = Field( - description="Whether the sound is favorited by the current user", default=False, + description="Whether the sound is favorited by the current user", + default=False, ) favorite_count: int = Field( - description="Number of users who favorited this sound", default=0, + description="Number of users who favorited this sound", + default=0, ) created_at: datetime = Field(description="Creation timestamp") updated_at: datetime = Field(description="Last update timestamp") @@ -50,7 +56,10 @@ class SoundResponse(BaseModel): @classmethod def from_sound( - cls, sound: Sound, is_favorited: bool = False, favorite_count: int = 0, + cls, + sound: Sound, + is_favorited: bool = False, # noqa: FBT001, FBT002 + favorite_count: int = 0, ) -> "SoundResponse": """Create a SoundResponse from a Sound model. @@ -64,7 +73,8 @@ class SoundResponse(BaseModel): """ if sound.id is None: - raise ValueError("Sound ID cannot be None") + msg = "Sound ID cannot be None" + raise ValueError(msg) return cls( id=sound.id, diff --git a/app/schemas/user.py b/app/schemas/user.py index 853a7ae..2a79ef9 100644 --- a/app/schemas/user.py +++ b/app/schemas/user.py @@ -7,7 +7,10 @@ class UserUpdate(BaseModel): """Schema for updating a user.""" name: str | None = Field( - None, min_length=1, max_length=100, description="User full name", + None, + min_length=1, + max_length=100, + description="User full name", ) plan_id: int | None = Field(None, description="User plan ID") credits: int | None = Field(None, ge=0, description="User credits") diff --git a/app/services/auth.py b/app/services/auth.py index c6c89c0..46f902f 100644 --- a/app/services/auth.py +++ b/app/services/auth.py @@ -454,7 +454,10 @@ class AuthService: return user async def change_user_password( - self, user: User, current_password: str | None, new_password: str, + self, + user: User, + current_password: str | None, + new_password: str, ) -> None: """Change user's password.""" # Store user email before any operations to avoid session detachment issues @@ -484,8 +487,11 @@ class AuthService: self.session.add(user) await self.session.commit() - logger.info("Password %s successfully for user: %s", - "changed" if had_existing_password else "set", user_email) + logger.info( + "Password %s successfully for user: %s", + "changed" if had_existing_password else "set", + user_email, + ) async def user_to_response(self, user: User) -> UserResponse: """Convert User model to UserResponse with plan information.""" diff --git a/app/services/dashboard.py b/app/services/dashboard.py index a034f0f..ca62183 100644 --- a/app/services/dashboard.py +++ b/app/services/dashboard.py @@ -72,9 +72,7 @@ class DashboardService: "play_count": sound["play_count"], "duration": sound["duration"], "created_at": ( - sound["created_at"].isoformat() - if sound["created_at"] - else None + sound["created_at"].isoformat() if sound["created_at"] else None ), } for sound in top_sounds diff --git a/app/services/extraction.py b/app/services/extraction.py index bb648a1..fd7ce52 100644 --- a/app/services/extraction.py +++ b/app/services/extraction.py @@ -532,7 +532,8 @@ class ExtractionService: """Add the sound to the user's main playlist.""" try: await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001 - sound_id, user_id, + sound_id, + user_id, ) logger.info( "Added sound %d to main playlist for user %d", @@ -554,6 +555,10 @@ class ExtractionService: if not extraction: return None + # Get user information + user = await self.user_repo.get_by_id(extraction.user_id) + user_name = user.name if user else None + return { "id": extraction.id or 0, # Should never be None for existing extraction "url": extraction.url, @@ -564,11 +569,12 @@ class ExtractionService: "error": extraction.error, "sound_id": extraction.sound_id, "user_id": extraction.user_id, + "user_name": user_name, "created_at": extraction.created_at.isoformat(), "updated_at": extraction.updated_at.isoformat(), } - async def get_user_extractions( + async def get_user_extractions( # noqa: PLR0913 self, user_id: int, search: str | None = None, @@ -580,7 +586,10 @@ class ExtractionService: ) -> PaginatedExtractionsResponse: """Get all extractions for a user with filtering, search, and sorting.""" offset = (page - 1) * limit - extraction_user_tuples, total_count = await self.extraction_repo.get_user_extractions_filtered( + ( + extraction_user_tuples, + total_count, + ) = await self.extraction_repo.get_user_extractions_filtered( user_id=user_id, search=search, sort_by=sort_by, @@ -619,7 +628,7 @@ class ExtractionService: "total_pages": total_pages, } - async def get_all_extractions( + async def get_all_extractions( # noqa: PLR0913 self, search: str | None = None, sort_by: str = "created_at", @@ -630,7 +639,10 @@ class ExtractionService: ) -> PaginatedExtractionsResponse: """Get all extractions with filtering, search, and sorting.""" offset = (page - 1) * limit - extraction_user_tuples, total_count = await self.extraction_repo.get_all_extractions_filtered( + ( + extraction_user_tuples, + total_count, + ) = await self.extraction_repo.get_all_extractions_filtered( search=search, sort_by=sort_by, sort_order=sort_order, diff --git a/app/services/favorite.py b/app/services/favorite.py index 26fab4a..870b7a7 100644 --- a/app/services/favorite.py +++ b/app/services/favorite.py @@ -49,12 +49,14 @@ class FavoriteService: # Verify user exists user = await user_repo.get_by_id(user_id) if not user: - raise ValueError(f"User with ID {user_id} not found") + msg = f"User with ID {user_id} not found" + raise ValueError(msg) # Verify sound exists sound = await sound_repo.get_by_id(sound_id) if not sound: - raise ValueError(f"Sound with ID {sound_id} not found") + msg = f"Sound with ID {sound_id} not found" + raise ValueError(msg) # Get data for the event immediately after loading sound_name = sound.name @@ -63,9 +65,8 @@ class FavoriteService: # Check if already favorited existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id) if existing: - raise ValueError( - f"Sound {sound_id} is already favorited by user {user_id}", - ) + msg = f"Sound {sound_id} is already favorited by user {user_id}" + raise ValueError(msg) # Create favorite favorite_data = { @@ -120,12 +121,14 @@ class FavoriteService: # Verify user exists user = await user_repo.get_by_id(user_id) if not user: - raise ValueError(f"User with ID {user_id} not found") + msg = f"User with ID {user_id} not found" + raise ValueError(msg) # Verify playlist exists playlist = await playlist_repo.get_by_id(playlist_id) if not playlist: - raise ValueError(f"Playlist with ID {playlist_id} not found") + msg = f"Playlist with ID {playlist_id} not found" + raise ValueError(msg) # Check if already favorited existing = await favorite_repo.get_by_user_and_playlist( @@ -133,9 +136,8 @@ class FavoriteService: playlist_id, ) if existing: - raise ValueError( - f"Playlist {playlist_id} is already favorited by user {user_id}", - ) + msg = f"Playlist {playlist_id} is already favorited by user {user_id}" + raise ValueError(msg) # Create favorite favorite_data = { @@ -163,7 +165,8 @@ class FavoriteService: favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id) if not favorite: - raise ValueError(f"Sound {sound_id} is not favorited by user {user_id}") + msg = f"Sound {sound_id} is not favorited by user {user_id}" + raise ValueError(msg) # Get user and sound info before deletion for the event user_repo = UserRepository(session) @@ -192,7 +195,8 @@ class FavoriteService: } await socket_manager.broadcast_to_all("sound_favorited", event_data) logger.info( - "Broadcasted sound_favorited event for sound %s removal", sound_id, + "Broadcasted sound_favorited event for sound %s removal", + sound_id, ) except Exception: logger.exception( @@ -219,9 +223,8 @@ class FavoriteService: playlist_id, ) if not favorite: - raise ValueError( - f"Playlist {playlist_id} is not favorited by user {user_id}", - ) + msg = f"Playlist {playlist_id} is not favorited by user {user_id}" + raise ValueError(msg) await favorite_repo.delete(favorite) logger.info( diff --git a/app/services/playlist.py b/app/services/playlist.py index f7ee0d5..ab3eba8 100644 --- a/app/services/playlist.py +++ b/app/services/playlist.py @@ -16,6 +16,7 @@ logger = get_logger(__name__) class PaginatedPlaylistsResponse(TypedDict): """Response type for paginated playlists.""" + playlists: list[dict] total: int page: int @@ -286,7 +287,7 @@ class PlaylistService: ) -> PaginatedPlaylistsResponse: """Search and sort playlists with pagination.""" offset = (page - 1) * limit - + playlists, total_count = await self.playlist_repo.search_and_sort( search_query=search_query, sort_by=sort_by, @@ -299,9 +300,9 @@ class PlaylistService: current_user_id=current_user_id, return_count=True, ) - + total_pages = (total_count + limit - 1) // limit # Ceiling division - + return PaginatedPlaylistsResponse( playlists=playlists, total=total_count, @@ -468,7 +469,9 @@ class PlaylistService: } async def add_sound_to_main_playlist( - self, sound_id: int, user_id: int, # noqa: ARG002 + self, + sound_id: int, # noqa: ARG002 + user_id: int, # noqa: ARG002 ) -> None: """Add a sound to the global main playlist.""" raise HTTPException( @@ -477,7 +480,9 @@ class PlaylistService: ) async def _add_sound_to_main_playlist_internal( - self, sound_id: int, user_id: int, + self, + sound_id: int, + user_id: int, ) -> None: """Add sound to main playlist bypassing restrictions. diff --git a/tests/api/v1/admin/test_users_endpoints.py b/tests/api/v1/admin/test_users_endpoints.py index b01c812..e59f123 100644 --- a/tests/api/v1/admin/test_users_endpoints.py +++ b/tests/api/v1/admin/test_users_endpoints.py @@ -21,8 +21,6 @@ def mock_plan_repository(): return Mock() - - @pytest.fixture def regular_user(): """Create regular user for testing.""" @@ -60,52 +58,78 @@ class TestAdminUserEndpoints: test_plan: Plan, ) -> None: """Test listing users successfully.""" - with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all: + with patch( + "app.repositories.user.UserRepository.get_all_with_plan_paginated", + ) as mock_get_all: # Create mock user objects that don't trigger database saves - mock_admin = type("User", (), { - "id": admin_user.id, - "email": admin_user.email, - "name": admin_user.name, - "picture": None, - "role": admin_user.role, - "credits": admin_user.credits, - "is_active": admin_user.is_active, - "created_at": admin_user.created_at, - "updated_at": admin_user.updated_at, - "plan": type("Plan", (), { - "id": test_plan.id, - "name": test_plan.name, - "max_credits": test_plan.max_credits, - })(), - })() + mock_admin = type( + "User", + (), + { + "id": admin_user.id, + "email": admin_user.email, + "name": admin_user.name, + "picture": None, + "role": admin_user.role, + "credits": admin_user.credits, + "is_active": admin_user.is_active, + "created_at": admin_user.created_at, + "updated_at": admin_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": test_plan.id, + "name": test_plan.name, + "max_credits": test_plan.max_credits, + }, + )(), + }, + )() - mock_regular = type("User", (), { - "id": regular_user.id, - "email": regular_user.email, - "name": regular_user.name, - "picture": None, - "role": regular_user.role, - "credits": regular_user.credits, - "is_active": regular_user.is_active, - "created_at": regular_user.created_at, - "updated_at": regular_user.updated_at, - "plan": type("Plan", (), { - "id": test_plan.id, - "name": test_plan.name, - "max_credits": test_plan.max_credits, - })(), - })() + mock_regular = type( + "User", + (), + { + "id": regular_user.id, + "email": regular_user.email, + "name": regular_user.name, + "picture": None, + "role": regular_user.role, + "credits": regular_user.credits, + "is_active": regular_user.is_active, + "created_at": regular_user.created_at, + "updated_at": regular_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": test_plan.id, + "name": test_plan.name, + "max_credits": test_plan.max_credits, + }, + )(), + }, + )() - mock_get_all.return_value = [mock_admin, mock_regular] + # Mock returns tuple (users, total_count) + mock_get_all.return_value = ([mock_admin, mock_regular], 2) response = await authenticated_admin_client.get("/api/v1/admin/users/") assert response.status_code == 200 data = response.json() - assert len(data) == 2 - assert data[0]["email"] == "admin@example.com" - assert data[1]["email"] == "user@example.com" - mock_get_all.assert_called_once_with(limit=100, offset=0) + assert "users" in data + assert "total" in data + assert "page" in data + assert "limit" in data + assert "total_pages" in data + assert len(data["users"]) == 2 + assert data["users"][0]["email"] == "admin@example.com" + assert data["users"][1]["email"] == "user@example.com" + assert data["total"] == 2 + assert data["page"] == 1 + assert data["limit"] == 50 @pytest.mark.asyncio async def test_list_users_with_pagination( @@ -115,29 +139,55 @@ class TestAdminUserEndpoints: test_plan: Plan, ) -> None: """Test listing users with pagination.""" - with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all: - mock_admin = type("User", (), { - "id": admin_user.id, - "email": admin_user.email, - "name": admin_user.name, - "picture": None, - "role": admin_user.role, - "credits": admin_user.credits, - "is_active": admin_user.is_active, - "created_at": admin_user.created_at, - "updated_at": admin_user.updated_at, - "plan": type("Plan", (), { - "id": test_plan.id, - "name": test_plan.name, - "max_credits": test_plan.max_credits, - })(), - })() - mock_get_all.return_value = [mock_admin] + from app.repositories.user import SortOrder, UserSortField, UserStatus - response = await authenticated_admin_client.get("/api/v1/admin/users/?limit=10&offset=5") + with patch( + "app.repositories.user.UserRepository.get_all_with_plan_paginated", + ) as mock_get_all: + mock_admin = type( + "User", + (), + { + "id": admin_user.id, + "email": admin_user.email, + "name": admin_user.name, + "picture": None, + "role": admin_user.role, + "credits": admin_user.credits, + "is_active": admin_user.is_active, + "created_at": admin_user.created_at, + "updated_at": admin_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": test_plan.id, + "name": test_plan.name, + "max_credits": test_plan.max_credits, + }, + )(), + }, + )() + # Mock returns tuple (users, total_count) + mock_get_all.return_value = ([mock_admin], 1) + + response = await authenticated_admin_client.get( + "/api/v1/admin/users/?page=2&limit=10", + ) assert response.status_code == 200 - mock_get_all.assert_called_once_with(limit=10, offset=5) + data = response.json() + assert "users" in data + assert data["page"] == 2 + assert data["limit"] == 10 + mock_get_all.assert_called_once_with( + page=2, + limit=10, + search=None, + sort_by=UserSortField.NAME, + sort_order=SortOrder.ASC, + status_filter=UserStatus.ALL, + ) @pytest.mark.asyncio async def test_list_users_unauthenticated(self, client: AsyncClient) -> None: @@ -153,7 +203,9 @@ class TestAdminUserEndpoints: regular_user: User, ) -> None: """Test listing users as non-admin user.""" - with patch("app.core.dependencies.get_current_active_user", return_value=regular_user): + with patch( + "app.core.dependencies.get_current_active_user", return_value=regular_user, + ): response = await client.get("/api/v1/admin/users/") assert response.status_code == 401 @@ -169,24 +221,34 @@ class TestAdminUserEndpoints: """Test getting specific user successfully.""" with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + ) as mock_get_by_id, ): - mock_user = type("User", (), { - "id": regular_user.id, - "email": regular_user.email, - "name": regular_user.name, - "picture": None, - "role": regular_user.role, - "credits": regular_user.credits, - "is_active": regular_user.is_active, - "created_at": regular_user.created_at, - "updated_at": regular_user.updated_at, - "plan": type("Plan", (), { - "id": test_plan.id, - "name": test_plan.name, - "max_credits": test_plan.max_credits, - })(), - })() + mock_user = type( + "User", + (), + { + "id": regular_user.id, + "email": regular_user.email, + "name": regular_user.name, + "picture": None, + "role": regular_user.role, + "credits": regular_user.credits, + "is_active": regular_user.is_active, + "created_at": regular_user.created_at, + "updated_at": regular_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": test_plan.id, + "name": test_plan.name, + "max_credits": test_plan.max_credits, + }, + )(), + }, + )() mock_get_by_id.return_value = mock_user response = await authenticated_admin_client.get("/api/v1/admin/users/2") @@ -207,7 +269,10 @@ class TestAdminUserEndpoints: """Test getting non-existent user.""" with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None), + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + return_value=None, + ), ): response = await authenticated_admin_client.get("/api/v1/admin/users/999") @@ -226,43 +291,63 @@ class TestAdminUserEndpoints: """Test updating user successfully.""" with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + ) as mock_get_by_id, patch("app.repositories.user.UserRepository.update") as mock_update, - patch("app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan), + patch( + "app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan, + ), ): - mock_user = type("User", (), { - "id": regular_user.id, - "email": regular_user.email, - "name": regular_user.name, - "picture": None, - "role": regular_user.role, - "credits": regular_user.credits, - "is_active": regular_user.is_active, - "created_at": regular_user.created_at, - "updated_at": regular_user.updated_at, - "plan": type("Plan", (), { - "id": test_plan.id, - "name": test_plan.name, - "max_credits": test_plan.max_credits, - })(), - })() + mock_user = type( + "User", + (), + { + "id": regular_user.id, + "email": regular_user.email, + "name": regular_user.name, + "picture": None, + "role": regular_user.role, + "credits": regular_user.credits, + "is_active": regular_user.is_active, + "created_at": regular_user.created_at, + "updated_at": regular_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": test_plan.id, + "name": test_plan.name, + "max_credits": test_plan.max_credits, + }, + )(), + }, + )() - updated_mock = type("User", (), { - "id": regular_user.id, - "email": regular_user.email, - "name": "Updated Name", - "picture": None, - "role": regular_user.role, - "credits": 200, - "is_active": regular_user.is_active, - "created_at": regular_user.created_at, - "updated_at": regular_user.updated_at, - "plan": type("Plan", (), { - "id": test_plan.id, - "name": test_plan.name, - "max_credits": test_plan.max_credits, - })(), - })() + updated_mock = type( + "User", + (), + { + "id": regular_user.id, + "email": regular_user.email, + "name": "Updated Name", + "picture": None, + "role": regular_user.role, + "credits": 200, + "is_active": regular_user.is_active, + "created_at": regular_user.created_at, + "updated_at": regular_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": test_plan.id, + "name": test_plan.name, + "max_credits": test_plan.max_credits, + }, + )(), + }, + )() mock_get_by_id.return_value = mock_user mock_update.return_value = updated_mock @@ -271,7 +356,10 @@ class TestAdminUserEndpoints: async def mock_refresh(instance, attributes=None): pass - with patch("sqlmodel.ext.asyncio.session.AsyncSession.refresh", side_effect=mock_refresh): + with patch( + "sqlmodel.ext.asyncio.session.AsyncSession.refresh", + side_effect=mock_refresh, + ): response = await authenticated_admin_client.patch( "/api/v1/admin/users/2", json={ @@ -295,7 +383,10 @@ class TestAdminUserEndpoints: """Test updating non-existent user.""" with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None), + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + return_value=None, + ), ): response = await authenticated_admin_client.patch( "/api/v1/admin/users/999", @@ -316,25 +407,35 @@ class TestAdminUserEndpoints: """Test updating user with invalid plan.""" with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + ) as mock_get_by_id, patch("app.repositories.plan.PlanRepository.get_by_id", return_value=None), ): - mock_user = type("User", (), { - "id": regular_user.id, - "email": regular_user.email, - "name": regular_user.name, - "picture": None, - "role": regular_user.role, - "credits": regular_user.credits, - "is_active": regular_user.is_active, - "created_at": regular_user.created_at, - "updated_at": regular_user.updated_at, - "plan": type("Plan", (), { - "id": 1, - "name": "Basic", - "max_credits": 100, - })(), - })() + mock_user = type( + "User", + (), + { + "id": regular_user.id, + "email": regular_user.email, + "name": regular_user.name, + "picture": None, + "role": regular_user.role, + "credits": regular_user.credits, + "is_active": regular_user.is_active, + "created_at": regular_user.created_at, + "updated_at": regular_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": 1, + "name": "Basic", + "max_credits": 100, + }, + )(), + }, + )() mock_get_by_id.return_value = mock_user response = await authenticated_admin_client.patch( "/api/v1/admin/users/2", @@ -356,29 +457,41 @@ class TestAdminUserEndpoints: """Test disabling user successfully.""" with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + ) as mock_get_by_id, patch("app.repositories.user.UserRepository.update") as mock_update, ): - mock_user = type("User", (), { - "id": regular_user.id, - "email": regular_user.email, - "name": regular_user.name, - "picture": None, - "role": regular_user.role, - "credits": regular_user.credits, - "is_active": regular_user.is_active, - "created_at": regular_user.created_at, - "updated_at": regular_user.updated_at, - "plan": type("Plan", (), { - "id": test_plan.id, - "name": test_plan.name, - "max_credits": test_plan.max_credits, - })(), - })() + mock_user = type( + "User", + (), + { + "id": regular_user.id, + "email": regular_user.email, + "name": regular_user.name, + "picture": None, + "role": regular_user.role, + "credits": regular_user.credits, + "is_active": regular_user.is_active, + "created_at": regular_user.created_at, + "updated_at": regular_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": test_plan.id, + "name": test_plan.name, + "max_credits": test_plan.max_credits, + }, + )(), + }, + )() mock_get_by_id.return_value = mock_user mock_update.return_value = mock_user - response = await authenticated_admin_client.post("/api/v1/admin/users/2/disable") + response = await authenticated_admin_client.post( + "/api/v1/admin/users/2/disable", + ) assert response.status_code == 200 data = response.json() @@ -393,9 +506,14 @@ class TestAdminUserEndpoints: """Test disabling non-existent user.""" with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None), + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + return_value=None, + ), ): - response = await authenticated_admin_client.post("/api/v1/admin/users/999/disable") + response = await authenticated_admin_client.post( + "/api/v1/admin/users/999/disable", + ) assert response.status_code == 404 data = response.json() @@ -421,29 +539,41 @@ class TestAdminUserEndpoints: with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + ) as mock_get_by_id, patch("app.repositories.user.UserRepository.update") as mock_update, ): - mock_disabled_user = type("User", (), { - "id": disabled_user.id, - "email": disabled_user.email, - "name": disabled_user.name, - "picture": None, - "role": disabled_user.role, - "credits": disabled_user.credits, - "is_active": disabled_user.is_active, - "created_at": disabled_user.created_at, - "updated_at": disabled_user.updated_at, - "plan": type("Plan", (), { - "id": test_plan.id, - "name": test_plan.name, - "max_credits": test_plan.max_credits, - })(), - })() + mock_disabled_user = type( + "User", + (), + { + "id": disabled_user.id, + "email": disabled_user.email, + "name": disabled_user.name, + "picture": None, + "role": disabled_user.role, + "credits": disabled_user.credits, + "is_active": disabled_user.is_active, + "created_at": disabled_user.created_at, + "updated_at": disabled_user.updated_at, + "plan": type( + "Plan", + (), + { + "id": test_plan.id, + "name": test_plan.name, + "max_credits": test_plan.max_credits, + }, + )(), + }, + )() mock_get_by_id.return_value = mock_disabled_user mock_update.return_value = mock_disabled_user - response = await authenticated_admin_client.post("/api/v1/admin/users/3/enable") + response = await authenticated_admin_client.post( + "/api/v1/admin/users/3/enable", + ) assert response.status_code == 200 data = response.json() @@ -458,9 +588,14 @@ class TestAdminUserEndpoints: """Test enabling non-existent user.""" with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None), + patch( + "app.repositories.user.UserRepository.get_by_id_with_plan", + return_value=None, + ), ): - response = await authenticated_admin_client.post("/api/v1/admin/users/999/enable") + response = await authenticated_admin_client.post( + "/api/v1/admin/users/999/enable", + ) assert response.status_code == 404 data = response.json() @@ -479,9 +614,14 @@ class TestAdminUserEndpoints: with ( patch("app.core.dependencies.get_admin_user", return_value=admin_user), - patch("app.repositories.plan.PlanRepository.get_all", return_value=[basic_plan, premium_plan]), + patch( + "app.repositories.plan.PlanRepository.get_all", + return_value=[basic_plan, premium_plan], + ), ): - response = await authenticated_admin_client.get("/api/v1/admin/users/plans/list") + response = await authenticated_admin_client.get( + "/api/v1/admin/users/plans/list", + ) assert response.status_code == 200 data = response.json() diff --git a/tests/api/v1/test_auth_endpoints.py b/tests/api/v1/test_auth_endpoints.py index bbd6f8d..13f0bfa 100644 --- a/tests/api/v1/test_auth_endpoints.py +++ b/tests/api/v1/test_auth_endpoints.py @@ -488,11 +488,17 @@ class TestAuthEndpoints: test_plan: Plan, ) -> None: """Test refresh token success.""" - with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh: - mock_refresh.return_value = type("TokenResponse", (), { - "access_token": "new_access_token", - "expires_in": 3600, - })() + with patch( + "app.services.auth.AuthService.refresh_access_token", + ) as mock_refresh: + mock_refresh.return_value = type( + "TokenResponse", + (), + { + "access_token": "new_access_token", + "expires_in": 3600, + }, + )() response = await test_client.post( "/api/v1/auth/refresh", @@ -516,7 +522,9 @@ class TestAuthEndpoints: @pytest.mark.asyncio async def test_refresh_token_service_error(self, test_client: AsyncClient) -> None: """Test refresh token with service error.""" - with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh: + with patch( + "app.services.auth.AuthService.refresh_access_token", + ) as mock_refresh: mock_refresh.side_effect = Exception("Database error") response = await test_client.post( @@ -528,7 +536,6 @@ class TestAuthEndpoints: data = response.json() assert "Token refresh failed" in data["detail"] - @pytest.mark.asyncio async def test_exchange_oauth_token_invalid_code( self, @@ -554,7 +561,9 @@ class TestAuthEndpoints: """Test update profile success.""" with ( patch("app.services.auth.AuthService.update_user_profile") as mock_update, - patch("app.services.auth.AuthService.user_to_response") as mock_user_to_response, + patch( + "app.services.auth.AuthService.user_to_response", + ) as mock_user_to_response, ): updated_user = User( id=test_user.id, @@ -569,6 +578,7 @@ class TestAuthEndpoints: # Mock the user_to_response to return UserResponse format from app.schemas.auth import UserResponse + mock_user_to_response.return_value = UserResponse( id=test_user.id, email=test_user.email, @@ -598,7 +608,9 @@ class TestAuthEndpoints: assert data["name"] == "Updated Name" @pytest.mark.asyncio - async def test_update_profile_unauthenticated(self, test_client: AsyncClient) -> None: + async def test_update_profile_unauthenticated( + self, test_client: AsyncClient, + ) -> None: """Test update profile without authentication.""" response = await test_client.patch( "/api/v1/auth/me", @@ -632,7 +644,9 @@ class TestAuthEndpoints: assert data["message"] == "Password changed successfully" @pytest.mark.asyncio - async def test_change_password_unauthenticated(self, test_client: AsyncClient) -> None: + async def test_change_password_unauthenticated( + self, test_client: AsyncClient, + ) -> None: """Test change password without authentication.""" response = await test_client.post( "/api/v1/auth/change-password", @@ -652,7 +666,9 @@ class TestAuthEndpoints: auth_cookies: dict[str, str], ) -> None: """Test get user OAuth providers success.""" - with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers: + with patch( + "app.services.auth.AuthService.get_user_oauth_providers", + ) as mock_providers: from datetime import datetime from app.models.user_oauth import UserOauth @@ -699,7 +715,9 @@ class TestAuthEndpoints: assert data[2]["display_name"] == "GitHub" @pytest.mark.asyncio - async def test_get_user_providers_unauthenticated(self, test_client: AsyncClient) -> None: + async def test_get_user_providers_unauthenticated( + self, test_client: AsyncClient, + ) -> None: """Test get user OAuth providers without authentication.""" response = await test_client.get("/api/v1/auth/user-providers") diff --git a/tests/api/v1/test_playlist_endpoints.py b/tests/api/v1/test_playlist_endpoints.py index fa6bc7b..4d4e0f0 100644 --- a/tests/api/v1/test_playlist_endpoints.py +++ b/tests/api/v1/test_playlist_endpoints.py @@ -109,9 +109,15 @@ class TestPlaylistEndpoints: assert response.status_code == 200 data = response.json() - assert len(data) == 2 + assert "playlists" in data + assert "total" in data + assert "page" in data + assert "limit" in data + assert "total_pages" in data + assert len(data["playlists"]) == 2 + assert data["total"] == 2 - playlist_names = {p["name"] for p in data} + playlist_names = {p["name"] for p in data["playlists"]} assert "Test Playlist" in playlist_names assert "Main Playlist" in playlist_names diff --git a/tests/api/v1/test_sound_endpoints.py b/tests/api/v1/test_sound_endpoints.py index a523c60..d4f326c 100644 --- a/tests/api/v1/test_sound_endpoints.py +++ b/tests/api/v1/test_sound_endpoints.py @@ -9,7 +9,7 @@ from httpx import AsyncClient from app.models.user import User if TYPE_CHECKING: - from app.services.extraction import ExtractionInfo + from app.services.extraction import ExtractionInfo, PaginatedExtractionsResponse class TestSoundEndpoints: @@ -32,6 +32,7 @@ class TestSoundEndpoints: "error": None, "sound_id": None, "user_id": authenticated_user.id, + "user_name": authenticated_user.name, "created_at": "2025-08-03T12:00:00Z", "updated_at": "2025-08-03T12:00:00Z", } @@ -111,6 +112,7 @@ class TestSoundEndpoints: "error": None, "sound_id": 42, "user_id": authenticated_user.id, + "user_name": authenticated_user.name, "created_at": "2025-08-03T12:00:00Z", "updated_at": "2025-08-03T12:00:00Z", } @@ -154,41 +156,49 @@ class TestSoundEndpoints: authenticated_user: User, ) -> None: """Test getting user extractions.""" - mock_extractions: list[ExtractionInfo] = [ - { - "id": 1, - "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", - "title": "Never Gonna Give You Up", - "service": "youtube", - "service_id": "dQw4w9WgXcQ", - "status": "completed", - "error": None, - "sound_id": 42, - "user_id": authenticated_user.id, - "created_at": "2025-08-03T12:00:00Z", - "updated_at": "2025-08-03T12:00:00Z", - }, - { - "id": 2, - "url": "https://soundcloud.com/example/track", - "title": "Example Track", - "service": "soundcloud", - "service_id": "example-track", - "status": "pending", - "error": None, - "sound_id": None, - "user_id": authenticated_user.id, - "created_at": "2025-08-03T12:00:00Z", - "updated_at": "2025-08-03T12:00:00Z", - }, - ] + mock_extractions: PaginatedExtractionsResponse = { + "extractions": [ + { + "id": 1, + "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", + "title": "Never Gonna Give You Up", + "service": "youtube", + "service_id": "dQw4w9WgXcQ", + "status": "completed", + "error": None, + "sound_id": 42, + "user_id": authenticated_user.id, + "user_name": authenticated_user.name, + "created_at": "2025-08-03T12:00:00Z", + "updated_at": "2025-08-03T12:00:00Z", + }, + { + "id": 2, + "url": "https://soundcloud.com/example/track", + "title": "Example Track", + "service": "soundcloud", + "service_id": "example-track", + "status": "pending", + "error": None, + "sound_id": None, + "user_id": authenticated_user.id, + "user_name": authenticated_user.name, + "created_at": "2025-08-03T12:00:00Z", + "updated_at": "2025-08-03T12:00:00Z", + }, + ], + "total": 2, + "page": 1, + "limit": 50, + "total_pages": 1, + } with patch( "app.services.extraction.ExtractionService.get_user_extractions", ) as mock_get: mock_get.return_value = mock_extractions - response = await authenticated_client.get("/api/v1/extractions/") + response = await authenticated_client.get("/api/v1/extractions/user") assert response.status_code == 200 data = response.json() @@ -337,7 +347,9 @@ class TestSoundEndpoints: """Test getting sounds with authentication.""" from app.models.sound import Sound - with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get: + with patch( + "app.repositories.sound.SoundRepository.search_and_sort", + ) as mock_get: # Create mock sounds with all required fields mock_sound_1 = Sound( id=1, @@ -383,7 +395,9 @@ class TestSoundEndpoints: """Test getting sounds with type filtering.""" from app.models.sound import Sound - with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get: + with patch( + "app.repositories.sound.SoundRepository.search_and_sort", + ) as mock_get: # Create mock sound with all required fields mock_sound = Sound( id=1, diff --git a/tests/conftest.py b/tests/conftest.py index 680c7da..79b27a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -335,5 +335,3 @@ async def admin_cookies(admin_user: User) -> dict[str, str]: access_token = JWTUtils.create_access_token(token_data) return {"access_token": access_token} - - diff --git a/tests/repositories/test_playlist.py b/tests/repositories/test_playlist.py index 961f18e..79ed892 100644 --- a/tests/repositories/test_playlist.py +++ b/tests/repositories/test_playlist.py @@ -539,21 +539,35 @@ class TestPlaylistRepository: sound_ids = [s.id for s in sounds] # Add first two sounds sequentially (positions 0, 1) - await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0 - await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1 + await playlist_repository.add_sound_to_playlist( + playlist_id, sound_ids[0], + ) # position 0 + await playlist_repository.add_sound_to_playlist( + playlist_id, sound_ids[1], + ) # position 1 # Now insert third sound at position 1 - should shift existing sound at position 1 to position 2 - await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=1) + await playlist_repository.add_sound_to_playlist( + playlist_id, sound_ids[2], position=1, + ) # Verify the final positions - playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id) + playlist_sounds = await playlist_repository.get_playlist_sound_entries( + playlist_id, + ) assert len(playlist_sounds) == 3 - assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0 + assert ( + playlist_sounds[0].sound_id == sound_ids[0] + ) # Original sound 0 stays at position 0 assert playlist_sounds[0].position == 0 - assert playlist_sounds[1].sound_id == sound_ids[2] # New sound 2 inserted at position 1 + assert ( + playlist_sounds[1].sound_id == sound_ids[2] + ) # New sound 2 inserted at position 1 assert playlist_sounds[1].position == 1 - assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2 + assert ( + playlist_sounds[2].sound_id == sound_ids[1] + ) # Original sound 1 shifted to position 2 assert playlist_sounds[2].position == 2 @pytest.mark.asyncio @@ -615,21 +629,35 @@ class TestPlaylistRepository: sound_ids = [s.id for s in sounds] # Add first two sounds sequentially (positions 0, 1) - await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0 - await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1 + await playlist_repository.add_sound_to_playlist( + playlist_id, sound_ids[0], + ) # position 0 + await playlist_repository.add_sound_to_playlist( + playlist_id, sound_ids[1], + ) # position 1 # Now insert third sound at position 0 - should shift existing sounds to positions 1, 2 - await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=0) + await playlist_repository.add_sound_to_playlist( + playlist_id, sound_ids[2], position=0, + ) # Verify the final positions - playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id) + playlist_sounds = await playlist_repository.get_playlist_sound_entries( + playlist_id, + ) assert len(playlist_sounds) == 3 - assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0 + assert ( + playlist_sounds[0].sound_id == sound_ids[2] + ) # New sound 2 inserted at position 0 assert playlist_sounds[0].position == 0 - assert playlist_sounds[1].sound_id == sound_ids[0] # Original sound 0 shifted to position 1 + assert ( + playlist_sounds[1].sound_id == sound_ids[0] + ) # Original sound 0 shifted to position 1 assert playlist_sounds[1].position == 1 - assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2 + assert ( + playlist_sounds[2].sound_id == sound_ids[1] + ) # Original sound 1 shifted to position 2 assert playlist_sounds[2].position == 2 @pytest.mark.asyncio diff --git a/tests/services/test_credit.py b/tests/services/test_credit.py index c9b2ad9..73f9fb4 100644 --- a/tests/services/test_credit.py +++ b/tests/services/test_credit.py @@ -409,7 +409,9 @@ class TestCreditService: @pytest.mark.asyncio async def test_recharge_user_credits_success( - self, credit_service, sample_user, + self, + credit_service, + sample_user, ) -> None: """Test successful credit recharge for a user.""" mock_session = credit_service.db_session_factory() diff --git a/tests/services/test_dashboard.py b/tests/services/test_dashboard.py index d03c55f..fae0886 100644 --- a/tests/services/test_dashboard.py +++ b/tests/services/test_dashboard.py @@ -43,7 +43,9 @@ class TestDashboardService: "total_duration": 75000, "total_size": 1024000, } - mock_sound_repository.get_soundboard_statistics = AsyncMock(return_value=mock_stats) + mock_sound_repository.get_soundboard_statistics = AsyncMock( + return_value=mock_stats, + ) result = await dashboard_service.get_soundboard_statistics() diff --git a/tests/services/test_extraction.py b/tests/services/test_extraction.py index a225950..8640c0d 100644 --- a/tests/services/test_extraction.py +++ b/tests/services/test_extraction.py @@ -99,6 +99,11 @@ class TestExtractionService: url = "https://www.youtube.com/watch?v=test123" user_id = 1 + # Mock user for user_name retrieval + mock_user = Mock() + mock_user.name = "Test User" + extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user) + # Mock repository call - no service detection happens during creation mock_extraction = Extraction( id=1, @@ -120,6 +125,7 @@ class TestExtractionService: assert result["service_id"] is None # Not detected during creation assert result["title"] is None # Not detected during creation assert result["status"] == "pending" + assert result["user_name"] == "Test User" @pytest.mark.asyncio async def test_create_extraction_basic(self, extraction_service) -> None: @@ -127,6 +133,11 @@ class TestExtractionService: url = "https://www.youtube.com/watch?v=test123" user_id = 1 + # Mock user for user_name retrieval + mock_user = Mock() + mock_user.name = "Test User" + extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user) + # Mock repository call - creation always succeeds now mock_extraction = Extraction( id=2, @@ -146,6 +157,7 @@ class TestExtractionService: assert result["id"] == 2 assert result["url"] == url assert result["status"] == "pending" + assert result["user_name"] == "Test User" @pytest.mark.asyncio async def test_create_extraction_any_url(self, extraction_service) -> None: @@ -153,6 +165,11 @@ class TestExtractionService: url = "https://invalid.url" user_id = 1 + # Mock user for user_name retrieval + mock_user = Mock() + mock_user.name = "Test User" + extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user) + # Mock repository call - even invalid URLs are accepted during creation mock_extraction = Extraction( id=3, @@ -172,6 +189,7 @@ class TestExtractionService: assert result["id"] == 3 assert result["url"] == url assert result["status"] == "pending" + assert result["user_name"] == "Test User" @pytest.mark.asyncio async def test_process_extraction_with_service_detection( @@ -408,9 +426,16 @@ class TestExtractionService: sound_id=42, ) + # Mock user for user_name retrieval + mock_user = Mock() + mock_user.name = "Test User" + extraction_service.extraction_repo.get_by_id = AsyncMock( return_value=extraction, ) + extraction_service.user_repo.get_by_id = AsyncMock( + return_value=mock_user, + ) result = await extraction_service.get_extraction_by_id(1) @@ -421,6 +446,7 @@ class TestExtractionService: assert result["title"] == "Test Video" assert result["status"] == "completed" assert result["sound_id"] == 42 + assert result["user_name"] == "Test User" @pytest.mark.asyncio async def test_get_extraction_by_id_not_found(self, extraction_service) -> None: @@ -434,52 +460,70 @@ class TestExtractionService: @pytest.mark.asyncio async def test_get_user_extractions(self, extraction_service) -> None: """Test getting user extractions.""" + from app.models.user import User + + user = User(id=1, name="Test User", email="test@example.com") extractions = [ - Extraction( - id=1, - service="youtube", - service_id="test123", - url="https://www.youtube.com/watch?v=test123", - user_id=1, - title="Test Video 1", - status="completed", - sound_id=42, + ( + Extraction( + id=1, + service="youtube", + service_id="test123", + url="https://www.youtube.com/watch?v=test123", + user_id=1, + title="Test Video 1", + status="completed", + sound_id=42, + ), + user, ), - Extraction( - id=2, - service="youtube", - service_id="test456", - url="https://www.youtube.com/watch?v=test456", - user_id=1, - title="Test Video 2", - status="pending", + ( + Extraction( + id=2, + service="youtube", + service_id="test456", + url="https://www.youtube.com/watch?v=test456", + user_id=1, + title="Test Video 2", + status="pending", + ), + user, ), ] - extraction_service.extraction_repo.get_by_user = AsyncMock( - return_value=extractions, + extraction_service.extraction_repo.get_user_extractions_filtered = AsyncMock( + return_value=(extractions, 2), ) result = await extraction_service.get_user_extractions(1) - assert len(result) == 2 - assert result[0]["id"] == 1 - assert result[0]["title"] == "Test Video 1" - assert result[1]["id"] == 2 - assert result[1]["title"] == "Test Video 2" + assert result["total"] == 2 + assert len(result["extractions"]) == 2 + assert result["extractions"][0]["id"] == 1 + assert result["extractions"][0]["title"] == "Test Video 1" + assert result["extractions"][0]["user_name"] == "Test User" + assert result["extractions"][1]["id"] == 2 + assert result["extractions"][1]["title"] == "Test Video 2" + assert result["extractions"][1]["user_name"] == "Test User" @pytest.mark.asyncio async def test_get_pending_extractions(self, extraction_service) -> None: """Test getting pending extractions.""" + from app.models.user import User + + user = User(id=1, name="Test User", email="test@example.com") pending_extractions = [ - Extraction( - id=1, - service="youtube", - service_id="test123", - url="https://www.youtube.com/watch?v=test123", - user_id=1, - title="Pending Video", - status="pending", + ( + Extraction( + id=1, + service="youtube", + service_id="test123", + url="https://www.youtube.com/watch?v=test123", + user_id=1, + title="Pending Video", + status="pending", + ), + user, ), ] @@ -492,3 +536,4 @@ class TestExtractionService: assert len(result) == 1 assert result[0]["id"] == 1 assert result[0]["status"] == "pending" + assert result[0]["user_name"] == "Test User" diff --git a/tests/services/test_scheduler.py b/tests/services/test_scheduler.py index 333bc21..22a7975 100644 --- a/tests/services/test_scheduler.py +++ b/tests/services/test_scheduler.py @@ -25,9 +25,10 @@ class TestSchedulerService: @pytest.mark.asyncio async def test_start_scheduler(self, scheduler_service) -> None: """Test starting the scheduler service.""" - with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \ - patch.object(scheduler_service.scheduler, "start") as mock_start: - + with ( + patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, + patch.object(scheduler_service.scheduler, "start") as mock_start, + ): await scheduler_service.start() # Verify job was added @@ -61,7 +62,9 @@ class TestSchedulerService: "total_credits_added": 500, } - with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge: + with patch.object( + scheduler_service.credit_service, "recharge_all_users_credits", + ) as mock_recharge: mock_recharge.return_value = mock_stats await scheduler_service._daily_credit_recharge() @@ -71,7 +74,9 @@ class TestSchedulerService: @pytest.mark.asyncio async def test_daily_credit_recharge_failure(self, scheduler_service) -> None: """Test daily credit recharge task with failure.""" - with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge: + with patch.object( + scheduler_service.credit_service, "recharge_all_users_credits", + ) as mock_recharge: mock_recharge.side_effect = Exception("Database error") # Should not raise exception, just log it