fix: Add missing commas in function calls and improve code formatting
This commit is contained in:
@@ -460,7 +460,7 @@ async def update_profile(
|
||||
"""Update the current user's profile."""
|
||||
try:
|
||||
updated_user = await auth_service.update_user_profile(
|
||||
current_user, request.model_dump(exclude_unset=True)
|
||||
current_user, request.model_dump(exclude_unset=True),
|
||||
)
|
||||
return await auth_service.user_to_response(updated_user)
|
||||
except Exception as e:
|
||||
@@ -482,7 +482,7 @@ async def change_password(
|
||||
user_email = current_user.email
|
||||
try:
|
||||
await auth_service.change_user_password(
|
||||
current_user, request.current_password, request.new_password
|
||||
current_user, request.current_password, request.new_password,
|
||||
)
|
||||
return {"message": "Password changed successfully"}
|
||||
except ValueError as e:
|
||||
|
||||
@@ -41,5 +41,5 @@ async def get_top_sounds(
|
||||
return await dashboard_service.get_top_sounds(
|
||||
sound_type=sound_type,
|
||||
period=period,
|
||||
limit=limit
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.user import User
|
||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||
from app.schemas.common import MessageResponse
|
||||
from app.schemas.playlist import (
|
||||
PlaylistAddSoundRequest,
|
||||
@@ -19,7 +20,6 @@ from app.schemas.playlist import (
|
||||
PlaylistUpdateRequest,
|
||||
)
|
||||
from app.services.playlist import PlaylistService
|
||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||
|
||||
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from app.core.dependencies import get_current_active_user_flexible
|
||||
from app.models.credit_action import CreditActionType
|
||||
from app.models.sound import Sound
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository, SoundSortField, SortOrder
|
||||
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
|
||||
from app.services.credit import CreditService, InsufficientCreditsError
|
||||
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
||||
|
||||
|
||||
@@ -8,10 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.core.database import get_db
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.auth import AuthService
|
||||
from app.services.dashboard import DashboardService
|
||||
from app.services.oauth import OAuthService
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.utils.auth import JWTUtils, TokenUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Playlist repository for database operations."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from sqlalchemy import func, update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import col, select
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
@@ -447,6 +448,6 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
|
||||
search_query, sort_by, sort_order
|
||||
search_query, sort_by, sort_order,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -153,7 +153,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
if search_query and search_query.strip():
|
||||
search_pattern = f"%{search_query.strip().lower()}%"
|
||||
statement = statement.where(
|
||||
func.lower(Sound.name).like(search_pattern)
|
||||
func.lower(Sound.name).like(search_pattern),
|
||||
)
|
||||
|
||||
# Apply sorting
|
||||
@@ -178,7 +178,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
|
||||
search_query, sound_types, sort_by, sort_order
|
||||
search_query, sound_types, sort_by, sort_order,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -189,7 +189,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
func.count(Sound.id).label("count"),
|
||||
func.sum(Sound.play_count).label("total_plays"),
|
||||
func.sum(Sound.duration).label("total_duration"),
|
||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
|
||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
|
||||
).where(Sound.type == "SDB")
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
@@ -199,7 +199,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
"count": row.count if row.count is not None else 0,
|
||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
||||
"total_size": row.total_size if row.total_size is not None else 0
|
||||
"total_size": row.total_size if row.total_size is not None else 0,
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to get soundboard statistics")
|
||||
@@ -212,7 +212,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
func.count(Sound.id).label("count"),
|
||||
func.sum(Sound.play_count).label("total_plays"),
|
||||
func.sum(Sound.duration).label("total_duration"),
|
||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
|
||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
|
||||
).where(Sound.type == "EXT")
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
@@ -222,7 +222,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
"count": row.count if row.count is not None else 0,
|
||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
||||
"total_size": row.total_size if row.total_size is not None else 0
|
||||
"total_size": row.total_size if row.total_size is not None else 0,
|
||||
}
|
||||
except Exception:
|
||||
logger.exception("Failed to get track statistics")
|
||||
@@ -244,7 +244,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
Sound.type,
|
||||
Sound.duration,
|
||||
Sound.created_at,
|
||||
func.count(SoundPlayed.id).label("play_count")
|
||||
func.count(SoundPlayed.id).label("play_count"),
|
||||
)
|
||||
.select_from(SoundPlayed)
|
||||
.join(Sound, SoundPlayed.sound_id == Sound.id)
|
||||
@@ -266,7 +266,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
Sound.name,
|
||||
Sound.type,
|
||||
Sound.duration,
|
||||
Sound.created_at
|
||||
Sound.created_at,
|
||||
)
|
||||
.order_by(func.count(SoundPlayed.id).desc())
|
||||
.limit(limit)
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.plan import Plan
|
||||
|
||||
@@ -96,5 +96,5 @@ class UpdateProfileRequest(BaseModel):
|
||||
"""Schema for profile update request."""
|
||||
|
||||
name: str | None = Field(
|
||||
None, min_length=1, max_length=100, description="User display name"
|
||||
None, min_length=1, max_length=100, description="User display name",
|
||||
)
|
||||
|
||||
@@ -454,7 +454,7 @@ class AuthService:
|
||||
return user
|
||||
|
||||
async def change_user_password(
|
||||
self, user: User, current_password: str | None, new_password: str
|
||||
self, user: User, current_password: str | None, new_password: str,
|
||||
) -> None:
|
||||
"""Change user's password."""
|
||||
# Store user email before any operations to avoid session detachment issues
|
||||
@@ -488,7 +488,7 @@ class AuthService:
|
||||
async def user_to_response(self, user: User) -> UserResponse:
|
||||
"""Convert User model to UserResponse with plan information."""
|
||||
# Load plan relationship if not already loaded
|
||||
if not hasattr(user, 'plan') or not user.plan:
|
||||
if not hasattr(user, "plan") or not user.plan:
|
||||
await self.session.refresh(user, ["plan"])
|
||||
|
||||
return UserResponse(
|
||||
|
||||
@@ -272,8 +272,7 @@ class PlaylistService:
|
||||
# Ensure position doesn't create gaps - if position is too high, place at end
|
||||
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||
max_position = len(current_sounds)
|
||||
if position > max_position:
|
||||
position = max_position
|
||||
position = min(position, max_position)
|
||||
|
||||
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
|
||||
logger.info(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for admin user endpoints."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
@@ -76,7 +76,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
|
||||
mock_regular = type("User", (), {
|
||||
@@ -93,7 +93,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
|
||||
mock_get_all.return_value = [mock_admin, mock_regular]
|
||||
@@ -130,7 +130,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_all.return_value = [mock_admin]
|
||||
|
||||
@@ -185,7 +185,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_by_id.return_value = mock_user
|
||||
|
||||
@@ -244,7 +244,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
|
||||
updated_mock = type("User", (), {
|
||||
@@ -261,7 +261,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
|
||||
mock_get_by_id.return_value = mock_user
|
||||
@@ -278,7 +278,7 @@ class TestAdminUserEndpoints:
|
||||
"name": "Updated Name",
|
||||
"credits": 200,
|
||||
"plan_id": 1,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -299,7 +299,7 @@ class TestAdminUserEndpoints:
|
||||
):
|
||||
response = await authenticated_admin_client.patch(
|
||||
"/api/v1/admin/users/999",
|
||||
json={"name": "Updated Name"}
|
||||
json={"name": "Updated Name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -333,12 +333,12 @@ class TestAdminUserEndpoints:
|
||||
"id": 1,
|
||||
"name": "Basic",
|
||||
"max_credits": 100,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_by_id.return_value = mock_user
|
||||
response = await authenticated_admin_client.patch(
|
||||
"/api/v1/admin/users/2",
|
||||
json={"plan_id": 999}
|
||||
json={"plan_id": 999},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
@@ -373,7 +373,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_by_id.return_value = mock_user
|
||||
mock_update.return_value = mock_user
|
||||
@@ -438,7 +438,7 @@ class TestAdminUserEndpoints:
|
||||
"id": test_plan.id,
|
||||
"name": test_plan.name,
|
||||
"max_credits": test_plan.max_credits,
|
||||
})()
|
||||
})(),
|
||||
})()
|
||||
mock_get_by_id.return_value = mock_disabled_user
|
||||
mock_update.return_value = mock_disabled_user
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for authentication endpoints."""
|
||||
|
||||
from datetime import UTC
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -495,7 +496,7 @@ class TestAuthEndpoints:
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
cookies={"refresh_token": "valid_refresh_token"}
|
||||
cookies={"refresh_token": "valid_refresh_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -520,7 +521,7 @@ class TestAuthEndpoints:
|
||||
|
||||
response = await test_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
cookies={"refresh_token": "valid_refresh_token"}
|
||||
cookies={"refresh_token": "valid_refresh_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
@@ -536,7 +537,7 @@ class TestAuthEndpoints:
|
||||
"""Test OAuth token exchange with invalid code."""
|
||||
response = await test_client.post(
|
||||
"/api/v1/auth/exchange-oauth-token",
|
||||
json={"code": "invalid_code"}
|
||||
json={"code": "invalid_code"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
@@ -589,7 +590,7 @@ class TestAuthEndpoints:
|
||||
response = await test_client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"name": "Updated Name"},
|
||||
cookies=auth_cookies
|
||||
cookies=auth_cookies,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -601,7 +602,7 @@ class TestAuthEndpoints:
|
||||
"""Test update profile without authentication."""
|
||||
response = await test_client.patch(
|
||||
"/api/v1/auth/me",
|
||||
json={"name": "Updated Name"}
|
||||
json={"name": "Updated Name"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -621,9 +622,9 @@ class TestAuthEndpoints:
|
||||
"/api/v1/auth/change-password",
|
||||
json={
|
||||
"current_password": "old_password",
|
||||
"new_password": "new_password"
|
||||
"new_password": "new_password",
|
||||
},
|
||||
cookies=auth_cookies
|
||||
cookies=auth_cookies,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
@@ -637,8 +638,8 @@ class TestAuthEndpoints:
|
||||
"/api/v1/auth/change-password",
|
||||
json={
|
||||
"current_password": "old_password",
|
||||
"new_password": "new_password"
|
||||
}
|
||||
"new_password": "new_password",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
@@ -652,8 +653,9 @@ class TestAuthEndpoints:
|
||||
) -> None:
|
||||
"""Test get user OAuth providers success."""
|
||||
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.user_oauth import UserOauth
|
||||
from datetime import datetime, timezone
|
||||
|
||||
mock_oauth_google = UserOauth(
|
||||
id=1,
|
||||
@@ -662,8 +664,8 @@ class TestAuthEndpoints:
|
||||
provider_user_id="google123",
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
mock_oauth_github = UserOauth(
|
||||
id=2,
|
||||
@@ -672,14 +674,14 @@ class TestAuthEndpoints:
|
||||
provider_user_id="github456",
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
mock_providers.return_value = [mock_oauth_google, mock_oauth_github]
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/v1/auth/user-providers",
|
||||
cookies=auth_cookies
|
||||
cookies=auth_cookies,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -409,7 +409,7 @@ class TestCreditService:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recharge_user_credits_success(
|
||||
self, credit_service, sample_user
|
||||
self, credit_service, sample_user,
|
||||
) -> None:
|
||||
"""Test successful credit recharge for a user."""
|
||||
mock_session = credit_service.db_session_factory()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for dashboard service."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -63,7 +63,7 @@ class TestDashboardService:
|
||||
):
|
||||
"""Test getting soundboard statistics with exception."""
|
||||
mock_sound_repository.get_soundboard_statistics = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
side_effect=Exception("Database error"),
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
@@ -105,7 +105,7 @@ class TestDashboardService:
|
||||
):
|
||||
"""Test getting track statistics with exception."""
|
||||
mock_sound_repository.get_track_statistics = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
side_effect=Exception("Database error"),
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
@@ -198,7 +198,7 @@ class TestDashboardService:
|
||||
):
|
||||
"""Test getting top sounds with exception."""
|
||||
mock_sound_repository.get_top_sounds = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
side_effect=Exception("Database error"),
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
|
||||
@@ -25,8 +25,8 @@ class TestSchedulerService:
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_scheduler(self, scheduler_service) -> None:
|
||||
"""Test starting the scheduler service."""
|
||||
with patch.object(scheduler_service.scheduler, 'add_job') as mock_add_job, \
|
||||
patch.object(scheduler_service.scheduler, 'start') as mock_start:
|
||||
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
|
||||
patch.object(scheduler_service.scheduler, "start") as mock_start:
|
||||
|
||||
await scheduler_service.start()
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestSchedulerService:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_scheduler(self, scheduler_service) -> None:
|
||||
"""Test stopping the scheduler service."""
|
||||
with patch.object(scheduler_service.scheduler, 'shutdown') as mock_shutdown:
|
||||
with patch.object(scheduler_service.scheduler, "shutdown") as mock_shutdown:
|
||||
await scheduler_service.stop()
|
||||
mock_shutdown.assert_called_once()
|
||||
|
||||
@@ -61,7 +61,7 @@ class TestSchedulerService:
|
||||
"total_credits_added": 500,
|
||||
}
|
||||
|
||||
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
|
||||
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
||||
mock_recharge.return_value = mock_stats
|
||||
|
||||
await scheduler_service._daily_credit_recharge()
|
||||
@@ -71,7 +71,7 @@ class TestSchedulerService:
|
||||
@pytest.mark.asyncio
|
||||
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
|
||||
"""Test daily credit recharge task with failure."""
|
||||
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
|
||||
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
||||
mock_recharge.side_effect = Exception("Database error")
|
||||
|
||||
# Should not raise exception, just log it
|
||||
|
||||
Reference in New Issue
Block a user