Add comprehensive tests for playlist service and refactor socket service tests

- Introduced a new test suite for the PlaylistService covering various functionalities including creation, retrieval, updating, and deletion of playlists.
- Added tests for handling sounds within playlists, ensuring correct behavior when adding/removing sounds and managing current playlists.
- Refactored socket service tests for improved readability by adjusting function signatures.
- Cleaned up unnecessary whitespace in sound normalizer and sound scanner tests for consistency.
- Enhanced audio utility tests to ensure accurate hash and size calculations, including edge cases for nonexistent files.
- Removed redundant blank lines in cookie utility tests for cleaner code.
This commit is contained in:
JSC
2025-07-29 19:25:46 +02:00
parent 301b5dd794
commit 5ed19c8f0f
31 changed files with 4248 additions and 194 deletions

View File

@@ -2,13 +2,14 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1 import auth, main, socket, sounds from app.api.v1 import auth, main, playlists, socket, sounds
# V1 API router with v1 prefix # V1 API router with v1 prefix
api_router = APIRouter(prefix="/v1") api_router = APIRouter(prefix="/v1")
# Include all route modules # Include all route modules
api_router.include_router(auth.router, tags=["authentication"])
api_router.include_router(main.router, tags=["main"]) api_router.include_router(main.router, tags=["main"])
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"]) api_router.include_router(playlists.router, tags=["playlists"])
api_router.include_router(socket.router, tags=["socket"]) api_router.include_router(socket.router, tags=["socket"])
api_router.include_router(sounds.router, tags=["sounds"]) api_router.include_router(sounds.router, tags=["sounds"])

View File

@@ -28,7 +28,7 @@ from app.services.auth import AuthService
from app.services.oauth import OAuthService from app.services.oauth import OAuthService
from app.utils.auth import JWTUtils, TokenUtils from app.utils.auth import JWTUtils, TokenUtils
router = APIRouter() router = APIRouter(prefix="/auth", tags=["authentication"])
logger = get_logger(__name__) logger = get_logger(__name__)
# Global temporary storage for OAuth codes (in production, use Redis with TTL) # Global temporary storage for OAuth codes (in production, use Redis with TTL)
@@ -459,7 +459,8 @@ async def generate_api_token(
) )
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to generate API token for user: %s", current_user.email, "Failed to generate API token for user: %s",
current_user.email,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -495,7 +496,8 @@ async def revoke_api_token(
await auth_service.revoke_api_token(current_user) await auth_service.revoke_api_token(current_user)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to revoke API token for user: %s", current_user.email, "Failed to revoke API token for user: %s",
current_user.email,
) )
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

328
app/api/v1/playlists.py Normal file
View File

@@ -0,0 +1,328 @@
"""Playlist management API endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.playlist import Playlist
from app.models.sound import Sound
from app.models.user import User
from app.services.playlist import PlaylistService
router = APIRouter(prefix="/playlists", tags=["playlists"])
class PlaylistCreateRequest(BaseModel):
"""Request model for creating a playlist."""
name: str
description: str | None = None
genre: str | None = None
class PlaylistUpdateRequest(BaseModel):
"""Request model for updating a playlist."""
name: str | None = None
description: str | None = None
genre: str | None = None
is_current: bool | None = None
class PlaylistResponse(BaseModel):
"""Response model for playlist data."""
id: int
name: str
description: str | None
genre: str | None
is_main: bool
is_current: bool
is_deletable: bool
created_at: str
updated_at: str | None
@classmethod
def from_playlist(cls, playlist: Playlist) -> "PlaylistResponse":
"""Create response from playlist model."""
return cls(
id=playlist.id,
name=playlist.name,
description=playlist.description,
genre=playlist.genre,
is_main=playlist.is_main,
is_current=playlist.is_current,
is_deletable=playlist.is_deletable,
created_at=playlist.created_at.isoformat(),
updated_at=playlist.updated_at.isoformat() if playlist.updated_at else None,
)
class SoundResponse(BaseModel):
"""Response model for sound data in playlists."""
id: int
name: str
filename: str
type: str
duration: int | None
size: int | None
play_count: int
created_at: str
@classmethod
def from_sound(cls, sound: Sound) -> "SoundResponse":
"""Create response from sound model."""
return cls(
id=sound.id,
name=sound.name,
filename=sound.filename,
type=sound.type,
duration=sound.duration,
size=sound.size,
play_count=sound.play_count,
created_at=sound.created_at.isoformat(),
)
class AddSoundRequest(BaseModel):
"""Request model for adding a sound to a playlist."""
sound_id: int
position: int | None = None
class ReorderRequest(BaseModel):
"""Request model for reordering sounds in a playlist."""
sound_positions: list[tuple[int, int]]
class PlaylistStatsResponse(BaseModel):
"""Response model for playlist statistics."""
sound_count: int
total_duration_ms: int
total_play_count: int
async def get_playlist_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> PlaylistService:
"""Get the playlist service."""
return PlaylistService(session)
@router.get("/")
async def get_all_playlists(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> list[PlaylistResponse]:
"""Get all playlists from all users."""
playlists = await playlist_service.get_all_playlists()
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
@router.get("/user")
async def get_user_playlists(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> list[PlaylistResponse]:
"""Get playlists for the current user only."""
playlists = await playlist_service.get_user_playlists(current_user.id)
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
@router.get("/main")
async def get_main_playlist(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistResponse:
"""Get the global main playlist."""
playlist = await playlist_service.get_main_playlist()
return PlaylistResponse.from_playlist(playlist)
@router.get("/current")
async def get_current_playlist(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistResponse:
"""Get the user's current playlist (falls back to main playlist)."""
playlist = await playlist_service.get_current_playlist(current_user.id)
return PlaylistResponse.from_playlist(playlist)
@router.post("/")
async def create_playlist(
request: PlaylistCreateRequest,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistResponse:
"""Create a new playlist."""
playlist = await playlist_service.create_playlist(
user_id=current_user.id,
name=request.name,
description=request.description,
genre=request.genre,
)
return PlaylistResponse.from_playlist(playlist)
@router.get("/{playlist_id}")
async def get_playlist(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistResponse:
"""Get a specific playlist."""
playlist = await playlist_service.get_playlist_by_id(playlist_id)
return PlaylistResponse.from_playlist(playlist)
@router.put("/{playlist_id}")
async def update_playlist(
playlist_id: int,
request: PlaylistUpdateRequest,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistResponse:
"""Update a playlist."""
playlist = await playlist_service.update_playlist(
playlist_id=playlist_id,
user_id=current_user.id,
name=request.name,
description=request.description,
genre=request.genre,
is_current=request.is_current,
)
return PlaylistResponse.from_playlist(playlist)
@router.delete("/{playlist_id}")
async def delete_playlist(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> dict[str, str]:
"""Delete a playlist."""
await playlist_service.delete_playlist(playlist_id, current_user.id)
return {"message": "Playlist deleted successfully"}
@router.get("/search/{query}")
async def search_playlists(
query: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> list[PlaylistResponse]:
"""Search all playlists by name."""
playlists = await playlist_service.search_all_playlists(query)
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
@router.get("/user/search/{query}")
async def search_user_playlists(
query: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> list[PlaylistResponse]:
"""Search current user's playlists by name."""
playlists = await playlist_service.search_playlists(query, current_user.id)
return [PlaylistResponse.from_playlist(playlist) for playlist in playlists]
@router.get("/{playlist_id}/sounds")
async def get_playlist_sounds(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> list[SoundResponse]:
"""Get all sounds in a playlist."""
sounds = await playlist_service.get_playlist_sounds(playlist_id)
return [SoundResponse.from_sound(sound) for sound in sounds]
@router.post("/{playlist_id}/sounds")
async def add_sound_to_playlist(
playlist_id: int,
request: AddSoundRequest,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> dict[str, str]:
"""Add a sound to a playlist."""
await playlist_service.add_sound_to_playlist(
playlist_id=playlist_id,
sound_id=request.sound_id,
user_id=current_user.id,
position=request.position,
)
return {"message": "Sound added to playlist successfully"}
@router.delete("/{playlist_id}/sounds/{sound_id}")
async def remove_sound_from_playlist(
playlist_id: int,
sound_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> dict[str, str]:
"""Remove a sound from a playlist."""
await playlist_service.remove_sound_from_playlist(
playlist_id=playlist_id,
sound_id=sound_id,
user_id=current_user.id,
)
return {"message": "Sound removed from playlist successfully"}
@router.put("/{playlist_id}/sounds/reorder")
async def reorder_playlist_sounds(
playlist_id: int,
request: ReorderRequest,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> dict[str, str]:
"""Reorder sounds in a playlist."""
await playlist_service.reorder_playlist_sounds(
playlist_id=playlist_id,
user_id=current_user.id,
sound_positions=request.sound_positions,
)
return {"message": "Playlist sounds reordered successfully"}
@router.put("/{playlist_id}/set-current")
async def set_current_playlist(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistResponse:
"""Set a playlist as the current playlist."""
playlist = await playlist_service.set_current_playlist(playlist_id, current_user.id)
return PlaylistResponse.from_playlist(playlist)
@router.delete("/current")
async def unset_current_playlist(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> dict[str, str]:
"""Unset the current playlist."""
await playlist_service.unset_current_playlist(current_user.id)
return {"message": "Current playlist unset successfully"}
@router.get("/{playlist_id}/stats")
async def get_playlist_stats(
playlist_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
playlist_service: Annotated[PlaylistService, Depends(get_playlist_service)],
) -> PlaylistStatsResponse:
"""Get statistics for a playlist."""
stats = await playlist_service.get_playlist_stats(playlist_id)
return PlaylistStatsResponse(**stats)

View File

@@ -30,9 +30,7 @@ class Settings(BaseSettings):
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# JWT Configuration # JWT Configuration
JWT_SECRET_KEY: str = ( JWT_SECRET_KEY: str = "your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
)
JWT_ALGORITHM: str = "HS256" JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7

View File

@@ -0,0 +1,273 @@
"""Playlist repository for database operations."""
from typing import Any
from sqlalchemy import func
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.playlist import Playlist
from app.models.playlist_sound import PlaylistSound
from app.models.sound import Sound
logger = get_logger(__name__)
class PlaylistRepository:
"""Repository for playlist operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the playlist repository."""
self.session = session
async def get_by_id(self, playlist_id: int) -> Playlist | None:
"""Get a playlist by ID."""
try:
statement = select(Playlist).where(Playlist.id == playlist_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get playlist by ID: %s", playlist_id)
raise
async def get_by_name(self, name: str) -> Playlist | None:
"""Get a playlist by name."""
try:
statement = select(Playlist).where(Playlist.name == name)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get playlist by name: %s", name)
raise
async def get_by_user_id(self, user_id: int) -> list[Playlist]:
"""Get all playlists for a user."""
try:
statement = select(Playlist).where(Playlist.user_id == user_id)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get playlists for user: %s", user_id)
raise
async def get_all(self) -> list[Playlist]:
"""Get all playlists from all users."""
try:
statement = select(Playlist)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all playlists")
raise
async def get_main_playlist(self) -> Playlist | None:
"""Get the global main playlist."""
try:
statement = select(Playlist).where(
Playlist.is_main == True, # noqa: E712
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get main playlist")
raise
async def get_current_playlist(self, user_id: int) -> Playlist | None:
"""Get the user's current playlist."""
try:
statement = select(Playlist).where(
Playlist.user_id == user_id,
Playlist.is_current == True, # noqa: E712
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get current playlist for user: %s", user_id)
raise
async def create(self, playlist_data: dict[str, Any]) -> Playlist:
"""Create a new playlist."""
try:
playlist = Playlist(**playlist_data)
self.session.add(playlist)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to create playlist")
raise
else:
logger.info("Created new playlist: %s", playlist.name)
return playlist
async def update(self, playlist: Playlist, update_data: dict[str, Any]) -> Playlist:
"""Update a playlist."""
try:
for field, value in update_data.items():
setattr(playlist, field, value)
await self.session.commit()
await self.session.refresh(playlist)
except Exception:
await self.session.rollback()
logger.exception("Failed to update playlist")
raise
else:
logger.info("Updated playlist: %s", playlist.name)
return playlist
async def delete(self, playlist: Playlist) -> None:
"""Delete a playlist."""
try:
await self.session.delete(playlist)
await self.session.commit()
logger.info("Deleted playlist: %s", playlist.name)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete playlist")
raise
async def search_by_name(
self, query: str, user_id: int | None = None
) -> list[Playlist]:
"""Search playlists by name (case-insensitive)."""
try:
statement = select(Playlist).where(
func.lower(Playlist.name).like(f"%{query.lower()}%"),
)
if user_id is not None:
statement = statement.where(Playlist.user_id == user_id)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to search playlists by name: %s", query)
raise
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
"""Get all sounds in a playlist, ordered by position."""
try:
statement = (
select(Sound)
.join(PlaylistSound)
.where(PlaylistSound.playlist_id == playlist_id)
.order_by(PlaylistSound.position)
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get sounds for playlist: %s", playlist_id)
raise
async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, position: int | None = None
) -> PlaylistSound:
"""Add a sound to a playlist."""
try:
if position is None:
# Get the next available position
statement = select(
func.coalesce(func.max(PlaylistSound.position), -1) + 1
).where(PlaylistSound.playlist_id == playlist_id)
result = await self.session.exec(statement)
position = result.first() or 0
playlist_sound = PlaylistSound(
playlist_id=playlist_id,
sound_id=sound_id,
position=position,
)
self.session.add(playlist_sound)
await self.session.commit()
await self.session.refresh(playlist_sound)
except Exception:
await self.session.rollback()
logger.exception(
"Failed to add sound %s to playlist %s", sound_id, playlist_id
)
raise
else:
logger.info(
"Added sound %s to playlist %s at position %s",
sound_id,
playlist_id,
position,
)
return playlist_sound
async def remove_sound_from_playlist(self, playlist_id: int, sound_id: int) -> None:
"""Remove a sound from a playlist."""
try:
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.sound_id == sound_id,
)
result = await self.session.exec(statement)
playlist_sound = result.first()
if playlist_sound:
await self.session.delete(playlist_sound)
await self.session.commit()
logger.info("Removed sound %s from playlist %s", sound_id, playlist_id)
except Exception:
await self.session.rollback()
logger.exception(
"Failed to remove sound %s from playlist %s", sound_id, playlist_id
)
raise
async def reorder_playlist_sounds(
self, playlist_id: int, sound_positions: list[tuple[int, int]]
) -> None:
"""Reorder sounds in a playlist.
Args:
playlist_id: The playlist ID
sound_positions: List of (sound_id, new_position) tuples
"""
try:
for sound_id, new_position in sound_positions:
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.sound_id == sound_id,
)
result = await self.session.exec(statement)
playlist_sound = result.first()
if playlist_sound:
playlist_sound.position = new_position
await self.session.commit()
logger.info("Reordered sounds in playlist %s", playlist_id)
except Exception:
await self.session.rollback()
logger.exception("Failed to reorder sounds in playlist %s", playlist_id)
raise
async def get_playlist_sound_count(self, playlist_id: int) -> int:
"""Get the number of sounds in a playlist."""
try:
statement = select(func.count(PlaylistSound.id)).where(
PlaylistSound.playlist_id == playlist_id
)
result = await self.session.exec(statement)
return result.first() or 0
except Exception:
logger.exception("Failed to get sound count for playlist: %s", playlist_id)
raise
async def is_sound_in_playlist(self, playlist_id: int, sound_id: int) -> bool:
"""Check if a sound is already in a playlist."""
try:
statement = select(PlaylistSound).where(
PlaylistSound.playlist_id == playlist_id,
PlaylistSound.sound_id == sound_id,
)
result = await self.session.exec(statement)
return result.first() is not None
except Exception:
logger.exception(
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id
)
raise

View File

@@ -116,11 +116,7 @@ class SoundRepository:
async def get_popular_sounds(self, limit: int = 10) -> list[Sound]: async def get_popular_sounds(self, limit: int = 10) -> list[Sound]:
"""Get the most played sounds.""" """Get the most played sounds."""
try: try:
statement = ( statement = select(Sound).order_by(desc(Sound.play_count)).limit(limit)
select(Sound)
.order_by(desc(Sound.play_count))
.limit(limit)
)
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
except Exception: except Exception:
@@ -147,5 +143,7 @@ class SoundRepository:
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
except Exception: except Exception:
logger.exception("Failed to get unnormalized sounds by type: %s", sound_type) logger.exception(
"Failed to get unnormalized sounds by type: %s", sound_type
)
raise raise

View File

@@ -51,6 +51,7 @@ class UserRepository:
async def create(self, user_data: dict[str, Any]) -> User: async def create(self, user_data: dict[str, Any]) -> User:
"""Create a new user.""" """Create a new user."""
def _raise_plan_not_found() -> None: def _raise_plan_not_found() -> None:
msg = "Default plan not found" msg = "Default plan not found"
raise ValueError(msg) raise ValueError(msg)

View File

@@ -14,6 +14,7 @@ from app.models.extraction import Extraction
from app.models.sound import Sound from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
from app.services.playlist import PlaylistService
from app.services.sound_normalizer import SoundNormalizerService from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
@@ -41,6 +42,7 @@ class ExtractionService:
self.session = session self.session = session
self.extraction_repo = ExtractionRepository(session) self.extraction_repo = ExtractionRepository(session)
self.sound_repo = SoundRepository(session) self.sound_repo = SoundRepository(session)
self.playlist_service = PlaylistService(session)
# Ensure required directories exist # Ensure required directories exist
self._ensure_directories() self._ensure_directories()
@@ -447,20 +449,18 @@ class ExtractionService:
async def _add_to_main_playlist(self, sound: Sound, user_id: int) -> None: async def _add_to_main_playlist(self, sound: Sound, user_id: int) -> None:
"""Add the sound to the user's main playlist.""" """Add the sound to the user's main playlist."""
try: try:
# This is a placeholder - implement based on your playlist logic await self.playlist_service.add_sound_to_main_playlist(sound.id, user_id)
# For now, we'll just log that we would add it to the main playlist
logger.info( logger.info(
"Would add sound %d to main playlist for user %d", "Added sound %d to main playlist for user %d",
sound.id, sound.id,
user_id, user_id,
) )
except Exception as e: except Exception:
logger.exception( logger.exception(
"Error adding sound %d to main playlist for user %d: %s", "Error adding sound %d to main playlist for user %d",
sound.id, sound.id,
user_id, user_id,
e,
) )
# Don't fail the extraction if playlist addition fails # Don't fail the extraction if playlist addition fails

316
app/services/playlist.py Normal file
View File

@@ -0,0 +1,316 @@
"""Playlist service for business logic operations."""
from typing import Any
from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.playlist import Playlist
from app.models.sound import Sound
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
logger = get_logger(__name__)
class PlaylistService:
"""Service for playlist operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the playlist service."""
self.session = session
self.playlist_repo = PlaylistRepository(session)
self.sound_repo = SoundRepository(session)
async def get_playlist_by_id(self, playlist_id: int) -> Playlist:
"""Get a playlist by ID."""
playlist = await self.playlist_repo.get_by_id(playlist_id)
if not playlist:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Playlist not found",
)
return playlist
async def get_user_playlists(self, user_id: int) -> list[Playlist]:
"""Get all playlists for a user."""
return await self.playlist_repo.get_by_user_id(user_id)
async def get_all_playlists(self) -> list[Playlist]:
"""Get all playlists from all users."""
return await self.playlist_repo.get_all()
async def get_main_playlist(self) -> Playlist:
"""Get the global main playlist."""
main_playlist = await self.playlist_repo.get_main_playlist()
if not main_playlist:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Main playlist not found. Make sure to run database seeding."
)
return main_playlist
async def get_current_playlist(self, user_id: int) -> Playlist:
"""Get the user's current playlist, fallback to main playlist."""
current_playlist = await self.playlist_repo.get_current_playlist(user_id)
if current_playlist:
return current_playlist
# Fallback to main playlist if no current playlist is set
return await self.get_main_playlist()
async def create_playlist(
self,
user_id: int,
name: str,
description: str | None = None,
genre: str | None = None,
is_main: bool = False,
is_current: bool = False,
is_deletable: bool = True,
) -> Playlist:
"""Create a new playlist."""
# Check if name already exists for this user
existing_playlist = await self.playlist_repo.get_by_name(name)
if existing_playlist and existing_playlist.user_id == user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="A playlist with this name already exists",
)
# If this is set as current, unset the previous current playlist
if is_current:
await self._unset_current_playlist(user_id)
playlist_data = {
"user_id": user_id,
"name": name,
"description": description,
"genre": genre,
"is_main": is_main,
"is_current": is_current,
"is_deletable": is_deletable,
}
playlist = await self.playlist_repo.create(playlist_data)
logger.info("Created playlist '%s' for user %s", name, user_id)
return playlist
async def update_playlist(
self,
playlist_id: int,
user_id: int,
name: str | None = None,
description: str | None = None,
genre: str | None = None,
is_current: bool | None = None,
) -> Playlist:
"""Update a playlist."""
playlist = await self.get_playlist_by_id(playlist_id)
update_data: dict[str, Any] = {}
if name is not None:
# Check if new name conflicts with existing playlist
existing_playlist = await self.playlist_repo.get_by_name(name)
if (
existing_playlist
and existing_playlist.id != playlist_id
and existing_playlist.user_id == user_id
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="A playlist with this name already exists",
)
update_data["name"] = name
if description is not None:
update_data["description"] = description
if genre is not None:
update_data["genre"] = genre
if is_current is not None:
if is_current:
await self._unset_current_playlist(user_id)
update_data["is_current"] = is_current
if update_data:
playlist = await self.playlist_repo.update(playlist, update_data)
logger.info("Updated playlist %s for user %s", playlist_id, user_id)
return playlist
async def delete_playlist(self, playlist_id: int, user_id: int) -> None:
"""Delete a playlist."""
playlist = await self.get_playlist_by_id(playlist_id)
if not playlist.is_deletable:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="This playlist cannot be deleted",
)
# Check if this is the current playlist
was_current = playlist.is_current
await self.playlist_repo.delete(playlist)
logger.info("Deleted playlist %s for user %s", playlist_id, user_id)
# If the deleted playlist was current, set main playlist as current
if was_current:
await self._set_main_as_current(user_id)
async def search_playlists(self, query: str, user_id: int) -> list[Playlist]:
"""Search user's playlists by name."""
return await self.playlist_repo.search_by_name(query, user_id)
async def search_all_playlists(self, query: str) -> list[Playlist]:
"""Search all playlists by name."""
return await self.playlist_repo.search_by_name(query)
async def get_playlist_sounds(self, playlist_id: int) -> list[Sound]:
"""Get all sounds in a playlist."""
await self.get_playlist_by_id(playlist_id) # Verify playlist exists
return await self.playlist_repo.get_playlist_sounds(playlist_id)
async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, user_id: int, position: int | None = None
) -> None:
"""Add a sound to a playlist."""
# Verify playlist exists
await self.get_playlist_by_id(playlist_id)
# Verify sound exists
sound = await self.sound_repo.get_by_id(sound_id)
if not sound:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Sound not found",
)
# Check if sound is already in playlist
if await self.playlist_repo.is_sound_in_playlist(playlist_id, sound_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Sound is already in this playlist",
)
await self.playlist_repo.add_sound_to_playlist(playlist_id, sound_id, position)
logger.info(
"Added sound %s to playlist %s for user %s", sound_id, playlist_id, user_id
)
async def remove_sound_from_playlist(
self, playlist_id: int, sound_id: int, user_id: int
) -> None:
"""Remove a sound from a playlist."""
# Verify playlist exists
await self.get_playlist_by_id(playlist_id)
# Check if sound is in playlist
if not await self.playlist_repo.is_sound_in_playlist(playlist_id, sound_id):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Sound not found in this playlist",
)
await self.playlist_repo.remove_sound_from_playlist(playlist_id, sound_id)
logger.info(
"Removed sound %s from playlist %s for user %s",
sound_id,
playlist_id,
user_id,
)
async def reorder_playlist_sounds(
self, playlist_id: int, user_id: int, sound_positions: list[tuple[int, int]]
) -> None:
"""Reorder sounds in a playlist."""
# Verify playlist exists
await self.get_playlist_by_id(playlist_id)
# Validate all sounds are in the playlist
for sound_id, _ in sound_positions:
if not await self.playlist_repo.is_sound_in_playlist(playlist_id, sound_id):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Sound {sound_id} is not in this playlist",
)
await self.playlist_repo.reorder_playlist_sounds(playlist_id, sound_positions)
logger.info("Reordered sounds in playlist %s for user %s", playlist_id, user_id)
async def set_current_playlist(self, playlist_id: int, user_id: int) -> Playlist:
"""Set a playlist as the current playlist."""
playlist = await self.get_playlist_by_id(playlist_id)
# Unset previous current playlist
await self._unset_current_playlist(user_id)
# Set new current playlist
playlist = await self.playlist_repo.update(playlist, {"is_current": True})
logger.info("Set playlist %s as current for user %s", playlist_id, user_id)
return playlist
async def unset_current_playlist(self, user_id: int) -> None:
"""Unset the current playlist and set main playlist as current."""
await self._unset_current_playlist(user_id)
await self._set_main_as_current(user_id)
logger.info(
"Unset current playlist and set main as current for user %s", user_id
)
async def get_playlist_stats(self, playlist_id: int) -> dict[str, Any]:
"""Get statistics for a playlist."""
await self.get_playlist_by_id(playlist_id) # Verify playlist exists
sound_count = await self.playlist_repo.get_playlist_sound_count(playlist_id)
sounds = await self.playlist_repo.get_playlist_sounds(playlist_id)
total_duration = sum(sound.duration or 0 for sound in sounds)
total_plays = sum(sound.play_count or 0 for sound in sounds)
return {
"sound_count": sound_count,
"total_duration_ms": total_duration,
"total_play_count": total_plays,
}
async def add_sound_to_main_playlist(self, sound_id: int, user_id: int) -> None:
"""Add a sound to the global main playlist."""
main_playlist = await self.get_main_playlist()
if main_playlist.id is None:
raise ValueError("Main playlist has no ID")
# Check if sound is already in main playlist
if not await self.playlist_repo.is_sound_in_playlist(
main_playlist.id, sound_id
):
await self.playlist_repo.add_sound_to_playlist(main_playlist.id, sound_id)
logger.info(
"Added sound %s to main playlist for user %s",
sound_id,
user_id,
)
async def _unset_current_playlist(self, user_id: int) -> None:
"""Unset the current playlist for a user."""
current_playlist = await self.playlist_repo.get_current_playlist(user_id)
if current_playlist:
await self.playlist_repo.update(current_playlist, {"is_current": False})
async def _set_main_as_current(self, user_id: int) -> None:
"""Unset current playlist so main playlist becomes the fallback current."""
# Just ensure no user playlist is marked as current
# The get_current_playlist method will fallback to main playlist
await self._unset_current_playlist(user_id)
logger.info(
"Unset current playlist for user %s, main playlist is now fallback",
user_id,
)

View File

@@ -107,7 +107,6 @@ class SoundNormalizerService:
original_dir = type_to_original_dir.get(sound_type, "sounds/originals/other") original_dir = type_to_original_dir.get(sound_type, "sounds/originals/other")
return Path(original_dir) / filename return Path(original_dir) / filename
async def _normalize_audio_one_pass( async def _normalize_audio_one_pass(
self, self,
input_path: Path, input_path: Path,
@@ -178,9 +177,12 @@ class SoundNormalizerService:
result = ffmpeg.run(stream, capture_stderr=True, quiet=True) result = ffmpeg.run(stream, capture_stderr=True, quiet=True)
analysis_output = result[1].decode("utf-8") analysis_output = result[1].decode("utf-8")
except ffmpeg.Error as e: except ffmpeg.Error as e:
logger.error("FFmpeg first pass failed for %s. Stdout: %s, Stderr: %s", logger.error(
input_path, e.stdout.decode() if e.stdout else "None", "FFmpeg first pass failed for %s. Stdout: %s, Stderr: %s",
e.stderr.decode() if e.stderr else "None") input_path,
e.stdout.decode() if e.stdout else "None",
e.stderr.decode() if e.stderr else "None",
)
raise raise
# Extract loudnorm measurements from the output # Extract loudnorm measurements from the output
@@ -190,7 +192,9 @@ class SoundNormalizerService:
# Find JSON in the output # Find JSON in the output
json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output) json_match = re.search(r'\{[^{}]*"input_i"[^{}]*\}', analysis_output)
if not json_match: if not json_match:
logger.error("Could not find JSON in loudnorm output: %s", analysis_output) logger.error(
"Could not find JSON in loudnorm output: %s", analysis_output
)
raise ValueError("Could not extract loudnorm analysis data") raise ValueError("Could not extract loudnorm analysis data")
logger.debug("Found JSON match: %s", json_match.group()) logger.debug("Found JSON match: %s", json_match.group())
@@ -198,11 +202,18 @@ class SoundNormalizerService:
# Check for invalid values that would cause second pass to fail # Check for invalid values that would cause second pass to fail
invalid_values = ["-inf", "inf", "nan"] invalid_values = ["-inf", "inf", "nan"]
for key in ["input_i", "input_lra", "input_tp", "input_thresh", "target_offset"]: for key in [
"input_i",
"input_lra",
"input_tp",
"input_thresh",
"target_offset",
]:
if str(analysis_data.get(key, "")).lower() in invalid_values: if str(analysis_data.get(key, "")).lower() in invalid_values:
logger.warning( logger.warning(
"Invalid analysis value for %s: %s. Falling back to one-pass normalization.", "Invalid analysis value for %s: %s. Falling back to one-pass normalization.",
key, analysis_data.get(key) key,
analysis_data.get(key),
) )
# Fall back to one-pass normalization # Fall back to one-pass normalization
await self._normalize_audio_one_pass(input_path, output_path) await self._normalize_audio_one_pass(input_path, output_path)
@@ -241,9 +252,12 @@ class SoundNormalizerService:
ffmpeg.run(stream, quiet=True, overwrite_output=True) ffmpeg.run(stream, quiet=True, overwrite_output=True)
logger.info("Two-pass normalization completed: %s", output_path) logger.info("Two-pass normalization completed: %s", output_path)
except ffmpeg.Error as e: except ffmpeg.Error as e:
logger.error("FFmpeg second pass failed for %s. Stdout: %s, Stderr: %s", logger.error(
input_path, e.stdout.decode() if e.stdout else "None", "FFmpeg second pass failed for %s. Stdout: %s, Stderr: %s",
e.stderr.decode() if e.stderr else "None") input_path,
e.stdout.decode() if e.stdout else "None",
e.stderr.decode() if e.stderr else "None",
)
raise raise
except Exception as e: except Exception as e:

View File

@@ -56,7 +56,6 @@ class SoundScannerService:
".aac", ".aac",
} }
def extract_name_from_filename(self, filename: str) -> str: def extract_name_from_filename(self, filename: str) -> str:
"""Extract a clean name from filename.""" """Extract a clean name from filename."""
# Remove extension # Remove extension

View File

@@ -1,7 +1,6 @@
"""Cookie parsing utilities for WebSocket authentication.""" """Cookie parsing utilities for WebSocket authentication."""
def parse_cookies(cookie_header: str) -> dict[str, str]: def parse_cookies(cookie_header: str) -> dict[str, str]:
"""Parse HTTP cookie header into a dictionary.""" """Parse HTTP cookie header into a dictionary."""
cookies = {} cookies = {}

View File

@@ -14,7 +14,9 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_api_token_success( async def test_generate_api_token_success(
self, authenticated_client: AsyncClient, authenticated_user: User, self,
authenticated_client: AsyncClient,
authenticated_user: User,
): ):
"""Test successful API token generation.""" """Test successful API token generation."""
request_data = {"expires_days": 30} request_data = {"expires_days": 30}
@@ -33,6 +35,7 @@ class TestApiTokenEndpoints:
# Verify token format (should be URL-safe base64) # Verify token format (should be URL-safe base64)
import base64 import base64
try: try:
base64.urlsafe_b64decode(data["api_token"] + "===") # Add padding base64.urlsafe_b64decode(data["api_token"] + "===") # Add padding
except Exception: except Exception:
@@ -40,7 +43,8 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_api_token_default_expiry( async def test_generate_api_token_default_expiry(
self, authenticated_client: AsyncClient, self,
authenticated_client: AsyncClient,
): ):
"""Test API token generation with default expiry.""" """Test API token generation with default expiry."""
response = await authenticated_client.post("/api/v1/auth/api-token", json={}) response = await authenticated_client.post("/api/v1/auth/api-token", json={})
@@ -65,7 +69,8 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_api_token_custom_expiry( async def test_generate_api_token_custom_expiry(
self, authenticated_client: AsyncClient, self,
authenticated_client: AsyncClient,
): ):
"""Test API token generation with custom expiry.""" """Test API token generation with custom expiry."""
expires_days = 90 expires_days = 90
@@ -96,7 +101,8 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_api_token_validation_errors( async def test_generate_api_token_validation_errors(
self, authenticated_client: AsyncClient, self,
authenticated_client: AsyncClient,
): ):
"""Test API token generation with validation errors.""" """Test API token generation with validation errors."""
# Test minimum validation # Test minimum validation
@@ -124,7 +130,8 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_api_token_status_no_token( async def test_get_api_token_status_no_token(
self, authenticated_client: AsyncClient, self,
authenticated_client: AsyncClient,
): ):
"""Test getting API token status when user has no token.""" """Test getting API token status when user has no token."""
response = await authenticated_client.get("/api/v1/auth/api-token/status") response = await authenticated_client.get("/api/v1/auth/api-token/status")
@@ -138,7 +145,8 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_api_token_status_with_token( async def test_get_api_token_status_with_token(
self, authenticated_client: AsyncClient, self,
authenticated_client: AsyncClient,
): ):
"""Test getting API token status when user has a token.""" """Test getting API token status when user has a token."""
# First generate a token # First generate a token
@@ -159,14 +167,18 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_api_token_status_expired_token( async def test_get_api_token_status_expired_token(
self, authenticated_client: AsyncClient, authenticated_user: User, self,
authenticated_client: AsyncClient,
authenticated_user: User,
): ):
"""Test getting API token status with expired token.""" """Test getting API token status with expired token."""
# Mock expired token # Mock expired token
with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True): with patch("app.utils.auth.TokenUtils.is_token_expired", return_value=True):
# Set a token on the user # Set a token on the user
authenticated_user.api_token = "expired_token" authenticated_user.api_token = "expired_token"
authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta(days=1) authenticated_user.api_token_expires_at = datetime.now(UTC) - timedelta(
days=1
)
response = await authenticated_client.get("/api/v1/auth/api-token/status") response = await authenticated_client.get("/api/v1/auth/api-token/status")
@@ -185,7 +197,8 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_revoke_api_token_success( async def test_revoke_api_token_success(
self, authenticated_client: AsyncClient, self,
authenticated_client: AsyncClient,
): ):
"""Test successful API token revocation.""" """Test successful API token revocation."""
# First generate a token # First generate a token
@@ -195,7 +208,9 @@ class TestApiTokenEndpoints:
) )
# Verify token exists # Verify token exists
status_response = await authenticated_client.get("/api/v1/auth/api-token/status") status_response = await authenticated_client.get(
"/api/v1/auth/api-token/status"
)
assert status_response.json()["has_token"] is True assert status_response.json()["has_token"] is True
# Revoke the token # Revoke the token
@@ -206,12 +221,15 @@ class TestApiTokenEndpoints:
assert data["message"] == "API token revoked successfully" assert data["message"] == "API token revoked successfully"
# Verify token is gone # Verify token is gone
status_response = await authenticated_client.get("/api/v1/auth/api-token/status") status_response = await authenticated_client.get(
"/api/v1/auth/api-token/status"
)
assert status_response.json()["has_token"] is False assert status_response.json()["has_token"] is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_revoke_api_token_no_token( async def test_revoke_api_token_no_token(
self, authenticated_client: AsyncClient, self,
authenticated_client: AsyncClient,
): ):
"""Test revoking API token when user has no token.""" """Test revoking API token when user has no token."""
response = await authenticated_client.delete("/api/v1/auth/api-token") response = await authenticated_client.delete("/api/v1/auth/api-token")
@@ -228,7 +246,9 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_authentication_success( async def test_api_token_authentication_success(
self, client: AsyncClient, authenticated_client: AsyncClient, self,
client: AsyncClient,
authenticated_client: AsyncClient,
): ):
"""Test successful authentication using API token.""" """Test successful authentication using API token."""
# Generate API token # Generate API token
@@ -259,7 +279,9 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_authentication_expired_token( async def test_api_token_authentication_expired_token(
self, client: AsyncClient, authenticated_client: AsyncClient, self,
client: AsyncClient,
authenticated_client: AsyncClient,
): ):
"""Test authentication with expired API token.""" """Test authentication with expired API token."""
# Generate API token # Generate API token
@@ -299,7 +321,10 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_authentication_inactive_user( async def test_api_token_authentication_inactive_user(
self, client: AsyncClient, authenticated_client: AsyncClient, authenticated_user: User, self,
client: AsyncClient,
authenticated_client: AsyncClient,
authenticated_user: User,
): ):
"""Test authentication with API token for inactive user.""" """Test authentication with API token for inactive user."""
# Generate API token # Generate API token
@@ -322,7 +347,10 @@ class TestApiTokenEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flexible_authentication_prefers_api_token( async def test_flexible_authentication_prefers_api_token(
self, client: AsyncClient, authenticated_client: AsyncClient, auth_cookies: dict[str, str], self,
client: AsyncClient,
authenticated_client: AsyncClient,
auth_cookies: dict[str, str],
): ):
"""Test that flexible authentication prefers API token over cookie.""" """Test that flexible authentication prefers API token over cookie."""
# Generate API token # Generate API token

View File

@@ -73,7 +73,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_duplicate_email( async def test_register_duplicate_email(
self, test_client: AsyncClient, test_user: User, self,
test_client: AsyncClient,
test_user: User,
) -> None: ) -> None:
"""Test registration with duplicate email.""" """Test registration with duplicate email."""
user_data = { user_data = {
@@ -128,7 +130,10 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_success( async def test_login_success(
self, test_client: AsyncClient, test_user: User, test_login_data: dict[str, str], self,
test_client: AsyncClient,
test_user: User,
test_login_data: dict[str, str],
) -> None: ) -> None:
"""Test successful user login.""" """Test successful user login."""
response = await test_client.post("/api/v1/auth/login", json=test_login_data) response = await test_client.post("/api/v1/auth/login", json=test_login_data)
@@ -161,7 +166,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_invalid_password( async def test_login_invalid_password(
self, test_client: AsyncClient, test_user: User, self,
test_client: AsyncClient,
test_user: User,
) -> None: ) -> None:
"""Test login with invalid password.""" """Test login with invalid password."""
login_data = {"email": test_user.email, "password": "wrongpassword"} login_data = {"email": test_user.email, "password": "wrongpassword"}
@@ -183,7 +190,10 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_success( async def test_get_current_user_success(
self, test_client: AsyncClient, test_user: User, auth_cookies: dict[str, str], self,
test_client: AsyncClient,
test_user: User,
auth_cookies: dict[str, str],
) -> None: ) -> None:
"""Test getting current user info successfully.""" """Test getting current user info successfully."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -210,7 +220,8 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_invalid_token( async def test_get_current_user_invalid_token(
self, test_client: AsyncClient, self,
test_client: AsyncClient,
) -> None: ) -> None:
"""Test getting current user with invalid token.""" """Test getting current user with invalid token."""
# Set invalid cookies on client instance # Set invalid cookies on client instance
@@ -223,7 +234,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_expired_token( async def test_get_current_user_expired_token(
self, test_client: AsyncClient, test_user: User, self,
test_client: AsyncClient,
test_user: User,
) -> None: ) -> None:
"""Test getting current user with expired token.""" """Test getting current user with expired token."""
from datetime import timedelta from datetime import timedelta
@@ -237,7 +250,8 @@ class TestAuthEndpoints:
"role": "user", "role": "user",
} }
expired_token = JWTUtils.create_access_token( expired_token = JWTUtils.create_access_token(
token_data, expires_delta=timedelta(seconds=-1), token_data,
expires_delta=timedelta(seconds=-1),
) )
# Set expired cookies on client instance # Set expired cookies on client instance
@@ -262,7 +276,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_access_with_user_role( async def test_admin_access_with_user_role(
self, test_client: AsyncClient, auth_cookies: dict[str, str], self,
test_client: AsyncClient,
auth_cookies: dict[str, str],
) -> None: ) -> None:
"""Test that regular users cannot access admin endpoints.""" """Test that regular users cannot access admin endpoints."""
# This test would be for admin-only endpoints when they're created # This test would be for admin-only endpoints when they're created
@@ -293,7 +309,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_access_with_admin_role( async def test_admin_access_with_admin_role(
self, test_client: AsyncClient, admin_cookies: dict[str, str], self,
test_client: AsyncClient,
admin_cookies: dict[str, str],
) -> None: ) -> None:
"""Test that admin users can access admin endpoints.""" """Test that admin users can access admin endpoints."""
from app.core.dependencies import get_admin_user from app.core.dependencies import get_admin_user
@@ -357,7 +375,8 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_authorize_invalid_provider( async def test_oauth_authorize_invalid_provider(
self, test_client: AsyncClient, self,
test_client: AsyncClient,
) -> None: ) -> None:
"""Test OAuth authorization with invalid provider.""" """Test OAuth authorization with invalid provider."""
response = await test_client.get("/api/v1/auth/invalid/authorize") response = await test_client.get("/api/v1/auth/invalid/authorize")
@@ -368,7 +387,9 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_callback_new_user( async def test_oauth_callback_new_user(
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any], self,
test_client: AsyncClient,
ensure_plans: tuple[Any, Any],
) -> None: ) -> None:
"""Test OAuth callback for new user creation.""" """Test OAuth callback for new user creation."""
# Mock OAuth user info # Mock OAuth user info
@@ -400,7 +421,10 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_callback_existing_user_link( async def test_oauth_callback_existing_user_link(
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any], self,
test_client: AsyncClient,
test_user: Any,
ensure_plans: tuple[Any, Any],
) -> None: ) -> None:
"""Test OAuth callback for linking to existing user.""" """Test OAuth callback for linking to existing user."""
# Mock OAuth user info with same email as test user # Mock OAuth user info with same email as test user
@@ -442,7 +466,8 @@ class TestAuthEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_oauth_callback_invalid_provider( async def test_oauth_callback_invalid_provider(
self, test_client: AsyncClient, self,
test_client: AsyncClient,
) -> None: ) -> None:
"""Test OAuth callback with invalid provider.""" """Test OAuth callback with invalid provider."""
response = await test_client.get( response = await test_client.get(

File diff suppressed because it is too large Load Diff

View File

@@ -22,7 +22,12 @@ class TestSocketEndpoints:
"""Test socket API endpoints.""" """Test socket API endpoints."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_socket_status_authenticated(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): async def test_get_socket_status_authenticated(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
mock_socket_manager,
):
"""Test getting socket status for authenticated user.""" """Test getting socket status for authenticated user."""
response = await authenticated_client.get("/api/v1/socket/status") response = await authenticated_client.get("/api/v1/socket/status")
@@ -43,7 +48,12 @@ class TestSocketEndpoints:
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_to_user_success(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): async def test_send_message_to_user_success(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
mock_socket_manager,
):
"""Test sending message to specific user successfully.""" """Test sending message to specific user successfully."""
target_user_id = 2 target_user_id = 2
message = "Hello there!" message = "Hello there!"
@@ -72,7 +82,12 @@ class TestSocketEndpoints:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_to_user_not_connected(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): async def test_send_message_to_user_not_connected(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
mock_socket_manager,
):
"""Test sending message to user who is not connected.""" """Test sending message to user who is not connected."""
target_user_id = 999 target_user_id = 999
message = "Hello there!" message = "Hello there!"
@@ -102,7 +117,12 @@ class TestSocketEndpoints:
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_broadcast_message_success(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): async def test_broadcast_message_success(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
mock_socket_manager,
):
"""Test broadcasting message to all users successfully.""" """Test broadcasting message to all users successfully."""
message = "Important announcement!" message = "Important announcement!"
@@ -137,7 +157,9 @@ class TestSocketEndpoints:
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_missing_parameters(self, authenticated_client: AsyncClient, authenticated_user: User): async def test_send_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User
):
"""Test sending message with missing parameters.""" """Test sending message with missing parameters."""
# Missing target_user_id # Missing target_user_id
response = await authenticated_client.post( response = await authenticated_client.post(
@@ -154,13 +176,17 @@ class TestSocketEndpoints:
assert response.status_code == 422 assert response.status_code == 422
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_broadcast_message_missing_parameters(self, authenticated_client: AsyncClient, authenticated_user: User): async def test_broadcast_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User
):
"""Test broadcasting message with missing parameters.""" """Test broadcasting message with missing parameters."""
response = await authenticated_client.post("/api/v1/socket/broadcast") response = await authenticated_client.post("/api/v1/socket/broadcast")
assert response.status_code == 422 assert response.status_code == 422
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_invalid_user_id(self, authenticated_client: AsyncClient, authenticated_user: User): async def test_send_message_invalid_user_id(
self, authenticated_client: AsyncClient, authenticated_user: User
):
"""Test sending message with invalid user ID.""" """Test sending message with invalid user ID."""
response = await authenticated_client.post( response = await authenticated_client.post(
"/api/v1/socket/send-message", "/api/v1/socket/send-message",
@@ -169,10 +195,19 @@ class TestSocketEndpoints:
assert response.status_code == 422 assert response.status_code == 422
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_socket_status_shows_user_connection(self, authenticated_client: AsyncClient, authenticated_user: User, mock_socket_manager): async def test_socket_status_shows_user_connection(
self,
authenticated_client: AsyncClient,
authenticated_user: User,
mock_socket_manager,
):
"""Test that socket status correctly shows if user is connected.""" """Test that socket status correctly shows if user is connected."""
# Test when user is connected # Test when user is connected
mock_socket_manager.get_connected_users.return_value = [str(authenticated_user.id), "2", "3"] mock_socket_manager.get_connected_users.return_value = [
str(authenticated_user.id),
"2",
"3",
]
response = await authenticated_client.get("/api/v1/socket/status") response = await authenticated_client.get("/api/v1/socket/status")
data = response.json() data = response.json()

View File

@@ -870,7 +870,6 @@ class TestSoundEndpoints:
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
mock_get_sound.return_value = mock_sound mock_get_sound.return_value = mock_sound
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
@@ -950,7 +949,6 @@ class TestSoundEndpoints:
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
mock_get_sound.return_value = mock_sound mock_get_sound.return_value = mock_sound
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
@@ -1003,7 +1001,6 @@ class TestSoundEndpoints:
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
mock_get_sound.return_value = mock_sound mock_get_sound.return_value = mock_sound
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result
@@ -1059,7 +1056,6 @@ class TestSoundEndpoints:
) as mock_normalize_sound, ) as mock_normalize_sound,
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
): ):
mock_get_sound.return_value = mock_sound mock_get_sound.return_value = mock_sound
mock_normalize_sound.return_value = mock_result mock_normalize_sound.return_value = mock_result

View File

@@ -103,7 +103,8 @@ async def test_client(test_app) -> AsyncGenerator[AsyncClient, None]:
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def authenticated_client( async def authenticated_client(
test_app: FastAPI, auth_cookies: dict[str, str], test_app: FastAPI,
auth_cookies: dict[str, str],
) -> AsyncGenerator[AsyncClient, None]: ) -> AsyncGenerator[AsyncClient, None]:
"""Create a test HTTP client with authentication cookies.""" """Create a test HTTP client with authentication cookies."""
async with AsyncClient( async with AsyncClient(
@@ -116,7 +117,8 @@ async def authenticated_client(
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def authenticated_admin_client( async def authenticated_admin_client(
test_app: FastAPI, admin_cookies: dict[str, str], test_app: FastAPI,
admin_cookies: dict[str, str],
) -> AsyncGenerator[AsyncClient, None]: ) -> AsyncGenerator[AsyncClient, None]:
"""Create a test HTTP client with admin authentication cookies.""" """Create a test HTTP client with admin authentication cookies."""
async with AsyncClient( async with AsyncClient(
@@ -211,7 +213,8 @@ async def ensure_plans(test_session: AsyncSession) -> tuple[Plan, Plan]:
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def test_user( async def test_user(
test_session: AsyncSession, ensure_plans: tuple[Plan, Plan], test_session: AsyncSession,
ensure_plans: tuple[Plan, Plan],
) -> User: ) -> User:
"""Create a test user.""" """Create a test user."""
user = User( user = User(
@@ -231,7 +234,8 @@ async def test_user(
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def admin_user( async def admin_user(
test_session: AsyncSession, ensure_plans: tuple[Plan, Plan], test_session: AsyncSession,
ensure_plans: tuple[Plan, Plan],
) -> User: ) -> User:
"""Create a test admin user.""" """Create a test admin user."""
user = User( user = User(

View File

@@ -36,7 +36,9 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_success( async def test_get_current_user_api_token_success(
self, mock_auth_service, test_user, self,
mock_auth_service,
test_user,
): ):
"""Test successful API token authentication.""" """Test successful API token authentication."""
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user
@@ -46,7 +48,9 @@ class TestApiTokenDependencies:
result = await get_current_user_api_token(mock_auth_service, api_token_header) result = await get_current_user_api_token(mock_auth_service, api_token_header)
assert result == test_user assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123") mock_auth_service.get_user_by_api_token.assert_called_once_with(
"test_api_token_123"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_no_header(self, mock_auth_service): async def test_get_current_user_api_token_no_header(self, mock_auth_service):
@@ -94,7 +98,9 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_expired_token( async def test_get_current_user_api_token_expired_token(
self, mock_auth_service, test_user, self,
mock_auth_service,
test_user,
): ):
"""Test API token authentication with expired token.""" """Test API token authentication with expired token."""
# Set expired token # Set expired token
@@ -111,7 +117,9 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_inactive_user( async def test_get_current_user_api_token_inactive_user(
self, mock_auth_service, test_user, self,
mock_auth_service,
test_user,
): ):
"""Test API token authentication with inactive user.""" """Test API token authentication with inactive user."""
test_user.is_active = False test_user.is_active = False
@@ -126,9 +134,13 @@ class TestApiTokenDependencies:
assert "Account is deactivated" in exc_info.value.detail assert "Account is deactivated" in exc_info.value.detail
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_service_exception(self, mock_auth_service): async def test_get_current_user_api_token_service_exception(
self, mock_auth_service
):
"""Test API token authentication with service exception.""" """Test API token authentication with service exception."""
mock_auth_service.get_user_by_api_token.side_effect = Exception("Database error") mock_auth_service.get_user_by_api_token.side_effect = Exception(
"Database error"
)
api_token_header = "test_token" api_token_header = "test_token"
@@ -140,7 +152,9 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_flexible_uses_api_token( async def test_get_current_user_flexible_uses_api_token(
self, mock_auth_service, test_user, self,
mock_auth_service,
test_user,
): ):
"""Test flexible authentication uses API token when available.""" """Test flexible authentication uses API token when available."""
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user
@@ -149,11 +163,15 @@ class TestApiTokenDependencies:
access_token = "jwt_token" access_token = "jwt_token"
result = await get_current_user_flexible( result = await get_current_user_flexible(
mock_auth_service, access_token, api_token_header, mock_auth_service,
access_token,
api_token_header,
) )
assert result == test_user assert result == test_user
mock_auth_service.get_user_by_api_token.assert_called_once_with("test_api_token_123") mock_auth_service.get_user_by_api_token.assert_called_once_with(
"test_api_token_123"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_flexible_falls_back_to_jwt(self, mock_auth_service): async def test_get_current_user_flexible_falls_back_to_jwt(self, mock_auth_service):
@@ -165,7 +183,9 @@ class TestApiTokenDependencies:
await get_current_user_flexible(mock_auth_service, "jwt_token", None) await get_current_user_flexible(mock_auth_service, "jwt_token", None)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_no_expiry_never_expires(self, mock_auth_service, test_user): async def test_api_token_no_expiry_never_expires(
self, mock_auth_service, test_user
):
"""Test API token with no expiry date never expires.""" """Test API token with no expiry date never expires."""
test_user.api_token_expires_at = None test_user.api_token_expires_at = None
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user

View File

@@ -0,0 +1,828 @@
"""Tests for playlist repository."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.playlist import Playlist
from app.models.sound import Sound
from app.models.user import User
from app.repositories.playlist import PlaylistRepository
class TestPlaylistRepository:
"""Test playlist repository operations."""
@pytest_asyncio.fixture
async def playlist_repository(
self,
test_session: AsyncSession,
) -> AsyncGenerator[PlaylistRepository, None]:
"""Create a playlist repository instance."""
yield PlaylistRepository(test_session)
@pytest_asyncio.fixture
async def test_playlist(
self,
test_session: AsyncSession,
test_user: User,
) -> Playlist:
"""Create a test playlist."""
playlist = Playlist(
user_id=test_user.id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
return playlist
@pytest_asyncio.fixture
async def main_playlist(
self,
test_session: AsyncSession,
) -> Playlist:
"""Create a main playlist."""
playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_current=False,
is_deletable=False,
)
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
return playlist
@pytest_asyncio.fixture
async def test_sound(
self,
test_session: AsyncSession,
) -> Sound:
"""Create a test sound."""
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=0,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(sound)
return sound
@pytest.mark.asyncio
async def test_get_by_id_existing(
self,
playlist_repository: PlaylistRepository,
test_playlist: Playlist,
) -> None:
"""Test getting playlist by ID when playlist exists."""
assert test_playlist.id is not None
playlist = await playlist_repository.get_by_id(test_playlist.id)
assert playlist is not None
assert playlist.id == test_playlist.id
assert playlist.name == test_playlist.name
assert playlist.description == test_playlist.description
@pytest.mark.asyncio
async def test_get_by_id_nonexistent(
self,
playlist_repository: PlaylistRepository,
) -> None:
"""Test getting playlist by ID when playlist doesn't exist."""
playlist = await playlist_repository.get_by_id(99999)
assert playlist is None
@pytest.mark.asyncio
async def test_get_by_name_existing(
self,
playlist_repository: PlaylistRepository,
test_playlist: Playlist,
) -> None:
"""Test getting playlist by name when playlist exists."""
playlist = await playlist_repository.get_by_name(test_playlist.name)
assert playlist is not None
assert playlist.id == test_playlist.id
assert playlist.name == test_playlist.name
@pytest.mark.asyncio
async def test_get_by_name_nonexistent(
self,
playlist_repository: PlaylistRepository,
) -> None:
"""Test getting playlist by name when playlist doesn't exist."""
playlist = await playlist_repository.get_by_name("Nonexistent Playlist")
assert playlist is None
@pytest.mark.asyncio
async def test_get_by_user_id(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test getting playlists by user ID."""
# Create test user within this test
from app.utils.auth import PasswordUtils
user = User(
email="test@example.com",
name="Test User",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist for this user
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
await test_session.commit()
# Test the repository method
playlists = await playlist_repository.get_by_user_id(user_id)
# Should only return user's playlists, not the main playlist (user_id=None)
assert len(playlists) == 1
assert playlists[0].name == "Test Playlist"
@pytest.mark.asyncio
async def test_get_main_playlist(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
) -> None:
"""Test getting main playlist."""
# Create main playlist within this test
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_current=False,
is_deletable=False,
)
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(main_playlist)
# Extract ID before async call
main_playlist_id = main_playlist.id
# Test the repository method
playlist = await playlist_repository.get_main_playlist()
assert playlist is not None
assert playlist.id == main_playlist_id
assert playlist.is_main is True
@pytest.mark.asyncio
async def test_get_current_playlist(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test getting current playlist when none is set."""
# Create test user within this test
from app.utils.auth import PasswordUtils
user = User(
email="test2@example.com",
name="Test User 2",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Test the repository method - should return None when no current playlist
playlist = await playlist_repository.get_current_playlist(user_id)
# Should return None since no user playlist is marked as current
assert playlist is None
@pytest.mark.asyncio
async def test_create_playlist(
self,
playlist_repository: PlaylistRepository,
test_user: User,
) -> None:
"""Test creating a new playlist."""
playlist_data = {
"user_id": test_user.id,
"name": "New Playlist",
"description": "A new playlist",
"genre": "rock",
"is_main": False,
"is_current": False,
"is_deletable": True,
}
playlist = await playlist_repository.create(playlist_data)
assert playlist.id is not None
assert playlist.name == "New Playlist"
assert playlist.description == "A new playlist"
assert playlist.genre == "rock"
assert playlist.is_main is False
assert playlist.is_current is False
assert playlist.is_deletable is True
@pytest.mark.asyncio
async def test_update_playlist(
self,
playlist_repository: PlaylistRepository,
test_playlist: Playlist,
) -> None:
"""Test updating a playlist."""
update_data = {
"name": "Updated Playlist",
"description": "Updated description",
"genre": "jazz",
}
updated_playlist = await playlist_repository.update(test_playlist, update_data)
assert updated_playlist.name == "Updated Playlist"
assert updated_playlist.description == "Updated description"
assert updated_playlist.genre == "jazz"
@pytest.mark.asyncio
async def test_delete_playlist(
self,
playlist_repository: PlaylistRepository,
test_playlist: Playlist,
) -> None:
"""Test deleting a playlist."""
playlist_id = test_playlist.id
await playlist_repository.delete(test_playlist)
# Verify playlist is deleted
deleted_playlist = await playlist_repository.get_by_id(playlist_id)
assert deleted_playlist is None
@pytest.mark.asyncio
async def test_search_by_name(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test searching playlists by name."""
# Create test user within this test
from app.utils.auth import PasswordUtils
user = User(
email="test3@example.com",
name="Test User 3",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist
test_playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(test_playlist)
# Create main playlist
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_current=False,
is_deletable=False,
)
test_session.add(main_playlist)
await test_session.commit()
# Search for all playlists (no user filter)
all_results = await playlist_repository.search_by_name("playlist")
assert len(all_results) >= 2 # Should include both user and main playlists
# Search with user filter
user_results = await playlist_repository.search_by_name("playlist", user_id)
assert len(user_results) == 1 # Only user's playlists, not main playlist
# Search for specific playlist
test_results = await playlist_repository.search_by_name("test", user_id)
assert len(test_results) == 1
assert test_results[0].name == "Test Playlist"
@pytest.mark.asyncio
async def test_add_sound_to_playlist(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test adding a sound to a playlist."""
# Create test user within this test
from app.utils.auth import PasswordUtils
user = User(
email="test4@example.com",
name="Test User 4",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Create test playlist
playlist = Playlist(
user_id=user.id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
# Create test sound
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=0,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async call
playlist_id = playlist.id
sound_id = sound.id
# Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id
)
assert playlist_sound.playlist_id == playlist_id
assert playlist_sound.sound_id == sound_id
assert playlist_sound.position == 0
@pytest.mark.asyncio
async def test_add_sound_to_playlist_with_position(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test adding a sound to a playlist with specific position."""
# Create test user within this test
from app.utils.auth import PasswordUtils
user = User(
email="test5@example.com",
name="Test User 5",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Extract user ID immediately after refresh
user_id = user.id
# Create test playlist
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
# Create test sound
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=0,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async call
playlist_id = playlist.id
sound_id = sound.id
# Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id, position=5
)
assert playlist_sound.position == 5
@pytest.mark.asyncio
async def test_remove_sound_from_playlist(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test removing a sound from a playlist."""
# Create objects within this test
from app.utils.auth import PasswordUtils
user = User(
email="test@example.com",
name="Test User",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=0,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# First add the sound
await playlist_repository.add_sound_to_playlist(playlist_id, sound_id)
# Verify it was added
assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
)
# Remove the sound
await playlist_repository.remove_sound_from_playlist(
playlist_id, sound_id
)
# Verify it was removed
assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
)
@pytest.mark.asyncio
async def test_get_playlist_sounds(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test getting sounds in a playlist."""
# Create objects within this test
from app.utils.auth import PasswordUtils
user = User(
email="test@example.com",
name="Test User",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=0,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# Initially empty
sounds = await playlist_repository.get_playlist_sounds(playlist_id)
assert len(sounds) == 0
# Add sound
await playlist_repository.add_sound_to_playlist(playlist_id, sound_id)
# Check sounds
sounds = await playlist_repository.get_playlist_sounds(playlist_id)
assert len(sounds) == 1
assert sounds[0].id == sound_id
@pytest.mark.asyncio
async def test_get_playlist_sound_count(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test getting sound count in a playlist."""
# Create objects within this test
from app.utils.auth import PasswordUtils
user = User(
email="test@example.com",
name="Test User",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=0,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# Initially empty
count = await playlist_repository.get_playlist_sound_count(playlist_id)
assert count == 0
# Add sound
await playlist_repository.add_sound_to_playlist(playlist_id, sound_id)
# Check count
count = await playlist_repository.get_playlist_sound_count(playlist_id)
assert count == 1
@pytest.mark.asyncio
async def test_is_sound_in_playlist(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test checking if sound is in playlist."""
# Create objects within this test
from app.utils.auth import PasswordUtils
user = User(
email="test@example.com",
name="Test User",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=0,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# Initially not in playlist
assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
)
# Add sound
await playlist_repository.add_sound_to_playlist(playlist_id, sound_id)
# Now in playlist
assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id
)
@pytest.mark.asyncio
async def test_reorder_playlist_sounds(
self,
playlist_repository: PlaylistRepository,
test_session: AsyncSession,
ensure_plans,
) -> None:
"""Test reordering sounds in a playlist."""
# Create objects within this test
from app.utils.auth import PasswordUtils
user = User(
email="test@example.com",
name="Test User",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
user_id = user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
# Create multiple sounds
sound1 = Sound(name="Sound 1", filename="sound1.mp3", type="SDB", hash="hash1")
sound2 = Sound(name="Sound 2", filename="sound2.mp3", type="SDB", hash="hash2")
test_session.add_all([playlist, sound1, sound2])
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound1)
await test_session.refresh(sound2)
# Extract IDs before async calls
playlist_id = playlist.id
sound1_id = sound1.id
sound2_id = sound2.id
# Add sounds to playlist
await playlist_repository.add_sound_to_playlist(
playlist_id, sound1_id, position=0
)
await playlist_repository.add_sound_to_playlist(
playlist_id, sound2_id, position=1
)
# Reorder sounds - use different positions to avoid constraint issues
sound_positions = [(sound1_id, 10), (sound2_id, 5)]
await playlist_repository.reorder_playlist_sounds(
playlist_id, sound_positions
)
# Verify new order
sounds = await playlist_repository.get_playlist_sounds(playlist_id)
assert len(sounds) == 2
assert sounds[0].id == sound2_id # sound2 now at position 5
assert sounds[1].id == sound1_id # sound1 now at position 10

View File

@@ -48,11 +48,15 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_duplicate_email( async def test_register_duplicate_email(
self, auth_service: AuthService, test_user: User, self,
auth_service: AuthService,
test_user: User,
) -> None: ) -> None:
"""Test registration with duplicate email.""" """Test registration with duplicate email."""
request = UserRegisterRequest( request = UserRegisterRequest(
email=test_user.email, password="password123", name="Another User", email=test_user.email,
password="password123",
name="Another User",
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -89,7 +93,8 @@ class TestAuthService:
async def test_login_invalid_email(self, auth_service: AuthService) -> None: async def test_login_invalid_email(self, auth_service: AuthService) -> None:
"""Test login with invalid email.""" """Test login with invalid email."""
request = UserLoginRequest( request = UserLoginRequest(
email="nonexistent@example.com", password="password123", email="nonexistent@example.com",
password="password123",
) )
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -100,7 +105,9 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_invalid_password( async def test_login_invalid_password(
self, auth_service: AuthService, test_user: User, self,
auth_service: AuthService,
test_user: User,
) -> None: ) -> None:
"""Test login with invalid password.""" """Test login with invalid password."""
request = UserLoginRequest(email=test_user.email, password="wrongpassword") request = UserLoginRequest(email=test_user.email, password="wrongpassword")
@@ -113,7 +120,10 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_inactive_user( async def test_login_inactive_user(
self, auth_service: AuthService, test_user: User, test_session: AsyncSession, self,
auth_service: AuthService,
test_user: User,
test_session: AsyncSession,
) -> None: ) -> None:
"""Test login with inactive user.""" """Test login with inactive user."""
# Store the email before deactivating # Store the email before deactivating
@@ -133,7 +143,10 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_user_without_password( async def test_login_user_without_password(
self, auth_service: AuthService, test_user: User, test_session: AsyncSession, self,
auth_service: AuthService,
test_user: User,
test_session: AsyncSession,
) -> None: ) -> None:
"""Test login with user that has no password hash.""" """Test login with user that has no password hash."""
# Store the email before removing password # Store the email before removing password
@@ -153,7 +166,9 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_success( async def test_get_current_user_success(
self, auth_service: AuthService, test_user: User, self,
auth_service: AuthService,
test_user: User,
) -> None: ) -> None:
"""Test getting current user successfully.""" """Test getting current user successfully."""
user = await auth_service.get_current_user(test_user.id) user = await auth_service.get_current_user(test_user.id)
@@ -174,7 +189,10 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_inactive( async def test_get_current_user_inactive(
self, auth_service: AuthService, test_user: User, test_session: AsyncSession, self,
auth_service: AuthService,
test_user: User,
test_session: AsyncSession,
) -> None: ) -> None:
"""Test getting current user when user is inactive.""" """Test getting current user when user is inactive."""
# Store the user ID before deactivating # Store the user ID before deactivating
@@ -192,7 +210,9 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_access_token( async def test_create_access_token(
self, auth_service: AuthService, test_user: User, self,
auth_service: AuthService,
test_user: User,
) -> None: ) -> None:
"""Test access token creation.""" """Test access token creation."""
token_response = auth_service._create_access_token(test_user) token_response = auth_service._create_access_token(test_user)
@@ -211,7 +231,10 @@ class TestAuthService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_user_response( async def test_create_user_response(
self, auth_service: AuthService, test_user: User, test_session: AsyncSession, self,
auth_service: AuthService,
test_user: User,
test_session: AsyncSession,
) -> None: ) -> None:
"""Test user response creation.""" """Test user response creation."""
# Ensure plan relationship is loaded # Ensure plan relationship is loaded

View File

@@ -52,7 +52,9 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_service_info_youtube(self, mock_ydl_class, extraction_service): async def test_detect_service_info_youtube(
self, mock_ydl_class, extraction_service
):
"""Test service detection for YouTube.""" """Test service detection for YouTube."""
mock_ydl = Mock() mock_ydl = Mock()
mock_ydl_class.return_value.__enter__.return_value = mock_ydl mock_ydl_class.return_value.__enter__.return_value = mock_ydl
@@ -75,7 +77,9 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_service_info_failure(self, mock_ydl_class, extraction_service): async def test_detect_service_info_failure(
self, mock_ydl_class, extraction_service
):
"""Test service detection failure.""" """Test service detection failure."""
mock_ydl = Mock() mock_ydl = Mock()
mock_ydl_class.return_value.__enter__.return_value = mock_ydl mock_ydl_class.return_value.__enter__.return_value = mock_ydl
@@ -201,8 +205,12 @@ class TestExtractionService:
extraction_service, "_detect_service_info", return_value=service_info extraction_service, "_detect_service_info", return_value=service_info
), ),
patch.object(extraction_service, "_extract_media") as mock_extract, patch.object(extraction_service, "_extract_media") as mock_extract,
patch.object(extraction_service, "_move_files_to_final_location") as mock_move, patch.object(
patch.object(extraction_service, "_create_sound_record") as mock_create_sound, extraction_service, "_move_files_to_final_location"
) as mock_move,
patch.object(
extraction_service, "_create_sound_record"
) as mock_create_sound,
patch.object(extraction_service, "_normalize_sound") as mock_normalize, patch.object(extraction_service, "_normalize_sound") as mock_normalize,
patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist, patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist,
): ):
@@ -288,7 +296,6 @@ class TestExtractionService:
"app.services.extraction.get_file_hash", return_value="test_hash" "app.services.extraction.get_file_hash", return_value="test_hash"
), ),
): ):
extraction_service.sound_repo.create = AsyncMock( extraction_service.sound_repo.create = AsyncMock(
return_value=mock_sound return_value=mock_sound
) )

View File

@@ -29,8 +29,9 @@ class TestExtractionProcessor:
async def test_start_and_stop(self, processor): async def test_start_and_stop(self, processor):
"""Test starting and stopping the processor.""" """Test starting and stopping the processor."""
# Mock the _process_queue method to avoid actual processing # Mock the _process_queue method to avoid actual processing
with patch.object(processor, "_process_queue", new_callable=AsyncMock) as mock_process: with patch.object(
processor, "_process_queue", new_callable=AsyncMock
) as mock_process:
# Start the processor # Start the processor
await processor.start() await processor.start()
assert processor.processor_task is not None assert processor.processor_task is not None
@@ -44,7 +45,6 @@ class TestExtractionProcessor:
async def test_start_already_running(self, processor): async def test_start_already_running(self, processor):
"""Test starting processor when already running.""" """Test starting processor when already running."""
with patch.object(processor, "_process_queue", new_callable=AsyncMock): with patch.object(processor, "_process_queue", new_callable=AsyncMock):
# Start first time # Start first time
await processor.start() await processor.start()
first_task = processor.processor_task first_task = processor.processor_task
@@ -150,7 +150,6 @@ class TestExtractionProcessor:
return_value=mock_service, return_value=mock_service,
), ),
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aenter__.return_value = mock_session
@@ -176,7 +175,6 @@ class TestExtractionProcessor:
return_value=mock_service, return_value=mock_service,
), ),
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aenter__.return_value = mock_session
@@ -207,7 +205,6 @@ class TestExtractionProcessor:
return_value=mock_service, return_value=mock_service,
), ),
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aenter__.return_value = mock_session
@@ -232,14 +229,15 @@ class TestExtractionProcessor:
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession"
) as mock_session_class, ) as mock_session_class,
patch.object(processor, "_process_single_extraction", new_callable=AsyncMock) as mock_process, patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock
) as mock_process,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
return_value=mock_service, return_value=mock_service,
), ),
patch("asyncio.create_task") as mock_create_task, patch("asyncio.create_task") as mock_create_task,
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aenter__.return_value = mock_session
@@ -276,14 +274,15 @@ class TestExtractionProcessor:
patch( patch(
"app.services.extraction_processor.AsyncSession" "app.services.extraction_processor.AsyncSession"
) as mock_session_class, ) as mock_session_class,
patch.object(processor, "_process_single_extraction", new_callable=AsyncMock) as mock_process, patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock
) as mock_process,
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
return_value=mock_service, return_value=mock_service,
), ),
patch("asyncio.create_task") as mock_create_task, patch("asyncio.create_task") as mock_create_task,
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aenter__.return_value = mock_session

View File

@@ -0,0 +1,971 @@
"""Tests for playlist service."""
from collections.abc import AsyncGenerator
import pytest
import pytest_asyncio
from fastapi import HTTPException
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.playlist import Playlist
from app.models.sound import Sound
from app.models.user import User
from app.services.playlist import PlaylistService
class TestPlaylistService:
"""Test playlist service operations."""
@pytest_asyncio.fixture
async def playlist_service(
self,
test_session: AsyncSession,
) -> AsyncGenerator[PlaylistService, None]:
"""Create a playlist service instance."""
yield PlaylistService(test_session)
@pytest_asyncio.fixture
async def test_playlist(
self,
test_session: AsyncSession,
test_user: User,
) -> Playlist:
"""Create a test playlist."""
# Extract user_id from test_user within the fixture
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
return playlist
@pytest_asyncio.fixture
async def current_playlist(
self,
test_session: AsyncSession,
test_user: User,
) -> Playlist:
"""Create a current playlist."""
# Extract user_id from test_user within the fixture
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Current Playlist",
description="Currently active playlist",
is_main=False,
is_current=True,
is_deletable=True,
)
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
return playlist
@pytest_asyncio.fixture
async def main_playlist(
self,
test_session: AsyncSession,
) -> Playlist:
"""Create a main playlist."""
playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_current=False,
is_deletable=False,
)
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
return playlist
@pytest_asyncio.fixture
async def test_sound(
self,
test_session: AsyncSession,
) -> Sound:
"""Create a test sound."""
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=10,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(sound)
return sound
@pytest_asyncio.fixture
async def other_user(
self,
test_session: AsyncSession,
ensure_plans,
) -> User:
"""Create another test user."""
from app.utils.auth import PasswordUtils
user = User(
email="other@example.com",
name="Other User",
password_hash=PasswordUtils.hash_password("password123"),
role="user",
is_active=True,
plan_id=ensure_plans[0].id,
credits=100,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
return user
@pytest.mark.asyncio
async def test_get_playlist_by_id_success(
self,
playlist_service: PlaylistService,
test_user: User,
test_playlist: Playlist,
) -> None:
"""Test getting playlist by ID successfully."""
assert test_playlist.id is not None
playlist = await playlist_service.get_playlist_by_id(test_playlist.id)
assert playlist.id == test_playlist.id
assert playlist.name == test_playlist.name
@pytest.mark.asyncio
async def test_get_playlist_by_id_not_found(
self,
playlist_service: PlaylistService,
test_user: User,
) -> None:
"""Test getting non-existent playlist."""
with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_playlist_by_id(99999)
assert exc_info.value.status_code == 404
assert "not found" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_main_playlist_existing(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test getting existing main playlist."""
# Create main playlist manually
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_deletable=False,
)
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(main_playlist)
playlist = await playlist_service.get_main_playlist()
assert playlist.id == main_playlist.id
assert playlist.is_main is True
@pytest.mark.asyncio
async def test_get_main_playlist_create_if_not_exists(
self,
playlist_service: PlaylistService,
test_user: User,
) -> None:
"""Test that service fails if main playlist doesn't exist."""
# Should raise an HTTPException if no main playlist exists
with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_main_playlist()
assert exc_info.value.status_code == 500
assert "Main playlist not found" in exc_info.value.detail
@pytest.mark.asyncio
async def test_create_playlist_success(
self,
playlist_service: PlaylistService,
test_user: User,
) -> None:
"""Test creating a new playlist successfully."""
user_id = test_user.id # Extract user_id while session is available
playlist = await playlist_service.create_playlist(
user_id=user_id,
name="New Playlist",
description="A new playlist",
genre="rock",
)
assert playlist.name == "New Playlist"
assert playlist.description == "A new playlist"
assert playlist.genre == "rock"
assert playlist.user_id == user_id
assert playlist.is_main is False
assert playlist.is_current is False
assert playlist.is_deletable is True
@pytest.mark.asyncio
async def test_create_playlist_duplicate_name(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test creating playlist with duplicate name fails."""
# Create test playlist within this test
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
# Extract name before async call
playlist_name = playlist.name
with pytest.raises(HTTPException) as exc_info:
await playlist_service.create_playlist(
user_id=user_id,
name=playlist_name, # Same name as existing playlist
)
assert exc_info.value.status_code == 400
assert "already exists" in exc_info.value.detail
@pytest.mark.asyncio
async def test_create_playlist_as_current(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test creating a playlist as current unsets previous current."""
# Create current playlist within this test
user_id = test_user.id
current_playlist = Playlist(
user_id=user_id,
name="Current Playlist",
description="Currently active playlist",
is_main=False,
is_current=True,
is_deletable=True,
)
test_session.add(current_playlist)
await test_session.commit()
await test_session.refresh(current_playlist)
# Verify the existing current playlist
assert current_playlist.is_current is True
# Extract ID before async call
current_playlist_id = current_playlist.id
# Create new playlist as current
new_playlist = await playlist_service.create_playlist(
user_id=user_id,
name="New Current Playlist",
is_current=True,
)
assert new_playlist.is_current is True
# Verify the old current playlist is no longer current
# We need to refresh the old playlist from the database
old_playlist = await playlist_service.get_playlist_by_id(current_playlist_id)
assert old_playlist.is_current is False
@pytest.mark.asyncio
async def test_update_playlist_success(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test updating a playlist successfully."""
# Create test playlist within this test
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
# Extract IDs before async call
playlist_id = playlist.id
updated_playlist = await playlist_service.update_playlist(
playlist_id=playlist_id,
user_id=user_id,
name="Updated Name",
description="Updated description",
genre="jazz",
)
assert updated_playlist.name == "Updated Name"
assert updated_playlist.description == "Updated description"
assert updated_playlist.genre == "jazz"
@pytest.mark.asyncio
async def test_update_playlist_set_current(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test setting a playlist as current via update."""
# Create test playlist within this test
user_id = test_user.id
test_playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(test_playlist)
current_playlist = Playlist(
user_id=user_id,
name="Current Playlist",
description="Currently active playlist",
is_main=False,
is_current=True,
is_deletable=True,
)
test_session.add(current_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(current_playlist)
# Extract IDs before async calls
test_playlist_id = test_playlist.id
current_playlist_id = current_playlist.id
# Verify initial state
assert test_playlist.is_current is False
assert current_playlist.is_current is True
# Update playlist to be current
updated_playlist = await playlist_service.update_playlist(
playlist_id=test_playlist_id,
user_id=user_id,
is_current=True,
)
assert updated_playlist.is_current is True
# Verify old current playlist is no longer current
old_current = await playlist_service.get_playlist_by_id(current_playlist_id)
assert old_current.is_current is False
@pytest.mark.asyncio
async def test_delete_playlist_success(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test deleting a playlist successfully."""
# Create test playlist within this test
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
await test_session.commit()
await test_session.refresh(playlist)
# Extract ID before async call
playlist_id = playlist.id
await playlist_service.delete_playlist(playlist_id, user_id)
# Verify playlist is deleted
with pytest.raises(HTTPException) as exc_info:
await playlist_service.get_playlist_by_id(playlist_id)
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_delete_current_playlist_sets_main_as_current(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test deleting current playlist sets main as current."""
# Create main playlist first (required by service)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_current=False,
is_deletable=False,
)
test_session.add(main_playlist)
# Create current playlist within this test
user_id = test_user.id
current_playlist = Playlist(
user_id=user_id,
name="Current Playlist",
description="Currently active playlist",
is_main=False,
is_current=True,
is_deletable=True,
)
test_session.add(current_playlist)
await test_session.commit()
await test_session.refresh(current_playlist)
# Extract ID before async call
current_playlist_id = current_playlist.id
# Delete the current playlist
await playlist_service.delete_playlist(current_playlist_id, user_id)
# Verify main playlist is now fallback current (main playlist doesn't have is_current=True)
# The service returns main playlist when no current is set
current = await playlist_service.get_current_playlist(user_id)
assert current.is_main is True
@pytest.mark.asyncio
async def test_delete_non_deletable_playlist(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test deleting non-deletable playlist fails."""
# Extract user ID immediately
user_id = test_user.id
# Create non-deletable playlist
non_deletable = Playlist(
user_id=user_id,
name="Non-deletable",
is_deletable=False,
)
test_session.add(non_deletable)
await test_session.commit()
await test_session.refresh(non_deletable)
# Extract ID before async call
non_deletable_id = non_deletable.id
with pytest.raises(HTTPException) as exc_info:
await playlist_service.delete_playlist(non_deletable_id, user_id)
assert exc_info.value.status_code == 400
assert "cannot be deleted" in exc_info.value.detail
@pytest.mark.asyncio
async def test_add_sound_to_playlist_success(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test adding sound to playlist successfully."""
# Create test playlist and sound within this test
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=10,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
await playlist_service.add_sound_to_playlist(
playlist_id=playlist_id,
sound_id=sound_id,
user_id=user_id,
)
# Verify sound was added
sounds = await playlist_service.get_playlist_sounds(playlist_id)
assert len(sounds) == 1
assert sounds[0].id == sound_id
@pytest.mark.asyncio
async def test_add_sound_to_playlist_already_exists(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test adding sound that's already in playlist fails."""
# Create test playlist and sound within this test
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=10,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# Add sound first time
await playlist_service.add_sound_to_playlist(
playlist_id=playlist_id,
sound_id=sound_id,
user_id=user_id,
)
# Try to add same sound again
with pytest.raises(HTTPException) as exc_info:
await playlist_service.add_sound_to_playlist(
playlist_id=playlist_id,
sound_id=sound_id,
user_id=user_id,
)
assert exc_info.value.status_code == 400
assert "already in this playlist" in exc_info.value.detail
@pytest.mark.asyncio
async def test_remove_sound_from_playlist_success(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test removing sound from playlist successfully."""
# Create test playlist and sound within this test
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=10,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# Add sound first
await playlist_service.add_sound_to_playlist(
playlist_id=playlist_id,
sound_id=sound_id,
user_id=user_id,
)
# Remove sound
await playlist_service.remove_sound_from_playlist(
playlist_id=playlist_id,
sound_id=sound_id,
user_id=user_id,
)
# Verify sound was removed
sounds = await playlist_service.get_playlist_sounds(playlist_id)
assert len(sounds) == 0
@pytest.mark.asyncio
async def test_remove_sound_not_in_playlist(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test removing sound that's not in playlist fails."""
# Create test playlist and sound within this test
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=10,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
with pytest.raises(HTTPException) as exc_info:
await playlist_service.remove_sound_from_playlist(
playlist_id=playlist_id,
sound_id=sound_id,
user_id=user_id,
)
assert exc_info.value.status_code == 404
assert "not found in this playlist" in exc_info.value.detail
@pytest.mark.asyncio
async def test_set_current_playlist(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test setting a playlist as current."""
# Create test playlists within this test
user_id = test_user.id
test_playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(test_playlist)
current_playlist = Playlist(
user_id=user_id,
name="Current Playlist",
description="Currently active playlist",
is_main=False,
is_current=True,
is_deletable=True,
)
test_session.add(current_playlist)
await test_session.commit()
await test_session.refresh(test_playlist)
await test_session.refresh(current_playlist)
# Extract IDs before async calls
test_playlist_id = test_playlist.id
current_playlist_id = current_playlist.id
# Verify initial state
assert current_playlist.is_current is True
assert test_playlist.is_current is False
# Set test_playlist as current
updated_playlist = await playlist_service.set_current_playlist(
test_playlist_id, user_id
)
assert updated_playlist.is_current is True
# Verify old current is no longer current
old_current = await playlist_service.get_playlist_by_id(current_playlist_id)
assert old_current.is_current is False
@pytest.mark.asyncio
async def test_unset_current_playlist_sets_main_as_current(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test unsetting current playlist falls back to main playlist."""
# Create test playlists within this test
user_id = test_user.id
current_playlist = Playlist(
user_id=user_id,
name="Current Playlist",
description="Currently active playlist",
is_main=False,
is_current=True,
is_deletable=True,
)
test_session.add(current_playlist)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_current=False,
is_deletable=False,
)
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(current_playlist)
await test_session.refresh(main_playlist)
# Extract IDs before async calls
current_playlist_id = current_playlist.id
main_playlist_id = main_playlist.id
# Verify initial state
assert current_playlist.is_current is True
# Unset current playlist
await playlist_service.unset_current_playlist(user_id)
# Verify get_current_playlist returns main playlist as fallback
current = await playlist_service.get_current_playlist(user_id)
assert current.id == main_playlist_id
assert current.is_main is True
# Verify old current is no longer current
old_current = await playlist_service.get_playlist_by_id(current_playlist_id)
assert old_current.is_current is False
@pytest.mark.asyncio
async def test_get_playlist_stats(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test getting playlist statistics."""
# Create test playlist and sound within this test
user_id = test_user.id
playlist = Playlist(
user_id=user_id,
name="Test Playlist",
description="A test playlist",
genre="test",
is_main=False,
is_current=False,
is_deletable=True,
)
test_session.add(playlist)
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=10,
)
test_session.add(sound)
await test_session.commit()
await test_session.refresh(playlist)
await test_session.refresh(sound)
# Extract IDs before async calls
playlist_id = playlist.id
sound_id = sound.id
# Initially empty playlist
stats = await playlist_service.get_playlist_stats(playlist_id)
assert stats["sound_count"] == 0
assert stats["total_duration_ms"] == 0
assert stats["total_play_count"] == 0
# Add sound to playlist
await playlist_service.add_sound_to_playlist(
playlist_id=playlist_id,
sound_id=sound_id,
user_id=user_id,
)
# Check stats again
stats = await playlist_service.get_playlist_stats(playlist_id)
assert stats["sound_count"] == 1
assert stats["total_duration_ms"] == 5000 # From test_sound fixture
assert stats["total_play_count"] == 10 # From test_sound fixture
@pytest.mark.asyncio
async def test_add_sound_to_main_playlist(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test adding sound to main playlist."""
# Create test sound and main playlist within this test
user_id = test_user.id
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=10,
)
test_session.add(sound)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_current=False,
is_deletable=False,
)
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(sound)
await test_session.refresh(main_playlist)
# Extract IDs before async calls
sound_id = sound.id
main_playlist_id = main_playlist.id
# Add sound to main playlist
await playlist_service.add_sound_to_main_playlist(sound_id, user_id)
# Verify sound was added to main playlist
sounds = await playlist_service.get_playlist_sounds(main_playlist_id)
assert len(sounds) == 1
assert sounds[0].id == sound_id
@pytest.mark.asyncio
async def test_add_sound_to_main_playlist_already_exists(
self,
playlist_service: PlaylistService,
test_user: User,
test_session: AsyncSession,
) -> None:
"""Test adding sound to main playlist when it already exists (should not duplicate)."""
# Create test sound and main playlist within this test
user_id = test_user.id
sound = Sound(
name="Test Sound",
filename="test.mp3",
type="SDB",
duration=5000,
size=1024,
hash="test_hash",
play_count=10,
)
test_session.add(sound)
main_playlist = Playlist(
user_id=None,
name="Main Playlist",
description="Main playlist",
is_main=True,
is_current=False,
is_deletable=False,
)
test_session.add(main_playlist)
await test_session.commit()
await test_session.refresh(sound)
await test_session.refresh(main_playlist)
# Extract IDs before async calls
sound_id = sound.id
main_playlist_id = main_playlist.id
# Add sound to main playlist twice
await playlist_service.add_sound_to_main_playlist(sound_id, user_id)
await playlist_service.add_sound_to_main_playlist(sound_id, user_id)
# Verify sound is only added once
sounds = await playlist_service.get_playlist_sounds(main_playlist_id)
assert len(sounds) == 1
assert sounds[0].id == sound_id

View File

@@ -97,7 +97,9 @@ class TestSocketManager:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_success(self, mock_decode, mock_extract_token, socket_manager, mock_sio): async def test_connect_handler_success(
self, mock_decode, mock_extract_token, socket_manager, mock_sio
):
"""Test successful connection with valid token.""" """Test successful connection with valid token."""
# Setup mocks # Setup mocks
mock_extract_token.return_value = "valid_token" mock_extract_token.return_value = "valid_token"
@@ -130,7 +132,9 @@ class TestSocketManager:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
async def test_connect_handler_no_token(self, mock_extract_token, socket_manager, mock_sio): async def test_connect_handler_no_token(
self, mock_extract_token, socket_manager, mock_sio
):
"""Test connection with no access token.""" """Test connection with no access token."""
# Setup mocks # Setup mocks
mock_extract_token.return_value = None mock_extract_token.return_value = None
@@ -162,7 +166,9 @@ class TestSocketManager:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_invalid_token(self, mock_decode, mock_extract_token, socket_manager, mock_sio): async def test_connect_handler_invalid_token(
self, mock_decode, mock_extract_token, socket_manager, mock_sio
):
"""Test connection with invalid token.""" """Test connection with invalid token."""
# Setup mocks # Setup mocks
mock_extract_token.return_value = "invalid_token" mock_extract_token.return_value = "invalid_token"
@@ -195,7 +201,9 @@ class TestSocketManager:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_missing_user_id(self, mock_decode, mock_extract_token, socket_manager, mock_sio): async def test_connect_handler_missing_user_id(
self, mock_decode, mock_extract_token, socket_manager, mock_sio
):
"""Test connection with token missing user ID.""" """Test connection with token missing user ID."""
# Setup mocks # Setup mocks
mock_extract_token.return_value = "token_without_user_id" mock_extract_token.return_value = "token_without_user_id"

View File

@@ -182,7 +182,6 @@ class TestSoundNormalizerService:
"app.services.sound_normalizer.get_file_hash", return_value="new_hash" "app.services.sound_normalizer.get_file_hash", return_value="new_hash"
), ),
): ):
# Setup path mocks # Setup path mocks
mock_orig_path.return_value = Path("/fake/original.mp3") mock_orig_path.return_value = Path("/fake/original.mp3")
mock_norm_path.return_value = Path("/fake/normalized.mp3") mock_norm_path.return_value = Path("/fake/normalized.mp3")
@@ -256,7 +255,6 @@ class TestSoundNormalizerService:
"app.services.sound_normalizer.get_file_hash", return_value="norm_hash" "app.services.sound_normalizer.get_file_hash", return_value="norm_hash"
), ),
): ):
# Setup path mocks # Setup path mocks
mock_orig_path.return_value = Path("/fake/original.mp3") mock_orig_path.return_value = Path("/fake/original.mp3")
mock_norm_path.return_value = Path("/fake/normalized.mp3") mock_norm_path.return_value = Path("/fake/normalized.mp3")
@@ -294,7 +292,6 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path, patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path, patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
): ):
# Setup path mocks # Setup path mocks
mock_orig_path.return_value = Path("/fake/original.mp3") mock_orig_path.return_value = Path("/fake/original.mp3")
mock_norm_path.return_value = Path("/fake/normalized.mp3") mock_norm_path.return_value = Path("/fake/normalized.mp3")
@@ -306,7 +303,6 @@ class TestSoundNormalizerService:
normalizer_service, "_normalize_audio_two_pass" normalizer_service, "_normalize_audio_two_pass"
) as mock_normalize, ) as mock_normalize,
): ):
mock_normalize.side_effect = Exception("Normalization failed") mock_normalize.side_effect = Exception("Normalization failed")
result = await normalizer_service.normalize_sound(sound) result = await normalizer_service.normalize_sound(sound)

View File

@@ -41,6 +41,7 @@ class TestSoundScannerService:
try: try:
from app.utils.audio import get_file_hash from app.utils.audio import get_file_hash
hash_value = get_file_hash(temp_path) hash_value = get_file_hash(temp_path)
assert len(hash_value) == 64 # SHA-256 hash length assert len(hash_value) == 64 # SHA-256 hash length
assert isinstance(hash_value, str) assert isinstance(hash_value, str)
@@ -56,6 +57,7 @@ class TestSoundScannerService:
try: try:
from app.utils.audio import get_file_size from app.utils.audio import get_file_size
size = get_file_size(temp_path) size = get_file_size(temp_path)
assert size > 0 assert size > 0
assert isinstance(size, int) assert isinstance(size, int)
@@ -83,6 +85,7 @@ class TestSoundScannerService:
temp_path = Path("/fake/path/test.mp3") temp_path = Path("/fake/path/test.mp3")
from app.utils.audio import get_audio_duration from app.utils.audio import get_audio_duration
duration = get_audio_duration(temp_path) duration = get_audio_duration(temp_path)
assert duration == 123456 # 123.456 seconds * 1000 = 123456 ms assert duration == 123456 # 123.456 seconds * 1000 = 123456 ms
@@ -95,6 +98,7 @@ class TestSoundScannerService:
temp_path = Path("/fake/path/test.mp3") temp_path = Path("/fake/path/test.mp3")
from app.utils.audio import get_audio_duration from app.utils.audio import get_audio_duration
duration = get_audio_duration(temp_path) duration = get_audio_duration(temp_path)
assert duration == 0 assert duration == 0
@@ -129,10 +133,11 @@ class TestSoundScannerService:
) )
# Mock file operations to return same hash # Mock file operations to return same hash
with patch("app.services.sound_scanner.get_file_hash", return_value="same_hash"), \ with (
patch("app.services.sound_scanner.get_audio_duration", return_value=120000), \ patch("app.services.sound_scanner.get_file_hash", return_value="same_hash"),
patch("app.services.sound_scanner.get_file_size", return_value=1024): patch("app.services.sound_scanner.get_audio_duration", return_value=120000),
patch("app.services.sound_scanner.get_file_size", return_value=1024),
):
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
temp_path = Path(f.name) temp_path = Path(f.name)
@@ -175,10 +180,11 @@ class TestSoundScannerService:
scanner_service.sound_repo.create = AsyncMock(return_value=created_sound) scanner_service.sound_repo.create = AsyncMock(return_value=created_sound)
# Mock file operations # Mock file operations
with patch("app.services.sound_scanner.get_file_hash", return_value="test_hash"), \ with (
patch("app.services.sound_scanner.get_audio_duration", return_value=120000), \ patch("app.services.sound_scanner.get_file_hash", return_value="test_hash"),
patch("app.services.sound_scanner.get_file_size", return_value=1024): patch("app.services.sound_scanner.get_audio_duration", return_value=120000),
patch("app.services.sound_scanner.get_file_size", return_value=1024),
):
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
temp_path = Path(f.name) temp_path = Path(f.name)
@@ -208,7 +214,9 @@ class TestSoundScannerService:
assert call_args["duration"] == 120000 # Duration in ms assert call_args["duration"] == 120000 # Duration in ms
assert call_args["size"] == 1024 assert call_args["size"] == 1024
assert call_args["hash"] == "test_hash" assert call_args["hash"] == "test_hash"
assert call_args["is_deletable"] is False # SDB sounds are not deletable assert (
call_args["is_deletable"] is False
) # SDB sounds are not deletable
finally: finally:
temp_path.unlink() temp_path.unlink()
@@ -229,10 +237,11 @@ class TestSoundScannerService:
scanner_service.sound_repo.update = AsyncMock(return_value=existing_sound) scanner_service.sound_repo.update = AsyncMock(return_value=existing_sound)
# Mock file operations to return new values # Mock file operations to return new values
with patch("app.services.sound_scanner.get_file_hash", return_value="new_hash"), \ with (
patch("app.services.sound_scanner.get_audio_duration", return_value=120000), \ patch("app.services.sound_scanner.get_file_hash", return_value="new_hash"),
patch("app.services.sound_scanner.get_file_size", return_value=1024): patch("app.services.sound_scanner.get_audio_duration", return_value=120000),
patch("app.services.sound_scanner.get_file_size", return_value=1024),
):
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
temp_path = Path(f.name) temp_path = Path(f.name)
@@ -259,7 +268,9 @@ class TestSoundScannerService:
assert results["files"][0]["reason"] == "file was modified" assert results["files"][0]["reason"] == "file was modified"
# Verify sound_repo.update was called with correct data # Verify sound_repo.update was called with correct data
call_args = scanner_service.sound_repo.update.call_args[0][1] # update_data call_args = scanner_service.sound_repo.update.call_args[0][
1
] # update_data
assert call_args["duration"] == 120000 assert call_args["duration"] == 120000
assert call_args["size"] == 1024 assert call_args["size"] == 1024
assert call_args["hash"] == "new_hash" assert call_args["hash"] == "new_hash"
@@ -283,10 +294,13 @@ class TestSoundScannerService:
scanner_service.sound_repo.create = AsyncMock(return_value=created_sound) scanner_service.sound_repo.create = AsyncMock(return_value=created_sound)
# Mock file operations # Mock file operations
with patch("app.services.sound_scanner.get_file_hash", return_value="custom_hash"), \ with (
patch("app.services.sound_scanner.get_audio_duration", return_value=60000), \ patch(
patch("app.services.sound_scanner.get_file_size", return_value=2048): "app.services.sound_scanner.get_file_hash", return_value="custom_hash"
),
patch("app.services.sound_scanner.get_audio_duration", return_value=60000),
patch("app.services.sound_scanner.get_file_size", return_value=2048),
):
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
temp_path = Path(f.name) temp_path = Path(f.name)
@@ -301,7 +315,9 @@ class TestSoundScannerService:
"errors": 0, "errors": 0,
"files": [], "files": [],
} }
await scanner_service._sync_audio_file(temp_path, "CUSTOM", None, results) await scanner_service._sync_audio_file(
temp_path, "CUSTOM", None, results
)
assert results["added"] == 1 assert results["added"] == 1
assert results["skipped"] == 0 assert results["skipped"] == 0

View File

@@ -39,7 +39,7 @@ class TestAudioUtils:
def test_get_file_hash_binary_content(self): def test_get_file_hash_binary_content(self):
"""Test file hash calculation with binary content.""" """Test file hash calculation with binary content."""
# Create a temporary file with binary content # Create a temporary file with binary content
test_bytes = b"\x00\x01\x02\x03\xFF\xFE\xFD" test_bytes = b"\x00\x01\x02\x03\xff\xfe\xfd"
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
f.write(test_bytes) f.write(test_bytes)
temp_path = Path(f.name) temp_path = Path(f.name)
@@ -146,7 +146,7 @@ class TestAudioUtils:
def test_get_file_size_binary_file(self): def test_get_file_size_binary_file(self):
"""Test file size calculation for binary file.""" """Test file size calculation for binary file."""
# Create a temporary file with binary content # Create a temporary file with binary content
test_bytes = b"\x00\x01\x02\x03\xFF\xFE\xFD" * 100 # 700 bytes test_bytes = b"\x00\x01\x02\x03\xff\xfe\xfd" * 100 # 700 bytes
with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f: with tempfile.NamedTemporaryFile(mode="wb", delete=False) as f:
f.write(test_bytes) f.write(test_bytes)
temp_path = Path(f.name) temp_path = Path(f.name)

View File

@@ -1,6 +1,5 @@
"""Tests for cookie utilities.""" """Tests for cookie utilities."""
from app.utils.cookies import extract_access_token_from_cookies, parse_cookies from app.utils.cookies import extract_access_token_from_cookies, parse_cookies