fix: Add missing commas in function calls and improve code formatting
Some checks failed
Backend CI / lint (push) Failing after 4m51s
Backend CI / test (push) Successful in 4m19s

This commit is contained in:
JSC
2025-08-12 23:37:38 +02:00
parent d3d7edb287
commit f094fbf140
18 changed files with 135 additions and 133 deletions

View File

@@ -460,7 +460,7 @@ async def update_profile(
"""Update the current user's profile."""
try:
updated_user = await auth_service.update_user_profile(
current_user, request.model_dump(exclude_unset=True)
current_user, request.model_dump(exclude_unset=True),
)
return await auth_service.user_to_response(updated_user)
except Exception as e:
@@ -482,7 +482,7 @@ async def change_password(
user_email = current_user.email
try:
await auth_service.change_user_password(
current_user, request.current_password, request.new_password
current_user, request.current_password, request.new_password,
)
return {"message": "Password changed successfully"}
except ValueError as e:
@@ -505,7 +505,7 @@ async def get_user_providers(
) -> list[dict[str, str]]:
"""Get the current user's connected authentication providers."""
providers = []
# Add password provider if user has password
if current_user.password_hash:
providers.append({
@@ -513,7 +513,7 @@ async def get_user_providers(
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
})
# Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
for oauth in oauth_providers:
@@ -522,11 +522,11 @@ async def get_user_providers(
display_name = "GitHub"
elif oauth.provider == "google":
display_name = "Google"
providers.append({
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
})
return providers

View File

@@ -41,5 +41,5 @@ async def get_top_sounds(
return await dashboard_service.get_top_sounds(
sound_type=sound_type,
period=period,
limit=limit
limit=limit,
)

View File

@@ -8,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder
from app.schemas.common import MessageResponse
from app.schemas.playlist import (
PlaylistAddSoundRequest,
@@ -19,7 +20,6 @@ from app.schemas.playlist import (
PlaylistUpdateRequest,
)
from app.services.playlist import PlaylistService
from app.repositories.playlist import PlaylistSortField, SortOrder
router = APIRouter(prefix="/playlists", tags=["playlists"])

View File

@@ -10,7 +10,7 @@ from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType
from app.models.sound import Sound
from app.models.user import User
from app.repositories.sound import SoundRepository, SoundSortField, SortOrder
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service

View File

@@ -8,10 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.logging import get_logger
from app.models.user import User
from app.repositories.sound import SoundRepository
from app.services.auth import AuthService
from app.services.dashboard import DashboardService
from app.services.oauth import OAuthService
from app.repositories.sound import SoundRepository
from app.utils.auth import JWTUtils, TokenUtils
logger = get_logger(__name__)

View File

@@ -1,9 +1,10 @@
"""Playlist repository for database operations."""
from enum import Enum
from sqlalchemy import func, update
from sqlalchemy.orm import selectinload
from sqlmodel import col, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
@@ -18,7 +19,7 @@ logger = get_logger(__name__)
class PlaylistSortField(str, Enum):
"""Playlist sort field enumeration."""
NAME = "name"
GENRE = "genre"
CREATED_AT = "created_at"
@@ -29,7 +30,7 @@ class PlaylistSortField(str, Enum):
class SortOrder(str, Enum):
"""Sort order enumeration."""
ASC = "asc"
DESC = "desc"
@@ -154,7 +155,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Use a two-step approach to avoid unique constraint violations:
# 1. Move all affected positions to negative temporary positions
# 2. Then move them to their final positions
# Step 1: Move to temporary negative positions
update_to_negative = (
update(PlaylistSound)
@@ -166,8 +167,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
)
await self.session.exec(update_to_negative)
await self.session.commit()
# Step 2: Move from temporary negative positions to final positions
# Step 2: Move from temporary negative positions to final positions
update_to_final = (
update(PlaylistSound)
.where(
@@ -337,15 +338,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name)
)
# Apply filters
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply sorting
if sort_by == PlaylistSortField.SOUND_COUNT:
if sort_order == SortOrder.DESC:
@@ -360,7 +361,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
else:
# Default sorting by name
subquery = subquery.order_by(Playlist.name.asc())
else:
# Simple query without stats-based sorting
subquery = (
@@ -385,15 +386,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name)
)
# Apply filters
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply sorting
if sort_by:
if sort_by == PlaylistSortField.NAME:
@@ -406,7 +407,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
sort_column = Playlist.updated_at
else:
sort_column = Playlist.name
if sort_order == SortOrder.DESC:
subquery = subquery.order_by(sort_column.desc())
else:
@@ -414,16 +415,16 @@ class PlaylistRepository(BaseRepository[Playlist]):
else:
# Default sorting by name ascending
subquery = subquery.order_by(Playlist.name.asc())
# Apply pagination
if offset > 0:
subquery = subquery.offset(offset)
if limit is not None:
subquery = subquery.limit(limit)
result = await self.session.exec(subquery)
rows = result.all()
# Convert to dictionary format
playlists = []
for row in rows:
@@ -442,11 +443,11 @@ class PlaylistRepository(BaseRepository[Playlist]):
"sound_count": row.sound_count or 0,
"total_duration": row.total_duration or 0,
})
return playlists
except Exception:
logger.exception(
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
search_query, sort_by, sort_order
search_query, sort_by, sort_order,
)
raise

View File

@@ -17,7 +17,7 @@ logger = get_logger(__name__)
class SoundSortField(str, Enum):
"""Sound sort field enumeration."""
NAME = "name"
FILENAME = "filename"
DURATION = "duration"
@@ -30,7 +30,7 @@ class SoundSortField(str, Enum):
class SortOrder(str, Enum):
"""Sort order enumeration."""
ASC = "asc"
DESC = "desc"
@@ -144,18 +144,18 @@ class SoundRepository(BaseRepository[Sound]):
"""Search and sort sounds with optional filtering."""
try:
statement = select(Sound)
# Apply type filter
if sound_types:
statement = statement.where(col(Sound.type).in_(sound_types))
# Apply search filter
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
statement = statement.where(
func.lower(Sound.name).like(search_pattern)
func.lower(Sound.name).like(search_pattern),
)
# Apply sorting
if sort_by:
sort_column = getattr(Sound, sort_by.value)
@@ -166,19 +166,19 @@ class SoundRepository(BaseRepository[Sound]):
else:
# Default sorting by name ascending
statement = statement.order_by(Sound.name.asc())
# Apply pagination
if offset > 0:
statement = statement.offset(offset)
if limit is not None:
statement = statement.limit(limit)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception(
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
search_query, sound_types, sort_by, sort_order
search_query, sound_types, sort_by, sort_order,
)
raise
@@ -189,17 +189,17 @@ class SoundRepository(BaseRepository[Sound]):
func.count(Sound.id).label("count"),
func.sum(Sound.play_count).label("total_plays"),
func.sum(Sound.duration).label("total_duration"),
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
).where(Sound.type == "SDB")
result = await self.session.exec(statement)
row = result.first()
return {
"count": row.count if row.count is not None else 0,
"total_plays": row.total_plays if row.total_plays is not None else 0,
"total_duration": row.total_duration if row.total_duration is not None else 0,
"total_size": row.total_size if row.total_size is not None else 0
"total_size": row.total_size if row.total_size is not None else 0,
}
except Exception:
logger.exception("Failed to get soundboard statistics")
@@ -212,17 +212,17 @@ class SoundRepository(BaseRepository[Sound]):
func.count(Sound.id).label("count"),
func.sum(Sound.play_count).label("total_plays"),
func.sum(Sound.duration).label("total_duration"),
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
).where(Sound.type == "EXT")
result = await self.session.exec(statement)
row = result.first()
return {
"count": row.count if row.count is not None else 0,
"total_plays": row.total_plays if row.total_plays is not None else 0,
"total_duration": row.total_duration if row.total_duration is not None else 0,
"total_size": row.total_size if row.total_size is not None else 0
"total_size": row.total_size if row.total_size is not None else 0,
}
except Exception:
logger.exception("Failed to get track statistics")
@@ -244,20 +244,20 @@ class SoundRepository(BaseRepository[Sound]):
Sound.type,
Sound.duration,
Sound.created_at,
func.count(SoundPlayed.id).label("play_count")
func.count(SoundPlayed.id).label("play_count"),
)
.select_from(SoundPlayed)
.join(Sound, SoundPlayed.sound_id == Sound.id)
)
# Apply sound type filter
if sound_type != "all":
statement = statement.where(Sound.type == sound_type.upper())
# Apply date filter if provided
if date_filter:
statement = statement.where(SoundPlayed.created_at >= date_filter)
# Group by sound and order by play count descending
statement = (
statement
@@ -266,15 +266,15 @@ class SoundRepository(BaseRepository[Sound]):
Sound.name,
Sound.type,
Sound.duration,
Sound.created_at
Sound.created_at,
)
.order_by(func.count(SoundPlayed.id).desc())
.limit(limit)
)
result = await self.session.exec(statement)
rows = result.all()
# Convert to dictionaries with the play count from the period
return [
{

View File

@@ -2,9 +2,9 @@
from typing import Any
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.logging import get_logger
from app.models.plan import Plan

View File

@@ -96,5 +96,5 @@ class UpdateProfileRequest(BaseModel):
"""Schema for profile update request."""
name: str | None = Field(
None, min_length=1, max_length=100, description="User display name"
None, min_length=1, max_length=100, description="User display name",
)

View File

@@ -434,36 +434,36 @@ class AuthService:
async def update_user_profile(self, user: User, data: dict) -> User:
"""Update user profile information."""
logger.info("Updating profile for user: %s", user.email)
# Only allow updating specific fields
allowed_fields = {"name"}
update_data = {k: v for k, v in data.items() if k in allowed_fields}
if not update_data:
return user
# Update user
for field, value in update_data.items():
setattr(user, field, value)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user, ["plan"])
logger.info("Profile updated successfully for user: %s", user.email)
return user
async def change_user_password(
self, user: User, current_password: str | None, new_password: str
self, user: User, current_password: str | None, new_password: str,
) -> None:
"""Change user's password."""
# Store user email before any operations to avoid session detachment issues
user_email = user.email
logger.info("Changing password for user: %s", user_email)
# Store whether user had existing password before we modify it
had_existing_password = user.password_hash is not None
# If user has existing password, verify it
if had_existing_password:
if not current_password:
@@ -473,24 +473,24 @@ class AuthService:
else:
# User doesn't have a password (OAuth-only user), so we're setting their first password
logger.info("Setting first password for OAuth user: %s", user_email)
# Hash new password
new_password_hash = PasswordUtils.hash_password(new_password)
# Update user
user.password_hash = new_password_hash
self.session.add(user)
await self.session.commit()
logger.info("Password %s successfully for user: %s",
logger.info("Password %s successfully for user: %s",
"changed" if had_existing_password else "set", user_email)
async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information."""
# Load plan relationship if not already loaded
if not hasattr(user, 'plan') or not user.plan:
if not hasattr(user, "plan") or not user.plan:
await self.session.refresh(user, ["plan"])
return UserResponse(
id=user.id,
email=user.email,

View File

@@ -56,14 +56,14 @@ class DashboardService:
try:
# Calculate the date filter based on period
date_filter = self._get_date_filter(period)
# Get top sounds from repository
top_sounds = await self.sound_repository.get_top_sounds(
sound_type=sound_type,
date_filter=date_filter,
limit=limit,
)
return [
{
"id": sound["id"],
@@ -86,7 +86,7 @@ class DashboardService:
period,
)
raise
def _get_date_filter(self, period: str) -> datetime | None:
"""Calculate the date filter based on the period."""
now = datetime.now(UTC)

View File

@@ -272,9 +272,8 @@ class PlaylistService:
# Ensure position doesn't create gaps - if position is too high, place at end
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
max_position = len(current_sounds)
if position > max_position:
position = max_position
position = min(position, max_position)
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
logger.info(
"Added sound %s to playlist %s for user %s at position %s",
@@ -306,10 +305,10 @@ class PlaylistService:
)
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
# Reorder remaining sounds to eliminate gaps
await self._reorder_playlist_positions(playlist_id)
logger.info(
"Removed sound %s from playlist %s for user %s and reordered positions",
sound_id,
@@ -326,7 +325,7 @@ class PlaylistService:
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
if not sounds:
return
# Create sequential positions: 0, 1, 2, 3...
sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)]
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)

View File

@@ -1,6 +1,6 @@
"""Tests for admin user endpoints."""
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import Mock, patch
import pytest
from httpx import AsyncClient
@@ -76,9 +76,9 @@ class TestAdminUserEndpoints:
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})()
})(),
})()
mock_regular = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
@@ -93,9 +93,9 @@ class TestAdminUserEndpoints:
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})()
})(),
})()
mock_get_all.return_value = [mock_admin, mock_regular]
response = await authenticated_admin_client.get("/api/v1/admin/users/")
@@ -130,7 +130,7 @@ class TestAdminUserEndpoints:
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})()
})(),
})()
mock_get_all.return_value = [mock_admin]
@@ -185,7 +185,7 @@ class TestAdminUserEndpoints:
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})()
})(),
})()
mock_get_by_id.return_value = mock_user
@@ -244,9 +244,9 @@ class TestAdminUserEndpoints:
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})()
})(),
})()
updated_mock = type("User", (), {
"id": regular_user.id,
"email": regular_user.email,
@@ -261,9 +261,9 @@ class TestAdminUserEndpoints:
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})()
})(),
})()
mock_get_by_id.return_value = mock_user
mock_update.return_value = updated_mock
@@ -278,7 +278,7 @@ class TestAdminUserEndpoints:
"name": "Updated Name",
"credits": 200,
"plan_id": 1,
}
},
)
assert response.status_code == 200
@@ -299,7 +299,7 @@ class TestAdminUserEndpoints:
):
response = await authenticated_admin_client.patch(
"/api/v1/admin/users/999",
json={"name": "Updated Name"}
json={"name": "Updated Name"},
)
assert response.status_code == 404
@@ -333,12 +333,12 @@ class TestAdminUserEndpoints:
"id": 1,
"name": "Basic",
"max_credits": 100,
})()
})(),
})()
mock_get_by_id.return_value = mock_user
response = await authenticated_admin_client.patch(
"/api/v1/admin/users/2",
json={"plan_id": 999}
json={"plan_id": 999},
)
assert response.status_code == 404
@@ -373,7 +373,7 @@ class TestAdminUserEndpoints:
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})()
})(),
})()
mock_get_by_id.return_value = mock_user
mock_update.return_value = mock_user
@@ -438,7 +438,7 @@ class TestAdminUserEndpoints:
"id": test_plan.id,
"name": test_plan.name,
"max_credits": test_plan.max_credits,
})()
})(),
})()
mock_get_by_id.return_value = mock_disabled_user
mock_update.return_value = mock_disabled_user
@@ -487,4 +487,4 @@ class TestAdminUserEndpoints:
data = response.json()
assert len(data) == 2
assert data[0]["name"] == "Basic"
assert data[1]["name"] == "Premium"
assert data[1]["name"] == "Premium"

View File

@@ -1,5 +1,6 @@
"""Tests for authentication endpoints."""
from datetime import UTC
from typing import Any
from unittest.mock import patch
@@ -495,7 +496,7 @@ class TestAuthEndpoints:
response = await test_client.post(
"/api/v1/auth/refresh",
cookies={"refresh_token": "valid_refresh_token"}
cookies={"refresh_token": "valid_refresh_token"},
)
assert response.status_code == 200
@@ -520,7 +521,7 @@ class TestAuthEndpoints:
response = await test_client.post(
"/api/v1/auth/refresh",
cookies={"refresh_token": "valid_refresh_token"}
cookies={"refresh_token": "valid_refresh_token"},
)
assert response.status_code == 500
@@ -536,7 +537,7 @@ class TestAuthEndpoints:
"""Test OAuth token exchange with invalid code."""
response = await test_client.post(
"/api/v1/auth/exchange-oauth-token",
json={"code": "invalid_code"}
json={"code": "invalid_code"},
)
assert response.status_code == 400
@@ -565,7 +566,7 @@ class TestAuthEndpoints:
is_active=test_user.is_active,
)
mock_update.return_value = updated_user
# Mock the user_to_response to return UserResponse format
from app.schemas.auth import UserResponse
mock_user_to_response.return_value = UserResponse(
@@ -589,7 +590,7 @@ class TestAuthEndpoints:
response = await test_client.patch(
"/api/v1/auth/me",
json={"name": "Updated Name"},
cookies=auth_cookies
cookies=auth_cookies,
)
assert response.status_code == 200
@@ -601,7 +602,7 @@ class TestAuthEndpoints:
"""Test update profile without authentication."""
response = await test_client.patch(
"/api/v1/auth/me",
json={"name": "Updated Name"}
json={"name": "Updated Name"},
)
assert response.status_code == 401
@@ -621,9 +622,9 @@ class TestAuthEndpoints:
"/api/v1/auth/change-password",
json={
"current_password": "old_password",
"new_password": "new_password"
"new_password": "new_password",
},
cookies=auth_cookies
cookies=auth_cookies,
)
assert response.status_code == 200
@@ -637,8 +638,8 @@ class TestAuthEndpoints:
"/api/v1/auth/change-password",
json={
"current_password": "old_password",
"new_password": "new_password"
}
"new_password": "new_password",
},
)
assert response.status_code == 401
@@ -652,9 +653,10 @@ class TestAuthEndpoints:
) -> None:
"""Test get user OAuth providers success."""
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
from datetime import datetime
from app.models.user_oauth import UserOauth
from datetime import datetime, timezone
mock_oauth_google = UserOauth(
id=1,
user_id=test_user.id,
@@ -662,34 +664,34 @@ class TestAuthEndpoints:
provider_user_id="google123",
email="test@example.com",
name="Test User",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
mock_oauth_github = UserOauth(
id=2,
user_id=test_user.id,
provider="github",
provider="github",
provider_user_id="github456",
email="test@example.com",
name="Test User",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
mock_providers.return_value = [mock_oauth_google, mock_oauth_github]
response = await test_client.get(
"/api/v1/auth/user-providers",
cookies=auth_cookies
cookies=auth_cookies,
)
assert response.status_code == 200
data = response.json()
assert len(data) == 3 # password + 2 OAuth providers
# Check password provider (first)
assert data[0]["provider"] == "password"
assert data[0]["display_name"] == "Password"
# Check OAuth providers
assert data[1]["provider"] == "google"
assert data[1]["display_name"] == "Google"

View File

@@ -529,7 +529,7 @@ class TestPlaylistRepository:
)
test_session.add(sound)
sounds.append(sound)
await test_session.commit()
await test_session.refresh(playlist)
for sound in sounds:
@@ -547,7 +547,7 @@ class TestPlaylistRepository:
# Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0
assert playlist_sounds[0].position == 0
@@ -605,7 +605,7 @@ class TestPlaylistRepository:
)
test_session.add(sound)
sounds.append(sound)
await test_session.commit()
await test_session.refresh(playlist)
for sound in sounds:
@@ -623,7 +623,7 @@ class TestPlaylistRepository:
# Verify the final positions
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
assert len(playlist_sounds) == 3
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0
assert playlist_sounds[0].position == 0

View File

@@ -409,7 +409,7 @@ class TestCreditService:
@pytest.mark.asyncio
async def test_recharge_user_credits_success(
self, credit_service, sample_user
self, credit_service, sample_user,
) -> None:
"""Test successful credit recharge for a user."""
mock_session = credit_service.db_session_factory()

View File

@@ -1,6 +1,6 @@
"""Tests for dashboard service."""
from datetime import UTC, datetime, timedelta
from datetime import UTC, datetime
from unittest.mock import AsyncMock, Mock, patch
import pytest
@@ -63,7 +63,7 @@ class TestDashboardService:
):
"""Test getting soundboard statistics with exception."""
mock_sound_repository.get_soundboard_statistics = AsyncMock(
side_effect=Exception("Database error")
side_effect=Exception("Database error"),
)
with pytest.raises(Exception, match="Database error"):
@@ -105,7 +105,7 @@ class TestDashboardService:
):
"""Test getting track statistics with exception."""
mock_sound_repository.get_track_statistics = AsyncMock(
side_effect=Exception("Database error")
side_effect=Exception("Database error"),
)
with pytest.raises(Exception, match="Database error"):
@@ -198,7 +198,7 @@ class TestDashboardService:
):
"""Test getting top sounds with exception."""
mock_sound_repository.get_top_sounds = AsyncMock(
side_effect=Exception("Database error")
side_effect=Exception("Database error"),
)
with pytest.raises(Exception, match="Database error"):
@@ -274,4 +274,4 @@ class TestDashboardService:
def test_get_date_filter_unknown_period(self, dashboard_service):
"""Test date filter for unknown period."""
result = dashboard_service._get_date_filter("unknown_period")
assert result is None
assert result is None

View File

@@ -25,9 +25,9 @@ class TestSchedulerService:
@pytest.mark.asyncio
async def test_start_scheduler(self, scheduler_service) -> None:
"""Test starting the scheduler service."""
with patch.object(scheduler_service.scheduler, 'add_job') as mock_add_job, \
patch.object(scheduler_service.scheduler, 'start') as mock_start:
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
patch.object(scheduler_service.scheduler, "start") as mock_start:
await scheduler_service.start()
# Verify job was added
@@ -47,7 +47,7 @@ class TestSchedulerService:
@pytest.mark.asyncio
async def test_stop_scheduler(self, scheduler_service) -> None:
"""Test stopping the scheduler service."""
with patch.object(scheduler_service.scheduler, 'shutdown') as mock_shutdown:
with patch.object(scheduler_service.scheduler, "shutdown") as mock_shutdown:
await scheduler_service.stop()
mock_shutdown.assert_called_once()
@@ -61,7 +61,7 @@ class TestSchedulerService:
"total_credits_added": 500,
}
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
mock_recharge.return_value = mock_stats
await scheduler_service._daily_credit_recharge()
@@ -71,10 +71,10 @@ class TestSchedulerService:
@pytest.mark.asyncio
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
"""Test daily credit recharge task with failure."""
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
mock_recharge.side_effect = Exception("Database error")
# Should not raise exception, just log it
await scheduler_service._daily_credit_recharge()
mock_recharge.assert_called_once()
mock_recharge.assert_called_once()