fix: Add missing commas in function calls and improve code formatting
This commit is contained in:
@@ -460,7 +460,7 @@ async def update_profile(
|
||||
"""Update the current user's profile."""
|
||||
try:
|
||||
updated_user = await auth_service.update_user_profile(
|
||||
current_user, request.model_dump(exclude_unset=True)
|
||||
current_user, request.model_dump(exclude_unset=True),
|
||||
)
|
||||
return await auth_service.user_to_response(updated_user)
|
||||
except Exception as e:
|
||||
@@ -482,7 +482,7 @@ async def change_password(
|
||||
user_email = current_user.email
|
||||
try:
|
||||
await auth_service.change_user_password(
|
||||
current_user, request.current_password, request.new_password
|
||||
current_user, request.current_password, request.new_password,
|
||||
)
|
||||
return {"message": "Password changed successfully"}
|
||||
except ValueError as e:
|
||||
@@ -505,7 +505,7 @@ async def get_user_providers(
|
||||
) -> list[dict[str, str]]:
|
||||
"""Get the current user's connected authentication providers."""
|
||||
providers = []
|
||||
|
||||
|
||||
# Add password provider if user has password
|
||||
if current_user.password_hash:
|
||||
providers.append({
|
||||
@@ -513,7 +513,7 @@ async def get_user_providers(
|
||||
"display_name": "Password",
|
||||
"connected_at": current_user.created_at.isoformat(),
|
||||
})
|
||||
|
||||
|
||||
# Get OAuth providers from the database
|
||||
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
||||
for oauth in oauth_providers:
|
||||
@@ -522,11 +522,11 @@ async def get_user_providers(
|
||||
display_name = "GitHub"
|
||||
elif oauth.provider == "google":
|
||||
display_name = "Google"
|
||||
|
||||
|
||||
providers.append({
|
||||
"provider": oauth.provider,
|
||||
"display_name": display_name,
|
||||
"connected_at": oauth.created_at.isoformat(),
|
||||
})
|
||||
|
||||
|
||||
return providers
|
||||
|
||||
@@ -41,5 +41,5 @@ async def get_top_sounds(
|
||||
return await dashboard_service.get_top_sounds(
|
||||
sound_type=sound_type,
|
||||
period=period,
|
||||
limit=limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.user import User
|
||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.playlist import (
|
||||
PlaylistAddSoundRequest,
|
||||
@@ -19,7 +20,6 @@ from app.schemas.playlist import (
|
||||
PlaylistUpdateRequest,
|
||||
)
|
||||
from app.services.playlist import PlaylistService
|
||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||
|
||||
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.models.sound import Sound
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository, SoundSortField, SortOrder
|
||||
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.core.database import get_db
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.auth import AuthService
|
||||
from app.services.dashboard import DashboardService
|
||||
from app.services.oauth import OAuthService
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.utils.auth import JWTUtils, TokenUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Playlist repository for database operations."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
@@ -18,7 +19,7 @@ logger = get_logger(__name__)
|
||||
|
||||
class PlaylistSortField(str, Enum):
|
||||
"""Playlist sort field enumeration."""
|
||||
|
||||
|
||||
NAME = "name"
|
||||
GENRE = "genre"
|
||||
CREATED_AT = "created_at"
|
||||
@@ -29,7 +30,7 @@ class PlaylistSortField(str, Enum):
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order enumeration."""
|
||||
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
@@ -154,7 +155,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
# Use a two-step approach to avoid unique constraint violations:
|
||||
# 1. Move all affected positions to negative temporary positions
|
||||
# 2. Then move them to their final positions
|
||||
|
||||
|
||||
# Step 1: Move to temporary negative positions
|
||||
update_to_negative = (
|
||||
update(PlaylistSound)
|
||||
@@ -166,8 +167,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
)
|
||||
await self.session.exec(update_to_negative)
|
||||
await self.session.commit()
|
||||
|
||||
# Step 2: Move from temporary negative positions to final positions
|
||||
|
||||
# Step 2: Move from temporary negative positions to final positions
|
||||
update_to_final = (
|
||||
update(PlaylistSound)
|
||||
.where(
|
||||
@@ -337,15 +338,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
||||
.group_by(Playlist.id, User.name)
|
||||
)
|
||||
|
||||
|
||||
# Apply filters
|
||||
if search_query and search_query.strip():
|
||||
search_pattern = f"%{search_query.strip().lower()}%"
|
||||
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
|
||||
|
||||
|
||||
if user_id is not None:
|
||||
subquery = subquery.where(Playlist.user_id == user_id)
|
||||
|
||||
|
||||
# Apply sorting
|
||||
if sort_by == PlaylistSortField.SOUND_COUNT:
|
||||
if sort_order == SortOrder.DESC:
|
||||
@@ -360,7 +361,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
else:
|
||||
# Default sorting by name
|
||||
subquery = subquery.order_by(Playlist.name.asc())
|
||||
|
||||
|
||||
else:
|
||||
# Simple query without stats-based sorting
|
||||
subquery = (
|
||||
@@ -385,15 +386,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
||||
.group_by(Playlist.id, User.name)
|
||||
)
|
||||
|
||||
|
||||
# Apply filters
|
||||
if search_query and search_query.strip():
|
||||
search_pattern = f"%{search_query.strip().lower()}%"
|
||||
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
|
||||
|
||||
|
||||
if user_id is not None:
|
||||
subquery = subquery.where(Playlist.user_id == user_id)
|
||||
|
||||
|
||||
# Apply sorting
|
||||
if sort_by:
|
||||
if sort_by == PlaylistSortField.NAME:
|
||||
@@ -406,7 +407,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
sort_column = Playlist.updated_at
|
||||
else:
|
||||
sort_column = Playlist.name
|
||||
|
||||
|
||||
if sort_order == SortOrder.DESC:
|
||||
subquery = subquery.order_by(sort_column.desc())
|
||||
else:
|
||||
@@ -414,16 +415,16 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
else:
|
||||
# Default sorting by name ascending
|
||||
subquery = subquery.order_by(Playlist.name.asc())
|
||||
|
||||
|
||||
# Apply pagination
|
||||
if offset > 0:
|
||||
subquery = subquery.offset(offset)
|
||||
if limit is not None:
|
||||
subquery = subquery.limit(limit)
|
||||
|
||||
|
||||
result = await self.session.exec(subquery)
|
||||
rows = result.all()
|
||||
|
||||
|
||||
# Convert to dictionary format
|
||||
playlists = []
|
||||
for row in rows:
|
||||
@@ -442,11 +443,11 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
"sound_count": row.sound_count or 0,
|
||||
"total_duration": row.total_duration or 0,
|
||||
})
|
||||
|
||||
|
||||
return playlists
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
|
||||
search_query, sort_by, sort_order
|
||||
search_query, sort_by, sort_order,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -17,7 +17,7 @@ logger = get_logger(__name__)
|
||||
|
||||
class SoundSortField(str, Enum):
|
||||
"""Sound sort field enumeration."""
|
||||
|
||||
|
||||
NAME = "name"
|
||||
FILENAME = "filename"
|
||||
DURATION = "duration"
|
||||
@@ -30,7 +30,7 @@ class SoundSortField(str, Enum):
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order enumeration."""
|
||||
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
@@ -144,18 +144,18 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
"""Search and sort sounds with optional filtering."""
|
||||
try:
|
||||
statement = select(Sound)
|
||||
|
||||
|
||||
# Apply type filter
|
||||
if sound_types:
|
||||
statement = statement.where(col(Sound.type).in_(sound_types))
|
||||
|
||||
|
||||
# Apply search filter
|
||||
if search_query and search_query.strip():
|
||||
search_pattern = f"%{search_query.strip().lower()}%"
|
||||
statement = statement.where(
|
||||
func.lower(Sound.name).like(search_pattern)
|
||||
func.lower(Sound.name).like(search_pattern),
|
||||
)
|
||||
|
||||
|
||||
# Apply sorting
|
||||
if sort_by:
|
||||
sort_column = getattr(Sound, sort_by.value)
|
||||
@@ -166,19 +166,19 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
else:
|
||||
# Default sorting by name ascending
|
||||
statement = statement.order_by(Sound.name.asc())
|
||||
|
||||
|
||||
# Apply pagination
|
||||
if offset > 0:
|
||||
statement = statement.offset(offset)
|
||||
if limit is not None:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
|
||||
search_query, sound_types, sort_by, sort_order
|
||||
search_query, sound_types, sort_by, sort_order,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -189,17 +189,17 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
func.count(Sound.id).label("count"),
|
||||
func.sum(Sound.play_count).label("total_plays"),
|
||||
func.sum(Sound.duration).label("total_duration"),
|
||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
|
||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
|
||||
).where(Sound.type == "SDB")
|
||||
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
row = result.first()
|
||||
|
||||
|
||||
return {
|
||||
"count": row.count if row.count is not None else 0,
|
||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
||||
"total_size": row.total_size if row.total_size is not None else 0
|
||||
"total_size": row.total_size if row.total_size is not None else 0,
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to get soundboard statistics")
|
||||
@@ -212,17 +212,17 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
func.count(Sound.id).label("count"),
|
||||
func.sum(Sound.play_count).label("total_plays"),
|
||||
func.sum(Sound.duration).label("total_duration"),
|
||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
|
||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
|
||||
).where(Sound.type == "EXT")
|
||||
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
row = result.first()
|
||||
|
||||
|
||||
return {
|
||||
"count": row.count if row.count is not None else 0,
|
||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
||||
"total_size": row.total_size if row.total_size is not None else 0
|
||||
"total_size": row.total_size if row.total_size is not None else 0,
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to get track statistics")
|
||||
@@ -244,20 +244,20 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
Sound.type,
|
||||
Sound.duration,
|
||||
Sound.created_at,
|
||||
func.count(SoundPlayed.id).label("play_count")
|
||||
func.count(SoundPlayed.id).label("play_count"),
|
||||
)
|
||||
.select_from(SoundPlayed)
|
||||
.join(Sound, SoundPlayed.sound_id == Sound.id)
|
||||
)
|
||||
|
||||
|
||||
# Apply sound type filter
|
||||
if sound_type != "all":
|
||||
statement = statement.where(Sound.type == sound_type.upper())
|
||||
|
||||
|
||||
# Apply date filter if provided
|
||||
if date_filter:
|
||||
statement = statement.where(SoundPlayed.created_at >= date_filter)
|
||||
|
||||
|
||||
# Group by sound and order by play count descending
|
||||
statement = (
|
||||
statement
|
||||
@@ -266,15 +266,15 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
Sound.name,
|
||||
Sound.type,
|
||||
Sound.duration,
|
||||
Sound.created_at
|
||||
Sound.created_at,
|
||||
)
|
||||
.order_by(func.count(SoundPlayed.id).desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
rows = result.all()
|
||||
|
||||
|
||||
# Convert to dictionaries with the play count from the period
|
||||
return [
|
||||
{
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.plan import Plan
|
||||
|
||||
@@ -96,5 +96,5 @@ class UpdateProfileRequest(BaseModel):
|
||||
"""Schema for profile update request."""
|
||||
|
||||
name: str | None = Field(
|
||||
None, min_length=1, max_length=100, description="User display name"
|
||||
None, min_length=1, max_length=100, description="User display name",
|
||||
)
|
||||
|
||||
@@ -434,36 +434,36 @@ class AuthService:
|
||||
async def update_user_profile(self, user: User, data: dict) -> User:
|
||||
"""Update user profile information."""
|
||||
logger.info("Updating profile for user: %s", user.email)
|
||||
|
||||
|
||||
# Only allow updating specific fields
|
||||
allowed_fields = {"name"}
|
||||
update_data = {k: v for k, v in data.items() if k in allowed_fields}
|
||||
|
||||
|
||||
if not update_data:
|
||||
return user
|
||||
|
||||
|
||||
# Update user
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(user, ["plan"])
|
||||
|
||||
|
||||
logger.info("Profile updated successfully for user: %s", user.email)
|
||||
return user
|
||||
|
||||
async def change_user_password(
|
||||
self, user: User, current_password: str | None, new_password: str
|
||||
self, user: User, current_password: str | None, new_password: str,
|
||||
) -> None:
|
||||
"""Change user's password."""
|
||||
# Store user email before any operations to avoid session detachment issues
|
||||
user_email = user.email
|
||||
logger.info("Changing password for user: %s", user_email)
|
||||
|
||||
|
||||
# Store whether user had existing password before we modify it
|
||||
had_existing_password = user.password_hash is not None
|
||||
|
||||
|
||||
# If user has existing password, verify it
|
||||
if had_existing_password:
|
||||
if not current_password:
|
||||
@@ -473,24 +473,24 @@ class AuthService:
|
||||
else:
|
||||
# User doesn't have a password (OAuth-only user), so we're setting their first password
|
||||
logger.info("Setting first password for OAuth user: %s", user_email)
|
||||
|
||||
|
||||
# Hash new password
|
||||
new_password_hash = PasswordUtils.hash_password(new_password)
|
||||
|
||||
|
||||
# Update user
|
||||
user.password_hash = new_password_hash
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Password %s successfully for user: %s",
|
||||
|
||||
logger.info("Password %s successfully for user: %s",
|
||||
"changed" if had_existing_password else "set", user_email)
|
||||
|
||||
async def user_to_response(self, user: User) -> UserResponse:
|
||||
"""Convert User model to UserResponse with plan information."""
|
||||
# Load plan relationship if not already loaded
|
||||
if not hasattr(user, 'plan') or not user.plan:
|
||||
if not hasattr(user, "plan") or not user.plan:
|
||||
await self.session.refresh(user, ["plan"])
|
||||
|
||||
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
|
||||
@@ -56,14 +56,14 @@ class DashboardService:
|
||||
try:
|
||||
# Calculate the date filter based on period
|
||||
date_filter = self._get_date_filter(period)
|
||||
|
||||
|
||||
# Get top sounds from repository
|
||||
top_sounds = await self.sound_repository.get_top_sounds(
|
||||
sound_type=sound_type,
|
||||
date_filter=date_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"id": sound["id"],
|
||||
@@ -86,7 +86,7 @@ class DashboardService:
|
||||
period,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _get_date_filter(self, period: str) -> datetime | None:
|
||||
"""Calculate the date filter based on the period."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
@@ -272,9 +272,8 @@ class PlaylistService:
|
||||
# Ensure position doesn't create gaps - if position is too high, place at end
|
||||
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||
max_position = len(current_sounds)
|
||||
if position > max_position:
|
||||
position = max_position
|
||||
|
||||
position = min(position, max_position)
|
||||
|
||||
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
|
||||
logger.info(
|
||||
"Added sound %s to playlist %s for user %s at position %s",
|
||||
@@ -306,10 +305,10 @@ class PlaylistService:
|
||||
)
|
||||
|
||||
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
|
||||
|
||||
|
||||
# Reorder remaining sounds to eliminate gaps
|
||||
await self._reorder_playlist_positions(playlist_id)
|
||||
|
||||
|
||||
logger.info(
|
||||
"Removed sound %s from playlist %s for user %s and reordered positions",
|
||||
sound_id,
|
||||
@@ -326,7 +325,7 @@ class PlaylistService:
|
||||
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||
if not sounds:
|
||||
return
|
||||
|
||||
|
||||
# Create sequential positions: 0, 1, 2, 3...
|
||||
sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)]
|
||||
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for admin user endpoints."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
@@ -76,9 +76,9 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
|
||||
|
||||
mock_regular = type("User", (), {
|
||||
"id": regular_user.id,
|
||||
"email": regular_user.email,
|
||||
@@ -93,9 +93,9 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
|
||||
|
||||
mock_get_all.return_value = [mock_admin, mock_regular]
|
||||
|
||||
response = await authenticated_admin_client.get("/api/v1/admin/users/")
|
||||
@@ -130,7 +130,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_all.return_value = [mock_admin]
|
||||
|
||||
@@ -185,7 +185,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_by_id.return_value = mock_user
|
||||
|
||||
@@ -244,9 +244,9 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
|
||||
|
||||
updated_mock = type("User", (), {
|
||||
"id": regular_user.id,
|
||||
"email": regular_user.email,
|
||||
@@ -261,9 +261,9 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
|
||||
|
||||
mock_get_by_id.return_value = mock_user
|
||||
mock_update.return_value = updated_mock
|
||||
|
||||
@@ -278,7 +278,7 @@ class TestAdminUserEndpoints:
|
||||
"name": "Updated Name",
|
||||
"credits": 200,
|
||||
"plan_id": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -299,7 +299,7 @@ class TestAdminUserEndpoints:
|
||||
):
|
||||
response = await authenticated_admin_client.patch(
|
||||
"/api/v1/admin/users/999",
|
||||
json={"name": "Updated Name"}
|
||||
json={"name": "Updated Name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -333,12 +333,12 @@ class TestAdminUserEndpoints:
|
||||
"id": 1,
|
||||
"name": "Basic",
|
||||
"max_credits": 100,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_by_id.return_value = mock_user
|
||||
response = await authenticated_admin_client.patch(
|
||||
"/api/v1/admin/users/2",
|
||||
json={"plan_id": 999}
|
||||
json={"plan_id": 999},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -373,7 +373,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_by_id.return_value = mock_user
|
||||
mock_update.return_value = mock_user
|
||||
@@ -438,7 +438,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_by_id.return_value = mock_disabled_user
|
||||
mock_update.return_value = mock_disabled_user
|
||||
@@ -487,4 +487,4 @@ class TestAdminUserEndpoints:
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["name"] == "Basic"
|
||||
assert data[1]["name"] == "Premium"
|
||||
assert data[1]["name"] == "Premium"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for authentication endpoints."""
|
||||
|
||||
from datetime import UTC
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -495,7 +496,7 @@ class TestAuthEndpoints:
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
cookies={"refresh_token": "valid_refresh_token"}
|
||||
cookies={"refresh_token": "valid_refresh_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -520,7 +521,7 @@ class TestAuthEndpoints:
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
cookies={"refresh_token": "valid_refresh_token"}
|
||||
cookies={"refresh_token": "valid_refresh_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
@@ -536,7 +537,7 @@ class TestAuthEndpoints:
|
||||
"""Test OAuth token exchange with invalid code."""
|
||||
response = await test_client.post(
|
||||
"/api/v1/auth/exchange-oauth-token",
|
||||
json={"code": "invalid_code"}
|
||||
json={"code": "invalid_code"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -565,7 +566,7 @@ class TestAuthEndpoints:
|
||||
is_active=test_user.is_active,
|
||||
)
|
||||
mock_update.return_value = updated_user
|
||||
|
||||
|
||||
# Mock the user_to_response to return UserResponse format
|
||||
from app.schemas.auth import UserResponse
|
||||
mock_user_to_response.return_value = UserResponse(
|
||||
@@ -589,7 +590,7 @@ class TestAuthEndpoints:
|
||||
response = await test_client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"name": "Updated Name"},
|
||||
cookies=auth_cookies
|
||||
cookies=auth_cookies,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -601,7 +602,7 @@ class TestAuthEndpoints:
|
||||
"""Test update profile without authentication."""
|
||||
response = await test_client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"name": "Updated Name"}
|
||||
json={"name": "Updated Name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -621,9 +622,9 @@ class TestAuthEndpoints:
|
||||
"/api/v1/auth/change-password",
|
||||
json={
|
||||
"current_password": "old_password",
|
||||
"new_password": "new_password"
|
||||
"new_password": "new_password",
|
||||
},
|
||||
cookies=auth_cookies
|
||||
cookies=auth_cookies,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -637,8 +638,8 @@ class TestAuthEndpoints:
|
||||
"/api/v1/auth/change-password",
|
||||
json={
|
||||
"current_password": "old_password",
|
||||
"new_password": "new_password"
|
||||
}
|
||||
"new_password": "new_password",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -652,9 +653,10 @@ class TestAuthEndpoints:
|
||||
) -> None:
|
||||
"""Test get user OAuth providers success."""
|
||||
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.user_oauth import UserOauth
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
mock_oauth_google = UserOauth(
|
||||
id=1,
|
||||
user_id=test_user.id,
|
||||
@@ -662,34 +664,34 @@ class TestAuthEndpoints:
|
||||
provider_user_id="google123",
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
mock_oauth_github = UserOauth(
|
||||
id=2,
|
||||
user_id=test_user.id,
|
||||
provider="github",
|
||||
provider="github",
|
||||
provider_user_id="github456",
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
mock_providers.return_value = [mock_oauth_google, mock_oauth_github]
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/v1/auth/user-providers",
|
||||
cookies=auth_cookies
|
||||
cookies=auth_cookies,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 3 # password + 2 OAuth providers
|
||||
|
||||
|
||||
# Check password provider (first)
|
||||
assert data[0]["provider"] == "password"
|
||||
assert data[0]["display_name"] == "Password"
|
||||
|
||||
|
||||
# Check OAuth providers
|
||||
assert data[1]["provider"] == "google"
|
||||
assert data[1]["display_name"] == "Google"
|
||||
|
||||
@@ -529,7 +529,7 @@ class TestPlaylistRepository:
|
||||
)
|
||||
test_session.add(sound)
|
||||
sounds.append(sound)
|
||||
|
||||
|
||||
await test_session.commit()
|
||||
await test_session.refresh(playlist)
|
||||
for sound in sounds:
|
||||
@@ -547,7 +547,7 @@ class TestPlaylistRepository:
|
||||
|
||||
# Verify the final positions
|
||||
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
|
||||
|
||||
|
||||
assert len(playlist_sounds) == 3
|
||||
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0
|
||||
assert playlist_sounds[0].position == 0
|
||||
@@ -605,7 +605,7 @@ class TestPlaylistRepository:
|
||||
)
|
||||
test_session.add(sound)
|
||||
sounds.append(sound)
|
||||
|
||||
|
||||
await test_session.commit()
|
||||
await test_session.refresh(playlist)
|
||||
for sound in sounds:
|
||||
@@ -623,7 +623,7 @@ class TestPlaylistRepository:
|
||||
|
||||
# Verify the final positions
|
||||
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
|
||||
|
||||
|
||||
assert len(playlist_sounds) == 3
|
||||
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0
|
||||
assert playlist_sounds[0].position == 0
|
||||
|
||||
@@ -409,7 +409,7 @@ class TestCreditService:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recharge_user_credits_success(
|
||||
self, credit_service, sample_user
|
||||
self, credit_service, sample_user,
|
||||
) -> None:
|
||||
"""Test successful credit recharge for a user."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for dashboard service."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -63,7 +63,7 @@ class TestDashboardService:
|
||||
):
|
||||
"""Test getting soundboard statistics with exception."""
|
||||
mock_sound_repository.get_soundboard_statistics = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
side_effect=Exception("Database error"),
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
@@ -105,7 +105,7 @@ class TestDashboardService:
|
||||
):
|
||||
"""Test getting track statistics with exception."""
|
||||
mock_sound_repository.get_track_statistics = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
side_effect=Exception("Database error"),
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
@@ -198,7 +198,7 @@ class TestDashboardService:
|
||||
):
|
||||
"""Test getting top sounds with exception."""
|
||||
mock_sound_repository.get_top_sounds = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
side_effect=Exception("Database error"),
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
@@ -274,4 +274,4 @@ class TestDashboardService:
|
||||
def test_get_date_filter_unknown_period(self, dashboard_service):
|
||||
"""Test date filter for unknown period."""
|
||||
result = dashboard_service._get_date_filter("unknown_period")
|
||||
assert result is None
|
||||
assert result is None
|
||||
|
||||
@@ -25,9 +25,9 @@ class TestSchedulerService:
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_scheduler(self, scheduler_service) -> None:
|
||||
"""Test starting the scheduler service."""
|
||||
with patch.object(scheduler_service.scheduler, 'add_job') as mock_add_job, \
|
||||
patch.object(scheduler_service.scheduler, 'start') as mock_start:
|
||||
|
||||
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
|
||||
patch.object(scheduler_service.scheduler, "start") as mock_start:
|
||||
|
||||
await scheduler_service.start()
|
||||
|
||||
# Verify job was added
|
||||
@@ -47,7 +47,7 @@ class TestSchedulerService:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_scheduler(self, scheduler_service) -> None:
|
||||
"""Test stopping the scheduler service."""
|
||||
with patch.object(scheduler_service.scheduler, 'shutdown') as mock_shutdown:
|
||||
with patch.object(scheduler_service.scheduler, "shutdown") as mock_shutdown:
|
||||
await scheduler_service.stop()
|
||||
mock_shutdown.assert_called_once()
|
||||
|
||||
@@ -61,7 +61,7 @@ class TestSchedulerService:
|
||||
"total_credits_added": 500,
|
||||
}
|
||||
|
||||
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
|
||||
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
||||
mock_recharge.return_value = mock_stats
|
||||
|
||||
await scheduler_service._daily_credit_recharge()
|
||||
@@ -71,10 +71,10 @@ class TestSchedulerService:
|
||||
@pytest.mark.asyncio
|
||||
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
|
||||
"""Test daily credit recharge task with failure."""
|
||||
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
|
||||
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
||||
mock_recharge.side_effect = Exception("Database error")
|
||||
|
||||
# Should not raise exception, just log it
|
||||
await scheduler_service._daily_credit_recharge()
|
||||
|
||||
mock_recharge.assert_called_once()
|
||||
mock_recharge.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user