Refactor test cases for improved readability and consistency
All checks were successful
Backend CI / lint (push) Successful in 9m49s
Backend CI / test (push) Successful in 6m15s

- Adjusted function signatures in various test files to enhance clarity by aligning parameters.
- Updated patching syntax for better readability across test cases.
- Improved formatting and spacing in test assertions and mock setups.
- Ensured consistent use of async/await patterns in async test functions.
- Enhanced comments for better understanding of test intentions.
This commit is contained in:
JSC
2025-08-01 20:53:30 +02:00
parent d926779fe4
commit 6068599a47
39 changed files with 691 additions and 286 deletions

View File

@@ -32,7 +32,7 @@ async def get_sound_normalizer_service(
# SCAN ENDPOINTS
@router.post("/scan")
async def scan_sounds(
current_user: Annotated[User, Depends(get_admin_user)],
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
) -> dict[str, ScanResults | str]:
"""Sync the soundboard directory (add/update/delete sounds). Admin only."""
@@ -53,11 +53,11 @@ async def scan_sounds(
@router.post("/scan/custom")
async def scan_custom_directory(
directory: str,
current_user: Annotated[User, Depends(get_admin_user)],
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
sound_type: str = "SDB",
) -> dict[str, ScanResults | str]:
"""Sync a custom directory with the database (add/update/delete sounds). Admin only."""
"""Sync a custom directory with the database. Admin only."""
try:
results = await scanner_service.scan_directory(directory, sound_type)
except ValueError as e:
@@ -80,14 +80,15 @@ async def scan_custom_directory(
# NORMALIZE ENDPOINTS
@router.post("/normalize/all")
async def normalize_all_sounds(
current_user: Annotated[User, Depends(get_admin_user)],
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[
SoundNormalizerService,
Depends(get_sound_normalizer_service),
],
*,
force: Annotated[
bool,
Query( # noqa: FBT002
Query(
description="Force normalization of already normalized sounds",
),
] = False,
@@ -119,14 +120,15 @@ async def normalize_all_sounds(
@router.post("/normalize/type/{sound_type}")
async def normalize_sounds_by_type(
sound_type: str,
current_user: Annotated[User, Depends(get_admin_user)],
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[
SoundNormalizerService,
Depends(get_sound_normalizer_service),
],
*,
force: Annotated[
bool,
Query( # noqa: FBT002
Query(
description="Force normalization of already normalized sounds",
),
] = False,
@@ -167,14 +169,15 @@ async def normalize_sounds_by_type(
@router.post("/normalize/{sound_id}")
async def normalize_sound_by_id(
sound_id: int,
current_user: Annotated[User, Depends(get_admin_user)],
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[
SoundNormalizerService,
Depends(get_sound_normalizer_service),
],
*,
force: Annotated[
bool,
Query( # noqa: FBT002
Query(
description="Force normalization of already normalized sound",
),
] = False,

View File

@@ -2,7 +2,7 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
@@ -110,7 +110,7 @@ async def update_playlist(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
playlist = await playlist_service.update_playlist(
playlist_id=playlist_id,
user_id=current_user.id,

View File

@@ -2,7 +2,7 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db, get_session_factory
@@ -18,7 +18,6 @@ from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
router = APIRouter(prefix="/sounds", tags=["sounds"])
async def get_extraction_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> ExtractionService:
@@ -43,7 +42,6 @@ async def get_sound_repository(
return SoundRepository(session)
# EXTRACT
@router.post("/extract")
async def create_extraction(
@@ -60,7 +58,8 @@ async def create_extraction(
)
extraction_info = await extraction_service.create_extraction(
url, current_user.id,
url,
current_user.id,
)
# Queue the extraction for background processing
@@ -83,8 +82,6 @@ async def create_extraction(
}
@router.get("/extract/{extraction_id}")
async def get_extraction(
extraction_id: int,
@@ -206,7 +203,6 @@ async def play_sound_with_vlc(
}
@router.post("/stop")
async def stop_all_vlc_instances(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001

View File

@@ -40,8 +40,10 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
def get_session_factory() -> Callable[[], AsyncSession]:
"""Get a session factory function for services."""
def session_factory() -> AsyncSession:
return AsyncSession(engine)
return session_factory

View File

@@ -30,9 +30,7 @@ class Sound(BaseModel, table=True):
is_deletable: bool = Field(default=True, nullable=False)
# constraints
__table_args__ = (
UniqueConstraint("hash", name="uq_sound_hash"),
)
__table_args__ = (UniqueConstraint("hash", name="uq_sound_hash"),)
# relationships
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound")

View File

@@ -43,7 +43,9 @@ class BaseRepository[ModelType]:
return result.first()
except Exception:
logger.exception(
"Failed to get %s by ID: %s", self.model.__name__, entity_id,
"Failed to get %s by ID: %s",
self.model.__name__,
entity_id,
)
raise

View File

@@ -91,8 +91,7 @@ class CreditTransactionRepository(BaseRepository[CreditTransaction]):
"""
stmt = (
select(CreditTransaction)
.where(CreditTransaction.success == True) # noqa: E712
select(CreditTransaction).where(CreditTransaction.success == True) # noqa: E712
)
if user_id is not None:

View File

@@ -1,6 +1,5 @@
"""Extraction repository for database operations."""
from sqlalchemy import desc
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -17,12 +16,15 @@ class ExtractionRepository(BaseRepository[Extraction]):
super().__init__(Extraction, session)
async def get_by_service_and_id(
self, service: str, service_id: str,
self,
service: str,
service_id: str,
) -> Extraction | None:
"""Get an extraction by service and service_id."""
result = await self.session.exec(
select(Extraction).where(
Extraction.service == service, Extraction.service_id == service_id,
Extraction.service == service,
Extraction.service_id == service_id,
),
)
return result.first()

View File

@@ -1,6 +1,5 @@
"""Playlist repository for database operations."""
from sqlalchemy import func
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -66,7 +65,9 @@ class PlaylistRepository(BaseRepository[Playlist]):
raise
async def search_by_name(
self, query: str, user_id: int | None = None,
self,
query: str,
user_id: int | None = None,
) -> list[Playlist]:
"""Search playlists by name (case-insensitive)."""
try:
@@ -98,7 +99,10 @@ class PlaylistRepository(BaseRepository[Playlist]):
raise
async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, position: int | None = None,
self,
playlist_id: int,
sound_id: int,
position: int | None = None,
) -> PlaylistSound:
"""Add a sound to a playlist."""
try:
@@ -121,7 +125,9 @@ class PlaylistRepository(BaseRepository[Playlist]):
except Exception:
await self.session.rollback()
logger.exception(
"Failed to add sound %s to playlist %s", sound_id, playlist_id,
"Failed to add sound %s to playlist %s",
sound_id,
playlist_id,
)
raise
else:
@@ -150,12 +156,16 @@ class PlaylistRepository(BaseRepository[Playlist]):
except Exception:
await self.session.rollback()
logger.exception(
"Failed to remove sound %s from playlist %s", sound_id, playlist_id,
"Failed to remove sound %s from playlist %s",
sound_id,
playlist_id,
)
raise
async def reorder_playlist_sounds(
self, playlist_id: int, sound_positions: list[tuple[int, int]],
self,
playlist_id: int,
sound_positions: list[tuple[int, int]],
) -> None:
"""Reorder sounds in a playlist.
@@ -220,6 +230,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
return result.first() is not None
except Exception:
logger.exception(
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id,
"Failed to check if sound %s is in playlist %s",
sound_id,
playlist_id,
)
raise

View File

@@ -91,6 +91,7 @@ class SoundRepository(BaseRepository[Sound]):
return list(result.all())
except Exception:
logger.exception(
"Failed to get unnormalized sounds by type: %s", sound_type,
"Failed to get unnormalized sounds by type: %s",
sound_type,
)
raise

View File

@@ -1,6 +1,5 @@
"""Repository for user OAuth operations."""
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -60,4 +59,3 @@ class UserOauthRepository(BaseRepository[UserOauth]):
raise
else:
return result.first()

View File

@@ -30,17 +30,21 @@ class PlayerStateResponse(BaseModel):
status: str = Field(description="Player status (playing, paused, stopped)")
current_sound: dict[str, Any] | None = Field(
None, description="Current sound information",
None,
description="Current sound information",
)
playlist: dict[str, Any] | None = Field(
None, description="Current playlist information",
None,
description="Current playlist information",
)
position: int = Field(description="Current position in milliseconds")
duration: int | None = Field(
None, description="Total duration in milliseconds",
None,
description="Total duration in milliseconds",
)
volume: int = Field(description="Current volume (0-100)")
mode: str = Field(description="Current playback mode")
index: int | None = Field(
None, description="Current track index in playlist",
None,
description="Current track index in playlist",
)

View File

@@ -156,7 +156,8 @@ class ExtractionService:
# Check if extraction already exists for this service
existing = await self.extraction_repo.get_by_service_and_id(
service_info["service"], service_info["service_id"],
service_info["service"],
service_info["service_id"],
)
if existing and existing.id != extraction_id:
error_msg = (
@@ -181,7 +182,8 @@ class ExtractionService:
# Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url,
extraction_id,
extraction_url,
)
# Move files to final locations
@@ -227,7 +229,9 @@ class ExtractionService:
except Exception as e:
error_msg = str(e)
logger.exception(
"Failed to process extraction %d: %s", extraction_id, error_msg,
"Failed to process extraction %d: %s",
extraction_id,
error_msg,
)
else:
return {
@@ -262,7 +266,9 @@ class ExtractionService:
}
async def _extract_media(
self, extraction_id: int, extraction_url: str,
self,
extraction_id: int,
extraction_url: str,
) -> tuple[Path, Path | None]:
"""Extract audio and thumbnail using yt-dlp."""
temp_dir = Path(settings.EXTRACTION_TEMP_DIR)

View File

@@ -65,7 +65,8 @@ class ExtractionProcessor:
# The processor will pick it up on the next cycle
else:
logger.warning(
"Extraction %d is already being processed", extraction_id,
"Extraction %d is already being processed",
extraction_id,
)
async def _process_queue(self) -> None:

View File

@@ -35,10 +35,11 @@ async def _is_current_playlist(session: AsyncSession, playlist_id: int) -> bool:
playlist_repo = PlaylistRepository(session)
current_playlist = await playlist_repo.get_current_playlist()
return current_playlist is not None and current_playlist.id == playlist_id
except Exception: # noqa: BLE001
logger.warning("Failed to check if playlist is current", exc_info=True)
return False
else:
return current_playlist is not None and current_playlist.id == playlist_id
class PlaylistService:
@@ -199,7 +200,7 @@ class PlaylistService:
await self.playlist_repo.delete(playlist)
logger.info("Deleted playlist %s for user %s", playlist_id, user_id)
# If the deleted playlist was current, reload player to use main playlist fallback
# If the deleted playlist was current, reload player to use main fallback
if was_current:
await _reload_player_playlist()

View File

@@ -140,7 +140,10 @@ class SoundNormalizerService:
stream = ffmpeg.overwrite_output(stream)
await asyncio.to_thread(
ffmpeg.run, stream, quiet=True, overwrite_output=True,
ffmpeg.run,
stream,
quiet=True,
overwrite_output=True,
)
logger.info("One-pass normalization completed: %s", output_path)
@@ -180,7 +183,10 @@ class SoundNormalizerService:
# Run first pass and capture output
try:
result = await asyncio.to_thread(
ffmpeg.run, stream, capture_stderr=True, quiet=True,
ffmpeg.run,
stream,
capture_stderr=True,
quiet=True,
)
analysis_output = result[1].decode("utf-8")
except ffmpeg.Error as e:
@@ -262,7 +268,10 @@ class SoundNormalizerService:
try:
await asyncio.to_thread(
ffmpeg.run, stream, quiet=True, overwrite_output=True,
ffmpeg.run,
stream,
quiet=True,
overwrite_output=True,
)
logger.info("Two-pass normalization completed: %s", output_path)
except ffmpeg.Error as e:

View File

@@ -40,6 +40,7 @@ def requires_credits(
return True
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
@@ -70,7 +71,8 @@ def requires_credits(
# Validate credits before execution
await credit_service.validate_and_reserve_credits(
user_id, action_type,
user_id,
action_type,
)
# Execute the function
@@ -86,10 +88,14 @@ def requires_credits(
finally:
# Deduct credits based on success
await credit_service.deduct_credits(
user_id, action_type, success=success, metadata=metadata,
user_id,
action_type,
success=success,
metadata=metadata,
)
return wrapper # type: ignore[return-value]
return decorator
@@ -111,6 +117,7 @@ def validate_credits_only(
Decorated function that validates credits only
"""
def decorator(func: F) -> F:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
@@ -141,6 +148,7 @@ def validate_credits_only(
return await func(*args, **kwargs)
return wrapper # type: ignore[return-value]
return decorator
@@ -173,7 +181,8 @@ class CreditManager:
async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type,
self.user_id,
self.action_type,
)
self.validated = True
return self
@@ -189,7 +198,10 @@ class CreditManager:
# If no exception occurred, consider it successful
success = exc_type is None and self.success
await self.credit_service.deduct_credits(
self.user_id, self.action_type, success=success, metadata=self.metadata,
self.user_id,
self.action_type,
success=success,
metadata=self.metadata,
)
def mark_success(self) -> None: