Refactor user endpoint tests to include pagination and response structure validation

- Updated tests for listing users to validate pagination and response format.
- Changed mock return values to include total count and pagination details.
- Refactored user creation mocks for clarity and consistency.
- Enhanced assertions to check for presence of pagination fields in responses.
- Adjusted test cases for user retrieval and updates to ensure proper handling of user data.
- Improved readability by restructuring mock definitions and assertions across various test files.
This commit is contained in:
JSC
2025-08-17 12:36:52 +02:00
parent e6f796a3c9
commit 6b55ff0e81
35 changed files with 863 additions and 503 deletions

View File

@@ -10,7 +10,7 @@ from app.core.dependencies import get_admin_user
from app.models.plan import Plan
from app.models.user import User
from app.repositories.plan import PlanRepository
from app.repositories.user import UserRepository, UserSortField, SortOrder, UserStatus
from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
from app.schemas.auth import UserResponse
from app.schemas.user import UserUpdate
@@ -36,21 +36,27 @@ def _user_to_response(user: User) -> UserResponse:
"name": user.plan.name,
"max_credits": user.plan.max_credits,
"features": [], # Add features if needed
} if user.plan else {},
}
if user.plan
else {},
created_at=user.created_at,
updated_at=user.updated_at,
)
@router.get("/")
async def list_users(
async def list_users( # noqa: PLR0913
session: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
search: Annotated[str | None, Query(description="Search in name or email")] = None,
sort_by: Annotated[UserSortField, Query(description="Sort by field")] = UserSortField.NAME,
sort_by: Annotated[
UserSortField, Query(description="Sort by field"),
] = UserSortField.NAME,
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
status_filter: Annotated[UserStatus, Query(description="Filter by status")] = UserStatus.ALL,
status_filter: Annotated[
UserStatus, Query(description="Filter by status"),
] = UserStatus.ALL,
) -> dict[str, Any]:
"""Get all users with pagination, search, and filters (admin only)."""
user_repo = UserRepository(session)
@@ -62,9 +68,9 @@ async def list_users(
sort_order=sort_order,
status_filter=status_filter,
)
total_pages = (total_count + limit - 1) // limit # Ceiling division
return {
"users": [_user_to_response(user) for user in users],
"total": total_count,

View File

@@ -464,7 +464,8 @@ async def update_profile(
"""Update the current user's profile."""
try:
updated_user = await auth_service.update_user_profile(
current_user, request.model_dump(exclude_unset=True),
current_user,
request.model_dump(exclude_unset=True),
)
return await auth_service.user_to_response(updated_user)
except Exception as e:
@@ -486,7 +487,9 @@ async def change_password(
user_email = current_user.email
try:
await auth_service.change_user_password(
current_user, request.current_password, request.new_password,
current_user,
request.current_password,
request.new_password,
)
except ValueError as e:
raise HTTPException(
@@ -513,11 +516,13 @@ async def get_user_providers(
# Add password provider if user has password
if current_user.password_hash:
providers.append({
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
})
providers.append(
{
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
},
)
# Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
@@ -528,10 +533,12 @@ async def get_user_providers(
elif oauth.provider == "google":
display_name = "Google"
providers.append({
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
})
providers.append(
{
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
},
)
return providers

View File

@@ -34,7 +34,8 @@ async def get_top_sounds(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
sound_type: Annotated[
str, Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
str,
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
],
period: Annotated[
str,
@@ -43,7 +44,8 @@ async def get_top_sounds(
),
] = "all_time",
limit: Annotated[
int, Query(description="Number of top sounds to return", ge=1, le=100),
int,
Query(description="Number of top sounds to return", ge=1, le=100),
] = 10,
) -> list[dict[str, Any]]:
"""Get top sounds by play count for a specific period."""

View File

@@ -60,68 +60,13 @@ async def create_extraction(
}
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions(
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result
@router.get("/user")
async def get_user_extractions(
async def get_user_extractions( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
@@ -153,3 +98,62 @@ async def get_user_extractions(
) from e
else:
return result
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions( # noqa: PLR0913
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result

View File

@@ -4,6 +4,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.database import get_session_factory
from app.core.dependencies import get_current_active_user
from app.models.user import User
from app.schemas.common import MessageResponse
@@ -19,12 +20,10 @@ router = APIRouter(prefix="/favorites", tags=["favorites"])
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
from app.core.database import get_session_factory
return FavoriteService(get_session_factory())
@router.get("/", response_model=FavoritesListResponse)
@router.get("/")
async def get_user_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -33,12 +32,14 @@ async def get_user_favorites(
) -> FavoritesListResponse:
"""Get all favorites for the current user."""
favorites = await favorite_service.get_user_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/sounds", response_model=FavoritesListResponse)
@router.get("/sounds")
async def get_user_sound_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -47,12 +48,14 @@ async def get_user_sound_favorites(
) -> FavoritesListResponse:
"""Get sound favorites for the current user."""
favorites = await favorite_service.get_user_sound_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/playlists", response_model=FavoritesListResponse)
@router.get("/playlists")
async def get_user_playlist_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -61,12 +64,14 @@ async def get_user_playlist_favorites(
) -> FavoritesListResponse:
"""Get playlist favorites for the current user."""
favorites = await favorite_service.get_user_playlist_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/counts", response_model=FavoriteCountsResponse)
@router.get("/counts")
async def get_favorite_counts(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -76,7 +81,7 @@ async def get_favorite_counts(
return FavoriteCountsResponse(**counts)
@router.post("/sounds/{sound_id}", response_model=FavoriteResponse)
@router.post("/sounds/{sound_id}")
async def add_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -103,7 +108,7 @@ async def add_sound_favorite(
) from e
@router.post("/playlists/{playlist_id}", response_model=FavoriteResponse)
@router.post("/playlists/{playlist_id}")
async def add_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -112,7 +117,8 @@ async def add_playlist_favorite(
"""Add a playlist to favorites."""
try:
favorite = await favorite_service.add_playlist_favorite(
current_user.id, playlist_id,
current_user.id,
playlist_id,
)
return FavoriteResponse.model_validate(favorite)
except ValueError as e:
@@ -132,7 +138,7 @@ async def add_playlist_favorite(
) from e
@router.delete("/sounds/{sound_id}", response_model=MessageResponse)
@router.delete("/sounds/{sound_id}")
async def remove_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -149,7 +155,7 @@ async def remove_sound_favorite(
) from e
@router.delete("/playlists/{playlist_id}", response_model=MessageResponse)
@router.delete("/playlists/{playlist_id}")
async def remove_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -185,6 +191,7 @@ async def check_playlist_favorited(
) -> dict[str, bool]:
"""Check if a playlist is favorited by the current user."""
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id,
current_user.id,
playlist_id,
)
return {"is_favorited": is_favorited}

View File

@@ -5,7 +5,7 @@ from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.database import get_db, get_session_factory
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder
@@ -34,7 +34,6 @@ async def get_playlist_service(
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
from app.core.database import get_session_factory
return FavoriteService(get_session_factory())
@@ -57,7 +56,7 @@ async def get_all_playlists( # noqa: PLR0913
] = SortOrder.ASC,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
favorites_only: Annotated[
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited playlists"),
] = False,
@@ -78,15 +77,26 @@ async def get_all_playlists( # noqa: PLR0913
# Convert to PlaylistResponse with favorite indicators
playlist_responses = []
for playlist_dict in result["playlists"]:
# The playlist service returns dict, need to create playlist object-like structure
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist_dict["id"])
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_dict["id"])
# The playlist service returns dict, need to create playlist object structure
playlist_id = playlist_dict["id"]
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
# Create a PlaylistResponse-like dict with proper datetime conversion
playlist_response = {
**playlist_dict,
"created_at": playlist_dict["created_at"].isoformat() if playlist_dict["created_at"] else None,
"updated_at": playlist_dict["updated_at"].isoformat() if playlist_dict["updated_at"] else None,
"created_at": (
playlist_dict["created_at"].isoformat()
if playlist_dict["created_at"]
else None
),
"updated_at": (
playlist_dict["updated_at"].isoformat()
if playlist_dict["updated_at"]
else None
),
"is_favorited": is_favorited,
"favorite_count": favorite_count,
}
@@ -113,9 +123,13 @@ async def get_user_playlists(
# Add favorite indicators for each playlist
playlist_responses = []
for playlist in playlists:
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
playlist_response = PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
playlist_response = PlaylistResponse.from_playlist(
playlist, is_favorited, favorite_count,
)
playlist_responses.append(playlist_response)
return playlist_responses
@@ -129,7 +143,9 @@ async def get_main_playlist(
) -> PlaylistResponse:
"""Get the global main playlist."""
playlist = await playlist_service.get_main_playlist()
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@@ -142,7 +158,9 @@ async def get_current_playlist(
) -> PlaylistResponse:
"""Get the global current playlist (falls back to main playlist)."""
playlist = await playlist_service.get_current_playlist()
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@@ -172,7 +190,9 @@ async def get_playlist(
) -> PlaylistResponse:
"""Get a specific playlist."""
playlist = await playlist_service.get_playlist_by_id(playlist_id)
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)

View File

@@ -40,7 +40,7 @@ async def get_sound_repository(
return SoundRepository(session)
@router.get("/", response_model=SoundsListResponse)
@router.get("/")
async def get_sounds( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
@@ -69,7 +69,7 @@ async def get_sounds( # noqa: PLR0913
int,
Query(description="Number of results to skip", ge=0),
] = 0,
favorites_only: Annotated[
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited sounds"),
] = False,
@@ -90,9 +90,13 @@ async def get_sounds( # noqa: PLR0913
# Add favorite indicators for each sound
sound_responses = []
for sound in sounds:
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound.id)
is_favorited = await favorite_service.is_sound_favorited(
current_user.id, sound.id,
)
favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
sound_response = SoundResponse.from_sound(sound, is_favorited, favorite_count)
sound_response = SoundResponse.from_sound(
sound, is_favorited, favorite_count,
)
sound_responses.append(sound_response)
except Exception as e:

View File

@@ -20,7 +20,7 @@ class Settings(BaseSettings):
# Production URLs (for reverse proxy deployment)
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
# CORS Configuration
CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS

View File

@@ -20,7 +20,9 @@ class BaseModel(SQLModel):
# SQLAlchemy event listener to automatically update updated_at timestamp
@event.listens_for(BaseModel, "before_update", propagate=True)
def update_timestamp(
mapper: Mapper[Any], connection: Connection, target: BaseModel, # noqa: ARG001
mapper: Mapper[Any], # noqa: ARG001
connection: Connection, # noqa: ARG001
target: BaseModel,
) -> None:
"""Automatically set updated_at timestamp before update operations."""
target.updated_at = datetime.now(UTC)

View File

@@ -35,5 +35,3 @@ class PlaylistSound(BaseModel, table=True):
# relationships
playlist: "Playlist" = Relationship(back_populates="playlist_sounds")
sound: "Sound" = Relationship(back_populates="playlist_sounds")

View File

@@ -58,7 +58,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
)
return list(result.all())
async def get_user_extractions_filtered(
async def get_user_extractions_filtered( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
@@ -92,7 +92,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery()
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
@@ -106,10 +106,10 @@ class ExtractionRepository(BaseRepository[Extraction]):
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
return list(result.all()), total_count
async def get_all_extractions_filtered(
async def get_all_extractions_filtered( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
@@ -138,7 +138,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery()
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
@@ -152,5 +152,5 @@ class ExtractionRepository(BaseRepository[Extraction]):
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
return list(result.all()), total_count

View File

@@ -118,7 +118,9 @@ class FavoriteRepository(BaseRepository[Favorite]):
raise
async def get_by_user_and_sound(
self, user_id: int, sound_id: int,
self,
user_id: int,
sound_id: int,
) -> Favorite | None:
"""Get a favorite by user and sound.
@@ -138,12 +140,16 @@ class FavoriteRepository(BaseRepository[Favorite]):
return result.first()
except Exception:
logger.exception(
"Failed to get favorite for user %s and sound %s", user_id, sound_id,
"Failed to get favorite for user %s and sound %s",
user_id,
sound_id,
)
raise
async def get_by_user_and_playlist(
self, user_id: int, playlist_id: int,
self,
user_id: int,
playlist_id: int,
) -> Favorite | None:
"""Get a favorite by user and playlist.

View File

@@ -57,7 +57,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
# management
except Exception:
logger.exception(
"Failed to update playlist timestamp for playlist: %s", playlist_id,
"Failed to update playlist timestamp for playlist: %s",
playlist_id,
)
raise
@@ -341,7 +342,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
include_stats: bool = False, # noqa: FBT001, FBT002
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False,
favorites_only: bool = False, # noqa: FBT001, FBT002
current_user_id: int | None = None,
*,
return_count: bool = False,
@@ -395,9 +396,13 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = select(1).select_from(Favorite).where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
favorites_subquery = (
select(1)
.select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
)
subquery = subquery.where(favorites_subquery.exists())
@@ -466,9 +471,13 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = select(1).select_from(Favorite).where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
favorites_subquery = (
select(1)
.select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
)
subquery = subquery.where(favorites_subquery.exists())

View File

@@ -141,7 +141,7 @@ class SoundRepository(BaseRepository[Sound]):
sort_order: SortOrder = SortOrder.ASC,
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False,
favorites_only: bool = False, # noqa: FBT001, FBT002
user_id: int | None = None,
) -> list[Sound]:
"""Search and sort sounds with optional filtering."""
@@ -189,7 +189,8 @@ class SoundRepository(BaseRepository[Sound]):
logger.exception(
(
"Failed to search and sort sounds: "
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, user_id=%s"
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, "
"user_id=%s"
),
search_query,
sound_types,
@@ -288,8 +289,7 @@ class SoundRepository(BaseRepository[Sound]):
# Group by sound and order by play count descending
statement = (
statement
.group_by(
statement.group_by(
Sound.id,
Sound.name,
Sound.type,

View File

@@ -1,7 +1,7 @@
"""User repository."""
from typing import Any
from enum import Enum
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import selectinload
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
class UserSortField(str, Enum):
"""User sort fields."""
NAME = "name"
EMAIL = "email"
ROLE = "role"
@@ -27,12 +28,14 @@ class UserSortField(str, Enum):
class SortOrder(str, Enum):
"""Sort order."""
ASC = "asc"
DESC = "desc"
class UserStatus(str, Enum):
"""User status filter."""
ALL = "all"
ACTIVE = "active"
INACTIVE = "inactive"
@@ -64,7 +67,7 @@ class UserRepository(BaseRepository[User]):
logger.exception("Failed to get all users with plan")
raise
async def get_all_with_plan_paginated(
async def get_all_with_plan_paginated( # noqa: PLR0913
self,
page: int = 1,
limit: int = 50,
@@ -77,21 +80,20 @@ class UserRepository(BaseRepository[User]):
try:
# Calculate offset
offset = (page - 1) * limit
# Build base query
base_query = select(User).options(selectinload(User.plan))
count_query = select(func.count(User.id))
# Apply search filter
if search and search.strip():
search_pattern = f"%{search.strip().lower()}%"
search_condition = (
func.lower(User.name).like(search_pattern) |
func.lower(User.email).like(search_pattern)
)
search_condition = func.lower(User.name).like(
search_pattern,
) | func.lower(User.email).like(search_pattern)
base_query = base_query.where(search_condition)
count_query = count_query.where(search_condition)
# Apply status filter
if status_filter == UserStatus.ACTIVE:
base_query = base_query.where(User.is_active == True) # noqa: E712
@@ -99,47 +101,34 @@ class UserRepository(BaseRepository[User]):
elif status_filter == UserStatus.INACTIVE:
base_query = base_query.where(User.is_active == False) # noqa: E712
count_query = count_query.where(User.is_active == False) # noqa: E712
# Apply sorting
if sort_by == UserSortField.EMAIL:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.email.desc())
else:
base_query = base_query.order_by(User.email.asc())
elif sort_by == UserSortField.ROLE:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.role.desc())
else:
base_query = base_query.order_by(User.role.asc())
elif sort_by == UserSortField.CREDITS:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.credits.desc())
else:
base_query = base_query.order_by(User.credits.asc())
elif sort_by == UserSortField.CREATED_AT:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.created_at.desc())
else:
base_query = base_query.order_by(User.created_at.asc())
else: # Default to name
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.name.desc())
else:
base_query = base_query.order_by(User.name.asc())
sort_column = {
UserSortField.NAME: User.name,
UserSortField.EMAIL: User.email,
UserSortField.ROLE: User.role,
UserSortField.CREDITS: User.credits,
UserSortField.CREATED_AT: User.created_at,
}.get(sort_by, User.name)
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(sort_column.desc())
else:
base_query = base_query.order_by(sort_column.asc())
# Get total count
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply pagination and get results
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
users = list(result.all())
return users, total_count
except Exception:
logger.exception("Failed to get paginated users with plan")
raise
else:
return users, total_count
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
"""Get a user by ID with plan relationship loaded."""
@@ -178,7 +167,7 @@ class UserRepository(BaseRepository[User]):
logger.exception("Failed to get user by API token")
raise
async def create(self, user_data: dict[str, Any]) -> User:
async def create(self, entity_data: dict[str, Any]) -> User:
"""Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None:
@@ -194,7 +183,7 @@ class UserRepository(BaseRepository[User]):
if is_first_user:
# First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro")
user_data["role"] = "admin"
entity_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan")
else:
# Regular users get free plan
@@ -210,11 +199,11 @@ class UserRepository(BaseRepository[User]):
assert default_plan is not None # noqa: S101
# Set plan_id and default credits
user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits
entity_data["plan_id"] = default_plan.id
entity_data["credits"] = default_plan.credits
# Use BaseRepository's create method
return await super().create(user_data)
return await super().create(entity_data)
except Exception:
logger.exception("Failed to create user")
raise

View File

@@ -85,7 +85,8 @@ class ChangePasswordRequest(BaseModel):
"""Schema for password change request."""
current_password: str | None = Field(
None, description="Current password (required if user has existing password)",
None,
description="Current password (required if user has existing password)",
)
new_password: str = Field(
...,
@@ -98,5 +99,8 @@ class UpdateProfileRequest(BaseModel):
"""Schema for profile update request."""
name: str | None = Field(
None, min_length=1, max_length=100, description="User display name",
None,
min_length=1,
max_length=100,
description="User display name",
)

View File

@@ -11,10 +11,12 @@ class FavoriteResponse(BaseModel):
id: int = Field(description="Favorite ID")
user_id: int = Field(description="User ID")
sound_id: int | None = Field(
description="Sound ID if this is a sound favorite", default=None,
description="Sound ID if this is a sound favorite",
default=None,
)
playlist_id: int | None = Field(
description="Playlist ID if this is a playlist favorite", default=None,
description="Playlist ID if this is a playlist favorite",
default=None,
)
created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp")

View File

@@ -39,14 +39,19 @@ class PlaylistResponse(BaseModel):
updated_at: str | None
@classmethod
def from_playlist(cls, playlist: Playlist, is_favorited: bool = False, favorite_count: int = 0) -> "PlaylistResponse":
def from_playlist(
cls,
playlist: Playlist,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "PlaylistResponse":
"""Create response from playlist model.
Args:
playlist: The Playlist model
is_favorited: Whether the playlist is favorited by the current user
favorite_count: Number of users who favorited this playlist
Returns:
PlaylistResponse instance

View File

@@ -18,16 +18,20 @@ class SoundResponse(BaseModel):
size: int = Field(description="File size in bytes")
hash: str = Field(description="File hash")
normalized_filename: str | None = Field(
description="Normalized filename", default=None,
description="Normalized filename",
default=None,
)
normalized_duration: int | None = Field(
description="Normalized duration in milliseconds", default=None,
description="Normalized duration in milliseconds",
default=None,
)
normalized_size: int | None = Field(
description="Normalized file size in bytes", default=None,
description="Normalized file size in bytes",
default=None,
)
normalized_hash: str | None = Field(
description="Normalized file hash", default=None,
description="Normalized file hash",
default=None,
)
thumbnail: str | None = Field(description="Thumbnail filename", default=None)
play_count: int = Field(description="Number of times played")
@@ -35,10 +39,12 @@ class SoundResponse(BaseModel):
is_music: bool = Field(description="Whether the sound is music")
is_deletable: bool = Field(description="Whether the sound can be deleted")
is_favorited: bool = Field(
description="Whether the sound is favorited by the current user", default=False,
description="Whether the sound is favorited by the current user",
default=False,
)
favorite_count: int = Field(
description="Number of users who favorited this sound", default=0,
description="Number of users who favorited this sound",
default=0,
)
created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp")
@@ -50,7 +56,10 @@ class SoundResponse(BaseModel):
@classmethod
def from_sound(
cls, sound: Sound, is_favorited: bool = False, favorite_count: int = 0,
cls,
sound: Sound,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "SoundResponse":
"""Create a SoundResponse from a Sound model.
@@ -64,7 +73,8 @@ class SoundResponse(BaseModel):
"""
if sound.id is None:
raise ValueError("Sound ID cannot be None")
msg = "Sound ID cannot be None"
raise ValueError(msg)
return cls(
id=sound.id,

View File

@@ -7,7 +7,10 @@ class UserUpdate(BaseModel):
"""Schema for updating a user."""
name: str | None = Field(
None, min_length=1, max_length=100, description="User full name",
None,
min_length=1,
max_length=100,
description="User full name",
)
plan_id: int | None = Field(None, description="User plan ID")
credits: int | None = Field(None, ge=0, description="User credits")

View File

@@ -454,7 +454,10 @@ class AuthService:
return user
async def change_user_password(
self, user: User, current_password: str | None, new_password: str,
self,
user: User,
current_password: str | None,
new_password: str,
) -> None:
"""Change user's password."""
# Store user email before any operations to avoid session detachment issues
@@ -484,8 +487,11 @@ class AuthService:
self.session.add(user)
await self.session.commit()
logger.info("Password %s successfully for user: %s",
"changed" if had_existing_password else "set", user_email)
logger.info(
"Password %s successfully for user: %s",
"changed" if had_existing_password else "set",
user_email,
)
async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information."""

View File

@@ -72,9 +72,7 @@ class DashboardService:
"play_count": sound["play_count"],
"duration": sound["duration"],
"created_at": (
sound["created_at"].isoformat()
if sound["created_at"]
else None
sound["created_at"].isoformat() if sound["created_at"] else None
),
}
for sound in top_sounds

View File

@@ -532,7 +532,8 @@ class ExtractionService:
"""Add the sound to the user's main playlist."""
try:
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
sound_id, user_id,
sound_id,
user_id,
)
logger.info(
"Added sound %d to main playlist for user %d",
@@ -554,6 +555,10 @@ class ExtractionService:
if not extraction:
return None
# Get user information
user = await self.user_repo.get_by_id(extraction.user_id)
user_name = user.name if user else None
return {
"id": extraction.id or 0, # Should never be None for existing extraction
"url": extraction.url,
@@ -564,11 +569,12 @@ class ExtractionService:
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user_name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
async def get_user_extractions(
async def get_user_extractions( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
@@ -580,7 +586,10 @@ class ExtractionService:
) -> PaginatedExtractionsResponse:
"""Get all extractions for a user with filtering, search, and sorting."""
offset = (page - 1) * limit
extraction_user_tuples, total_count = await self.extraction_repo.get_user_extractions_filtered(
(
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_user_extractions_filtered(
user_id=user_id,
search=search,
sort_by=sort_by,
@@ -619,7 +628,7 @@ class ExtractionService:
"total_pages": total_pages,
}
async def get_all_extractions(
async def get_all_extractions( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
@@ -630,7 +639,10 @@ class ExtractionService:
) -> PaginatedExtractionsResponse:
"""Get all extractions with filtering, search, and sorting."""
offset = (page - 1) * limit
extraction_user_tuples, total_count = await self.extraction_repo.get_all_extractions_filtered(
(
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_all_extractions_filtered(
search=search,
sort_by=sort_by,
sort_order=sort_order,

View File

@@ -49,12 +49,14 @@ class FavoriteService:
# Verify user exists
user = await user_repo.get_by_id(user_id)
if not user:
raise ValueError(f"User with ID {user_id} not found")
msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify sound exists
sound = await sound_repo.get_by_id(sound_id)
if not sound:
raise ValueError(f"Sound with ID {sound_id} not found")
msg = f"Sound with ID {sound_id} not found"
raise ValueError(msg)
# Get data for the event immediately after loading
sound_name = sound.name
@@ -63,9 +65,8 @@ class FavoriteService:
# Check if already favorited
existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if existing:
raise ValueError(
f"Sound {sound_id} is already favorited by user {user_id}",
)
msg = f"Sound {sound_id} is already favorited by user {user_id}"
raise ValueError(msg)
# Create favorite
favorite_data = {
@@ -120,12 +121,14 @@ class FavoriteService:
# Verify user exists
user = await user_repo.get_by_id(user_id)
if not user:
raise ValueError(f"User with ID {user_id} not found")
msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify playlist exists
playlist = await playlist_repo.get_by_id(playlist_id)
if not playlist:
raise ValueError(f"Playlist with ID {playlist_id} not found")
msg = f"Playlist with ID {playlist_id} not found"
raise ValueError(msg)
# Check if already favorited
existing = await favorite_repo.get_by_user_and_playlist(
@@ -133,9 +136,8 @@ class FavoriteService:
playlist_id,
)
if existing:
raise ValueError(
f"Playlist {playlist_id} is already favorited by user {user_id}",
)
msg = f"Playlist {playlist_id} is already favorited by user {user_id}"
raise ValueError(msg)
# Create favorite
favorite_data = {
@@ -163,7 +165,8 @@ class FavoriteService:
favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if not favorite:
raise ValueError(f"Sound {sound_id} is not favorited by user {user_id}")
msg = f"Sound {sound_id} is not favorited by user {user_id}"
raise ValueError(msg)
# Get user and sound info before deletion for the event
user_repo = UserRepository(session)
@@ -192,7 +195,8 @@ class FavoriteService:
}
await socket_manager.broadcast_to_all("sound_favorited", event_data)
logger.info(
"Broadcasted sound_favorited event for sound %s removal", sound_id,
"Broadcasted sound_favorited event for sound %s removal",
sound_id,
)
except Exception:
logger.exception(
@@ -219,9 +223,8 @@ class FavoriteService:
playlist_id,
)
if not favorite:
raise ValueError(
f"Playlist {playlist_id} is not favorited by user {user_id}",
)
msg = f"Playlist {playlist_id} is not favorited by user {user_id}"
raise ValueError(msg)
await favorite_repo.delete(favorite)
logger.info(

View File

@@ -16,6 +16,7 @@ logger = get_logger(__name__)
class PaginatedPlaylistsResponse(TypedDict):
"""Response type for paginated playlists."""
playlists: list[dict]
total: int
page: int
@@ -286,7 +287,7 @@ class PlaylistService:
) -> PaginatedPlaylistsResponse:
"""Search and sort playlists with pagination."""
offset = (page - 1) * limit
playlists, total_count = await self.playlist_repo.search_and_sort(
search_query=search_query,
sort_by=sort_by,
@@ -299,9 +300,9 @@ class PlaylistService:
current_user_id=current_user_id,
return_count=True,
)
total_pages = (total_count + limit - 1) // limit # Ceiling division
return PaginatedPlaylistsResponse(
playlists=playlists,
total=total_count,
@@ -468,7 +469,9 @@ class PlaylistService:
}
async def add_sound_to_main_playlist(
self, sound_id: int, user_id: int, # noqa: ARG002
self,
sound_id: int, # noqa: ARG002
user_id: int, # noqa: ARG002
) -> None:
"""Add a sound to the global main playlist."""
raise HTTPException(
@@ -477,7 +480,9 @@ class PlaylistService:
)
async def _add_sound_to_main_playlist_internal(
self, sound_id: int, user_id: int,
self,
sound_id: int,
user_id: int,
) -> None:
"""Add sound to main playlist bypassing restrictions.