fix: Add missing commas in function calls and improve code formatting
Some checks failed
Backend CI / lint (push) Failing after 4m51s
Backend CI / test (push) Successful in 4m19s

This commit is contained in:
JSC
2025-08-12 23:37:38 +02:00
parent d3d7edb287
commit f094fbf140
18 changed files with 135 additions and 133 deletions

View File

@@ -460,7 +460,7 @@ async def update_profile(
"""Update the current user's profile."""
try:
updated_user = await auth_service.update_user_profile(
current_user, request.model_dump(exclude_unset=True)
current_user, request.model_dump(exclude_unset=True),
)
return await auth_service.user_to_response(updated_user)
except Exception as e:
@@ -482,7 +482,7 @@ async def change_password(
user_email = current_user.email
try:
await auth_service.change_user_password(
current_user, request.current_password, request.new_password
current_user, request.current_password, request.new_password,
)
return {"message": "Password changed successfully"}
except ValueError as e:
@@ -505,7 +505,7 @@ async def get_user_providers(
) -> list[dict[str, str]]:
"""Get the current user's connected authentication providers."""
providers = []
# Add password provider if user has password
if current_user.password_hash:
providers.append({
@@ -513,7 +513,7 @@ async def get_user_providers(
"display_name": "Password",
"connected_at": current_user.created_at.isoformat(),
})
# Get OAuth providers from the database
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
for oauth in oauth_providers:
@@ -522,11 +522,11 @@ async def get_user_providers(
display_name = "GitHub"
elif oauth.provider == "google":
display_name = "Google"
providers.append({
"provider": oauth.provider,
"display_name": display_name,
"connected_at": oauth.created_at.isoformat(),
})
return providers

View File

@@ -41,5 +41,5 @@ async def get_top_sounds(
return await dashboard_service.get_top_sounds(
sound_type=sound_type,
period=period,
limit=limit
limit=limit,
)

View File

@@ -8,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.repositories.playlist import PlaylistSortField, SortOrder
from app.schemas.common import MessageResponse
from app.schemas.playlist import (
PlaylistAddSoundRequest,
@@ -19,7 +20,6 @@ from app.schemas.playlist import (
PlaylistUpdateRequest,
)
from app.services.playlist import PlaylistService
from app.repositories.playlist import PlaylistSortField, SortOrder
router = APIRouter(prefix="/playlists", tags=["playlists"])

View File

@@ -10,7 +10,7 @@ from app.core.dependencies import get_current_active_user_flexible
from app.models.credit_action import CreditActionType
from app.models.sound import Sound
from app.models.user import User
from app.repositories.sound import SoundRepository, SoundSortField, SortOrder
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
from app.services.credit import CreditService, InsufficientCreditsError
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service

View File

@@ -8,10 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.logging import get_logger
from app.models.user import User
from app.repositories.sound import SoundRepository
from app.services.auth import AuthService
from app.services.dashboard import DashboardService
from app.services.oauth import OAuthService
from app.repositories.sound import SoundRepository
from app.utils.auth import JWTUtils, TokenUtils
logger = get_logger(__name__)

View File

@@ -1,9 +1,10 @@
"""Playlist repository for database operations."""
from enum import Enum
from sqlalchemy import func, update
from sqlalchemy.orm import selectinload
from sqlmodel import col, select
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
@@ -18,7 +19,7 @@ logger = get_logger(__name__)
class PlaylistSortField(str, Enum):
"""Playlist sort field enumeration."""
NAME = "name"
GENRE = "genre"
CREATED_AT = "created_at"
@@ -29,7 +30,7 @@ class PlaylistSortField(str, Enum):
class SortOrder(str, Enum):
"""Sort order enumeration."""
ASC = "asc"
DESC = "desc"
@@ -154,7 +155,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
# Use a two-step approach to avoid unique constraint violations:
# 1. Move all affected positions to negative temporary positions
# 2. Then move them to their final positions
# Step 1: Move to temporary negative positions
update_to_negative = (
update(PlaylistSound)
@@ -166,8 +167,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
)
await self.session.exec(update_to_negative)
await self.session.commit()
# Step 2: Move from temporary negative positions to final positions
# Step 2: Move from temporary negative positions to final positions
update_to_final = (
update(PlaylistSound)
.where(
@@ -337,15 +338,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name)
)
# Apply filters
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply sorting
if sort_by == PlaylistSortField.SOUND_COUNT:
if sort_order == SortOrder.DESC:
@@ -360,7 +361,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
else:
# Default sorting by name
subquery = subquery.order_by(Playlist.name.asc())
else:
# Simple query without stats-based sorting
subquery = (
@@ -385,15 +386,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
.group_by(Playlist.id, User.name)
)
# Apply filters
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
if user_id is not None:
subquery = subquery.where(Playlist.user_id == user_id)
# Apply sorting
if sort_by:
if sort_by == PlaylistSortField.NAME:
@@ -406,7 +407,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
sort_column = Playlist.updated_at
else:
sort_column = Playlist.name
if sort_order == SortOrder.DESC:
subquery = subquery.order_by(sort_column.desc())
else:
@@ -414,16 +415,16 @@ class PlaylistRepository(BaseRepository[Playlist]):
else:
# Default sorting by name ascending
subquery = subquery.order_by(Playlist.name.asc())
# Apply pagination
if offset > 0:
subquery = subquery.offset(offset)
if limit is not None:
subquery = subquery.limit(limit)
result = await self.session.exec(subquery)
rows = result.all()
# Convert to dictionary format
playlists = []
for row in rows:
@@ -442,11 +443,11 @@ class PlaylistRepository(BaseRepository[Playlist]):
"sound_count": row.sound_count or 0,
"total_duration": row.total_duration or 0,
})
return playlists
except Exception:
logger.exception(
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
search_query, sort_by, sort_order
search_query, sort_by, sort_order,
)
raise

View File

@@ -17,7 +17,7 @@ logger = get_logger(__name__)
class SoundSortField(str, Enum):
"""Sound sort field enumeration."""
NAME = "name"
FILENAME = "filename"
DURATION = "duration"
@@ -30,7 +30,7 @@ class SoundSortField(str, Enum):
class SortOrder(str, Enum):
"""Sort order enumeration."""
ASC = "asc"
DESC = "desc"
@@ -144,18 +144,18 @@ class SoundRepository(BaseRepository[Sound]):
"""Search and sort sounds with optional filtering."""
try:
statement = select(Sound)
# Apply type filter
if sound_types:
statement = statement.where(col(Sound.type).in_(sound_types))
# Apply search filter
if search_query and search_query.strip():
search_pattern = f"%{search_query.strip().lower()}%"
statement = statement.where(
func.lower(Sound.name).like(search_pattern)
func.lower(Sound.name).like(search_pattern),
)
# Apply sorting
if sort_by:
sort_column = getattr(Sound, sort_by.value)
@@ -166,19 +166,19 @@ class SoundRepository(BaseRepository[Sound]):
else:
# Default sorting by name ascending
statement = statement.order_by(Sound.name.asc())
# Apply pagination
if offset > 0:
statement = statement.offset(offset)
if limit is not None:
statement = statement.limit(limit)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception(
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
search_query, sound_types, sort_by, sort_order
search_query, sound_types, sort_by, sort_order,
)
raise
@@ -189,17 +189,17 @@ class SoundRepository(BaseRepository[Sound]):
func.count(Sound.id).label("count"),
func.sum(Sound.play_count).label("total_plays"),
func.sum(Sound.duration).label("total_duration"),
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
).where(Sound.type == "SDB")
result = await self.session.exec(statement)
row = result.first()
return {
"count": row.count if row.count is not None else 0,
"total_plays": row.total_plays if row.total_plays is not None else 0,
"total_duration": row.total_duration if row.total_duration is not None else 0,
"total_size": row.total_size if row.total_size is not None else 0
"total_size": row.total_size if row.total_size is not None else 0,
}
except Exception:
logger.exception("Failed to get soundboard statistics")
@@ -212,17 +212,17 @@ class SoundRepository(BaseRepository[Sound]):
func.count(Sound.id).label("count"),
func.sum(Sound.play_count).label("total_plays"),
func.sum(Sound.duration).label("total_duration"),
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
).where(Sound.type == "EXT")
result = await self.session.exec(statement)
row = result.first()
return {
"count": row.count if row.count is not None else 0,
"total_plays": row.total_plays if row.total_plays is not None else 0,
"total_duration": row.total_duration if row.total_duration is not None else 0,
"total_size": row.total_size if row.total_size is not None else 0
"total_size": row.total_size if row.total_size is not None else 0,
}
except Exception:
logger.exception("Failed to get track statistics")
@@ -244,20 +244,20 @@ class SoundRepository(BaseRepository[Sound]):
Sound.type,
Sound.duration,
Sound.created_at,
func.count(SoundPlayed.id).label("play_count")
func.count(SoundPlayed.id).label("play_count"),
)
.select_from(SoundPlayed)
.join(Sound, SoundPlayed.sound_id == Sound.id)
)
# Apply sound type filter
if sound_type != "all":
statement = statement.where(Sound.type == sound_type.upper())
# Apply date filter if provided
if date_filter:
statement = statement.where(SoundPlayed.created_at >= date_filter)
# Group by sound and order by play count descending
statement = (
statement
@@ -266,15 +266,15 @@ class SoundRepository(BaseRepository[Sound]):
Sound.name,
Sound.type,
Sound.duration,
Sound.created_at
Sound.created_at,
)
.order_by(func.count(SoundPlayed.id).desc())
.limit(limit)
)
result = await self.session.exec(statement)
rows = result.all()
# Convert to dictionaries with the play count from the period
return [
{

View File

@@ -2,9 +2,9 @@
from typing import Any
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import selectinload
from app.core.logging import get_logger
from app.models.plan import Plan

View File

@@ -96,5 +96,5 @@ class UpdateProfileRequest(BaseModel):
"""Schema for profile update request."""
name: str | None = Field(
None, min_length=1, max_length=100, description="User display name"
None, min_length=1, max_length=100, description="User display name",
)

View File

@@ -434,36 +434,36 @@ class AuthService:
async def update_user_profile(self, user: User, data: dict) -> User:
"""Update user profile information."""
logger.info("Updating profile for user: %s", user.email)
# Only allow updating specific fields
allowed_fields = {"name"}
update_data = {k: v for k, v in data.items() if k in allowed_fields}
if not update_data:
return user
# Update user
for field, value in update_data.items():
setattr(user, field, value)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user, ["plan"])
logger.info("Profile updated successfully for user: %s", user.email)
return user
async def change_user_password(
self, user: User, current_password: str | None, new_password: str
self, user: User, current_password: str | None, new_password: str,
) -> None:
"""Change user's password."""
# Store user email before any operations to avoid session detachment issues
user_email = user.email
logger.info("Changing password for user: %s", user_email)
# Store whether user had existing password before we modify it
had_existing_password = user.password_hash is not None
# If user has existing password, verify it
if had_existing_password:
if not current_password:
@@ -473,24 +473,24 @@ class AuthService:
else:
# User doesn't have a password (OAuth-only user), so we're setting their first password
logger.info("Setting first password for OAuth user: %s", user_email)
# Hash new password
new_password_hash = PasswordUtils.hash_password(new_password)
# Update user
user.password_hash = new_password_hash
self.session.add(user)
await self.session.commit()
logger.info("Password %s successfully for user: %s",
logger.info("Password %s successfully for user: %s",
"changed" if had_existing_password else "set", user_email)
async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information."""
# Load plan relationship if not already loaded
if not hasattr(user, 'plan') or not user.plan:
if not hasattr(user, "plan") or not user.plan:
await self.session.refresh(user, ["plan"])
return UserResponse(
id=user.id,
email=user.email,

View File

@@ -56,14 +56,14 @@ class DashboardService:
try:
# Calculate the date filter based on period
date_filter = self._get_date_filter(period)
# Get top sounds from repository
top_sounds = await self.sound_repository.get_top_sounds(
sound_type=sound_type,
date_filter=date_filter,
limit=limit,
)
return [
{
"id": sound["id"],
@@ -86,7 +86,7 @@ class DashboardService:
period,
)
raise
def _get_date_filter(self, period: str) -> datetime | None:
"""Calculate the date filter based on the period."""
now = datetime.now(UTC)

View File

@@ -272,9 +272,8 @@ class PlaylistService:
# Ensure position doesn't create gaps - if position is too high, place at end
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
max_position = len(current_sounds)
if position > max_position:
position = max_position
position = min(position, max_position)
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
logger.info(
"Added sound %s to playlist %s for user %s at position %s",
@@ -306,10 +305,10 @@ class PlaylistService:
)
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
# Reorder remaining sounds to eliminate gaps
await self._reorder_playlist_positions(playlist_id)
logger.info(
"Removed sound %s from playlist %s for user %s and reordered positions",
sound_id,
@@ -326,7 +325,7 @@ class PlaylistService:
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
if not sounds:
return
# Create sequential positions: 0, 1, 2, 3...
sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)]
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)