Refactor test cases for improved readability and consistency
- Adjusted function signatures in various test files to enhance clarity by aligning parameters. - Updated patching syntax for better readability across test cases. - Improved formatting and spacing in test assertions and mock setups. - Ensured consistent use of async/await patterns in async test functions. - Enhanced comments for better understanding of test intentions.
This commit is contained in:
@@ -32,7 +32,7 @@ async def get_sound_normalizer_service(
|
||||
# SCAN ENDPOINTS
|
||||
@router.post("/scan")
|
||||
async def scan_sounds(
|
||||
current_user: Annotated[User, Depends(get_admin_user)],
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
|
||||
) -> dict[str, ScanResults | str]:
|
||||
"""Sync the soundboard directory (add/update/delete sounds). Admin only."""
|
||||
@@ -53,11 +53,11 @@ async def scan_sounds(
|
||||
@router.post("/scan/custom")
|
||||
async def scan_custom_directory(
|
||||
directory: str,
|
||||
current_user: Annotated[User, Depends(get_admin_user)],
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
|
||||
sound_type: str = "SDB",
|
||||
) -> dict[str, ScanResults | str]:
|
||||
"""Sync a custom directory with the database (add/update/delete sounds). Admin only."""
|
||||
"""Sync a custom directory with the database. Admin only."""
|
||||
try:
|
||||
results = await scanner_service.scan_directory(directory, sound_type)
|
||||
except ValueError as e:
|
||||
@@ -80,14 +80,15 @@ async def scan_custom_directory(
|
||||
# NORMALIZE ENDPOINTS
|
||||
@router.post("/normalize/all")
|
||||
async def normalize_all_sounds(
|
||||
current_user: Annotated[User, Depends(get_admin_user)],
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService,
|
||||
Depends(get_sound_normalizer_service),
|
||||
],
|
||||
*,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query( # noqa: FBT002
|
||||
Query(
|
||||
description="Force normalization of already normalized sounds",
|
||||
),
|
||||
] = False,
|
||||
@@ -119,14 +120,15 @@ async def normalize_all_sounds(
|
||||
@router.post("/normalize/type/{sound_type}")
|
||||
async def normalize_sounds_by_type(
|
||||
sound_type: str,
|
||||
current_user: Annotated[User, Depends(get_admin_user)],
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService,
|
||||
Depends(get_sound_normalizer_service),
|
||||
],
|
||||
*,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query( # noqa: FBT002
|
||||
Query(
|
||||
description="Force normalization of already normalized sounds",
|
||||
),
|
||||
] = False,
|
||||
@@ -167,14 +169,15 @@ async def normalize_sounds_by_type(
|
||||
@router.post("/normalize/{sound_id}")
|
||||
async def normalize_sound_by_id(
|
||||
sound_id: int,
|
||||
current_user: Annotated[User, Depends(get_admin_user)],
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
normalizer_service: Annotated[
|
||||
SoundNormalizerService,
|
||||
Depends(get_sound_normalizer_service),
|
||||
],
|
||||
*,
|
||||
force: Annotated[
|
||||
bool,
|
||||
Query( # noqa: FBT002
|
||||
Query(
|
||||
description="Force normalization of already normalized sound",
|
||||
),
|
||||
] = False,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
@@ -110,7 +110,7 @@ async def update_playlist(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
|
||||
playlist = await playlist_service.update_playlist(
|
||||
playlist_id=playlist_id,
|
||||
user_id=current_user.id,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db, get_session_factory
|
||||
@@ -18,7 +18,6 @@ from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
|
||||
router = APIRouter(prefix="/sounds", tags=["sounds"])
|
||||
|
||||
|
||||
|
||||
async def get_extraction_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ExtractionService:
|
||||
@@ -43,7 +42,6 @@ async def get_sound_repository(
|
||||
return SoundRepository(session)
|
||||
|
||||
|
||||
|
||||
# EXTRACT
|
||||
@router.post("/extract")
|
||||
async def create_extraction(
|
||||
@@ -60,7 +58,8 @@ async def create_extraction(
|
||||
)
|
||||
|
||||
extraction_info = await extraction_service.create_extraction(
|
||||
url, current_user.id,
|
||||
url,
|
||||
current_user.id,
|
||||
)
|
||||
|
||||
# Queue the extraction for background processing
|
||||
@@ -83,8 +82,6 @@ async def create_extraction(
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/extract/{extraction_id}")
|
||||
async def get_extraction(
|
||||
extraction_id: int,
|
||||
@@ -206,7 +203,6 @@ async def play_sound_with_vlc(
|
||||
}
|
||||
|
||||
|
||||
|
||||
@router.post("/stop")
|
||||
async def stop_all_vlc_instances(
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001
|
||||
|
||||
@@ -40,8 +40,10 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
def get_session_factory() -> Callable[[], AsyncSession]:
|
||||
"""Get a session factory function for services."""
|
||||
|
||||
def session_factory() -> AsyncSession:
|
||||
return AsyncSession(engine)
|
||||
|
||||
return session_factory
|
||||
|
||||
|
||||
|
||||
@@ -30,9 +30,7 @@ class Sound(BaseModel, table=True):
|
||||
is_deletable: bool = Field(default=True, nullable=False)
|
||||
|
||||
# constraints
|
||||
__table_args__ = (
|
||||
UniqueConstraint("hash", name="uq_sound_hash"),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("hash", name="uq_sound_hash"),)
|
||||
|
||||
# relationships
|
||||
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound")
|
||||
|
||||
@@ -43,7 +43,9 @@ class BaseRepository[ModelType]:
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get %s by ID: %s", self.model.__name__, entity_id,
|
||||
"Failed to get %s by ID: %s",
|
||||
self.model.__name__,
|
||||
entity_id,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -91,8 +91,7 @@ class CreditTransactionRepository(BaseRepository[CreditTransaction]):
|
||||
|
||||
"""
|
||||
stmt = (
|
||||
select(CreditTransaction)
|
||||
.where(CreditTransaction.success == True) # noqa: E712
|
||||
select(CreditTransaction).where(CreditTransaction.success == True) # noqa: E712
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Extraction repository for database operations."""
|
||||
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -17,12 +16,15 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
super().__init__(Extraction, session)
|
||||
|
||||
async def get_by_service_and_id(
|
||||
self, service: str, service_id: str,
|
||||
self,
|
||||
service: str,
|
||||
service_id: str,
|
||||
) -> Extraction | None:
|
||||
"""Get an extraction by service and service_id."""
|
||||
result = await self.session.exec(
|
||||
select(Extraction).where(
|
||||
Extraction.service == service, Extraction.service_id == service_id,
|
||||
Extraction.service == service,
|
||||
Extraction.service_id == service_id,
|
||||
),
|
||||
)
|
||||
return result.first()
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Playlist repository for database operations."""
|
||||
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -66,7 +65,9 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
raise
|
||||
|
||||
async def search_by_name(
|
||||
self, query: str, user_id: int | None = None,
|
||||
self,
|
||||
query: str,
|
||||
user_id: int | None = None,
|
||||
) -> list[Playlist]:
|
||||
"""Search playlists by name (case-insensitive)."""
|
||||
try:
|
||||
@@ -98,7 +99,10 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
raise
|
||||
|
||||
async def add_sound_to_playlist(
|
||||
self, playlist_id: int, sound_id: int, position: int | None = None,
|
||||
self,
|
||||
playlist_id: int,
|
||||
sound_id: int,
|
||||
position: int | None = None,
|
||||
) -> PlaylistSound:
|
||||
"""Add a sound to a playlist."""
|
||||
try:
|
||||
@@ -121,7 +125,9 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception(
|
||||
"Failed to add sound %s to playlist %s", sound_id, playlist_id,
|
||||
"Failed to add sound %s to playlist %s",
|
||||
sound_id,
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
@@ -150,12 +156,16 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception(
|
||||
"Failed to remove sound %s from playlist %s", sound_id, playlist_id,
|
||||
"Failed to remove sound %s from playlist %s",
|
||||
sound_id,
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def reorder_playlist_sounds(
|
||||
self, playlist_id: int, sound_positions: list[tuple[int, int]],
|
||||
self,
|
||||
playlist_id: int,
|
||||
sound_positions: list[tuple[int, int]],
|
||||
) -> None:
|
||||
"""Reorder sounds in a playlist.
|
||||
|
||||
@@ -220,6 +230,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
return result.first() is not None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id,
|
||||
"Failed to check if sound %s is in playlist %s",
|
||||
sound_id,
|
||||
playlist_id,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -91,6 +91,7 @@ class SoundRepository(BaseRepository[Sound]):
|
||||
return list(result.all())
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get unnormalized sounds by type: %s", sound_type,
|
||||
"Failed to get unnormalized sounds by type: %s",
|
||||
sound_type,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Repository for user OAuth operations."""
|
||||
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -60,4 +59,3 @@ class UserOauthRepository(BaseRepository[UserOauth]):
|
||||
raise
|
||||
else:
|
||||
return result.first()
|
||||
|
||||
|
||||
@@ -30,17 +30,21 @@ class PlayerStateResponse(BaseModel):
|
||||
|
||||
status: str = Field(description="Player status (playing, paused, stopped)")
|
||||
current_sound: dict[str, Any] | None = Field(
|
||||
None, description="Current sound information",
|
||||
None,
|
||||
description="Current sound information",
|
||||
)
|
||||
playlist: dict[str, Any] | None = Field(
|
||||
None, description="Current playlist information",
|
||||
None,
|
||||
description="Current playlist information",
|
||||
)
|
||||
position: int = Field(description="Current position in milliseconds")
|
||||
duration: int | None = Field(
|
||||
None, description="Total duration in milliseconds",
|
||||
None,
|
||||
description="Total duration in milliseconds",
|
||||
)
|
||||
volume: int = Field(description="Current volume (0-100)")
|
||||
mode: str = Field(description="Current playback mode")
|
||||
index: int | None = Field(
|
||||
None, description="Current track index in playlist",
|
||||
None,
|
||||
description="Current track index in playlist",
|
||||
)
|
||||
|
||||
@@ -156,7 +156,8 @@ class ExtractionService:
|
||||
|
||||
# Check if extraction already exists for this service
|
||||
existing = await self.extraction_repo.get_by_service_and_id(
|
||||
service_info["service"], service_info["service_id"],
|
||||
service_info["service"],
|
||||
service_info["service_id"],
|
||||
)
|
||||
if existing and existing.id != extraction_id:
|
||||
error_msg = (
|
||||
@@ -181,7 +182,8 @@ class ExtractionService:
|
||||
|
||||
# Extract audio and thumbnail
|
||||
audio_file, thumbnail_file = await self._extract_media(
|
||||
extraction_id, extraction_url,
|
||||
extraction_id,
|
||||
extraction_url,
|
||||
)
|
||||
|
||||
# Move files to final locations
|
||||
@@ -227,7 +229,9 @@ class ExtractionService:
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(
|
||||
"Failed to process extraction %d: %s", extraction_id, error_msg,
|
||||
"Failed to process extraction %d: %s",
|
||||
extraction_id,
|
||||
error_msg,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
@@ -262,7 +266,9 @@ class ExtractionService:
|
||||
}
|
||||
|
||||
async def _extract_media(
|
||||
self, extraction_id: int, extraction_url: str,
|
||||
self,
|
||||
extraction_id: int,
|
||||
extraction_url: str,
|
||||
) -> tuple[Path, Path | None]:
|
||||
"""Extract audio and thumbnail using yt-dlp."""
|
||||
temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
|
||||
|
||||
@@ -65,7 +65,8 @@ class ExtractionProcessor:
|
||||
# The processor will pick it up on the next cycle
|
||||
else:
|
||||
logger.warning(
|
||||
"Extraction %d is already being processed", extraction_id,
|
||||
"Extraction %d is already being processed",
|
||||
extraction_id,
|
||||
)
|
||||
|
||||
async def _process_queue(self) -> None:
|
||||
|
||||
@@ -35,10 +35,11 @@ async def _is_current_playlist(session: AsyncSession, playlist_id: int) -> bool:
|
||||
|
||||
playlist_repo = PlaylistRepository(session)
|
||||
current_playlist = await playlist_repo.get_current_playlist()
|
||||
return current_playlist is not None and current_playlist.id == playlist_id
|
||||
except Exception: # noqa: BLE001
|
||||
logger.warning("Failed to check if playlist is current", exc_info=True)
|
||||
return False
|
||||
else:
|
||||
return current_playlist is not None and current_playlist.id == playlist_id
|
||||
|
||||
|
||||
class PlaylistService:
|
||||
@@ -199,7 +200,7 @@ class PlaylistService:
|
||||
await self.playlist_repo.delete(playlist)
|
||||
logger.info("Deleted playlist %s for user %s", playlist_id, user_id)
|
||||
|
||||
# If the deleted playlist was current, reload player to use main playlist fallback
|
||||
# If the deleted playlist was current, reload player to use main fallback
|
||||
if was_current:
|
||||
await _reload_player_playlist()
|
||||
|
||||
|
||||
@@ -140,7 +140,10 @@ class SoundNormalizerService:
|
||||
stream = ffmpeg.overwrite_output(stream)
|
||||
|
||||
await asyncio.to_thread(
|
||||
ffmpeg.run, stream, quiet=True, overwrite_output=True,
|
||||
ffmpeg.run,
|
||||
stream,
|
||||
quiet=True,
|
||||
overwrite_output=True,
|
||||
)
|
||||
logger.info("One-pass normalization completed: %s", output_path)
|
||||
|
||||
@@ -180,7 +183,10 @@ class SoundNormalizerService:
|
||||
# Run first pass and capture output
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
ffmpeg.run, stream, capture_stderr=True, quiet=True,
|
||||
ffmpeg.run,
|
||||
stream,
|
||||
capture_stderr=True,
|
||||
quiet=True,
|
||||
)
|
||||
analysis_output = result[1].decode("utf-8")
|
||||
except ffmpeg.Error as e:
|
||||
@@ -262,7 +268,10 @@ class SoundNormalizerService:
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
ffmpeg.run, stream, quiet=True, overwrite_output=True,
|
||||
ffmpeg.run,
|
||||
stream,
|
||||
quiet=True,
|
||||
overwrite_output=True,
|
||||
)
|
||||
logger.info("Two-pass normalization completed: %s", output_path)
|
||||
except ffmpeg.Error as e:
|
||||
|
||||
@@ -40,6 +40,7 @@ def requires_credits(
|
||||
return True
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
@@ -70,7 +71,8 @@ def requires_credits(
|
||||
|
||||
# Validate credits before execution
|
||||
await credit_service.validate_and_reserve_credits(
|
||||
user_id, action_type,
|
||||
user_id,
|
||||
action_type,
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
@@ -86,10 +88,14 @@ def requires_credits(
|
||||
finally:
|
||||
# Deduct credits based on success
|
||||
await credit_service.deduct_credits(
|
||||
user_id, action_type, success=success, metadata=metadata,
|
||||
user_id,
|
||||
action_type,
|
||||
success=success,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -111,6 +117,7 @@ def validate_credits_only(
|
||||
Decorated function that validates credits only
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
@@ -141,6 +148,7 @@ def validate_credits_only(
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@@ -173,7 +181,8 @@ class CreditManager:
|
||||
async def __aenter__(self) -> "CreditManager":
|
||||
"""Enter context manager - validate credits."""
|
||||
await self.credit_service.validate_and_reserve_credits(
|
||||
self.user_id, self.action_type,
|
||||
self.user_id,
|
||||
self.action_type,
|
||||
)
|
||||
self.validated = True
|
||||
return self
|
||||
@@ -189,7 +198,10 @@ class CreditManager:
|
||||
# If no exception occurred, consider it successful
|
||||
success = exc_type is None and self.success
|
||||
await self.credit_service.deduct_credits(
|
||||
self.user_id, self.action_type, success=success, metadata=self.metadata,
|
||||
self.user_id,
|
||||
self.action_type,
|
||||
success=success,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
def mark_success(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user