fix: Add missing commas in function calls and improve code formatting
This commit is contained in:
@@ -460,7 +460,7 @@ async def update_profile(
|
|||||||
"""Update the current user's profile."""
|
"""Update the current user's profile."""
|
||||||
try:
|
try:
|
||||||
updated_user = await auth_service.update_user_profile(
|
updated_user = await auth_service.update_user_profile(
|
||||||
current_user, request.model_dump(exclude_unset=True)
|
current_user, request.model_dump(exclude_unset=True),
|
||||||
)
|
)
|
||||||
return await auth_service.user_to_response(updated_user)
|
return await auth_service.user_to_response(updated_user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -482,7 +482,7 @@ async def change_password(
|
|||||||
user_email = current_user.email
|
user_email = current_user.email
|
||||||
try:
|
try:
|
||||||
await auth_service.change_user_password(
|
await auth_service.change_user_password(
|
||||||
current_user, request.current_password, request.new_password
|
current_user, request.current_password, request.new_password,
|
||||||
)
|
)
|
||||||
return {"message": "Password changed successfully"}
|
return {"message": "Password changed successfully"}
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -505,7 +505,7 @@ async def get_user_providers(
|
|||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Get the current user's connected authentication providers."""
|
"""Get the current user's connected authentication providers."""
|
||||||
providers = []
|
providers = []
|
||||||
|
|
||||||
# Add password provider if user has password
|
# Add password provider if user has password
|
||||||
if current_user.password_hash:
|
if current_user.password_hash:
|
||||||
providers.append({
|
providers.append({
|
||||||
@@ -513,7 +513,7 @@ async def get_user_providers(
|
|||||||
"display_name": "Password",
|
"display_name": "Password",
|
||||||
"connected_at": current_user.created_at.isoformat(),
|
"connected_at": current_user.created_at.isoformat(),
|
||||||
})
|
})
|
||||||
|
|
||||||
# Get OAuth providers from the database
|
# Get OAuth providers from the database
|
||||||
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
||||||
for oauth in oauth_providers:
|
for oauth in oauth_providers:
|
||||||
@@ -522,11 +522,11 @@ async def get_user_providers(
|
|||||||
display_name = "GitHub"
|
display_name = "GitHub"
|
||||||
elif oauth.provider == "google":
|
elif oauth.provider == "google":
|
||||||
display_name = "Google"
|
display_name = "Google"
|
||||||
|
|
||||||
providers.append({
|
providers.append({
|
||||||
"provider": oauth.provider,
|
"provider": oauth.provider,
|
||||||
"display_name": display_name,
|
"display_name": display_name,
|
||||||
"connected_at": oauth.created_at.isoformat(),
|
"connected_at": oauth.created_at.isoformat(),
|
||||||
})
|
})
|
||||||
|
|
||||||
return providers
|
return providers
|
||||||
|
|||||||
@@ -41,5 +41,5 @@ async def get_top_sounds(
|
|||||||
return await dashboard_service.get_top_sounds(
|
return await dashboard_service.get_top_sounds(
|
||||||
sound_type=sound_type,
|
sound_type=sound_type,
|
||||||
period=period,
|
period=period,
|
||||||
limit=limit
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.dependencies import get_current_active_user_flexible
|
from app.core.dependencies import get_current_active_user_flexible
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.playlist import PlaylistSortField, SortOrder
|
||||||
from app.schemas.common import MessageResponse
|
from app.schemas.common import MessageResponse
|
||||||
from app.schemas.playlist import (
|
from app.schemas.playlist import (
|
||||||
PlaylistAddSoundRequest,
|
PlaylistAddSoundRequest,
|
||||||
@@ -19,7 +20,6 @@ from app.schemas.playlist import (
|
|||||||
PlaylistUpdateRequest,
|
PlaylistUpdateRequest,
|
||||||
)
|
)
|
||||||
from app.services.playlist import PlaylistService
|
from app.services.playlist import PlaylistService
|
||||||
from app.repositories.playlist import PlaylistSortField, SortOrder
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
router = APIRouter(prefix="/playlists", tags=["playlists"])
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from app.core.dependencies import get_current_active_user_flexible
|
|||||||
from app.models.credit_action import CreditActionType
|
from app.models.credit_action import CreditActionType
|
||||||
from app.models.sound import Sound
|
from app.models.sound import Sound
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.sound import SoundRepository, SoundSortField, SortOrder
|
from app.repositories.sound import SortOrder, SoundRepository, SoundSortField
|
||||||
from app.services.credit import CreditService, InsufficientCreditsError
|
from app.services.credit import CreditService, InsufficientCreditsError
|
||||||
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.repositories.sound import SoundRepository
|
||||||
from app.services.auth import AuthService
|
from app.services.auth import AuthService
|
||||||
from app.services.dashboard import DashboardService
|
from app.services.dashboard import DashboardService
|
||||||
from app.services.oauth import OAuthService
|
from app.services.oauth import OAuthService
|
||||||
from app.repositories.sound import SoundRepository
|
|
||||||
from app.utils.auth import JWTUtils, TokenUtils
|
from app.utils.auth import JWTUtils, TokenUtils
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""Playlist repository for database operations."""
|
"""Playlist repository for database operations."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from sqlalchemy import func, update
|
from sqlalchemy import func, update
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlmodel import col, select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
@@ -18,7 +19,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
class PlaylistSortField(str, Enum):
|
class PlaylistSortField(str, Enum):
|
||||||
"""Playlist sort field enumeration."""
|
"""Playlist sort field enumeration."""
|
||||||
|
|
||||||
NAME = "name"
|
NAME = "name"
|
||||||
GENRE = "genre"
|
GENRE = "genre"
|
||||||
CREATED_AT = "created_at"
|
CREATED_AT = "created_at"
|
||||||
@@ -29,7 +30,7 @@ class PlaylistSortField(str, Enum):
|
|||||||
|
|
||||||
class SortOrder(str, Enum):
|
class SortOrder(str, Enum):
|
||||||
"""Sort order enumeration."""
|
"""Sort order enumeration."""
|
||||||
|
|
||||||
ASC = "asc"
|
ASC = "asc"
|
||||||
DESC = "desc"
|
DESC = "desc"
|
||||||
|
|
||||||
@@ -154,7 +155,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
# Use a two-step approach to avoid unique constraint violations:
|
# Use a two-step approach to avoid unique constraint violations:
|
||||||
# 1. Move all affected positions to negative temporary positions
|
# 1. Move all affected positions to negative temporary positions
|
||||||
# 2. Then move them to their final positions
|
# 2. Then move them to their final positions
|
||||||
|
|
||||||
# Step 1: Move to temporary negative positions
|
# Step 1: Move to temporary negative positions
|
||||||
update_to_negative = (
|
update_to_negative = (
|
||||||
update(PlaylistSound)
|
update(PlaylistSound)
|
||||||
@@ -166,8 +167,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
)
|
)
|
||||||
await self.session.exec(update_to_negative)
|
await self.session.exec(update_to_negative)
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
|
|
||||||
# Step 2: Move from temporary negative positions to final positions
|
# Step 2: Move from temporary negative positions to final positions
|
||||||
update_to_final = (
|
update_to_final = (
|
||||||
update(PlaylistSound)
|
update(PlaylistSound)
|
||||||
.where(
|
.where(
|
||||||
@@ -337,15 +338,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
||||||
.group_by(Playlist.id, User.name)
|
.group_by(Playlist.id, User.name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
if search_query and search_query.strip():
|
if search_query and search_query.strip():
|
||||||
search_pattern = f"%{search_query.strip().lower()}%"
|
search_pattern = f"%{search_query.strip().lower()}%"
|
||||||
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
|
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
|
||||||
|
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
subquery = subquery.where(Playlist.user_id == user_id)
|
subquery = subquery.where(Playlist.user_id == user_id)
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
if sort_by == PlaylistSortField.SOUND_COUNT:
|
if sort_by == PlaylistSortField.SOUND_COUNT:
|
||||||
if sort_order == SortOrder.DESC:
|
if sort_order == SortOrder.DESC:
|
||||||
@@ -360,7 +361,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
else:
|
else:
|
||||||
# Default sorting by name
|
# Default sorting by name
|
||||||
subquery = subquery.order_by(Playlist.name.asc())
|
subquery = subquery.order_by(Playlist.name.asc())
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Simple query without stats-based sorting
|
# Simple query without stats-based sorting
|
||||||
subquery = (
|
subquery = (
|
||||||
@@ -385,15 +386,15 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
.join(Sound, PlaylistSound.sound_id == Sound.id, isouter=True)
|
||||||
.group_by(Playlist.id, User.name)
|
.group_by(Playlist.id, User.name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
if search_query and search_query.strip():
|
if search_query and search_query.strip():
|
||||||
search_pattern = f"%{search_query.strip().lower()}%"
|
search_pattern = f"%{search_query.strip().lower()}%"
|
||||||
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
|
subquery = subquery.where(func.lower(Playlist.name).like(search_pattern))
|
||||||
|
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
subquery = subquery.where(Playlist.user_id == user_id)
|
subquery = subquery.where(Playlist.user_id == user_id)
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
if sort_by:
|
if sort_by:
|
||||||
if sort_by == PlaylistSortField.NAME:
|
if sort_by == PlaylistSortField.NAME:
|
||||||
@@ -406,7 +407,7 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
sort_column = Playlist.updated_at
|
sort_column = Playlist.updated_at
|
||||||
else:
|
else:
|
||||||
sort_column = Playlist.name
|
sort_column = Playlist.name
|
||||||
|
|
||||||
if sort_order == SortOrder.DESC:
|
if sort_order == SortOrder.DESC:
|
||||||
subquery = subquery.order_by(sort_column.desc())
|
subquery = subquery.order_by(sort_column.desc())
|
||||||
else:
|
else:
|
||||||
@@ -414,16 +415,16 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
else:
|
else:
|
||||||
# Default sorting by name ascending
|
# Default sorting by name ascending
|
||||||
subquery = subquery.order_by(Playlist.name.asc())
|
subquery = subquery.order_by(Playlist.name.asc())
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
if offset > 0:
|
if offset > 0:
|
||||||
subquery = subquery.offset(offset)
|
subquery = subquery.offset(offset)
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
subquery = subquery.limit(limit)
|
subquery = subquery.limit(limit)
|
||||||
|
|
||||||
result = await self.session.exec(subquery)
|
result = await self.session.exec(subquery)
|
||||||
rows = result.all()
|
rows = result.all()
|
||||||
|
|
||||||
# Convert to dictionary format
|
# Convert to dictionary format
|
||||||
playlists = []
|
playlists = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
@@ -442,11 +443,11 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
|||||||
"sound_count": row.sound_count or 0,
|
"sound_count": row.sound_count or 0,
|
||||||
"total_duration": row.total_duration or 0,
|
"total_duration": row.total_duration or 0,
|
||||||
})
|
})
|
||||||
|
|
||||||
return playlists
|
return playlists
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
|
"Failed to search and sort playlists: query=%s, sort_by=%s, sort_order=%s",
|
||||||
search_query, sort_by, sort_order
|
search_query, sort_by, sort_order,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
class SoundSortField(str, Enum):
|
class SoundSortField(str, Enum):
|
||||||
"""Sound sort field enumeration."""
|
"""Sound sort field enumeration."""
|
||||||
|
|
||||||
NAME = "name"
|
NAME = "name"
|
||||||
FILENAME = "filename"
|
FILENAME = "filename"
|
||||||
DURATION = "duration"
|
DURATION = "duration"
|
||||||
@@ -30,7 +30,7 @@ class SoundSortField(str, Enum):
|
|||||||
|
|
||||||
class SortOrder(str, Enum):
|
class SortOrder(str, Enum):
|
||||||
"""Sort order enumeration."""
|
"""Sort order enumeration."""
|
||||||
|
|
||||||
ASC = "asc"
|
ASC = "asc"
|
||||||
DESC = "desc"
|
DESC = "desc"
|
||||||
|
|
||||||
@@ -144,18 +144,18 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
"""Search and sort sounds with optional filtering."""
|
"""Search and sort sounds with optional filtering."""
|
||||||
try:
|
try:
|
||||||
statement = select(Sound)
|
statement = select(Sound)
|
||||||
|
|
||||||
# Apply type filter
|
# Apply type filter
|
||||||
if sound_types:
|
if sound_types:
|
||||||
statement = statement.where(col(Sound.type).in_(sound_types))
|
statement = statement.where(col(Sound.type).in_(sound_types))
|
||||||
|
|
||||||
# Apply search filter
|
# Apply search filter
|
||||||
if search_query and search_query.strip():
|
if search_query and search_query.strip():
|
||||||
search_pattern = f"%{search_query.strip().lower()}%"
|
search_pattern = f"%{search_query.strip().lower()}%"
|
||||||
statement = statement.where(
|
statement = statement.where(
|
||||||
func.lower(Sound.name).like(search_pattern)
|
func.lower(Sound.name).like(search_pattern),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply sorting
|
# Apply sorting
|
||||||
if sort_by:
|
if sort_by:
|
||||||
sort_column = getattr(Sound, sort_by.value)
|
sort_column = getattr(Sound, sort_by.value)
|
||||||
@@ -166,19 +166,19 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
else:
|
else:
|
||||||
# Default sorting by name ascending
|
# Default sorting by name ascending
|
||||||
statement = statement.order_by(Sound.name.asc())
|
statement = statement.order_by(Sound.name.asc())
|
||||||
|
|
||||||
# Apply pagination
|
# Apply pagination
|
||||||
if offset > 0:
|
if offset > 0:
|
||||||
statement = statement.offset(offset)
|
statement = statement.offset(offset)
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
statement = statement.limit(limit)
|
statement = statement.limit(limit)
|
||||||
|
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
|
"Failed to search and sort sounds: query=%s, types=%s, sort_by=%s, sort_order=%s",
|
||||||
search_query, sound_types, sort_by, sort_order
|
search_query, sound_types, sort_by, sort_order,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -189,17 +189,17 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
func.count(Sound.id).label("count"),
|
func.count(Sound.id).label("count"),
|
||||||
func.sum(Sound.play_count).label("total_plays"),
|
func.sum(Sound.play_count).label("total_plays"),
|
||||||
func.sum(Sound.duration).label("total_duration"),
|
func.sum(Sound.duration).label("total_duration"),
|
||||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
|
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
|
||||||
).where(Sound.type == "SDB")
|
).where(Sound.type == "SDB")
|
||||||
|
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
row = result.first()
|
row = result.first()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"count": row.count if row.count is not None else 0,
|
"count": row.count if row.count is not None else 0,
|
||||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||||
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
||||||
"total_size": row.total_size if row.total_size is not None else 0
|
"total_size": row.total_size if row.total_size is not None else 0,
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to get soundboard statistics")
|
logger.exception("Failed to get soundboard statistics")
|
||||||
@@ -212,17 +212,17 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
func.count(Sound.id).label("count"),
|
func.count(Sound.id).label("count"),
|
||||||
func.sum(Sound.play_count).label("total_plays"),
|
func.sum(Sound.play_count).label("total_plays"),
|
||||||
func.sum(Sound.duration).label("total_duration"),
|
func.sum(Sound.duration).label("total_duration"),
|
||||||
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size")
|
func.sum(Sound.size + func.coalesce(Sound.normalized_size, 0)).label("total_size"),
|
||||||
).where(Sound.type == "EXT")
|
).where(Sound.type == "EXT")
|
||||||
|
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
row = result.first()
|
row = result.first()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"count": row.count if row.count is not None else 0,
|
"count": row.count if row.count is not None else 0,
|
||||||
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
"total_plays": row.total_plays if row.total_plays is not None else 0,
|
||||||
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
"total_duration": row.total_duration if row.total_duration is not None else 0,
|
||||||
"total_size": row.total_size if row.total_size is not None else 0
|
"total_size": row.total_size if row.total_size is not None else 0,
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to get track statistics")
|
logger.exception("Failed to get track statistics")
|
||||||
@@ -244,20 +244,20 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
Sound.type,
|
Sound.type,
|
||||||
Sound.duration,
|
Sound.duration,
|
||||||
Sound.created_at,
|
Sound.created_at,
|
||||||
func.count(SoundPlayed.id).label("play_count")
|
func.count(SoundPlayed.id).label("play_count"),
|
||||||
)
|
)
|
||||||
.select_from(SoundPlayed)
|
.select_from(SoundPlayed)
|
||||||
.join(Sound, SoundPlayed.sound_id == Sound.id)
|
.join(Sound, SoundPlayed.sound_id == Sound.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply sound type filter
|
# Apply sound type filter
|
||||||
if sound_type != "all":
|
if sound_type != "all":
|
||||||
statement = statement.where(Sound.type == sound_type.upper())
|
statement = statement.where(Sound.type == sound_type.upper())
|
||||||
|
|
||||||
# Apply date filter if provided
|
# Apply date filter if provided
|
||||||
if date_filter:
|
if date_filter:
|
||||||
statement = statement.where(SoundPlayed.created_at >= date_filter)
|
statement = statement.where(SoundPlayed.created_at >= date_filter)
|
||||||
|
|
||||||
# Group by sound and order by play count descending
|
# Group by sound and order by play count descending
|
||||||
statement = (
|
statement = (
|
||||||
statement
|
statement
|
||||||
@@ -266,15 +266,15 @@ class SoundRepository(BaseRepository[Sound]):
|
|||||||
Sound.name,
|
Sound.name,
|
||||||
Sound.type,
|
Sound.type,
|
||||||
Sound.duration,
|
Sound.duration,
|
||||||
Sound.created_at
|
Sound.created_at,
|
||||||
)
|
)
|
||||||
.order_by(func.count(SoundPlayed.id).desc())
|
.order_by(func.count(SoundPlayed.id).desc())
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
rows = result.all()
|
rows = result.all()
|
||||||
|
|
||||||
# Convert to dictionaries with the play count from the period
|
# Convert to dictionaries with the play count from the period
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
|
|||||||
@@ -96,5 +96,5 @@ class UpdateProfileRequest(BaseModel):
|
|||||||
"""Schema for profile update request."""
|
"""Schema for profile update request."""
|
||||||
|
|
||||||
name: str | None = Field(
|
name: str | None = Field(
|
||||||
None, min_length=1, max_length=100, description="User display name"
|
None, min_length=1, max_length=100, description="User display name",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -434,36 +434,36 @@ class AuthService:
|
|||||||
async def update_user_profile(self, user: User, data: dict) -> User:
|
async def update_user_profile(self, user: User, data: dict) -> User:
|
||||||
"""Update user profile information."""
|
"""Update user profile information."""
|
||||||
logger.info("Updating profile for user: %s", user.email)
|
logger.info("Updating profile for user: %s", user.email)
|
||||||
|
|
||||||
# Only allow updating specific fields
|
# Only allow updating specific fields
|
||||||
allowed_fields = {"name"}
|
allowed_fields = {"name"}
|
||||||
update_data = {k: v for k, v in data.items() if k in allowed_fields}
|
update_data = {k: v for k, v in data.items() if k in allowed_fields}
|
||||||
|
|
||||||
if not update_data:
|
if not update_data:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
# Update user
|
# Update user
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(user, field, value)
|
setattr(user, field, value)
|
||||||
|
|
||||||
self.session.add(user)
|
self.session.add(user)
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
await self.session.refresh(user, ["plan"])
|
await self.session.refresh(user, ["plan"])
|
||||||
|
|
||||||
logger.info("Profile updated successfully for user: %s", user.email)
|
logger.info("Profile updated successfully for user: %s", user.email)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def change_user_password(
|
async def change_user_password(
|
||||||
self, user: User, current_password: str | None, new_password: str
|
self, user: User, current_password: str | None, new_password: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Change user's password."""
|
"""Change user's password."""
|
||||||
# Store user email before any operations to avoid session detachment issues
|
# Store user email before any operations to avoid session detachment issues
|
||||||
user_email = user.email
|
user_email = user.email
|
||||||
logger.info("Changing password for user: %s", user_email)
|
logger.info("Changing password for user: %s", user_email)
|
||||||
|
|
||||||
# Store whether user had existing password before we modify it
|
# Store whether user had existing password before we modify it
|
||||||
had_existing_password = user.password_hash is not None
|
had_existing_password = user.password_hash is not None
|
||||||
|
|
||||||
# If user has existing password, verify it
|
# If user has existing password, verify it
|
||||||
if had_existing_password:
|
if had_existing_password:
|
||||||
if not current_password:
|
if not current_password:
|
||||||
@@ -473,24 +473,24 @@ class AuthService:
|
|||||||
else:
|
else:
|
||||||
# User doesn't have a password (OAuth-only user), so we're setting their first password
|
# User doesn't have a password (OAuth-only user), so we're setting their first password
|
||||||
logger.info("Setting first password for OAuth user: %s", user_email)
|
logger.info("Setting first password for OAuth user: %s", user_email)
|
||||||
|
|
||||||
# Hash new password
|
# Hash new password
|
||||||
new_password_hash = PasswordUtils.hash_password(new_password)
|
new_password_hash = PasswordUtils.hash_password(new_password)
|
||||||
|
|
||||||
# Update user
|
# Update user
|
||||||
user.password_hash = new_password_hash
|
user.password_hash = new_password_hash
|
||||||
self.session.add(user)
|
self.session.add(user)
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
|
|
||||||
logger.info("Password %s successfully for user: %s",
|
logger.info("Password %s successfully for user: %s",
|
||||||
"changed" if had_existing_password else "set", user_email)
|
"changed" if had_existing_password else "set", user_email)
|
||||||
|
|
||||||
async def user_to_response(self, user: User) -> UserResponse:
|
async def user_to_response(self, user: User) -> UserResponse:
|
||||||
"""Convert User model to UserResponse with plan information."""
|
"""Convert User model to UserResponse with plan information."""
|
||||||
# Load plan relationship if not already loaded
|
# Load plan relationship if not already loaded
|
||||||
if not hasattr(user, 'plan') or not user.plan:
|
if not hasattr(user, "plan") or not user.plan:
|
||||||
await self.session.refresh(user, ["plan"])
|
await self.session.refresh(user, ["plan"])
|
||||||
|
|
||||||
return UserResponse(
|
return UserResponse(
|
||||||
id=user.id,
|
id=user.id,
|
||||||
email=user.email,
|
email=user.email,
|
||||||
|
|||||||
@@ -56,14 +56,14 @@ class DashboardService:
|
|||||||
try:
|
try:
|
||||||
# Calculate the date filter based on period
|
# Calculate the date filter based on period
|
||||||
date_filter = self._get_date_filter(period)
|
date_filter = self._get_date_filter(period)
|
||||||
|
|
||||||
# Get top sounds from repository
|
# Get top sounds from repository
|
||||||
top_sounds = await self.sound_repository.get_top_sounds(
|
top_sounds = await self.sound_repository.get_top_sounds(
|
||||||
sound_type=sound_type,
|
sound_type=sound_type,
|
||||||
date_filter=date_filter,
|
date_filter=date_filter,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"id": sound["id"],
|
"id": sound["id"],
|
||||||
@@ -86,7 +86,7 @@ class DashboardService:
|
|||||||
period,
|
period,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _get_date_filter(self, period: str) -> datetime | None:
|
def _get_date_filter(self, period: str) -> datetime | None:
|
||||||
"""Calculate the date filter based on the period."""
|
"""Calculate the date filter based on the period."""
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
|
|||||||
@@ -272,9 +272,8 @@ class PlaylistService:
|
|||||||
# Ensure position doesn't create gaps - if position is too high, place at end
|
# Ensure position doesn't create gaps - if position is too high, place at end
|
||||||
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
current_sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||||
max_position = len(current_sounds)
|
max_position = len(current_sounds)
|
||||||
if position > max_position:
|
position = min(position, max_position)
|
||||||
position = max_position
|
|
||||||
|
|
||||||
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
|
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Added sound %s to playlist %s for user %s at position %s",
|
"Added sound %s to playlist %s for user %s at position %s",
|
||||||
@@ -306,10 +305,10 @@ class PlaylistService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
|
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
|
||||||
|
|
||||||
# Reorder remaining sounds to eliminate gaps
|
# Reorder remaining sounds to eliminate gaps
|
||||||
await self._reorder_playlist_positions(playlist_id)
|
await self._reorder_playlist_positions(playlist_id)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removed sound %s from playlist %s for user %s and reordered positions",
|
"Removed sound %s from playlist %s for user %s and reordered positions",
|
||||||
sound_id,
|
sound_id,
|
||||||
@@ -326,7 +325,7 @@ class PlaylistService:
|
|||||||
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
|
||||||
if not sounds:
|
if not sounds:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create sequential positions: 0, 1, 2, 3...
|
# Create sequential positions: 0, 1, 2, 3...
|
||||||
sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)]
|
sound_positions = [(sound.id, index) for index, sound in enumerate(sounds)]
|
||||||
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)
|
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for admin user endpoints."""
|
"""Tests for admin user endpoints."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
@@ -76,9 +76,9 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
|
|
||||||
mock_regular = type("User", (), {
|
mock_regular = type("User", (), {
|
||||||
"id": regular_user.id,
|
"id": regular_user.id,
|
||||||
"email": regular_user.email,
|
"email": regular_user.email,
|
||||||
@@ -93,9 +93,9 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
|
|
||||||
mock_get_all.return_value = [mock_admin, mock_regular]
|
mock_get_all.return_value = [mock_admin, mock_regular]
|
||||||
|
|
||||||
response = await authenticated_admin_client.get("/api/v1/admin/users/")
|
response = await authenticated_admin_client.get("/api/v1/admin/users/")
|
||||||
@@ -130,7 +130,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_all.return_value = [mock_admin]
|
mock_get_all.return_value = [mock_admin]
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
|
|
||||||
@@ -244,9 +244,9 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
|
|
||||||
updated_mock = type("User", (), {
|
updated_mock = type("User", (), {
|
||||||
"id": regular_user.id,
|
"id": regular_user.id,
|
||||||
"email": regular_user.email,
|
"email": regular_user.email,
|
||||||
@@ -261,9 +261,9 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
|
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
mock_update.return_value = updated_mock
|
mock_update.return_value = updated_mock
|
||||||
|
|
||||||
@@ -278,7 +278,7 @@ class TestAdminUserEndpoints:
|
|||||||
"name": "Updated Name",
|
"name": "Updated Name",
|
||||||
"credits": 200,
|
"credits": 200,
|
||||||
"plan_id": 1,
|
"plan_id": 1,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -299,7 +299,7 @@ class TestAdminUserEndpoints:
|
|||||||
):
|
):
|
||||||
response = await authenticated_admin_client.patch(
|
response = await authenticated_admin_client.patch(
|
||||||
"/api/v1/admin/users/999",
|
"/api/v1/admin/users/999",
|
||||||
json={"name": "Updated Name"}
|
json={"name": "Updated Name"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
@@ -333,12 +333,12 @@ class TestAdminUserEndpoints:
|
|||||||
"id": 1,
|
"id": 1,
|
||||||
"name": "Basic",
|
"name": "Basic",
|
||||||
"max_credits": 100,
|
"max_credits": 100,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
response = await authenticated_admin_client.patch(
|
response = await authenticated_admin_client.patch(
|
||||||
"/api/v1/admin/users/2",
|
"/api/v1/admin/users/2",
|
||||||
json={"plan_id": 999}
|
json={"plan_id": 999},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
@@ -373,7 +373,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_by_id.return_value = mock_user
|
mock_get_by_id.return_value = mock_user
|
||||||
mock_update.return_value = mock_user
|
mock_update.return_value = mock_user
|
||||||
@@ -438,7 +438,7 @@ class TestAdminUserEndpoints:
|
|||||||
"id": test_plan.id,
|
"id": test_plan.id,
|
||||||
"name": test_plan.name,
|
"name": test_plan.name,
|
||||||
"max_credits": test_plan.max_credits,
|
"max_credits": test_plan.max_credits,
|
||||||
})()
|
})(),
|
||||||
})()
|
})()
|
||||||
mock_get_by_id.return_value = mock_disabled_user
|
mock_get_by_id.return_value = mock_disabled_user
|
||||||
mock_update.return_value = mock_disabled_user
|
mock_update.return_value = mock_disabled_user
|
||||||
@@ -487,4 +487,4 @@ class TestAdminUserEndpoints:
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
assert len(data) == 2
|
assert len(data) == 2
|
||||||
assert data[0]["name"] == "Basic"
|
assert data[0]["name"] == "Basic"
|
||||||
assert data[1]["name"] == "Premium"
|
assert data[1]["name"] == "Premium"
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Tests for authentication endpoints."""
|
"""Tests for authentication endpoints."""
|
||||||
|
|
||||||
|
from datetime import UTC
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@@ -495,7 +496,7 @@ class TestAuthEndpoints:
|
|||||||
|
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/v1/auth/refresh",
|
"/api/v1/auth/refresh",
|
||||||
cookies={"refresh_token": "valid_refresh_token"}
|
cookies={"refresh_token": "valid_refresh_token"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -520,7 +521,7 @@ class TestAuthEndpoints:
|
|||||||
|
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/v1/auth/refresh",
|
"/api/v1/auth/refresh",
|
||||||
cookies={"refresh_token": "valid_refresh_token"}
|
cookies={"refresh_token": "valid_refresh_token"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
@@ -536,7 +537,7 @@ class TestAuthEndpoints:
|
|||||||
"""Test OAuth token exchange with invalid code."""
|
"""Test OAuth token exchange with invalid code."""
|
||||||
response = await test_client.post(
|
response = await test_client.post(
|
||||||
"/api/v1/auth/exchange-oauth-token",
|
"/api/v1/auth/exchange-oauth-token",
|
||||||
json={"code": "invalid_code"}
|
json={"code": "invalid_code"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
@@ -565,7 +566,7 @@ class TestAuthEndpoints:
|
|||||||
is_active=test_user.is_active,
|
is_active=test_user.is_active,
|
||||||
)
|
)
|
||||||
mock_update.return_value = updated_user
|
mock_update.return_value = updated_user
|
||||||
|
|
||||||
# Mock the user_to_response to return UserResponse format
|
# Mock the user_to_response to return UserResponse format
|
||||||
from app.schemas.auth import UserResponse
|
from app.schemas.auth import UserResponse
|
||||||
mock_user_to_response.return_value = UserResponse(
|
mock_user_to_response.return_value = UserResponse(
|
||||||
@@ -589,7 +590,7 @@ class TestAuthEndpoints:
|
|||||||
response = await test_client.patch(
|
response = await test_client.patch(
|
||||||
"/api/v1/auth/me",
|
"/api/v1/auth/me",
|
||||||
json={"name": "Updated Name"},
|
json={"name": "Updated Name"},
|
||||||
cookies=auth_cookies
|
cookies=auth_cookies,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -601,7 +602,7 @@ class TestAuthEndpoints:
|
|||||||
"""Test update profile without authentication."""
|
"""Test update profile without authentication."""
|
||||||
response = await test_client.patch(
|
response = await test_client.patch(
|
||||||
"/api/v1/auth/me",
|
"/api/v1/auth/me",
|
||||||
json={"name": "Updated Name"}
|
json={"name": "Updated Name"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
@@ -621,9 +622,9 @@ class TestAuthEndpoints:
|
|||||||
"/api/v1/auth/change-password",
|
"/api/v1/auth/change-password",
|
||||||
json={
|
json={
|
||||||
"current_password": "old_password",
|
"current_password": "old_password",
|
||||||
"new_password": "new_password"
|
"new_password": "new_password",
|
||||||
},
|
},
|
||||||
cookies=auth_cookies
|
cookies=auth_cookies,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -637,8 +638,8 @@ class TestAuthEndpoints:
|
|||||||
"/api/v1/auth/change-password",
|
"/api/v1/auth/change-password",
|
||||||
json={
|
json={
|
||||||
"current_password": "old_password",
|
"current_password": "old_password",
|
||||||
"new_password": "new_password"
|
"new_password": "new_password",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
@@ -652,9 +653,10 @@ class TestAuthEndpoints:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test get user OAuth providers success."""
|
"""Test get user OAuth providers success."""
|
||||||
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
|
with patch("app.services.auth.AuthService.get_user_oauth_providers") as mock_providers:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from app.models.user_oauth import UserOauth
|
from app.models.user_oauth import UserOauth
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
mock_oauth_google = UserOauth(
|
mock_oauth_google = UserOauth(
|
||||||
id=1,
|
id=1,
|
||||||
user_id=test_user.id,
|
user_id=test_user.id,
|
||||||
@@ -662,34 +664,34 @@ class TestAuthEndpoints:
|
|||||||
provider_user_id="google123",
|
provider_user_id="google123",
|
||||||
email="test@example.com",
|
email="test@example.com",
|
||||||
name="Test User",
|
name="Test User",
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=datetime.now(UTC),
|
||||||
updated_at=datetime.now(timezone.utc),
|
updated_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
mock_oauth_github = UserOauth(
|
mock_oauth_github = UserOauth(
|
||||||
id=2,
|
id=2,
|
||||||
user_id=test_user.id,
|
user_id=test_user.id,
|
||||||
provider="github",
|
provider="github",
|
||||||
provider_user_id="github456",
|
provider_user_id="github456",
|
||||||
email="test@example.com",
|
email="test@example.com",
|
||||||
name="Test User",
|
name="Test User",
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=datetime.now(UTC),
|
||||||
updated_at=datetime.now(timezone.utc),
|
updated_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
mock_providers.return_value = [mock_oauth_google, mock_oauth_github]
|
mock_providers.return_value = [mock_oauth_google, mock_oauth_github]
|
||||||
|
|
||||||
response = await test_client.get(
|
response = await test_client.get(
|
||||||
"/api/v1/auth/user-providers",
|
"/api/v1/auth/user-providers",
|
||||||
cookies=auth_cookies
|
cookies=auth_cookies,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert len(data) == 3 # password + 2 OAuth providers
|
assert len(data) == 3 # password + 2 OAuth providers
|
||||||
|
|
||||||
# Check password provider (first)
|
# Check password provider (first)
|
||||||
assert data[0]["provider"] == "password"
|
assert data[0]["provider"] == "password"
|
||||||
assert data[0]["display_name"] == "Password"
|
assert data[0]["display_name"] == "Password"
|
||||||
|
|
||||||
# Check OAuth providers
|
# Check OAuth providers
|
||||||
assert data[1]["provider"] == "google"
|
assert data[1]["provider"] == "google"
|
||||||
assert data[1]["display_name"] == "Google"
|
assert data[1]["display_name"] == "Google"
|
||||||
|
|||||||
@@ -529,7 +529,7 @@ class TestPlaylistRepository:
|
|||||||
)
|
)
|
||||||
test_session.add(sound)
|
test_session.add(sound)
|
||||||
sounds.append(sound)
|
sounds.append(sound)
|
||||||
|
|
||||||
await test_session.commit()
|
await test_session.commit()
|
||||||
await test_session.refresh(playlist)
|
await test_session.refresh(playlist)
|
||||||
for sound in sounds:
|
for sound in sounds:
|
||||||
@@ -547,7 +547,7 @@ class TestPlaylistRepository:
|
|||||||
|
|
||||||
# Verify the final positions
|
# Verify the final positions
|
||||||
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
|
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
|
||||||
|
|
||||||
assert len(playlist_sounds) == 3
|
assert len(playlist_sounds) == 3
|
||||||
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0
|
assert playlist_sounds[0].sound_id == sound_ids[0] # Original sound 0 stays at position 0
|
||||||
assert playlist_sounds[0].position == 0
|
assert playlist_sounds[0].position == 0
|
||||||
@@ -605,7 +605,7 @@ class TestPlaylistRepository:
|
|||||||
)
|
)
|
||||||
test_session.add(sound)
|
test_session.add(sound)
|
||||||
sounds.append(sound)
|
sounds.append(sound)
|
||||||
|
|
||||||
await test_session.commit()
|
await test_session.commit()
|
||||||
await test_session.refresh(playlist)
|
await test_session.refresh(playlist)
|
||||||
for sound in sounds:
|
for sound in sounds:
|
||||||
@@ -623,7 +623,7 @@ class TestPlaylistRepository:
|
|||||||
|
|
||||||
# Verify the final positions
|
# Verify the final positions
|
||||||
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
|
playlist_sounds = await playlist_repository.get_playlist_sound_entries(playlist_id)
|
||||||
|
|
||||||
assert len(playlist_sounds) == 3
|
assert len(playlist_sounds) == 3
|
||||||
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0
|
assert playlist_sounds[0].sound_id == sound_ids[2] # New sound 2 inserted at position 0
|
||||||
assert playlist_sounds[0].position == 0
|
assert playlist_sounds[0].position == 0
|
||||||
|
|||||||
@@ -409,7 +409,7 @@ class TestCreditService:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_recharge_user_credits_success(
|
async def test_recharge_user_credits_success(
|
||||||
self, credit_service, sample_user
|
self, credit_service, sample_user,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful credit recharge for a user."""
|
"""Test successful credit recharge for a user."""
|
||||||
mock_session = credit_service.db_session_factory()
|
mock_session = credit_service.db_session_factory()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for dashboard service."""
|
"""Tests for dashboard service."""
|
||||||
|
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -63,7 +63,7 @@ class TestDashboardService:
|
|||||||
):
|
):
|
||||||
"""Test getting soundboard statistics with exception."""
|
"""Test getting soundboard statistics with exception."""
|
||||||
mock_sound_repository.get_soundboard_statistics = AsyncMock(
|
mock_sound_repository.get_soundboard_statistics = AsyncMock(
|
||||||
side_effect=Exception("Database error")
|
side_effect=Exception("Database error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Database error"):
|
with pytest.raises(Exception, match="Database error"):
|
||||||
@@ -105,7 +105,7 @@ class TestDashboardService:
|
|||||||
):
|
):
|
||||||
"""Test getting track statistics with exception."""
|
"""Test getting track statistics with exception."""
|
||||||
mock_sound_repository.get_track_statistics = AsyncMock(
|
mock_sound_repository.get_track_statistics = AsyncMock(
|
||||||
side_effect=Exception("Database error")
|
side_effect=Exception("Database error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Database error"):
|
with pytest.raises(Exception, match="Database error"):
|
||||||
@@ -198,7 +198,7 @@ class TestDashboardService:
|
|||||||
):
|
):
|
||||||
"""Test getting top sounds with exception."""
|
"""Test getting top sounds with exception."""
|
||||||
mock_sound_repository.get_top_sounds = AsyncMock(
|
mock_sound_repository.get_top_sounds = AsyncMock(
|
||||||
side_effect=Exception("Database error")
|
side_effect=Exception("Database error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(Exception, match="Database error"):
|
with pytest.raises(Exception, match="Database error"):
|
||||||
@@ -274,4 +274,4 @@ class TestDashboardService:
|
|||||||
def test_get_date_filter_unknown_period(self, dashboard_service):
|
def test_get_date_filter_unknown_period(self, dashboard_service):
|
||||||
"""Test date filter for unknown period."""
|
"""Test date filter for unknown period."""
|
||||||
result = dashboard_service._get_date_filter("unknown_period")
|
result = dashboard_service._get_date_filter("unknown_period")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|||||||
@@ -25,9 +25,9 @@ class TestSchedulerService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_start_scheduler(self, scheduler_service) -> None:
|
async def test_start_scheduler(self, scheduler_service) -> None:
|
||||||
"""Test starting the scheduler service."""
|
"""Test starting the scheduler service."""
|
||||||
with patch.object(scheduler_service.scheduler, 'add_job') as mock_add_job, \
|
with patch.object(scheduler_service.scheduler, "add_job") as mock_add_job, \
|
||||||
patch.object(scheduler_service.scheduler, 'start') as mock_start:
|
patch.object(scheduler_service.scheduler, "start") as mock_start:
|
||||||
|
|
||||||
await scheduler_service.start()
|
await scheduler_service.start()
|
||||||
|
|
||||||
# Verify job was added
|
# Verify job was added
|
||||||
@@ -47,7 +47,7 @@ class TestSchedulerService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stop_scheduler(self, scheduler_service) -> None:
|
async def test_stop_scheduler(self, scheduler_service) -> None:
|
||||||
"""Test stopping the scheduler service."""
|
"""Test stopping the scheduler service."""
|
||||||
with patch.object(scheduler_service.scheduler, 'shutdown') as mock_shutdown:
|
with patch.object(scheduler_service.scheduler, "shutdown") as mock_shutdown:
|
||||||
await scheduler_service.stop()
|
await scheduler_service.stop()
|
||||||
mock_shutdown.assert_called_once()
|
mock_shutdown.assert_called_once()
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ class TestSchedulerService:
|
|||||||
"total_credits_added": 500,
|
"total_credits_added": 500,
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
|
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
||||||
mock_recharge.return_value = mock_stats
|
mock_recharge.return_value = mock_stats
|
||||||
|
|
||||||
await scheduler_service._daily_credit_recharge()
|
await scheduler_service._daily_credit_recharge()
|
||||||
@@ -71,10 +71,10 @@ class TestSchedulerService:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
|
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
|
||||||
"""Test daily credit recharge task with failure."""
|
"""Test daily credit recharge task with failure."""
|
||||||
with patch.object(scheduler_service.credit_service, 'recharge_all_users_credits') as mock_recharge:
|
with patch.object(scheduler_service.credit_service, "recharge_all_users_credits") as mock_recharge:
|
||||||
mock_recharge.side_effect = Exception("Database error")
|
mock_recharge.side_effect = Exception("Database error")
|
||||||
|
|
||||||
# Should not raise exception, just log it
|
# Should not raise exception, just log it
|
||||||
await scheduler_service._daily_credit_recharge()
|
await scheduler_service._daily_credit_recharge()
|
||||||
|
|
||||||
mock_recharge.assert_called_once()
|
mock_recharge.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user