Refactor user endpoint tests to include pagination and response structure validation

- Updated tests for listing users to validate pagination and response format.
- Changed mock return values to include total count and pagination details.
- Refactored user creation mocks for clarity and consistency.
- Enhanced assertions to check for presence of pagination fields in responses.
- Adjusted test cases for user retrieval and updates to ensure proper handling of user data.
- Improved readability by restructuring mock definitions and assertions across various test files.
This commit is contained in:
JSC
2025-08-17 12:36:52 +02:00
parent e6f796a3c9
commit 6b55ff0e81
35 changed files with 863 additions and 503 deletions

View File

@@ -10,7 +10,7 @@ from app.core.dependencies import get_admin_user
from app.models.plan import Plan
from app.models.user import User
from app.repositories.plan import PlanRepository
from app.repositories.user import UserRepository, UserSortField, SortOrder, UserStatus
from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
from app.schemas.auth import UserResponse
from app.schemas.user import UserUpdate
@@ -36,21 +36,27 @@ def _user_to_response(user: User) -> UserResponse:
"name": user.plan.name,
"max_credits": user.plan.max_credits,
"features": [], # Add features if needed
} if user.plan else {},
}
if user.plan
else {},
created_at=user.created_at,
updated_at=user.updated_at,
)
@router.get("/")
async def list_users(
async def list_users( # noqa: PLR0913
session: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
search: Annotated[str | None, Query(description="Search in name or email")] = None,
sort_by: Annotated[UserSortField, Query(description="Sort by field")] = UserSortField.NAME,
sort_by: Annotated[
UserSortField, Query(description="Sort by field"),
] = UserSortField.NAME,
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
status_filter: Annotated[UserStatus, Query(description="Filter by status")] = UserStatus.ALL,
status_filter: Annotated[
UserStatus, Query(description="Filter by status"),
] = UserStatus.ALL,
) -> dict[str, Any]:
"""Get all users with pagination, search, and filters (admin only)."""
user_repo = UserRepository(session)
@@ -62,9 +68,9 @@ async def list_users(
sort_order=sort_order,
status_filter=status_filter,
)
total_pages = (total_count + limit - 1) // limit # Ceiling division
return {
"users": [_user_to_response(user) for user in users],
"total": total_count,

View File

@@ -464,7 +464,8 @@ async def update_profile(
"""Update the current user's profile."""
try:
updated_user = await auth_service.update_user_profile(
current_user, request.model_dump(exclude_unset=True),
current_user,
request.model_dump(exclude_unset=True),
)
return await auth_service.user_to_response(updated_user)
except Exception as e:
@@ -486,7 +487,9 @@ async def change_password(
user_email = current_user.email
try:
await auth_service.change_user_password(
current_user, request.current_password, request.new_password,
current_user,
request.current_password,
request.new_password,
)
except ValueError as e:
raise HTTPException(
@@ -513,11 +516,13 @@ async def get_user_providers(
# Add password provider if user has password
if current_user.password_hash:
providers.append({
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
})
providers.append(
{
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
},
)
# Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
@@ -528,10 +533,12 @@ async def get_user_providers(
elif oauth.provider == "google":
display_name = "Google"
providers.append({
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
})
providers.append(
{
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
},
)
return providers

View File

@@ -34,7 +34,8 @@ async def get_top_sounds(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
sound_type: Annotated[
str, Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
str,
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
],
period: Annotated[
str,
@@ -43,7 +44,8 @@ async def get_top_sounds(
),
] = "all_time",
limit: Annotated[
int, Query(description="Number of top sounds to return", ge=1, le=100),
int,
Query(description="Number of top sounds to return", ge=1, le=100),
] = 10,
) -> list[dict[str, Any]]:
"""Get top sounds by play count for a specific period."""

View File

@@ -60,68 +60,13 @@ async def create_extraction(
}
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions(
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result
@router.get("/user")
async def get_user_extractions(
async def get_user_extractions( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
@@ -153,3 +98,62 @@ async def get_user_extractions(
) from e
else:
return result
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions( # noqa: PLR0913
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result

View File

@@ -4,6 +4,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.database import get_session_factory
from app.core.dependencies import get_current_active_user
from app.models.user import User
from app.schemas.common import MessageResponse
@@ -19,12 +20,10 @@ router = APIRouter(prefix="/favorites", tags=["favorites"])
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
from app.core.database import get_session_factory
return FavoriteService(get_session_factory())
@router.get("/", response_model=FavoritesListResponse)
@router.get("/")
async def get_user_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -33,12 +32,14 @@ async def get_user_favorites(
) -> FavoritesListResponse:
"""Get all favorites for the current user."""
favorites = await favorite_service.get_user_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/sounds", response_model=FavoritesListResponse)
@router.get("/sounds")
async def get_user_sound_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -47,12 +48,14 @@ async def get_user_sound_favorites(
) -> FavoritesListResponse:
"""Get sound favorites for the current user."""
favorites = await favorite_service.get_user_sound_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/playlists", response_model=FavoritesListResponse)
@router.get("/playlists")
async def get_user_playlist_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -61,12 +64,14 @@ async def get_user_playlist_favorites(
) -> FavoritesListResponse:
"""Get playlist favorites for the current user."""
favorites = await favorite_service.get_user_playlist_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/counts", response_model=FavoriteCountsResponse)
@router.get("/counts")
async def get_favorite_counts(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -76,7 +81,7 @@ async def get_favorite_counts(
return FavoriteCountsResponse(**counts)
@router.post("/sounds/{sound_id}", response_model=FavoriteResponse)
@router.post("/sounds/{sound_id}")
async def add_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -103,7 +108,7 @@ async def add_sound_favorite(
) from e
@router.post("/playlists/{playlist_id}", response_model=FavoriteResponse)
@router.post("/playlists/{playlist_id}")
async def add_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -112,7 +117,8 @@ async def add_playlist_favorite(
"""Add a playlist to favorites."""
try:
favorite = await favorite_service.add_playlist_favorite(
current_user.id, playlist_id,
current_user.id,
playlist_id,
)
return FavoriteResponse.model_validate(favorite)
except ValueError as e:
@@ -132,7 +138,7 @@ async def add_playlist_favorite(
) from e
@router.delete("/sounds/{sound_id}", response_model=MessageResponse)
@router.delete("/sounds/{sound_id}")
async def remove_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -149,7 +155,7 @@ async def remove_sound_favorite(
) from e
@router.delete("/playlists/{playlist_id}", response_model=MessageResponse)
@router.delete("/playlists/{playlist_id}")
async def remove_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -185,6 +191,7 @@ async def check_playlist_favorited(
) -> dict[str, bool]:
"""Check if a playlist is favorited by the current user."""
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id,
current_user.id,
playlist_id,
)
return {"is_favorited": is_favorited}

View File

@@ -5,7 +5,7 @@ from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.database import get_db, get_session_factory
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder
@@ -34,7 +34,6 @@ async def get_playlist_service(
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
from app.core.database import get_session_factory
return FavoriteService(get_session_factory())
@@ -57,7 +56,7 @@ async def get_all_playlists( # noqa: PLR0913
] = SortOrder.ASC,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
favorites_only: Annotated[
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited playlists"),
] = False,
@@ -78,15 +77,26 @@ async def get_all_playlists( # noqa: PLR0913
# Convert to PlaylistResponse with favorite indicators
playlist_responses = []
for playlist_dict in result["playlists"]:
# The playlist service returns dict, need to create playlist object-like structure
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist_dict["id"])
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_dict["id"])
# The playlist service returns dict, need to create playlist object structure
playlist_id = playlist_dict["id"]
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
# Create a PlaylistResponse-like dict with proper datetime conversion
playlist_response = {
**playlist_dict,
"created_at": playlist_dict["created_at"].isoformat() if playlist_dict["created_at"] else None,
"updated_at": playlist_dict["updated_at"].isoformat() if playlist_dict["updated_at"] else None,
"created_at": (
playlist_dict["created_at"].isoformat()
if playlist_dict["created_at"]
else None
),
"updated_at": (
playlist_dict["updated_at"].isoformat()
if playlist_dict["updated_at"]
else None
),
"is_favorited": is_favorited,
"favorite_count": favorite_count,
}
@@ -113,9 +123,13 @@ async def get_user_playlists(
# Add favorite indicators for each playlist
playlist_responses = []
for playlist in playlists:
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
playlist_response = PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
playlist_response = PlaylistResponse.from_playlist(
playlist, is_favorited, favorite_count,
)
playlist_responses.append(playlist_response)
return playlist_responses
@@ -129,7 +143,9 @@ async def get_main_playlist(
) -> PlaylistResponse:
"""Get the global main playlist."""
playlist = await playlist_service.get_main_playlist()
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@@ -142,7 +158,9 @@ async def get_current_playlist(
) -> PlaylistResponse:
"""Get the global current playlist (falls back to main playlist)."""
playlist = await playlist_service.get_current_playlist()
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@@ -172,7 +190,9 @@ async def get_playlist(
) -> PlaylistResponse:
"""Get a specific playlist."""
playlist = await playlist_service.get_playlist_by_id(playlist_id)
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)

View File

@@ -40,7 +40,7 @@ async def get_sound_repository(
return SoundRepository(session)
@router.get("/", response_model=SoundsListResponse)
@router.get("/")
async def get_sounds( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
@@ -69,7 +69,7 @@ async def get_sounds( # noqa: PLR0913
int,
Query(description="Number of results to skip", ge=0),
] = 0,
favorites_only: Annotated[
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited sounds"),
] = False,
@@ -90,9 +90,13 @@ async def get_sounds( # noqa: PLR0913
# Add favorite indicators for each sound
sound_responses = []
for sound in sounds:
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound.id)
is_favorited = await favorite_service.is_sound_favorited(
current_user.id, sound.id,
)
favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
sound_response = SoundResponse.from_sound(sound, is_favorited, favorite_count)
sound_response = SoundResponse.from_sound(
sound, is_favorited, favorite_count,
)
sound_responses.append(sound_response)
except Exception as e: