Compare commits

...

12 Commits

Author SHA1 Message Date
JSC
a660cc1861 Merge branch 'favorite'
Some checks failed
Backend CI / lint (push) Successful in 9m21s
Backend CI / test (push) Failing after 3m59s
2025-08-17 13:25:59 +02:00
JSC
6b55ff0e81 Refactor user endpoint tests to include pagination and response structure validation
- Updated tests for listing users to validate pagination and response format.
- Changed mock return values to include total count and pagination details.
- Refactored user creation mocks for clarity and consistency.
- Enhanced assertions to check for presence of pagination fields in responses.
- Adjusted test cases for user retrieval and updates to ensure proper handling of user data.
- Improved readability by restructuring mock definitions and assertions across various test files.
2025-08-17 12:36:52 +02:00
JSC
e6f796a3c9 feat: Add pagination, search, and filter functionality to user retrieval endpoint 2025-08-17 11:44:15 +02:00
JSC
99c757a073 feat: Implement pagination for extractions and playlists with total count in responses 2025-08-17 11:21:55 +02:00
JSC
f598ec2c12 fix: Extract user name in session context for improved performance 2025-08-17 01:49:47 +02:00
JSC
66d22df7dd feat: Add filtering, searching, and sorting to extraction retrieval endpoints 2025-08-17 01:44:43 +02:00
JSC
3326e406f8 feat: Add filtering, searching, and sorting to user extractions retrieval 2025-08-17 01:27:41 +02:00
JSC
fe15e7a6af fix: Correct log message for sound favorited event broadcasting 2025-08-17 01:08:33 +02:00
JSC
f56cc8b4cc feat: Enhance sound favorite management; add WebSocket event broadcasting for favoriting and unfavoriting sounds 2025-08-16 22:19:24 +02:00
JSC
f906b6d643 feat: Enhance favorites functionality; add favorites filtering to playlists and sounds, and improve favorite indicators in responses 2025-08-16 21:41:50 +02:00
JSC
78508c84eb feat: Add favorites filtering to sound retrieval; include user-specific favorite sounds in the API response 2025-08-16 21:27:40 +02:00
JSC
a947fd830b feat: Implement favorites management API; add endpoints for adding, removing, and retrieving favorites for sounds and playlists
feat: Create Favorite model and repository for managing user favorites in the database
feat: Add FavoriteService to handle business logic for favorites management
feat: Enhance Playlist and Sound response schemas to include favorite indicators and counts
refactor: Update API routes to include favorites functionality in playlists and sounds
2025-08-16 21:16:02 +02:00
41 changed files with 2282 additions and 379 deletions

View File

@@ -7,6 +7,7 @@ from app.api.v1 import (
auth,
dashboard,
extractions,
favorites,
files,
main,
player,
@@ -22,6 +23,7 @@ api_router = APIRouter(prefix="/v1")
api_router.include_router(auth.router, tags=["authentication"])
api_router.include_router(dashboard.router, tags=["dashboard"])
api_router.include_router(extractions.router, tags=["extractions"])
api_router.include_router(favorites.router, tags=["favorites"])
api_router.include_router(files.router, tags=["files"])
api_router.include_router(main.router, tags=["main"])
api_router.include_router(player.router, tags=["player"])

View File

@@ -1,8 +1,8 @@
"""Admin users endpoints."""
from typing import Annotated
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
@@ -10,7 +10,7 @@ from app.core.dependencies import get_admin_user
from app.models.plan import Plan
from app.models.user import User
from app.repositories.plan import PlanRepository
from app.repositories.user import UserRepository
from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
from app.schemas.auth import UserResponse
from app.schemas.user import UserUpdate
@@ -36,22 +36,48 @@ def _user_to_response(user: User) -> UserResponse:
"name": user.plan.name,
"max_credits": user.plan.max_credits,
"features": [], # Add features if needed
} if user.plan else {},
}
if user.plan
else {},
created_at=user.created_at,
updated_at=user.updated_at,
)
@router.get("/")
async def list_users(
async def list_users( # noqa: PLR0913
session: Annotated[AsyncSession, Depends(get_db)],
limit: int = 100,
offset: int = 0,
) -> list[UserResponse]:
"""Get all users (admin only)."""
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
search: Annotated[str | None, Query(description="Search in name or email")] = None,
sort_by: Annotated[
UserSortField, Query(description="Sort by field"),
] = UserSortField.NAME,
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
status_filter: Annotated[
UserStatus, Query(description="Filter by status"),
] = UserStatus.ALL,
) -> dict[str, Any]:
"""Get all users with pagination, search, and filters (admin only)."""
user_repo = UserRepository(session)
users = await user_repo.get_all_with_plan(limit=limit, offset=offset)
return [_user_to_response(user) for user in users]
users, total_count = await user_repo.get_all_with_plan_paginated(
page=page,
limit=limit,
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
)
total_pages = (total_count + limit - 1) // limit # Ceiling division
return {
"users": [_user_to_response(user) for user in users],
"total": total_count,
"page": page,
"limit": limit,
"total_pages": total_pages,
}
@router.get("/{user_id}")

View File

@@ -464,7 +464,8 @@ async def update_profile(
"""Update the current user's profile."""
try:
updated_user = await auth_service.update_user_profile(
current_user, request.model_dump(exclude_unset=True),
current_user,
request.model_dump(exclude_unset=True),
)
return await auth_service.user_to_response(updated_user)
except Exception as e:
@@ -486,7 +487,9 @@ async def change_password(
user_email = current_user.email
try:
await auth_service.change_user_password(
current_user, request.current_password, request.new_password,
current_user,
request.current_password,
request.new_password,
)
except ValueError as e:
raise HTTPException(
@@ -513,11 +516,13 @@ async def get_user_providers(
# Add password provider if user has password
if current_user.password_hash:
providers.append({
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
})
providers.append(
{
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
},
)
# Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
@@ -528,10 +533,12 @@ async def get_user_providers(
elif oauth.provider == "google":
display_name = "Google"
providers.append({
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
})
providers.append(
{
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
},
)
return providers

View File

@@ -34,7 +34,8 @@ async def get_top_sounds(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
sound_type: Annotated[
str, Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
str,
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
],
period: Annotated[
str,
@@ -43,7 +44,8 @@ async def get_top_sounds(
),
] = "all_time",
limit: Annotated[
int, Query(description="Number of top sounds to return", ge=1, le=100),
int,
Query(description="Number of top sounds to return", ge=1, le=100),
] = 10,
) -> list[dict[str, Any]]:
"""Get top sounds by play count for a specific period."""

View File

@@ -2,7 +2,7 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
@@ -60,6 +60,46 @@ async def create_extraction(
}
@router.get("/user")
async def get_user_extractions( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions for the current user."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
result = await extraction_service.get_user_extractions(
user_id=current_user.id,
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
@@ -88,19 +128,27 @@ async def get_extraction(
@router.get("/")
async def get_user_extractions(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
async def get_all_extractions( # noqa: PLR0913
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, list[ExtractionInfo]]:
"""Get all extractions for the current user."""
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
extractions = await extraction_service.get_user_extractions(current_user.id)
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
@@ -108,6 +156,4 @@ async def get_user_extractions(
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return {
"extractions": extractions,
}
return result

197
app/api/v1/favorites.py Normal file
View File

@@ -0,0 +1,197 @@
"""Favorites management API endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.database import get_session_factory
from app.core.dependencies import get_current_active_user
from app.models.user import User
from app.schemas.common import MessageResponse
from app.schemas.favorite import (
FavoriteCountsResponse,
FavoriteResponse,
FavoritesListResponse,
)
from app.services.favorite import FavoriteService
router = APIRouter(prefix="/favorites", tags=["favorites"])
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
return FavoriteService(get_session_factory())
@router.get("/")
async def get_user_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0,
) -> FavoritesListResponse:
"""Get all favorites for the current user."""
favorites = await favorite_service.get_user_favorites(
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/sounds")
async def get_user_sound_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0,
) -> FavoritesListResponse:
"""Get sound favorites for the current user."""
favorites = await favorite_service.get_user_sound_favorites(
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/playlists")
async def get_user_playlist_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0,
) -> FavoritesListResponse:
"""Get playlist favorites for the current user."""
favorites = await favorite_service.get_user_playlist_favorites(
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/counts")
async def get_favorite_counts(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> FavoriteCountsResponse:
"""Get favorite counts for the current user."""
counts = await favorite_service.get_favorite_counts(current_user.id)
return FavoriteCountsResponse(**counts)
@router.post("/sounds/{sound_id}")
async def add_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> FavoriteResponse:
"""Add a sound to favorites."""
try:
favorite = await favorite_service.add_sound_favorite(current_user.id, sound_id)
return FavoriteResponse.model_validate(favorite)
except ValueError as e:
if "not found" in str(e):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
if "already favorited" in str(e):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
) from e
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
@router.post("/playlists/{playlist_id}")
async def add_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> FavoriteResponse:
"""Add a playlist to favorites."""
try:
favorite = await favorite_service.add_playlist_favorite(
current_user.id,
playlist_id,
)
return FavoriteResponse.model_validate(favorite)
except ValueError as e:
if "not found" in str(e):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
if "already favorited" in str(e):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=str(e),
) from e
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
@router.delete("/sounds/{sound_id}")
async def remove_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> MessageResponse:
"""Remove a sound from favorites."""
try:
await favorite_service.remove_sound_favorite(current_user.id, sound_id)
return MessageResponse(message="Sound removed from favorites")
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
@router.delete("/playlists/{playlist_id}")
async def remove_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> MessageResponse:
"""Remove a playlist from favorites."""
try:
await favorite_service.remove_playlist_favorite(current_user.id, playlist_id)
return MessageResponse(message="Playlist removed from favorites")
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
) from e
@router.get("/sounds/{sound_id}/check")
async def check_sound_favorited(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> dict[str, bool]:
"""Check if a sound is favorited by the current user."""
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound_id)
return {"is_favorited": is_favorited}
@router.get("/playlists/{playlist_id}/check")
async def check_playlist_favorited(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> dict[str, bool]:
"""Check if a playlist is favorited by the current user."""
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id,
playlist_id,
)
return {"is_favorited": is_favorited}

View File

@@ -1,11 +1,11 @@
"""Playlist management API endpoints."""
from typing import Annotated
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.database import get_db, get_session_factory
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder
@@ -19,6 +19,7 @@ from app.schemas.playlist import (
PlaylistStatsResponse,
PlaylistUpdateRequest,
)
from app.services.favorite import FavoriteService
from app.services.playlist import PlaylistService
router = APIRouter(prefix="/playlists", tags=["playlists"])
@@ -31,10 +32,16 @@ async def get_playlist_service(
return PlaylistService(session)
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
return FavoriteService(get_session_factory())
@router.get("/")
async def get_all_playlists( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
search: Annotated[
str | None,
Query(description="Search playlists by name"),
@@ -47,55 +54,115 @@ async def get_all_playlists( # noqa: PLR0913
SortOrder,
Query(description="Sort order (asc or desc)"),
] = SortOrder.ASC,
limit: Annotated[
int | None,
Query(description="Maximum number of results", ge=1, le=1000),
] = None,
offset: Annotated[
int,
Query(description="Number of results to skip", ge=0),
] = 0,
) -> list[dict]:
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited playlists"),
] = False,
) -> dict[str, Any]:
"""Get all playlists from all users with search and sorting."""
return await playlist_service.search_and_sort_playlists(
result = await playlist_service.search_and_sort_playlists_paginated(
search_query=search,
sort_by=sort_by,
sort_order=sort_order,
user_id=None,
include_stats=True,
page=page,
limit=limit,
offset=offset,
favorites_only=favorites_only,
current_user_id=current_user.id,
)
# Convert to PlaylistResponse with favorite indicators
playlist_responses = []
for playlist_dict in result["playlists"]:
# The playlist service returns dict, need to create playlist object structure
playlist_id = playlist_dict["id"]
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
# Create a PlaylistResponse-like dict with proper datetime conversion
playlist_response = {
**playlist_dict,
"created_at": (
playlist_dict["created_at"].isoformat()
if playlist_dict["created_at"]
else None
),
"updated_at": (
playlist_dict["updated_at"].isoformat()
if playlist_dict["updated_at"]
else None
),
"is_favorited": is_favorited,
"favorite_count": favorite_count,
}
playlist_responses.append(playlist_response)
return {
"playlists": playlist_responses,
"total": result["total"],
"page": result["page"],
"limit": result["limit"],
"total_pages": result["total_pages"],
}
@router.get("/user")
async def get_user_playlists(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> list[PlaylistResponse]:
"""Get playlists for the current user only."""
playlists = await playlist_service.get_user_playlists(current_user.id)
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
# Add favorite indicators for each playlist
playlist_responses = []
for playlist in playlists:
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
playlist_response = PlaylistResponse.from_playlist(
playlist, is_favorited, favorite_count,
)
playlist_responses.append(playlist_response)
return playlist_responses
@router.get("/main")
async def get_main_playlist(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> PlaylistResponse:
"""Get the global main playlist."""
playlist = await playlist_service.get_main_playlist()
return PlaylistResponse.from_playlist(playlist)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@router.get("/current")
async def get_current_playlist(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> PlaylistResponse:
"""Get the global current playlist (falls back to main playlist)."""
playlist = await playlist_service.get_current_playlist()
return PlaylistResponse.from_playlist(playlist)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@router.post("/")
@@ -117,12 +184,17 @@ async def create_playlist(
@router.get("/{playlist_id}")
async def get_playlist(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
) -> PlaylistResponse:
"""Get a specific playlist."""
playlist = await playlist_service.get_playlist_by_id(playlist_id)
return PlaylistResponse.from_playlist(playlist)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@router.put("/{playlist_id}")

View File

@@ -8,10 +8,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db, get_session_factory
from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType
from app.models.sound import Sound
from app.models.user import User
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
from app.schemas.sound import SoundResponse, SoundsListResponse
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.favorite import FavoriteService
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
router = APIRouter(prefix="/sounds", tags=["sounds"])
@@ -27,6 +28,11 @@ def get_credit_service() -> CreditService:
return CreditService(get_session_factory())
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
return FavoriteService(get_session_factory())
async def get_sound_repository(
session: Annotated[AsyncSession, Depends(get_db)],
) -> SoundRepository:
@@ -36,8 +42,9 @@ async def get_sound_repository(
@router.get("/")
async def get_sounds( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
types: Annotated[
list[str] | None,
Query(description="Filter by sound types (e.g., SDB, TTS, EXT)"),
@@ -62,7 +69,11 @@ async def get_sounds( # noqa: PLR0913
int,
Query(description="Number of results to skip", ge=0),
] = 0,
) -> dict[str, list[Sound]]:
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited sounds"),
] = False,
) -> SoundsListResponse:
"""Get sounds with optional search, filtering, and sorting."""
try:
sounds = await sound_repo.search_and_sort(
@@ -72,14 +83,29 @@ async def get_sounds( # noqa: PLR0913
sort_order=sort_order,
limit=limit,
offset=offset,
favorites_only=favorites_only,
user_id=current_user.id,
)
# Add favorite indicators for each sound
sound_responses = []
for sound in sounds:
is_favorited = await favorite_service.is_sound_favorited(
current_user.id, sound.id,
)
favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
sound_response = SoundResponse.from_sound(
sound, is_favorited, favorite_count,
)
sound_responses.append(sound_response)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get sounds: {e!s}",
) from e
else:
return {"sounds": sounds}
return SoundsListResponse(sounds=sound_responses)
# VLC PLAYER

View File

@@ -20,7 +20,7 @@ class Settings(BaseSettings):
# Production URLs (for reverse proxy deployment)
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
# CORS Configuration
CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS

View File

@@ -9,6 +9,7 @@ from app.core.logging import get_logger
from app.core.seeds import seed_all_data
from app.models import ( # noqa: F401
extraction,
favorite,
plan,
playlist,
playlist_sound,

View File

@@ -20,7 +20,9 @@ class BaseModel(SQLModel):
# SQLAlchemy event listener to automatically update updated_at timestamp
@event.listens_for(BaseModel, "before_update", propagate=True)
def update_timestamp(
mapper: Mapper[Any], connection: Connection, target: BaseModel, # noqa: ARG001
mapper: Mapper[Any], # noqa: ARG001
connection: Connection, # noqa: ARG001
target: BaseModel,
) -> None:
"""Automatically set updated_at timestamp before update operations."""
target.updated_at = datetime.now(UTC)

29
app/models/favorite.py Normal file
View File

@@ -0,0 +1,29 @@
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.playlist import Playlist
from app.models.sound import Sound
from app.models.user import User
class Favorite(BaseModel, table=True):
"""Database model for user favorites (sounds and playlists)."""
user_id: int = Field(foreign_key="user.id", nullable=False)
sound_id: int | None = Field(foreign_key="sound.id", default=None)
playlist_id: int | None = Field(foreign_key="playlist.id", default=None)
# constraints
__table_args__ = (
UniqueConstraint("user_id", "sound_id", name="uq_favorite_user_sound"),
UniqueConstraint("user_id", "playlist_id", name="uq_favorite_user_playlist"),
)
# relationships
user: "User" = Relationship(back_populates="favorites")
sound: "Sound" = Relationship(back_populates="favorites")
playlist: "Playlist" = Relationship(back_populates="favorites")

View File

@@ -5,6 +5,7 @@ from sqlmodel import Field, Relationship
from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.favorite import Favorite
from app.models.playlist_sound import PlaylistSound
from app.models.user import User
@@ -23,3 +24,4 @@ class Playlist(BaseModel, table=True):
# relationships
user: "User" = Relationship(back_populates="playlists")
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="playlist")
favorites: list["Favorite"] = Relationship(back_populates="playlist")

View File

@@ -35,5 +35,3 @@ class PlaylistSound(BaseModel, table=True):
# relationships
playlist: "Playlist" = Relationship(back_populates="playlist_sounds")
sound: "Sound" = Relationship(back_populates="playlist_sounds")

View File

@@ -6,6 +6,7 @@ from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.extraction import Extraction
from app.models.favorite import Favorite
from app.models.playlist_sound import PlaylistSound
from app.models.sound_played import SoundPlayed
@@ -36,3 +37,4 @@ class Sound(BaseModel, table=True):
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound")
extractions: list["Extraction"] = Relationship(back_populates="sound")
play_history: list["SoundPlayed"] = Relationship(back_populates="sound")
favorites: list["Favorite"] = Relationship(back_populates="sound")

View File

@@ -8,6 +8,7 @@ from app.models.base import BaseModel
if TYPE_CHECKING:
from app.models.credit_transaction import CreditTransaction
from app.models.extraction import Extraction
from app.models.favorite import Favorite
from app.models.plan import Plan
from app.models.playlist import Playlist
from app.models.sound_played import SoundPlayed
@@ -37,3 +38,4 @@ class User(BaseModel, table=True):
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
extractions: list["Extraction"] = Relationship(back_populates="user")
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")
favorites: list["Favorite"] = Relationship(back_populates="user")

View File

@@ -1,10 +1,11 @@
"""Extraction repository for database operations."""
from sqlalchemy import desc
from sqlalchemy import asc, desc, func, or_
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction
from app.models.user import User
from app.repositories.base import BaseRepository
@@ -38,10 +39,11 @@ class ExtractionRepository(BaseRepository[Extraction]):
)
return list(result.all())
async def get_pending_extractions(self) -> list[Extraction]:
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
"""Get all pending extractions."""
result = await self.session.exec(
select(Extraction)
select(Extraction, User)
.join(User, Extraction.user_id == User.id)
.where(Extraction.status == "pending")
.order_by(Extraction.created_at),
)
@@ -55,3 +57,100 @@ class ExtractionRepository(BaseRepository[Extraction]):
.order_by(desc(Extraction.created_at)),
)
return list(result.all())
async def get_user_extractions_filtered( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[tuple[Extraction, User]], int]:
"""Get extractions for a user with filtering, search, and sorting."""
base_query = (
select(Extraction, User)
.join(User, Extraction.user_id == User.id)
.where(Extraction.user_id == user_id)
)
# Apply search filter
if search:
search_pattern = f"%{search}%"
base_query = base_query.where(
or_(
Extraction.title.ilike(search_pattern),
Extraction.url.ilike(search_pattern),
Extraction.service.ilike(search_pattern),
),
)
# Apply status filter
if status_filter:
base_query = base_query.where(Extraction.status == status_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply sorting and pagination
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
if sort_order.lower() == "asc":
base_query = base_query.order_by(asc(sort_column))
else:
base_query = base_query.order_by(desc(sort_column))
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
return list(result.all()), total_count
async def get_all_extractions_filtered( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[tuple[Extraction, User]], int]:
"""Get all extractions with filtering, search, and sorting."""
base_query = select(Extraction, User).join(User, Extraction.user_id == User.id)
# Apply search filter
if search:
search_pattern = f"%{search}%"
base_query = base_query.where(
or_(
Extraction.title.ilike(search_pattern),
Extraction.url.ilike(search_pattern),
Extraction.service.ilike(search_pattern),
),
)
# Apply status filter
if status_filter:
base_query = base_query.where(Extraction.status == status_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply sorting and pagination
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
if sort_order.lower() == "asc":
base_query = base_query.order_by(asc(sort_column))
else:
base_query = base_query.order_by(desc(sort_column))
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
return list(result.all()), total_count

View File

@@ -0,0 +1,258 @@
"""Repository for managing favorites."""
from sqlmodel import and_, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class FavoriteRepository(BaseRepository[Favorite]):
"""Repository for managing favorites."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the favorite repository.
Args:
session: Database session
"""
super().__init__(Favorite, session)
async def get_user_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get all favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user favorites
"""
try:
statement = (
select(Favorite)
.where(Favorite.user_id == user_id)
.limit(limit)
.offset(offset)
.order_by(Favorite.created_at.desc())
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get favorites for user: %s", user_id)
raise
async def get_user_sound_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get sound favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user sound favorites
"""
try:
statement = (
select(Favorite)
.where(and_(Favorite.user_id == user_id, Favorite.sound_id.isnot(None)))
.limit(limit)
.offset(offset)
.order_by(Favorite.created_at.desc())
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get sound favorites for user: %s", user_id)
raise
async def get_user_playlist_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get playlist favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user playlist favorites
"""
try:
statement = (
select(Favorite)
.where(
and_(Favorite.user_id == user_id, Favorite.playlist_id.isnot(None)),
)
.limit(limit)
.offset(offset)
.order_by(Favorite.created_at.desc())
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get playlist favorites for user: %s", user_id)
raise
async def get_by_user_and_sound(
self,
user_id: int,
sound_id: int,
) -> Favorite | None:
"""Get a favorite by user and sound.
Args:
user_id: The user ID
sound_id: The sound ID
Returns:
The favorite if found, None otherwise
"""
try:
statement = select(Favorite).where(
and_(Favorite.user_id == user_id, Favorite.sound_id == sound_id),
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception(
"Failed to get favorite for user %s and sound %s",
user_id,
sound_id,
)
raise
async def get_by_user_and_playlist(
self,
user_id: int,
playlist_id: int,
) -> Favorite | None:
"""Get a favorite by user and playlist.
Args:
user_id: The user ID
playlist_id: The playlist ID
Returns:
The favorite if found, None otherwise
"""
try:
statement = select(Favorite).where(
and_(Favorite.user_id == user_id, Favorite.playlist_id == playlist_id),
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception(
"Failed to get favorite for user %s and playlist %s",
user_id,
playlist_id,
)
raise
async def is_sound_favorited(self, user_id: int, sound_id: int) -> bool:
"""Check if a sound is favorited by a user.
Args:
user_id: The user ID
sound_id: The sound ID
Returns:
True if the sound is favorited, False otherwise
"""
favorite = await self.get_by_user_and_sound(user_id, sound_id)
return favorite is not None
async def is_playlist_favorited(self, user_id: int, playlist_id: int) -> bool:
"""Check if a playlist is favorited by a user.
Args:
user_id: The user ID
playlist_id: The playlist ID
Returns:
True if the playlist is favorited, False otherwise
"""
favorite = await self.get_by_user_and_playlist(user_id, playlist_id)
return favorite is not None
async def count_user_favorites(self, user_id: int) -> int:
"""Count total favorites for a user.
Args:
user_id: The user ID
Returns:
Total number of favorites
"""
try:
statement = select(Favorite).where(Favorite.user_id == user_id)
result = await self.session.exec(statement)
return len(list(result.all()))
except Exception:
logger.exception("Failed to count favorites for user: %s", user_id)
raise
async def count_sound_favorites(self, sound_id: int) -> int:
"""Count how many users have favorited a sound.
Args:
sound_id: The sound ID
Returns:
Number of users who favorited this sound
"""
try:
statement = select(Favorite).where(Favorite.sound_id == sound_id)
result = await self.session.exec(statement)
return len(list(result.all()))
except Exception:
logger.exception("Failed to count favorites for sound: %s", sound_id)
raise
async def count_playlist_favorites(self, playlist_id: int) -> int:
"""Count how many users have favorited a playlist.
Args:
playlist_id: The playlist ID
Returns:
Number of users who favorited this playlist
"""
try:
statement = select(Favorite).where(Favorite.playlist_id == playlist_id)
result = await self.session.exec(statement)
return len(list(result.all()))
except Exception:
logger.exception("Failed to count favorites for playlist: %s", playlist_id)
raise

View File

@@ -9,6 +9,7 @@ from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.models.playlist import Playlist
from app.models.playlist_sound import PlaylistSound
from app.models.sound import Sound
@@ -56,7 +57,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
# management
except Exception:
logger.exception(
"Failed to update playlist timestamp for playlist: %s", playlist_id,
"Failed to update playlist timestamp for playlist: %s",
playlist_id,
)
raise
@@ -340,7 +342,11 @@ class PlaylistRepository(BaseRepository[Playlist]):
include_stats: bool = False, # noqa: FBT001, FBT002
limit: int | None = None,
offset: int = 0,
) -> list[dict]:
favorites_only: bool = False, # noqa: FBT001, FBT002
current_user_id: int | None = None,
*,
return_count: bool = False,
) -> list[dict] | tuple[list[dict], int]:
"""Search and sort playlists with optional statistics."""
try:
if include_stats and sort_by in (
@@ -387,6 +393,19 @@ class PlaylistRepository(BaseRepository[Playlist]):
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = (
select(1)
.select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
)
subquery = subquery.where(favorites_subquery.exists())
# Apply sorting
if sort_by == PlaylistSortField.SOUND_COUNT:
if sort_order == SortOrder.DESC:
@@ -449,6 +468,19 @@ class PlaylistRepository(BaseRepository[Playlist]):
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = (
select(1)
.select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
)
subquery = subquery.where(favorites_subquery.exists())
# Apply sorting
if sort_by:
if sort_by == PlaylistSortField.NAME:
@@ -470,6 +502,14 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Default sorting by name ascending
subquery = subquery.order_by(Playlist.name.asc())
# Get total count if requested
total_count = 0
if return_count:
# Create count query from the subquery before pagination
count_query = select(func.count()).select_from(subquery.subquery())
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply pagination
if offset > 0:
subquery = subquery.offset(offset)
@@ -511,4 +551,6 @@ class PlaylistRepository(BaseRepository[Playlist]):
)
raise
else:
if return_count:
return playlists, total_count
return playlists

View File

@@ -8,6 +8,7 @@ from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.models.sound import Sound
from app.models.sound_played import SoundPlayed
from app.repositories.base import BaseRepository
@@ -140,11 +141,20 @@ class SoundRepository(BaseRepository[Sound]):
sort_order: SortOrder = SortOrder.ASC,
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False, # noqa: FBT001, FBT002
user_id: int | None = None,
) -> list[Sound]:
"""Search and sort sounds with optional filtering."""
try:
statement = select(Sound)
# Apply favorites filter
if favorites_only and user_id is not None:
statement = statement.join(Favorite).where(
Favorite.user_id == user_id,
Favorite.sound_id == Sound.id,
)
# Apply type filter
if sound_types:
statement = statement.where(col(Sound.type).in_(sound_types))
@@ -179,12 +189,15 @@ class SoundRepository(BaseRepository[Sound]):
logger.exception(
(
"Failed to search and sort sounds: "
"query=%s, types=%s, sort_by=%s, sort_order=%s"
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, "
"user_id=%s"
),
search_query,
sound_types,
sort_by,
sort_order,
favorites_only,
user_id,
)
raise
@@ -276,8 +289,7 @@ class SoundRepository(BaseRepository[Sound]):
# Group by sound and order by play count descending
statement = (
statement
.group_by(
statement.group_by(
Sound.id,
Sound.name,
Sound.type,

View File

@@ -1,7 +1,9 @@
"""User repository."""
from enum import Enum
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -14,6 +16,31 @@ from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class UserSortField(str, Enum):
"""User sort fields."""
NAME = "name"
EMAIL = "email"
ROLE = "role"
CREDITS = "credits"
CREATED_AT = "created_at"
class SortOrder(str, Enum):
"""Sort order."""
ASC = "asc"
DESC = "desc"
class UserStatus(str, Enum):
"""User status filter."""
ALL = "all"
ACTIVE = "active"
INACTIVE = "inactive"
class UserRepository(BaseRepository[User]):
"""Repository for user operations."""
@@ -40,6 +67,69 @@ class UserRepository(BaseRepository[User]):
logger.exception("Failed to get all users with plan")
raise
async def get_all_with_plan_paginated( # noqa: PLR0913
self,
page: int = 1,
limit: int = 50,
search: str | None = None,
sort_by: UserSortField = UserSortField.NAME,
sort_order: SortOrder = SortOrder.ASC,
status_filter: UserStatus = UserStatus.ALL,
) -> tuple[list[User], int]:
"""Get all users with plan relationship loaded and return total count."""
try:
# Calculate offset
offset = (page - 1) * limit
# Build base query
base_query = select(User).options(selectinload(User.plan))
count_query = select(func.count(User.id))
# Apply search filter
if search and search.strip():
search_pattern = f"%{search.strip().lower()}%"
search_condition = func.lower(User.name).like(
search_pattern,
) | func.lower(User.email).like(search_pattern)
base_query = base_query.where(search_condition)
count_query = count_query.where(search_condition)
# Apply status filter
if status_filter == UserStatus.ACTIVE:
base_query = base_query.where(User.is_active == True) # noqa: E712
count_query = count_query.where(User.is_active == True) # noqa: E712
elif status_filter == UserStatus.INACTIVE:
base_query = base_query.where(User.is_active == False) # noqa: E712
count_query = count_query.where(User.is_active == False) # noqa: E712
# Apply sorting
sort_column = {
UserSortField.NAME: User.name,
UserSortField.EMAIL: User.email,
UserSortField.ROLE: User.role,
UserSortField.CREDITS: User.credits,
UserSortField.CREATED_AT: User.created_at,
}.get(sort_by, User.name)
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(sort_column.desc())
else:
base_query = base_query.order_by(sort_column.asc())
# Get total count
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply pagination and get results
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
users = list(result.all())
except Exception:
logger.exception("Failed to get paginated users with plan")
raise
else:
return users, total_count
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
"""Get a user by ID with plan relationship loaded."""
try:
@@ -77,7 +167,7 @@ class UserRepository(BaseRepository[User]):
logger.exception("Failed to get user by API token")
raise
async def create(self, user_data: dict[str, Any]) -> User:
async def create(self, entity_data: dict[str, Any]) -> User:
"""Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None:
@@ -93,7 +183,7 @@ class UserRepository(BaseRepository[User]):
if is_first_user:
# First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro")
user_data["role"] = "admin"
entity_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan")
else:
# Regular users get free plan
@@ -109,11 +199,11 @@ class UserRepository(BaseRepository[User]):
assert default_plan is not None # noqa: S101
# Set plan_id and default credits
user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits
entity_data["plan_id"] = default_plan.id
entity_data["credits"] = default_plan.credits
# Use BaseRepository's create method
return await super().create(user_data)
return await super().create(entity_data)
except Exception:
logger.exception("Failed to create user")
raise

View File

@@ -85,7 +85,8 @@ class ChangePasswordRequest(BaseModel):
"""Schema for password change request."""
current_password: str | None = Field(
None, description="Current password (required if user has existing password)",
None,
description="Current password (required if user has existing password)",
)
new_password: str = Field(
...,
@@ -98,5 +99,8 @@ class UpdateProfileRequest(BaseModel):
"""Schema for profile update request."""
name: str | None = Field(
None, min_length=1, max_length=100, description="User display name",
None,
min_length=1,
max_length=100,
description="User display name",
)

41
app/schemas/favorite.py Normal file
View File

@@ -0,0 +1,41 @@
"""Favorite response schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
class FavoriteResponse(BaseModel):
"""Response schema for a favorite."""
id: int = Field(description="Favorite ID")
user_id: int = Field(description="User ID")
sound_id: int | None = Field(
description="Sound ID if this is a sound favorite",
default=None,
)
playlist_id: int | None = Field(
description="Playlist ID if this is a playlist favorite",
default=None,
)
created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp")
class Config:
"""Pydantic config."""
from_attributes = True
class FavoritesListResponse(BaseModel):
"""Response schema for a list of favorites."""
favorites: list[FavoriteResponse] = Field(description="List of favorites")
class FavoriteCountsResponse(BaseModel):
"""Response schema for favorite counts."""
total: int = Field(description="Total number of favorites")
sounds: int = Field(description="Number of favorited sounds")
playlists: int = Field(description="Number of favorited playlists")

View File

@@ -33,12 +33,29 @@ class PlaylistResponse(BaseModel):
is_main: bool
is_current: bool
is_deletable: bool
is_favorited: bool = False
favorite_count: int = 0
created_at: str
updated_at: str | None
@classmethod
def from_playlist(cls, playlist: Playlist) -> "PlaylistResponse":
"""Create response from playlist model."""
def from_playlist(
cls,
playlist: Playlist,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "PlaylistResponse":
"""Create response from playlist model.
Args:
playlist: The Playlist model
is_favorited: Whether the playlist is favorited by the current user
favorite_count: Number of users who favorited this playlist
Returns:
PlaylistResponse instance
"""
if playlist.id is None:
msg = "Playlist ID cannot be None"
raise ValueError(msg)
@@ -50,6 +67,8 @@ class PlaylistResponse(BaseModel):
is_main=playlist.is_main,
is_current=playlist.is_current,
is_deletable=playlist.is_deletable,
is_favorited=is_favorited,
favorite_count=favorite_count,
created_at=playlist.created_at.isoformat(),
updated_at=playlist.updated_at.isoformat() if playlist.updated_at else None,
)

106
app/schemas/sound.py Normal file
View File

@@ -0,0 +1,106 @@
"""Sound response schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
from app.models.sound import Sound
class SoundResponse(BaseModel):
"""Response schema for a sound with favorite indicator."""
id: int = Field(description="Sound ID")
type: str = Field(description="Sound type")
name: str = Field(description="Sound name")
filename: str = Field(description="Sound filename")
duration: int = Field(description="Duration in milliseconds")
size: int = Field(description="File size in bytes")
hash: str = Field(description="File hash")
normalized_filename: str | None = Field(
description="Normalized filename",
default=None,
)
normalized_duration: int | None = Field(
description="Normalized duration in milliseconds",
default=None,
)
normalized_size: int | None = Field(
description="Normalized file size in bytes",
default=None,
)
normalized_hash: str | None = Field(
description="Normalized file hash",
default=None,
)
thumbnail: str | None = Field(description="Thumbnail filename", default=None)
play_count: int = Field(description="Number of times played")
is_normalized: bool = Field(description="Whether the sound is normalized")
is_music: bool = Field(description="Whether the sound is music")
is_deletable: bool = Field(description="Whether the sound can be deleted")
is_favorited: bool = Field(
description="Whether the sound is favorited by the current user",
default=False,
)
favorite_count: int = Field(
description="Number of users who favorited this sound",
default=0,
)
created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp")
class Config:
"""Pydantic config."""
from_attributes = True
@classmethod
def from_sound(
cls,
sound: Sound,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "SoundResponse":
"""Create a SoundResponse from a Sound model.
Args:
sound: The Sound model
is_favorited: Whether the sound is favorited by the current user
favorite_count: Number of users who favorited this sound
Returns:
SoundResponse instance
"""
if sound.id is None:
msg = "Sound ID cannot be None"
raise ValueError(msg)
return cls(
id=sound.id,
type=sound.type,
name=sound.name,
filename=sound.filename,
duration=sound.duration,
size=sound.size,
hash=sound.hash,
normalized_filename=sound.normalized_filename,
normalized_duration=sound.normalized_duration,
normalized_size=sound.normalized_size,
normalized_hash=sound.normalized_hash,
thumbnail=sound.thumbnail,
play_count=sound.play_count,
is_normalized=sound.is_normalized,
is_music=sound.is_music,
is_deletable=sound.is_deletable,
is_favorited=is_favorited,
favorite_count=favorite_count,
created_at=sound.created_at,
updated_at=sound.updated_at,
)
class SoundsListResponse(BaseModel):
"""Response schema for a list of sounds."""
sounds: list[SoundResponse] = Field(description="List of sounds")

View File

@@ -7,7 +7,10 @@ class UserUpdate(BaseModel):
"""Schema for updating a user."""
name: str | None = Field(
None, min_length=1, max_length=100, description="User full name",
None,
min_length=1,
max_length=100,
description="User full name",
)
plan_id: int | None = Field(None, description="User plan ID")
credits: int | None = Field(None, ge=0, description="User credits")

View File

@@ -454,7 +454,10 @@ class AuthService:
return user
async def change_user_password(
self, user: User, current_password: str | None, new_password: str,
self,
user: User,
current_password: str | None,
new_password: str,
) -> None:
"""Change user's password."""
# Store user email before any operations to avoid session detachment issues
@@ -484,8 +487,11 @@ class AuthService:
self.session.add(user)
await self.session.commit()
logger.info("Password %s successfully for user: %s",
"changed" if had_existing_password else "set", user_email)
logger.info(
"Password %s successfully for user: %s",
"changed" if had_existing_password else "set",
user_email,
)
async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information."""

View File

@@ -72,9 +72,7 @@ class DashboardService:
"play_count": sound["play_count"],
"duration": sound["duration"],
"created_at": (
sound["created_at"].isoformat()
if sound["created_at"]
else None
sound["created_at"].isoformat() if sound["created_at"] else None
),
}
for sound in top_sounds

View File

@@ -13,6 +13,7 @@ from app.core.logging import get_logger
from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository
from app.repositories.user import UserRepository
from app.services.playlist import PlaylistService
from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
@@ -32,10 +33,21 @@ class ExtractionInfo(TypedDict):
error: str | None
sound_id: int | None
user_id: int
user_name: str | None
created_at: str
updated_at: str
class PaginatedExtractionsResponse(TypedDict):
"""Type definition for paginated extractions response."""
extractions: list[ExtractionInfo]
total: int
page: int
limit: int
total_pages: int
class ExtractionService:
"""Service for extracting audio from external services using yt-dlp."""
@@ -44,6 +56,7 @@ class ExtractionService:
self.session = session
self.extraction_repo = ExtractionRepository(session)
self.sound_repo = SoundRepository(session)
self.user_repo = UserRepository(session)
self.playlist_service = PlaylistService(session)
# Ensure required directories exist
@@ -66,6 +79,15 @@ class ExtractionService:
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
try:
# Get user information
user = await self.user_repo.get_by_id(user_id)
if not user:
msg = f"User {user_id} not found"
raise ValueError(msg)
# Extract user name immediately while in session context
user_name = user.name
# Create the extraction record without service detection for fast response
extraction_data = {
"url": url,
@@ -92,6 +114,7 @@ class ExtractionService:
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user_name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
@@ -509,7 +532,8 @@ class ExtractionService:
"""Add the sound to the user's main playlist."""
try:
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
sound_id, user_id,
sound_id,
user_id,
)
logger.info(
"Added sound %d to main playlist for user %d",
@@ -531,6 +555,10 @@ class ExtractionService:
if not extraction:
return None
# Get user information
user = await self.user_repo.get_by_id(extraction.user_id)
user_name = user.name if user else None
return {
"id": extraction.id or 0, # Should never be None for existing extraction
"url": extraction.url,
@@ -541,15 +569,37 @@ class ExtractionService:
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user_name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
async def get_user_extractions(self, user_id: int) -> list[ExtractionInfo]:
"""Get all extractions for a user."""
extractions = await self.extraction_repo.get_by_user(user_id)
async def get_user_extractions( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
page: int = 1,
limit: int = 50,
) -> PaginatedExtractionsResponse:
"""Get all extractions for a user with filtering, search, and sorting."""
offset = (page - 1) * limit
(
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_user_extractions_filtered(
user_id=user_id,
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
limit=limit,
offset=offset,
)
return [
extractions = [
{
"id": extraction.id
or 0, # Should never be None for existing extraction
@@ -561,15 +611,78 @@ class ExtractionService:
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user.name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
for extraction in extractions
for extraction, user in extraction_user_tuples
]
total_pages = (total_count + limit - 1) // limit # Ceiling division
return {
"extractions": extractions,
"total": total_count,
"page": page,
"limit": limit,
"total_pages": total_pages,
}
async def get_all_extractions( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
page: int = 1,
limit: int = 50,
) -> PaginatedExtractionsResponse:
"""Get all extractions with filtering, search, and sorting."""
offset = (page - 1) * limit
(
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_all_extractions_filtered(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
limit=limit,
offset=offset,
)
extractions = [
{
"id": extraction.id
or 0, # Should never be None for existing extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user.name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
for extraction, user in extraction_user_tuples
]
total_pages = (total_count + limit - 1) // limit # Ceiling division
return {
"extractions": extractions,
"total": total_count,
"page": page,
"limit": limit,
"total_pages": total_pages,
}
async def get_pending_extractions(self) -> list[ExtractionInfo]:
"""Get all pending extractions."""
extractions = await self.extraction_repo.get_pending_extractions()
extraction_user_tuples = await self.extraction_repo.get_pending_extractions()
return [
{
@@ -583,8 +696,9 @@ class ExtractionService:
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user.name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
for extraction in extractions
for extraction, user in extraction_user_tuples
]

382
app/services/favorite.py Normal file
View File

@@ -0,0 +1,382 @@
"""Service for managing user favorites."""
from collections.abc import Callable
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.favorite import Favorite
from app.repositories.favorite import FavoriteRepository
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
from app.repositories.user import UserRepository
from app.services.socket import socket_manager
logger = get_logger(__name__)
class FavoriteService:
"""Service for managing user favorites."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
"""Initialize the favorite service.
Args:
db_session_factory: Factory function to create database sessions
"""
self.db_session_factory = db_session_factory
async def add_sound_favorite(self, user_id: int, sound_id: int) -> Favorite:
"""Add a sound to user's favorites.
Args:
user_id: The user ID
sound_id: The sound ID
Returns:
The created favorite
Raises:
ValueError: If user or sound not found, or already favorited
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
user_repo = UserRepository(session)
sound_repo = SoundRepository(session)
# Verify user exists
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify sound exists
sound = await sound_repo.get_by_id(sound_id)
if not sound:
msg = f"Sound with ID {sound_id} not found"
raise ValueError(msg)
# Get data for the event immediately after loading
sound_name = sound.name
user_name = user.name
# Check if already favorited
existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if existing:
msg = f"Sound {sound_id} is already favorited by user {user_id}"
raise ValueError(msg)
# Create favorite
favorite_data = {
"user_id": user_id,
"sound_id": sound_id,
"playlist_id": None,
}
favorite = await favorite_repo.create(favorite_data)
logger.info("User %s favorited sound %s", user_id, sound_id)
# Get updated favorite count within the same session
favorite_count = await favorite_repo.count_sound_favorites(sound_id)
# Emit sound_favorited event via WebSocket (outside the session)
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": user_id,
"user_name": user_name,
"favorite_count": favorite_count,
}
await socket_manager.broadcast_to_all("sound_favorited", event_data)
logger.info("Broadcasted sound_favorited event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_favorited event for sound %s",
sound_id,
)
return favorite
async def add_playlist_favorite(self, user_id: int, playlist_id: int) -> Favorite:
"""Add a playlist to user's favorites.
Args:
user_id: The user ID
playlist_id: The playlist ID
Returns:
The created favorite
Raises:
ValueError: If user or playlist not found, or already favorited
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
user_repo = UserRepository(session)
playlist_repo = PlaylistRepository(session)
# Verify user exists
user = await user_repo.get_by_id(user_id)
if not user:
msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify playlist exists
playlist = await playlist_repo.get_by_id(playlist_id)
if not playlist:
msg = f"Playlist with ID {playlist_id} not found"
raise ValueError(msg)
# Check if already favorited
existing = await favorite_repo.get_by_user_and_playlist(
user_id,
playlist_id,
)
if existing:
msg = f"Playlist {playlist_id} is already favorited by user {user_id}"
raise ValueError(msg)
# Create favorite
favorite_data = {
"user_id": user_id,
"sound_id": None,
"playlist_id": playlist_id,
}
favorite = await favorite_repo.create(favorite_data)
logger.info("User %s favorited playlist %s", user_id, playlist_id)
return favorite
async def remove_sound_favorite(self, user_id: int, sound_id: int) -> None:
"""Remove a sound from user's favorites.
Args:
user_id: The user ID
sound_id: The sound ID
Raises:
ValueError: If favorite not found
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if not favorite:
msg = f"Sound {sound_id} is not favorited by user {user_id}"
raise ValueError(msg)
# Get user and sound info before deletion for the event
user_repo = UserRepository(session)
sound_repo = SoundRepository(session)
user = await user_repo.get_by_id(user_id)
sound = await sound_repo.get_by_id(sound_id)
# Get data for the event immediately after loading
sound_name = sound.name if sound else "Unknown"
user_name = user.name if user else "Unknown"
await favorite_repo.delete(favorite)
logger.info("User %s removed sound %s from favorites", user_id, sound_id)
# Get updated favorite count after deletion within the same session
favorite_count = await favorite_repo.count_sound_favorites(sound_id)
# Emit sound_favorited event via WebSocket (outside the session)
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": user_id,
"user_name": user_name,
"favorite_count": favorite_count,
}
await socket_manager.broadcast_to_all("sound_favorited", event_data)
logger.info(
"Broadcasted sound_favorited event for sound %s removal",
sound_id,
)
except Exception:
logger.exception(
"Failed to broadcast sound_favorited event for sound %s removal",
sound_id,
)
async def remove_playlist_favorite(self, user_id: int, playlist_id: int) -> None:
"""Remove a playlist from user's favorites.
Args:
user_id: The user ID
playlist_id: The playlist ID
Raises:
ValueError: If favorite not found
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
favorite = await favorite_repo.get_by_user_and_playlist(
user_id,
playlist_id,
)
if not favorite:
msg = f"Playlist {playlist_id} is not favorited by user {user_id}"
raise ValueError(msg)
await favorite_repo.delete(favorite)
logger.info(
"User %s removed playlist %s from favorites",
user_id,
playlist_id,
)
async def get_user_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get all favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user favorites
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.get_user_favorites(user_id, limit, offset)
async def get_user_sound_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get sound favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user sound favorites
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.get_user_sound_favorites(user_id, limit, offset)
async def get_user_playlist_favorites(
self,
user_id: int,
limit: int = 100,
offset: int = 0,
) -> list[Favorite]:
"""Get playlist favorites for a user.
Args:
user_id: The user ID
limit: Maximum number of favorites to return
offset: Number of favorites to skip
Returns:
List of user playlist favorites
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.get_user_playlist_favorites(
user_id,
limit,
offset,
)
async def is_sound_favorited(self, user_id: int, sound_id: int) -> bool:
"""Check if a sound is favorited by a user.
Args:
user_id: The user ID
sound_id: The sound ID
Returns:
True if the sound is favorited, False otherwise
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.is_sound_favorited(user_id, sound_id)
async def is_playlist_favorited(self, user_id: int, playlist_id: int) -> bool:
"""Check if a playlist is favorited by a user.
Args:
user_id: The user ID
playlist_id: The playlist ID
Returns:
True if the playlist is favorited, False otherwise
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.is_playlist_favorited(user_id, playlist_id)
async def get_favorite_counts(self, user_id: int) -> dict[str, int]:
"""Get favorite counts for a user.
Args:
user_id: The user ID
Returns:
Dictionary with favorite counts
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
total = await favorite_repo.count_user_favorites(user_id)
sounds = len(await favorite_repo.get_user_sound_favorites(user_id))
playlists = len(await favorite_repo.get_user_playlist_favorites(user_id))
return {
"total": total,
"sounds": sounds,
"playlists": playlists,
}
async def get_sound_favorite_count(self, sound_id: int) -> int:
"""Get the number of users who have favorited a sound.
Args:
sound_id: The sound ID
Returns:
Number of users who favorited this sound
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.count_sound_favorites(sound_id)
async def get_playlist_favorite_count(self, playlist_id: int) -> int:
"""Get the number of users who have favorited a playlist.
Args:
playlist_id: The playlist ID
Returns:
Number of users who favorited this playlist
"""
async with self.db_session_factory() as session:
favorite_repo = FavoriteRepository(session)
return await favorite_repo.count_playlist_favorites(playlist_id)

View File

@@ -1,6 +1,6 @@
"""Playlist service for business logic operations."""
from typing import Any
from typing import Any, TypedDict
from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -14,6 +14,16 @@ from app.repositories.sound import SoundRepository
logger = get_logger(__name__)
class PaginatedPlaylistsResponse(TypedDict):
"""Response type for paginated playlists."""
playlists: list[dict]
total: int
page: int
limit: int
total_pages: int
async def _reload_player_playlist() -> None:
"""Reload the player playlist after current playlist changes."""
try:
@@ -246,6 +256,8 @@ class PlaylistService:
include_stats: bool = False,
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False,
current_user_id: int | None = None,
) -> list[dict]:
"""Search and sort playlists with optional statistics."""
return await self.playlist_repo.search_and_sort(
@@ -256,6 +268,47 @@ class PlaylistService:
include_stats=include_stats,
limit=limit,
offset=offset,
favorites_only=favorites_only,
current_user_id=current_user_id,
)
async def search_and_sort_playlists_paginated( # noqa: PLR0913
self,
search_query: str | None = None,
sort_by: PlaylistSortField | None = None,
sort_order: SortOrder = SortOrder.ASC,
user_id: int | None = None,
*,
include_stats: bool = False,
page: int = 1,
limit: int = 50,
favorites_only: bool = False,
current_user_id: int | None = None,
) -> PaginatedPlaylistsResponse:
"""Search and sort playlists with pagination."""
offset = (page - 1) * limit
playlists, total_count = await self.playlist_repo.search_and_sort(
search_query=search_query,
sort_by=sort_by,
sort_order=sort_order,
user_id=user_id,
include_stats=include_stats,
limit=limit,
offset=offset,
favorites_only=favorites_only,
current_user_id=current_user_id,
return_count=True,
)
total_pages = (total_count + limit - 1) // limit # Ceiling division
return PaginatedPlaylistsResponse(
playlists=playlists,
total=total_count,
page=page,
limit=limit,
total_pages=total_pages,
)
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
@@ -416,7 +469,9 @@ class PlaylistService:
}
async def add_sound_to_main_playlist(
self, sound_id: int, user_id: int, # noqa: ARG002
self,
sound_id: int, # noqa: ARG002
user_id: int, # noqa: ARG002
) -> None:
"""Add a sound to the global main playlist."""
raise HTTPException(
@@ -425,7 +480,9 @@ class PlaylistService:
)
async def _add_sound_to_main_playlist_internal(
self, sound_id: int, user_id: int,
self,
sound_id: int,
user_id: int,
) -> None:
"""Add sound to main playlist bypassing restrictions.

View File

@@ -21,8 +21,6 @@ def mock_plan_repository():
return Mock()
@pytest.fixture
def regular_user():
"""Create regular user for testing."""
@@ -60,52 +58,78 @@ class TestAdminUserEndpoints:
test_plan: Plan,
) -> None:
"""Test listing users successfully."""
with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all:
with patch(
"app.repositories.user.UserRepository.get_all_with_plan_paginated",
) as mock_get_all:
# Create mock user objects that don't trigger database saves
mock_admin = type("User", (), {
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_admin = type(
"User",
(),
{
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_regular = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_regular = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_all.return_value = [mock_admin, mock_regular]
# Mock returns tuple (users, total_count)
mock_get_all.return_value = ([mock_admin, mock_regular], 2)
response = await authenticated_admin_client.get("/api/v1/admin/users/")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["email"] == "admin@example.com"
assert data[1]["email"] == "user@example.com"
mock_get_all.assert_called_once_with(limit=100, offset=0)
assert "users" in data
assert "total" in data
assert "page" in data
assert "limit" in data
assert "total_pages" in data
assert len(data["users"]) == 2
assert data["users"][0]["email"] == "admin@example.com"
assert data["users"][1]["email"] == "user@example.com"
assert data["total"] == 2
assert data["page"] == 1
assert data["limit"] == 50
@pytest.mark.asyncio
async def test_list_users_with_pagination(
@@ -115,29 +139,55 @@ class TestAdminUserEndpoints:
test_plan: Plan,
) -> None:
"""Test listing users with pagination."""
with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all:
mock_admin = type("User", (), {
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_get_all.return_value = [mock_admin]
from app.repositories.user import SortOrder, UserSortField, UserStatus
response = await authenticated_admin_client.get("/api/v1/admin/users/?limit=10&offset=5")
with patch(
"app.repositories.user.UserRepository.get_all_with_plan_paginated",
) as mock_get_all:
mock_admin = type(
"User",
(),
{
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
# Mock returns tuple (users, total_count)
mock_get_all.return_value = ([mock_admin], 1)
response = await authenticated_admin_client.get(
"/api/v1/admin/users/?page=2&limit=10",
)
assert response.status_code == 200
mock_get_all.assert_called_once_with(limit=10, offset=5)
data = response.json()
assert "users" in data
assert data["page"] == 2
assert data["limit"] == 10
mock_get_all.assert_called_once_with(
page=2,
limit=10,
search=None,
sort_by=UserSortField.NAME,
sort_order=SortOrder.ASC,
status_filter=UserStatus.ALL,
)
@pytest.mark.asyncio
async def test_list_users_unauthenticated(self, client: AsyncClient) -> None:
@@ -153,7 +203,9 @@ class TestAdminUserEndpoints:
regular_user: User,
) -> None:
"""Test listing users as non-admin user."""
with patch("app.core.dependencies.get_current_active_user", return_value=regular_user):
with patch(
"app.core.dependencies.get_current_active_user", return_value=regular_user,
):
response = await client.get("/api/v1/admin/users/")
assert response.status_code == 401
@@ -169,24 +221,34 @@ class TestAdminUserEndpoints:
"""Test getting specific user successfully."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
):
mock_user = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_user = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user
response = await authenticated_admin_client.get("/api/v1/admin/users/2")
@@ -207,7 +269,10 @@ class TestAdminUserEndpoints:
"""Test getting non-existent user."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
):
response = await authenticated_admin_client.get("/api/v1/admin/users/999")
@@ -226,43 +291,63 @@ class TestAdminUserEndpoints:
"""Test updating user successfully."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update,
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan),
patch(
"app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan,
),
):
mock_user = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_user = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
updated_mock = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": "Updated Name",
"picture": None,
"role": regular_user.role,
"credits": 200,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
updated_mock = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": "Updated Name",
"picture": None,
"role": regular_user.role,
"credits": 200,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user
mock_update.return_value = updated_mock
@@ -271,7 +356,10 @@ class TestAdminUserEndpoints:
async def mock_refresh(instance, attributes=None):
pass
with patch("sqlmodel.ext.asyncio.session.AsyncSession.refresh", side_effect=mock_refresh):
with patch(
"sqlmodel.ext.asyncio.session.AsyncSession.refresh",
side_effect=mock_refresh,
):
response = await authenticated_admin_client.patch(
"/api/v1/admin/users/2",
json={
@@ -295,7 +383,10 @@ class TestAdminUserEndpoints:
"""Test updating non-existent user."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
):
response = await authenticated_admin_client.patch(
"/api/v1/admin/users/999",
@@ -316,25 +407,35 @@ class TestAdminUserEndpoints:
"""Test updating user with invalid plan."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=None),
):
mock_user = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": 1,
"name": "Basic",
"max_credits": 100,
})(),
})()
mock_user = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": 1,
"name": "Basic",
"max_credits": 100,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user
response = await authenticated_admin_client.patch(
"/api/v1/admin/users/2",
@@ -356,29 +457,41 @@ class TestAdminUserEndpoints:
"""Test disabling user successfully."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update,
):
mock_user = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_user = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user
mock_update.return_value = mock_user
response = await authenticated_admin_client.post("/api/v1/admin/users/2/disable")
response = await authenticated_admin_client.post(
"/api/v1/admin/users/2/disable",
)
assert response.status_code == 200
data = response.json()
@@ -393,9 +506,14 @@ class TestAdminUserEndpoints:
"""Test disabling non-existent user."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
):
response = await authenticated_admin_client.post("/api/v1/admin/users/999/disable")
response = await authenticated_admin_client.post(
"/api/v1/admin/users/999/disable",
)
assert response.status_code == 404
data = response.json()
@@ -421,29 +539,41 @@ class TestAdminUserEndpoints:
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update,
):
mock_disabled_user = type("User", (), {
"id": disabled_user.id,
"email": disabled_user.email,
"name": disabled_user.name,
"picture": None,
"role": disabled_user.role,
"credits": disabled_user.credits,
"is_active": disabled_user.is_active,
"created_at": disabled_user.created_at,
"updated_at": disabled_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_disabled_user = type(
"User",
(),
{
"id": disabled_user.id,
"email": disabled_user.email,
"name": disabled_user.name,
"picture": None,
"role": disabled_user.role,
"credits": disabled_user.credits,
"is_active": disabled_user.is_active,
"created_at": disabled_user.created_at,
"updated_at": disabled_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_disabled_user
mock_update.return_value = mock_disabled_user
response = await authenticated_admin_client.post("/api/v1/admin/users/3/enable")
response = await authenticated_admin_client.post(
"/api/v1/admin/users/3/enable",
)
assert response.status_code == 200
data = response.json()
@@ -458,9 +588,14 @@ class TestAdminUserEndpoints:
"""Test enabling non-existent user."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
):
response = await authenticated_admin_client.post("/api/v1/admin/users/999/enable")
response = await authenticated_admin_client.post(
"/api/v1/admin/users/999/enable",
)
assert response.status_code == 404
data = response.json()
@@ -479,9 +614,14 @@ class TestAdminUserEndpoints:
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.plan.PlanRepository.get_all", return_value=[basic_plan, premium_plan]),
patch(
"app.repositories.plan.PlanRepository.get_all",
return_value=[basic_plan, premium_plan],
),
):
response = await authenticated_admin_client.get("/api/v1/admin/users/plans/list")
response = await authenticated_admin_client.get(
"/api/v1/admin/users/plans/list",
)
assert response.status_code == 200
data = response.json()

View File

@@ -488,11 +488,17 @@ class TestAuthEndpoints:
test_plan: Plan,
) -> None:
"""Test refresh token success."""
with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh:
mock_refresh.return_value = type("TokenResponse", (), {
"access_token": "new_access_token",
"expires_in": 3600,
})()
with patch(
"app.services.auth.AuthService.refresh_access_token",
) as mock_refresh:
mock_refresh.return_value = type(
"TokenResponse",
(),
{
"access_token": "new_access_token",
"expires_in": 3600,
},
)()
response = await test_client.post(
"/api/v1/auth/refresh",
@@ -516,7 +522,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio
async def test_refresh_token_service_error(self, test_client: AsyncClient) -> None:
"""Test refresh token with service error."""
with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh:
with patch(
"app.services.auth.AuthService.refresh_access_token",
) as mock_refresh:
mock_refresh.side_effect = Exception("Database error")
response = await test_client.post(
@@ -528,7 +536,6 @@ class TestAuthEndpoints:
data = response.json()
assert "Token refresh failed" in data["detail"]
@pytest.mark.asyncio
async def test_exchange_oauth_token_invalid_code(
self,
@@ -554,7 +561,9 @@ class TestAuthEndpoints:
"""Test update profile success."""
with (
patch("app.services.auth.AuthService.update_user_profile") as mock_update,
patch("app.services.auth.AuthService.user_to_response") as mock_user_to_response,
patch(
"app.services.auth.AuthService.user_to_response",
) as mock_user_to_response,
):
updated_user = User(
id=test_user.id,
@@ -569,6 +578,7 @@ class TestAuthEndpoints:
# Mock the user_to_response to return UserResponse format
from app.schemas.auth import UserResponse
mock_user_to_response.return_value = UserResponse(
id=test_user.id,
email=test_user.email,
@@ -598,7 +608,9 @@ class TestAuthEndpoints:
assert data["name"] == "Updated Name"
@pytest.mark.asyncio
async def test_update_profile_unauthenticated(self, test_client: AsyncClient) -> None:
async def test_update_profile_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test update profile without authentication."""
response = await test_client.patch(
"/api/v1/auth/me",
@@ -632,7 +644,9 @@ class TestAuthEndpoints:
assert data["message"] == "Password changed successfully"
@pytest.mark.asyncio
async def test_change_password_unauthenticated(self, test_client: AsyncClient) -> None:
async def test_change_password_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test change password without authentication."""
response = await test_client.post(
"/api/v1/auth/change-password",
@@ -652,7 +666,9 @@ class TestAuthEndpoints:
auth_cookies: dict[str, str],
) -> None:
"""Test get user OAuth providers success."""
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
with patch(
"app.services.auth.AuthService.get_user_oauth_providers",
) as mock_providers:
from datetime import datetime
from app.models.user_oauth import UserOauth
@@ -699,7 +715,9 @@ class TestAuthEndpoints:
assert data[2]["display_name"] == "GitHub"
@pytest.mark.asyncio
async def test_get_user_providers_unauthenticated(self, test_client: AsyncClient) -> None:
async def test_get_user_providers_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test get user OAuth providers without authentication."""
response = await test_client.get("/api/v1/auth/user-providers")

View File

@@ -109,9 +109,15 @@ class TestPlaylistEndpoints:
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert "playlists" in data
assert "total" in data
assert "page" in data
assert "limit" in data
assert "total_pages" in data
assert len(data["playlists"]) == 2
assert data["total"] == 2
playlist_names = {p["name"] for p in data}
playlist_names = {p["name"] for p in data["playlists"]}
assert "Test Playlist" in playlist_names
assert "Main Playlist" in playlist_names

View File

@@ -9,7 +9,7 @@ from httpx import AsyncClient
from app.models.user import User
if TYPE_CHECKING:
from app.services.extraction import ExtractionInfo
from app.services.extraction import ExtractionInfo, PaginatedExtractionsResponse
class TestSoundEndpoints:
@@ -32,6 +32,7 @@ class TestSoundEndpoints:
"error": None,
"sound_id": None,
"user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
}
@@ -111,6 +112,7 @@ class TestSoundEndpoints:
"error": None,
"sound_id": 42,
"user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
}
@@ -154,41 +156,49 @@ class TestSoundEndpoints:
authenticated_user: User,
) -> None:
"""Test getting user extractions."""
mock_extractions: list[ExtractionInfo] = [
{
"id": 1,
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"title": "Never Gonna Give You Up",
"service": "youtube",
"service_id": "dQw4w9WgXcQ",
"status": "completed",
"error": None,
"sound_id": 42,
"user_id": authenticated_user.id,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
{
"id": 2,
"url": "https://soundcloud.com/example/track",
"title": "Example Track",
"service": "soundcloud",
"service_id": "example-track",
"status": "pending",
"error": None,
"sound_id": None,
"user_id": authenticated_user.id,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
]
mock_extractions: PaginatedExtractionsResponse = {
"extractions": [
{
"id": 1,
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"title": "Never Gonna Give You Up",
"service": "youtube",
"service_id": "dQw4w9WgXcQ",
"status": "completed",
"error": None,
"sound_id": 42,
"user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
{
"id": 2,
"url": "https://soundcloud.com/example/track",
"title": "Example Track",
"service": "soundcloud",
"service_id": "example-track",
"status": "pending",
"error": None,
"sound_id": None,
"user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
],
"total": 2,
"page": 1,
"limit": 50,
"total_pages": 1,
}
with patch(
"app.services.extraction.ExtractionService.get_user_extractions",
) as mock_get:
mock_get.return_value = mock_extractions
response = await authenticated_client.get("/api/v1/extractions/")
response = await authenticated_client.get("/api/v1/extractions/user")
assert response.status_code == 200
data = response.json()
@@ -337,7 +347,9 @@ class TestSoundEndpoints:
"""Test getting sounds with authentication."""
from app.models.sound import Sound
with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get:
with patch(
"app.repositories.sound.SoundRepository.search_and_sort",
) as mock_get:
# Create mock sounds with all required fields
mock_sound_1 = Sound(
id=1,
@@ -383,7 +395,9 @@ class TestSoundEndpoints:
"""Test getting sounds with type filtering."""
from app.models.sound import Sound
with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get:
with patch(
"app.repositories.sound.SoundRepository.search_and_sort",
) as mock_get:
# Create mock sound with all required fields
mock_sound = Sound(
id=1,

View File

@@ -335,5 +335,3 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
access_token = JWTUtils.create_access_token(token_data)
return {"access_token": access_token}

View File

@@ -539,21 +539,35 @@ class TestPlaylistRepository:
sound_ids = [s.id for s in sounds]
# Add first two sounds sequentially (positions 0, 1)
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[0],
) # position 0
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[1],
) # position 1
# Now insert third sound at position 1 - should shift existing sound at position 1 to position 2
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=1)
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[2], position=1,
)
# Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
playlist_sounds = await playlist_repository.get_playlist_sound_entries(
playlist_id,
)
assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0
assert (
playlist_sounds[0].sound_id == sound_ids[0]
) # Original sound 0 stays at position 0
assert playlist_sounds[0].position == 0
assert playlist_sounds[1].sound_id == sound_ids[2] # New sound 2 inserted at position 1
assert (
playlist_sounds[1].sound_id == sound_ids[2]
) # New sound 2 inserted at position 1
assert playlist_sounds[1].position == 1
assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2
assert (
playlist_sounds[2].sound_id == sound_ids[1]
) # Original sound 1 shifted to position 2
assert playlist_sounds[2].position == 2
@pytest.mark.asyncio
@@ -615,21 +629,35 @@ class TestPlaylistRepository:
sound_ids = [s.id for s in sounds]
# Add first two sounds sequentially (positions 0, 1)
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[0],
) # position 0
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[1],
) # position 1
# Now insert third sound at position 0 - should shift existing sounds to positions 1, 2
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=0)
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[2], position=0,
)
# Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
playlist_sounds = await playlist_repository.get_playlist_sound_entries(
playlist_id,
)
assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0
assert (
playlist_sounds[0].sound_id == sound_ids[2]
) # New sound 2 inserted at position 0
assert playlist_sounds[0].position == 0
assert playlist_sounds[1].sound_id == sound_ids[0] # Original sound 0 shifted to position 1
assert (
playlist_sounds[1].sound_id == sound_ids[0]
) # Original sound 0 shifted to position 1
assert playlist_sounds[1].position == 1
assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2
assert (
playlist_sounds[2].sound_id == sound_ids[1]
) # Original sound 1 shifted to position 2
assert playlist_sounds[2].position == 2
@pytest.mark.asyncio

View File

@@ -409,7 +409,9 @@ class TestCreditService:
@pytest.mark.asyncio
async def test_recharge_user_credits_success(
self, credit_service, sample_user,
self,
credit_service,
sample_user,
) -> None:
"""Test successful credit recharge for a user."""
mock_session = credit_service.db_session_factory()

View File

@@ -43,7 +43,9 @@ class TestDashboardService:
"total_duration": 75000,
"total_size": 1024000,
}
mock_sound_repository.get_soundboard_statistics = AsyncMock(return_value=mock_stats)
mock_sound_repository.get_soundboard_statistics = AsyncMock(
return_value=mock_stats,
)
result = await dashboard_service.get_soundboard_statistics()

View File

@@ -99,6 +99,11 @@ class TestExtractionService:
url = "https://www.youtube.com/watch?v=test123"
user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - no service detection happens during creation
mock_extraction = Extraction(
id=1,
@@ -120,6 +125,7 @@ class TestExtractionService:
assert result["service_id"] is None # Not detected during creation
assert result["title"] is None # Not detected during creation
assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_create_extraction_basic(self, extraction_service) -> None:
@@ -127,6 +133,11 @@ class TestExtractionService:
url = "https://www.youtube.com/watch?v=test123"
user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - creation always succeeds now
mock_extraction = Extraction(
id=2,
@@ -146,6 +157,7 @@ class TestExtractionService:
assert result["id"] == 2
assert result["url"] == url
assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_create_extraction_any_url(self, extraction_service) -> None:
@@ -153,6 +165,11 @@ class TestExtractionService:
url = "https://invalid.url"
user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - even invalid URLs are accepted during creation
mock_extraction = Extraction(
id=3,
@@ -172,6 +189,7 @@ class TestExtractionService:
assert result["id"] == 3
assert result["url"] == url
assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_process_extraction_with_service_detection(
@@ -408,9 +426,16 @@ class TestExtractionService:
sound_id=42,
)
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=extraction,
)
extraction_service.user_repo.get_by_id = AsyncMock(
return_value=mock_user,
)
result = await extraction_service.get_extraction_by_id(1)
@@ -421,6 +446,7 @@ class TestExtractionService:
assert result["title"] == "Test Video"
assert result["status"] == "completed"
assert result["sound_id"] == 42
assert result["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_get_extraction_by_id_not_found(self, extraction_service) -> None:
@@ -434,52 +460,70 @@ class TestExtractionService:
@pytest.mark.asyncio
async def test_get_user_extractions(self, extraction_service) -> None:
"""Test getting user extractions."""
from app.models.user import User
user = User(id=1, name="Test User", email="test@example.com")
extractions = [
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Test Video 1",
status="completed",
sound_id=42,
(
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Test Video 1",
status="completed",
sound_id=42,
),
user,
),
Extraction(
id=2,
service="youtube",
service_id="test456",
url="https://www.youtube.com/watch?v=test456",
user_id=1,
title="Test Video 2",
status="pending",
(
Extraction(
id=2,
service="youtube",
service_id="test456",
url="https://www.youtube.com/watch?v=test456",
user_id=1,
title="Test Video 2",
status="pending",
),
user,
),
]
extraction_service.extraction_repo.get_by_user = AsyncMock(
return_value=extractions,
extraction_service.extraction_repo.get_user_extractions_filtered = AsyncMock(
return_value=(extractions, 2),
)
result = await extraction_service.get_user_extractions(1)
assert len(result) == 2
assert result[0]["id"] == 1
assert result[0]["title"] == "Test Video 1"
assert result[1]["id"] == 2
assert result[1]["title"] == "Test Video 2"
assert result["total"] == 2
assert len(result["extractions"]) == 2
assert result["extractions"][0]["id"] == 1
assert result["extractions"][0]["title"] == "Test Video 1"
assert result["extractions"][0]["user_name"] == "Test User"
assert result["extractions"][1]["id"] == 2
assert result["extractions"][1]["title"] == "Test Video 2"
assert result["extractions"][1]["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_get_pending_extractions(self, extraction_service) -> None:
"""Test getting pending extractions."""
from app.models.user import User
user = User(id=1, name="Test User", email="test@example.com")
pending_extractions = [
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Pending Video",
status="pending",
(
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Pending Video",
status="pending",
),
user,
),
]
@@ -492,3 +536,4 @@ class TestExtractionService:
assert len(result) == 1
assert result[0]["id"] == 1
assert result[0]["status"] == "pending"
assert result[0]["user_name"] == "Test User"

View File

@@ -25,9 +25,10 @@ class TestSchedulerService:
@pytest.mark.asyncio
async def test_start_scheduler(self, scheduler_service) -> None:
"""Test starting the scheduler service."""
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
patch.object(scheduler_service.scheduler, "start") as mock_start:
with (
patch.object(scheduler_service.scheduler, "add_job") as mock_add_job,
patch.object(scheduler_service.scheduler, "start") as mock_start,
):
await scheduler_service.start()
# Verify job was added
@@ -61,7 +62,9 @@ class TestSchedulerService:
"total_credits_added": 500,
}
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
with patch.object(
scheduler_service.credit_service, "recharge_all_users_credits",
) as mock_recharge:
mock_recharge.return_value = mock_stats
await scheduler_service._daily_credit_recharge()
@@ -71,7 +74,9 @@ class TestSchedulerService:
@pytest.mark.asyncio
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
"""Test daily credit recharge task with failure."""
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
with patch.object(
scheduler_service.credit_service, "recharge_all_users_credits",
) as mock_recharge:
mock_recharge.side_effect = Exception("Database error")
# Should not raise exception, just log it