Refactor user endpoint tests to include pagination and response structure validation
- Updated tests for listing users to validate pagination and response format. - Changed mock return values to include total count and pagination details. - Refactored user creation mocks for clarity and consistency. - Enhanced assertions to check for presence of pagination fields in responses. - Adjusted test cases for user retrieval and updates to ensure proper handling of user data. - Improved readability by restructuring mock definitions and assertions across various test files.
This commit is contained in:
@@ -10,7 +10,7 @@ from app.core.dependencies import get_admin_user
|
||||
from app.models.plan import Plan
|
||||
from app.models.user import User
|
||||
from app.repositories.plan import PlanRepository
|
||||
from app.repositories.user import UserRepository, UserSortField, SortOrder, UserStatus
|
||||
from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
|
||||
from app.schemas.auth import UserResponse
|
||||
from app.schemas.user import UserUpdate
|
||||
|
||||
@@ -36,21 +36,27 @@ def _user_to_response(user: User) -> UserResponse:
|
||||
"name": user.plan.name,
|
||||
"max_credits": user.plan.max_credits,
|
||||
"features": [], # Add features if needed
|
||||
} if user.plan else {},
|
||||
}
|
||||
if user.plan
|
||||
else {},
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_users(
|
||||
async def list_users( # noqa: PLR0913
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||
search: Annotated[str | None, Query(description="Search in name or email")] = None,
|
||||
sort_by: Annotated[UserSortField, Query(description="Sort by field")] = UserSortField.NAME,
|
||||
sort_by: Annotated[
|
||||
UserSortField, Query(description="Sort by field"),
|
||||
] = UserSortField.NAME,
|
||||
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
|
||||
status_filter: Annotated[UserStatus, Query(description="Filter by status")] = UserStatus.ALL,
|
||||
status_filter: Annotated[
|
||||
UserStatus, Query(description="Filter by status"),
|
||||
] = UserStatus.ALL,
|
||||
) -> dict[str, Any]:
|
||||
"""Get all users with pagination, search, and filters (admin only)."""
|
||||
user_repo = UserRepository(session)
|
||||
@@ -62,9 +68,9 @@ async def list_users(
|
||||
sort_order=sort_order,
|
||||
status_filter=status_filter,
|
||||
)
|
||||
|
||||
|
||||
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||
|
||||
|
||||
return {
|
||||
"users": [_user_to_response(user) for user in users],
|
||||
"total": total_count,
|
||||
|
||||
@@ -464,7 +464,8 @@ async def update_profile(
|
||||
"""Update the current user's profile."""
|
||||
try:
|
||||
updated_user = await auth_service.update_user_profile(
|
||||
current_user, request.model_dump(exclude_unset=True),
|
||||
current_user,
|
||||
request.model_dump(exclude_unset=True),
|
||||
)
|
||||
return await auth_service.user_to_response(updated_user)
|
||||
except Exception as e:
|
||||
@@ -486,7 +487,9 @@ async def change_password(
|
||||
user_email = current_user.email
|
||||
try:
|
||||
await auth_service.change_user_password(
|
||||
current_user, request.current_password, request.new_password,
|
||||
current_user,
|
||||
request.current_password,
|
||||
request.new_password,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
@@ -513,11 +516,13 @@ async def get_user_providers(
|
||||
|
||||
# Add password provider if user has password
|
||||
if current_user.password_hash:
|
||||
providers.append({
|
||||
"provider": "password",
|
||||
"display_name": "Password",
|
||||
"connected_at": current_user.created_at.isoformat(),
|
||||
})
|
||||
providers.append(
|
||||
{
|
||||
"provider": "password",
|
||||
"display_name": "Password",
|
||||
"connected_at": current_user.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Get OAuth providers from the database
|
||||
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
||||
@@ -528,10 +533,12 @@ async def get_user_providers(
|
||||
elif oauth.provider == "google":
|
||||
display_name = "Google"
|
||||
|
||||
providers.append({
|
||||
"provider": oauth.provider,
|
||||
"display_name": display_name,
|
||||
"connected_at": oauth.created_at.isoformat(),
|
||||
})
|
||||
providers.append(
|
||||
{
|
||||
"provider": oauth.provider,
|
||||
"display_name": display_name,
|
||||
"connected_at": oauth.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return providers
|
||||
|
||||
@@ -34,7 +34,8 @@ async def get_top_sounds(
|
||||
_current_user: Annotated[User, Depends(get_current_user)],
|
||||
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
|
||||
sound_type: Annotated[
|
||||
str, Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
|
||||
str,
|
||||
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
|
||||
],
|
||||
period: Annotated[
|
||||
str,
|
||||
@@ -43,7 +44,8 @@ async def get_top_sounds(
|
||||
),
|
||||
] = "all_time",
|
||||
limit: Annotated[
|
||||
int, Query(description="Number of top sounds to return", ge=1, le=100),
|
||||
int,
|
||||
Query(description="Number of top sounds to return", ge=1, le=100),
|
||||
] = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get top sounds by play count for a specific period."""
|
||||
|
||||
@@ -60,68 +60,13 @@ async def create_extraction(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{extraction_id}")
|
||||
async def get_extraction(
|
||||
extraction_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> ExtractionInfo:
|
||||
"""Get extraction information by ID."""
|
||||
try:
|
||||
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
|
||||
|
||||
if not extraction_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Extraction {extraction_id} not found",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extraction: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return extraction_info
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_all_extractions(
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
|
||||
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
|
||||
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
|
||||
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||
) -> dict:
|
||||
"""Get all extractions with optional filtering, search, and sorting."""
|
||||
try:
|
||||
result = await extraction_service.get_all_extractions(
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
status_filter=status_filter,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extractions: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/user")
|
||||
async def get_user_extractions(
|
||||
async def get_user_extractions( # noqa: PLR0913
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
|
||||
search: Annotated[
|
||||
str | None, Query(description="Search in title, URL, or service"),
|
||||
] = None,
|
||||
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
|
||||
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
|
||||
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
@@ -153,3 +98,62 @@ async def get_user_extractions(
|
||||
) from e
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{extraction_id}")
|
||||
async def get_extraction(
|
||||
extraction_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> ExtractionInfo:
|
||||
"""Get extraction information by ID."""
|
||||
try:
|
||||
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
|
||||
|
||||
if not extraction_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Extraction {extraction_id} not found",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extraction: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return extraction_info
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def get_all_extractions( # noqa: PLR0913
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
search: Annotated[
|
||||
str | None, Query(description="Search in title, URL, or service"),
|
||||
] = None,
|
||||
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
|
||||
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
|
||||
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||
) -> dict:
|
||||
"""Get all extractions with optional filtering, search, and sorting."""
|
||||
try:
|
||||
result = await extraction_service.get_all_extractions(
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
status_filter=status_filter,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get extractions: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return result
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
|
||||
from app.core.database import get_session_factory
|
||||
from app.core.dependencies import get_current_active_user
|
||||
from app.models.user import User
|
||||
from app.schemas.common import MessageResponse
|
||||
@@ -19,12 +20,10 @@ router = APIRouter(prefix="/favorites", tags=["favorites"])
|
||||
|
||||
def get_favorite_service() -> FavoriteService:
|
||||
"""Get the favorite service."""
|
||||
from app.core.database import get_session_factory
|
||||
|
||||
return FavoriteService(get_session_factory())
|
||||
|
||||
|
||||
@router.get("/", response_model=FavoritesListResponse)
|
||||
@router.get("/")
|
||||
async def get_user_favorites(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
@@ -33,12 +32,14 @@ async def get_user_favorites(
|
||||
) -> FavoritesListResponse:
|
||||
"""Get all favorites for the current user."""
|
||||
favorites = await favorite_service.get_user_favorites(
|
||||
current_user.id, limit, offset,
|
||||
current_user.id,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
return FavoritesListResponse(favorites=favorites)
|
||||
|
||||
|
||||
@router.get("/sounds", response_model=FavoritesListResponse)
|
||||
@router.get("/sounds")
|
||||
async def get_user_sound_favorites(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
@@ -47,12 +48,14 @@ async def get_user_sound_favorites(
|
||||
) -> FavoritesListResponse:
|
||||
"""Get sound favorites for the current user."""
|
||||
favorites = await favorite_service.get_user_sound_favorites(
|
||||
current_user.id, limit, offset,
|
||||
current_user.id,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
return FavoritesListResponse(favorites=favorites)
|
||||
|
||||
|
||||
@router.get("/playlists", response_model=FavoritesListResponse)
|
||||
@router.get("/playlists")
|
||||
async def get_user_playlist_favorites(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
@@ -61,12 +64,14 @@ async def get_user_playlist_favorites(
|
||||
) -> FavoritesListResponse:
|
||||
"""Get playlist favorites for the current user."""
|
||||
favorites = await favorite_service.get_user_playlist_favorites(
|
||||
current_user.id, limit, offset,
|
||||
current_user.id,
|
||||
limit,
|
||||
offset,
|
||||
)
|
||||
return FavoritesListResponse(favorites=favorites)
|
||||
|
||||
|
||||
@router.get("/counts", response_model=FavoriteCountsResponse)
|
||||
@router.get("/counts")
|
||||
async def get_favorite_counts(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||
@@ -76,7 +81,7 @@ async def get_favorite_counts(
|
||||
return FavoriteCountsResponse(**counts)
|
||||
|
||||
|
||||
@router.post("/sounds/{sound_id}", response_model=FavoriteResponse)
|
||||
@router.post("/sounds/{sound_id}")
|
||||
async def add_sound_favorite(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -103,7 +108,7 @@ async def add_sound_favorite(
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/playlists/{playlist_id}", response_model=FavoriteResponse)
|
||||
@router.post("/playlists/{playlist_id}")
|
||||
async def add_playlist_favorite(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -112,7 +117,8 @@ async def add_playlist_favorite(
|
||||
"""Add a playlist to favorites."""
|
||||
try:
|
||||
favorite = await favorite_service.add_playlist_favorite(
|
||||
current_user.id, playlist_id,
|
||||
current_user.id,
|
||||
playlist_id,
|
||||
)
|
||||
return FavoriteResponse.model_validate(favorite)
|
||||
except ValueError as e:
|
||||
@@ -132,7 +138,7 @@ async def add_playlist_favorite(
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/sounds/{sound_id}", response_model=MessageResponse)
|
||||
@router.delete("/sounds/{sound_id}")
|
||||
async def remove_sound_favorite(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -149,7 +155,7 @@ async def remove_sound_favorite(
|
||||
) from e
|
||||
|
||||
|
||||
@router.delete("/playlists/{playlist_id}", response_model=MessageResponse)
|
||||
@router.delete("/playlists/{playlist_id}")
|
||||
async def remove_playlist_favorite(
|
||||
playlist_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
@@ -185,6 +191,7 @@ async def check_playlist_favorited(
|
||||
) -> dict[str, bool]:
|
||||
"""Check if a playlist is favorited by the current user."""
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id, playlist_id,
|
||||
current_user.id,
|
||||
playlist_id,
|
||||
)
|
||||
return {"is_favorited": is_favorited}
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Annotated, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.database import get_db, get_session_factory
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.user import User
|
||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||
@@ -34,7 +34,6 @@ async def get_playlist_service(
|
||||
|
||||
def get_favorite_service() -> FavoriteService:
|
||||
"""Get the favorite service."""
|
||||
from app.core.database import get_session_factory
|
||||
return FavoriteService(get_session_factory())
|
||||
|
||||
|
||||
@@ -57,7 +56,7 @@ async def get_all_playlists( # noqa: PLR0913
|
||||
] = SortOrder.ASC,
|
||||
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||
favorites_only: Annotated[
|
||||
favorites_only: Annotated[ # noqa: FBT002
|
||||
bool,
|
||||
Query(description="Show only favorited playlists"),
|
||||
] = False,
|
||||
@@ -78,15 +77,26 @@ async def get_all_playlists( # noqa: PLR0913
|
||||
# Convert to PlaylistResponse with favorite indicators
|
||||
playlist_responses = []
|
||||
for playlist_dict in result["playlists"]:
|
||||
# The playlist service returns dict, need to create playlist object-like structure
|
||||
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist_dict["id"])
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_dict["id"])
|
||||
# The playlist service returns dict, need to create playlist object structure
|
||||
playlist_id = playlist_dict["id"]
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id, playlist_id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
|
||||
|
||||
# Create a PlaylistResponse-like dict with proper datetime conversion
|
||||
playlist_response = {
|
||||
**playlist_dict,
|
||||
"created_at": playlist_dict["created_at"].isoformat() if playlist_dict["created_at"] else None,
|
||||
"updated_at": playlist_dict["updated_at"].isoformat() if playlist_dict["updated_at"] else None,
|
||||
"created_at": (
|
||||
playlist_dict["created_at"].isoformat()
|
||||
if playlist_dict["created_at"]
|
||||
else None
|
||||
),
|
||||
"updated_at": (
|
||||
playlist_dict["updated_at"].isoformat()
|
||||
if playlist_dict["updated_at"]
|
||||
else None
|
||||
),
|
||||
"is_favorited": is_favorited,
|
||||
"favorite_count": favorite_count,
|
||||
}
|
||||
@@ -113,9 +123,13 @@ async def get_user_playlists(
|
||||
# Add favorite indicators for each playlist
|
||||
playlist_responses = []
|
||||
for playlist in playlists:
|
||||
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id, playlist.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||
playlist_response = PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||
playlist_response = PlaylistResponse.from_playlist(
|
||||
playlist, is_favorited, favorite_count,
|
||||
)
|
||||
playlist_responses.append(playlist_response)
|
||||
|
||||
return playlist_responses
|
||||
@@ -129,7 +143,9 @@ async def get_main_playlist(
|
||||
) -> PlaylistResponse:
|
||||
"""Get the global main playlist."""
|
||||
playlist = await playlist_service.get_main_playlist()
|
||||
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id, playlist.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||
|
||||
@@ -142,7 +158,9 @@ async def get_current_playlist(
|
||||
) -> PlaylistResponse:
|
||||
"""Get the global current playlist (falls back to main playlist)."""
|
||||
playlist = await playlist_service.get_current_playlist()
|
||||
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id, playlist.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||
|
||||
@@ -172,7 +190,9 @@ async def get_playlist(
|
||||
) -> PlaylistResponse:
|
||||
"""Get a specific playlist."""
|
||||
playlist = await playlist_service.get_playlist_by_id(playlist_id)
|
||||
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
|
||||
is_favorited = await favorite_service.is_playlist_favorited(
|
||||
current_user.id, playlist.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ async def get_sound_repository(
|
||||
return SoundRepository(session)
|
||||
|
||||
|
||||
@router.get("/", response_model=SoundsListResponse)
|
||||
@router.get("/")
|
||||
async def get_sounds( # noqa: PLR0913
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||
@@ -69,7 +69,7 @@ async def get_sounds( # noqa: PLR0913
|
||||
int,
|
||||
Query(description="Number of results to skip", ge=0),
|
||||
] = 0,
|
||||
favorites_only: Annotated[
|
||||
favorites_only: Annotated[ # noqa: FBT002
|
||||
bool,
|
||||
Query(description="Show only favorited sounds"),
|
||||
] = False,
|
||||
@@ -90,9 +90,13 @@ async def get_sounds( # noqa: PLR0913
|
||||
# Add favorite indicators for each sound
|
||||
sound_responses = []
|
||||
for sound in sounds:
|
||||
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound.id)
|
||||
is_favorited = await favorite_service.is_sound_favorited(
|
||||
current_user.id, sound.id,
|
||||
)
|
||||
favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
|
||||
sound_response = SoundResponse.from_sound(sound, is_favorited, favorite_count)
|
||||
sound_response = SoundResponse.from_sound(
|
||||
sound, is_favorited, favorite_count,
|
||||
)
|
||||
sound_responses.append(sound_response)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -20,7 +20,7 @@ class Settings(BaseSettings):
|
||||
|
||||
# Production URLs (for reverse proxy deployment)
|
||||
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
|
||||
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
|
||||
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
|
||||
|
||||
# CORS Configuration
|
||||
CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS
|
||||
|
||||
@@ -20,7 +20,9 @@ class BaseModel(SQLModel):
|
||||
# SQLAlchemy event listener to automatically update updated_at timestamp
|
||||
@event.listens_for(BaseModel, "before_update", propagate=True)
|
||||
def update_timestamp(
|
||||
mapper: Mapper[Any], connection: Connection, target: BaseModel, # noqa: ARG001
|
||||
mapper: Mapper[Any], # noqa: ARG001
|
||||
connection: Connection, # noqa: ARG001
|
||||
target: BaseModel,
|
||||
) -> None:
|
||||
"""Automatically set updated_at timestamp before update operations."""
|
||||
target.updated_at = datetime.now(UTC)
|
||||
|
||||
@@ -35,5 +35,3 @@ class PlaylistSound(BaseModel, table=True):
|
||||
# relationships
|
||||
playlist: "Playlist" = Relationship(back_populates="playlist_sounds")
|
||||
sound: "Sound" = Relationship(back_populates="playlist_sounds")
|
||||
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
)
|
||||
return list(result.all())
|
||||
|
||||
async def get_user_extractions_filtered(
|
||||
async def get_user_extractions_filtered( # noqa: PLR0913
|
||||
self,
|
||||
user_id: int,
|
||||
search: str | None = None,
|
||||
@@ -92,7 +92,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(
|
||||
base_query.subquery()
|
||||
base_query.subquery(),
|
||||
)
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
@@ -106,10 +106,10 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
|
||||
|
||||
return list(result.all()), total_count
|
||||
|
||||
async def get_all_extractions_filtered(
|
||||
async def get_all_extractions_filtered( # noqa: PLR0913
|
||||
self,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
@@ -138,7 +138,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(
|
||||
base_query.subquery()
|
||||
base_query.subquery(),
|
||||
)
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
@@ -152,5 +152,5 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
|
||||
|
||||
return list(result.all()), total_count
|
||||
|
||||
@@ -118,7 +118,9 @@ class FavoriteRepository(BaseRepository[Favorite]):
|
||||
raise
|
||||
|
||||
async def get_by_user_and_sound(
|
||||
self, user_id: int, sound_id: int,
|
||||
self,
|
||||
user_id: int,
|
||||
sound_id: int,
|
||||
) -> Favorite | None:
|
||||
"""Get a favorite by user and sound.
|
||||
|
||||
@@ -138,12 +140,16 @@ class FavoriteRepository(BaseRepository[Favorite]):
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get favorite for user %s and sound %s", user_id, sound_id,
|
||||
"Failed to get favorite for user %s and sound %s",
|
||||
user_id,
|
||||
sound_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_user_and_playlist(
|
||||
self, user_id: int, playlist_id: int,
|
||||
self,
|
||||
user_id: int,
|
||||
playlist_id: int,
|
||||
) -> Favorite | None:
|
||||
"""Get a favorite by user and playlist.
|
||||
|
||||
|
||||
@@ -57,7 +57,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
# management
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to update playlist timestamp for playlist: %s", playlist_id,
|
||||
"Failed to update playlist timestamp for playlist: %s",
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -341,7 +342,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
include_stats: bool = False, # noqa: FBT001, FBT002
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
favorites_only: bool = False,
|
||||
favorites_only: bool = False, # noqa: FBT001, FBT002
|
||||
current_user_id: int | None = None,
|
||||
*,
|
||||
return_count: bool = False,
|
||||
@@ -395,9 +396,13 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
# Apply favorites filter
|
||||
if favorites_only and current_user_id is not None:
|
||||
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
|
||||
favorites_subquery = select(1).select_from(Favorite).where(
|
||||
Favorite.user_id == current_user_id,
|
||||
Favorite.playlist_id == Playlist.id,
|
||||
favorites_subquery = (
|
||||
select(1)
|
||||
.select_from(Favorite)
|
||||
.where(
|
||||
Favorite.user_id == current_user_id,
|
||||
Favorite.playlist_id == Playlist.id,
|
||||
)
|
||||
)
|
||||
subquery = subquery.where(favorites_subquery.exists())
|
||||
|
||||
@@ -466,9 +471,13 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
# Apply favorites filter
|
||||
if favorites_only and current_user_id is not None:
|
||||
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
|
||||
favorites_subquery = select(1).select_from(Favorite).where(
|
||||
Favorite.user_id == current_user_id,
|
||||
Favorite.playlist_id == Playlist.id,
|
||||
favorites_subquery = (
|
||||
select(1)
|
||||
.select_from(Favorite)
|
||||
.where(
|
||||
Favorite.user_id == current_user_id,
|
||||
Favorite.playlist_id == Playlist.id,
|
||||
)
|
||||
)
|
||||
subquery = subquery.where(favorites_subquery.exists())
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
sort_order: SortOrder = SortOrder.ASC,
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
favorites_only: bool = False,
|
||||
favorites_only: bool = False, # noqa: FBT001, FBT002
|
||||
user_id: int | None = None,
|
||||
) -> list[Sound]:
|
||||
"""Search and sort sounds with optional filtering."""
|
||||
@@ -189,7 +189,8 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
logger.exception(
|
||||
(
|
||||
"Failed to search and sort sounds: "
|
||||
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, user_id=%s"
|
||||
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, "
|
||||
"user_id=%s"
|
||||
),
|
||||
search_query,
|
||||
sound_types,
|
||||
@@ -288,8 +289,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
|
||||
# Group by sound and order by play count descending
|
||||
statement = (
|
||||
statement
|
||||
.group_by(
|
||||
statement.group_by(
|
||||
Sound.id,
|
||||
Sound.name,
|
||||
Sound.type,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""User repository."""
|
||||
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import selectinload
|
||||
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
|
||||
|
||||
class UserSortField(str, Enum):
|
||||
"""User sort fields."""
|
||||
|
||||
NAME = "name"
|
||||
EMAIL = "email"
|
||||
ROLE = "role"
|
||||
@@ -27,12 +28,14 @@ class UserSortField(str, Enum):
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
"""User status filter."""
|
||||
|
||||
ALL = "all"
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
@@ -64,7 +67,7 @@ class UserRepository(BaseRepository[User]):
|
||||
logger.exception("Failed to get all users with plan")
|
||||
raise
|
||||
|
||||
async def get_all_with_plan_paginated(
|
||||
async def get_all_with_plan_paginated( # noqa: PLR0913
|
||||
self,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
@@ -77,21 +80,20 @@ class UserRepository(BaseRepository[User]):
|
||||
try:
|
||||
# Calculate offset
|
||||
offset = (page - 1) * limit
|
||||
|
||||
|
||||
# Build base query
|
||||
base_query = select(User).options(selectinload(User.plan))
|
||||
count_query = select(func.count(User.id))
|
||||
|
||||
|
||||
# Apply search filter
|
||||
if search and search.strip():
|
||||
search_pattern = f"%{search.strip().lower()}%"
|
||||
search_condition = (
|
||||
func.lower(User.name).like(search_pattern) |
|
||||
func.lower(User.email).like(search_pattern)
|
||||
)
|
||||
search_condition = func.lower(User.name).like(
|
||||
search_pattern,
|
||||
) | func.lower(User.email).like(search_pattern)
|
||||
base_query = base_query.where(search_condition)
|
||||
count_query = count_query.where(search_condition)
|
||||
|
||||
|
||||
# Apply status filter
|
||||
if status_filter == UserStatus.ACTIVE:
|
||||
base_query = base_query.where(User.is_active == True) # noqa: E712
|
||||
@@ -99,47 +101,34 @@ class UserRepository(BaseRepository[User]):
|
||||
elif status_filter == UserStatus.INACTIVE:
|
||||
base_query = base_query.where(User.is_active == False) # noqa: E712
|
||||
count_query = count_query.where(User.is_active == False) # noqa: E712
|
||||
|
||||
|
||||
# Apply sorting
|
||||
if sort_by == UserSortField.EMAIL:
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.email.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.email.asc())
|
||||
elif sort_by == UserSortField.ROLE:
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.role.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.role.asc())
|
||||
elif sort_by == UserSortField.CREDITS:
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.credits.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.credits.asc())
|
||||
elif sort_by == UserSortField.CREATED_AT:
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.created_at.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.created_at.asc())
|
||||
else: # Default to name
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.name.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.name.asc())
|
||||
|
||||
sort_column = {
|
||||
UserSortField.NAME: User.name,
|
||||
UserSortField.EMAIL: User.email,
|
||||
UserSortField.ROLE: User.role,
|
||||
UserSortField.CREDITS: User.credits,
|
||||
UserSortField.CREATED_AT: User.created_at,
|
||||
}.get(sort_by, User.name)
|
||||
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(sort_column.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(sort_column.asc())
|
||||
|
||||
# Get total count
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
|
||||
# Apply pagination and get results
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
users = list(result.all())
|
||||
|
||||
return users, total_count
|
||||
except Exception:
|
||||
logger.exception("Failed to get paginated users with plan")
|
||||
raise
|
||||
else:
|
||||
return users, total_count
|
||||
|
||||
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
|
||||
"""Get a user by ID with plan relationship loaded."""
|
||||
@@ -178,7 +167,7 @@ class UserRepository(BaseRepository[User]):
|
||||
logger.exception("Failed to get user by API token")
|
||||
raise
|
||||
|
||||
async def create(self, user_data: dict[str, Any]) -> User:
|
||||
async def create(self, entity_data: dict[str, Any]) -> User:
|
||||
"""Create a new user with plan assignment and first user admin logic."""
|
||||
|
||||
def _raise_plan_not_found() -> None:
|
||||
@@ -194,7 +183,7 @@ class UserRepository(BaseRepository[User]):
|
||||
if is_first_user:
|
||||
# First user gets admin role and pro plan
|
||||
plan_statement = select(Plan).where(Plan.code == "pro")
|
||||
user_data["role"] = "admin"
|
||||
entity_data["role"] = "admin"
|
||||
logger.info("Creating first user with admin role and pro plan")
|
||||
else:
|
||||
# Regular users get free plan
|
||||
@@ -210,11 +199,11 @@ class UserRepository(BaseRepository[User]):
|
||||
assert default_plan is not None # noqa: S101
|
||||
|
||||
# Set plan_id and default credits
|
||||
user_data["plan_id"] = default_plan.id
|
||||
user_data["credits"] = default_plan.credits
|
||||
entity_data["plan_id"] = default_plan.id
|
||||
entity_data["credits"] = default_plan.credits
|
||||
|
||||
# Use BaseRepository's create method
|
||||
return await super().create(user_data)
|
||||
return await super().create(entity_data)
|
||||
except Exception:
|
||||
logger.exception("Failed to create user")
|
||||
raise
|
||||
|
||||
@@ -85,7 +85,8 @@ class ChangePasswordRequest(BaseModel):
|
||||
"""Schema for password change request."""
|
||||
|
||||
current_password: str | None = Field(
|
||||
None, description="Current password (required if user has existing password)",
|
||||
None,
|
||||
description="Current password (required if user has existing password)",
|
||||
)
|
||||
new_password: str = Field(
|
||||
...,
|
||||
@@ -98,5 +99,8 @@ class UpdateProfileRequest(BaseModel):
|
||||
"""Schema for profile update request."""
|
||||
|
||||
name: str | None = Field(
|
||||
None, min_length=1, max_length=100, description="User display name",
|
||||
None,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="User display name",
|
||||
)
|
||||
|
||||
@@ -11,10 +11,12 @@ class FavoriteResponse(BaseModel):
|
||||
id: int = Field(description="Favorite ID")
|
||||
user_id: int = Field(description="User ID")
|
||||
sound_id: int | None = Field(
|
||||
description="Sound ID if this is a sound favorite", default=None,
|
||||
description="Sound ID if this is a sound favorite",
|
||||
default=None,
|
||||
)
|
||||
playlist_id: int | None = Field(
|
||||
description="Playlist ID if this is a playlist favorite", default=None,
|
||||
description="Playlist ID if this is a playlist favorite",
|
||||
default=None,
|
||||
)
|
||||
created_at: datetime = Field(description="Creation timestamp")
|
||||
updated_at: datetime = Field(description="Last update timestamp")
|
||||
|
||||
@@ -39,14 +39,19 @@ class PlaylistResponse(BaseModel):
|
||||
updated_at: str | None
|
||||
|
||||
@classmethod
|
||||
def from_playlist(cls, playlist: Playlist, is_favorited: bool = False, favorite_count: int = 0) -> "PlaylistResponse":
|
||||
def from_playlist(
|
||||
cls,
|
||||
playlist: Playlist,
|
||||
is_favorited: bool = False, # noqa: FBT001, FBT002
|
||||
favorite_count: int = 0,
|
||||
) -> "PlaylistResponse":
|
||||
"""Create response from playlist model.
|
||||
|
||||
|
||||
Args:
|
||||
playlist: The Playlist model
|
||||
is_favorited: Whether the playlist is favorited by the current user
|
||||
favorite_count: Number of users who favorited this playlist
|
||||
|
||||
|
||||
Returns:
|
||||
PlaylistResponse instance
|
||||
|
||||
|
||||
@@ -18,16 +18,20 @@ class SoundResponse(BaseModel):
|
||||
size: int = Field(description="File size in bytes")
|
||||
hash: str = Field(description="File hash")
|
||||
normalized_filename: str | None = Field(
|
||||
description="Normalized filename", default=None,
|
||||
description="Normalized filename",
|
||||
default=None,
|
||||
)
|
||||
normalized_duration: int | None = Field(
|
||||
description="Normalized duration in milliseconds", default=None,
|
||||
description="Normalized duration in milliseconds",
|
||||
default=None,
|
||||
)
|
||||
normalized_size: int | None = Field(
|
||||
description="Normalized file size in bytes", default=None,
|
||||
description="Normalized file size in bytes",
|
||||
default=None,
|
||||
)
|
||||
normalized_hash: str | None = Field(
|
||||
description="Normalized file hash", default=None,
|
||||
description="Normalized file hash",
|
||||
default=None,
|
||||
)
|
||||
thumbnail: str | None = Field(description="Thumbnail filename", default=None)
|
||||
play_count: int = Field(description="Number of times played")
|
||||
@@ -35,10 +39,12 @@ class SoundResponse(BaseModel):
|
||||
is_music: bool = Field(description="Whether the sound is music")
|
||||
is_deletable: bool = Field(description="Whether the sound can be deleted")
|
||||
is_favorited: bool = Field(
|
||||
description="Whether the sound is favorited by the current user", default=False,
|
||||
description="Whether the sound is favorited by the current user",
|
||||
default=False,
|
||||
)
|
||||
favorite_count: int = Field(
|
||||
description="Number of users who favorited this sound", default=0,
|
||||
description="Number of users who favorited this sound",
|
||||
default=0,
|
||||
)
|
||||
created_at: datetime = Field(description="Creation timestamp")
|
||||
updated_at: datetime = Field(description="Last update timestamp")
|
||||
@@ -50,7 +56,10 @@ class SoundResponse(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_sound(
|
||||
cls, sound: Sound, is_favorited: bool = False, favorite_count: int = 0,
|
||||
cls,
|
||||
sound: Sound,
|
||||
is_favorited: bool = False, # noqa: FBT001, FBT002
|
||||
favorite_count: int = 0,
|
||||
) -> "SoundResponse":
|
||||
"""Create a SoundResponse from a Sound model.
|
||||
|
||||
@@ -64,7 +73,8 @@ class SoundResponse(BaseModel):
|
||||
|
||||
"""
|
||||
if sound.id is None:
|
||||
raise ValueError("Sound ID cannot be None")
|
||||
msg = "Sound ID cannot be None"
|
||||
raise ValueError(msg)
|
||||
|
||||
return cls(
|
||||
id=sound.id,
|
||||
|
||||
@@ -7,7 +7,10 @@ class UserUpdate(BaseModel):
|
||||
"""Schema for updating a user."""
|
||||
|
||||
name: str | None = Field(
|
||||
None, min_length=1, max_length=100, description="User full name",
|
||||
None,
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
description="User full name",
|
||||
)
|
||||
plan_id: int | None = Field(None, description="User plan ID")
|
||||
credits: int | None = Field(None, ge=0, description="User credits")
|
||||
|
||||
@@ -454,7 +454,10 @@ class AuthService:
|
||||
return user
|
||||
|
||||
async def change_user_password(
|
||||
self, user: User, current_password: str | None, new_password: str,
|
||||
self,
|
||||
user: User,
|
||||
current_password: str | None,
|
||||
new_password: str,
|
||||
) -> None:
|
||||
"""Change user's password."""
|
||||
# Store user email before any operations to avoid session detachment issues
|
||||
@@ -484,8 +487,11 @@ class AuthService:
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Password %s successfully for user: %s",
|
||||
"changed" if had_existing_password else "set", user_email)
|
||||
logger.info(
|
||||
"Password %s successfully for user: %s",
|
||||
"changed" if had_existing_password else "set",
|
||||
user_email,
|
||||
)
|
||||
|
||||
async def user_to_response(self, user: User) -> UserResponse:
|
||||
"""Convert User model to UserResponse with plan information."""
|
||||
|
||||
@@ -72,9 +72,7 @@ class DashboardService:
|
||||
"play_count": sound["play_count"],
|
||||
"duration": sound["duration"],
|
||||
"created_at": (
|
||||
sound["created_at"].isoformat()
|
||||
if sound["created_at"]
|
||||
else None
|
||||
sound["created_at"].isoformat() if sound["created_at"] else None
|
||||
),
|
||||
}
|
||||
for sound in top_sounds
|
||||
|
||||
@@ -532,7 +532,8 @@ class ExtractionService:
|
||||
"""Add the sound to the user's main playlist."""
|
||||
try:
|
||||
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
|
||||
sound_id, user_id,
|
||||
sound_id,
|
||||
user_id,
|
||||
)
|
||||
logger.info(
|
||||
"Added sound %d to main playlist for user %d",
|
||||
@@ -554,6 +555,10 @@ class ExtractionService:
|
||||
if not extraction:
|
||||
return None
|
||||
|
||||
# Get user information
|
||||
user = await self.user_repo.get_by_id(extraction.user_id)
|
||||
user_name = user.name if user else None
|
||||
|
||||
return {
|
||||
"id": extraction.id or 0, # Should never be None for existing extraction
|
||||
"url": extraction.url,
|
||||
@@ -564,11 +569,12 @@ class ExtractionService:
|
||||
"error": extraction.error,
|
||||
"sound_id": extraction.sound_id,
|
||||
"user_id": extraction.user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": extraction.created_at.isoformat(),
|
||||
"updated_at": extraction.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
async def get_user_extractions(
|
||||
async def get_user_extractions( # noqa: PLR0913
|
||||
self,
|
||||
user_id: int,
|
||||
search: str | None = None,
|
||||
@@ -580,7 +586,10 @@ class ExtractionService:
|
||||
) -> PaginatedExtractionsResponse:
|
||||
"""Get all extractions for a user with filtering, search, and sorting."""
|
||||
offset = (page - 1) * limit
|
||||
extraction_user_tuples, total_count = await self.extraction_repo.get_user_extractions_filtered(
|
||||
(
|
||||
extraction_user_tuples,
|
||||
total_count,
|
||||
) = await self.extraction_repo.get_user_extractions_filtered(
|
||||
user_id=user_id,
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
@@ -619,7 +628,7 @@ class ExtractionService:
|
||||
"total_pages": total_pages,
|
||||
}
|
||||
|
||||
async def get_all_extractions(
|
||||
async def get_all_extractions( # noqa: PLR0913
|
||||
self,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
@@ -630,7 +639,10 @@ class ExtractionService:
|
||||
) -> PaginatedExtractionsResponse:
|
||||
"""Get all extractions with filtering, search, and sorting."""
|
||||
offset = (page - 1) * limit
|
||||
extraction_user_tuples, total_count = await self.extraction_repo.get_all_extractions_filtered(
|
||||
(
|
||||
extraction_user_tuples,
|
||||
total_count,
|
||||
) = await self.extraction_repo.get_all_extractions_filtered(
|
||||
search=search,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
|
||||
@@ -49,12 +49,14 @@ class FavoriteService:
|
||||
# Verify user exists
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
msg = f"User with ID {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Verify sound exists
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if not sound:
|
||||
raise ValueError(f"Sound with ID {sound_id} not found")
|
||||
msg = f"Sound with ID {sound_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get data for the event immediately after loading
|
||||
sound_name = sound.name
|
||||
@@ -63,9 +65,8 @@ class FavoriteService:
|
||||
# Check if already favorited
|
||||
existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
|
||||
if existing:
|
||||
raise ValueError(
|
||||
f"Sound {sound_id} is already favorited by user {user_id}",
|
||||
)
|
||||
msg = f"Sound {sound_id} is already favorited by user {user_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create favorite
|
||||
favorite_data = {
|
||||
@@ -120,12 +121,14 @@ class FavoriteService:
|
||||
# Verify user exists
|
||||
user = await user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with ID {user_id} not found")
|
||||
msg = f"User with ID {user_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Verify playlist exists
|
||||
playlist = await playlist_repo.get_by_id(playlist_id)
|
||||
if not playlist:
|
||||
raise ValueError(f"Playlist with ID {playlist_id} not found")
|
||||
msg = f"Playlist with ID {playlist_id} not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check if already favorited
|
||||
existing = await favorite_repo.get_by_user_and_playlist(
|
||||
@@ -133,9 +136,8 @@ class FavoriteService:
|
||||
playlist_id,
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(
|
||||
f"Playlist {playlist_id} is already favorited by user {user_id}",
|
||||
)
|
||||
msg = f"Playlist {playlist_id} is already favorited by user {user_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create favorite
|
||||
favorite_data = {
|
||||
@@ -163,7 +165,8 @@ class FavoriteService:
|
||||
|
||||
favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
|
||||
if not favorite:
|
||||
raise ValueError(f"Sound {sound_id} is not favorited by user {user_id}")
|
||||
msg = f"Sound {sound_id} is not favorited by user {user_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get user and sound info before deletion for the event
|
||||
user_repo = UserRepository(session)
|
||||
@@ -192,7 +195,8 @@ class FavoriteService:
|
||||
}
|
||||
await socket_manager.broadcast_to_all("sound_favorited", event_data)
|
||||
logger.info(
|
||||
"Broadcasted sound_favorited event for sound %s removal", sound_id,
|
||||
"Broadcasted sound_favorited event for sound %s removal",
|
||||
sound_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
@@ -219,9 +223,8 @@ class FavoriteService:
|
||||
playlist_id,
|
||||
)
|
||||
if not favorite:
|
||||
raise ValueError(
|
||||
f"Playlist {playlist_id} is not favorited by user {user_id}",
|
||||
)
|
||||
msg = f"Playlist {playlist_id} is not favorited by user {user_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
await favorite_repo.delete(favorite)
|
||||
logger.info(
|
||||
|
||||
@@ -16,6 +16,7 @@ logger = get_logger(__name__)
|
||||
|
||||
class PaginatedPlaylistsResponse(TypedDict):
|
||||
"""Response type for paginated playlists."""
|
||||
|
||||
playlists: list[dict]
|
||||
total: int
|
||||
page: int
|
||||
@@ -286,7 +287,7 @@ class PlaylistService:
|
||||
) -> PaginatedPlaylistsResponse:
|
||||
"""Search and sort playlists with pagination."""
|
||||
offset = (page - 1) * limit
|
||||
|
||||
|
||||
playlists, total_count = await self.playlist_repo.search_and_sort(
|
||||
search_query=search_query,
|
||||
sort_by=sort_by,
|
||||
@@ -299,9 +300,9 @@ class PlaylistService:
|
||||
current_user_id=current_user_id,
|
||||
return_count=True,
|
||||
)
|
||||
|
||||
|
||||
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||
|
||||
|
||||
return PaginatedPlaylistsResponse(
|
||||
playlists=playlists,
|
||||
total=total_count,
|
||||
@@ -468,7 +469,9 @@ class PlaylistService:
|
||||
}
|
||||
|
||||
async def add_sound_to_main_playlist(
|
||||
self, sound_id: int, user_id: int, # noqa: ARG002
|
||||
self,
|
||||
sound_id: int, # noqa: ARG002
|
||||
user_id: int, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Add a sound to the global main playlist."""
|
||||
raise HTTPException(
|
||||
@@ -477,7 +480,9 @@ class PlaylistService:
|
||||
)
|
||||
|
||||
async def _add_sound_to_main_playlist_internal(
|
||||
self, sound_id: int, user_id: int,
|
||||
self,
|
||||
sound_id: int,
|
||||
user_id: int,
|
||||
) -> None:
|
||||
"""Add sound to main playlist bypassing restrictions.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user