fix: Add missing commas in function calls and improve code formatting
This commit is contained in:
@@ -460,7 +460,7 @@ async def update_profile(
|
|||||||
"""Update the current user's profile."""
|
"""Update the current user's profile."""
|
||||||
try:
|
try:
|
||||||
updated_user = await auth_service.update_user_profile(
|
updated_user = await auth_service.update_user_profile(
|
||||||
current_user, request.model_dump(exclude_unset=True)
|
current_user, request.model_dump(exclude_unset=True),
|
||||||
)
|
)
|
||||||
return await auth_service.user_to_response(updated_user)
|
return await auth_service.user_to_response(updated_user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -482,7 +482,7 @@ async def change_password(
|
|||||||
user_email = current_user.email
|
user_email = current_user.email
|
||||||
try:
|
try:
|
||||||
await auth_service.change_user_password(
|
await auth_service.change_user_password(
|
||||||
current_user, request.current_password, request.new_password
|
current_user, request.current_password, request.new_password,
|
||||||
)
|
)
|
||||||
return {"message": "Password changed successfully"}
|
return {"message": "Password changed successfully"}
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
@@ -41,5 +41,5 @@ async def get_top_sounds(
|
|||||||
return await dashboard_service.get_top_sounds(
|
return await dashboard_service.get_top_sounds(
|
||||||
sound_type=sound_type,
|
sound_type=sound_type,
|
||||||
period=period,
|
period=period,
|
||||||
limit=limit
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.dependencies import get_current_active_user_flexible
|
from app.core.dependencies import get_current_active_user_flexible
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.schemas.playlist import (
|
from app.schemas.playlist import (
|
||||||
PlaylistAddSoundRequest,
|
PlaylistAddSoundRequest,
|
||||||
@@ -19,7 +20,6 @@ from app.schemas.playlist import (
|
|||||||
PlaylistUpdateRequest,
|
PlaylistUpdateRequest,
|
||||||
)
|
)
|
||||||
from app.services.playlist import PlaylistService
|
from app.services.playlist import PlaylistService
|
||||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from app.core.dependencies import get_current_active_user_flexible
|
|||||||
from app.models.credit_action import CreditActionType
|
from app.models.credit_action import CreditActionType
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.sound import SoundRepository, SoundSortField, SortOrder
|
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
|
||||||
from app.services.credit import CreditService, InsufficientCreditsError
|
from app.services.credit import CreditService, InsufficientCreditsError
|
||||||
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.sound import SoundRepository
|
||||||
from app.services.auth import AuthService
|
from app.services.auth import AuthService
|
||||||
from app.services.dashboard import DashboardService
|
from app.services.dashboard import DashboardService
|
||||||
from app.services.oauth import OAuthService
|
from app.services.oauth import OAuthService
|
||||||
from app.repositories.sound import SoundRepository
|
|
||||||
from app.utils.auth import JWTUtils, TokenUtils
|
from app.utils.auth import JWTUtils, TokenUtils
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Playlist repository for database operations."""
|
"""Playlist repository for database operations."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from sqlalchemy import func, update
|
from sqlalchemy import func, update
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlmodel import col, select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
@@ -447,6 +448,6 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
|
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
|
||||||
search_query, sort_by, sort_order
|
search_query, sort_by, sort_order,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
if search_query and search_query.strip():
|
if search_query and search_query.strip():
|
||||||
search_pattern = f"%{search_query.strip().lower()}%"
|
search_pattern = f"%{search_query.strip().lower()}%"
|
||||||
statement = statement.where(
|
statement = statement.where(
|
||||||
func.lower(Sound.name).like(search_pattern)
|
func.lower(Sound.name).like(search_pattern),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
@@ -178,7 +178,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
|
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
|
||||||
search_query, sound_types, sort_by, sort_order
|
search_query, sound_types, sort_by, sort_order,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
func.count(Sound.id).label("count"),
|
func.count(Sound.id).label("count"),
|
||||||
func.sum(Sound.play_count).label("total_plays"),
|
func.sum(Sound.play_count).label("total_plays"),
|
||||||
func.sum(Sound.duration).label("total_duration"),
|
func.sum(Sound.duration).label("total_duration"),
|
||||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
|
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
|
||||||
).where(Sound.type == "SDB")
|
).where(Sound.type == "SDB")
|
||||||
|
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
@@ -199,7 +199,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
"count": row.count if row.count is not None else 0,
|
"count": row.count if row.count is not None else 0,
|
||||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||||
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
||||||
"total_size": row.total_size if row.total_size is not None else 0
|
"total_size": row.total_size if row.total_size is not None else 0,
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to get soundboard statistics")
|
logger.exception("Failed to get soundboard statistics")
|
||||||
@@ -212,7 +212,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
func.count(Sound.id).label("count"),
|
func.count(Sound.id).label("count"),
|
||||||
func.sum(Sound.play_count).label("total_plays"),
|
func.sum(Sound.play_count).label("total_plays"),
|
||||||
func.sum(Sound.duration).label("total_duration"),
|
func.sum(Sound.duration).label("total_duration"),
|
||||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
|
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
|
||||||
).where(Sound.type == "EXT")
|
).where(Sound.type == "EXT")
|
||||||
|
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
@@ -222,7 +222,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
"count": row.count if row.count is not None else 0,
|
"count": row.count if row.count is not None else 0,
|
||||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||||
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
||||||
"total_size": row.total_size if row.total_size is not None else 0
|
"total_size": row.total_size if row.total_size is not None else 0,
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to get track statistics")
|
logger.exception("Failed to get track statistics")
|
||||||
@@ -244,7 +244,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
Sound.type,
|
Sound.type,
|
||||||
Sound.duration,
|
Sound.duration,
|
||||||
Sound.created_at,
|
Sound.created_at,
|
||||||
func.count(SoundPlayed.id).label("play_count")
|
func.count(SoundPlayed.id).label("play_count"),
|
||||||
)
|
)
|
||||||
.select_from(SoundPlayed)
|
.select_from(SoundPlayed)
|
||||||
.join(Sound, SoundPlayed.sound_id == Sound.id)
|
.join(Sound, SoundPlayed.sound_id == Sound.id)
|
||||||
@@ -266,7 +266,7 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
Sound.name,
|
Sound.name,
|
||||||
Sound.type,
|
Sound.type,
|
||||||
Sound.duration,
|
Sound.duration,
|
||||||
Sound.created_at
|
Sound.created_at,
|
||||||
)
|
)
|
||||||
.order_by(func.count(SoundPlayed.id).desc())
|
.order_by(func.count(SoundPlayed.id).desc())
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
|
|||||||
@@ -96,5 +96,5 @@ class UpdateProfileRequest(BaseModel):
|
|||||||
"""Schema for profile update request."""
|
"""Schema for profile update request."""
|
||||||
|
|
||||||
name: str | None = Field(
|
name: str | None = Field(
|
||||||
None, min_length=1, max_length=100, description="User display name"
|
None, min_length=1, max_length=100, description="User display name",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -454,7 +454,7 @@ class AuthService:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
async def change_user_password(
|
async def change_user_password(
|
||||||
self, user: User, current_password: str | None, new_password: str
|
self, user: User, current_password: str | None, new_password: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Change user's password."""
|
"""Change user's password."""
|
||||||
# Store user email before any operations to avoid session detachment issues
|
# Store user email before any operations to avoid session detachment issues
|
||||||
@@ -488,7 +488,7 @@ class AuthService:
|
|||||||
async def user_to_response(self, user: User) -> UserResponse:
|
async def user_to_response(self, user: User) -> UserResponse:
|
||||||
"""Convert User model to UserResponse with plan information."""
|
"""Convert User model to UserResponse with plan information."""
|
||||||
# Load plan relationship if not already loaded
|
# Load plan relationship if not already loaded
|
||||||
if not hasattr(user, 'plan') or not user.plan:
|
if not hasattr(user, "plan") or not user.plan:
|
||||||
await self.session.refresh(user, ["plan"])
|
await self.session.refresh(user, ["plan"])
|
||||||
|
|
||||||
return UserResponse(
|
return UserResponse(
|
||||||
|
|||||||
@@ -272,8 +272,7 @@ class PlaylistService:
|
|||||||
# Ensure position doesn't create gaps - if position is too high, place at end
|
# Ensure position doesn't create gaps - if position is too high, place at end
|
||||||
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||||
max_position = len(current_sounds)
|
max_position = len(current_sounds)
|
||||||
if position > max_position:
|
position = min(position, max_position)
|
||||||
position = max_position
|
|
||||||
|
|
||||||
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
|
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for admin user endpoints."""
|
"""Tests for admin user endpoints."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
@@ -76,7 +76,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
|
|
||||||
mock_regular = type("User", (), {
|
mock_regular = type("User", (), {
|
||||||
@@ -93,7 +93,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
|
|
||||||
mock_get_all.return_value = [mock_admin, mock_regular]
|
mock_get_all.return_value = [mock_admin, mock_regular]
|
||||||
@@ -130,7 +130,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_all.return_value = [mock_admin]
|
mock_get_all.return_value = [mock_admin]
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
|
|
||||||
@@ -244,7 +244,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
|
|
||||||
updated_mock = type("User", (), {
|
updated_mock = type("User", (), {
|
||||||
@@ -261,7 +261,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
|
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
@@ -278,7 +278,7 @@ class TestAdminUserEndpoints:
|
|||||||
"name": "Updated Name",
|
"name": "Updated Name",
|
||||||
"credits": 200,
|
"credits": 200,
|
||||||
"plan_id": 1,
|
"plan_id": 1,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -299,7 +299,7 @@ class TestAdminUserEndpoints:
|
|||||||
):
|
):
|
||||||
response = await authenticated_admin_client.patch(
|
response = await authenticated_admin_client.patch(
|
||||||
"/api/v1/admin/users/999",
|
"/api/v1/admin/users/999",
|
||||||
json={"name": "Updated Name"}
|
json={"name": "Updated Name"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
@@ -333,12 +333,12 @@ class TestAdminUserEndpoints:
|
|||||||
"id": 1,
|
"id": 1,
|
||||||
"name": "Basic",
|
"name": "Basic",
|
||||||
"max_credits": 100,
|
"max_credits": 100,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
response = await authenticated_admin_client.patch(
|
response = await authenticated_admin_client.patch(
|
||||||
"/api/v1/admin/users/2",
|
"/api/v1/admin/users/2",
|
||||||
json={"plan_id": 999}
|
json={"plan_id": 999},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
@@ -373,7 +373,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
mock_update.return_value = mock_user
|
mock_update.return_value = mock_user
|
||||||
@@ -438,7 +438,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_by_id.return_value = mock_disabled_user
|
mock_get_by_id.return_value = mock_disabled_user
|
||||||
mock_update.return_value = mock_disabled_user
|
mock_update.return_value = mock_disabled_user
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Tests for authentication endpoints."""
|
"""Tests for authentication endpoints."""
|
||||||
|
|
||||||
|
from datetime import UTC
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -495,7 +496,7 @@ class TestAuthEndpoints:
|
|||||||
|
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/v1/auth/refresh",
|
"/api/v1/auth/refresh",
|
||||||
cookies={"refresh_token": "valid_refresh_token"}
|
cookies={"refresh_token": "valid_refresh_token"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -520,7 +521,7 @@ class TestAuthEndpoints:
|
|||||||
|
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/v1/auth/refresh",
|
"/api/v1/auth/refresh",
|
||||||
cookies={"refresh_token": "valid_refresh_token"}
|
cookies={"refresh_token": "valid_refresh_token"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
@@ -536,7 +537,7 @@ class TestAuthEndpoints:
|
|||||||
"""Test OAuth token exchange with invalid code."""
|
"""Test OAuth token exchange with invalid code."""
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/v1/auth/exchange-oauth-token",
|
"/api/v1/auth/exchange-oauth-token",
|
||||||
json={"code": "invalid_code"}
|
json={"code": "invalid_code"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
@@ -589,7 +590,7 @@ class TestAuthEndpoints:
|
|||||||
response = await test_client.patch(
|
response = await test_client.patch(
|
||||||
"/api/v1/auth/me",
|
"/api/v1/auth/me",
|
||||||
json={"name": "Updated Name"},
|
json={"name": "Updated Name"},
|
||||||
cookies=auth_cookies
|
cookies=auth_cookies,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -601,7 +602,7 @@ class TestAuthEndpoints:
|
|||||||
"""Test update profile without authentication."""
|
"""Test update profile without authentication."""
|
||||||
response = await test_client.patch(
|
response = await test_client.patch(
|
||||||
"/api/v1/auth/me",
|
"/api/v1/auth/me",
|
||||||
json={"name": "Updated Name"}
|
json={"name": "Updated Name"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
@@ -621,9 +622,9 @@ class TestAuthEndpoints:
|
|||||||
"/api/v1/auth/change-password",
|
"/api/v1/auth/change-password",
|
||||||
json={
|
json={
|
||||||
"current_password": "old_password",
|
"current_password": "old_password",
|
||||||
"new_password": "new_password"
|
"new_password": "new_password",
|
||||||
},
|
},
|
||||||
cookies=auth_cookies
|
cookies=auth_cookies,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -637,8 +638,8 @@ class TestAuthEndpoints:
|
|||||||
"/api/v1/auth/change-password",
|
"/api/v1/auth/change-password",
|
||||||
json={
|
json={
|
||||||
"current_password": "old_password",
|
"current_password": "old_password",
|
||||||
"new_password": "new_password"
|
"new_password": "new_password",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
@@ -652,8 +653,9 @@ class TestAuthEndpoints:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test get user OAuth providers success."""
|
"""Test get user OAuth providers success."""
|
||||||
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
|
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from app.models.user_oauth import UserOauth
|
from app.models.user_oauth import UserOauth
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
mock_oauth_google = UserOauth(
|
mock_oauth_google = UserOauth(
|
||||||
id=1,
|
id=1,
|
||||||
@@ -662,8 +664,8 @@ class TestAuthEndpoints:
|
|||||||
provider_user_id="google123",
|
provider_user_id="google123",
|
||||||
email="test@example.com",
|
email="test@example.com",
|
||||||
name="Test User",
|
name="Test User",
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=datetime.now(UTC),
|
||||||
updated_at=datetime.now(timezone.utc),
|
updated_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
mock_oauth_github = UserOauth(
|
mock_oauth_github = UserOauth(
|
||||||
id=2,
|
id=2,
|
||||||
@@ -672,14 +674,14 @@ class TestAuthEndpoints:
|
|||||||
provider_user_id="github456",
|
provider_user_id="github456",
|
||||||
email="test@example.com",
|
email="test@example.com",
|
||||||
name="Test User",
|
name="Test User",
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=datetime.now(UTC),
|
||||||
updated_at=datetime.now(timezone.utc),
|
updated_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
mock_providers.return_value = [mock_oauth_google, mock_oauth_github]
|
mock_providers.return_value = [mock_oauth_google, mock_oauth_github]
|
||||||
|
|
||||||
response = await test_client.get(
|
response = await test_client.get(
|
||||||
"/api/v1/auth/user-providers",
|
"/api/v1/auth/user-providers",
|
||||||
cookies=auth_cookies
|
cookies=auth_cookies,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -409,7 +409,7 @@ class TestCreditService:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_recharge_user_credits_success(
|
async def test_recharge_user_credits_success(
|
||||||
self, credit_service, sample_user
|
self, credit_service, sample_user,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful credit recharge for a user."""
|
"""Test successful credit recharge for a user."""
|
||||||
mock_session = credit_service.db_session_factory()
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for dashboard service."""
|
"""Tests for dashboard service."""
|
||||||
|
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -63,7 +63,7 @@ class TestDashboardService:
|
|||||||
):
|
):
|
||||||
"""Test getting soundboard statistics with exception."""
|
"""Test getting soundboard statistics with exception."""
|
||||||
mock_sound_repository.get_soundboard_statistics = AsyncMock(
|
mock_sound_repository.get_soundboard_statistics = AsyncMock(
|
||||||
side_effect=Exception("Database error")
|
side_effect=Exception("Database error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Database error"):
|
with pytest.raises(Exception, match="Database error"):
|
||||||
@@ -105,7 +105,7 @@ class TestDashboardService:
|
|||||||
):
|
):
|
||||||
"""Test getting track statistics with exception."""
|
"""Test getting track statistics with exception."""
|
||||||
mock_sound_repository.get_track_statistics = AsyncMock(
|
mock_sound_repository.get_track_statistics = AsyncMock(
|
||||||
side_effect=Exception("Database error")
|
side_effect=Exception("Database error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Database error"):
|
with pytest.raises(Exception, match="Database error"):
|
||||||
@@ -198,7 +198,7 @@ class TestDashboardService:
|
|||||||
):
|
):
|
||||||
"""Test getting top sounds with exception."""
|
"""Test getting top sounds with exception."""
|
||||||
mock_sound_repository.get_top_sounds = AsyncMock(
|
mock_sound_repository.get_top_sounds = AsyncMock(
|
||||||
side_effect=Exception("Database error")
|
side_effect=Exception("Database error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Database error"):
|
with pytest.raises(Exception, match="Database error"):
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ class TestSchedulerService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_scheduler(self, scheduler_service) -> None:
|
async def test_start_scheduler(self, scheduler_service) -> None:
|
||||||
"""Test starting the scheduler service."""
|
"""Test starting the scheduler service."""
|
||||||
with patch.object(scheduler_service.scheduler, 'add_job') as mock_add_job, \
|
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
|
||||||
patch.object(scheduler_service.scheduler, 'start') as mock_start:
|
patch.object(scheduler_service.scheduler, "start") as mock_start:
|
||||||
|
|
||||||
await scheduler_service.start()
|
await scheduler_service.start()
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class TestSchedulerService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_scheduler(self, scheduler_service) -> None:
|
async def test_stop_scheduler(self, scheduler_service) -> None:
|
||||||
"""Test stopping the scheduler service."""
|
"""Test stopping the scheduler service."""
|
||||||
with patch.object(scheduler_service.scheduler, 'shutdown') as mock_shutdown:
|
with patch.object(scheduler_service.scheduler, "shutdown") as mock_shutdown:
|
||||||
await scheduler_service.stop()
|
await scheduler_service.stop()
|
||||||
mock_shutdown.assert_called_once()
|
mock_shutdown.assert_called_once()
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ class TestSchedulerService:
|
|||||||
"total_credits_added": 500,
|
"total_credits_added": 500,
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
|
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
||||||
mock_recharge.return_value = mock_stats
|
mock_recharge.return_value = mock_stats
|
||||||
|
|
||||||
await scheduler_service._daily_credit_recharge()
|
await scheduler_service._daily_credit_recharge()
|
||||||
@@ -71,7 +71,7 @@ class TestSchedulerService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
|
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
|
||||||
"""Test daily credit recharge task with failure."""
|
"""Test daily credit recharge task with failure."""
|
||||||
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
|
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
||||||
mock_recharge.side_effect = Exception("Database error")
|
mock_recharge.side_effect = Exception("Database error")
|
||||||
|
|
||||||
# Should not raise exception, just log it
|
# Should not raise exception, just log it
|
||||||
|
|||||||
Reference in New Issue
Block a user