fix: Add missing commas in function calls and improve code formatting
Some checks failed
Backend CI / lint (push) Failing after 4m51s
Backend CI / test (push) Successful in 4m19s

This commit is contained in:
JSC
2025-08-12 23:37:38 +02:00
parent d3d7edb287
commit f094fbf140
18 changed files with 135 additions and 133 deletions

View File

@@ -460,7 +460,7 @@ async def update_profile(
"""Update the current user's profile.""" """Update the current user's profile."""
try: try:
updated_user = await auth_service.update_user_profile( updated_user = await auth_service.update_user_profile(
current_user, request.model_dump(exclude_unset=True) current_user, request.model_dump(exclude_unset=True),
) )
return await auth_service.user_to_response(updated_user) return await auth_service.user_to_response(updated_user)
except Exception as e: except Exception as e:
@@ -482,7 +482,7 @@ async def change_password(
user_email = current_user.email user_email = current_user.email
try: try:
await auth_service.change_user_password( await auth_service.change_user_password(
current_user, request.current_password, request.new_password current_user, request.current_password, request.new_password,
) )
return {"message": "Password changed successfully"} return {"message": "Password changed successfully"}
except ValueError as e: except ValueError as e:
@@ -505,7 +505,7 @@ async def get_user_providers(
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""Get the current user's connected authentication providers.""" """Get the current user's connected authentication providers."""
providers = [] providers = []
# Add password provider if user has password # Add password provider if user has password
if current_user.password_hash: if current_user.password_hash:
providers.append({ providers.append({
@@ -513,7 +513,7 @@ async def get_user_providers(
"display_name": "Password", "display_name": "Password",
"connected_at": current_user.created_at.isoformat(), "connected_at": current_user.created_at.isoformat(),
}) })
# Get OAuth providers from the database # Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user) oauth_providers = await auth_service.get_user_oauth_providers(current_user)
for oauth in oauth_providers: for oauth in oauth_providers:
@@ -522,11 +522,11 @@ async def get_user_providers(
display_name = "GitHub" display_name = "GitHub"
elif oauth.provider == "google": elif oauth.provider == "google":
display_name = "Google" display_name = "Google"
providers.append({ providers.append({
"provider": oauth.provider, "provider": oauth.provider,
"display_name": display_name, "display_name": display_name,
"connected_at": oauth.created_at.isoformat(), "connected_at": oauth.created_at.isoformat(),
}) })
return providers return providers

View File

@@ -41,5 +41,5 @@ async def get_top_sounds(
return await dashboard_service.get_top_sounds( return await dashboard_service.get_top_sounds(
sound_type=sound_type, sound_type=sound_type,
period=period, period=period,
limit=limit limit=limit,
) )

View File

@@ -8,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder
from app.schemas.common import MessageResponse from app.schemas.common import MessageResponse
from app.schemas.playlist import ( from app.schemas.playlist import (
PlaylistAddSoundRequest, PlaylistAddSoundRequest,
@@ -19,7 +20,6 @@ from app.schemas.playlist import (
PlaylistUpdateRequest, PlaylistUpdateRequest,
) )
from app.services.playlist import PlaylistService from app.services.playlist import PlaylistService
from app.repositories.playlist import PlaylistSortField, SortOrder
router = APIRouter(prefix="/playlists", tags=["playlists"]) router = APIRouter(prefix="/playlists", tags=["playlists"])

View File

@@ -10,7 +10,7 @@ from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType from app.models.credit_action import CreditActionType
from app.models.sound import Sound from app.models.sound import Sound
from app.models.user import User from app.models.user import User
from app.repositories.sound import SoundRepository, SoundSortField, SortOrder from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
from app.services.credit import CreditService, InsufficientCreditsError from app.services.credit import CreditService, InsufficientCreditsError
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service from app.services.vlc_player import VLCPlayerService, get_vlc_player_service

View File

@@ -8,10 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.user import User from app.models.user import User
from app.repositories.sound import SoundRepository
from app.services.auth import AuthService from app.services.auth import AuthService
from app.services.dashboard import DashboardService from app.services.dashboard import DashboardService
from app.services.oauth import OAuthService from app.services.oauth import OAuthService
from app.repositories.sound import SoundRepository
from app.utils.auth import JWTUtils, TokenUtils from app.utils.auth import JWTUtils, TokenUtils
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -1,9 +1,10 @@
"""Playlist repository for database operations.""" """Playlist repository for database operations."""
from enum import Enum from enum import Enum
from sqlalchemy import func, update from sqlalchemy import func, update
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from sqlmodel import col, select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger from app.core.logging import get_logger
@@ -18,7 +19,7 @@ logger = get_logger(__name__)
class PlaylistSortField(str, Enum): class PlaylistSortField(str, Enum):
"""Playlist sort field enumeration.""" """Playlist sort field enumeration."""
NAME = "name" NAME = "name"
GENRE = "genre" GENRE = "genre"
CREATED_AT = "created_at" CREATED_AT = "created_at"
@@ -29,7 +30,7 @@ class PlaylistSortField(str, Enum):
class SortOrder(str, Enum): class SortOrder(str, Enum):
"""Sort order enumeration.""" """Sort order enumeration."""
ASC = "asc" ASC = "asc"
DESC = "desc" DESC = "desc"
@@ -154,7 +155,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Use a two-step approach to avoid unique constraint violations: # Use a two-step approach to avoid unique constraint violations:
# 1. Move all affected positions to negative temporary positions # 1. Move all affected positions to negative temporary positions
# 2. Then move them to their final positions # 2. Then move them to their final positions
# Step 1: Move to temporary negative positions # Step 1: Move to temporary negative positions
update_to_negative = ( update_to_negative = (
update(PlaylistSound) update(PlaylistSound)
@@ -166,8 +167,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
) )
await self.session.exec(update_to_negative) await self.session.exec(update_to_negative)
await self.session.commit() await self.session.commit()
# Step 2: Move from temporary negative positions to final positions # Step 2: Move from temporary negative positions to final positions
update_to_final = ( update_to_final = (
update(PlaylistSound) update(PlaylistSound)
.where( .where(
@@ -337,15 +338,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True) .join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name) .group_by(Playlist.id, User.name)
) )
# Apply filters # Apply filters
if search_query and search_query.strip(): if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%" search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern)) subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
if user_id is not None: if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id) subquery = subquery.where(Playlist.user_id == user_id)
# Apply sorting # Apply sorting
if sort_by == PlaylistSortField.SOUND_COUNT: if sort_by == PlaylistSortField.SOUND_COUNT:
if sort_order == SortOrder.DESC: if sort_order == SortOrder.DESC:
@@ -360,7 +361,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
else: else:
# Default sorting by name # Default sorting by name
subquery = subquery.order_by(Playlist.name.asc()) subquery = subquery.order_by(Playlist.name.asc())
else: else:
# Simple query without stats-based sorting # Simple query without stats-based sorting
subquery = ( subquery = (
@@ -385,15 +386,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True) .join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name) .group_by(Playlist.id, User.name)
) )
# Apply filters # Apply filters
if search_query and search_query.strip(): if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%" search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern)) subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
if user_id is not None: if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id) subquery = subquery.where(Playlist.user_id == user_id)
# Apply sorting # Apply sorting
if sort_by: if sort_by:
if sort_by == PlaylistSortField.NAME: if sort_by == PlaylistSortField.NAME:
@@ -406,7 +407,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
sort_column = Playlist.updated_at sort_column = Playlist.updated_at
else: else:
sort_column = Playlist.name sort_column = Playlist.name
if sort_order == SortOrder.DESC: if sort_order == SortOrder.DESC:
subquery = subquery.order_by(sort_column.desc()) subquery = subquery.order_by(sort_column.desc())
else: else:
@@ -414,16 +415,16 @@ class PlaylistRepository(BaseRepository[Playlist]):
else: else:
# Default sorting by name ascending # Default sorting by name ascending
subquery = subquery.order_by(Playlist.name.asc()) subquery = subquery.order_by(Playlist.name.asc())
# Apply pagination # Apply pagination
if offset > 0: if offset > 0:
subquery = subquery.offset(offset) subquery = subquery.offset(offset)
if limit is not None: if limit is not None:
subquery = subquery.limit(limit) subquery = subquery.limit(limit)
result = await self.session.exec(subquery) result = await self.session.exec(subquery)
rows = result.all() rows = result.all()
# Convert to dictionary format # Convert to dictionary format
playlists = [] playlists = []
for row in rows: for row in rows:
@@ -442,11 +443,11 @@ class PlaylistRepository(BaseRepository[Playlist]):
"sound_count": row.sound_count or 0, "sound_count": row.sound_count or 0,
"total_duration": row.total_duration or 0, "total_duration": row.total_duration or 0,
}) })
return playlists return playlists
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s", "Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
search_query, sort_by, sort_order search_query, sort_by, sort_order,
) )
raise raise

View File

@@ -17,7 +17,7 @@ logger = get_logger(__name__)
class SoundSortField(str, Enum): class SoundSortField(str, Enum):
"""Sound sort field enumeration.""" """Sound sort field enumeration."""
NAME = "name" NAME = "name"
FILENAME = "filename" FILENAME = "filename"
DURATION = "duration" DURATION = "duration"
@@ -30,7 +30,7 @@ class SoundSortField(str, Enum):
class SortOrder(str, Enum): class SortOrder(str, Enum):
"""Sort order enumeration.""" """Sort order enumeration."""
ASC = "asc" ASC = "asc"
DESC = "desc" DESC = "desc"
@@ -144,18 +144,18 @@ class SoundRepository(BaseRepository[Sound]):
"""Search and sort sounds with optional filtering.""" """Search and sort sounds with optional filtering."""
try: try:
statement = select(Sound) statement = select(Sound)
# Apply type filter # Apply type filter
if sound_types: if sound_types:
statement = statement.where(col(Sound.type).in_(sound_types)) statement = statement.where(col(Sound.type).in_(sound_types))
# Apply search filter # Apply search filter
if search_query and search_query.strip(): if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%" search_pattern = f"%{search_query.strip().lower()}%"
statement = statement.where( statement = statement.where(
func.lower(Sound.name).like(search_pattern) func.lower(Sound.name).like(search_pattern),
) )
# Apply sorting # Apply sorting
if sort_by: if sort_by:
sort_column = getattr(Sound, sort_by.value) sort_column = getattr(Sound, sort_by.value)
@@ -166,19 +166,19 @@ class SoundRepository(BaseRepository[Sound]):
else: else:
# Default sorting by name ascending # Default sorting by name ascending
statement = statement.order_by(Sound.name.asc()) statement = statement.order_by(Sound.name.asc())
# Apply pagination # Apply pagination
if offset > 0: if offset > 0:
statement = statement.offset(offset) statement = statement.offset(offset)
if limit is not None: if limit is not None:
statement = statement.limit(limit) statement = statement.limit(limit)
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s", "Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
search_query, sound_types, sort_by, sort_order search_query, sound_types, sort_by, sort_order,
) )
raise raise
@@ -189,17 +189,17 @@ class SoundRepository(BaseRepository[Sound]):
func.count(Sound.id).label("count"), func.count(Sound.id).label("count"),
func.sum(Sound.play_count).label("total_plays"), func.sum(Sound.play_count).label("total_plays"),
func.sum(Sound.duration).label("total_duration"), func.sum(Sound.duration).label("total_duration"),
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size") func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
).where(Sound.type == "SDB") ).where(Sound.type == "SDB")
result = await self.session.exec(statement) result = await self.session.exec(statement)
row = result.first() row = result.first()
return { return {
"count": row.count if row.count is not None else 0, "count": row.count if row.count is not None else 0,
"total_plays": row.total_plays if row.total_plays is not None else 0, "total_plays": row.total_plays if row.total_plays is not None else 0,
"total_duration": row.total_duration if row.total_duration is not None else 0, "total_duration": row.total_duration if row.total_duration is not None else 0,
"total_size": row.total_size if row.total_size is not None else 0 "total_size": row.total_size if row.total_size is not None else 0,
} }
except Exception: except Exception:
logger.exception("Failed to get soundboard statistics") logger.exception("Failed to get soundboard statistics")
@@ -212,17 +212,17 @@ class SoundRepository(BaseRepository[Sound]):
func.count(Sound.id).label("count"), func.count(Sound.id).label("count"),
func.sum(Sound.play_count).label("total_plays"), func.sum(Sound.play_count).label("total_plays"),
func.sum(Sound.duration).label("total_duration"), func.sum(Sound.duration).label("total_duration"),
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size") func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
).where(Sound.type == "EXT") ).where(Sound.type == "EXT")
result = await self.session.exec(statement) result = await self.session.exec(statement)
row = result.first() row = result.first()
return { return {
"count": row.count if row.count is not None else 0, "count": row.count if row.count is not None else 0,
"total_plays": row.total_plays if row.total_plays is not None else 0, "total_plays": row.total_plays if row.total_plays is not None else 0,
"total_duration": row.total_duration if row.total_duration is not None else 0, "total_duration": row.total_duration if row.total_duration is not None else 0,
"total_size": row.total_size if row.total_size is not None else 0 "total_size": row.total_size if row.total_size is not None else 0,
} }
except Exception: except Exception:
logger.exception("Failed to get track statistics") logger.exception("Failed to get track statistics")
@@ -244,20 +244,20 @@ class SoundRepository(BaseRepository[Sound]):
Sound.type, Sound.type,
Sound.duration, Sound.duration,
Sound.created_at, Sound.created_at,
func.count(SoundPlayed.id).label("play_count") func.count(SoundPlayed.id).label("play_count"),
) )
.select_from(SoundPlayed) .select_from(SoundPlayed)
.join(Sound, SoundPlayed.sound_id == Sound.id) .join(Sound, SoundPlayed.sound_id == Sound.id)
) )
# Apply sound type filter # Apply sound type filter
if sound_type != "all": if sound_type != "all":
statement = statement.where(Sound.type == sound_type.upper()) statement = statement.where(Sound.type == sound_type.upper())
# Apply date filter if provided # Apply date filter if provided
if date_filter: if date_filter:
statement = statement.where(SoundPlayed.created_at >= date_filter) statement = statement.where(SoundPlayed.created_at >= date_filter)
# Group by sound and order by play count descending # Group by sound and order by play count descending
statement = ( statement = (
statement statement
@@ -266,15 +266,15 @@ class SoundRepository(BaseRepository[Sound]):
Sound.name, Sound.name,
Sound.type, Sound.type,
Sound.duration, Sound.duration,
Sound.created_at Sound.created_at,
) )
.order_by(func.count(SoundPlayed.id).desc()) .order_by(func.count(SoundPlayed.id).desc())
.limit(limit) .limit(limit)
) )
result = await self.session.exec(statement) result = await self.session.exec(statement)
rows = result.all() rows = result.all()
# Convert to dictionaries with the play count from the period # Convert to dictionaries with the play count from the period
return [ return [
{ {

View File

@@ -2,9 +2,9 @@
from typing import Any from typing import Any
from sqlalchemy.orm import selectinload
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.plan import Plan from app.models.plan import Plan

View File

@@ -96,5 +96,5 @@ class UpdateProfileRequest(BaseModel):
"""Schema for profile update request.""" """Schema for profile update request."""
name: str | None = Field( name: str | None = Field(
None, min_length=1, max_length=100, description="User display name" None, min_length=1, max_length=100, description="User display name",
) )

View File

@@ -434,36 +434,36 @@ class AuthService:
async def update_user_profile(self, user: User, data: dict) -> User: async def update_user_profile(self, user: User, data: dict) -> User:
"""Update user profile information.""" """Update user profile information."""
logger.info("Updating profile for user: %s", user.email) logger.info("Updating profile for user: %s", user.email)
# Only allow updating specific fields # Only allow updating specific fields
allowed_fields = {"name"} allowed_fields = {"name"}
update_data = {k: v for k, v in data.items() if k in allowed_fields} update_data = {k: v for k, v in data.items() if k in allowed_fields}
if not update_data: if not update_data:
return user return user
# Update user # Update user
for field, value in update_data.items(): for field, value in update_data.items():
setattr(user, field, value) setattr(user, field, value)
self.session.add(user) self.session.add(user)
await self.session.commit() await self.session.commit()
await self.session.refresh(user, ["plan"]) await self.session.refresh(user, ["plan"])
logger.info("Profile updated successfully for user: %s", user.email) logger.info("Profile updated successfully for user: %s", user.email)
return user return user
async def change_user_password( async def change_user_password(
self, user: User, current_password: str | None, new_password: str self, user: User, current_password: str | None, new_password: str,
) -> None: ) -> None:
"""Change user's password.""" """Change user's password."""
# Store user email before any operations to avoid session detachment issues # Store user email before any operations to avoid session detachment issues
user_email = user.email user_email = user.email
logger.info("Changing password for user: %s", user_email) logger.info("Changing password for user: %s", user_email)
# Store whether user had existing password before we modify it # Store whether user had existing password before we modify it
had_existing_password = user.password_hash is not None had_existing_password = user.password_hash is not None
# If user has existing password, verify it # If user has existing password, verify it
if had_existing_password: if had_existing_password:
if not current_password: if not current_password:
@@ -473,24 +473,24 @@ class AuthService:
else: else:
# User doesn't have a password (OAuth-only user), so we're setting their first password # User doesn't have a password (OAuth-only user), so we're setting their first password
logger.info("Setting first password for OAuth user: %s", user_email) logger.info("Setting first password for OAuth user: %s", user_email)
# Hash new password # Hash new password
new_password_hash = PasswordUtils.hash_password(new_password) new_password_hash = PasswordUtils.hash_password(new_password)
# Update user # Update user
user.password_hash = new_password_hash user.password_hash = new_password_hash
self.session.add(user) self.session.add(user)
await self.session.commit() await self.session.commit()
logger.info("Password %s successfully for user: %s", logger.info("Password %s successfully for user: %s",
"changed" if had_existing_password else "set", user_email) "changed" if had_existing_password else "set", user_email)
async def user_to_response(self, user: User) -> UserResponse: async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information.""" """Convert User model to UserResponse with plan information."""
# Load plan relationship if not already loaded # Load plan relationship if not already loaded
if not hasattr(user, 'plan') or not user.plan: if not hasattr(user, "plan") or not user.plan:
await self.session.refresh(user, ["plan"]) await self.session.refresh(user, ["plan"])
return UserResponse( return UserResponse(
id=user.id, id=user.id,
email=user.email, email=user.email,

View File

@@ -56,14 +56,14 @@ class DashboardService:
try: try:
# Calculate the date filter based on period # Calculate the date filter based on period
date_filter = self._get_date_filter(period) date_filter = self._get_date_filter(period)
# Get top sounds from repository # Get top sounds from repository
top_sounds = await self.sound_repository.get_top_sounds( top_sounds = await self.sound_repository.get_top_sounds(
sound_type=sound_type, sound_type=sound_type,
date_filter=date_filter, date_filter=date_filter,
limit=limit, limit=limit,
) )
return [ return [
{ {
"id": sound["id"], "id": sound["id"],
@@ -86,7 +86,7 @@ class DashboardService:
period, period,
) )
raise raise
def _get_date_filter(self, period: str) -> datetime | None: def _get_date_filter(self, period: str) -> datetime | None:
"""Calculate the date filter based on the period.""" """Calculate the date filter based on the period."""
now = datetime.now(UTC) now = datetime.now(UTC)

View File

@@ -272,9 +272,8 @@ class PlaylistService:
# Ensure position doesn't create gaps - if position is too high, place at end # Ensure position doesn't create gaps - if position is too high, place at end
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id) current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
max_position = len(current_sounds) max_position = len(current_sounds)
if position > max_position: position = min(position, max_position)
position = max_position
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position) await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
logger.info( logger.info(
"Added sound %s to playlist %s for user %s at position %s", "Added sound %s to playlist %s for user %s at position %s",
@@ -306,10 +305,10 @@ class PlaylistService:
) )
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id) await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
# Reorder remaining sounds to eliminate gaps # Reorder remaining sounds to eliminate gaps
await self._reorder_playlist_positions(playlist_id) await self._reorder_playlist_positions(playlist_id)
logger.info( logger.info(
"Removed sound %s from playlist %s for user %s and reordered positions", "Removed sound %s from playlist %s for user %s and reordered positions",
sound_id, sound_id,
@@ -326,7 +325,7 @@ class PlaylistService:
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id) sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
if not sounds: if not sounds:
return return
# Create sequential positions: 0, 1, 2, 3... # Create sequential positions: 0, 1, 2, 3...
sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)] sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)]
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions) await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)

View File

@@ -1,6 +1,6 @@
"""Tests for admin user endpoints.""" """Tests for admin user endpoints."""
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@@ -76,9 +76,9 @@ class TestAdminUserEndpoints:
"id": test_plan.id, "id": test_plan.id,
"name": test_plan.name, "name": test_plan.name,
"max_credits": test_plan.max_credits, "max_credits": test_plan.max_credits,
})() })(),
})() })()
mock_regular = type("User", (), { mock_regular = type("User", (), {
"id": regular_user.id, "id": regular_user.id,
"email": regular_user.email, "email": regular_user.email,
@@ -93,9 +93,9 @@ class TestAdminUserEndpoints:
"id": test_plan.id, "id": test_plan.id,
"name": test_plan.name, "name": test_plan.name,
"max_credits": test_plan.max_credits, "max_credits": test_plan.max_credits,
})() })(),
})() })()
mock_get_all.return_value = [mock_admin, mock_regular] mock_get_all.return_value = [mock_admin, mock_regular]
response = await authenticated_admin_client.get("/api/v1/admin/users/") response = await authenticated_admin_client.get("/api/v1/admin/users/")
@@ -130,7 +130,7 @@ class TestAdminUserEndpoints:
"id": test_plan.id, "id": test_plan.id,
"name": test_plan.name, "name": test_plan.name,
"max_credits": test_plan.max_credits, "max_credits": test_plan.max_credits,
})() })(),
})() })()
mock_get_all.return_value = [mock_admin] mock_get_all.return_value = [mock_admin]
@@ -185,7 +185,7 @@ class TestAdminUserEndpoints:
"id": test_plan.id, "id": test_plan.id,
"name": test_plan.name, "name": test_plan.name,
"max_credits": test_plan.max_credits, "max_credits": test_plan.max_credits,
})() })(),
})() })()
mock_get_by_id.return_value = mock_user mock_get_by_id.return_value = mock_user
@@ -244,9 +244,9 @@ class TestAdminUserEndpoints:
"id": test_plan.id, "id": test_plan.id,
"name": test_plan.name, "name": test_plan.name,
"max_credits": test_plan.max_credits, "max_credits": test_plan.max_credits,
})() })(),
})() })()
updated_mock = type("User", (), { updated_mock = type("User", (), {
"id": regular_user.id, "id": regular_user.id,
"email": regular_user.email, "email": regular_user.email,
@@ -261,9 +261,9 @@ class TestAdminUserEndpoints:
"id": test_plan.id, "id": test_plan.id,
"name": test_plan.name, "name": test_plan.name,
"max_credits": test_plan.max_credits, "max_credits": test_plan.max_credits,
})() })(),
})() })()
mock_get_by_id.return_value = mock_user mock_get_by_id.return_value = mock_user
mock_update.return_value = updated_mock mock_update.return_value = updated_mock
@@ -278,7 +278,7 @@ class TestAdminUserEndpoints:
"name": "Updated Name", "name": "Updated Name",
"credits": 200, "credits": 200,
"plan_id": 1, "plan_id": 1,
} },
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -299,7 +299,7 @@ class TestAdminUserEndpoints:
): ):
response = await authenticated_admin_client.patch( response = await authenticated_admin_client.patch(
"/api/v1/admin/users/999", "/api/v1/admin/users/999",
json={"name": "Updated Name"} json={"name": "Updated Name"},
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -333,12 +333,12 @@ class TestAdminUserEndpoints:
"id": 1, "id": 1,
"name": "Basic", "name": "Basic",
"max_credits": 100, "max_credits": 100,
})() })(),
})() })()
mock_get_by_id.return_value = mock_user mock_get_by_id.return_value = mock_user
response = await authenticated_admin_client.patch( response = await authenticated_admin_client.patch(
"/api/v1/admin/users/2", "/api/v1/admin/users/2",
json={"plan_id": 999} json={"plan_id": 999},
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -373,7 +373,7 @@ class TestAdminUserEndpoints:
"id": test_plan.id, "id": test_plan.id,
"name": test_plan.name, "name": test_plan.name,
"max_credits": test_plan.max_credits, "max_credits": test_plan.max_credits,
})() })(),
})() })()
mock_get_by_id.return_value = mock_user mock_get_by_id.return_value = mock_user
mock_update.return_value = mock_user mock_update.return_value = mock_user
@@ -438,7 +438,7 @@ class TestAdminUserEndpoints:
"id": test_plan.id, "id": test_plan.id,
"name": test_plan.name, "name": test_plan.name,
"max_credits": test_plan.max_credits, "max_credits": test_plan.max_credits,
})() })(),
})() })()
mock_get_by_id.return_value = mock_disabled_user mock_get_by_id.return_value = mock_disabled_user
mock_update.return_value = mock_disabled_user mock_update.return_value = mock_disabled_user
@@ -487,4 +487,4 @@ class TestAdminUserEndpoints:
data = response.json() data = response.json()
assert len(data) == 2 assert len(data) == 2
assert data[0]["name"] == "Basic" assert data[0]["name"] == "Basic"
assert data[1]["name"] == "Premium" assert data[1]["name"] == "Premium"

View File

@@ -1,5 +1,6 @@
"""Tests for authentication endpoints.""" """Tests for authentication endpoints."""
from datetime import UTC
from typing import Any from typing import Any
from unittest.mock import patch from unittest.mock import patch
@@ -495,7 +496,7 @@ class TestAuthEndpoints:
response = await test_client.post( response = await test_client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh",
cookies={"refresh_token": "valid_refresh_token"} cookies={"refresh_token": "valid_refresh_token"},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -520,7 +521,7 @@ class TestAuthEndpoints:
response = await test_client.post( response = await test_client.post(
"/api/v1/auth/refresh", "/api/v1/auth/refresh",
cookies={"refresh_token": "valid_refresh_token"} cookies={"refresh_token": "valid_refresh_token"},
) )
assert response.status_code == 500 assert response.status_code == 500
@@ -536,7 +537,7 @@ class TestAuthEndpoints:
"""Test OAuth token exchange with invalid code.""" """Test OAuth token exchange with invalid code."""
response = await test_client.post( response = await test_client.post(
"/api/v1/auth/exchange-oauth-token", "/api/v1/auth/exchange-oauth-token",
json={"code": "invalid_code"} json={"code": "invalid_code"},
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -565,7 +566,7 @@ class TestAuthEndpoints:
is_active=test_user.is_active, is_active=test_user.is_active,
) )
mock_update.return_value = updated_user mock_update.return_value = updated_user
# Mock the user_to_response to return UserResponse format # Mock the user_to_response to return UserResponse format
from app.schemas.auth import UserResponse from app.schemas.auth import UserResponse
mock_user_to_response.return_value = UserResponse( mock_user_to_response.return_value = UserResponse(
@@ -589,7 +590,7 @@ class TestAuthEndpoints:
response = await test_client.patch( response = await test_client.patch(
"/api/v1/auth/me", "/api/v1/auth/me",
json={"name": "Updated Name"}, json={"name": "Updated Name"},
cookies=auth_cookies cookies=auth_cookies,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -601,7 +602,7 @@ class TestAuthEndpoints:
"""Test update profile without authentication.""" """Test update profile without authentication."""
response = await test_client.patch( response = await test_client.patch(
"/api/v1/auth/me", "/api/v1/auth/me",
json={"name": "Updated Name"} json={"name": "Updated Name"},
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -621,9 +622,9 @@ class TestAuthEndpoints:
"/api/v1/auth/change-password", "/api/v1/auth/change-password",
json={ json={
"current_password": "old_password", "current_password": "old_password",
"new_password": "new_password" "new_password": "new_password",
}, },
cookies=auth_cookies cookies=auth_cookies,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -637,8 +638,8 @@ class TestAuthEndpoints:
"/api/v1/auth/change-password", "/api/v1/auth/change-password",
json={ json={
"current_password": "old_password", "current_password": "old_password",
"new_password": "new_password" "new_password": "new_password",
} },
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -652,9 +653,10 @@ class TestAuthEndpoints:
) -> None: ) -> None:
"""Test get user OAuth providers success.""" """Test get user OAuth providers success."""
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers: with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
from datetime import datetime
from app.models.user_oauth import UserOauth from app.models.user_oauth import UserOauth
from datetime import datetime, timezone
mock_oauth_google = UserOauth( mock_oauth_google = UserOauth(
id=1, id=1,
user_id=test_user.id, user_id=test_user.id,
@@ -662,34 +664,34 @@ class TestAuthEndpoints:
provider_user_id="google123", provider_user_id="google123",
email="test@example.com", email="test@example.com",
name="Test User", name="Test User",
created_at=datetime.now(timezone.utc), created_at=datetime.now(UTC),
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(UTC),
) )
mock_oauth_github = UserOauth( mock_oauth_github = UserOauth(
id=2, id=2,
user_id=test_user.id, user_id=test_user.id,
provider="github", provider="github",
provider_user_id="github456", provider_user_id="github456",
email="test@example.com", email="test@example.com",
name="Test User", name="Test User",
created_at=datetime.now(timezone.utc), created_at=datetime.now(UTC),
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(UTC),
) )
mock_providers.return_value = [mock_oauth_google, mock_oauth_github] mock_providers.return_value = [mock_oauth_google, mock_oauth_github]
response = await test_client.get( response = await test_client.get(
"/api/v1/auth/user-providers", "/api/v1/auth/user-providers",
cookies=auth_cookies cookies=auth_cookies,
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 3 # password + 2 OAuth providers assert len(data) == 3 # password + 2 OAuth providers
# Check password provider (first) # Check password provider (first)
assert data[0]["provider"] == "password" assert data[0]["provider"] == "password"
assert data[0]["display_name"] == "Password" assert data[0]["display_name"] == "Password"
# Check OAuth providers # Check OAuth providers
assert data[1]["provider"] == "google" assert data[1]["provider"] == "google"
assert data[1]["display_name"] == "Google" assert data[1]["display_name"] == "Google"

View File

@@ -529,7 +529,7 @@ class TestPlaylistRepository:
) )
test_session.add(sound) test_session.add(sound)
sounds.append(sound) sounds.append(sound)
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
for sound in sounds: for sound in sounds:
@@ -547,7 +547,7 @@ class TestPlaylistRepository:
# Verify the final positions # Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id) playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
assert len(playlist_sounds) == 3 assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0 assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0
assert playlist_sounds[0].position == 0 assert playlist_sounds[0].position == 0
@@ -605,7 +605,7 @@ class TestPlaylistRepository:
) )
test_session.add(sound) test_session.add(sound)
sounds.append(sound) sounds.append(sound)
await test_session.commit() await test_session.commit()
await test_session.refresh(playlist) await test_session.refresh(playlist)
for sound in sounds: for sound in sounds:
@@ -623,7 +623,7 @@ class TestPlaylistRepository:
# Verify the final positions # Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id) playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
assert len(playlist_sounds) == 3 assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0 assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0
assert playlist_sounds[0].position == 0 assert playlist_sounds[0].position == 0

View File

@@ -409,7 +409,7 @@ class TestCreditService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_recharge_user_credits_success( async def test_recharge_user_credits_success(
self, credit_service, sample_user self, credit_service, sample_user,
) -> None: ) -> None:
"""Test successful credit recharge for a user.""" """Test successful credit recharge for a user."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()

View File

@@ -1,6 +1,6 @@
"""Tests for dashboard service.""" """Tests for dashboard service."""
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
@@ -63,7 +63,7 @@ class TestDashboardService:
): ):
"""Test getting soundboard statistics with exception.""" """Test getting soundboard statistics with exception."""
mock_sound_repository.get_soundboard_statistics = AsyncMock( mock_sound_repository.get_soundboard_statistics = AsyncMock(
side_effect=Exception("Database error") side_effect=Exception("Database error"),
) )
with pytest.raises(Exception, match="Database error"): with pytest.raises(Exception, match="Database error"):
@@ -105,7 +105,7 @@ class TestDashboardService:
): ):
"""Test getting track statistics with exception.""" """Test getting track statistics with exception."""
mock_sound_repository.get_track_statistics = AsyncMock( mock_sound_repository.get_track_statistics = AsyncMock(
side_effect=Exception("Database error") side_effect=Exception("Database error"),
) )
with pytest.raises(Exception, match="Database error"): with pytest.raises(Exception, match="Database error"):
@@ -198,7 +198,7 @@ class TestDashboardService:
): ):
"""Test getting top sounds with exception.""" """Test getting top sounds with exception."""
mock_sound_repository.get_top_sounds = AsyncMock( mock_sound_repository.get_top_sounds = AsyncMock(
side_effect=Exception("Database error") side_effect=Exception("Database error"),
) )
with pytest.raises(Exception, match="Database error"): with pytest.raises(Exception, match="Database error"):
@@ -274,4 +274,4 @@ class TestDashboardService:
def test_get_date_filter_unknown_period(self, dashboard_service): def test_get_date_filter_unknown_period(self, dashboard_service):
"""Test date filter for unknown period.""" """Test date filter for unknown period."""
result = dashboard_service._get_date_filter("unknown_period") result = dashboard_service._get_date_filter("unknown_period")
assert result is None assert result is None

View File

@@ -25,9 +25,9 @@ class TestSchedulerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_scheduler(self, scheduler_service) -> None: async def test_start_scheduler(self, scheduler_service) -> None:
"""Test starting the scheduler service.""" """Test starting the scheduler service."""
with patch.object(scheduler_service.scheduler, 'add_job') as mock_add_job, \ with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
patch.object(scheduler_service.scheduler, 'start') as mock_start: patch.object(scheduler_service.scheduler, "start") as mock_start:
await scheduler_service.start() await scheduler_service.start()
# Verify job was added # Verify job was added
@@ -47,7 +47,7 @@ class TestSchedulerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_scheduler(self, scheduler_service) -> None: async def test_stop_scheduler(self, scheduler_service) -> None:
"""Test stopping the scheduler service.""" """Test stopping the scheduler service."""
with patch.object(scheduler_service.scheduler, 'shutdown') as mock_shutdown: with patch.object(scheduler_service.scheduler, "shutdown") as mock_shutdown:
await scheduler_service.stop() await scheduler_service.stop()
mock_shutdown.assert_called_once() mock_shutdown.assert_called_once()
@@ -61,7 +61,7 @@ class TestSchedulerService:
"total_credits_added": 500, "total_credits_added": 500,
} }
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge: with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
mock_recharge.return_value = mock_stats mock_recharge.return_value = mock_stats
await scheduler_service._daily_credit_recharge() await scheduler_service._daily_credit_recharge()
@@ -71,10 +71,10 @@ class TestSchedulerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None: async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
"""Test daily credit recharge task with failure.""" """Test daily credit recharge task with failure."""
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge: with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
mock_recharge.side_effect = Exception("Database error") mock_recharge.side_effect = Exception("Database error")
# Should not raise exception, just log it # Should not raise exception, just log it
await scheduler_service._daily_credit_recharge() await scheduler_service._daily_credit_recharge()
mock_recharge.assert_called_once() mock_recharge.assert_called_once()