Compare commits
12 Commits
5e6cc04ad2
...
a660cc1861
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a660cc1861 | ||
|
|
6b55ff0e81 | ||
|
|
e6f796a3c9 | ||
|
|
99c757a073 | ||
|
|
f598ec2c12 | ||
|
|
66d22df7dd | ||
|
|
3326e406f8 | ||
|
|
fe15e7a6af | ||
|
|
f56cc8b4cc | ||
|
|
f906b6d643 | ||
|
|
78508c84eb | ||
|
|
a947fd830b |
@@ -7,6 +7,7 @@ from app.api.v1 import (
|
|||||||
auth,
|
auth,
|
||||||
dashboard,
|
dashboard,
|
||||||
extractions,
|
extractions,
|
||||||
|
favorites,
|
||||||
files,
|
files,
|
||||||
main,
|
main,
|
||||||
player,
|
player,
|
||||||
@@ -22,6 +23,7 @@ api_router = APIRouter(prefix="/v1")
|
|||||||
api_router.include_router(auth.router, tags=["authentication"])
|
api_router.include_router(auth.router, tags=["authentication"])
|
||||||
api_router.include_router(dashboard.router, tags=["dashboard"])
|
api_router.include_router(dashboard.router, tags=["dashboard"])
|
||||||
api_router.include_router(extractions.router, tags=["extractions"])
|
api_router.include_router(extractions.router, tags=["extractions"])
|
||||||
|
api_router.include_router(favorites.router, tags=["favorites"])
|
||||||
api_router.include_router(files.router, tags=["files"])
|
api_router.include_router(files.router, tags=["files"])
|
||||||
api_router.include_router(main.router, tags=["main"])
|
api_router.include_router(main.router, tags=["main"])
|
||||||
api_router.include_router(player.router, tags=["player"])
|
api_router.include_router(player.router, tags=["player"])
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Admin users endpoints."""
|
"""Admin users endpoints."""
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
@@ -10,7 +10,7 @@ from app.core.dependencies import get_admin_user
|
|||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.plan import PlanRepository
|
from app.repositories.plan import PlanRepository
|
||||||
from app.repositories.user import UserRepository
|
from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
|
||||||
from app.schemas.auth import UserResponse
|
from app.schemas.auth import UserResponse
|
||||||
from app.schemas.user import UserUpdate
|
from app.schemas.user import UserUpdate
|
||||||
|
|
||||||
@@ -36,22 +36,48 @@ def _user_to_response(user: User) -> UserResponse:
|
|||||||
"name": user.plan.name,
|
"name": user.plan.name,
|
||||||
"max_credits": user.plan.max_credits,
|
"max_credits": user.plan.max_credits,
|
||||||
"features": [], # Add features if needed
|
"features": [], # Add features if needed
|
||||||
} if user.plan else {},
|
}
|
||||||
|
if user.plan
|
||||||
|
else {},
|
||||||
created_at=user.created_at,
|
created_at=user.created_at,
|
||||||
updated_at=user.updated_at,
|
updated_at=user.updated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def list_users(
|
async def list_users( # noqa: PLR0913
|
||||||
session: Annotated[AsyncSession, Depends(get_db)],
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
limit: int = 100,
|
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||||
offset: int = 0,
|
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||||
) -> list[UserResponse]:
|
search: Annotated[str | None, Query(description="Search in name or email")] = None,
|
||||||
"""Get all users (admin only)."""
|
sort_by: Annotated[
|
||||||
|
UserSortField, Query(description="Sort by field"),
|
||||||
|
] = UserSortField.NAME,
|
||||||
|
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
|
||||||
|
status_filter: Annotated[
|
||||||
|
UserStatus, Query(description="Filter by status"),
|
||||||
|
] = UserStatus.ALL,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Get all users with pagination, search, and filters (admin only)."""
|
||||||
user_repo = UserRepository(session)
|
user_repo = UserRepository(session)
|
||||||
users = await user_repo.get_all_with_plan(limit=limit, offset=offset)
|
users, total_count = await user_repo.get_all_with_plan_paginated(
|
||||||
return [_user_to_response(user) for user in users]
|
page=page,
|
||||||
|
limit=limit,
|
||||||
|
search=search,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
status_filter=status_filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||||
|
|
||||||
|
return {
|
||||||
|
"users": [_user_to_response(user) for user in users],
|
||||||
|
"total": total_count,
|
||||||
|
"page": page,
|
||||||
|
"limit": limit,
|
||||||
|
"total_pages": total_pages,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}")
|
@router.get("/{user_id}")
|
||||||
|
|||||||
@@ -464,7 +464,8 @@ async def update_profile(
|
|||||||
"""Update the current user's profile."""
|
"""Update the current user's profile."""
|
||||||
try:
|
try:
|
||||||
updated_user = await auth_service.update_user_profile(
|
updated_user = await auth_service.update_user_profile(
|
||||||
current_user, request.model_dump(exclude_unset=True),
|
current_user,
|
||||||
|
request.model_dump(exclude_unset=True),
|
||||||
)
|
)
|
||||||
return await auth_service.user_to_response(updated_user)
|
return await auth_service.user_to_response(updated_user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -486,7 +487,9 @@ async def change_password(
|
|||||||
user_email = current_user.email
|
user_email = current_user.email
|
||||||
try:
|
try:
|
||||||
await auth_service.change_user_password(
|
await auth_service.change_user_password(
|
||||||
current_user, request.current_password, request.new_password,
|
current_user,
|
||||||
|
request.current_password,
|
||||||
|
request.new_password,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -513,11 +516,13 @@ async def get_user_providers(
|
|||||||
|
|
||||||
# Add password provider if user has password
|
# Add password provider if user has password
|
||||||
if current_user.password_hash:
|
if current_user.password_hash:
|
||||||
providers.append({
|
providers.append(
|
||||||
"provider": "password",
|
{
|
||||||
"display_name": "Password",
|
"provider": "password",
|
||||||
"connected_at": current_user.created_at.isoformat(),
|
"display_name": "Password",
|
||||||
})
|
"connected_at": current_user.created_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Get OAuth providers from the database
|
# Get OAuth providers from the database
|
||||||
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
||||||
@@ -528,10 +533,12 @@ async def get_user_providers(
|
|||||||
elif oauth.provider == "google":
|
elif oauth.provider == "google":
|
||||||
display_name = "Google"
|
display_name = "Google"
|
||||||
|
|
||||||
providers.append({
|
providers.append(
|
||||||
"provider": oauth.provider,
|
{
|
||||||
"display_name": display_name,
|
"provider": oauth.provider,
|
||||||
"connected_at": oauth.created_at.isoformat(),
|
"display_name": display_name,
|
||||||
})
|
"connected_at": oauth.created_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return providers
|
return providers
|
||||||
|
|||||||
@@ -34,7 +34,8 @@ async def get_top_sounds(
|
|||||||
_current_user: Annotated[User, Depends(get_current_user)],
|
_current_user: Annotated[User, Depends(get_current_user)],
|
||||||
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
|
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
|
||||||
sound_type: Annotated[
|
sound_type: Annotated[
|
||||||
str, Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
|
str,
|
||||||
|
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
|
||||||
],
|
],
|
||||||
period: Annotated[
|
period: Annotated[
|
||||||
str,
|
str,
|
||||||
@@ -43,7 +44,8 @@ async def get_top_sounds(
|
|||||||
),
|
),
|
||||||
] = "all_time",
|
] = "all_time",
|
||||||
limit: Annotated[
|
limit: Annotated[
|
||||||
int, Query(description="Number of top sounds to return", ge=1, le=100),
|
int,
|
||||||
|
Query(description="Number of top sounds to return", ge=1, le=100),
|
||||||
] = 10,
|
] = 10,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Get top sounds by play count for a specific period."""
|
"""Get top sounds by play count for a specific period."""
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
@@ -60,6 +60,46 @@ async def create_extraction(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user")
|
||||||
|
async def get_user_extractions( # noqa: PLR0913
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
|
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||||
|
search: Annotated[
|
||||||
|
str | None, Query(description="Search in title, URL, or service"),
|
||||||
|
] = None,
|
||||||
|
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
|
||||||
|
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
|
||||||
|
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||||
|
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||||
|
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""Get all extractions for the current user."""
|
||||||
|
try:
|
||||||
|
if current_user.id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User ID not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await extraction_service.get_user_extractions(
|
||||||
|
user_id=current_user.id,
|
||||||
|
search=search,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
status_filter=status_filter,
|
||||||
|
page=page,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to get extractions: {e!s}",
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{extraction_id}")
|
@router.get("/{extraction_id}")
|
||||||
async def get_extraction(
|
async def get_extraction(
|
||||||
extraction_id: int,
|
extraction_id: int,
|
||||||
@@ -88,19 +128,27 @@ async def get_extraction(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def get_user_extractions(
|
async def get_all_extractions( # noqa: PLR0913
|
||||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
|
||||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||||
) -> dict[str, list[ExtractionInfo]]:
|
search: Annotated[
|
||||||
"""Get all extractions for the current user."""
|
str | None, Query(description="Search in title, URL, or service"),
|
||||||
|
] = None,
|
||||||
|
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
|
||||||
|
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
|
||||||
|
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
|
||||||
|
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||||
|
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||||
|
) -> dict:
|
||||||
|
"""Get all extractions with optional filtering, search, and sorting."""
|
||||||
try:
|
try:
|
||||||
if current_user.id is None:
|
result = await extraction_service.get_all_extractions(
|
||||||
raise HTTPException(
|
search=search,
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
sort_by=sort_by,
|
||||||
detail="User ID not available",
|
sort_order=sort_order,
|
||||||
)
|
status_filter=status_filter,
|
||||||
|
page=page,
|
||||||
extractions = await extraction_service.get_user_extractions(current_user.id)
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -108,6 +156,4 @@ async def get_user_extractions(
|
|||||||
detail=f"Failed to get extractions: {e!s}",
|
detail=f"Failed to get extractions: {e!s}",
|
||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
return {
|
return result
|
||||||
"extractions": extractions,
|
|
||||||
}
|
|
||||||
|
|||||||
197
app/api/v1/favorites.py
Normal file
197
app/api/v1/favorites.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""Favorites management API endpoints."""
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
|
||||||
|
from app.core.database import get_session_factory
|
||||||
|
from app.core.dependencies import get_current_active_user
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.common import MessageResponse
|
||||||
|
from app.schemas.favorite import (
|
||||||
|
FavoriteCountsResponse,
|
||||||
|
FavoriteResponse,
|
||||||
|
FavoritesListResponse,
|
||||||
|
)
|
||||||
|
from app.services.favorite import FavoriteService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/favorites", tags=["favorites"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_favorite_service() -> FavoriteService:
|
||||||
|
"""Get the favorite service."""
|
||||||
|
return FavoriteService(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def get_user_favorites(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||||
|
offset: Annotated[int, Query(ge=0)] = 0,
|
||||||
|
) -> FavoritesListResponse:
|
||||||
|
"""Get all favorites for the current user."""
|
||||||
|
favorites = await favorite_service.get_user_favorites(
|
||||||
|
current_user.id,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
)
|
||||||
|
return FavoritesListResponse(favorites=favorites)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sounds")
|
||||||
|
async def get_user_sound_favorites(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||||
|
offset: Annotated[int, Query(ge=0)] = 0,
|
||||||
|
) -> FavoritesListResponse:
|
||||||
|
"""Get sound favorites for the current user."""
|
||||||
|
favorites = await favorite_service.get_user_sound_favorites(
|
||||||
|
current_user.id,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
)
|
||||||
|
return FavoritesListResponse(favorites=favorites)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/playlists")
|
||||||
|
async def get_user_playlist_favorites(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||||
|
offset: Annotated[int, Query(ge=0)] = 0,
|
||||||
|
) -> FavoritesListResponse:
|
||||||
|
"""Get playlist favorites for the current user."""
|
||||||
|
favorites = await favorite_service.get_user_playlist_favorites(
|
||||||
|
current_user.id,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
)
|
||||||
|
return FavoritesListResponse(favorites=favorites)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/counts")
|
||||||
|
async def get_favorite_counts(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
) -> FavoriteCountsResponse:
|
||||||
|
"""Get favorite counts for the current user."""
|
||||||
|
counts = await favorite_service.get_favorite_counts(current_user.id)
|
||||||
|
return FavoriteCountsResponse(**counts)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/sounds/{sound_id}")
|
||||||
|
async def add_sound_favorite(
|
||||||
|
sound_id: int,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
) -> FavoriteResponse:
|
||||||
|
"""Add a sound to favorites."""
|
||||||
|
try:
|
||||||
|
favorite = await favorite_service.add_sound_favorite(current_user.id, sound_id)
|
||||||
|
return FavoriteResponse.model_validate(favorite)
|
||||||
|
except ValueError as e:
|
||||||
|
if "not found" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
if "already favorited" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/playlists/{playlist_id}")
|
||||||
|
async def add_playlist_favorite(
|
||||||
|
playlist_id: int,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
) -> FavoriteResponse:
|
||||||
|
"""Add a playlist to favorites."""
|
||||||
|
try:
|
||||||
|
favorite = await favorite_service.add_playlist_favorite(
|
||||||
|
current_user.id,
|
||||||
|
playlist_id,
|
||||||
|
)
|
||||||
|
return FavoriteResponse.model_validate(favorite)
|
||||||
|
except ValueError as e:
|
||||||
|
if "not found" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
if "already favorited" in str(e):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sounds/{sound_id}")
|
||||||
|
async def remove_sound_favorite(
|
||||||
|
sound_id: int,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
) -> MessageResponse:
|
||||||
|
"""Remove a sound from favorites."""
|
||||||
|
try:
|
||||||
|
await favorite_service.remove_sound_favorite(current_user.id, sound_id)
|
||||||
|
return MessageResponse(message="Sound removed from favorites")
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/playlists/{playlist_id}")
|
||||||
|
async def remove_playlist_favorite(
|
||||||
|
playlist_id: int,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
) -> MessageResponse:
|
||||||
|
"""Remove a playlist from favorites."""
|
||||||
|
try:
|
||||||
|
await favorite_service.remove_playlist_favorite(current_user.id, playlist_id)
|
||||||
|
return MessageResponse(message="Playlist removed from favorites")
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sounds/{sound_id}/check")
|
||||||
|
async def check_sound_favorited(
|
||||||
|
sound_id: int,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Check if a sound is favorited by the current user."""
|
||||||
|
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound_id)
|
||||||
|
return {"is_favorited": is_favorited}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/playlists/{playlist_id}/check")
|
||||||
|
async def check_playlist_favorited(
|
||||||
|
playlist_id: int,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Check if a playlist is favorited by the current user."""
|
||||||
|
is_favorited = await favorite_service.is_playlist_favorited(
|
||||||
|
current_user.id,
|
||||||
|
playlist_id,
|
||||||
|
)
|
||||||
|
return {"is_favorited": is_favorited}
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Playlist management API endpoints."""
|
"""Playlist management API endpoints."""
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db, get_session_factory
|
||||||
from app.core.dependencies import get_current_active_user_flexible
|
from app.core.dependencies import get_current_active_user_flexible
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||||
@@ -19,6 +19,7 @@ from app.schemas.playlist import (
|
|||||||
PlaylistStatsResponse,
|
PlaylistStatsResponse,
|
||||||
PlaylistUpdateRequest,
|
PlaylistUpdateRequest,
|
||||||
)
|
)
|
||||||
|
from app.services.favorite import FavoriteService
|
||||||
from app.services.playlist import PlaylistService
|
from app.services.playlist import PlaylistService
|
||||||
|
|
||||||
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
||||||
@@ -31,10 +32,16 @@ async def get_playlist_service(
|
|||||||
return PlaylistService(session)
|
return PlaylistService(session)
|
||||||
|
|
||||||
|
|
||||||
|
def get_favorite_service() -> FavoriteService:
|
||||||
|
"""Get the favorite service."""
|
||||||
|
return FavoriteService(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def get_all_playlists( # noqa: PLR0913
|
async def get_all_playlists( # noqa: PLR0913
|
||||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
search: Annotated[
|
search: Annotated[
|
||||||
str | None,
|
str | None,
|
||||||
Query(description="Search playlists by name"),
|
Query(description="Search playlists by name"),
|
||||||
@@ -47,55 +54,115 @@ async def get_all_playlists( # noqa: PLR0913
|
|||||||
SortOrder,
|
SortOrder,
|
||||||
Query(description="Sort order (asc or desc)"),
|
Query(description="Sort order (asc or desc)"),
|
||||||
] = SortOrder.ASC,
|
] = SortOrder.ASC,
|
||||||
limit: Annotated[
|
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
|
||||||
int | None,
|
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
|
||||||
Query(description="Maximum number of results", ge=1, le=1000),
|
favorites_only: Annotated[ # noqa: FBT002
|
||||||
] = None,
|
bool,
|
||||||
offset: Annotated[
|
Query(description="Show only favorited playlists"),
|
||||||
int,
|
] = False,
|
||||||
Query(description="Number of results to skip", ge=0),
|
) -> dict[str, Any]:
|
||||||
] = 0,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Get all playlists from all users with search and sorting."""
|
"""Get all playlists from all users with search and sorting."""
|
||||||
return await playlist_service.search_and_sort_playlists(
|
result = await playlist_service.search_and_sort_playlists_paginated(
|
||||||
search_query=search,
|
search_query=search,
|
||||||
sort_by=sort_by,
|
sort_by=sort_by,
|
||||||
sort_order=sort_order,
|
sort_order=sort_order,
|
||||||
user_id=None,
|
user_id=None,
|
||||||
include_stats=True,
|
include_stats=True,
|
||||||
|
page=page,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset,
|
favorites_only=favorites_only,
|
||||||
|
current_user_id=current_user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convert to PlaylistResponse with favorite indicators
|
||||||
|
playlist_responses = []
|
||||||
|
for playlist_dict in result["playlists"]:
|
||||||
|
# The playlist service returns dict, need to create playlist object structure
|
||||||
|
playlist_id = playlist_dict["id"]
|
||||||
|
is_favorited = await favorite_service.is_playlist_favorited(
|
||||||
|
current_user.id, playlist_id,
|
||||||
|
)
|
||||||
|
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
|
||||||
|
|
||||||
|
# Create a PlaylistResponse-like dict with proper datetime conversion
|
||||||
|
playlist_response = {
|
||||||
|
**playlist_dict,
|
||||||
|
"created_at": (
|
||||||
|
playlist_dict["created_at"].isoformat()
|
||||||
|
if playlist_dict["created_at"]
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"updated_at": (
|
||||||
|
playlist_dict["updated_at"].isoformat()
|
||||||
|
if playlist_dict["updated_at"]
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"is_favorited": is_favorited,
|
||||||
|
"favorite_count": favorite_count,
|
||||||
|
}
|
||||||
|
playlist_responses.append(playlist_response)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"playlists": playlist_responses,
|
||||||
|
"total": result["total"],
|
||||||
|
"page": result["page"],
|
||||||
|
"limit": result["limit"],
|
||||||
|
"total_pages": result["total_pages"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/user")
|
@router.get("/user")
|
||||||
async def get_user_playlists(
|
async def get_user_playlists(
|
||||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
) -> list[PlaylistResponse]:
|
) -> list[PlaylistResponse]:
|
||||||
"""Get playlists for the current user only."""
|
"""Get playlists for the current user only."""
|
||||||
playlists = await playlist_service.get_user_playlists(current_user.id)
|
playlists = await playlist_service.get_user_playlists(current_user.id)
|
||||||
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
|
|
||||||
|
# Add favorite indicators for each playlist
|
||||||
|
playlist_responses = []
|
||||||
|
for playlist in playlists:
|
||||||
|
is_favorited = await favorite_service.is_playlist_favorited(
|
||||||
|
current_user.id, playlist.id,
|
||||||
|
)
|
||||||
|
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||||
|
playlist_response = PlaylistResponse.from_playlist(
|
||||||
|
playlist, is_favorited, favorite_count,
|
||||||
|
)
|
||||||
|
playlist_responses.append(playlist_response)
|
||||||
|
|
||||||
|
return playlist_responses
|
||||||
|
|
||||||
|
|
||||||
@router.get("/main")
|
@router.get("/main")
|
||||||
async def get_main_playlist(
|
async def get_main_playlist(
|
||||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
) -> PlaylistResponse:
|
) -> PlaylistResponse:
|
||||||
"""Get the global main playlist."""
|
"""Get the global main playlist."""
|
||||||
playlist = await playlist_service.get_main_playlist()
|
playlist = await playlist_service.get_main_playlist()
|
||||||
return PlaylistResponse.from_playlist(playlist)
|
is_favorited = await favorite_service.is_playlist_favorited(
|
||||||
|
current_user.id, playlist.id,
|
||||||
|
)
|
||||||
|
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||||
|
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/current")
|
@router.get("/current")
|
||||||
async def get_current_playlist(
|
async def get_current_playlist(
|
||||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
) -> PlaylistResponse:
|
) -> PlaylistResponse:
|
||||||
"""Get the global current playlist (falls back to main playlist)."""
|
"""Get the global current playlist (falls back to main playlist)."""
|
||||||
playlist = await playlist_service.get_current_playlist()
|
playlist = await playlist_service.get_current_playlist()
|
||||||
return PlaylistResponse.from_playlist(playlist)
|
is_favorited = await favorite_service.is_playlist_favorited(
|
||||||
|
current_user.id, playlist.id,
|
||||||
|
)
|
||||||
|
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||||
|
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/")
|
@router.post("/")
|
||||||
@@ -117,12 +184,17 @@ async def create_playlist(
|
|||||||
@router.get("/{playlist_id}")
|
@router.get("/{playlist_id}")
|
||||||
async def get_playlist(
|
async def get_playlist(
|
||||||
playlist_id: int,
|
playlist_id: int,
|
||||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
) -> PlaylistResponse:
|
) -> PlaylistResponse:
|
||||||
"""Get a specific playlist."""
|
"""Get a specific playlist."""
|
||||||
playlist = await playlist_service.get_playlist_by_id(playlist_id)
|
playlist = await playlist_service.get_playlist_by_id(playlist_id)
|
||||||
return PlaylistResponse.from_playlist(playlist)
|
is_favorited = await favorite_service.is_playlist_favorited(
|
||||||
|
current_user.id, playlist.id,
|
||||||
|
)
|
||||||
|
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
|
||||||
|
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{playlist_id}")
|
@router.put("/{playlist_id}")
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from app.core.database import get_db, get_session_factory
|
from app.core.database import get_db, get_session_factory
|
||||||
from app.core.dependencies import get_current_active_user_flexible
|
from app.core.dependencies import get_current_active_user_flexible
|
||||||
from app.models.credit_action import CreditActionType
|
from app.models.credit_action import CreditActionType
|
||||||
from app.models.sound import Sound
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
|
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
|
||||||
|
from app.schemas.sound import SoundResponse, SoundsListResponse
|
||||||
from app.services.credit import CreditService, InsufficientCreditsError
|
from app.services.credit import CreditService, InsufficientCreditsError
|
||||||
|
from app.services.favorite import FavoriteService
|
||||||
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
||||||
|
|
||||||
router = APIRouter(prefix="/sounds", tags=["sounds"])
|
router = APIRouter(prefix="/sounds", tags=["sounds"])
|
||||||
@@ -27,6 +28,11 @@ def get_credit_service() -> CreditService:
|
|||||||
return CreditService(get_session_factory())
|
return CreditService(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
|
def get_favorite_service() -> FavoriteService:
|
||||||
|
"""Get the favorite service."""
|
||||||
|
return FavoriteService(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
async def get_sound_repository(
|
async def get_sound_repository(
|
||||||
session: Annotated[AsyncSession, Depends(get_db)],
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
) -> SoundRepository:
|
) -> SoundRepository:
|
||||||
@@ -36,8 +42,9 @@ async def get_sound_repository(
|
|||||||
|
|
||||||
@router.get("/")
|
@router.get("/")
|
||||||
async def get_sounds( # noqa: PLR0913
|
async def get_sounds( # noqa: PLR0913
|
||||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
|
||||||
|
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
|
||||||
types: Annotated[
|
types: Annotated[
|
||||||
list[str] | None,
|
list[str] | None,
|
||||||
Query(description="Filter by sound types (e.g., SDB, TTS, EXT)"),
|
Query(description="Filter by sound types (e.g., SDB, TTS, EXT)"),
|
||||||
@@ -62,7 +69,11 @@ async def get_sounds( # noqa: PLR0913
|
|||||||
int,
|
int,
|
||||||
Query(description="Number of results to skip", ge=0),
|
Query(description="Number of results to skip", ge=0),
|
||||||
] = 0,
|
] = 0,
|
||||||
) -> dict[str, list[Sound]]:
|
favorites_only: Annotated[ # noqa: FBT002
|
||||||
|
bool,
|
||||||
|
Query(description="Show only favorited sounds"),
|
||||||
|
] = False,
|
||||||
|
) -> SoundsListResponse:
|
||||||
"""Get sounds with optional search, filtering, and sorting."""
|
"""Get sounds with optional search, filtering, and sorting."""
|
||||||
try:
|
try:
|
||||||
sounds = await sound_repo.search_and_sort(
|
sounds = await sound_repo.search_and_sort(
|
||||||
@@ -72,14 +83,29 @@ async def get_sounds( # noqa: PLR0913
|
|||||||
sort_order=sort_order,
|
sort_order=sort_order,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
favorites_only=favorites_only,
|
||||||
|
user_id=current_user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add favorite indicators for each sound
|
||||||
|
sound_responses = []
|
||||||
|
for sound in sounds:
|
||||||
|
is_favorited = await favorite_service.is_sound_favorited(
|
||||||
|
current_user.id, sound.id,
|
||||||
|
)
|
||||||
|
favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
|
||||||
|
sound_response = SoundResponse.from_sound(
|
||||||
|
sound, is_favorited, favorite_count,
|
||||||
|
)
|
||||||
|
sound_responses.append(sound_response)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to get sounds: {e!s}",
|
detail=f"Failed to get sounds: {e!s}",
|
||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
return {"sounds": sounds}
|
return SoundsListResponse(sounds=sound_responses)
|
||||||
|
|
||||||
|
|
||||||
# VLC PLAYER
|
# VLC PLAYER
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# Production URLs (for reverse proxy deployment)
|
# Production URLs (for reverse proxy deployment)
|
||||||
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
|
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
|
||||||
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
|
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
|
||||||
|
|
||||||
# CORS Configuration
|
# CORS Configuration
|
||||||
CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS
|
CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from app.core.logging import get_logger
|
|||||||
from app.core.seeds import seed_all_data
|
from app.core.seeds import seed_all_data
|
||||||
from app.models import ( # noqa: F401
|
from app.models import ( # noqa: F401
|
||||||
extraction,
|
extraction,
|
||||||
|
favorite,
|
||||||
plan,
|
plan,
|
||||||
playlist,
|
playlist,
|
||||||
playlist_sound,
|
playlist_sound,
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ class BaseModel(SQLModel):
|
|||||||
# SQLAlchemy event listener to automatically update updated_at timestamp
|
# SQLAlchemy event listener to automatically update updated_at timestamp
|
||||||
@event.listens_for(BaseModel, "before_update", propagate=True)
|
@event.listens_for(BaseModel, "before_update", propagate=True)
|
||||||
def update_timestamp(
|
def update_timestamp(
|
||||||
mapper: Mapper[Any], connection: Connection, target: BaseModel, # noqa: ARG001
|
mapper: Mapper[Any], # noqa: ARG001
|
||||||
|
connection: Connection, # noqa: ARG001
|
||||||
|
target: BaseModel,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Automatically set updated_at timestamp before update operations."""
|
"""Automatically set updated_at timestamp before update operations."""
|
||||||
target.updated_at = datetime.now(UTC)
|
target.updated_at = datetime.now(UTC)
|
||||||
|
|||||||
29
app/models/favorite.py
Normal file
29
app/models/favorite.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||||
|
|
||||||
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.models.playlist import Playlist
|
||||||
|
from app.models.sound import Sound
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
class Favorite(BaseModel, table=True):
|
||||||
|
"""Database model for user favorites (sounds and playlists)."""
|
||||||
|
|
||||||
|
user_id: int = Field(foreign_key="user.id", nullable=False)
|
||||||
|
sound_id: int | None = Field(foreign_key="sound.id", default=None)
|
||||||
|
playlist_id: int | None = Field(foreign_key="playlist.id", default=None)
|
||||||
|
|
||||||
|
# constraints
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("user_id", "sound_id", name="uq_favorite_user_sound"),
|
||||||
|
UniqueConstraint("user_id", "playlist_id", name="uq_favorite_user_playlist"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# relationships
|
||||||
|
user: "User" = Relationship(back_populates="favorites")
|
||||||
|
sound: "Sound" = Relationship(back_populates="favorites")
|
||||||
|
playlist: "Playlist" = Relationship(back_populates="favorites")
|
||||||
@@ -5,6 +5,7 @@ from sqlmodel import Field, Relationship
|
|||||||
from app.models.base import BaseModel
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from app.models.favorite import Favorite
|
||||||
from app.models.playlist_sound import PlaylistSound
|
from app.models.playlist_sound import PlaylistSound
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
@@ -23,3 +24,4 @@ class Playlist(BaseModel, table=True):
|
|||||||
# relationships
|
# relationships
|
||||||
user: "User" = Relationship(back_populates="playlists")
|
user: "User" = Relationship(back_populates="playlists")
|
||||||
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="playlist")
|
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="playlist")
|
||||||
|
favorites: list["Favorite"] = Relationship(back_populates="playlist")
|
||||||
|
|||||||
@@ -35,5 +35,3 @@ class PlaylistSound(BaseModel, table=True):
|
|||||||
# relationships
|
# relationships
|
||||||
playlist: "Playlist" = Relationship(back_populates="playlist_sounds")
|
playlist: "Playlist" = Relationship(back_populates="playlist_sounds")
|
||||||
sound: "Sound" = Relationship(back_populates="playlist_sounds")
|
sound: "Sound" = Relationship(back_populates="playlist_sounds")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from app.models.base import BaseModel
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.models.extraction import Extraction
|
from app.models.extraction import Extraction
|
||||||
|
from app.models.favorite import Favorite
|
||||||
from app.models.playlist_sound import PlaylistSound
|
from app.models.playlist_sound import PlaylistSound
|
||||||
from app.models.sound_played import SoundPlayed
|
from app.models.sound_played import SoundPlayed
|
||||||
|
|
||||||
@@ -36,3 +37,4 @@ class Sound(BaseModel, table=True):
|
|||||||
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound")
|
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound")
|
||||||
extractions: list["Extraction"] = Relationship(back_populates="sound")
|
extractions: list["Extraction"] = Relationship(back_populates="sound")
|
||||||
play_history: list["SoundPlayed"] = Relationship(back_populates="sound")
|
play_history: list["SoundPlayed"] = Relationship(back_populates="sound")
|
||||||
|
favorites: list["Favorite"] = Relationship(back_populates="sound")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.models.base import BaseModel
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.models.credit_transaction import CreditTransaction
|
from app.models.credit_transaction import CreditTransaction
|
||||||
from app.models.extraction import Extraction
|
from app.models.extraction import Extraction
|
||||||
|
from app.models.favorite import Favorite
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
from app.models.playlist import Playlist
|
from app.models.playlist import Playlist
|
||||||
from app.models.sound_played import SoundPlayed
|
from app.models.sound_played import SoundPlayed
|
||||||
@@ -37,3 +38,4 @@ class User(BaseModel, table=True):
|
|||||||
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
|
sounds_played: list["SoundPlayed"] = Relationship(back_populates="user")
|
||||||
extractions: list["Extraction"] = Relationship(back_populates="user")
|
extractions: list["Extraction"] = Relationship(back_populates="user")
|
||||||
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")
|
credit_transactions: list["CreditTransaction"] = Relationship(back_populates="user")
|
||||||
|
favorites: list["Favorite"] = Relationship(back_populates="user")
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
"""Extraction repository for database operations."""
|
"""Extraction repository for database operations."""
|
||||||
|
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import asc, desc, func, or_
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.models.extraction import Extraction
|
from app.models.extraction import Extraction
|
||||||
|
from app.models.user import User
|
||||||
from app.repositories.base import BaseRepository
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
@@ -38,10 +39,11 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
|||||||
)
|
)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
async def get_pending_extractions(self) -> list[Extraction]:
|
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
|
||||||
"""Get all pending extractions."""
|
"""Get all pending extractions."""
|
||||||
result = await self.session.exec(
|
result = await self.session.exec(
|
||||||
select(Extraction)
|
select(Extraction, User)
|
||||||
|
.join(User, Extraction.user_id == User.id)
|
||||||
.where(Extraction.status == "pending")
|
.where(Extraction.status == "pending")
|
||||||
.order_by(Extraction.created_at),
|
.order_by(Extraction.created_at),
|
||||||
)
|
)
|
||||||
@@ -55,3 +57,100 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
|||||||
.order_by(desc(Extraction.created_at)),
|
.order_by(desc(Extraction.created_at)),
|
||||||
)
|
)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_user_extractions_filtered( # noqa: PLR0913
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
search: str | None = None,
|
||||||
|
sort_by: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
status_filter: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[tuple[Extraction, User]], int]:
|
||||||
|
"""Get extractions for a user with filtering, search, and sorting."""
|
||||||
|
base_query = (
|
||||||
|
select(Extraction, User)
|
||||||
|
.join(User, Extraction.user_id == User.id)
|
||||||
|
.where(Extraction.user_id == user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply search filter
|
||||||
|
if search:
|
||||||
|
search_pattern = f"%{search}%"
|
||||||
|
base_query = base_query.where(
|
||||||
|
or_(
|
||||||
|
Extraction.title.ilike(search_pattern),
|
||||||
|
Extraction.url.ilike(search_pattern),
|
||||||
|
Extraction.service.ilike(search_pattern),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply status filter
|
||||||
|
if status_filter:
|
||||||
|
base_query = base_query.where(Extraction.status == status_filter)
|
||||||
|
|
||||||
|
# Get total count before pagination
|
||||||
|
count_query = select(func.count()).select_from(
|
||||||
|
base_query.subquery(),
|
||||||
|
)
|
||||||
|
count_result = await self.session.exec(count_query)
|
||||||
|
total_count = count_result.one()
|
||||||
|
|
||||||
|
# Apply sorting and pagination
|
||||||
|
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
|
||||||
|
if sort_order.lower() == "asc":
|
||||||
|
base_query = base_query.order_by(asc(sort_column))
|
||||||
|
else:
|
||||||
|
base_query = base_query.order_by(desc(sort_column))
|
||||||
|
|
||||||
|
paginated_query = base_query.limit(limit).offset(offset)
|
||||||
|
result = await self.session.exec(paginated_query)
|
||||||
|
|
||||||
|
return list(result.all()), total_count
|
||||||
|
|
||||||
|
async def get_all_extractions_filtered( # noqa: PLR0913
|
||||||
|
self,
|
||||||
|
search: str | None = None,
|
||||||
|
sort_by: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
status_filter: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> tuple[list[tuple[Extraction, User]], int]:
|
||||||
|
"""Get all extractions with filtering, search, and sorting."""
|
||||||
|
base_query = select(Extraction, User).join(User, Extraction.user_id == User.id)
|
||||||
|
|
||||||
|
# Apply search filter
|
||||||
|
if search:
|
||||||
|
search_pattern = f"%{search}%"
|
||||||
|
base_query = base_query.where(
|
||||||
|
or_(
|
||||||
|
Extraction.title.ilike(search_pattern),
|
||||||
|
Extraction.url.ilike(search_pattern),
|
||||||
|
Extraction.service.ilike(search_pattern),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply status filter
|
||||||
|
if status_filter:
|
||||||
|
base_query = base_query.where(Extraction.status == status_filter)
|
||||||
|
|
||||||
|
# Get total count before pagination
|
||||||
|
count_query = select(func.count()).select_from(
|
||||||
|
base_query.subquery(),
|
||||||
|
)
|
||||||
|
count_result = await self.session.exec(count_query)
|
||||||
|
total_count = count_result.one()
|
||||||
|
|
||||||
|
# Apply sorting and pagination
|
||||||
|
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
|
||||||
|
if sort_order.lower() == "asc":
|
||||||
|
base_query = base_query.order_by(asc(sort_column))
|
||||||
|
else:
|
||||||
|
base_query = base_query.order_by(desc(sort_column))
|
||||||
|
|
||||||
|
paginated_query = base_query.limit(limit).offset(offset)
|
||||||
|
result = await self.session.exec(paginated_query)
|
||||||
|
|
||||||
|
return list(result.all()), total_count
|
||||||
|
|||||||
258
app/repositories/favorite.py
Normal file
258
app/repositories/favorite.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""Repository for managing favorites."""
|
||||||
|
|
||||||
|
from sqlmodel import and_, select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.favorite import Favorite
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FavoriteRepository(BaseRepository[Favorite]):
|
||||||
|
"""Repository for managing favorites."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize the favorite repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__(Favorite, session)
|
||||||
|
|
||||||
|
async def get_user_favorites(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Favorite]:
|
||||||
|
"""Get all favorites for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
limit: Maximum number of favorites to return
|
||||||
|
offset: Number of favorites to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user favorites
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = (
|
||||||
|
select(Favorite)
|
||||||
|
.where(Favorite.user_id == user_id)
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
.order_by(Favorite.created_at.desc())
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get favorites for user: %s", user_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_user_sound_favorites(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Favorite]:
|
||||||
|
"""Get sound favorites for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
limit: Maximum number of favorites to return
|
||||||
|
offset: Number of favorites to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user sound favorites
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = (
|
||||||
|
select(Favorite)
|
||||||
|
.where(and_(Favorite.user_id == user_id, Favorite.sound_id.isnot(None)))
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
.order_by(Favorite.created_at.desc())
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get sound favorites for user: %s", user_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_user_playlist_favorites(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Favorite]:
|
||||||
|
"""Get playlist favorites for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
limit: Maximum number of favorites to return
|
||||||
|
offset: Number of favorites to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user playlist favorites
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = (
|
||||||
|
select(Favorite)
|
||||||
|
.where(
|
||||||
|
and_(Favorite.user_id == user_id, Favorite.playlist_id.isnot(None)),
|
||||||
|
)
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
.order_by(Favorite.created_at.desc())
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get playlist favorites for user: %s", user_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_user_and_sound(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
sound_id: int,
|
||||||
|
) -> Favorite | None:
|
||||||
|
"""Get a favorite by user and sound.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
sound_id: The sound ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The favorite if found, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = select(Favorite).where(
|
||||||
|
and_(Favorite.user_id == user_id, Favorite.sound_id == sound_id),
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.first()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to get favorite for user %s and sound %s",
|
||||||
|
user_id,
|
||||||
|
sound_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_user_and_playlist(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
playlist_id: int,
|
||||||
|
) -> Favorite | None:
|
||||||
|
"""Get a favorite by user and playlist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
playlist_id: The playlist ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The favorite if found, None otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = select(Favorite).where(
|
||||||
|
and_(Favorite.user_id == user_id, Favorite.playlist_id == playlist_id),
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.first()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to get favorite for user %s and playlist %s",
|
||||||
|
user_id,
|
||||||
|
playlist_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def is_sound_favorited(self, user_id: int, sound_id: int) -> bool:
|
||||||
|
"""Check if a sound is favorited by a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
sound_id: The sound ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the sound is favorited, False otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
favorite = await self.get_by_user_and_sound(user_id, sound_id)
|
||||||
|
return favorite is not None
|
||||||
|
|
||||||
|
async def is_playlist_favorited(self, user_id: int, playlist_id: int) -> bool:
|
||||||
|
"""Check if a playlist is favorited by a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
playlist_id: The playlist ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the playlist is favorited, False otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
favorite = await self.get_by_user_and_playlist(user_id, playlist_id)
|
||||||
|
return favorite is not None
|
||||||
|
|
||||||
|
async def count_user_favorites(self, user_id: int) -> int:
|
||||||
|
"""Count total favorites for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of favorites
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = select(Favorite).where(Favorite.user_id == user_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return len(list(result.all()))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to count favorites for user: %s", user_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def count_sound_favorites(self, sound_id: int) -> int:
|
||||||
|
"""Count how many users have favorited a sound.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sound_id: The sound ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of users who favorited this sound
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = select(Favorite).where(Favorite.sound_id == sound_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return len(list(result.all()))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to count favorites for sound: %s", sound_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def count_playlist_favorites(self, playlist_id: int) -> int:
|
||||||
|
"""Count how many users have favorited a playlist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
playlist_id: The playlist ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of users who favorited this playlist
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
statement = select(Favorite).where(Favorite.playlist_id == playlist_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return len(list(result.all()))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to count favorites for playlist: %s", playlist_id)
|
||||||
|
raise
|
||||||
@@ -9,6 +9,7 @@ from sqlmodel import select
|
|||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
from app.models.favorite import Favorite
|
||||||
from app.models.playlist import Playlist
|
from app.models.playlist import Playlist
|
||||||
from app.models.playlist_sound import PlaylistSound
|
from app.models.playlist_sound import PlaylistSound
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
@@ -56,7 +57,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
# management
|
# management
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to update playlist timestamp for playlist: %s", playlist_id,
|
"Failed to update playlist timestamp for playlist: %s",
|
||||||
|
playlist_id,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -340,7 +342,11 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
include_stats: bool = False, # noqa: FBT001, FBT002
|
include_stats: bool = False, # noqa: FBT001, FBT002
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[dict]:
|
favorites_only: bool = False, # noqa: FBT001, FBT002
|
||||||
|
current_user_id: int | None = None,
|
||||||
|
*,
|
||||||
|
return_count: bool = False,
|
||||||
|
) -> list[dict] | tuple[list[dict], int]:
|
||||||
"""Search and sort playlists with optional statistics."""
|
"""Search and sort playlists with optional statistics."""
|
||||||
try:
|
try:
|
||||||
if include_stats and sort_by in (
|
if include_stats and sort_by in (
|
||||||
@@ -387,6 +393,19 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
subquery = subquery.where(Playlist.user_id == user_id)
|
subquery = subquery.where(Playlist.user_id == user_id)
|
||||||
|
|
||||||
|
# Apply favorites filter
|
||||||
|
if favorites_only and current_user_id is not None:
|
||||||
|
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
|
||||||
|
favorites_subquery = (
|
||||||
|
select(1)
|
||||||
|
.select_from(Favorite)
|
||||||
|
.where(
|
||||||
|
Favorite.user_id == current_user_id,
|
||||||
|
Favorite.playlist_id == Playlist.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
subquery = subquery.where(favorites_subquery.exists())
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
if sort_by == PlaylistSortField.SOUND_COUNT:
|
if sort_by == PlaylistSortField.SOUND_COUNT:
|
||||||
if sort_order == SortOrder.DESC:
|
if sort_order == SortOrder.DESC:
|
||||||
@@ -449,6 +468,19 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
subquery = subquery.where(Playlist.user_id == user_id)
|
subquery = subquery.where(Playlist.user_id == user_id)
|
||||||
|
|
||||||
|
# Apply favorites filter
|
||||||
|
if favorites_only and current_user_id is not None:
|
||||||
|
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
|
||||||
|
favorites_subquery = (
|
||||||
|
select(1)
|
||||||
|
.select_from(Favorite)
|
||||||
|
.where(
|
||||||
|
Favorite.user_id == current_user_id,
|
||||||
|
Favorite.playlist_id == Playlist.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
subquery = subquery.where(favorites_subquery.exists())
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
if sort_by:
|
if sort_by:
|
||||||
if sort_by == PlaylistSortField.NAME:
|
if sort_by == PlaylistSortField.NAME:
|
||||||
@@ -470,6 +502,14 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
# Default sorting by name ascending
|
# Default sorting by name ascending
|
||||||
subquery = subquery.order_by(Playlist.name.asc())
|
subquery = subquery.order_by(Playlist.name.asc())
|
||||||
|
|
||||||
|
# Get total count if requested
|
||||||
|
total_count = 0
|
||||||
|
if return_count:
|
||||||
|
# Create count query from the subquery before pagination
|
||||||
|
count_query = select(func.count()).select_from(subquery.subquery())
|
||||||
|
count_result = await self.session.exec(count_query)
|
||||||
|
total_count = count_result.one()
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
if offset > 0:
|
if offset > 0:
|
||||||
subquery = subquery.offset(offset)
|
subquery = subquery.offset(offset)
|
||||||
@@ -511,4 +551,6 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
|
if return_count:
|
||||||
|
return playlists, total_count
|
||||||
return playlists
|
return playlists
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from sqlmodel import col, select
|
|||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
from app.models.favorite import Favorite
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
from app.models.sound_played import SoundPlayed
|
from app.models.sound_played import SoundPlayed
|
||||||
from app.repositories.base import BaseRepository
|
from app.repositories.base import BaseRepository
|
||||||
@@ -140,11 +141,20 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
sort_order: SortOrder = SortOrder.ASC,
|
sort_order: SortOrder = SortOrder.ASC,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
|
favorites_only: bool = False, # noqa: FBT001, FBT002
|
||||||
|
user_id: int | None = None,
|
||||||
) -> list[Sound]:
|
) -> list[Sound]:
|
||||||
"""Search and sort sounds with optional filtering."""
|
"""Search and sort sounds with optional filtering."""
|
||||||
try:
|
try:
|
||||||
statement = select(Sound)
|
statement = select(Sound)
|
||||||
|
|
||||||
|
# Apply favorites filter
|
||||||
|
if favorites_only and user_id is not None:
|
||||||
|
statement = statement.join(Favorite).where(
|
||||||
|
Favorite.user_id == user_id,
|
||||||
|
Favorite.sound_id == Sound.id,
|
||||||
|
)
|
||||||
|
|
||||||
# Apply type filter
|
# Apply type filter
|
||||||
if sound_types:
|
if sound_types:
|
||||||
statement = statement.where(col(Sound.type).in_(sound_types))
|
statement = statement.where(col(Sound.type).in_(sound_types))
|
||||||
@@ -179,12 +189,15 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
logger.exception(
|
logger.exception(
|
||||||
(
|
(
|
||||||
"Failed to search and sort sounds: "
|
"Failed to search and sort sounds: "
|
||||||
"query=%s, types=%s, sort_by=%s, sort_order=%s"
|
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, "
|
||||||
|
"user_id=%s"
|
||||||
),
|
),
|
||||||
search_query,
|
search_query,
|
||||||
sound_types,
|
sound_types,
|
||||||
sort_by,
|
sort_by,
|
||||||
sort_order,
|
sort_order,
|
||||||
|
favorites_only,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -276,8 +289,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
|
|
||||||
# Group by sound and order by play count descending
|
# Group by sound and order by play count descending
|
||||||
statement = (
|
statement = (
|
||||||
statement
|
statement.group_by(
|
||||||
.group_by(
|
|
||||||
Sound.id,
|
Sound.id,
|
||||||
Sound.name,
|
Sound.name,
|
||||||
Sound.type,
|
Sound.type,
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""User repository."""
|
"""User repository."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -14,6 +16,31 @@ from app.repositories.base import BaseRepository
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserSortField(str, Enum):
|
||||||
|
"""User sort fields."""
|
||||||
|
|
||||||
|
NAME = "name"
|
||||||
|
EMAIL = "email"
|
||||||
|
ROLE = "role"
|
||||||
|
CREDITS = "credits"
|
||||||
|
CREATED_AT = "created_at"
|
||||||
|
|
||||||
|
|
||||||
|
class SortOrder(str, Enum):
|
||||||
|
"""Sort order."""
|
||||||
|
|
||||||
|
ASC = "asc"
|
||||||
|
DESC = "desc"
|
||||||
|
|
||||||
|
|
||||||
|
class UserStatus(str, Enum):
|
||||||
|
"""User status filter."""
|
||||||
|
|
||||||
|
ALL = "all"
|
||||||
|
ACTIVE = "active"
|
||||||
|
INACTIVE = "inactive"
|
||||||
|
|
||||||
|
|
||||||
class UserRepository(BaseRepository[User]):
|
class UserRepository(BaseRepository[User]):
|
||||||
"""Repository for user operations."""
|
"""Repository for user operations."""
|
||||||
|
|
||||||
@@ -40,6 +67,69 @@ class UserRepository(BaseRepository[User]):
|
|||||||
logger.exception("Failed to get all users with plan")
|
logger.exception("Failed to get all users with plan")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_all_with_plan_paginated( # noqa: PLR0913
|
||||||
|
self,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
search: str | None = None,
|
||||||
|
sort_by: UserSortField = UserSortField.NAME,
|
||||||
|
sort_order: SortOrder = SortOrder.ASC,
|
||||||
|
status_filter: UserStatus = UserStatus.ALL,
|
||||||
|
) -> tuple[list[User], int]:
|
||||||
|
"""Get all users with plan relationship loaded and return total count."""
|
||||||
|
try:
|
||||||
|
# Calculate offset
|
||||||
|
offset = (page - 1) * limit
|
||||||
|
|
||||||
|
# Build base query
|
||||||
|
base_query = select(User).options(selectinload(User.plan))
|
||||||
|
count_query = select(func.count(User.id))
|
||||||
|
|
||||||
|
# Apply search filter
|
||||||
|
if search and search.strip():
|
||||||
|
search_pattern = f"%{search.strip().lower()}%"
|
||||||
|
search_condition = func.lower(User.name).like(
|
||||||
|
search_pattern,
|
||||||
|
) | func.lower(User.email).like(search_pattern)
|
||||||
|
base_query = base_query.where(search_condition)
|
||||||
|
count_query = count_query.where(search_condition)
|
||||||
|
|
||||||
|
# Apply status filter
|
||||||
|
if status_filter == UserStatus.ACTIVE:
|
||||||
|
base_query = base_query.where(User.is_active == True) # noqa: E712
|
||||||
|
count_query = count_query.where(User.is_active == True) # noqa: E712
|
||||||
|
elif status_filter == UserStatus.INACTIVE:
|
||||||
|
base_query = base_query.where(User.is_active == False) # noqa: E712
|
||||||
|
count_query = count_query.where(User.is_active == False) # noqa: E712
|
||||||
|
|
||||||
|
# Apply sorting
|
||||||
|
sort_column = {
|
||||||
|
UserSortField.NAME: User.name,
|
||||||
|
UserSortField.EMAIL: User.email,
|
||||||
|
UserSortField.ROLE: User.role,
|
||||||
|
UserSortField.CREDITS: User.credits,
|
||||||
|
UserSortField.CREATED_AT: User.created_at,
|
||||||
|
}.get(sort_by, User.name)
|
||||||
|
|
||||||
|
if sort_order == SortOrder.DESC:
|
||||||
|
base_query = base_query.order_by(sort_column.desc())
|
||||||
|
else:
|
||||||
|
base_query = base_query.order_by(sort_column.asc())
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_result = await self.session.exec(count_query)
|
||||||
|
total_count = count_result.one()
|
||||||
|
|
||||||
|
# Apply pagination and get results
|
||||||
|
paginated_query = base_query.limit(limit).offset(offset)
|
||||||
|
result = await self.session.exec(paginated_query)
|
||||||
|
users = list(result.all())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get paginated users with plan")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
return users, total_count
|
||||||
|
|
||||||
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
|
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
|
||||||
"""Get a user by ID with plan relationship loaded."""
|
"""Get a user by ID with plan relationship loaded."""
|
||||||
try:
|
try:
|
||||||
@@ -77,7 +167,7 @@ class UserRepository(BaseRepository[User]):
|
|||||||
logger.exception("Failed to get user by API token")
|
logger.exception("Failed to get user by API token")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def create(self, user_data: dict[str, Any]) -> User:
|
async def create(self, entity_data: dict[str, Any]) -> User:
|
||||||
"""Create a new user with plan assignment and first user admin logic."""
|
"""Create a new user with plan assignment and first user admin logic."""
|
||||||
|
|
||||||
def _raise_plan_not_found() -> None:
|
def _raise_plan_not_found() -> None:
|
||||||
@@ -93,7 +183,7 @@ class UserRepository(BaseRepository[User]):
|
|||||||
if is_first_user:
|
if is_first_user:
|
||||||
# First user gets admin role and pro plan
|
# First user gets admin role and pro plan
|
||||||
plan_statement = select(Plan).where(Plan.code == "pro")
|
plan_statement = select(Plan).where(Plan.code == "pro")
|
||||||
user_data["role"] = "admin"
|
entity_data["role"] = "admin"
|
||||||
logger.info("Creating first user with admin role and pro plan")
|
logger.info("Creating first user with admin role and pro plan")
|
||||||
else:
|
else:
|
||||||
# Regular users get free plan
|
# Regular users get free plan
|
||||||
@@ -109,11 +199,11 @@ class UserRepository(BaseRepository[User]):
|
|||||||
assert default_plan is not None # noqa: S101
|
assert default_plan is not None # noqa: S101
|
||||||
|
|
||||||
# Set plan_id and default credits
|
# Set plan_id and default credits
|
||||||
user_data["plan_id"] = default_plan.id
|
entity_data["plan_id"] = default_plan.id
|
||||||
user_data["credits"] = default_plan.credits
|
entity_data["credits"] = default_plan.credits
|
||||||
|
|
||||||
# Use BaseRepository's create method
|
# Use BaseRepository's create method
|
||||||
return await super().create(user_data)
|
return await super().create(entity_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to create user")
|
logger.exception("Failed to create user")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -85,7 +85,8 @@ class ChangePasswordRequest(BaseModel):
|
|||||||
"""Schema for password change request."""
|
"""Schema for password change request."""
|
||||||
|
|
||||||
current_password: str | None = Field(
|
current_password: str | None = Field(
|
||||||
None, description="Current password (required if user has existing password)",
|
None,
|
||||||
|
description="Current password (required if user has existing password)",
|
||||||
)
|
)
|
||||||
new_password: str = Field(
|
new_password: str = Field(
|
||||||
...,
|
...,
|
||||||
@@ -98,5 +99,8 @@ class UpdateProfileRequest(BaseModel):
|
|||||||
"""Schema for profile update request."""
|
"""Schema for profile update request."""
|
||||||
|
|
||||||
name: str | None = Field(
|
name: str | None = Field(
|
||||||
None, min_length=1, max_length=100, description="User display name",
|
None,
|
||||||
|
min_length=1,
|
||||||
|
max_length=100,
|
||||||
|
description="User display name",
|
||||||
)
|
)
|
||||||
|
|||||||
41
app/schemas/favorite.py
Normal file
41
app/schemas/favorite.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Favorite response schemas."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class FavoriteResponse(BaseModel):
|
||||||
|
"""Response schema for a favorite."""
|
||||||
|
|
||||||
|
id: int = Field(description="Favorite ID")
|
||||||
|
user_id: int = Field(description="User ID")
|
||||||
|
sound_id: int | None = Field(
|
||||||
|
description="Sound ID if this is a sound favorite",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
playlist_id: int | None = Field(
|
||||||
|
description="Playlist ID if this is a playlist favorite",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
created_at: datetime = Field(description="Creation timestamp")
|
||||||
|
updated_at: datetime = Field(description="Last update timestamp")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic config."""
|
||||||
|
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class FavoritesListResponse(BaseModel):
|
||||||
|
"""Response schema for a list of favorites."""
|
||||||
|
|
||||||
|
favorites: list[FavoriteResponse] = Field(description="List of favorites")
|
||||||
|
|
||||||
|
|
||||||
|
class FavoriteCountsResponse(BaseModel):
|
||||||
|
"""Response schema for favorite counts."""
|
||||||
|
|
||||||
|
total: int = Field(description="Total number of favorites")
|
||||||
|
sounds: int = Field(description="Number of favorited sounds")
|
||||||
|
playlists: int = Field(description="Number of favorited playlists")
|
||||||
@@ -33,12 +33,29 @@ class PlaylistResponse(BaseModel):
|
|||||||
is_main: bool
|
is_main: bool
|
||||||
is_current: bool
|
is_current: bool
|
||||||
is_deletable: bool
|
is_deletable: bool
|
||||||
|
is_favorited: bool = False
|
||||||
|
favorite_count: int = 0
|
||||||
created_at: str
|
created_at: str
|
||||||
updated_at: str | None
|
updated_at: str | None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_playlist(cls, playlist: Playlist) -> "PlaylistResponse":
|
def from_playlist(
|
||||||
"""Create response from playlist model."""
|
cls,
|
||||||
|
playlist: Playlist,
|
||||||
|
is_favorited: bool = False, # noqa: FBT001, FBT002
|
||||||
|
favorite_count: int = 0,
|
||||||
|
) -> "PlaylistResponse":
|
||||||
|
"""Create response from playlist model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
playlist: The Playlist model
|
||||||
|
is_favorited: Whether the playlist is favorited by the current user
|
||||||
|
favorite_count: Number of users who favorited this playlist
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PlaylistResponse instance
|
||||||
|
|
||||||
|
"""
|
||||||
if playlist.id is None:
|
if playlist.id is None:
|
||||||
msg = "Playlist ID cannot be None"
|
msg = "Playlist ID cannot be None"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
@@ -50,6 +67,8 @@ class PlaylistResponse(BaseModel):
|
|||||||
is_main=playlist.is_main,
|
is_main=playlist.is_main,
|
||||||
is_current=playlist.is_current,
|
is_current=playlist.is_current,
|
||||||
is_deletable=playlist.is_deletable,
|
is_deletable=playlist.is_deletable,
|
||||||
|
is_favorited=is_favorited,
|
||||||
|
favorite_count=favorite_count,
|
||||||
created_at=playlist.created_at.isoformat(),
|
created_at=playlist.created_at.isoformat(),
|
||||||
updated_at=playlist.updated_at.isoformat() if playlist.updated_at else None,
|
updated_at=playlist.updated_at.isoformat() if playlist.updated_at else None,
|
||||||
)
|
)
|
||||||
|
|||||||
106
app/schemas/sound.py
Normal file
106
app/schemas/sound.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""Sound response schemas."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.sound import Sound
|
||||||
|
|
||||||
|
|
||||||
|
class SoundResponse(BaseModel):
|
||||||
|
"""Response schema for a sound with favorite indicator."""
|
||||||
|
|
||||||
|
id: int = Field(description="Sound ID")
|
||||||
|
type: str = Field(description="Sound type")
|
||||||
|
name: str = Field(description="Sound name")
|
||||||
|
filename: str = Field(description="Sound filename")
|
||||||
|
duration: int = Field(description="Duration in milliseconds")
|
||||||
|
size: int = Field(description="File size in bytes")
|
||||||
|
hash: str = Field(description="File hash")
|
||||||
|
normalized_filename: str | None = Field(
|
||||||
|
description="Normalized filename",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
normalized_duration: int | None = Field(
|
||||||
|
description="Normalized duration in milliseconds",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
normalized_size: int | None = Field(
|
||||||
|
description="Normalized file size in bytes",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
normalized_hash: str | None = Field(
|
||||||
|
description="Normalized file hash",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
thumbnail: str | None = Field(description="Thumbnail filename", default=None)
|
||||||
|
play_count: int = Field(description="Number of times played")
|
||||||
|
is_normalized: bool = Field(description="Whether the sound is normalized")
|
||||||
|
is_music: bool = Field(description="Whether the sound is music")
|
||||||
|
is_deletable: bool = Field(description="Whether the sound can be deleted")
|
||||||
|
is_favorited: bool = Field(
|
||||||
|
description="Whether the sound is favorited by the current user",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
favorite_count: int = Field(
|
||||||
|
description="Number of users who favorited this sound",
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
created_at: datetime = Field(description="Creation timestamp")
|
||||||
|
updated_at: datetime = Field(description="Last update timestamp")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic config."""
|
||||||
|
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_sound(
|
||||||
|
cls,
|
||||||
|
sound: Sound,
|
||||||
|
is_favorited: bool = False, # noqa: FBT001, FBT002
|
||||||
|
favorite_count: int = 0,
|
||||||
|
) -> "SoundResponse":
|
||||||
|
"""Create a SoundResponse from a Sound model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sound: The Sound model
|
||||||
|
is_favorited: Whether the sound is favorited by the current user
|
||||||
|
favorite_count: Number of users who favorited this sound
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SoundResponse instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
if sound.id is None:
|
||||||
|
msg = "Sound ID cannot be None"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=sound.id,
|
||||||
|
type=sound.type,
|
||||||
|
name=sound.name,
|
||||||
|
filename=sound.filename,
|
||||||
|
duration=sound.duration,
|
||||||
|
size=sound.size,
|
||||||
|
hash=sound.hash,
|
||||||
|
normalized_filename=sound.normalized_filename,
|
||||||
|
normalized_duration=sound.normalized_duration,
|
||||||
|
normalized_size=sound.normalized_size,
|
||||||
|
normalized_hash=sound.normalized_hash,
|
||||||
|
thumbnail=sound.thumbnail,
|
||||||
|
play_count=sound.play_count,
|
||||||
|
is_normalized=sound.is_normalized,
|
||||||
|
is_music=sound.is_music,
|
||||||
|
is_deletable=sound.is_deletable,
|
||||||
|
is_favorited=is_favorited,
|
||||||
|
favorite_count=favorite_count,
|
||||||
|
created_at=sound.created_at,
|
||||||
|
updated_at=sound.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SoundsListResponse(BaseModel):
|
||||||
|
"""Response schema for a list of sounds."""
|
||||||
|
|
||||||
|
sounds: list[SoundResponse] = Field(description="List of sounds")
|
||||||
@@ -7,7 +7,10 @@ class UserUpdate(BaseModel):
|
|||||||
"""Schema for updating a user."""
|
"""Schema for updating a user."""
|
||||||
|
|
||||||
name: str | None = Field(
|
name: str | None = Field(
|
||||||
None, min_length=1, max_length=100, description="User full name",
|
None,
|
||||||
|
min_length=1,
|
||||||
|
max_length=100,
|
||||||
|
description="User full name",
|
||||||
)
|
)
|
||||||
plan_id: int | None = Field(None, description="User plan ID")
|
plan_id: int | None = Field(None, description="User plan ID")
|
||||||
credits: int | None = Field(None, ge=0, description="User credits")
|
credits: int | None = Field(None, ge=0, description="User credits")
|
||||||
|
|||||||
@@ -454,7 +454,10 @@ class AuthService:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
async def change_user_password(
|
async def change_user_password(
|
||||||
self, user: User, current_password: str | None, new_password: str,
|
self,
|
||||||
|
user: User,
|
||||||
|
current_password: str | None,
|
||||||
|
new_password: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Change user's password."""
|
"""Change user's password."""
|
||||||
# Store user email before any operations to avoid session detachment issues
|
# Store user email before any operations to avoid session detachment issues
|
||||||
@@ -484,8 +487,11 @@ class AuthService:
|
|||||||
self.session.add(user)
|
self.session.add(user)
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
|
|
||||||
logger.info("Password %s successfully for user: %s",
|
logger.info(
|
||||||
"changed" if had_existing_password else "set", user_email)
|
"Password %s successfully for user: %s",
|
||||||
|
"changed" if had_existing_password else "set",
|
||||||
|
user_email,
|
||||||
|
)
|
||||||
|
|
||||||
async def user_to_response(self, user: User) -> UserResponse:
|
async def user_to_response(self, user: User) -> UserResponse:
|
||||||
"""Convert User model to UserResponse with plan information."""
|
"""Convert User model to UserResponse with plan information."""
|
||||||
|
|||||||
@@ -72,9 +72,7 @@ class DashboardService:
|
|||||||
"play_count": sound["play_count"],
|
"play_count": sound["play_count"],
|
||||||
"duration": sound["duration"],
|
"duration": sound["duration"],
|
||||||
"created_at": (
|
"created_at": (
|
||||||
sound["created_at"].isoformat()
|
sound["created_at"].isoformat() if sound["created_at"] else None
|
||||||
if sound["created_at"]
|
|
||||||
else None
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
for sound in top_sounds
|
for sound in top_sounds
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from app.core.logging import get_logger
|
|||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
from app.repositories.extraction import ExtractionRepository
|
from app.repositories.extraction import ExtractionRepository
|
||||||
from app.repositories.sound import SoundRepository
|
from app.repositories.sound import SoundRepository
|
||||||
|
from app.repositories.user import UserRepository
|
||||||
from app.services.playlist import PlaylistService
|
from app.services.playlist import PlaylistService
|
||||||
from app.services.sound_normalizer import SoundNormalizerService
|
from app.services.sound_normalizer import SoundNormalizerService
|
||||||
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||||
@@ -32,10 +33,21 @@ class ExtractionInfo(TypedDict):
|
|||||||
error: str | None
|
error: str | None
|
||||||
sound_id: int | None
|
sound_id: int | None
|
||||||
user_id: int
|
user_id: int
|
||||||
|
user_name: str | None
|
||||||
created_at: str
|
created_at: str
|
||||||
updated_at: str
|
updated_at: str
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedExtractionsResponse(TypedDict):
|
||||||
|
"""Type definition for paginated extractions response."""
|
||||||
|
|
||||||
|
extractions: list[ExtractionInfo]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
total_pages: int
|
||||||
|
|
||||||
|
|
||||||
class ExtractionService:
|
class ExtractionService:
|
||||||
"""Service for extracting audio from external services using yt-dlp."""
|
"""Service for extracting audio from external services using yt-dlp."""
|
||||||
|
|
||||||
@@ -44,6 +56,7 @@ class ExtractionService:
|
|||||||
self.session = session
|
self.session = session
|
||||||
self.extraction_repo = ExtractionRepository(session)
|
self.extraction_repo = ExtractionRepository(session)
|
||||||
self.sound_repo = SoundRepository(session)
|
self.sound_repo = SoundRepository(session)
|
||||||
|
self.user_repo = UserRepository(session)
|
||||||
self.playlist_service = PlaylistService(session)
|
self.playlist_service = PlaylistService(session)
|
||||||
|
|
||||||
# Ensure required directories exist
|
# Ensure required directories exist
|
||||||
@@ -66,6 +79,15 @@ class ExtractionService:
|
|||||||
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
|
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Get user information
|
||||||
|
user = await self.user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
msg = f"User {user_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Extract user name immediately while in session context
|
||||||
|
user_name = user.name
|
||||||
|
|
||||||
# Create the extraction record without service detection for fast response
|
# Create the extraction record without service detection for fast response
|
||||||
extraction_data = {
|
extraction_data = {
|
||||||
"url": url,
|
"url": url,
|
||||||
@@ -92,6 +114,7 @@ class ExtractionService:
|
|||||||
"error": extraction.error,
|
"error": extraction.error,
|
||||||
"sound_id": extraction.sound_id,
|
"sound_id": extraction.sound_id,
|
||||||
"user_id": extraction.user_id,
|
"user_id": extraction.user_id,
|
||||||
|
"user_name": user_name,
|
||||||
"created_at": extraction.created_at.isoformat(),
|
"created_at": extraction.created_at.isoformat(),
|
||||||
"updated_at": extraction.updated_at.isoformat(),
|
"updated_at": extraction.updated_at.isoformat(),
|
||||||
}
|
}
|
||||||
@@ -509,7 +532,8 @@ class ExtractionService:
|
|||||||
"""Add the sound to the user's main playlist."""
|
"""Add the sound to the user's main playlist."""
|
||||||
try:
|
try:
|
||||||
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
|
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
|
||||||
sound_id, user_id,
|
sound_id,
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Added sound %d to main playlist for user %d",
|
"Added sound %d to main playlist for user %d",
|
||||||
@@ -531,6 +555,10 @@ class ExtractionService:
|
|||||||
if not extraction:
|
if not extraction:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Get user information
|
||||||
|
user = await self.user_repo.get_by_id(extraction.user_id)
|
||||||
|
user_name = user.name if user else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": extraction.id or 0, # Should never be None for existing extraction
|
"id": extraction.id or 0, # Should never be None for existing extraction
|
||||||
"url": extraction.url,
|
"url": extraction.url,
|
||||||
@@ -541,15 +569,37 @@ class ExtractionService:
|
|||||||
"error": extraction.error,
|
"error": extraction.error,
|
||||||
"sound_id": extraction.sound_id,
|
"sound_id": extraction.sound_id,
|
||||||
"user_id": extraction.user_id,
|
"user_id": extraction.user_id,
|
||||||
|
"user_name": user_name,
|
||||||
"created_at": extraction.created_at.isoformat(),
|
"created_at": extraction.created_at.isoformat(),
|
||||||
"updated_at": extraction.updated_at.isoformat(),
|
"updated_at": extraction.updated_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_user_extractions(self, user_id: int) -> list[ExtractionInfo]:
|
async def get_user_extractions( # noqa: PLR0913
|
||||||
"""Get all extractions for a user."""
|
self,
|
||||||
extractions = await self.extraction_repo.get_by_user(user_id)
|
user_id: int,
|
||||||
|
search: str | None = None,
|
||||||
|
sort_by: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
status_filter: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> PaginatedExtractionsResponse:
|
||||||
|
"""Get all extractions for a user with filtering, search, and sorting."""
|
||||||
|
offset = (page - 1) * limit
|
||||||
|
(
|
||||||
|
extraction_user_tuples,
|
||||||
|
total_count,
|
||||||
|
) = await self.extraction_repo.get_user_extractions_filtered(
|
||||||
|
user_id=user_id,
|
||||||
|
search=search,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
status_filter=status_filter,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
extractions = [
|
||||||
{
|
{
|
||||||
"id": extraction.id
|
"id": extraction.id
|
||||||
or 0, # Should never be None for existing extraction
|
or 0, # Should never be None for existing extraction
|
||||||
@@ -561,15 +611,78 @@ class ExtractionService:
|
|||||||
"error": extraction.error,
|
"error": extraction.error,
|
||||||
"sound_id": extraction.sound_id,
|
"sound_id": extraction.sound_id,
|
||||||
"user_id": extraction.user_id,
|
"user_id": extraction.user_id,
|
||||||
|
"user_name": user.name,
|
||||||
"created_at": extraction.created_at.isoformat(),
|
"created_at": extraction.created_at.isoformat(),
|
||||||
"updated_at": extraction.updated_at.isoformat(),
|
"updated_at": extraction.updated_at.isoformat(),
|
||||||
}
|
}
|
||||||
for extraction in extractions
|
for extraction, user in extraction_user_tuples
|
||||||
]
|
]
|
||||||
|
|
||||||
|
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||||
|
|
||||||
|
return {
|
||||||
|
"extractions": extractions,
|
||||||
|
"total": total_count,
|
||||||
|
"page": page,
|
||||||
|
"limit": limit,
|
||||||
|
"total_pages": total_pages,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_all_extractions( # noqa: PLR0913
|
||||||
|
self,
|
||||||
|
search: str | None = None,
|
||||||
|
sort_by: str = "created_at",
|
||||||
|
sort_order: str = "desc",
|
||||||
|
status_filter: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> PaginatedExtractionsResponse:
|
||||||
|
"""Get all extractions with filtering, search, and sorting."""
|
||||||
|
offset = (page - 1) * limit
|
||||||
|
(
|
||||||
|
extraction_user_tuples,
|
||||||
|
total_count,
|
||||||
|
) = await self.extraction_repo.get_all_extractions_filtered(
|
||||||
|
search=search,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
status_filter=status_filter,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
extractions = [
|
||||||
|
{
|
||||||
|
"id": extraction.id
|
||||||
|
or 0, # Should never be None for existing extraction
|
||||||
|
"url": extraction.url,
|
||||||
|
"service": extraction.service,
|
||||||
|
"service_id": extraction.service_id,
|
||||||
|
"title": extraction.title,
|
||||||
|
"status": extraction.status,
|
||||||
|
"error": extraction.error,
|
||||||
|
"sound_id": extraction.sound_id,
|
||||||
|
"user_id": extraction.user_id,
|
||||||
|
"user_name": user.name,
|
||||||
|
"created_at": extraction.created_at.isoformat(),
|
||||||
|
"updated_at": extraction.updated_at.isoformat(),
|
||||||
|
}
|
||||||
|
for extraction, user in extraction_user_tuples
|
||||||
|
]
|
||||||
|
|
||||||
|
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||||
|
|
||||||
|
return {
|
||||||
|
"extractions": extractions,
|
||||||
|
"total": total_count,
|
||||||
|
"page": page,
|
||||||
|
"limit": limit,
|
||||||
|
"total_pages": total_pages,
|
||||||
|
}
|
||||||
|
|
||||||
async def get_pending_extractions(self) -> list[ExtractionInfo]:
|
async def get_pending_extractions(self) -> list[ExtractionInfo]:
|
||||||
"""Get all pending extractions."""
|
"""Get all pending extractions."""
|
||||||
extractions = await self.extraction_repo.get_pending_extractions()
|
extraction_user_tuples = await self.extraction_repo.get_pending_extractions()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@@ -583,8 +696,9 @@ class ExtractionService:
|
|||||||
"error": extraction.error,
|
"error": extraction.error,
|
||||||
"sound_id": extraction.sound_id,
|
"sound_id": extraction.sound_id,
|
||||||
"user_id": extraction.user_id,
|
"user_id": extraction.user_id,
|
||||||
|
"user_name": user.name,
|
||||||
"created_at": extraction.created_at.isoformat(),
|
"created_at": extraction.created_at.isoformat(),
|
||||||
"updated_at": extraction.updated_at.isoformat(),
|
"updated_at": extraction.updated_at.isoformat(),
|
||||||
}
|
}
|
||||||
for extraction in extractions
|
for extraction, user in extraction_user_tuples
|
||||||
]
|
]
|
||||||
|
|||||||
382
app/services/favorite.py
Normal file
382
app/services/favorite.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
"""Service for managing user favorites."""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.favorite import Favorite
|
||||||
|
from app.repositories.favorite import FavoriteRepository
|
||||||
|
from app.repositories.playlist import PlaylistRepository
|
||||||
|
from app.repositories.sound import SoundRepository
|
||||||
|
from app.repositories.user import UserRepository
|
||||||
|
from app.services.socket import socket_manager
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FavoriteService:
|
||||||
|
"""Service for managing user favorites."""
|
||||||
|
|
||||||
|
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
||||||
|
"""Initialize the favorite service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_session_factory: Factory function to create database sessions
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.db_session_factory = db_session_factory
|
||||||
|
|
||||||
|
async def add_sound_favorite(self, user_id: int, sound_id: int) -> Favorite:
|
||||||
|
"""Add a sound to user's favorites.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
sound_id: The sound ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created favorite
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user or sound not found, or already favorited
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
sound_repo = SoundRepository(session)
|
||||||
|
|
||||||
|
# Verify user exists
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
msg = f"User with ID {user_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Verify sound exists
|
||||||
|
sound = await sound_repo.get_by_id(sound_id)
|
||||||
|
if not sound:
|
||||||
|
msg = f"Sound with ID {sound_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Get data for the event immediately after loading
|
||||||
|
sound_name = sound.name
|
||||||
|
user_name = user.name
|
||||||
|
|
||||||
|
# Check if already favorited
|
||||||
|
existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
|
||||||
|
if existing:
|
||||||
|
msg = f"Sound {sound_id} is already favorited by user {user_id}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Create favorite
|
||||||
|
favorite_data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"sound_id": sound_id,
|
||||||
|
"playlist_id": None,
|
||||||
|
}
|
||||||
|
favorite = await favorite_repo.create(favorite_data)
|
||||||
|
logger.info("User %s favorited sound %s", user_id, sound_id)
|
||||||
|
|
||||||
|
# Get updated favorite count within the same session
|
||||||
|
favorite_count = await favorite_repo.count_sound_favorites(sound_id)
|
||||||
|
|
||||||
|
# Emit sound_favorited event via WebSocket (outside the session)
|
||||||
|
try:
|
||||||
|
event_data = {
|
||||||
|
"sound_id": sound_id,
|
||||||
|
"sound_name": sound_name,
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_name": user_name,
|
||||||
|
"favorite_count": favorite_count,
|
||||||
|
}
|
||||||
|
await socket_manager.broadcast_to_all("sound_favorited", event_data)
|
||||||
|
logger.info("Broadcasted sound_favorited event for sound %s", sound_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to broadcast sound_favorited event for sound %s",
|
||||||
|
sound_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return favorite
|
||||||
|
|
||||||
|
async def add_playlist_favorite(self, user_id: int, playlist_id: int) -> Favorite:
|
||||||
|
"""Add a playlist to user's favorites.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
playlist_id: The playlist ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created favorite
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user or playlist not found, or already favorited
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
playlist_repo = PlaylistRepository(session)
|
||||||
|
|
||||||
|
# Verify user exists
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
msg = f"User with ID {user_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Verify playlist exists
|
||||||
|
playlist = await playlist_repo.get_by_id(playlist_id)
|
||||||
|
if not playlist:
|
||||||
|
msg = f"Playlist with ID {playlist_id} not found"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Check if already favorited
|
||||||
|
existing = await favorite_repo.get_by_user_and_playlist(
|
||||||
|
user_id,
|
||||||
|
playlist_id,
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
msg = f"Playlist {playlist_id} is already favorited by user {user_id}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Create favorite
|
||||||
|
favorite_data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"sound_id": None,
|
||||||
|
"playlist_id": playlist_id,
|
||||||
|
}
|
||||||
|
favorite = await favorite_repo.create(favorite_data)
|
||||||
|
logger.info("User %s favorited playlist %s", user_id, playlist_id)
|
||||||
|
return favorite
|
||||||
|
|
||||||
|
async def remove_sound_favorite(self, user_id: int, sound_id: int) -> None:
|
||||||
|
"""Remove a sound from user's favorites.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
sound_id: The sound ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If favorite not found
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
|
||||||
|
favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
|
||||||
|
if not favorite:
|
||||||
|
msg = f"Sound {sound_id} is not favorited by user {user_id}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# Get user and sound info before deletion for the event
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
sound_repo = SoundRepository(session)
|
||||||
|
user = await user_repo.get_by_id(user_id)
|
||||||
|
sound = await sound_repo.get_by_id(sound_id)
|
||||||
|
|
||||||
|
# Get data for the event immediately after loading
|
||||||
|
sound_name = sound.name if sound else "Unknown"
|
||||||
|
user_name = user.name if user else "Unknown"
|
||||||
|
|
||||||
|
await favorite_repo.delete(favorite)
|
||||||
|
logger.info("User %s removed sound %s from favorites", user_id, sound_id)
|
||||||
|
|
||||||
|
# Get updated favorite count after deletion within the same session
|
||||||
|
favorite_count = await favorite_repo.count_sound_favorites(sound_id)
|
||||||
|
|
||||||
|
# Emit sound_favorited event via WebSocket (outside the session)
|
||||||
|
try:
|
||||||
|
event_data = {
|
||||||
|
"sound_id": sound_id,
|
||||||
|
"sound_name": sound_name,
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_name": user_name,
|
||||||
|
"favorite_count": favorite_count,
|
||||||
|
}
|
||||||
|
await socket_manager.broadcast_to_all("sound_favorited", event_data)
|
||||||
|
logger.info(
|
||||||
|
"Broadcasted sound_favorited event for sound %s removal",
|
||||||
|
sound_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to broadcast sound_favorited event for sound %s removal",
|
||||||
|
sound_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def remove_playlist_favorite(self, user_id: int, playlist_id: int) -> None:
|
||||||
|
"""Remove a playlist from user's favorites.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
playlist_id: The playlist ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If favorite not found
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
|
||||||
|
favorite = await favorite_repo.get_by_user_and_playlist(
|
||||||
|
user_id,
|
||||||
|
playlist_id,
|
||||||
|
)
|
||||||
|
if not favorite:
|
||||||
|
msg = f"Playlist {playlist_id} is not favorited by user {user_id}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
await favorite_repo.delete(favorite)
|
||||||
|
logger.info(
|
||||||
|
"User %s removed playlist %s from favorites",
|
||||||
|
user_id,
|
||||||
|
playlist_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_favorites(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Favorite]:
|
||||||
|
"""Get all favorites for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
limit: Maximum number of favorites to return
|
||||||
|
offset: Number of favorites to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user favorites
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
return await favorite_repo.get_user_favorites(user_id, limit, offset)
|
||||||
|
|
||||||
|
async def get_user_sound_favorites(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Favorite]:
|
||||||
|
"""Get sound favorites for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
limit: Maximum number of favorites to return
|
||||||
|
offset: Number of favorites to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user sound favorites
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
return await favorite_repo.get_user_sound_favorites(user_id, limit, offset)
|
||||||
|
|
||||||
|
async def get_user_playlist_favorites(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Favorite]:
|
||||||
|
"""Get playlist favorites for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
limit: Maximum number of favorites to return
|
||||||
|
offset: Number of favorites to skip
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user playlist favorites
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
return await favorite_repo.get_user_playlist_favorites(
|
||||||
|
user_id,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def is_sound_favorited(self, user_id: int, sound_id: int) -> bool:
|
||||||
|
"""Check if a sound is favorited by a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
sound_id: The sound ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the sound is favorited, False otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
return await favorite_repo.is_sound_favorited(user_id, sound_id)
|
||||||
|
|
||||||
|
async def is_playlist_favorited(self, user_id: int, playlist_id: int) -> bool:
|
||||||
|
"""Check if a playlist is favorited by a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
playlist_id: The playlist ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the playlist is favorited, False otherwise
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
return await favorite_repo.is_playlist_favorited(user_id, playlist_id)
|
||||||
|
|
||||||
|
async def get_favorite_counts(self, user_id: int) -> dict[str, int]:
|
||||||
|
"""Get favorite counts for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with favorite counts
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
|
||||||
|
total = await favorite_repo.count_user_favorites(user_id)
|
||||||
|
sounds = len(await favorite_repo.get_user_sound_favorites(user_id))
|
||||||
|
playlists = len(await favorite_repo.get_user_playlist_favorites(user_id))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"sounds": sounds,
|
||||||
|
"playlists": playlists,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_sound_favorite_count(self, sound_id: int) -> int:
|
||||||
|
"""Get the number of users who have favorited a sound.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sound_id: The sound ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of users who favorited this sound
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
return await favorite_repo.count_sound_favorites(sound_id)
|
||||||
|
|
||||||
|
async def get_playlist_favorite_count(self, playlist_id: int) -> int:
|
||||||
|
"""Get the number of users who have favorited a playlist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
playlist_id: The playlist ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of users who favorited this playlist
|
||||||
|
|
||||||
|
"""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
favorite_repo = FavoriteRepository(session)
|
||||||
|
return await favorite_repo.count_playlist_favorites(playlist_id)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Playlist service for business logic operations."""
|
"""Playlist service for business logic operations."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -14,6 +14,16 @@ from app.repositories.sound import SoundRepository
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedPlaylistsResponse(TypedDict):
|
||||||
|
"""Response type for paginated playlists."""
|
||||||
|
|
||||||
|
playlists: list[dict]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
total_pages: int
|
||||||
|
|
||||||
|
|
||||||
async def _reload_player_playlist() -> None:
|
async def _reload_player_playlist() -> None:
|
||||||
"""Reload the player playlist after current playlist changes."""
|
"""Reload the player playlist after current playlist changes."""
|
||||||
try:
|
try:
|
||||||
@@ -246,6 +256,8 @@ class PlaylistService:
|
|||||||
include_stats: bool = False,
|
include_stats: bool = False,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
|
favorites_only: bool = False,
|
||||||
|
current_user_id: int | None = None,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Search and sort playlists with optional statistics."""
|
"""Search and sort playlists with optional statistics."""
|
||||||
return await self.playlist_repo.search_and_sort(
|
return await self.playlist_repo.search_and_sort(
|
||||||
@@ -256,6 +268,47 @@ class PlaylistService:
|
|||||||
include_stats=include_stats,
|
include_stats=include_stats,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
|
favorites_only=favorites_only,
|
||||||
|
current_user_id=current_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search_and_sort_playlists_paginated( # noqa: PLR0913
|
||||||
|
self,
|
||||||
|
search_query: str | None = None,
|
||||||
|
sort_by: PlaylistSortField | None = None,
|
||||||
|
sort_order: SortOrder = SortOrder.ASC,
|
||||||
|
user_id: int | None = None,
|
||||||
|
*,
|
||||||
|
include_stats: bool = False,
|
||||||
|
page: int = 1,
|
||||||
|
limit: int = 50,
|
||||||
|
favorites_only: bool = False,
|
||||||
|
current_user_id: int | None = None,
|
||||||
|
) -> PaginatedPlaylistsResponse:
|
||||||
|
"""Search and sort playlists with pagination."""
|
||||||
|
offset = (page - 1) * limit
|
||||||
|
|
||||||
|
playlists, total_count = await self.playlist_repo.search_and_sort(
|
||||||
|
search_query=search_query,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
user_id=user_id,
|
||||||
|
include_stats=include_stats,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
favorites_only=favorites_only,
|
||||||
|
current_user_id=current_user_id,
|
||||||
|
return_count=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_pages = (total_count + limit - 1) // limit # Ceiling division
|
||||||
|
|
||||||
|
return PaginatedPlaylistsResponse(
|
||||||
|
playlists=playlists,
|
||||||
|
total=total_count,
|
||||||
|
page=page,
|
||||||
|
limit=limit,
|
||||||
|
total_pages=total_pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
|
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
|
||||||
@@ -416,7 +469,9 @@ class PlaylistService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def add_sound_to_main_playlist(
|
async def add_sound_to_main_playlist(
|
||||||
self, sound_id: int, user_id: int, # noqa: ARG002
|
self,
|
||||||
|
sound_id: int, # noqa: ARG002
|
||||||
|
user_id: int, # noqa: ARG002
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a sound to the global main playlist."""
|
"""Add a sound to the global main playlist."""
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -425,7 +480,9 @@ class PlaylistService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _add_sound_to_main_playlist_internal(
|
async def _add_sound_to_main_playlist_internal(
|
||||||
self, sound_id: int, user_id: int,
|
self,
|
||||||
|
sound_id: int,
|
||||||
|
user_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add sound to main playlist bypassing restrictions.
|
"""Add sound to main playlist bypassing restrictions.
|
||||||
|
|
||||||
|
|||||||
@@ -21,8 +21,6 @@ def mock_plan_repository():
|
|||||||
return Mock()
|
return Mock()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def regular_user():
|
def regular_user():
|
||||||
"""Create regular user for testing."""
|
"""Create regular user for testing."""
|
||||||
@@ -60,52 +58,78 @@ class TestAdminUserEndpoints:
|
|||||||
test_plan: Plan,
|
test_plan: Plan,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test listing users successfully."""
|
"""Test listing users successfully."""
|
||||||
with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all:
|
with patch(
|
||||||
|
"app.repositories.user.UserRepository.get_all_with_plan_paginated",
|
||||||
|
) as mock_get_all:
|
||||||
# Create mock user objects that don't trigger database saves
|
# Create mock user objects that don't trigger database saves
|
||||||
mock_admin = type("User", (), {
|
mock_admin = type(
|
||||||
"id": admin_user.id,
|
"User",
|
||||||
"email": admin_user.email,
|
(),
|
||||||
"name": admin_user.name,
|
{
|
||||||
"picture": None,
|
"id": admin_user.id,
|
||||||
"role": admin_user.role,
|
"email": admin_user.email,
|
||||||
"credits": admin_user.credits,
|
"name": admin_user.name,
|
||||||
"is_active": admin_user.is_active,
|
"picture": None,
|
||||||
"created_at": admin_user.created_at,
|
"role": admin_user.role,
|
||||||
"updated_at": admin_user.updated_at,
|
"credits": admin_user.credits,
|
||||||
"plan": type("Plan", (), {
|
"is_active": admin_user.is_active,
|
||||||
"id": test_plan.id,
|
"created_at": admin_user.created_at,
|
||||||
"name": test_plan.name,
|
"updated_at": admin_user.updated_at,
|
||||||
"max_credits": test_plan.max_credits,
|
"plan": type(
|
||||||
})(),
|
"Plan",
|
||||||
})()
|
(),
|
||||||
|
{
|
||||||
|
"id": test_plan.id,
|
||||||
|
"name": test_plan.name,
|
||||||
|
"max_credits": test_plan.max_credits,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
mock_regular = type("User", (), {
|
mock_regular = type(
|
||||||
"id": regular_user.id,
|
"User",
|
||||||
"email": regular_user.email,
|
(),
|
||||||
"name": regular_user.name,
|
{
|
||||||
"picture": None,
|
"id": regular_user.id,
|
||||||
"role": regular_user.role,
|
"email": regular_user.email,
|
||||||
"credits": regular_user.credits,
|
"name": regular_user.name,
|
||||||
"is_active": regular_user.is_active,
|
"picture": None,
|
||||||
"created_at": regular_user.created_at,
|
"role": regular_user.role,
|
||||||
"updated_at": regular_user.updated_at,
|
"credits": regular_user.credits,
|
||||||
"plan": type("Plan", (), {
|
"is_active": regular_user.is_active,
|
||||||
"id": test_plan.id,
|
"created_at": regular_user.created_at,
|
||||||
"name": test_plan.name,
|
"updated_at": regular_user.updated_at,
|
||||||
"max_credits": test_plan.max_credits,
|
"plan": type(
|
||||||
})(),
|
"Plan",
|
||||||
})()
|
(),
|
||||||
|
{
|
||||||
|
"id": test_plan.id,
|
||||||
|
"name": test_plan.name,
|
||||||
|
"max_credits": test_plan.max_credits,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
mock_get_all.return_value = [mock_admin, mock_regular]
|
# Mock returns tuple (users, total_count)
|
||||||
|
mock_get_all.return_value = ([mock_admin, mock_regular], 2)
|
||||||
|
|
||||||
response = await authenticated_admin_client.get("/api/v1/admin/users/")
|
response = await authenticated_admin_client.get("/api/v1/admin/users/")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert len(data) == 2
|
assert "users" in data
|
||||||
assert data[0]["email"] == "admin@example.com"
|
assert "total" in data
|
||||||
assert data[1]["email"] == "user@example.com"
|
assert "page" in data
|
||||||
mock_get_all.assert_called_once_with(limit=100, offset=0)
|
assert "limit" in data
|
||||||
|
assert "total_pages" in data
|
||||||
|
assert len(data["users"]) == 2
|
||||||
|
assert data["users"][0]["email"] == "admin@example.com"
|
||||||
|
assert data["users"][1]["email"] == "user@example.com"
|
||||||
|
assert data["total"] == 2
|
||||||
|
assert data["page"] == 1
|
||||||
|
assert data["limit"] == 50
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_users_with_pagination(
|
async def test_list_users_with_pagination(
|
||||||
@@ -115,29 +139,55 @@ class TestAdminUserEndpoints:
|
|||||||
test_plan: Plan,
|
test_plan: Plan,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test listing users with pagination."""
|
"""Test listing users with pagination."""
|
||||||
with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all:
|
from app.repositories.user import SortOrder, UserSortField, UserStatus
|
||||||
mock_admin = type("User", (), {
|
|
||||||
"id": admin_user.id,
|
|
||||||
"email": admin_user.email,
|
|
||||||
"name": admin_user.name,
|
|
||||||
"picture": None,
|
|
||||||
"role": admin_user.role,
|
|
||||||
"credits": admin_user.credits,
|
|
||||||
"is_active": admin_user.is_active,
|
|
||||||
"created_at": admin_user.created_at,
|
|
||||||
"updated_at": admin_user.updated_at,
|
|
||||||
"plan": type("Plan", (), {
|
|
||||||
"id": test_plan.id,
|
|
||||||
"name": test_plan.name,
|
|
||||||
"max_credits": test_plan.max_credits,
|
|
||||||
})(),
|
|
||||||
})()
|
|
||||||
mock_get_all.return_value = [mock_admin]
|
|
||||||
|
|
||||||
response = await authenticated_admin_client.get("/api/v1/admin/users/?limit=10&offset=5")
|
with patch(
|
||||||
|
"app.repositories.user.UserRepository.get_all_with_plan_paginated",
|
||||||
|
) as mock_get_all:
|
||||||
|
mock_admin = type(
|
||||||
|
"User",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"id": admin_user.id,
|
||||||
|
"email": admin_user.email,
|
||||||
|
"name": admin_user.name,
|
||||||
|
"picture": None,
|
||||||
|
"role": admin_user.role,
|
||||||
|
"credits": admin_user.credits,
|
||||||
|
"is_active": admin_user.is_active,
|
||||||
|
"created_at": admin_user.created_at,
|
||||||
|
"updated_at": admin_user.updated_at,
|
||||||
|
"plan": type(
|
||||||
|
"Plan",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"id": test_plan.id,
|
||||||
|
"name": test_plan.name,
|
||||||
|
"max_credits": test_plan.max_credits,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
# Mock returns tuple (users, total_count)
|
||||||
|
mock_get_all.return_value = ([mock_admin], 1)
|
||||||
|
|
||||||
|
response = await authenticated_admin_client.get(
|
||||||
|
"/api/v1/admin/users/?page=2&limit=10",
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
mock_get_all.assert_called_once_with(limit=10, offset=5)
|
data = response.json()
|
||||||
|
assert "users" in data
|
||||||
|
assert data["page"] == 2
|
||||||
|
assert data["limit"] == 10
|
||||||
|
mock_get_all.assert_called_once_with(
|
||||||
|
page=2,
|
||||||
|
limit=10,
|
||||||
|
search=None,
|
||||||
|
sort_by=UserSortField.NAME,
|
||||||
|
sort_order=SortOrder.ASC,
|
||||||
|
status_filter=UserStatus.ALL,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_users_unauthenticated(self, client: AsyncClient) -> None:
|
async def test_list_users_unauthenticated(self, client: AsyncClient) -> None:
|
||||||
@@ -153,7 +203,9 @@ class TestAdminUserEndpoints:
|
|||||||
regular_user: User,
|
regular_user: User,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test listing users as non-admin user."""
|
"""Test listing users as non-admin user."""
|
||||||
with patch("app.core.dependencies.get_current_active_user", return_value=regular_user):
|
with patch(
|
||||||
|
"app.core.dependencies.get_current_active_user", return_value=regular_user,
|
||||||
|
):
|
||||||
response = await client.get("/api/v1/admin/users/")
|
response = await client.get("/api/v1/admin/users/")
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
@@ -169,24 +221,34 @@ class TestAdminUserEndpoints:
|
|||||||
"""Test getting specific user successfully."""
|
"""Test getting specific user successfully."""
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
) as mock_get_by_id,
|
||||||
):
|
):
|
||||||
mock_user = type("User", (), {
|
mock_user = type(
|
||||||
"id": regular_user.id,
|
"User",
|
||||||
"email": regular_user.email,
|
(),
|
||||||
"name": regular_user.name,
|
{
|
||||||
"picture": None,
|
"id": regular_user.id,
|
||||||
"role": regular_user.role,
|
"email": regular_user.email,
|
||||||
"credits": regular_user.credits,
|
"name": regular_user.name,
|
||||||
"is_active": regular_user.is_active,
|
"picture": None,
|
||||||
"created_at": regular_user.created_at,
|
"role": regular_user.role,
|
||||||
"updated_at": regular_user.updated_at,
|
"credits": regular_user.credits,
|
||||||
"plan": type("Plan", (), {
|
"is_active": regular_user.is_active,
|
||||||
"id": test_plan.id,
|
"created_at": regular_user.created_at,
|
||||||
"name": test_plan.name,
|
"updated_at": regular_user.updated_at,
|
||||||
"max_credits": test_plan.max_credits,
|
"plan": type(
|
||||||
})(),
|
"Plan",
|
||||||
})()
|
(),
|
||||||
|
{
|
||||||
|
"id": test_plan.id,
|
||||||
|
"name": test_plan.name,
|
||||||
|
"max_credits": test_plan.max_credits,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
|
|
||||||
response = await authenticated_admin_client.get("/api/v1/admin/users/2")
|
response = await authenticated_admin_client.get("/api/v1/admin/users/2")
|
||||||
@@ -207,7 +269,10 @@ class TestAdminUserEndpoints:
|
|||||||
"""Test getting non-existent user."""
|
"""Test getting non-existent user."""
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
response = await authenticated_admin_client.get("/api/v1/admin/users/999")
|
response = await authenticated_admin_client.get("/api/v1/admin/users/999")
|
||||||
|
|
||||||
@@ -226,43 +291,63 @@ class TestAdminUserEndpoints:
|
|||||||
"""Test updating user successfully."""
|
"""Test updating user successfully."""
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
) as mock_get_by_id,
|
||||||
patch("app.repositories.user.UserRepository.update") as mock_update,
|
patch("app.repositories.user.UserRepository.update") as mock_update,
|
||||||
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan),
|
patch(
|
||||||
|
"app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
mock_user = type("User", (), {
|
mock_user = type(
|
||||||
"id": regular_user.id,
|
"User",
|
||||||
"email": regular_user.email,
|
(),
|
||||||
"name": regular_user.name,
|
{
|
||||||
"picture": None,
|
"id": regular_user.id,
|
||||||
"role": regular_user.role,
|
"email": regular_user.email,
|
||||||
"credits": regular_user.credits,
|
"name": regular_user.name,
|
||||||
"is_active": regular_user.is_active,
|
"picture": None,
|
||||||
"created_at": regular_user.created_at,
|
"role": regular_user.role,
|
||||||
"updated_at": regular_user.updated_at,
|
"credits": regular_user.credits,
|
||||||
"plan": type("Plan", (), {
|
"is_active": regular_user.is_active,
|
||||||
"id": test_plan.id,
|
"created_at": regular_user.created_at,
|
||||||
"name": test_plan.name,
|
"updated_at": regular_user.updated_at,
|
||||||
"max_credits": test_plan.max_credits,
|
"plan": type(
|
||||||
})(),
|
"Plan",
|
||||||
})()
|
(),
|
||||||
|
{
|
||||||
|
"id": test_plan.id,
|
||||||
|
"name": test_plan.name,
|
||||||
|
"max_credits": test_plan.max_credits,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
updated_mock = type("User", (), {
|
updated_mock = type(
|
||||||
"id": regular_user.id,
|
"User",
|
||||||
"email": regular_user.email,
|
(),
|
||||||
"name": "Updated Name",
|
{
|
||||||
"picture": None,
|
"id": regular_user.id,
|
||||||
"role": regular_user.role,
|
"email": regular_user.email,
|
||||||
"credits": 200,
|
"name": "Updated Name",
|
||||||
"is_active": regular_user.is_active,
|
"picture": None,
|
||||||
"created_at": regular_user.created_at,
|
"role": regular_user.role,
|
||||||
"updated_at": regular_user.updated_at,
|
"credits": 200,
|
||||||
"plan": type("Plan", (), {
|
"is_active": regular_user.is_active,
|
||||||
"id": test_plan.id,
|
"created_at": regular_user.created_at,
|
||||||
"name": test_plan.name,
|
"updated_at": regular_user.updated_at,
|
||||||
"max_credits": test_plan.max_credits,
|
"plan": type(
|
||||||
})(),
|
"Plan",
|
||||||
})()
|
(),
|
||||||
|
{
|
||||||
|
"id": test_plan.id,
|
||||||
|
"name": test_plan.name,
|
||||||
|
"max_credits": test_plan.max_credits,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
mock_update.return_value = updated_mock
|
mock_update.return_value = updated_mock
|
||||||
@@ -271,7 +356,10 @@ class TestAdminUserEndpoints:
|
|||||||
async def mock_refresh(instance, attributes=None):
|
async def mock_refresh(instance, attributes=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with patch("sqlmodel.ext.asyncio.session.AsyncSession.refresh", side_effect=mock_refresh):
|
with patch(
|
||||||
|
"sqlmodel.ext.asyncio.session.AsyncSession.refresh",
|
||||||
|
side_effect=mock_refresh,
|
||||||
|
):
|
||||||
response = await authenticated_admin_client.patch(
|
response = await authenticated_admin_client.patch(
|
||||||
"/api/v1/admin/users/2",
|
"/api/v1/admin/users/2",
|
||||||
json={
|
json={
|
||||||
@@ -295,7 +383,10 @@ class TestAdminUserEndpoints:
|
|||||||
"""Test updating non-existent user."""
|
"""Test updating non-existent user."""
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
response = await authenticated_admin_client.patch(
|
response = await authenticated_admin_client.patch(
|
||||||
"/api/v1/admin/users/999",
|
"/api/v1/admin/users/999",
|
||||||
@@ -316,25 +407,35 @@ class TestAdminUserEndpoints:
|
|||||||
"""Test updating user with invalid plan."""
|
"""Test updating user with invalid plan."""
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
) as mock_get_by_id,
|
||||||
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=None),
|
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=None),
|
||||||
):
|
):
|
||||||
mock_user = type("User", (), {
|
mock_user = type(
|
||||||
"id": regular_user.id,
|
"User",
|
||||||
"email": regular_user.email,
|
(),
|
||||||
"name": regular_user.name,
|
{
|
||||||
"picture": None,
|
"id": regular_user.id,
|
||||||
"role": regular_user.role,
|
"email": regular_user.email,
|
||||||
"credits": regular_user.credits,
|
"name": regular_user.name,
|
||||||
"is_active": regular_user.is_active,
|
"picture": None,
|
||||||
"created_at": regular_user.created_at,
|
"role": regular_user.role,
|
||||||
"updated_at": regular_user.updated_at,
|
"credits": regular_user.credits,
|
||||||
"plan": type("Plan", (), {
|
"is_active": regular_user.is_active,
|
||||||
"id": 1,
|
"created_at": regular_user.created_at,
|
||||||
"name": "Basic",
|
"updated_at": regular_user.updated_at,
|
||||||
"max_credits": 100,
|
"plan": type(
|
||||||
})(),
|
"Plan",
|
||||||
})()
|
(),
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "Basic",
|
||||||
|
"max_credits": 100,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
response = await authenticated_admin_client.patch(
|
response = await authenticated_admin_client.patch(
|
||||||
"/api/v1/admin/users/2",
|
"/api/v1/admin/users/2",
|
||||||
@@ -356,29 +457,41 @@ class TestAdminUserEndpoints:
|
|||||||
"""Test disabling user successfully."""
|
"""Test disabling user successfully."""
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
) as mock_get_by_id,
|
||||||
patch("app.repositories.user.UserRepository.update") as mock_update,
|
patch("app.repositories.user.UserRepository.update") as mock_update,
|
||||||
):
|
):
|
||||||
mock_user = type("User", (), {
|
mock_user = type(
|
||||||
"id": regular_user.id,
|
"User",
|
||||||
"email": regular_user.email,
|
(),
|
||||||
"name": regular_user.name,
|
{
|
||||||
"picture": None,
|
"id": regular_user.id,
|
||||||
"role": regular_user.role,
|
"email": regular_user.email,
|
||||||
"credits": regular_user.credits,
|
"name": regular_user.name,
|
||||||
"is_active": regular_user.is_active,
|
"picture": None,
|
||||||
"created_at": regular_user.created_at,
|
"role": regular_user.role,
|
||||||
"updated_at": regular_user.updated_at,
|
"credits": regular_user.credits,
|
||||||
"plan": type("Plan", (), {
|
"is_active": regular_user.is_active,
|
||||||
"id": test_plan.id,
|
"created_at": regular_user.created_at,
|
||||||
"name": test_plan.name,
|
"updated_at": regular_user.updated_at,
|
||||||
"max_credits": test_plan.max_credits,
|
"plan": type(
|
||||||
})(),
|
"Plan",
|
||||||
})()
|
(),
|
||||||
|
{
|
||||||
|
"id": test_plan.id,
|
||||||
|
"name": test_plan.name,
|
||||||
|
"max_credits": test_plan.max_credits,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
mock_update.return_value = mock_user
|
mock_update.return_value = mock_user
|
||||||
|
|
||||||
response = await authenticated_admin_client.post("/api/v1/admin/users/2/disable")
|
response = await authenticated_admin_client.post(
|
||||||
|
"/api/v1/admin/users/2/disable",
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -393,9 +506,14 @@ class TestAdminUserEndpoints:
|
|||||||
"""Test disabling non-existent user."""
|
"""Test disabling non-existent user."""
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
response = await authenticated_admin_client.post("/api/v1/admin/users/999/disable")
|
response = await authenticated_admin_client.post(
|
||||||
|
"/api/v1/admin/users/999/disable",
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -421,29 +539,41 @@ class TestAdminUserEndpoints:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
) as mock_get_by_id,
|
||||||
patch("app.repositories.user.UserRepository.update") as mock_update,
|
patch("app.repositories.user.UserRepository.update") as mock_update,
|
||||||
):
|
):
|
||||||
mock_disabled_user = type("User", (), {
|
mock_disabled_user = type(
|
||||||
"id": disabled_user.id,
|
"User",
|
||||||
"email": disabled_user.email,
|
(),
|
||||||
"name": disabled_user.name,
|
{
|
||||||
"picture": None,
|
"id": disabled_user.id,
|
||||||
"role": disabled_user.role,
|
"email": disabled_user.email,
|
||||||
"credits": disabled_user.credits,
|
"name": disabled_user.name,
|
||||||
"is_active": disabled_user.is_active,
|
"picture": None,
|
||||||
"created_at": disabled_user.created_at,
|
"role": disabled_user.role,
|
||||||
"updated_at": disabled_user.updated_at,
|
"credits": disabled_user.credits,
|
||||||
"plan": type("Plan", (), {
|
"is_active": disabled_user.is_active,
|
||||||
"id": test_plan.id,
|
"created_at": disabled_user.created_at,
|
||||||
"name": test_plan.name,
|
"updated_at": disabled_user.updated_at,
|
||||||
"max_credits": test_plan.max_credits,
|
"plan": type(
|
||||||
})(),
|
"Plan",
|
||||||
})()
|
(),
|
||||||
|
{
|
||||||
|
"id": test_plan.id,
|
||||||
|
"name": test_plan.name,
|
||||||
|
"max_credits": test_plan.max_credits,
|
||||||
|
},
|
||||||
|
)(),
|
||||||
|
},
|
||||||
|
)()
|
||||||
mock_get_by_id.return_value = mock_disabled_user
|
mock_get_by_id.return_value = mock_disabled_user
|
||||||
mock_update.return_value = mock_disabled_user
|
mock_update.return_value = mock_disabled_user
|
||||||
|
|
||||||
response = await authenticated_admin_client.post("/api/v1/admin/users/3/enable")
|
response = await authenticated_admin_client.post(
|
||||||
|
"/api/v1/admin/users/3/enable",
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -458,9 +588,14 @@ class TestAdminUserEndpoints:
|
|||||||
"""Test enabling non-existent user."""
|
"""Test enabling non-existent user."""
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
|
patch(
|
||||||
|
"app.repositories.user.UserRepository.get_by_id_with_plan",
|
||||||
|
return_value=None,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
response = await authenticated_admin_client.post("/api/v1/admin/users/999/enable")
|
response = await authenticated_admin_client.post(
|
||||||
|
"/api/v1/admin/users/999/enable",
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -479,9 +614,14 @@ class TestAdminUserEndpoints:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
|
||||||
patch("app.repositories.plan.PlanRepository.get_all", return_value=[basic_plan, premium_plan]),
|
patch(
|
||||||
|
"app.repositories.plan.PlanRepository.get_all",
|
||||||
|
return_value=[basic_plan, premium_plan],
|
||||||
|
),
|
||||||
):
|
):
|
||||||
response = await authenticated_admin_client.get("/api/v1/admin/users/plans/list")
|
response = await authenticated_admin_client.get(
|
||||||
|
"/api/v1/admin/users/plans/list",
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|||||||
@@ -488,11 +488,17 @@ class TestAuthEndpoints:
|
|||||||
test_plan: Plan,
|
test_plan: Plan,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test refresh token success."""
|
"""Test refresh token success."""
|
||||||
with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh:
|
with patch(
|
||||||
mock_refresh.return_value = type("TokenResponse", (), {
|
"app.services.auth.AuthService.refresh_access_token",
|
||||||
"access_token": "new_access_token",
|
) as mock_refresh:
|
||||||
"expires_in": 3600,
|
mock_refresh.return_value = type(
|
||||||
})()
|
"TokenResponse",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"access_token": "new_access_token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/v1/auth/refresh",
|
"/api/v1/auth/refresh",
|
||||||
@@ -516,7 +522,9 @@ class TestAuthEndpoints:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_refresh_token_service_error(self, test_client: AsyncClient) -> None:
|
async def test_refresh_token_service_error(self, test_client: AsyncClient) -> None:
|
||||||
"""Test refresh token with service error."""
|
"""Test refresh token with service error."""
|
||||||
with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh:
|
with patch(
|
||||||
|
"app.services.auth.AuthService.refresh_access_token",
|
||||||
|
) as mock_refresh:
|
||||||
mock_refresh.side_effect = Exception("Database error")
|
mock_refresh.side_effect = Exception("Database error")
|
||||||
|
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
@@ -528,7 +536,6 @@ class TestAuthEndpoints:
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
assert "Token refresh failed" in data["detail"]
|
assert "Token refresh failed" in data["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exchange_oauth_token_invalid_code(
|
async def test_exchange_oauth_token_invalid_code(
|
||||||
self,
|
self,
|
||||||
@@ -554,7 +561,9 @@ class TestAuthEndpoints:
|
|||||||
"""Test update profile success."""
|
"""Test update profile success."""
|
||||||
with (
|
with (
|
||||||
patch("app.services.auth.AuthService.update_user_profile") as mock_update,
|
patch("app.services.auth.AuthService.update_user_profile") as mock_update,
|
||||||
patch("app.services.auth.AuthService.user_to_response") as mock_user_to_response,
|
patch(
|
||||||
|
"app.services.auth.AuthService.user_to_response",
|
||||||
|
) as mock_user_to_response,
|
||||||
):
|
):
|
||||||
updated_user = User(
|
updated_user = User(
|
||||||
id=test_user.id,
|
id=test_user.id,
|
||||||
@@ -569,6 +578,7 @@ class TestAuthEndpoints:
|
|||||||
|
|
||||||
# Mock the user_to_response to return UserResponse format
|
# Mock the user_to_response to return UserResponse format
|
||||||
from app.schemas.auth import UserResponse
|
from app.schemas.auth import UserResponse
|
||||||
|
|
||||||
mock_user_to_response.return_value = UserResponse(
|
mock_user_to_response.return_value = UserResponse(
|
||||||
id=test_user.id,
|
id=test_user.id,
|
||||||
email=test_user.email,
|
email=test_user.email,
|
||||||
@@ -598,7 +608,9 @@ class TestAuthEndpoints:
|
|||||||
assert data["name"] == "Updated Name"
|
assert data["name"] == "Updated Name"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_profile_unauthenticated(self, test_client: AsyncClient) -> None:
|
async def test_update_profile_unauthenticated(
|
||||||
|
self, test_client: AsyncClient,
|
||||||
|
) -> None:
|
||||||
"""Test update profile without authentication."""
|
"""Test update profile without authentication."""
|
||||||
response = await test_client.patch(
|
response = await test_client.patch(
|
||||||
"/api/v1/auth/me",
|
"/api/v1/auth/me",
|
||||||
@@ -632,7 +644,9 @@ class TestAuthEndpoints:
|
|||||||
assert data["message"] == "Password changed successfully"
|
assert data["message"] == "Password changed successfully"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_change_password_unauthenticated(self, test_client: AsyncClient) -> None:
|
async def test_change_password_unauthenticated(
|
||||||
|
self, test_client: AsyncClient,
|
||||||
|
) -> None:
|
||||||
"""Test change password without authentication."""
|
"""Test change password without authentication."""
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/v1/auth/change-password",
|
"/api/v1/auth/change-password",
|
||||||
@@ -652,7 +666,9 @@ class TestAuthEndpoints:
|
|||||||
auth_cookies: dict[str, str],
|
auth_cookies: dict[str, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test get user OAuth providers success."""
|
"""Test get user OAuth providers success."""
|
||||||
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
|
with patch(
|
||||||
|
"app.services.auth.AuthService.get_user_oauth_providers",
|
||||||
|
) as mock_providers:
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.models.user_oauth import UserOauth
|
from app.models.user_oauth import UserOauth
|
||||||
@@ -699,7 +715,9 @@ class TestAuthEndpoints:
|
|||||||
assert data[2]["display_name"] == "GitHub"
|
assert data[2]["display_name"] == "GitHub"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_user_providers_unauthenticated(self, test_client: AsyncClient) -> None:
|
async def test_get_user_providers_unauthenticated(
|
||||||
|
self, test_client: AsyncClient,
|
||||||
|
) -> None:
|
||||||
"""Test get user OAuth providers without authentication."""
|
"""Test get user OAuth providers without authentication."""
|
||||||
response = await test_client.get("/api/v1/auth/user-providers")
|
response = await test_client.get("/api/v1/auth/user-providers")
|
||||||
|
|
||||||
|
|||||||
@@ -109,9 +109,15 @@ class TestPlaylistEndpoints:
|
|||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert len(data) == 2
|
assert "playlists" in data
|
||||||
|
assert "total" in data
|
||||||
|
assert "page" in data
|
||||||
|
assert "limit" in data
|
||||||
|
assert "total_pages" in data
|
||||||
|
assert len(data["playlists"]) == 2
|
||||||
|
assert data["total"] == 2
|
||||||
|
|
||||||
playlist_names = {p["name"] for p in data}
|
playlist_names = {p["name"] for p in data["playlists"]}
|
||||||
assert "Test Playlist" in playlist_names
|
assert "Test Playlist" in playlist_names
|
||||||
assert "Main Playlist" in playlist_names
|
assert "Main Playlist" in playlist_names
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from httpx import AsyncClient
|
|||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.services.extraction import ExtractionInfo
|
from app.services.extraction import ExtractionInfo, PaginatedExtractionsResponse
|
||||||
|
|
||||||
|
|
||||||
class TestSoundEndpoints:
|
class TestSoundEndpoints:
|
||||||
@@ -32,6 +32,7 @@ class TestSoundEndpoints:
|
|||||||
"error": None,
|
"error": None,
|
||||||
"sound_id": None,
|
"sound_id": None,
|
||||||
"user_id": authenticated_user.id,
|
"user_id": authenticated_user.id,
|
||||||
|
"user_name": authenticated_user.name,
|
||||||
"created_at": "2025-08-03T12:00:00Z",
|
"created_at": "2025-08-03T12:00:00Z",
|
||||||
"updated_at": "2025-08-03T12:00:00Z",
|
"updated_at": "2025-08-03T12:00:00Z",
|
||||||
}
|
}
|
||||||
@@ -111,6 +112,7 @@ class TestSoundEndpoints:
|
|||||||
"error": None,
|
"error": None,
|
||||||
"sound_id": 42,
|
"sound_id": 42,
|
||||||
"user_id": authenticated_user.id,
|
"user_id": authenticated_user.id,
|
||||||
|
"user_name": authenticated_user.name,
|
||||||
"created_at": "2025-08-03T12:00:00Z",
|
"created_at": "2025-08-03T12:00:00Z",
|
||||||
"updated_at": "2025-08-03T12:00:00Z",
|
"updated_at": "2025-08-03T12:00:00Z",
|
||||||
}
|
}
|
||||||
@@ -154,41 +156,49 @@ class TestSoundEndpoints:
|
|||||||
authenticated_user: User,
|
authenticated_user: User,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test getting user extractions."""
|
"""Test getting user extractions."""
|
||||||
mock_extractions: list[ExtractionInfo] = [
|
mock_extractions: PaginatedExtractionsResponse = {
|
||||||
{
|
"extractions": [
|
||||||
"id": 1,
|
{
|
||||||
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
"id": 1,
|
||||||
"title": "Never Gonna Give You Up",
|
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||||
"service": "youtube",
|
"title": "Never Gonna Give You Up",
|
||||||
"service_id": "dQw4w9WgXcQ",
|
"service": "youtube",
|
||||||
"status": "completed",
|
"service_id": "dQw4w9WgXcQ",
|
||||||
"error": None,
|
"status": "completed",
|
||||||
"sound_id": 42,
|
"error": None,
|
||||||
"user_id": authenticated_user.id,
|
"sound_id": 42,
|
||||||
"created_at": "2025-08-03T12:00:00Z",
|
"user_id": authenticated_user.id,
|
||||||
"updated_at": "2025-08-03T12:00:00Z",
|
"user_name": authenticated_user.name,
|
||||||
},
|
"created_at": "2025-08-03T12:00:00Z",
|
||||||
{
|
"updated_at": "2025-08-03T12:00:00Z",
|
||||||
"id": 2,
|
},
|
||||||
"url": "https://soundcloud.com/example/track",
|
{
|
||||||
"title": "Example Track",
|
"id": 2,
|
||||||
"service": "soundcloud",
|
"url": "https://soundcloud.com/example/track",
|
||||||
"service_id": "example-track",
|
"title": "Example Track",
|
||||||
"status": "pending",
|
"service": "soundcloud",
|
||||||
"error": None,
|
"service_id": "example-track",
|
||||||
"sound_id": None,
|
"status": "pending",
|
||||||
"user_id": authenticated_user.id,
|
"error": None,
|
||||||
"created_at": "2025-08-03T12:00:00Z",
|
"sound_id": None,
|
||||||
"updated_at": "2025-08-03T12:00:00Z",
|
"user_id": authenticated_user.id,
|
||||||
},
|
"user_name": authenticated_user.name,
|
||||||
]
|
"created_at": "2025-08-03T12:00:00Z",
|
||||||
|
"updated_at": "2025-08-03T12:00:00Z",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"total": 2,
|
||||||
|
"page": 1,
|
||||||
|
"limit": 50,
|
||||||
|
"total_pages": 1,
|
||||||
|
}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.extraction.ExtractionService.get_user_extractions",
|
"app.services.extraction.ExtractionService.get_user_extractions",
|
||||||
) as mock_get:
|
) as mock_get:
|
||||||
mock_get.return_value = mock_extractions
|
mock_get.return_value = mock_extractions
|
||||||
|
|
||||||
response = await authenticated_client.get("/api/v1/extractions/")
|
response = await authenticated_client.get("/api/v1/extractions/user")
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
@@ -337,7 +347,9 @@ class TestSoundEndpoints:
|
|||||||
"""Test getting sounds with authentication."""
|
"""Test getting sounds with authentication."""
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
|
|
||||||
with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get:
|
with patch(
|
||||||
|
"app.repositories.sound.SoundRepository.search_and_sort",
|
||||||
|
) as mock_get:
|
||||||
# Create mock sounds with all required fields
|
# Create mock sounds with all required fields
|
||||||
mock_sound_1 = Sound(
|
mock_sound_1 = Sound(
|
||||||
id=1,
|
id=1,
|
||||||
@@ -383,7 +395,9 @@ class TestSoundEndpoints:
|
|||||||
"""Test getting sounds with type filtering."""
|
"""Test getting sounds with type filtering."""
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
|
|
||||||
with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get:
|
with patch(
|
||||||
|
"app.repositories.sound.SoundRepository.search_and_sort",
|
||||||
|
) as mock_get:
|
||||||
# Create mock sound with all required fields
|
# Create mock sound with all required fields
|
||||||
mock_sound = Sound(
|
mock_sound = Sound(
|
||||||
id=1,
|
id=1,
|
||||||
|
|||||||
@@ -335,5 +335,3 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
|
|||||||
access_token = JWTUtils.create_access_token(token_data)
|
access_token = JWTUtils.create_access_token(token_data)
|
||||||
|
|
||||||
return {"access_token": access_token}
|
return {"access_token": access_token}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -539,21 +539,35 @@ class TestPlaylistRepository:
|
|||||||
sound_ids = [s.id for s in sounds]
|
sound_ids = [s.id for s in sounds]
|
||||||
|
|
||||||
# Add first two sounds sequentially (positions 0, 1)
|
# Add first two sounds sequentially (positions 0, 1)
|
||||||
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0
|
await playlist_repository.add_sound_to_playlist(
|
||||||
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1
|
playlist_id, sound_ids[0],
|
||||||
|
) # position 0
|
||||||
|
await playlist_repository.add_sound_to_playlist(
|
||||||
|
playlist_id, sound_ids[1],
|
||||||
|
) # position 1
|
||||||
|
|
||||||
# Now insert third sound at position 1 - should shift existing sound at position 1 to position 2
|
# Now insert third sound at position 1 - should shift existing sound at position 1 to position 2
|
||||||
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=1)
|
await playlist_repository.add_sound_to_playlist(
|
||||||
|
playlist_id, sound_ids[2], position=1,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the final positions
|
# Verify the final positions
|
||||||
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
|
playlist_sounds = await playlist_repository.get_playlist_sound_entries(
|
||||||
|
playlist_id,
|
||||||
|
)
|
||||||
|
|
||||||
assert len(playlist_sounds) == 3
|
assert len(playlist_sounds) == 3
|
||||||
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0
|
assert (
|
||||||
|
playlist_sounds[0].sound_id == sound_ids[0]
|
||||||
|
) # Original sound 0 stays at position 0
|
||||||
assert playlist_sounds[0].position == 0
|
assert playlist_sounds[0].position == 0
|
||||||
assert playlist_sounds[1].sound_id == sound_ids[2] # New sound 2 inserted at position 1
|
assert (
|
||||||
|
playlist_sounds[1].sound_id == sound_ids[2]
|
||||||
|
) # New sound 2 inserted at position 1
|
||||||
assert playlist_sounds[1].position == 1
|
assert playlist_sounds[1].position == 1
|
||||||
assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2
|
assert (
|
||||||
|
playlist_sounds[2].sound_id == sound_ids[1]
|
||||||
|
) # Original sound 1 shifted to position 2
|
||||||
assert playlist_sounds[2].position == 2
|
assert playlist_sounds[2].position == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -615,21 +629,35 @@ class TestPlaylistRepository:
|
|||||||
sound_ids = [s.id for s in sounds]
|
sound_ids = [s.id for s in sounds]
|
||||||
|
|
||||||
# Add first two sounds sequentially (positions 0, 1)
|
# Add first two sounds sequentially (positions 0, 1)
|
||||||
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0
|
await playlist_repository.add_sound_to_playlist(
|
||||||
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1
|
playlist_id, sound_ids[0],
|
||||||
|
) # position 0
|
||||||
|
await playlist_repository.add_sound_to_playlist(
|
||||||
|
playlist_id, sound_ids[1],
|
||||||
|
) # position 1
|
||||||
|
|
||||||
# Now insert third sound at position 0 - should shift existing sounds to positions 1, 2
|
# Now insert third sound at position 0 - should shift existing sounds to positions 1, 2
|
||||||
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=0)
|
await playlist_repository.add_sound_to_playlist(
|
||||||
|
playlist_id, sound_ids[2], position=0,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the final positions
|
# Verify the final positions
|
||||||
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
|
playlist_sounds = await playlist_repository.get_playlist_sound_entries(
|
||||||
|
playlist_id,
|
||||||
|
)
|
||||||
|
|
||||||
assert len(playlist_sounds) == 3
|
assert len(playlist_sounds) == 3
|
||||||
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0
|
assert (
|
||||||
|
playlist_sounds[0].sound_id == sound_ids[2]
|
||||||
|
) # New sound 2 inserted at position 0
|
||||||
assert playlist_sounds[0].position == 0
|
assert playlist_sounds[0].position == 0
|
||||||
assert playlist_sounds[1].sound_id == sound_ids[0] # Original sound 0 shifted to position 1
|
assert (
|
||||||
|
playlist_sounds[1].sound_id == sound_ids[0]
|
||||||
|
) # Original sound 0 shifted to position 1
|
||||||
assert playlist_sounds[1].position == 1
|
assert playlist_sounds[1].position == 1
|
||||||
assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2
|
assert (
|
||||||
|
playlist_sounds[2].sound_id == sound_ids[1]
|
||||||
|
) # Original sound 1 shifted to position 2
|
||||||
assert playlist_sounds[2].position == 2
|
assert playlist_sounds[2].position == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -409,7 +409,9 @@ class TestCreditService:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_recharge_user_credits_success(
|
async def test_recharge_user_credits_success(
|
||||||
self, credit_service, sample_user,
|
self,
|
||||||
|
credit_service,
|
||||||
|
sample_user,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful credit recharge for a user."""
|
"""Test successful credit recharge for a user."""
|
||||||
mock_session = credit_service.db_session_factory()
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|||||||
@@ -43,7 +43,9 @@ class TestDashboardService:
|
|||||||
"total_duration": 75000,
|
"total_duration": 75000,
|
||||||
"total_size": 1024000,
|
"total_size": 1024000,
|
||||||
}
|
}
|
||||||
mock_sound_repository.get_soundboard_statistics = AsyncMock(return_value=mock_stats)
|
mock_sound_repository.get_soundboard_statistics = AsyncMock(
|
||||||
|
return_value=mock_stats,
|
||||||
|
)
|
||||||
|
|
||||||
result = await dashboard_service.get_soundboard_statistics()
|
result = await dashboard_service.get_soundboard_statistics()
|
||||||
|
|
||||||
|
|||||||
@@ -99,6 +99,11 @@ class TestExtractionService:
|
|||||||
url = "https://www.youtube.com/watch?v=test123"
|
url = "https://www.youtube.com/watch?v=test123"
|
||||||
user_id = 1
|
user_id = 1
|
||||||
|
|
||||||
|
# Mock user for user_name retrieval
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.name = "Test User"
|
||||||
|
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
|
||||||
|
|
||||||
# Mock repository call - no service detection happens during creation
|
# Mock repository call - no service detection happens during creation
|
||||||
mock_extraction = Extraction(
|
mock_extraction = Extraction(
|
||||||
id=1,
|
id=1,
|
||||||
@@ -120,6 +125,7 @@ class TestExtractionService:
|
|||||||
assert result["service_id"] is None # Not detected during creation
|
assert result["service_id"] is None # Not detected during creation
|
||||||
assert result["title"] is None # Not detected during creation
|
assert result["title"] is None # Not detected during creation
|
||||||
assert result["status"] == "pending"
|
assert result["status"] == "pending"
|
||||||
|
assert result["user_name"] == "Test User"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_extraction_basic(self, extraction_service) -> None:
|
async def test_create_extraction_basic(self, extraction_service) -> None:
|
||||||
@@ -127,6 +133,11 @@ class TestExtractionService:
|
|||||||
url = "https://www.youtube.com/watch?v=test123"
|
url = "https://www.youtube.com/watch?v=test123"
|
||||||
user_id = 1
|
user_id = 1
|
||||||
|
|
||||||
|
# Mock user for user_name retrieval
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.name = "Test User"
|
||||||
|
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
|
||||||
|
|
||||||
# Mock repository call - creation always succeeds now
|
# Mock repository call - creation always succeeds now
|
||||||
mock_extraction = Extraction(
|
mock_extraction = Extraction(
|
||||||
id=2,
|
id=2,
|
||||||
@@ -146,6 +157,7 @@ class TestExtractionService:
|
|||||||
assert result["id"] == 2
|
assert result["id"] == 2
|
||||||
assert result["url"] == url
|
assert result["url"] == url
|
||||||
assert result["status"] == "pending"
|
assert result["status"] == "pending"
|
||||||
|
assert result["user_name"] == "Test User"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_extraction_any_url(self, extraction_service) -> None:
|
async def test_create_extraction_any_url(self, extraction_service) -> None:
|
||||||
@@ -153,6 +165,11 @@ class TestExtractionService:
|
|||||||
url = "https://invalid.url"
|
url = "https://invalid.url"
|
||||||
user_id = 1
|
user_id = 1
|
||||||
|
|
||||||
|
# Mock user for user_name retrieval
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.name = "Test User"
|
||||||
|
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
|
||||||
|
|
||||||
# Mock repository call - even invalid URLs are accepted during creation
|
# Mock repository call - even invalid URLs are accepted during creation
|
||||||
mock_extraction = Extraction(
|
mock_extraction = Extraction(
|
||||||
id=3,
|
id=3,
|
||||||
@@ -172,6 +189,7 @@ class TestExtractionService:
|
|||||||
assert result["id"] == 3
|
assert result["id"] == 3
|
||||||
assert result["url"] == url
|
assert result["url"] == url
|
||||||
assert result["status"] == "pending"
|
assert result["status"] == "pending"
|
||||||
|
assert result["user_name"] == "Test User"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_extraction_with_service_detection(
|
async def test_process_extraction_with_service_detection(
|
||||||
@@ -408,9 +426,16 @@ class TestExtractionService:
|
|||||||
sound_id=42,
|
sound_id=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mock user for user_name retrieval
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.name = "Test User"
|
||||||
|
|
||||||
extraction_service.extraction_repo.get_by_id = AsyncMock(
|
extraction_service.extraction_repo.get_by_id = AsyncMock(
|
||||||
return_value=extraction,
|
return_value=extraction,
|
||||||
)
|
)
|
||||||
|
extraction_service.user_repo.get_by_id = AsyncMock(
|
||||||
|
return_value=mock_user,
|
||||||
|
)
|
||||||
|
|
||||||
result = await extraction_service.get_extraction_by_id(1)
|
result = await extraction_service.get_extraction_by_id(1)
|
||||||
|
|
||||||
@@ -421,6 +446,7 @@ class TestExtractionService:
|
|||||||
assert result["title"] == "Test Video"
|
assert result["title"] == "Test Video"
|
||||||
assert result["status"] == "completed"
|
assert result["status"] == "completed"
|
||||||
assert result["sound_id"] == 42
|
assert result["sound_id"] == 42
|
||||||
|
assert result["user_name"] == "Test User"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_extraction_by_id_not_found(self, extraction_service) -> None:
|
async def test_get_extraction_by_id_not_found(self, extraction_service) -> None:
|
||||||
@@ -434,52 +460,70 @@ class TestExtractionService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_user_extractions(self, extraction_service) -> None:
|
async def test_get_user_extractions(self, extraction_service) -> None:
|
||||||
"""Test getting user extractions."""
|
"""Test getting user extractions."""
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
user = User(id=1, name="Test User", email="test@example.com")
|
||||||
extractions = [
|
extractions = [
|
||||||
Extraction(
|
(
|
||||||
id=1,
|
Extraction(
|
||||||
service="youtube",
|
id=1,
|
||||||
service_id="test123",
|
service="youtube",
|
||||||
url="https://www.youtube.com/watch?v=test123",
|
service_id="test123",
|
||||||
user_id=1,
|
url="https://www.youtube.com/watch?v=test123",
|
||||||
title="Test Video 1",
|
user_id=1,
|
||||||
status="completed",
|
title="Test Video 1",
|
||||||
sound_id=42,
|
status="completed",
|
||||||
|
sound_id=42,
|
||||||
|
),
|
||||||
|
user,
|
||||||
),
|
),
|
||||||
Extraction(
|
(
|
||||||
id=2,
|
Extraction(
|
||||||
service="youtube",
|
id=2,
|
||||||
service_id="test456",
|
service="youtube",
|
||||||
url="https://www.youtube.com/watch?v=test456",
|
service_id="test456",
|
||||||
user_id=1,
|
url="https://www.youtube.com/watch?v=test456",
|
||||||
title="Test Video 2",
|
user_id=1,
|
||||||
status="pending",
|
title="Test Video 2",
|
||||||
|
status="pending",
|
||||||
|
),
|
||||||
|
user,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
extraction_service.extraction_repo.get_by_user = AsyncMock(
|
extraction_service.extraction_repo.get_user_extractions_filtered = AsyncMock(
|
||||||
return_value=extractions,
|
return_value=(extractions, 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await extraction_service.get_user_extractions(1)
|
result = await extraction_service.get_user_extractions(1)
|
||||||
|
|
||||||
assert len(result) == 2
|
assert result["total"] == 2
|
||||||
assert result[0]["id"] == 1
|
assert len(result["extractions"]) == 2
|
||||||
assert result[0]["title"] == "Test Video 1"
|
assert result["extractions"][0]["id"] == 1
|
||||||
assert result[1]["id"] == 2
|
assert result["extractions"][0]["title"] == "Test Video 1"
|
||||||
assert result[1]["title"] == "Test Video 2"
|
assert result["extractions"][0]["user_name"] == "Test User"
|
||||||
|
assert result["extractions"][1]["id"] == 2
|
||||||
|
assert result["extractions"][1]["title"] == "Test Video 2"
|
||||||
|
assert result["extractions"][1]["user_name"] == "Test User"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_pending_extractions(self, extraction_service) -> None:
|
async def test_get_pending_extractions(self, extraction_service) -> None:
|
||||||
"""Test getting pending extractions."""
|
"""Test getting pending extractions."""
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
user = User(id=1, name="Test User", email="test@example.com")
|
||||||
pending_extractions = [
|
pending_extractions = [
|
||||||
Extraction(
|
(
|
||||||
id=1,
|
Extraction(
|
||||||
service="youtube",
|
id=1,
|
||||||
service_id="test123",
|
service="youtube",
|
||||||
url="https://www.youtube.com/watch?v=test123",
|
service_id="test123",
|
||||||
user_id=1,
|
url="https://www.youtube.com/watch?v=test123",
|
||||||
title="Pending Video",
|
user_id=1,
|
||||||
status="pending",
|
title="Pending Video",
|
||||||
|
status="pending",
|
||||||
|
),
|
||||||
|
user,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -492,3 +536,4 @@ class TestExtractionService:
|
|||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0]["id"] == 1
|
assert result[0]["id"] == 1
|
||||||
assert result[0]["status"] == "pending"
|
assert result[0]["status"] == "pending"
|
||||||
|
assert result[0]["user_name"] == "Test User"
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ class TestSchedulerService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_scheduler(self, scheduler_service) -> None:
|
async def test_start_scheduler(self, scheduler_service) -> None:
|
||||||
"""Test starting the scheduler service."""
|
"""Test starting the scheduler service."""
|
||||||
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
|
with (
|
||||||
patch.object(scheduler_service.scheduler, "start") as mock_start:
|
patch.object(scheduler_service.scheduler, "add_job") as mock_add_job,
|
||||||
|
patch.object(scheduler_service.scheduler, "start") as mock_start,
|
||||||
|
):
|
||||||
await scheduler_service.start()
|
await scheduler_service.start()
|
||||||
|
|
||||||
# Verify job was added
|
# Verify job was added
|
||||||
@@ -61,7 +62,9 @@ class TestSchedulerService:
|
|||||||
"total_credits_added": 500,
|
"total_credits_added": 500,
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
with patch.object(
|
||||||
|
scheduler_service.credit_service, "recharge_all_users_credits",
|
||||||
|
) as mock_recharge:
|
||||||
mock_recharge.return_value = mock_stats
|
mock_recharge.return_value = mock_stats
|
||||||
|
|
||||||
await scheduler_service._daily_credit_recharge()
|
await scheduler_service._daily_credit_recharge()
|
||||||
@@ -71,7 +74,9 @@ class TestSchedulerService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
|
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
|
||||||
"""Test daily credit recharge task with failure."""
|
"""Test daily credit recharge task with failure."""
|
||||||
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
with patch.object(
|
||||||
|
scheduler_service.credit_service, "recharge_all_users_credits",
|
||||||
|
) as mock_recharge:
|
||||||
mock_recharge.side_effect = Exception("Database error")
|
mock_recharge.side_effect = Exception("Database error")
|
||||||
|
|
||||||
# Should not raise exception, just log it
|
# Should not raise exception, just log it
|
||||||
|
|||||||
Reference in New Issue
Block a user