Refactor user endpoint tests to include pagination and response structure validation

- Updated tests for listing users to validate pagination and response format.
- Changed mock return values to include total count and pagination details.
- Refactored user creation mocks for clarity and consistency.
- Enhanced assertions to check for presence of pagination fields in responses.
- Adjusted test cases for user retrieval and updates to ensure proper handling of user data.
- Improved readability by restructuring mock definitions and assertions across various test files.
This commit is contained in:
JSC
2025-08-17 12:36:52 +02:00
parent e6f796a3c9
commit 6b55ff0e81
35 changed files with 863 additions and 503 deletions

View File

@@ -10,7 +10,7 @@ from app.core.dependencies import get_admin_user
from app.models.plan import Plan
from app.models.user import User
from app.repositories.plan import PlanRepository
from app.repositories.user import UserRepository, UserSortField, SortOrder, UserStatus
from app.repositories.user import SortOrder, UserRepository, UserSortField, UserStatus
from app.schemas.auth import UserResponse
from app.schemas.user import UserUpdate
@@ -36,21 +36,27 @@ def _user_to_response(user: User) -> UserResponse:
"name": user.plan.name,
"max_credits": user.plan.max_credits,
"features": [], # Add features if needed
} if user.plan else {},
}
if user.plan
else {},
created_at=user.created_at,
updated_at=user.updated_at,
)
@router.get("/")
async def list_users(
async def list_users( # noqa: PLR0913
session: Annotated[AsyncSession, Depends(get_db)],
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
search: Annotated[str | None, Query(description="Search in name or email")] = None,
sort_by: Annotated[UserSortField, Query(description="Sort by field")] = UserSortField.NAME,
sort_by: Annotated[
UserSortField, Query(description="Sort by field"),
] = UserSortField.NAME,
sort_order: Annotated[SortOrder, Query(description="Sort order")] = SortOrder.ASC,
status_filter: Annotated[UserStatus, Query(description="Filter by status")] = UserStatus.ALL,
status_filter: Annotated[
UserStatus, Query(description="Filter by status"),
] = UserStatus.ALL,
) -> dict[str, Any]:
"""Get all users with pagination, search, and filters (admin only)."""
user_repo = UserRepository(session)

View File

@@ -464,7 +464,8 @@ async def update_profile(
"""Update the current user's profile."""
try:
updated_user = await auth_service.update_user_profile(
current_user, request.model_dump(exclude_unset=True),
current_user,
request.model_dump(exclude_unset=True),
)
return await auth_service.user_to_response(updated_user)
except Exception as e:
@@ -486,7 +487,9 @@ async def change_password(
user_email = current_user.email
try:
await auth_service.change_user_password(
current_user, request.current_password, request.new_password,
current_user,
request.current_password,
request.new_password,
)
except ValueError as e:
raise HTTPException(
@@ -513,11 +516,13 @@ async def get_user_providers(
# Add password provider if user has password
if current_user.password_hash:
providers.append({
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
})
providers.append(
{
"provider": "password",
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
},
)
# Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
@@ -528,10 +533,12 @@ async def get_user_providers(
elif oauth.provider == "google":
display_name = "Google"
providers.append({
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
})
providers.append(
{
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
},
)
return providers

View File

@@ -34,7 +34,8 @@ async def get_top_sounds(
_current_user: Annotated[User, Depends(get_current_user)],
dashboard_service: Annotated[DashboardService, Depends(get_dashboard_service)],
sound_type: Annotated[
str, Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
str,
Query(description="Sound type filter (SDB, TTS, EXT, or 'all')"),
],
period: Annotated[
str,
@@ -43,7 +44,8 @@ async def get_top_sounds(
),
] = "all_time",
limit: Annotated[
int, Query(description="Number of top sounds to return", ge=1, le=100),
int,
Query(description="Number of top sounds to return", ge=1, le=100),
] = 10,
) -> list[dict[str, Any]]:
"""Get top sounds by play count for a specific period."""

View File

@@ -60,68 +60,13 @@ async def create_extraction(
}
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions(
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result
@router.get("/user")
async def get_user_extractions(
async def get_user_extractions( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[str | None, Query(description="Search in title, URL, or service")] = None,
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
@@ -153,3 +98,62 @@ async def get_user_extractions(
) from e
else:
return result
@router.get("/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
else:
return extraction_info
@router.get("/")
async def get_all_extractions( # noqa: PLR0913
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
search: Annotated[
str | None, Query(description="Search in title, URL, or service"),
] = None,
sort_by: Annotated[str, Query(description="Sort by field")] = "created_at",
sort_order: Annotated[str, Query(description="Sort order (asc/desc)")] = "desc",
status_filter: Annotated[str | None, Query(description="Filter by status")] = None,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
) -> dict:
"""Get all extractions with optional filtering, search, and sorting."""
try:
result = await extraction_service.get_all_extractions(
search=search,
sort_by=sort_by,
sort_order=sort_order,
status_filter=status_filter,
page=page,
limit=limit,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e
else:
return result

View File

@@ -4,6 +4,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from app.core.database import get_session_factory
from app.core.dependencies import get_current_active_user
from app.models.user import User
from app.schemas.common import MessageResponse
@@ -19,12 +20,10 @@ router = APIRouter(prefix="/favorites", tags=["favorites"])
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
from app.core.database import get_session_factory
return FavoriteService(get_session_factory())
@router.get("/", response_model=FavoritesListResponse)
@router.get("/")
async def get_user_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -33,12 +32,14 @@ async def get_user_favorites(
) -> FavoritesListResponse:
"""Get all favorites for the current user."""
favorites = await favorite_service.get_user_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/sounds", response_model=FavoritesListResponse)
@router.get("/sounds")
async def get_user_sound_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -47,12 +48,14 @@ async def get_user_sound_favorites(
) -> FavoritesListResponse:
"""Get sound favorites for the current user."""
favorites = await favorite_service.get_user_sound_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/playlists", response_model=FavoritesListResponse)
@router.get("/playlists")
async def get_user_playlist_favorites(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -61,12 +64,14 @@ async def get_user_playlist_favorites(
) -> FavoritesListResponse:
"""Get playlist favorites for the current user."""
favorites = await favorite_service.get_user_playlist_favorites(
current_user.id, limit, offset,
current_user.id,
limit,
offset,
)
return FavoritesListResponse(favorites=favorites)
@router.get("/counts", response_model=FavoriteCountsResponse)
@router.get("/counts")
async def get_favorite_counts(
current_user: Annotated[User, Depends(get_current_active_user)],
favorite_service: Annotated[FavoriteService, Depends(get_favorite_service)],
@@ -76,7 +81,7 @@ async def get_favorite_counts(
return FavoriteCountsResponse(**counts)
@router.post("/sounds/{sound_id}", response_model=FavoriteResponse)
@router.post("/sounds/{sound_id}")
async def add_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -103,7 +108,7 @@ async def add_sound_favorite(
) from e
@router.post("/playlists/{playlist_id}", response_model=FavoriteResponse)
@router.post("/playlists/{playlist_id}")
async def add_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -112,7 +117,8 @@ async def add_playlist_favorite(
"""Add a playlist to favorites."""
try:
favorite = await favorite_service.add_playlist_favorite(
current_user.id, playlist_id,
current_user.id,
playlist_id,
)
return FavoriteResponse.model_validate(favorite)
except ValueError as e:
@@ -132,7 +138,7 @@ async def add_playlist_favorite(
) from e
@router.delete("/sounds/{sound_id}", response_model=MessageResponse)
@router.delete("/sounds/{sound_id}")
async def remove_sound_favorite(
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -149,7 +155,7 @@ async def remove_sound_favorite(
) from e
@router.delete("/playlists/{playlist_id}", response_model=MessageResponse)
@router.delete("/playlists/{playlist_id}")
async def remove_playlist_favorite(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user)],
@@ -185,6 +191,7 @@ async def check_playlist_favorited(
) -> dict[str, bool]:
"""Check if a playlist is favorited by the current user."""
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id,
current_user.id,
playlist_id,
)
return {"is_favorited": is_favorited}

View File

@@ -5,7 +5,7 @@ from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.database import get_db, get_session_factory
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder
@@ -34,7 +34,6 @@ async def get_playlist_service(
def get_favorite_service() -> FavoriteService:
"""Get the favorite service."""
from app.core.database import get_session_factory
return FavoriteService(get_session_factory())
@@ -57,7 +56,7 @@ async def get_all_playlists( # noqa: PLR0913
] = SortOrder.ASC,
page: Annotated[int, Query(description="Page number", ge=1)] = 1,
limit: Annotated[int, Query(description="Items per page", ge=1, le=100)] = 50,
favorites_only: Annotated[
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited playlists"),
] = False,
@@ -78,15 +77,26 @@ async def get_all_playlists( # noqa: PLR0913
# Convert to PlaylistResponse with favorite indicators
playlist_responses = []
for playlist_dict in result["playlists"]:
# The playlist service returns dict, need to create playlist object-like structure
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist_dict["id"])
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_dict["id"])
# The playlist service returns dict, need to create playlist object structure
playlist_id = playlist_dict["id"]
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist_id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist_id)
# Create a PlaylistResponse-like dict with proper datetime conversion
playlist_response = {
**playlist_dict,
"created_at": playlist_dict["created_at"].isoformat() if playlist_dict["created_at"] else None,
"updated_at": playlist_dict["updated_at"].isoformat() if playlist_dict["updated_at"] else None,
"created_at": (
playlist_dict["created_at"].isoformat()
if playlist_dict["created_at"]
else None
),
"updated_at": (
playlist_dict["updated_at"].isoformat()
if playlist_dict["updated_at"]
else None
),
"is_favorited": is_favorited,
"favorite_count": favorite_count,
}
@@ -113,9 +123,13 @@ async def get_user_playlists(
# Add favorite indicators for each playlist
playlist_responses = []
for playlist in playlists:
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
playlist_response = PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
playlist_response = PlaylistResponse.from_playlist(
playlist, is_favorited, favorite_count,
)
playlist_responses.append(playlist_response)
return playlist_responses
@@ -129,7 +143,9 @@ async def get_main_playlist(
) -> PlaylistResponse:
"""Get the global main playlist."""
playlist = await playlist_service.get_main_playlist()
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@@ -142,7 +158,9 @@ async def get_current_playlist(
) -> PlaylistResponse:
"""Get the global current playlist (falls back to main playlist)."""
playlist = await playlist_service.get_current_playlist()
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)
@@ -172,7 +190,9 @@ async def get_playlist(
) -> PlaylistResponse:
"""Get a specific playlist."""
playlist = await playlist_service.get_playlist_by_id(playlist_id)
is_favorited = await favorite_service.is_playlist_favorited(current_user.id, playlist.id)
is_favorited = await favorite_service.is_playlist_favorited(
current_user.id, playlist.id,
)
favorite_count = await favorite_service.get_playlist_favorite_count(playlist.id)
return PlaylistResponse.from_playlist(playlist, is_favorited, favorite_count)

View File

@@ -40,7 +40,7 @@ async def get_sound_repository(
return SoundRepository(session)
@router.get("/", response_model=SoundsListResponse)
@router.get("/")
async def get_sounds( # noqa: PLR0913
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
sound_repo: Annotated[SoundRepository, Depends(get_sound_repository)],
@@ -69,7 +69,7 @@ async def get_sounds( # noqa: PLR0913
int,
Query(description="Number of results to skip", ge=0),
] = 0,
favorites_only: Annotated[
favorites_only: Annotated[ # noqa: FBT002
bool,
Query(description="Show only favorited sounds"),
] = False,
@@ -90,9 +90,13 @@ async def get_sounds( # noqa: PLR0913
# Add favorite indicators for each sound
sound_responses = []
for sound in sounds:
is_favorited = await favorite_service.is_sound_favorited(current_user.id, sound.id)
is_favorited = await favorite_service.is_sound_favorited(
current_user.id, sound.id,
)
favorite_count = await favorite_service.get_sound_favorite_count(sound.id)
sound_response = SoundResponse.from_sound(sound, is_favorited, favorite_count)
sound_response = SoundResponse.from_sound(
sound, is_favorited, favorite_count,
)
sound_responses.append(sound_response)
except Exception as e:

View File

@@ -20,7 +20,7 @@ class Settings(BaseSettings):
# Production URLs (for reverse proxy deployment)
FRONTEND_URL: str = "http://localhost:8001" # Frontend URL in production
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
BACKEND_URL: str = "http://localhost:8000" # Backend base URL
# CORS Configuration
CORS_ORIGINS: list[str] = ["http://localhost:8001"] # Allowed origins for CORS

View File

@@ -20,7 +20,9 @@ class BaseModel(SQLModel):
# SQLAlchemy event listener to automatically update updated_at timestamp
@event.listens_for(BaseModel, "before_update", propagate=True)
def update_timestamp(
mapper: Mapper[Any], connection: Connection, target: BaseModel, # noqa: ARG001
mapper: Mapper[Any], # noqa: ARG001
connection: Connection, # noqa: ARG001
target: BaseModel,
) -> None:
"""Automatically set updated_at timestamp before update operations."""
target.updated_at = datetime.now(UTC)

View File

@@ -35,5 +35,3 @@ class PlaylistSound(BaseModel, table=True):
# relationships
playlist: "Playlist" = Relationship(back_populates="playlist_sounds")
sound: "Sound" = Relationship(back_populates="playlist_sounds")

View File

@@ -58,7 +58,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
)
return list(result.all())
async def get_user_extractions_filtered(
async def get_user_extractions_filtered( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
@@ -92,7 +92,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery()
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
@@ -109,7 +109,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
return list(result.all()), total_count
async def get_all_extractions_filtered(
async def get_all_extractions_filtered( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
@@ -138,7 +138,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery()
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()

View File

@@ -118,7 +118,9 @@ class FavoriteRepository(BaseRepository[Favorite]):
raise
async def get_by_user_and_sound(
self, user_id: int, sound_id: int,
self,
user_id: int,
sound_id: int,
) -> Favorite | None:
"""Get a favorite by user and sound.
@@ -138,12 +140,16 @@ class FavoriteRepository(BaseRepository[Favorite]):
return result.first()
except Exception:
logger.exception(
"Failed to get favorite for user %s and sound %s", user_id, sound_id,
"Failed to get favorite for user %s and sound %s",
user_id,
sound_id,
)
raise
async def get_by_user_and_playlist(
self, user_id: int, playlist_id: int,
self,
user_id: int,
playlist_id: int,
) -> Favorite | None:
"""Get a favorite by user and playlist.

View File

@@ -57,7 +57,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
# management
except Exception:
logger.exception(
"Failed to update playlist timestamp for playlist: %s", playlist_id,
"Failed to update playlist timestamp for playlist: %s",
playlist_id,
)
raise
@@ -341,7 +342,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
include_stats: bool = False, # noqa: FBT001, FBT002
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False,
favorites_only: bool = False, # noqa: FBT001, FBT002
current_user_id: int | None = None,
*,
return_count: bool = False,
@@ -395,9 +396,13 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = select(1).select_from(Favorite).where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
favorites_subquery = (
select(1)
.select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
)
subquery = subquery.where(favorites_subquery.exists())
@@ -466,9 +471,13 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Apply favorites filter
if favorites_only and current_user_id is not None:
# Use EXISTS subquery to avoid JOIN conflicts with GROUP BY
favorites_subquery = select(1).select_from(Favorite).where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
favorites_subquery = (
select(1)
.select_from(Favorite)
.where(
Favorite.user_id == current_user_id,
Favorite.playlist_id == Playlist.id,
)
)
subquery = subquery.where(favorites_subquery.exists())

View File

@@ -141,7 +141,7 @@ class SoundRepository(BaseRepository[Sound]):
sort_order: SortOrder = SortOrder.ASC,
limit: int | None = None,
offset: int = 0,
favorites_only: bool = False,
favorites_only: bool = False, # noqa: FBT001, FBT002
user_id: int | None = None,
) -> list[Sound]:
"""Search and sort sounds with optional filtering."""
@@ -189,7 +189,8 @@ class SoundRepository(BaseRepository[Sound]):
logger.exception(
(
"Failed to search and sort sounds: "
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, user_id=%s"
"query=%s, types=%s, sort_by=%s, sort_order=%s, favorites_only=%s, "
"user_id=%s"
),
search_query,
sound_types,
@@ -288,8 +289,7 @@ class SoundRepository(BaseRepository[Sound]):
# Group by sound and order by play count descending
statement = (
statement
.group_by(
statement.group_by(
Sound.id,
Sound.name,
Sound.type,

View File

@@ -1,7 +1,7 @@
"""User repository."""
from typing import Any
from enum import Enum
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import selectinload
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
class UserSortField(str, Enum):
"""User sort fields."""
NAME = "name"
EMAIL = "email"
ROLE = "role"
@@ -27,12 +28,14 @@ class UserSortField(str, Enum):
class SortOrder(str, Enum):
"""Sort order."""
ASC = "asc"
DESC = "desc"
class UserStatus(str, Enum):
"""User status filter."""
ALL = "all"
ACTIVE = "active"
INACTIVE = "inactive"
@@ -64,7 +67,7 @@ class UserRepository(BaseRepository[User]):
logger.exception("Failed to get all users with plan")
raise
async def get_all_with_plan_paginated(
async def get_all_with_plan_paginated( # noqa: PLR0913
self,
page: int = 1,
limit: int = 50,
@@ -85,10 +88,9 @@ class UserRepository(BaseRepository[User]):
# Apply search filter
if search and search.strip():
search_pattern = f"%{search.strip().lower()}%"
search_condition = (
func.lower(User.name).like(search_pattern) |
func.lower(User.email).like(search_pattern)
)
search_condition = func.lower(User.name).like(
search_pattern,
) | func.lower(User.email).like(search_pattern)
base_query = base_query.where(search_condition)
count_query = count_query.where(search_condition)
@@ -101,31 +103,18 @@ class UserRepository(BaseRepository[User]):
count_query = count_query.where(User.is_active == False) # noqa: E712
# Apply sorting
if sort_by == UserSortField.EMAIL:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.email.desc())
else:
base_query = base_query.order_by(User.email.asc())
elif sort_by == UserSortField.ROLE:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.role.desc())
else:
base_query = base_query.order_by(User.role.asc())
elif sort_by == UserSortField.CREDITS:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.credits.desc())
else:
base_query = base_query.order_by(User.credits.asc())
elif sort_by == UserSortField.CREATED_AT:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.created_at.desc())
else:
base_query = base_query.order_by(User.created_at.asc())
else: # Default to name
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.name.desc())
else:
base_query = base_query.order_by(User.name.asc())
sort_column = {
UserSortField.NAME: User.name,
UserSortField.EMAIL: User.email,
UserSortField.ROLE: User.role,
UserSortField.CREDITS: User.credits,
UserSortField.CREATED_AT: User.created_at,
}.get(sort_by, User.name)
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(sort_column.desc())
else:
base_query = base_query.order_by(sort_column.asc())
# Get total count
count_result = await self.session.exec(count_query)
@@ -135,11 +124,11 @@ class UserRepository(BaseRepository[User]):
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
users = list(result.all())
return users, total_count
except Exception:
logger.exception("Failed to get paginated users with plan")
raise
else:
return users, total_count
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
"""Get a user by ID with plan relationship loaded."""
@@ -178,7 +167,7 @@ class UserRepository(BaseRepository[User]):
logger.exception("Failed to get user by API token")
raise
async def create(self, user_data: dict[str, Any]) -> User:
async def create(self, entity_data: dict[str, Any]) -> User:
"""Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None:
@@ -194,7 +183,7 @@ class UserRepository(BaseRepository[User]):
if is_first_user:
# First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro")
user_data["role"] = "admin"
entity_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan")
else:
# Regular users get free plan
@@ -210,11 +199,11 @@ class UserRepository(BaseRepository[User]):
assert default_plan is not None # noqa: S101
# Set plan_id and default credits
user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits
entity_data["plan_id"] = default_plan.id
entity_data["credits"] = default_plan.credits
# Use BaseRepository's create method
return await super().create(user_data)
return await super().create(entity_data)
except Exception:
logger.exception("Failed to create user")
raise

View File

@@ -85,7 +85,8 @@ class ChangePasswordRequest(BaseModel):
"""Schema for password change request."""
current_password: str | None = Field(
None, description="Current password (required if user has existing password)",
None,
description="Current password (required if user has existing password)",
)
new_password: str = Field(
...,
@@ -98,5 +99,8 @@ class UpdateProfileRequest(BaseModel):
"""Schema for profile update request."""
name: str | None = Field(
None, min_length=1, max_length=100, description="User display name",
None,
min_length=1,
max_length=100,
description="User display name",
)

View File

@@ -11,10 +11,12 @@ class FavoriteResponse(BaseModel):
id: int = Field(description="Favorite ID")
user_id: int = Field(description="User ID")
sound_id: int | None = Field(
description="Sound ID if this is a sound favorite", default=None,
description="Sound ID if this is a sound favorite",
default=None,
)
playlist_id: int | None = Field(
description="Playlist ID if this is a playlist favorite", default=None,
description="Playlist ID if this is a playlist favorite",
default=None,
)
created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp")

View File

@@ -39,7 +39,12 @@ class PlaylistResponse(BaseModel):
updated_at: str | None
@classmethod
def from_playlist(cls, playlist: Playlist, is_favorited: bool = False, favorite_count: int = 0) -> "PlaylistResponse":
def from_playlist(
cls,
playlist: Playlist,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "PlaylistResponse":
"""Create response from playlist model.
Args:

View File

@@ -18,16 +18,20 @@ class SoundResponse(BaseModel):
size: int = Field(description="File size in bytes")
hash: str = Field(description="File hash")
normalized_filename: str | None = Field(
description="Normalized filename", default=None,
description="Normalized filename",
default=None,
)
normalized_duration: int | None = Field(
description="Normalized duration in milliseconds", default=None,
description="Normalized duration in milliseconds",
default=None,
)
normalized_size: int | None = Field(
description="Normalized file size in bytes", default=None,
description="Normalized file size in bytes",
default=None,
)
normalized_hash: str | None = Field(
description="Normalized file hash", default=None,
description="Normalized file hash",
default=None,
)
thumbnail: str | None = Field(description="Thumbnail filename", default=None)
play_count: int = Field(description="Number of times played")
@@ -35,10 +39,12 @@ class SoundResponse(BaseModel):
is_music: bool = Field(description="Whether the sound is music")
is_deletable: bool = Field(description="Whether the sound can be deleted")
is_favorited: bool = Field(
description="Whether the sound is favorited by the current user", default=False,
description="Whether the sound is favorited by the current user",
default=False,
)
favorite_count: int = Field(
description="Number of users who favorited this sound", default=0,
description="Number of users who favorited this sound",
default=0,
)
created_at: datetime = Field(description="Creation timestamp")
updated_at: datetime = Field(description="Last update timestamp")
@@ -50,7 +56,10 @@ class SoundResponse(BaseModel):
@classmethod
def from_sound(
cls, sound: Sound, is_favorited: bool = False, favorite_count: int = 0,
cls,
sound: Sound,
is_favorited: bool = False, # noqa: FBT001, FBT002
favorite_count: int = 0,
) -> "SoundResponse":
"""Create a SoundResponse from a Sound model.
@@ -64,7 +73,8 @@ class SoundResponse(BaseModel):
"""
if sound.id is None:
raise ValueError("Sound ID cannot be None")
msg = "Sound ID cannot be None"
raise ValueError(msg)
return cls(
id=sound.id,

View File

@@ -7,7 +7,10 @@ class UserUpdate(BaseModel):
"""Schema for updating a user."""
name: str | None = Field(
None, min_length=1, max_length=100, description="User full name",
None,
min_length=1,
max_length=100,
description="User full name",
)
plan_id: int | None = Field(None, description="User plan ID")
credits: int | None = Field(None, ge=0, description="User credits")

View File

@@ -454,7 +454,10 @@ class AuthService:
return user
async def change_user_password(
self, user: User, current_password: str | None, new_password: str,
self,
user: User,
current_password: str | None,
new_password: str,
) -> None:
"""Change user's password."""
# Store user email before any operations to avoid session detachment issues
@@ -484,8 +487,11 @@ class AuthService:
self.session.add(user)
await self.session.commit()
logger.info("Password %s successfully for user: %s",
"changed" if had_existing_password else "set", user_email)
logger.info(
"Password %s successfully for user: %s",
"changed" if had_existing_password else "set",
user_email,
)
async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information."""

View File

@@ -72,9 +72,7 @@ class DashboardService:
"play_count": sound["play_count"],
"duration": sound["duration"],
"created_at": (
sound["created_at"].isoformat()
if sound["created_at"]
else None
sound["created_at"].isoformat() if sound["created_at"] else None
),
}
for sound in top_sounds

View File

@@ -532,7 +532,8 @@ class ExtractionService:
"""Add the sound to the user's main playlist."""
try:
await self.playlist_service._add_sound_to_main_playlist_internal( # noqa: SLF001
sound_id, user_id,
sound_id,
user_id,
)
logger.info(
"Added sound %d to main playlist for user %d",
@@ -554,6 +555,10 @@ class ExtractionService:
if not extraction:
return None
# Get user information
user = await self.user_repo.get_by_id(extraction.user_id)
user_name = user.name if user else None
return {
"id": extraction.id or 0, # Should never be None for existing extraction
"url": extraction.url,
@@ -564,11 +569,12 @@ class ExtractionService:
"error": extraction.error,
"sound_id": extraction.sound_id,
"user_id": extraction.user_id,
"user_name": user_name,
"created_at": extraction.created_at.isoformat(),
"updated_at": extraction.updated_at.isoformat(),
}
async def get_user_extractions(
async def get_user_extractions( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
@@ -580,7 +586,10 @@ class ExtractionService:
) -> PaginatedExtractionsResponse:
"""Get all extractions for a user with filtering, search, and sorting."""
offset = (page - 1) * limit
extraction_user_tuples, total_count = await self.extraction_repo.get_user_extractions_filtered(
(
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_user_extractions_filtered(
user_id=user_id,
search=search,
sort_by=sort_by,
@@ -619,7 +628,7 @@ class ExtractionService:
"total_pages": total_pages,
}
async def get_all_extractions(
async def get_all_extractions( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
@@ -630,7 +639,10 @@ class ExtractionService:
) -> PaginatedExtractionsResponse:
"""Get all extractions with filtering, search, and sorting."""
offset = (page - 1) * limit
extraction_user_tuples, total_count = await self.extraction_repo.get_all_extractions_filtered(
(
extraction_user_tuples,
total_count,
) = await self.extraction_repo.get_all_extractions_filtered(
search=search,
sort_by=sort_by,
sort_order=sort_order,

View File

@@ -49,12 +49,14 @@ class FavoriteService:
# Verify user exists
user = await user_repo.get_by_id(user_id)
if not user:
raise ValueError(f"User with ID {user_id} not found")
msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify sound exists
sound = await sound_repo.get_by_id(sound_id)
if not sound:
raise ValueError(f"Sound with ID {sound_id} not found")
msg = f"Sound with ID {sound_id} not found"
raise ValueError(msg)
# Get data for the event immediately after loading
sound_name = sound.name
@@ -63,9 +65,8 @@ class FavoriteService:
# Check if already favorited
existing = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if existing:
raise ValueError(
f"Sound {sound_id} is already favorited by user {user_id}",
)
msg = f"Sound {sound_id} is already favorited by user {user_id}"
raise ValueError(msg)
# Create favorite
favorite_data = {
@@ -120,12 +121,14 @@ class FavoriteService:
# Verify user exists
user = await user_repo.get_by_id(user_id)
if not user:
raise ValueError(f"User with ID {user_id} not found")
msg = f"User with ID {user_id} not found"
raise ValueError(msg)
# Verify playlist exists
playlist = await playlist_repo.get_by_id(playlist_id)
if not playlist:
raise ValueError(f"Playlist with ID {playlist_id} not found")
msg = f"Playlist with ID {playlist_id} not found"
raise ValueError(msg)
# Check if already favorited
existing = await favorite_repo.get_by_user_and_playlist(
@@ -133,9 +136,8 @@ class FavoriteService:
playlist_id,
)
if existing:
raise ValueError(
f"Playlist {playlist_id} is already favorited by user {user_id}",
)
msg = f"Playlist {playlist_id} is already favorited by user {user_id}"
raise ValueError(msg)
# Create favorite
favorite_data = {
@@ -163,7 +165,8 @@ class FavoriteService:
favorite = await favorite_repo.get_by_user_and_sound(user_id, sound_id)
if not favorite:
raise ValueError(f"Sound {sound_id} is not favorited by user {user_id}")
msg = f"Sound {sound_id} is not favorited by user {user_id}"
raise ValueError(msg)
# Get user and sound info before deletion for the event
user_repo = UserRepository(session)
@@ -192,7 +195,8 @@ class FavoriteService:
}
await socket_manager.broadcast_to_all("sound_favorited", event_data)
logger.info(
"Broadcasted sound_favorited event for sound %s removal", sound_id,
"Broadcasted sound_favorited event for sound %s removal",
sound_id,
)
except Exception:
logger.exception(
@@ -219,9 +223,8 @@ class FavoriteService:
playlist_id,
)
if not favorite:
raise ValueError(
f"Playlist {playlist_id} is not favorited by user {user_id}",
)
msg = f"Playlist {playlist_id} is not favorited by user {user_id}"
raise ValueError(msg)
await favorite_repo.delete(favorite)
logger.info(

View File

@@ -16,6 +16,7 @@ logger = get_logger(__name__)
class PaginatedPlaylistsResponse(TypedDict):
"""Response type for paginated playlists."""
playlists: list[dict]
total: int
page: int
@@ -468,7 +469,9 @@ class PlaylistService:
}
async def add_sound_to_main_playlist(
self, sound_id: int, user_id: int, # noqa: ARG002
self,
sound_id: int, # noqa: ARG002
user_id: int, # noqa: ARG002
) -> None:
"""Add a sound to the global main playlist."""
raise HTTPException(
@@ -477,7 +480,9 @@ class PlaylistService:
)
async def _add_sound_to_main_playlist_internal(
self, sound_id: int, user_id: int,
self,
sound_id: int,
user_id: int,
) -> None:
"""Add sound to main playlist bypassing restrictions.

View File

@@ -21,8 +21,6 @@ def mock_plan_repository():
return Mock()
@pytest.fixture
def regular_user():
"""Create regular user for testing."""
@@ -60,52 +58,78 @@ class TestAdminUserEndpoints:
test_plan: Plan,
) -> None:
"""Test listing users successfully."""
with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all:
with patch(
"app.repositories.user.UserRepository.get_all_with_plan_paginated",
) as mock_get_all:
# Create mock user objects that don't trigger database saves
mock_admin = type("User", (), {
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_admin = type(
"User",
(),
{
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_regular = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_regular = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_all.return_value = [mock_admin, mock_regular]
# Mock returns tuple (users, total_count)
mock_get_all.return_value = ([mock_admin, mock_regular], 2)
response = await authenticated_admin_client.get("/api/v1/admin/users/")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["email"] == "admin@example.com"
assert data[1]["email"] == "user@example.com"
mock_get_all.assert_called_once_with(limit=100, offset=0)
assert "users" in data
assert "total" in data
assert "page" in data
assert "limit" in data
assert "total_pages" in data
assert len(data["users"]) == 2
assert data["users"][0]["email"] == "admin@example.com"
assert data["users"][1]["email"] == "user@example.com"
assert data["total"] == 2
assert data["page"] == 1
assert data["limit"] == 50
@pytest.mark.asyncio
async def test_list_users_with_pagination(
@@ -115,29 +139,55 @@ class TestAdminUserEndpoints:
test_plan: Plan,
) -> None:
"""Test listing users with pagination."""
with patch("app.repositories.user.UserRepository.get_all_with_plan") as mock_get_all:
mock_admin = type("User", (), {
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_get_all.return_value = [mock_admin]
from app.repositories.user import SortOrder, UserSortField, UserStatus
response = await authenticated_admin_client.get("/api/v1/admin/users/?limit=10&offset=5")
with patch(
"app.repositories.user.UserRepository.get_all_with_plan_paginated",
) as mock_get_all:
mock_admin = type(
"User",
(),
{
"id": admin_user.id,
"email": admin_user.email,
"name": admin_user.name,
"picture": None,
"role": admin_user.role,
"credits": admin_user.credits,
"is_active": admin_user.is_active,
"created_at": admin_user.created_at,
"updated_at": admin_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
# Mock returns tuple (users, total_count)
mock_get_all.return_value = ([mock_admin], 1)
response = await authenticated_admin_client.get(
"/api/v1/admin/users/?page=2&limit=10",
)
assert response.status_code == 200
mock_get_all.assert_called_once_with(limit=10, offset=5)
data = response.json()
assert "users" in data
assert data["page"] == 2
assert data["limit"] == 10
mock_get_all.assert_called_once_with(
page=2,
limit=10,
search=None,
sort_by=UserSortField.NAME,
sort_order=SortOrder.ASC,
status_filter=UserStatus.ALL,
)
@pytest.mark.asyncio
async def test_list_users_unauthenticated(self, client: AsyncClient) -> None:
@@ -153,7 +203,9 @@ class TestAdminUserEndpoints:
regular_user: User,
) -> None:
"""Test listing users as non-admin user."""
with patch("app.core.dependencies.get_current_active_user", return_value=regular_user):
with patch(
"app.core.dependencies.get_current_active_user", return_value=regular_user,
):
response = await client.get("/api/v1/admin/users/")
assert response.status_code == 401
@@ -169,24 +221,34 @@ class TestAdminUserEndpoints:
"""Test getting specific user successfully."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
):
mock_user = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_user = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user
response = await authenticated_admin_client.get("/api/v1/admin/users/2")
@@ -207,7 +269,10 @@ class TestAdminUserEndpoints:
"""Test getting non-existent user."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
):
response = await authenticated_admin_client.get("/api/v1/admin/users/999")
@@ -226,43 +291,63 @@ class TestAdminUserEndpoints:
"""Test updating user successfully."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update,
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan),
patch(
"app.repositories.plan.PlanRepository.get_by_id", return_value=test_plan,
),
):
mock_user = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_user = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
updated_mock = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": "Updated Name",
"picture": None,
"role": regular_user.role,
"credits": 200,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
updated_mock = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": "Updated Name",
"picture": None,
"role": regular_user.role,
"credits": 200,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user
mock_update.return_value = updated_mock
@@ -271,7 +356,10 @@ class TestAdminUserEndpoints:
async def mock_refresh(instance, attributes=None):
pass
with patch("sqlmodel.ext.asyncio.session.AsyncSession.refresh", side_effect=mock_refresh):
with patch(
"sqlmodel.ext.asyncio.session.AsyncSession.refresh",
side_effect=mock_refresh,
):
response = await authenticated_admin_client.patch(
"/api/v1/admin/users/2",
json={
@@ -295,7 +383,10 @@ class TestAdminUserEndpoints:
"""Test updating non-existent user."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
):
response = await authenticated_admin_client.patch(
"/api/v1/admin/users/999",
@@ -316,25 +407,35 @@ class TestAdminUserEndpoints:
"""Test updating user with invalid plan."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.plan.PlanRepository.get_by_id", return_value=None),
):
mock_user = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": 1,
"name": "Basic",
"max_credits": 100,
})(),
})()
mock_user = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": 1,
"name": "Basic",
"max_credits": 100,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user
response = await authenticated_admin_client.patch(
"/api/v1/admin/users/2",
@@ -356,29 +457,41 @@ class TestAdminUserEndpoints:
"""Test disabling user successfully."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update,
):
mock_user = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_user = type(
"User",
(),
{
"id": regular_user.id,
"email": regular_user.email,
"name": regular_user.name,
"picture": None,
"role": regular_user.role,
"credits": regular_user.credits,
"is_active": regular_user.is_active,
"created_at": regular_user.created_at,
"updated_at": regular_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_user
mock_update.return_value = mock_user
response = await authenticated_admin_client.post("/api/v1/admin/users/2/disable")
response = await authenticated_admin_client.post(
"/api/v1/admin/users/2/disable",
)
assert response.status_code == 200
data = response.json()
@@ -393,9 +506,14 @@ class TestAdminUserEndpoints:
"""Test disabling non-existent user."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
):
response = await authenticated_admin_client.post("/api/v1/admin/users/999/disable")
response = await authenticated_admin_client.post(
"/api/v1/admin/users/999/disable",
)
assert response.status_code == 404
data = response.json()
@@ -421,29 +539,41 @@ class TestAdminUserEndpoints:
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan") as mock_get_by_id,
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
) as mock_get_by_id,
patch("app.repositories.user.UserRepository.update") as mock_update,
):
mock_disabled_user = type("User", (), {
"id": disabled_user.id,
"email": disabled_user.email,
"name": disabled_user.name,
"picture": None,
"role": disabled_user.role,
"credits": disabled_user.credits,
"is_active": disabled_user.is_active,
"created_at": disabled_user.created_at,
"updated_at": disabled_user.updated_at,
"plan": type("Plan", (), {
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})(),
})()
mock_disabled_user = type(
"User",
(),
{
"id": disabled_user.id,
"email": disabled_user.email,
"name": disabled_user.name,
"picture": None,
"role": disabled_user.role,
"credits": disabled_user.credits,
"is_active": disabled_user.is_active,
"created_at": disabled_user.created_at,
"updated_at": disabled_user.updated_at,
"plan": type(
"Plan",
(),
{
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
},
)(),
},
)()
mock_get_by_id.return_value = mock_disabled_user
mock_update.return_value = mock_disabled_user
response = await authenticated_admin_client.post("/api/v1/admin/users/3/enable")
response = await authenticated_admin_client.post(
"/api/v1/admin/users/3/enable",
)
assert response.status_code == 200
data = response.json()
@@ -458,9 +588,14 @@ class TestAdminUserEndpoints:
"""Test enabling non-existent user."""
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.user.UserRepository.get_by_id_with_plan", return_value=None),
patch(
"app.repositories.user.UserRepository.get_by_id_with_plan",
return_value=None,
),
):
response = await authenticated_admin_client.post("/api/v1/admin/users/999/enable")
response = await authenticated_admin_client.post(
"/api/v1/admin/users/999/enable",
)
assert response.status_code == 404
data = response.json()
@@ -479,9 +614,14 @@ class TestAdminUserEndpoints:
with (
patch("app.core.dependencies.get_admin_user", return_value=admin_user),
patch("app.repositories.plan.PlanRepository.get_all", return_value=[basic_plan, premium_plan]),
patch(
"app.repositories.plan.PlanRepository.get_all",
return_value=[basic_plan, premium_plan],
),
):
response = await authenticated_admin_client.get("/api/v1/admin/users/plans/list")
response = await authenticated_admin_client.get(
"/api/v1/admin/users/plans/list",
)
assert response.status_code == 200
data = response.json()

View File

@@ -488,11 +488,17 @@ class TestAuthEndpoints:
test_plan: Plan,
) -> None:
"""Test refresh token success."""
with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh:
mock_refresh.return_value = type("TokenResponse", (), {
"access_token": "new_access_token",
"expires_in": 3600,
})()
with patch(
"app.services.auth.AuthService.refresh_access_token",
) as mock_refresh:
mock_refresh.return_value = type(
"TokenResponse",
(),
{
"access_token": "new_access_token",
"expires_in": 3600,
},
)()
response = await test_client.post(
"/api/v1/auth/refresh",
@@ -516,7 +522,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio
async def test_refresh_token_service_error(self, test_client: AsyncClient) -> None:
"""Test refresh token with service error."""
with patch("app.services.auth.AuthService.refresh_access_token") as mock_refresh:
with patch(
"app.services.auth.AuthService.refresh_access_token",
) as mock_refresh:
mock_refresh.side_effect = Exception("Database error")
response = await test_client.post(
@@ -528,7 +536,6 @@ class TestAuthEndpoints:
data = response.json()
assert "Token refresh failed" in data["detail"]
@pytest.mark.asyncio
async def test_exchange_oauth_token_invalid_code(
self,
@@ -554,7 +561,9 @@ class TestAuthEndpoints:
"""Test update profile success."""
with (
patch("app.services.auth.AuthService.update_user_profile") as mock_update,
patch("app.services.auth.AuthService.user_to_response") as mock_user_to_response,
patch(
"app.services.auth.AuthService.user_to_response",
) as mock_user_to_response,
):
updated_user = User(
id=test_user.id,
@@ -569,6 +578,7 @@ class TestAuthEndpoints:
# Mock the user_to_response to return UserResponse format
from app.schemas.auth import UserResponse
mock_user_to_response.return_value = UserResponse(
id=test_user.id,
email=test_user.email,
@@ -598,7 +608,9 @@ class TestAuthEndpoints:
assert data["name"] == "Updated Name"
@pytest.mark.asyncio
async def test_update_profile_unauthenticated(self, test_client: AsyncClient) -> None:
async def test_update_profile_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test update profile without authentication."""
response = await test_client.patch(
"/api/v1/auth/me",
@@ -632,7 +644,9 @@ class TestAuthEndpoints:
assert data["message"] == "Password changed successfully"
@pytest.mark.asyncio
async def test_change_password_unauthenticated(self, test_client: AsyncClient) -> None:
async def test_change_password_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test change password without authentication."""
response = await test_client.post(
"/api/v1/auth/change-password",
@@ -652,7 +666,9 @@ class TestAuthEndpoints:
auth_cookies: dict[str, str],
) -> None:
"""Test get user OAuth providers success."""
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
with patch(
"app.services.auth.AuthService.get_user_oauth_providers",
) as mock_providers:
from datetime import datetime
from app.models.user_oauth import UserOauth
@@ -699,7 +715,9 @@ class TestAuthEndpoints:
assert data[2]["display_name"] == "GitHub"
@pytest.mark.asyncio
async def test_get_user_providers_unauthenticated(self, test_client: AsyncClient) -> None:
async def test_get_user_providers_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test get user OAuth providers without authentication."""
response = await test_client.get("/api/v1/auth/user-providers")

View File

@@ -109,9 +109,15 @@ class TestPlaylistEndpoints:
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert "playlists" in data
assert "total" in data
assert "page" in data
assert "limit" in data
assert "total_pages" in data
assert len(data["playlists"]) == 2
assert data["total"] == 2
playlist_names = {p["name"] for p in data}
playlist_names = {p["name"] for p in data["playlists"]}
assert "Test Playlist" in playlist_names
assert "Main Playlist" in playlist_names

View File

@@ -9,7 +9,7 @@ from httpx import AsyncClient
from app.models.user import User
if TYPE_CHECKING:
from app.services.extraction import ExtractionInfo
from app.services.extraction import ExtractionInfo, PaginatedExtractionsResponse
class TestSoundEndpoints:
@@ -32,6 +32,7 @@ class TestSoundEndpoints:
"error": None,
"sound_id": None,
"user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
}
@@ -111,6 +112,7 @@ class TestSoundEndpoints:
"error": None,
"sound_id": 42,
"user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
}
@@ -154,41 +156,49 @@ class TestSoundEndpoints:
authenticated_user: User,
) -> None:
"""Test getting user extractions."""
mock_extractions: list[ExtractionInfo] = [
{
"id": 1,
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"title": "Never Gonna Give You Up",
"service": "youtube",
"service_id": "dQw4w9WgXcQ",
"status": "completed",
"error": None,
"sound_id": 42,
"user_id": authenticated_user.id,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
{
"id": 2,
"url": "https://soundcloud.com/example/track",
"title": "Example Track",
"service": "soundcloud",
"service_id": "example-track",
"status": "pending",
"error": None,
"sound_id": None,
"user_id": authenticated_user.id,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
]
mock_extractions: PaginatedExtractionsResponse = {
"extractions": [
{
"id": 1,
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
"title": "Never Gonna Give You Up",
"service": "youtube",
"service_id": "dQw4w9WgXcQ",
"status": "completed",
"error": None,
"sound_id": 42,
"user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
{
"id": 2,
"url": "https://soundcloud.com/example/track",
"title": "Example Track",
"service": "soundcloud",
"service_id": "example-track",
"status": "pending",
"error": None,
"sound_id": None,
"user_id": authenticated_user.id,
"user_name": authenticated_user.name,
"created_at": "2025-08-03T12:00:00Z",
"updated_at": "2025-08-03T12:00:00Z",
},
],
"total": 2,
"page": 1,
"limit": 50,
"total_pages": 1,
}
with patch(
"app.services.extraction.ExtractionService.get_user_extractions",
) as mock_get:
mock_get.return_value = mock_extractions
response = await authenticated_client.get("/api/v1/extractions/")
response = await authenticated_client.get("/api/v1/extractions/user")
assert response.status_code == 200
data = response.json()
@@ -337,7 +347,9 @@ class TestSoundEndpoints:
"""Test getting sounds with authentication."""
from app.models.sound import Sound
with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get:
with patch(
"app.repositories.sound.SoundRepository.search_and_sort",
) as mock_get:
# Create mock sounds with all required fields
mock_sound_1 = Sound(
id=1,
@@ -383,7 +395,9 @@ class TestSoundEndpoints:
"""Test getting sounds with type filtering."""
from app.models.sound import Sound
with patch("app.repositories.sound.SoundRepository.search_and_sort") as mock_get:
with patch(
"app.repositories.sound.SoundRepository.search_and_sort",
) as mock_get:
# Create mock sound with all required fields
mock_sound = Sound(
id=1,

View File

@@ -335,5 +335,3 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
access_token = JWTUtils.create_access_token(token_data)
return {"access_token": access_token}

View File

@@ -539,21 +539,35 @@ class TestPlaylistRepository:
sound_ids = [s.id for s in sounds]
# Add first two sounds sequentially (positions 0, 1)
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[0],
) # position 0
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[1],
) # position 1
# Now insert third sound at position 1 - should shift existing sound at position 1 to position 2
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=1)
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[2], position=1,
)
# Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
playlist_sounds = await playlist_repository.get_playlist_sound_entries(
playlist_id,
)
assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0
assert (
playlist_sounds[0].sound_id == sound_ids[0]
) # Original sound 0 stays at position 0
assert playlist_sounds[0].position == 0
assert playlist_sounds[1].sound_id == sound_ids[2] # New sound 2 inserted at position 1
assert (
playlist_sounds[1].sound_id == sound_ids[2]
) # New sound 2 inserted at position 1
assert playlist_sounds[1].position == 1
assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2
assert (
playlist_sounds[2].sound_id == sound_ids[1]
) # Original sound 1 shifted to position 2
assert playlist_sounds[2].position == 2
@pytest.mark.asyncio
@@ -615,21 +629,35 @@ class TestPlaylistRepository:
sound_ids = [s.id for s in sounds]
# Add first two sounds sequentially (positions 0, 1)
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[0]) # position 0
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[1]) # position 1
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[0],
) # position 0
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[1],
) # position 1
# Now insert third sound at position 0 - should shift existing sounds to positions 1, 2
await playlist_repository.add_sound_to_playlist(playlist_id, sound_ids[2], position=0)
await playlist_repository.add_sound_to_playlist(
playlist_id, sound_ids[2], position=0,
)
# Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
playlist_sounds = await playlist_repository.get_playlist_sound_entries(
playlist_id,
)
assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0
assert (
playlist_sounds[0].sound_id == sound_ids[2]
) # New sound 2 inserted at position 0
assert playlist_sounds[0].position == 0
assert playlist_sounds[1].sound_id == sound_ids[0] # Original sound 0 shifted to position 1
assert (
playlist_sounds[1].sound_id == sound_ids[0]
) # Original sound 0 shifted to position 1
assert playlist_sounds[1].position == 1
assert playlist_sounds[2].sound_id == sound_ids[1] # Original sound 1 shifted to position 2
assert (
playlist_sounds[2].sound_id == sound_ids[1]
) # Original sound 1 shifted to position 2
assert playlist_sounds[2].position == 2
@pytest.mark.asyncio

View File

@@ -409,7 +409,9 @@ class TestCreditService:
@pytest.mark.asyncio
async def test_recharge_user_credits_success(
self, credit_service, sample_user,
self,
credit_service,
sample_user,
) -> None:
"""Test successful credit recharge for a user."""
mock_session = credit_service.db_session_factory()

View File

@@ -43,7 +43,9 @@ class TestDashboardService:
"total_duration": 75000,
"total_size": 1024000,
}
mock_sound_repository.get_soundboard_statistics = AsyncMock(return_value=mock_stats)
mock_sound_repository.get_soundboard_statistics = AsyncMock(
return_value=mock_stats,
)
result = await dashboard_service.get_soundboard_statistics()

View File

@@ -99,6 +99,11 @@ class TestExtractionService:
url = "https://www.youtube.com/watch?v=test123"
user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - no service detection happens during creation
mock_extraction = Extraction(
id=1,
@@ -120,6 +125,7 @@ class TestExtractionService:
assert result["service_id"] is None # Not detected during creation
assert result["title"] is None # Not detected during creation
assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_create_extraction_basic(self, extraction_service) -> None:
@@ -127,6 +133,11 @@ class TestExtractionService:
url = "https://www.youtube.com/watch?v=test123"
user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - creation always succeeds now
mock_extraction = Extraction(
id=2,
@@ -146,6 +157,7 @@ class TestExtractionService:
assert result["id"] == 2
assert result["url"] == url
assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_create_extraction_any_url(self, extraction_service) -> None:
@@ -153,6 +165,11 @@ class TestExtractionService:
url = "https://invalid.url"
user_id = 1
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock repository call - even invalid URLs are accepted during creation
mock_extraction = Extraction(
id=3,
@@ -172,6 +189,7 @@ class TestExtractionService:
assert result["id"] == 3
assert result["url"] == url
assert result["status"] == "pending"
assert result["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_process_extraction_with_service_detection(
@@ -408,9 +426,16 @@ class TestExtractionService:
sound_id=42,
)
# Mock user for user_name retrieval
mock_user = Mock()
mock_user.name = "Test User"
extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=extraction,
)
extraction_service.user_repo.get_by_id = AsyncMock(
return_value=mock_user,
)
result = await extraction_service.get_extraction_by_id(1)
@@ -421,6 +446,7 @@ class TestExtractionService:
assert result["title"] == "Test Video"
assert result["status"] == "completed"
assert result["sound_id"] == 42
assert result["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_get_extraction_by_id_not_found(self, extraction_service) -> None:
@@ -434,52 +460,70 @@ class TestExtractionService:
@pytest.mark.asyncio
async def test_get_user_extractions(self, extraction_service) -> None:
"""Test getting user extractions."""
from app.models.user import User
user = User(id=1, name="Test User", email="test@example.com")
extractions = [
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Test Video 1",
status="completed",
sound_id=42,
(
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Test Video 1",
status="completed",
sound_id=42,
),
user,
),
Extraction(
id=2,
service="youtube",
service_id="test456",
url="https://www.youtube.com/watch?v=test456",
user_id=1,
title="Test Video 2",
status="pending",
(
Extraction(
id=2,
service="youtube",
service_id="test456",
url="https://www.youtube.com/watch?v=test456",
user_id=1,
title="Test Video 2",
status="pending",
),
user,
),
]
extraction_service.extraction_repo.get_by_user = AsyncMock(
return_value=extractions,
extraction_service.extraction_repo.get_user_extractions_filtered = AsyncMock(
return_value=(extractions, 2),
)
result = await extraction_service.get_user_extractions(1)
assert len(result) == 2
assert result[0]["id"] == 1
assert result[0]["title"] == "Test Video 1"
assert result[1]["id"] == 2
assert result[1]["title"] == "Test Video 2"
assert result["total"] == 2
assert len(result["extractions"]) == 2
assert result["extractions"][0]["id"] == 1
assert result["extractions"][0]["title"] == "Test Video 1"
assert result["extractions"][0]["user_name"] == "Test User"
assert result["extractions"][1]["id"] == 2
assert result["extractions"][1]["title"] == "Test Video 2"
assert result["extractions"][1]["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_get_pending_extractions(self, extraction_service) -> None:
"""Test getting pending extractions."""
from app.models.user import User
user = User(id=1, name="Test User", email="test@example.com")
pending_extractions = [
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Pending Video",
status="pending",
(
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Pending Video",
status="pending",
),
user,
),
]
@@ -492,3 +536,4 @@ class TestExtractionService:
assert len(result) == 1
assert result[0]["id"] == 1
assert result[0]["status"] == "pending"
assert result[0]["user_name"] == "Test User"

View File

@@ -25,9 +25,10 @@ class TestSchedulerService:
@pytest.mark.asyncio
async def test_start_scheduler(self, scheduler_service) -> None:
"""Test starting the scheduler service."""
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
patch.object(scheduler_service.scheduler, "start") as mock_start:
with (
patch.object(scheduler_service.scheduler, "add_job") as mock_add_job,
patch.object(scheduler_service.scheduler, "start") as mock_start,
):
await scheduler_service.start()
# Verify job was added
@@ -61,7 +62,9 @@ class TestSchedulerService:
"total_credits_added": 500,
}
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
with patch.object(
scheduler_service.credit_service, "recharge_all_users_credits",
) as mock_recharge:
mock_recharge.return_value = mock_stats
await scheduler_service._daily_credit_recharge()
@@ -71,7 +74,9 @@ class TestSchedulerService:
@pytest.mark.asyncio
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
"""Test daily credit recharge task with failure."""
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
with patch.object(
scheduler_service.credit_service, "recharge_all_users_credits",
) as mock_recharge:
mock_recharge.side_effect = Exception("Database error")
# Should not raise exception, just log it