Refactor user endpoint tests to include pagination and response structure validation

- Updated tests for listing users to validate pagination and response format.
- Changed mock return values to include total count and pagination details.
- Refactored user creation mocks for clarity and consistency.
- Enhanced assertions to check for presence of pagination fields in responses.
- Adjusted test cases for user retrieval and updates to ensure proper handling of user data.
- Improved readability by restructuring mock definitions and assertions across various test files.
This commit is contained in:
JSC
2025-08-17 12:36:52 +02:00
parent e6f796a3c9
commit 6b55ff0e81
35 changed files with 863 additions and 503 deletions

View File

@@ -10,7 +10,7 @@ from app.core.dependencies import get_admin_user
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.repositories.plan import PlanRepository from app.repositories.plan import PlanRepository
from app.repositories.user import UserRepository, UserSortField, SortOrder, UserStatus from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
from app.schemas.auth import UserResponse from app.schemas.auth import UserResponse
from app.schemas.user import UserUpdate from app.schemas.user import UserUpdate
@@ -36,21 +36,27 @@ def _user_to_response(user: User) -> UserResponse:
"name": user.plan.name, "name": user.plan.name,
"max_credits": user.plan.max_credits, "max_credits": user.plan.max_credits,
"features": [], # Add features if needed "features": [], # Add features if needed
} if user.plan else {}, }
if user.plan
else {},
created_at=user.created_at, created_at=user.created_at,
updated_at=user.updated_at, updated_at=user.updated_at,
) )
@router.get("/") @router.get("/")
async def list_users( async def list_users( # noqa: PLR0913
session: Annotated[AsyncSession, Depends(get_db)], session: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(description="Page number", ge=1)] = 1, page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50, limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
search: Annotated[str | None, Query(description="Search in name or email")] = None, search: Annotated[str | None, Query(description="Search in name or email")] = None,
sort_by: Annotated[UserSortField, Query(description="Sort by field")] = UserSortField.NAME, sort_by: Annotated[
UserSortField, Query(description="Sort by field"),
] = UserSortField.NAME,
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC, sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
status_filter: Annotated[UserStatus, Query(description="Filter by status")] = UserStatus.ALL, status_filter: Annotated[
UserStatus, Query(description="Filter by status"),
] = UserStatus.ALL,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Get all users with pagination, search, and filters (admin only).""" """Get all users with pagination, search, and filters (admin only)."""
user_repo = UserRepository(session) user_repo = UserRepository(session)

View File

@@ -464,7 +464,8 @@ async def update_profile(
"""Update the current user's profile.""" """Update the current user's profile."""
try: try:
updated_user = await auth_service.update_user_profile( updated_user = await auth_service.update_user_profile(
current_user, request.model_dump(exclude_unset=True), current_user,
request.model_dump(exclude_unset=True),
) )
return await auth_service.user_to_response(updated_user) return await auth_service.user_to_response(updated_user)
except Exception as e: except Exception as e:
@@ -486,7 +487,9 @@ async def change_password(
user_email = current_user.email user_email = current_user.email
try: try:
await auth_service.change_user_password( await auth_service.change_user_password(
current_user, request.current_password, request.new_password, current_user,
request.current_password,
request.new_password,
) )
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(
@@ -513,11 +516,13 @@ async def get_user_providers(
# Add password provider if user has password # Add password provider if user has password
if current_user.password_hash: if current_user.password_hash:
providers.append({ providers.append(
"provider": "password", {
"display_name": "Password", "provider": "password",
"connected_at": current_user.created_at.isoformat(), "display_name": "Password",
}) "connected_at": current_user.created_at.isoformat(),
},
)
# Get OAuth providers from the database # Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user) oauth_providers = await auth_service.get_user_oauth_providers(current_user)
@@ -528,10 +533,12 @@ async def get_user_providers(
elif oauth.provider == "google": elif oauth.provider == "google":
display_name = "Google" display_name = "Google"
providers.append({ providers.append(
"provider": oauth.provider, {
"display_name": display_name, "provider": oauth.provider,
"connected_at": oauth.created_at.isoformat(), "display_name": display_name,
}) "connected_at": oauth.created_at.isoformat(),
},
)
return providers return providers

View File

@@ -34,7 +34,8 @@ async def get_top_sounds(
_current_user: Annotated[User, Depends(get_current_user)], _current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)], dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
sound_type: Annotated[ sound_type: Annotated[
str, Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"), str,
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
], ],
period: Annotated[ period: Annotated[
str, str,
@@ -43,7 +44,8 @@ async def get_top_sounds(
), ),
] = "all_time", ] = "all_time",
limit: Annotated[ limit: Annotated[
int, Query(description="Number of top sounds to return", ge=1, le=100), int,
Query(description="Number of top sounds to return", ge=1, le=100),
] = 10, ] = 10,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Get top sounds by play count for a specific period.""" """Get top sounds by play count for a specific period."""

View File

@@ -60,68 +60,13 @@ async def create_extraction(
} }
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions(
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result
@router.get("/user") @router.get("/user")
async def get_user_extractions( async def get_user_extractions( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)], extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None, search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at", sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc", sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None, status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
@@ -153,3 +98,62 @@ async def get_user_extractions(
) from e ) from e
else: else:
return result return result
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions( # noqa: PLR0913
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result

View File

@@ -4,6 +4,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.database import get_session_factory
from app.core.dependencies import get_current_active_user from app.core.dependencies import get_current_active_user
from app.models.user import User from app.models.user import User
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
@@ -19,12 +20,10 @@ router = APIRouter(prefix="/favorites", tags=["favorites"])
def get_favorite_service() -> FavoriteService: def get_favorite_service() -> FavoriteService:
"""Get the favorite service.""" """Get the favorite service."""
from app.core.database import get_session_factory
return FavoriteService(get_session_factory()) return FavoriteService(get_session_factory())
@router.get("/", response_model=FavoritesListResponse) @router.get("/")
async def get_user_favorites( async def get_user_favorites(
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)], favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -33,12 +32,14 @@ async def get_user_favorites(
) -> FavoritesListResponse: ) -> FavoritesListResponse:
"""Get all favorites for the current user.""" """Get all favorites for the current user."""
favorites = await favorite_service.get_user_favorites( favorites = await favorite_service.get_user_favorites(
current_user.id, limit, offset, current_user.id,
limit,
offset,
) )
return FavoritesListResponse(favorites=favorites) return FavoritesListResponse(favorites=favorites)
@router.get("/sounds", response_model=FavoritesListResponse) @router.get("/sounds")
async def get_user_sound_favorites( async def get_user_sound_favorites(
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)], favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -47,12 +48,14 @@ async def get_user_sound_favorites(
) -> FavoritesListResponse: ) -> FavoritesListResponse:
"""Get sound favorites for the current user.""" """Get sound favorites for the current user."""
favorites = await favorite_service.get_user_sound_favorites( favorites = await favorite_service.get_user_sound_favorites(
current_user.id, limit, offset, current_user.id,
limit,
offset,
) )
return FavoritesListResponse(favorites=favorites) return FavoritesListResponse(favorites=favorites)
@router.get("/playlists", response_model=FavoritesListResponse) @router.get("/playlists")
async def get_user_playlist_favorites( async def get_user_playlist_favorites(
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)], favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -61,12 +64,14 @@ async def get_user_playlist_favorites(
) -> FavoritesListResponse: ) -> FavoritesListResponse:
"""Get playlist favorites for the current user.""" """Get playlist favorites for the current user."""
favorites = await favorite_service.get_user_playlist_favorites( favorites = await favorite_service.get_user_playlist_favorites(
current_user.id, limit, offset, current_user.id,
limit,
offset,
) )
return FavoritesListResponse(favorites=favorites) return FavoritesListResponse(favorites=favorites)
@router.get("/counts", response_model=FavoriteCountsResponse) @router.get("/counts")
async def get_favorite_counts( async def get_favorite_counts(
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)], favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -76,7 +81,7 @@ async def get_favorite_counts(
return FavoriteCountsResponse(**counts) return FavoriteCountsResponse(**counts)
@router.post("/sounds/{sound_id}", response_model=FavoriteResponse) @router.post("/sounds/{sound_id}")
async def add_sound_favorite( async def add_sound_favorite(
sound_id: int, sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
@@ -103,7 +108,7 @@ async def add_sound_favorite(
) from e ) from e
@router.post("/playlists/{playlist_id}", response_model=FavoriteResponse) @router.post("/playlists/{playlist_id}")
async def add_playlist_favorite( async def add_playlist_favorite(
playlist_id: int, playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
@@ -112,7 +117,8 @@ async def add_playlist_favorite(
"""Add a playlist to favorites.""" """Add a playlist to favorites."""
try: try:
favorite = await favorite_service.add_playlist_favorite( favorite = await favorite_service.add_playlist_favorite(
current_user.id, playlist_id, current_user.id,
playlist_id,
) )
return FavoriteResponse.model_validate(favorite) return FavoriteResponse.model_validate(favorite)
except ValueError as e: except ValueError as e:
@@ -132,7 +138,7 @@ async def add_playlist_favorite(
) from e ) from e
@router.delete("/sounds/{sound_id}", response_model=MessageResponse) @router.delete("/sounds/{sound_id}")
async def remove_sound_favorite( async def remove_sound_favorite(
sound_id: int, sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
@@ -149,7 +155,7 @@ async def remove_sound_favorite(
) from e ) from e
@router.delete("/playlists/{playlist_id}", response_model=MessageResponse) @router.delete("/playlists/{playlist_id}")
async def remove_playlist_favorite( async def remove_playlist_favorite(
playlist_id: int, playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)], current_user: Annotated[User, Depends(get_current_active_user)],
@@ -185,6 +191,7 @@ async def check_playlist_favorited(
) -> dict[str, bool]: ) -> dict[str, bool]:
"""Check if a playlist is favorited by the current user.""" """Check if a playlist is favorited by the current user."""
is_favorited = await favorite_service.is_playlist_favorited( is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id, current_user.id,
playlist_id,
) )
return {"is_favorited": is_favorited} return {"is_favorited": is_favorited}

View File

@@ -5,7 +5,7 @@ from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db, get_session_factory
from app.core.dependencies import get_current_active_user_flexible from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder from app.repositories.playlist import PlaylistSortField, SortOrder
@@ -34,7 +34,6 @@ async def get_playlist_service(
def get_favorite_service() -> FavoriteService: def get_favorite_service() -> FavoriteService:
"""Get the favorite service.""" """Get the favorite service."""
from app.core.database import get_session_factory
return FavoriteService(get_session_factory()) return FavoriteService(get_session_factory())
@@ -57,7 +56,7 @@ async def get_all_playlists( # noqa: PLR0913
] = SortOrder.ASC, ] = SortOrder.ASC,
page: Annotated[int, Query(description="Page number", ge=1)] = 1, page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50, limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
favorites_only: Annotated[ favorites_only: Annotated[ # noqa: FBT002
bool, bool,
Query(description="Show only favorited playlists"), Query(description="Show only favorited playlists"),
] = False, ] = False,
@@ -78,15 +77,26 @@ async def get_all_playlists( # noqa: PLR0913
# Convert to PlaylistResponse with favorite indicators # Convert to PlaylistResponse with favorite indicators
playlist_responses = [] playlist_responses = []
for playlist_dict in result["playlists"]: for playlist_dict in result["playlists"]:
# The playlist service returns dict, need to create playlist object-like structure # The playlist service returns dict, need to create playlist object structure
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist_dict["id"]) playlist_id = playlist_dict["id"]
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_dict["id"]) is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
# Create a PlaylistResponse-like dict with proper datetime conversion # Create a PlaylistResponse-like dict with proper datetime conversion
playlist_response = { playlist_response = {
**playlist_dict, **playlist_dict,
"created_at": playlist_dict["created_at"].isoformat() if playlist_dict["created_at"] else None, "created_at": (
"updated_at": playlist_dict["updated_at"].isoformat() if playlist_dict["updated_at"] else None, playlist_dict["created_at"].isoformat()
if playlist_dict["created_at"]
else None
),
"updated_at": (
playlist_dict["updated_at"].isoformat()
if playlist_dict["updated_at"]
else None
),
"is_favorited": is_favorited, "is_favorited": is_favorited,
"favorite_count": favorite_count, "favorite_count": favorite_count,
} }
@@ -113,9 +123,13 @@ async def get_user_playlists(
# Add favorite indicators for each playlist # Add favorite indicators for each playlist
playlist_responses = [] playlist_responses = []
for playlist in playlists: for playlist in playlists:
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id) is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id) favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
playlist_response = PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count) playlist_response = PlaylistResponse.from_playlist(
playlist, is_favorited, favorite_count,
)
playlist_responses.append(playlist_response) playlist_responses.append(playlist_response)
return playlist_responses return playlist_responses
@@ -129,7 +143,9 @@ async def get_main_playlist(
) -> PlaylistResponse: ) -> PlaylistResponse:
"""Get the global main playlist.""" """Get the global main playlist."""
playlist = await playlist_service.get_main_playlist() playlist = await playlist_service.get_main_playlist()
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id) is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id) favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count) return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@@ -142,7 +158,9 @@ async def get_current_playlist(
) -> PlaylistResponse: ) -> PlaylistResponse:
"""Get the global current playlist (falls back to main playlist).""" """Get the global current playlist (falls back to main playlist)."""
playlist = await playlist_service.get_current_playlist() playlist = await playlist_service.get_current_playlist()
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id) is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id) favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count) return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@@ -172,7 +190,9 @@ async def get_playlist(
) -> PlaylistResponse: ) -> PlaylistResponse:
"""Get a specific playlist.""" """Get a specific playlist."""
playlist = await playlist_service.get_playlist_by_id(playlist_id) playlist = await playlist_service.get_playlist_by_id(playlist_id)
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id) is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id) favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count) return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)

View File

@@ -40,7 +40,7 @@ async def get_sound_repository(
return SoundRepository(session) return SoundRepository(session)
@router.get("/", response_model=SoundsListResponse) @router.get("/")
async def get_sounds( # noqa: PLR0913 async def get_sounds( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)], current_user: Annotated[User, Depends(get_current_active_user_flexible)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)], sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
@@ -69,7 +69,7 @@ async def get_sounds( # noqa: PLR0913
int, int,
Query(description="Number of results to skip", ge=0), Query(description="Number of results to skip", ge=0),
] = 0, ] = 0,
favorites_only: Annotated[ favorites_only: Annotated[ # noqa: FBT002
bool, bool,
Query(description="Show only favorited sounds"), Query(description="Show only favorited sounds"),
] = False, ] = False,
@@ -90,9 +90,13 @@ async def get_sounds( # noqa: PLR0913
# Add favorite indicators for each sound # Add favorite indicators for each sound
sound_responses = [] sound_responses = []
for sound in sounds: for sound in sounds:
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound.id) is_favorited = await favorite_service.is_sound_favorited(
current_user.id, sound.id,
)
favorite_count = await favorite_service.get_sound_favorite_count(sound.id) favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
sound_response = SoundResponse.from_sound(sound, is_favorited, favorite_count) sound_response = SoundResponse.from_sound(
sound, is_favorited, favorite_count,
)
sound_responses.append(sound_response) sound_responses.append(sound_response)
except Exception as e: except Exception as e:

View File

@@ -20,7 +20,7 @@ class Settings(BaseSettings):
# Production URLs (for reverse proxy deployment) # Production URLs (for reverse proxy deployment)
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
BACKEND_URL: str = "http://localhost:8000" # Backend base URL BACKEND_URL: str = "http://localhost:8000" # Backend base URL
# CORS Configuration # CORS Configuration
CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS

View File

@@ -20,7 +20,9 @@ class BaseModel(SQLModel):
# SQLAlchemy event listener to automatically update updated_at timestamp # SQLAlchemy event listener to automatically update updated_at timestamp
@event.listens_for(BaseModel, "before_update", propagate=True) @event.listens_for(BaseModel, "before_update", propagate=True)
def update_timestamp( def update_timestamp(
mapper: Mapper[Any], connection: Connection, target: BaseModel, # noqa: ARG001 mapper: Mapper[Any], # noqa: ARG001
connection: Connection, # noqa: ARG001
target: BaseModel,
) -> None: ) -> None:
"""Automatically set updated_at timestamp before update operations.""" """Automatically set updated_at timestamp before update operations."""
target.updated_at = datetime.now(UTC) target.updated_at = datetime.now(UTC)

View File

@@ -35,5 +35,3 @@ class PlaylistSound(BaseModel, table=True):
# relationships # relationships
playlist: "Playlist" = Relationship(back_populates="playlist_sounds") playlist: "Playlist" = Relationship(back_populates="playlist_sounds")
sound: "Sound" = Relationship(back_populates="playlist_sounds") sound: "Sound" = Relationship(back_populates="playlist_sounds")

View File

@@ -58,7 +58,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
) )
return list(result.all()) return list(result.all())
async def get_user_extractions_filtered( async def get_user_extractions_filtered( # noqa: PLR0913
self, self,
user_id: int, user_id: int,
search: str | None = None, search: str | None = None,
@@ -92,7 +92,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
# Get total count before pagination # Get total count before pagination
count_query = select(func.count()).select_from( count_query = select(func.count()).select_from(
base_query.subquery() base_query.subquery(),
) )
count_result = await self.session.exec(count_query) count_result = await self.session.exec(count_query)
total_count = count_result.one() total_count = count_result.one()
@@ -109,7 +109,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
return list(result.all()), total_count return list(result.all()), total_count
async def get_all_extractions_filtered( async def get_all_extractions_filtered( # noqa: PLR0913
self, self,
search: str | None = None, search: str | None = None,
sort_by: str = "created_at", sort_by: str = "created_at",
@@ -138,7 +138,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
# Get total count before pagination # Get total count before pagination
count_query = select(func.count()).select_from( count_query = select(func.count()).select_from(
base_query.subquery() base_query.subquery(),
) )
count_result = await self.session.exec(count_query) count_result = await self.session.exec(count_query)
total_count = count_result.one() total_count = count_result.one()

View File

@@ -118,7 +118,9 @@ class FavoriteRepository(BaseRepository[Favorite]):
raise raise
async def get_by_user_and_sound( async def get_by_user_and_sound(
self, user_id: int, sound_id: int, self,
user_id: int,
sound_id: int,
) -> Favorite | None: ) -> Favorite | None:
"""Get a favorite by user and sound. """Get a favorite by user and sound.
@@ -138,12 +140,16 @@ class FavoriteRepository(BaseRepository[Favorite]):
return result.first() return result.first()
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to get favorite for user %s and sound %s", user_id, sound_id, "Failed to get favorite for user %s and sound %s",
user_id,
sound_id,
) )
raise raise
async def get_by_user_and_playlist( async def get_by_user_and_playlist(
self, user_id: int, playlist_id: int, self,
user_id: int,
playlist_id: int,
) -> Favorite | None: ) -> Favorite | None:
"""Get a favorite by user and playlist. """Get a favorite by user and playlist.

View File

@@ -57,7 +57,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
# management # management
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to update playlist timestamp for playlist: %s", playlist_id, "Failed to update playlist timestamp for playlist: %s",
playlist_id,
) )
raise raise
@@ -341,7 +342,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
include_stats: bool = False, # noqa: FBT001, FBT002 include_stats: bool = False, # noqa: FBT001, FBT002
limit: int | None = None, limit: int | None = None,
offset: int = 0, offset: int = 0,
favorites_only: bool = False, favorites_only: bool = False, # noqa: FBT001, FBT002
current_user_id: int | None = None, current_user_id: int | None = None,
*, *,
return_count: bool = False, return_count: bool = False,
@@ -395,9 +396,13 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Apply favorites filter # Apply favorites filter
if favorites_only and current_user_id is not None: if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY # Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = select(1).select_from(Favorite).where( favorites_subquery = (
Favorite.user_id == current_user_id, select(1)
Favorite.playlist_id == Playlist.id, .select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
) )
subquery = subquery.where(favorites_subquery.exists()) subquery = subquery.where(favorites_subquery.exists())
@@ -466,9 +471,13 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Apply favorites filter # Apply favorites filter
if favorites_only and current_user_id is not None: if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY # Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = select(1).select_from(Favorite).where( favorites_subquery = (
Favorite.user_id == current_user_id, select(1)
Favorite.playlist_id == Playlist.id, .select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
) )
subquery = subquery.where(favorites_subquery.exists()) subquery = subquery.where(favorites_subquery.exists())

View File

@@ -141,7 +141,7 @@ class SoundRepository(BaseRepository[Sound]):
sort_order: SortOrder = SortOrder.ASC, sort_order: SortOrder = SortOrder.ASC,
limit: int | None = None, limit: int | None = None,
offset: int = 0, offset: int = 0,
favorites_only: bool = False, favorites_only: bool = False, # noqa: FBT001, FBT002
user_id: int | None = None, user_id: int | None = None,
) -> list[Sound]: ) -> list[Sound]:
"""Search and sort sounds with optional filtering.""" """Search and sort sounds with optional filtering."""
@@ -189,7 +189,8 @@ class SoundRepository(BaseRepository[Sound]):
logger.exception( logger.exception(
( (
"Failed to search and sort sounds: " "Failed to search and sort sounds: "
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, user_id=%s" "query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, "
"user_id=%s"
), ),
search_query, search_query,
sound_types, sound_types,
@@ -288,8 +289,7 @@ class SoundRepository(BaseRepository[Sound]):
# Group by sound and order by play count descending # Group by sound and order by play count descending
statement = ( statement = (
statement statement.group_by(
.group_by(
Sound.id, Sound.id,
Sound.name, Sound.name,
Sound.type, Sound.type,

View File

@@ -1,7 +1,7 @@
"""User repository.""" """User repository."""
from typing import Any
from enum import Enum from enum import Enum
from typing import Any
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
class UserSortField(str, Enum): class UserSortField(str, Enum):
"""User sort fields.""" """User sort fields."""
NAME = "name" NAME = "name"
EMAIL = "email" EMAIL = "email"
ROLE = "role" ROLE = "role"
@@ -27,12 +28,14 @@ class UserSortField(str, Enum):
class SortOrder(str, Enum): class SortOrder(str, Enum):
"""Sort order.""" """Sort order."""
ASC = "asc" ASC = "asc"
DESC = "desc" DESC = "desc"
class UserStatus(str, Enum): class UserStatus(str, Enum):
"""User status filter.""" """User status filter."""
ALL = "all" ALL = "all"
ACTIVE = "active" ACTIVE = "active"
INACTIVE = "inactive" INACTIVE = "inactive"
@@ -64,7 +67,7 @@ class UserRepository(BaseRepository[User]):
logger.exception("Failed to get all users with plan") logger.exception("Failed to get all users with plan")
raise raise
async def get_all_with_plan_paginated( async def get_all_with_plan_paginated( # noqa: PLR0913
self, self,
page: int = 1, page: int = 1,
limit: int = 50, limit: int = 50,
@@ -85,10 +88,9 @@ class UserRepository(BaseRepository[User]):
# Apply search filter # Apply search filter
if search and search.strip(): if search and search.strip():
search_pattern = f"%{search.strip().lower()}%" search_pattern = f"%{search.strip().lower()}%"
search_condition = ( search_condition = func.lower(User.name).like(
func.lower(User.name).like(search_pattern) | search_pattern,
func.lower(User.email).like(search_pattern) ) | func.lower(User.email).like(search_pattern)
)
base_query = base_query.where(search_condition) base_query = base_query.where(search_condition)
count_query = count_query.where(search_condition) count_query = count_query.where(search_condition)
@@ -101,31 +103,18 @@ class UserRepository(BaseRepository[User]):
count_query = count_query.where(User.is_active == False) # noqa: E712 count_query = count_query.where(User.is_active == False) # noqa: E712
# Apply sorting # Apply sorting
if sort_by == UserSortField.EMAIL: sort_column = {
if sort_order == SortOrder.DESC: UserSortField.NAME: User.name,
base_query = base_query.order_by(User.email.desc()) UserSortField.EMAIL: User.email,
else: UserSortField.ROLE: User.role,
base_query = base_query.order_by(User.email.asc()) UserSortField.CREDITS: User.credits,
elif sort_by == UserSortField.ROLE: UserSortField.CREATED_AT: User.created_at,
if sort_order == SortOrder.DESC: }.get(sort_by, User.name)
base_query = base_query.order_by(User.role.desc())
else: if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.role.asc()) base_query = base_query.order_by(sort_column.desc())
elif sort_by == UserSortField.CREDITS: else:
if sort_order == SortOrder.DESC: base_query = base_query.order_by(sort_column.asc())
base_query = base_query.order_by(User.credits.desc())
else:
base_query = base_query.order_by(User.credits.asc())
elif sort_by == UserSortField.CREATED_AT:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.created_at.desc())
else:
base_query = base_query.order_by(User.created_at.asc())
else: # Default to name
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.name.desc())
else:
base_query = base_query.order_by(User.name.asc())
# Get total count # Get total count
count_result = await self.session.exec(count_query) count_result = await self.session.exec(count_query)
@@ -135,11 +124,11 @@ class UserRepository(BaseRepository[User]):
paginated_query = base_query.limit(limit).offset(offset) paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query) result = await self.session.exec(paginated_query)
users = list(result.all()) users = list(result.all())
return users, total_count
except Exception: except Exception:
logger.exception("Failed to get paginated users with plan") logger.exception("Failed to get paginated users with plan")
raise raise
else:
return users, total_count
async def get_by_id_with_plan(self, entity_id: int) -> User | None: async def get_by_id_with_plan(self, entity_id: int) -> User | None:
"""Get a user by ID with plan relationship loaded.""" """Get a user by ID with plan relationship loaded."""
@@ -178,7 +167,7 @@ class UserRepository(BaseRepository[User]):
logger.exception("Failed to get user by API token") logger.exception("Failed to get user by API token")
raise raise
async def create(self, user_data: dict[str, Any]) -> User: async def create(self, entity_data: dict[str, Any]) -> User:
"""Create a new user with plan assignment and first user admin logic.""" """Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None: def _raise_plan_not_found() -> None:
@@ -194,7 +183,7 @@ class UserRepository(BaseRepository[User]):
if is_first_user: if is_first_user:
# First user gets admin role and pro plan # First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro") plan_statement = select(Plan).where(Plan.code == "pro")
user_data["role"] = "admin" entity_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan") logger.info("Creating first user with admin role and pro plan")
else: else:
# Regular users get free plan # Regular users get free plan
@@ -210,11 +199,11 @@ class UserRepository(BaseRepository[User]):
assert default_plan is not None # noqa: S101 assert default_plan is not None # noqa: S101
# Set plan_id and default credits # Set plan_id and default credits
user_data["plan_id"] = default_plan.id entity_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits entity_data["credits"] = default_plan.credits
# Use BaseRepository's create method # Use BaseRepository's create method
return await super().create(user_data) return await super().create(entity_data)
except Exception: except Exception:
logger.exception("Failed to create user") logger.exception("Failed to create user")
raise raise

View File

@@ -85,7 +85,8 @@ class ChangePasswordRequest(BaseModel):
"""Schema for password change request.""" """Schema for password change request."""
current_password: str | None = Field( current_password: str | None = Field(
None, description="Current password (required if user has existing password)", None,
description="Current password (required if user has existing password)",
) )
new_password: str = Field( new_password: str = Field(
..., ...,
@@ -98,5 +99,8 @@ class UpdateProfileRequest(BaseModel):
"""Schema for profile update request.""" """Schema for profile update request."""
name: str | None = Field( name: str | None = Field(
None, min_length=1, max_length=100, description="User display name", None,
min_length=1,
max_length=100,
description="User display name",
) )

View File

@@ -11,10 +11,12 @@ class FavoriteResponse(BaseModel):
id: int = Field(description="Favorite ID") id: int = Field(description="Favorite ID")
user_id: int = Field(description="User ID") user_id: int = Field(description="User ID")
sound_id: int | None = Field( sound_id: int | None = Field(
description="Sound ID if this is a sound favorite", default=None, description="Sound ID if this is a sound favorite",
default=None,
) )
playlist_id: int | None = Field( playlist_id: int | None = Field(
description="Playlist ID if this is a playlist favorite", default=None, description="Playlist ID if this is a playlist favorite",
default=None,
) )
created_at: datetime = Field(description="Creation timestamp") created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp") updated_at: datetime = Field(description="Last update timestamp")

View File

@@ -39,7 +39,12 @@ class PlaylistResponse(BaseModel):
updated_at: str | None updated_at: str | None
@classmethod @classmethod
def from_playlist(cls, playlist: Playlist, is_favorited: bool = False, favorite_count: int = 0) -> "PlaylistResponse": def from_playlist(
cls,
playlist: Playlist,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "PlaylistResponse":
"""Create response from playlist model. """Create response from playlist model.
Args: Args:

View File

@@ -18,16 +18,20 @@ class SoundResponse(BaseModel):
size: int = Field(description="File size in bytes") size: int = Field(description="File size in bytes")
hash: str = Field(description="File hash") hash: str = Field(description="File hash")
normalized_filename: str | None = Field( normalized_filename: str | None = Field(
description="Normalized filename", default=None, description="Normalized filename",
default=None,
) )
normalized_duration: int | None = Field( normalized_duration: int | None = Field(
description="Normalized duration in milliseconds", default=None, description="Normalized duration in milliseconds",
default=None,
) )
normalized_size: int | None = Field( normalized_size: int | None = Field(
description="Normalized file size in bytes", default=None, description="Normalized file size in bytes",
default=None,
) )
normalized_hash: str | None = Field( normalized_hash: str | None = Field(
description="Normalized file hash", default=None, description="Normalized file hash",
default=None,
) )
thumbnail: str | None = Field(description="Thumbnail filename", default=None) thumbnail: str | None = Field(description="Thumbnail filename", default=None)
play_count: int = Field(description="Number of times played") play_count: int = Field(description="Number of times played")
@@ -35,10 +39,12 @@ class SoundResponse(BaseModel):
is_music: bool = Field(description="Whether the sound is music") is_music: bool = Field(description="Whether the sound is music")
is_deletable: bool = Field(description="Whether the sound can be deleted") is_deletable: bool = Field(description="Whether the sound can be deleted")
is_favorited: bool = Field( is_favorited: bool = Field(
description="Whether the sound is favorited by the current user", default=False, description="Whether the sound is favorited by the current user",
default=False,
) )
favorite_count: int = Field( favorite_count: int = Field(
description="Number of users who favorited this sound", default=0, description="Number of users who favorited this sound",
default=0,
) )
created_at: datetime = Field(description="Creation timestamp") created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp") updated_at: datetime = Field(description="Last update timestamp")
@@ -50,7 +56,10 @@ class SoundResponse(BaseModel):
@classmethod @classmethod
def from_sound( def from_sound(
cls, sound: Sound, is_favorited: bool = False, favorite_count: int = 0, cls,
sound: Sound,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "SoundResponse": ) -> "SoundResponse":
"""Create a SoundResponse from a Sound model. """Create a SoundResponse from a Sound model.
@@ -64,7 +73,8 @@ class SoundResponse(BaseModel):
""" """
if sound.id is None: if sound.id is None:
raise ValueError("Sound ID cannot be None") msg = "Sound ID cannot be None"
raise ValueError(msg)
return cls( return cls(
id=sound.id, id=sound.id,

View File

@@ -7,7 +7,10 @@ class UserUpdate(BaseModel):
"""Schema for updating a user.""" """Schema for updating a user."""
name: str | None = Field( name: str | None = Field(
None, min_length=1, max_length=100, description="User full name", None,
min_length=1,
max_length=100,
description="User full name",
) )
plan_id: int | None = Field(None, description="User plan ID") plan_id: int | None = Field(None, description="User plan ID")
credits: int | None = Field(None, ge=0, description="User credits") credits: int | None = Field(None, ge=0, description="User credits")

View File

@@ -454,7 +454,10 @@ class AuthService:
return user return user
async def change_user_password( async def change_user_password(
self, user: User, current_password: str | None, new_password: str, self,
user: User,
current_password: str | None,
new_password: str,
) -> None: ) -> None:
"""Change user's password.""" """Change user's password."""
# Store user email before any operations to avoid session detachment issues # Store user email before any operations to avoid session detachment issues
@@ -484,8 +487,11 @@ class AuthService:
self.session.add(user) self.session.add(user)
await self.session.commit() await self.session.commit()
logger.info("Password %s successfully for user: %s", logger.info(
"changed" if had_existing_password else "set", user_email) "Password %s successfully for user: %s",
"changed" if had_existing_password else "set",
user_email,
)
async def user_to_response(self, user: User) -> UserResponse: async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information.""" """Convert User model to UserResponse with plan information."""

View File

@@ -72,9 +72,7 @@ class DashboardService:
"play_count": sound["play_count"], "play_count": sound["play_count"],
"duration": sound["duration"], "duration": sound["duration"],
"created_at": ( "created_at": (
sound["created_at"].isoformat() sound["created_at"].isoformat() if sound["created_at"] else None
if sound["created_at"]
else None
), ),
} }
for sound in top_sounds for sound in top_sounds

View File

@@ -532,7 +532,8 @@ class ExtractionService:
"""Add the sound to the user's main playlist.""" """Add the sound to the user's main playlist."""
try: try:
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001 await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
sound_id, user_id, sound_id,
user_id,
) )
logger.info( logger.info(
"Added sound %d to main playlist for user %d", "Added sound %d to main playlist for user %d",
@@ -554,6 +555,10 @@ class ExtractionService:
if not extraction: if not extraction:
return None return None
# Get user information
user = await self.user_repo.get_by_id(extraction.user_id)
user_name = user.name if user else None
return { return {
"id": extraction.id or 0, # Should never be None for existing extraction "id": extraction.id or 0, # Should never be None for existing extraction
"url": extraction.url, "url": extraction.url,
@@ -564,11 +569,12 @@ class ExtractionService:
"error": extraction.error, "error": extraction.error,
"sound_id": extraction.sound_id, "sound_id": extraction.sound_id,
"user_id": extraction.user_id, "user_id": extraction.user_id,
"user_name": user_name,
"created_at": extraction.created_at.isoformat(), "created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(), "updated_at": extraction.updated_at.isoformat(),
} }
async def get_user_extractions( async def get_user_extractions( # noqa: PLR0913
self, self,
user_id: int, user_id: int,
search: str | None = None, search: str | None = None,
@@ -580,7 +586,10 @@ class ExtractionService:
) -> PaginatedExtractionsResponse: ) -> PaginatedExtractionsResponse:
"""Get all extractions for a user with filtering, search, and sorting.""" """Get all extractions for a user with filtering, search, and sorting."""
offset = (page - 1) * limit offset = (page - 1) * limit
extraction_user_tuples, total_count = await self.extraction_repo.get_user_extractions_filtered( (
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_user_extractions_filtered(
user_id=user_id, user_id=user_id,
search=search, search=search,
sort_by=sort_by, sort_by=sort_by,
@@ -619,7 +628,7 @@ class ExtractionService:
"total_pages": total_pages, "total_pages": total_pages,
} }
async def get_all_extractions( async def get_all_extractions( # noqa: PLR0913
self, self,
search: str | None = None, search: str | None = None,
sort_by: str = "created_at", sort_by: str = "created_at",
@@ -630,7 +639,10 @@ class ExtractionService:
) -> PaginatedExtractionsResponse: ) -> PaginatedExtractionsResponse:
"""Get all extractions with filtering, search, and sorting.""" """Get all extractions with filtering, search, and sorting."""
offset = (page - 1) * limit offset = (page - 1) * limit
extraction_user_tuples, total_count = await self.extraction_repo.get_all_extractions_filtered( (
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_all_extractions_filtered(
search=search, search=search,
sort_by=sort_by, sort_by=sort_by,
sort_order=sort_order, sort_order=sort_order,

View File

@@ -49,12 +49,14 @@ class FavoriteService:
# Verify user exists # Verify user exists
user = await user_repo.get_by_id(user_id) user = await user_repo.get_by_id(user_id)
if not user: if not user:
raise ValueError(f"User with ID {user_id} not found") msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify sound exists # Verify sound exists
sound = await sound_repo.get_by_id(sound_id) sound = await sound_repo.get_by_id(sound_id)
if not sound: if not sound:
raise ValueError(f"Sound with ID {sound_id} not found") msg = f"Sound with ID {sound_id} not found"
raise ValueError(msg)
# Get data for the event immediately after loading # Get data for the event immediately after loading
sound_name = sound.name sound_name = sound.name
@@ -63,9 +65,8 @@ class FavoriteService:
# Check if already favorited # Check if already favorited
existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id) existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if existing: if existing:
raise ValueError( msg = f"Sound {sound_id} is already favorited by user {user_id}"
f"Sound {sound_id} is already favorited by user {user_id}", raise ValueError(msg)
)
# Create favorite # Create favorite
favorite_data = { favorite_data = {
@@ -120,12 +121,14 @@ class FavoriteService:
# Verify user exists # Verify user exists
user = await user_repo.get_by_id(user_id) user = await user_repo.get_by_id(user_id)
if not user: if not user:
raise ValueError(f"User with ID {user_id} not found") msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify playlist exists # Verify playlist exists
playlist = await playlist_repo.get_by_id(playlist_id) playlist = await playlist_repo.get_by_id(playlist_id)
if not playlist: if not playlist:
raise ValueError(f"Playlist with ID {playlist_id} not found") msg = f"Playlist with ID {playlist_id} not found"
raise ValueError(msg)
# Check if already favorited # Check if already favorited
existing = await favorite_repo.get_by_user_and_playlist( existing = await favorite_repo.get_by_user_and_playlist(
@@ -133,9 +136,8 @@ class FavoriteService:
playlist_id, playlist_id,
) )
if existing: if existing:
raise ValueError( msg = f"Playlist {playlist_id} is already favorited by user {user_id}"
f"Playlist {playlist_id} is already favorited by user {user_id}", raise ValueError(msg)
)
# Create favorite # Create favorite
favorite_data = { favorite_data = {
@@ -163,7 +165,8 @@ class FavoriteService:
favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id) favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if not favorite: if not favorite:
raise ValueError(f"Sound {sound_id} is not favorited by user {user_id}") msg = f"Sound {sound_id} is not favorited by user {user_id}"
raise ValueError(msg)
# Get user and sound info before deletion for the event # Get user and sound info before deletion for the event
user_repo = UserRepository(session) user_repo = UserRepository(session)
@@ -192,7 +195,8 @@ class FavoriteService:
} }
await socket_manager.broadcast_to_all("sound_favorited", event_data) await socket_manager.broadcast_to_all("sound_favorited", event_data)
logger.info( logger.info(
"Broadcasted sound_favorited event for sound %s removal", sound_id, "Broadcasted sound_favorited event for sound %s removal",
sound_id,
) )
except Exception: except Exception:
logger.exception( logger.exception(
@@ -219,9 +223,8 @@ class FavoriteService:
playlist_id, playlist_id,
) )
if not favorite: if not favorite:
raise ValueError( msg = f"Playlist {playlist_id} is not favorited by user {user_id}"
f"Playlist {playlist_id} is not favorited by user {user_id}", raise ValueError(msg)
)
await favorite_repo.delete(favorite) await favorite_repo.delete(favorite)
logger.info( logger.info(

View File

@@ -16,6 +16,7 @@ logger = get_logger(__name__)
class PaginatedPlaylistsResponse(TypedDict): class PaginatedPlaylistsResponse(TypedDict):
"""Response type for paginated playlists.""" """Response type for paginated playlists."""
playlists: list[dict] playlists: list[dict]
total: int total: int
page: int page: int
@@ -468,7 +469,9 @@ class PlaylistService:
} }
async def add_sound_to_main_playlist( async def add_sound_to_main_playlist(
self, sound_id: int, user_id: int, # noqa: ARG002 self,
sound_id: int, # noqa: ARG002
user_id: int, # noqa: ARG002
) -> None: ) -> None:
"""Add a sound to the global main playlist.""" """Add a sound to the global main playlist."""
raise HTTPException( raise HTTPException(
@@ -477,7 +480,9 @@ class PlaylistService:
) )
async def _add_sound_to_main_playlist_internal( async def _add_sound_to_main_playlist_internal(
self, sound_id: int, user_id: int, self,
sound_id: int,
user_id: int,
) -> None: ) -> None:
"""Add sound to main playlist bypassing restrictions. """Add sound to main playlist bypassing restrictions.

View File

@@ -21,8 +21,6 @@ def mock_plan_repository():
return Mock() return Mock()
@pytest.fixture @pytest.fixture
def regular_user(): def regular_user():
"""Create regular user for testing.""" """Create regular user for testing."""
@@ -60,52 +58,78 @@ class TestAdminUserEndpoints:
test_plan: Plan, test_plan: Plan,
) -> None: ) -> None:
"""Test listing users successfully.""" """Test listing users successfully."""
with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all: with patch(
"app.repositories.user.UserRepository.get_all_with_plan_paginated",
) as mock_get_all:
# Create mock user objects that don't trigger database saves # Create mock user objects that don't trigger database saves
mock_admin = type("User", (), { mock_admin = type(
"id": admin_user.id, "User",
"email": admin_user.email, (),
"name": admin_user.name, {
"picture": None, "id": admin_user.id,
"role": admin_user.role, "email": admin_user.email,
"credits": admin_user.credits, "name": admin_user.name,
"is_active": admin_user.is_active, "picture": None,
"created_at": admin_user.created_at, "role": admin_user.role,
"updated_at": admin_user.updated_at, "credits": admin_user.credits,
"plan": type("Plan", (), { "is_active": admin_user.is_active,
"id": test_plan.id, "created_at": admin_user.created_at,
"name": test_plan.name, "updated_at": admin_user.updated_at,
"max_credits": test_plan.max_credits, "plan": type(
})(), "Plan",
})() (),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_regular = type("User", (), { mock_regular = type(
"id": regular_user.id, "User",
"email": regular_user.email, (),
"name": regular_user.name, {
"picture": None, "id": regular_user.id,
"role": regular_user.role, "email": regular_user.email,
"credits": regular_user.credits, "name": regular_user.name,
"is_active": regular_user.is_active, "picture": None,
"created_at": regular_user.created_at, "role": regular_user.role,
"updated_at": regular_user.updated_at, "credits": regular_user.credits,
"plan": type("Plan", (), { "is_active": regular_user.is_active,
"id": test_plan.id, "created_at": regular_user.created_at,
"name": test_plan.name, "updated_at": regular_user.updated_at,
"max_credits": test_plan.max_credits, "plan": type(
})(), "Plan",
})() (),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_all.return_value = [mock_admin, mock_regular] # Mock returns tuple (users, total_count)
mock_get_all.return_value = ([mock_admin, mock_regular], 2)
response = await authenticated_admin_client.get("/api/v1/admin/users/") response = await authenticated_admin_client.get("/api/v1/admin/users/")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 2 assert "users" in data
assert data[0]["email"] == "admin@example.com" assert "total" in data
assert data[1]["email"] == "user@example.com" assert "page" in data
mock_get_all.assert_called_once_with(limit=100, offset=0) assert "limit" in data
assert "total_pages" in data
assert len(data["users"]) == 2
assert data["users"][0]["email"] == "admin@example.com"
assert data["users"][1]["email"] == "user@example.com"
assert data["total"] == 2
assert data["page"] == 1
assert data["limit"] == 50
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_users_with_pagination( async def test_list_users_with_pagination(
@@ -115,29 +139,55 @@ class TestAdminUserEndpoints:
test_plan: Plan, test_plan: Plan,
) -> None: ) -> None:
"""Test listing users with pagination.""" """Test listing users with pagination."""
with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all: from app.repositories.user import SortOrder, UserSortField, UserStatus
mock_admin = type("User", (), {
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_get_all.return_value = [mock_admin]
response = await authenticated_admin_client.get("/api/v1/admin/users/?limit=10&offset=5") with patch(
"app.repositories.user.UserRepository.get_all_with_plan_paginated",
) as mock_get_all:
mock_admin = type(
"User",
(),
{
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
# Mock returns tuple (users, total_count)
mock_get_all.return_value = ([mock_admin], 1)
response = await authenticated_admin_client.get(
"/api/v1/admin/users/?page=2&limit=10",
)
assert response.status_code == 200 assert response.status_code == 200
mock_get_all.assert_called_once_with(limit=10, offset=5) data = response.json()
assert "users" in data
assert data["page"] == 2
assert data["limit"] == 10
mock_get_all.assert_called_once_with(
page=2,
limit=10,
search=None,
sort_by=UserSortField.NAME,
sort_order=SortOrder.ASC,
status_filter=UserStatus.ALL,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_users_unauthenticated(self, client: AsyncClient) -> None: async def test_list_users_unauthenticated(self, client: AsyncClient) -> None:
@@ -153,7 +203,9 @@ class TestAdminUserEndpoints:
regular_user: User, regular_user: User,
) -> None: ) -> None:
"""Test listing users as non-admin user.""" """Test listing users as non-admin user."""
with patch("app.core.dependencies.get_current_active_user", return_value=regular_user): with patch(
"app.core.dependencies.get_current_active_user", return_value=regular_user,
):
response = await client.get("/api/v1/admin/users/") response = await client.get("/api/v1/admin/users/")
assert response.status_code == 401 assert response.status_code == 401
@@ -169,24 +221,34 @@ class TestAdminUserEndpoints:
"""Test getting specific user successfully.""" """Test getting specific user successfully."""
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
): ):
mock_user = type("User", (), { mock_user = type(
"id": regular_user.id, "User",
"email": regular_user.email, (),
"name": regular_user.name, {
"picture": None, "id": regular_user.id,
"role": regular_user.role, "email": regular_user.email,
"credits": regular_user.credits, "name": regular_user.name,
"is_active": regular_user.is_active, "picture": None,
"created_at": regular_user.created_at, "role": regular_user.role,
"updated_at": regular_user.updated_at, "credits": regular_user.credits,
"plan": type("Plan", (), { "is_active": regular_user.is_active,
"id": test_plan.id, "created_at": regular_user.created_at,
"name": test_plan.name, "updated_at": regular_user.updated_at,
"max_credits": test_plan.max_credits, "plan": type(
})(), "Plan",
})() (),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user mock_get_by_id.return_value = mock_user
response = await authenticated_admin_client.get("/api/v1/admin/users/2") response = await authenticated_admin_client.get("/api/v1/admin/users/2")
@@ -207,7 +269,10 @@ class TestAdminUserEndpoints:
"""Test getting non-existent user.""" """Test getting non-existent user."""
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None), patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
): ):
response = await authenticated_admin_client.get("/api/v1/admin/users/999") response = await authenticated_admin_client.get("/api/v1/admin/users/999")
@@ -226,43 +291,63 @@ class TestAdminUserEndpoints:
"""Test updating user successfully.""" """Test updating user successfully."""
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update, patch("app.repositories.user.UserRepository.update") as mock_update,
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan), patch(
"app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan,
),
): ):
mock_user = type("User", (), { mock_user = type(
"id": regular_user.id, "User",
"email": regular_user.email, (),
"name": regular_user.name, {
"picture": None, "id": regular_user.id,
"role": regular_user.role, "email": regular_user.email,
"credits": regular_user.credits, "name": regular_user.name,
"is_active": regular_user.is_active, "picture": None,
"created_at": regular_user.created_at, "role": regular_user.role,
"updated_at": regular_user.updated_at, "credits": regular_user.credits,
"plan": type("Plan", (), { "is_active": regular_user.is_active,
"id": test_plan.id, "created_at": regular_user.created_at,
"name": test_plan.name, "updated_at": regular_user.updated_at,
"max_credits": test_plan.max_credits, "plan": type(
})(), "Plan",
})() (),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
updated_mock = type("User", (), { updated_mock = type(
"id": regular_user.id, "User",
"email": regular_user.email, (),
"name": "Updated Name", {
"picture": None, "id": regular_user.id,
"role": regular_user.role, "email": regular_user.email,
"credits": 200, "name": "Updated Name",
"is_active": regular_user.is_active, "picture": None,
"created_at": regular_user.created_at, "role": regular_user.role,
"updated_at": regular_user.updated_at, "credits": 200,
"plan": type("Plan", (), { "is_active": regular_user.is_active,
"id": test_plan.id, "created_at": regular_user.created_at,
"name": test_plan.name, "updated_at": regular_user.updated_at,
"max_credits": test_plan.max_credits, "plan": type(
})(), "Plan",
})() (),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user mock_get_by_id.return_value = mock_user
mock_update.return_value = updated_mock mock_update.return_value = updated_mock
@@ -271,7 +356,10 @@ class TestAdminUserEndpoints:
async def mock_refresh(instance, attributes=None): async def mock_refresh(instance, attributes=None):
pass pass
with patch("sqlmodel.ext.asyncio.session.AsyncSession.refresh", side_effect=mock_refresh): with patch(
"sqlmodel.ext.asyncio.session.AsyncSession.refresh",
side_effect=mock_refresh,
):
response = await authenticated_admin_client.patch( response = await authenticated_admin_client.patch(
"/api/v1/admin/users/2", "/api/v1/admin/users/2",
json={ json={
@@ -295,7 +383,10 @@ class TestAdminUserEndpoints:
"""Test updating non-existent user.""" """Test updating non-existent user."""
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None), patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
): ):
response = await authenticated_admin_client.patch( response = await authenticated_admin_client.patch(
"/api/v1/admin/users/999", "/api/v1/admin/users/999",
@@ -316,25 +407,35 @@ class TestAdminUserEndpoints:
"""Test updating user with invalid plan.""" """Test updating user with invalid plan."""
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=None), patch("app.repositories.plan.PlanRepository.get_by_id", return_value=None),
): ):
mock_user = type("User", (), { mock_user = type(
"id": regular_user.id, "User",
"email": regular_user.email, (),
"name": regular_user.name, {
"picture": None, "id": regular_user.id,
"role": regular_user.role, "email": regular_user.email,
"credits": regular_user.credits, "name": regular_user.name,
"is_active": regular_user.is_active, "picture": None,
"created_at": regular_user.created_at, "role": regular_user.role,
"updated_at": regular_user.updated_at, "credits": regular_user.credits,
"plan": type("Plan", (), { "is_active": regular_user.is_active,
"id": 1, "created_at": regular_user.created_at,
"name": "Basic", "updated_at": regular_user.updated_at,
"max_credits": 100, "plan": type(
})(), "Plan",
})() (),
{
"id": 1,
"name": "Basic",
"max_credits": 100,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user mock_get_by_id.return_value = mock_user
response = await authenticated_admin_client.patch( response = await authenticated_admin_client.patch(
"/api/v1/admin/users/2", "/api/v1/admin/users/2",
@@ -356,29 +457,41 @@ class TestAdminUserEndpoints:
"""Test disabling user successfully.""" """Test disabling user successfully."""
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update, patch("app.repositories.user.UserRepository.update") as mock_update,
): ):
mock_user = type("User", (), { mock_user = type(
"id": regular_user.id, "User",
"email": regular_user.email, (),
"name": regular_user.name, {
"picture": None, "id": regular_user.id,
"role": regular_user.role, "email": regular_user.email,
"credits": regular_user.credits, "name": regular_user.name,
"is_active": regular_user.is_active, "picture": None,
"created_at": regular_user.created_at, "role": regular_user.role,
"updated_at": regular_user.updated_at, "credits": regular_user.credits,
"plan": type("Plan", (), { "is_active": regular_user.is_active,
"id": test_plan.id, "created_at": regular_user.created_at,
"name": test_plan.name, "updated_at": regular_user.updated_at,
"max_credits": test_plan.max_credits, "plan": type(
})(), "Plan",
})() (),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user mock_get_by_id.return_value = mock_user
mock_update.return_value = mock_user mock_update.return_value = mock_user
response = await authenticated_admin_client.post("/api/v1/admin/users/2/disable") response = await authenticated_admin_client.post(
"/api/v1/admin/users/2/disable",
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -393,9 +506,14 @@ class TestAdminUserEndpoints:
"""Test disabling non-existent user.""" """Test disabling non-existent user."""
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None), patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
): ):
response = await authenticated_admin_client.post("/api/v1/admin/users/999/disable") response = await authenticated_admin_client.post(
"/api/v1/admin/users/999/disable",
)
assert response.status_code == 404 assert response.status_code == 404
data = response.json() data = response.json()
@@ -421,29 +539,41 @@ class TestAdminUserEndpoints:
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id, patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update, patch("app.repositories.user.UserRepository.update") as mock_update,
): ):
mock_disabled_user = type("User", (), { mock_disabled_user = type(
"id": disabled_user.id, "User",
"email": disabled_user.email, (),
"name": disabled_user.name, {
"picture": None, "id": disabled_user.id,
"role": disabled_user.role, "email": disabled_user.email,
"credits": disabled_user.credits, "name": disabled_user.name,
"is_active": disabled_user.is_active, "picture": None,
"created_at": disabled_user.created_at, "role": disabled_user.role,
"updated_at": disabled_user.updated_at, "credits": disabled_user.credits,
"plan": type("Plan", (), { "is_active": disabled_user.is_active,
"id": test_plan.id, "created_at": disabled_user.created_at,
"name": test_plan.name, "updated_at": disabled_user.updated_at,
"max_credits": test_plan.max_credits, "plan": type(
})(), "Plan",
})() (),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_disabled_user mock_get_by_id.return_value = mock_disabled_user
mock_update.return_value = mock_disabled_user mock_update.return_value = mock_disabled_user
response = await authenticated_admin_client.post("/api/v1/admin/users/3/enable") response = await authenticated_admin_client.post(
"/api/v1/admin/users/3/enable",
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -458,9 +588,14 @@ class TestAdminUserEndpoints:
"""Test enabling non-existent user.""" """Test enabling non-existent user."""
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None), patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
): ):
response = await authenticated_admin_client.post("/api/v1/admin/users/999/enable") response = await authenticated_admin_client.post(
"/api/v1/admin/users/999/enable",
)
assert response.status_code == 404 assert response.status_code == 404
data = response.json() data = response.json()
@@ -479,9 +614,14 @@ class TestAdminUserEndpoints:
with ( with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user), patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.plan.PlanRepository.get_all", return_value=[basic_plan, premium_plan]), patch(
"app.repositories.plan.PlanRepository.get_all",
return_value=[basic_plan, premium_plan],
),
): ):
response = await authenticated_admin_client.get("/api/v1/admin/users/plans/list") response = await authenticated_admin_client.get(
"/api/v1/admin/users/plans/list",
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()

View File

@@ -488,11 +488,17 @@ class TestAuthEndpoints:
test_plan: Plan, test_plan: Plan,
) -> None: ) -> None:
"""Test refresh token success.""" """Test refresh token success."""
with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh: with patch(
mock_refresh.return_value = type("TokenResponse", (), { "app.services.auth.AuthService.refresh_access_token",
"access_token": "new_access_token", ) as mock_refresh:
"expires_in": 3600, mock_refresh.return_value = type(
})() "TokenResponse",
(),
{
"access_token": "new_access_token",
"expires_in": 3600,
},
)()
response = await test_client.post( response = await test_client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh",
@@ -516,7 +522,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_token_service_error(self, test_client: AsyncClient) -> None: async def test_refresh_token_service_error(self, test_client: AsyncClient) -> None:
"""Test refresh token with service error.""" """Test refresh token with service error."""
with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh: with patch(
"app.services.auth.AuthService.refresh_access_token",
) as mock_refresh:
mock_refresh.side_effect = Exception("Database error") mock_refresh.side_effect = Exception("Database error")
response = await test_client.post( response = await test_client.post(
@@ -528,7 +536,6 @@ class TestAuthEndpoints:
data = response.json() data = response.json()
assert "Token refresh failed" in data["detail"] assert "Token refresh failed" in data["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exchange_oauth_token_invalid_code( async def test_exchange_oauth_token_invalid_code(
self, self,
@@ -554,7 +561,9 @@ class TestAuthEndpoints:
"""Test update profile success.""" """Test update profile success."""
with ( with (
patch("app.services.auth.AuthService.update_user_profile") as mock_update, patch("app.services.auth.AuthService.update_user_profile") as mock_update,
patch("app.services.auth.AuthService.user_to_response") as mock_user_to_response, patch(
"app.services.auth.AuthService.user_to_response",
) as mock_user_to_response,
): ):
updated_user = User( updated_user = User(
id=test_user.id, id=test_user.id,
@@ -569,6 +578,7 @@ class TestAuthEndpoints:
# Mock the user_to_response to return UserResponse format # Mock the user_to_response to return UserResponse format
from app.schemas.auth import UserResponse from app.schemas.auth import UserResponse
mock_user_to_response.return_value = UserResponse( mock_user_to_response.return_value = UserResponse(
id=test_user.id, id=test_user.id,
email=test_user.email, email=test_user.email,
@@ -598,7 +608,9 @@ class TestAuthEndpoints:
assert data["name"] == "Updated Name" assert data["name"] == "Updated Name"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_profile_unauthenticated(self, test_client: AsyncClient) -> None: async def test_update_profile_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test update profile without authentication.""" """Test update profile without authentication."""
response = await test_client.patch( response = await test_client.patch(
"/api/v1/auth/me", "/api/v1/auth/me",
@@ -632,7 +644,9 @@ class TestAuthEndpoints:
assert data["message"] == "Password changed successfully" assert data["message"] == "Password changed successfully"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password_unauthenticated(self, test_client: AsyncClient) -> None: async def test_change_password_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test change password without authentication.""" """Test change password without authentication."""
response = await test_client.post( response = await test_client.post(
"/api/v1/auth/change-password", "/api/v1/auth/change-password",
@@ -652,7 +666,9 @@ class TestAuthEndpoints:
auth_cookies: dict[str, str], auth_cookies: dict[str, str],
) -> None: ) -> None:
"""Test get user OAuth providers success.""" """Test get user OAuth providers success."""
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers: with patch(
"app.services.auth.AuthService.get_user_oauth_providers",
) as mock_providers:
from datetime import datetime from datetime import datetime
from app.models.user_oauth import UserOauth from app.models.user_oauth import UserOauth
@@ -699,7 +715,9 @@ class TestAuthEndpoints:
assert data[2]["display_name"] == "GitHub" assert data[2]["display_name"] == "GitHub"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_providers_unauthenticated(self, test_client: AsyncClient) -> None: async def test_get_user_providers_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test get user OAuth providers without authentication.""" """Test get user OAuth providers without authentication."""
response = await test_client.get("/api/v1/auth/user-providers") response = await test_client.get("/api/v1/auth/user-providers")

View File

@@ -109,9 +109,15 @@ class TestPlaylistEndpoints:
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 2 assert "playlists" in data
assert "total" in data
assert "page" in data
assert "limit" in data
assert "total_pages" in data
assert len(data["playlists"]) == 2
assert data["total"] == 2
playlist_names = {p["name"] for p in data} playlist_names = {p["name"] for p in data["playlists"]}
assert "Test Playlist" in playlist_names assert "Test Playlist" in playlist_names
assert "Main Playlist" in playlist_names assert "Main Playlist" in playlist_names

View File

@@ -9,7 +9,7 @@ from httpx import AsyncClient
from app.models.user import User from app.models.user import User
if TYPE_CHECKING: if TYPE_CHECKING:
from app.services.extraction import ExtractionInfo from app.services.extraction import ExtractionInfo, PaginatedExtractionsResponse
class TestSoundEndpoints: class TestSoundEndpoints:
@@ -32,6 +32,7 @@ class TestSoundEndpoints:
"error": None, "error": None,
"sound_id": None, "sound_id": None,
"user_id": authenticated_user.id, "user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z", "created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z", "updated_at": "2025-08-03T12:00:00Z",
} }
@@ -111,6 +112,7 @@ class TestSoundEndpoints:
"error": None, "error": None,
"sound_id": 42, "sound_id": 42,
"user_id": authenticated_user.id, "user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z", "created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z", "updated_at": "2025-08-03T12:00:00Z",
} }
@@ -154,41 +156,49 @@ class TestSoundEndpoints:
authenticated_user: User, authenticated_user: User,
) -> None: ) -> None:
"""Test getting user extractions.""" """Test getting user extractions."""
mock_extractions: list[ExtractionInfo] = [ mock_extractions: PaginatedExtractionsResponse = {
{ "extractions": [
"id": 1, {
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", "id": 1,
"title": "Never Gonna Give You Up", "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"service": "youtube", "title": "Never Gonna Give You Up",
"service_id": "dQw4w9WgXcQ", "service": "youtube",
"status": "completed", "service_id": "dQw4w9WgXcQ",
"error": None, "status": "completed",
"sound_id": 42, "error": None,
"user_id": authenticated_user.id, "sound_id": 42,
"created_at": "2025-08-03T12:00:00Z", "user_id": authenticated_user.id,
"updated_at": "2025-08-03T12:00:00Z", "user_name": authenticated_user.name,
}, "created_at": "2025-08-03T12:00:00Z",
{ "updated_at": "2025-08-03T12:00:00Z",
"id": 2, },
"url": "https://soundcloud.com/example/track", {
"title": "Example Track", "id": 2,
"service": "soundcloud", "url": "https://soundcloud.com/example/track",
"service_id": "example-track", "title": "Example Track",
"status": "pending", "service": "soundcloud",
"error": None, "service_id": "example-track",
"sound_id": None, "status": "pending",
"user_id": authenticated_user.id, "error": None,
"created_at": "2025-08-03T12:00:00Z", "sound_id": None,
"updated_at": "2025-08-03T12:00:00Z", "user_id": authenticated_user.id,
}, "user_name": authenticated_user.name,
] "created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
],
"total": 2,
"page": 1,
"limit": 50,
"total_pages": 1,
}
with patch( with patch(
"app.services.extraction.ExtractionService.get_user_extractions", "app.services.extraction.ExtractionService.get_user_extractions",
) as mock_get: ) as mock_get:
mock_get.return_value = mock_extractions mock_get.return_value = mock_extractions
response = await authenticated_client.get("/api/v1/extractions/") response = await authenticated_client.get("/api/v1/extractions/user")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -337,7 +347,9 @@ class TestSoundEndpoints:
"""Test getting sounds with authentication.""" """Test getting sounds with authentication."""
from app.models.sound import Sound from app.models.sound import Sound
with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get: with patch(
"app.repositories.sound.SoundRepository.search_and_sort",
) as mock_get:
# Create mock sounds with all required fields # Create mock sounds with all required fields
mock_sound_1 = Sound( mock_sound_1 = Sound(
id=1, id=1,
@@ -383,7 +395,9 @@ class TestSoundEndpoints:
"""Test getting sounds with type filtering.""" """Test getting sounds with type filtering."""
from app.models.sound import Sound from app.models.sound import Sound
with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get: with patch(
"app.repositories.sound.SoundRepository.search_and_sort",
) as mock_get:
# Create mock sound with all required fields # Create mock sound with all required fields
mock_sound = Sound( mock_sound = Sound(
id=1, id=1,

View File

@@ -335,5 +335,3 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
access_token = JWTUtils.create_access_token(token_data) access_token = JWTUtils.create_access_token(token_data)
return {"access_token": access_token} return {"access_token": access_token}

View File

@@ -539,21 +539,35 @@ class TestPlaylistRepository:
sound_ids = [s.id for s in sounds] sound_ids = [s.id for s in sounds]
# Add first two sounds sequentially (positions 0, 1) # Add first two sounds sequentially (positions 0, 1)
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0 await playlist_repository.add_sound_to_playlist(
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1 playlist_id, sound_ids[0],
) # position 0
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[1],
) # position 1
# Now insert third sound at position 1 - should shift existing sound at position 1 to position 2 # Now insert third sound at position 1 - should shift existing sound at position 1 to position 2
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=1) await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[2], position=1,
)
# Verify the final positions # Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id) playlist_sounds = await playlist_repository.get_playlist_sound_entries(
playlist_id,
)
assert len(playlist_sounds) == 3 assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0 assert (
playlist_sounds[0].sound_id == sound_ids[0]
) # Original sound 0 stays at position 0
assert playlist_sounds[0].position == 0 assert playlist_sounds[0].position == 0
assert playlist_sounds[1].sound_id == sound_ids[2] # New sound 2 inserted at position 1 assert (
playlist_sounds[1].sound_id == sound_ids[2]
) # New sound 2 inserted at position 1
assert playlist_sounds[1].position == 1 assert playlist_sounds[1].position == 1
assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2 assert (
playlist_sounds[2].sound_id == sound_ids[1]
) # Original sound 1 shifted to position 2
assert playlist_sounds[2].position == 2 assert playlist_sounds[2].position == 2
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -615,21 +629,35 @@ class TestPlaylistRepository:
sound_ids = [s.id for s in sounds] sound_ids = [s.id for s in sounds]
# Add first two sounds sequentially (positions 0, 1) # Add first two sounds sequentially (positions 0, 1)
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0 await playlist_repository.add_sound_to_playlist(
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1 playlist_id, sound_ids[0],
) # position 0
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[1],
) # position 1
# Now insert third sound at position 0 - should shift existing sounds to positions 1, 2 # Now insert third sound at position 0 - should shift existing sounds to positions 1, 2
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=0) await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[2], position=0,
)
# Verify the final positions # Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id) playlist_sounds = await playlist_repository.get_playlist_sound_entries(
playlist_id,
)
assert len(playlist_sounds) == 3 assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0 assert (
playlist_sounds[0].sound_id == sound_ids[2]
) # New sound 2 inserted at position 0
assert playlist_sounds[0].position == 0 assert playlist_sounds[0].position == 0
assert playlist_sounds[1].sound_id == sound_ids[0] # Original sound 0 shifted to position 1 assert (
playlist_sounds[1].sound_id == sound_ids[0]
) # Original sound 0 shifted to position 1
assert playlist_sounds[1].position == 1 assert playlist_sounds[1].position == 1
assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2 assert (
playlist_sounds[2].sound_id == sound_ids[1]
) # Original sound 1 shifted to position 2
assert playlist_sounds[2].position == 2 assert playlist_sounds[2].position == 2
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -409,7 +409,9 @@ class TestCreditService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_recharge_user_credits_success( async def test_recharge_user_credits_success(
self, credit_service, sample_user, self,
credit_service,
sample_user,
) -> None: ) -> None:
"""Test successful credit recharge for a user.""" """Test successful credit recharge for a user."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()

View File

@@ -43,7 +43,9 @@ class TestDashboardService:
"total_duration": 75000, "total_duration": 75000,
"total_size": 1024000, "total_size": 1024000,
} }
mock_sound_repository.get_soundboard_statistics = AsyncMock(return_value=mock_stats) mock_sound_repository.get_soundboard_statistics = AsyncMock(
return_value=mock_stats,
)
result = await dashboard_service.get_soundboard_statistics() result = await dashboard_service.get_soundboard_statistics()

View File

@@ -99,6 +99,11 @@ class TestExtractionService:
url = "https://www.youtube.com/watch?v=test123" url = "https://www.youtube.com/watch?v=test123"
user_id = 1 user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - no service detection happens during creation # Mock repository call - no service detection happens during creation
mock_extraction = Extraction( mock_extraction = Extraction(
id=1, id=1,
@@ -120,6 +125,7 @@ class TestExtractionService:
assert result["service_id"] is None # Not detected during creation assert result["service_id"] is None # Not detected during creation
assert result["title"] is None # Not detected during creation assert result["title"] is None # Not detected during creation
assert result["status"] == "pending" assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_extraction_basic(self, extraction_service) -> None: async def test_create_extraction_basic(self, extraction_service) -> None:
@@ -127,6 +133,11 @@ class TestExtractionService:
url = "https://www.youtube.com/watch?v=test123" url = "https://www.youtube.com/watch?v=test123"
user_id = 1 user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - creation always succeeds now # Mock repository call - creation always succeeds now
mock_extraction = Extraction( mock_extraction = Extraction(
id=2, id=2,
@@ -146,6 +157,7 @@ class TestExtractionService:
assert result["id"] == 2 assert result["id"] == 2
assert result["url"] == url assert result["url"] == url
assert result["status"] == "pending" assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_extraction_any_url(self, extraction_service) -> None: async def test_create_extraction_any_url(self, extraction_service) -> None:
@@ -153,6 +165,11 @@ class TestExtractionService:
url = "https://invalid.url" url = "https://invalid.url"
user_id = 1 user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - even invalid URLs are accepted during creation # Mock repository call - even invalid URLs are accepted during creation
mock_extraction = Extraction( mock_extraction = Extraction(
id=3, id=3,
@@ -172,6 +189,7 @@ class TestExtractionService:
assert result["id"] == 3 assert result["id"] == 3
assert result["url"] == url assert result["url"] == url
assert result["status"] == "pending" assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_extraction_with_service_detection( async def test_process_extraction_with_service_detection(
@@ -408,9 +426,16 @@ class TestExtractionService:
sound_id=42, sound_id=42,
) )
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.extraction_repo.get_by_id = AsyncMock( extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=extraction, return_value=extraction,
) )
extraction_service.user_repo.get_by_id = AsyncMock(
return_value=mock_user,
)
result = await extraction_service.get_extraction_by_id(1) result = await extraction_service.get_extraction_by_id(1)
@@ -421,6 +446,7 @@ class TestExtractionService:
assert result["title"] == "Test Video" assert result["title"] == "Test Video"
assert result["status"] == "completed" assert result["status"] == "completed"
assert result["sound_id"] == 42 assert result["sound_id"] == 42
assert result["user_name"] == "Test User"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_extraction_by_id_not_found(self, extraction_service) -> None: async def test_get_extraction_by_id_not_found(self, extraction_service) -> None:
@@ -434,52 +460,70 @@ class TestExtractionService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_extractions(self, extraction_service) -> None: async def test_get_user_extractions(self, extraction_service) -> None:
"""Test getting user extractions.""" """Test getting user extractions."""
from app.models.user import User
user = User(id=1, name="Test User", email="test@example.com")
extractions = [ extractions = [
Extraction( (
id=1, Extraction(
service="youtube", id=1,
service_id="test123", service="youtube",
url="https://www.youtube.com/watch?v=test123", service_id="test123",
user_id=1, url="https://www.youtube.com/watch?v=test123",
title="Test Video 1", user_id=1,
status="completed", title="Test Video 1",
sound_id=42, status="completed",
sound_id=42,
),
user,
), ),
Extraction( (
id=2, Extraction(
service="youtube", id=2,
service_id="test456", service="youtube",
url="https://www.youtube.com/watch?v=test456", service_id="test456",
user_id=1, url="https://www.youtube.com/watch?v=test456",
title="Test Video 2", user_id=1,
status="pending", title="Test Video 2",
status="pending",
),
user,
), ),
] ]
extraction_service.extraction_repo.get_by_user = AsyncMock( extraction_service.extraction_repo.get_user_extractions_filtered = AsyncMock(
return_value=extractions, return_value=(extractions, 2),
) )
result = await extraction_service.get_user_extractions(1) result = await extraction_service.get_user_extractions(1)
assert len(result) == 2 assert result["total"] == 2
assert result[0]["id"] == 1 assert len(result["extractions"]) == 2
assert result[0]["title"] == "Test Video 1" assert result["extractions"][0]["id"] == 1
assert result[1]["id"] == 2 assert result["extractions"][0]["title"] == "Test Video 1"
assert result[1]["title"] == "Test Video 2" assert result["extractions"][0]["user_name"] == "Test User"
assert result["extractions"][1]["id"] == 2
assert result["extractions"][1]["title"] == "Test Video 2"
assert result["extractions"][1]["user_name"] == "Test User"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_pending_extractions(self, extraction_service) -> None: async def test_get_pending_extractions(self, extraction_service) -> None:
"""Test getting pending extractions.""" """Test getting pending extractions."""
from app.models.user import User
user = User(id=1, name="Test User", email="test@example.com")
pending_extractions = [ pending_extractions = [
Extraction( (
id=1, Extraction(
service="youtube", id=1,
service_id="test123", service="youtube",
url="https://www.youtube.com/watch?v=test123", service_id="test123",
user_id=1, url="https://www.youtube.com/watch?v=test123",
title="Pending Video", user_id=1,
status="pending", title="Pending Video",
status="pending",
),
user,
), ),
] ]
@@ -492,3 +536,4 @@ class TestExtractionService:
assert len(result) == 1 assert len(result) == 1
assert result[0]["id"] == 1 assert result[0]["id"] == 1
assert result[0]["status"] == "pending" assert result[0]["status"] == "pending"
assert result[0]["user_name"] == "Test User"

View File

@@ -25,9 +25,10 @@ class TestSchedulerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_scheduler(self, scheduler_service) -> None: async def test_start_scheduler(self, scheduler_service) -> None:
"""Test starting the scheduler service.""" """Test starting the scheduler service."""
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \ with (
patch.object(scheduler_service.scheduler, "start") as mock_start: patch.object(scheduler_service.scheduler, "add_job") as mock_add_job,
patch.object(scheduler_service.scheduler, "start") as mock_start,
):
await scheduler_service.start() await scheduler_service.start()
# Verify job was added # Verify job was added
@@ -61,7 +62,9 @@ class TestSchedulerService:
"total_credits_added": 500, "total_credits_added": 500,
} }
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge: with patch.object(
scheduler_service.credit_service, "recharge_all_users_credits",
) as mock_recharge:
mock_recharge.return_value = mock_stats mock_recharge.return_value = mock_stats
await scheduler_service._daily_credit_recharge() await scheduler_service._daily_credit_recharge()
@@ -71,7 +74,9 @@ class TestSchedulerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None: async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
"""Test daily credit recharge task with failure.""" """Test daily credit recharge task with failure."""
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge: with patch.object(
scheduler_service.credit_service, "recharge_all_users_credits",
) as mock_recharge:
mock_recharge.side_effect = Exception("Database error") mock_recharge.side_effect = Exception("Database error")
# Should not raise exception, just log it # Should not raise exception, just log it