Refactor test cases for improved readability and consistency
All checks were successful
Backend CI / lint (push) Successful in 9m49s
Backend CI / test (push) Successful in 6m15s

- Adjusted function signatures in various test files to enhance clarity by aligning parameters.
- Updated patching syntax for better readability across test cases.
- Improved formatting and spacing in test assertions and mock setups.
- Ensured consistent use of async/await patterns in async test functions.
- Enhanced comments for better understanding of test intentions.
This commit is contained in:
JSC
2025-08-01 20:53:30 +02:00
parent d926779fe4
commit 6068599a47
39 changed files with 691 additions and 286 deletions

2
.gitignore vendored
View File

@@ -9,3 +9,5 @@ wheels/
# Virtual environments # Virtual environments
.venv .venv
.env .env
.coverage

View File

@@ -32,7 +32,7 @@ async def get_sound_normalizer_service(
# SCAN ENDPOINTS # SCAN ENDPOINTS
@router.post("/scan") @router.post("/scan")
async def scan_sounds( async def scan_sounds(
current_user: Annotated[User, Depends(get_admin_user)], current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)], scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
) -> dict[str, ScanResults | str]: ) -> dict[str, ScanResults | str]:
"""Sync the soundboard directory (add/update/delete sounds). Admin only.""" """Sync the soundboard directory (add/update/delete sounds). Admin only."""
@@ -53,11 +53,11 @@ async def scan_sounds(
@router.post("/scan/custom") @router.post("/scan/custom")
async def scan_custom_directory( async def scan_custom_directory(
directory: str, directory: str,
current_user: Annotated[User, Depends(get_admin_user)], current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)], scanner_service: Annotated[SoundScannerService, Depends(get_sound_scanner_service)],
sound_type: str = "SDB", sound_type: str = "SDB",
) -> dict[str, ScanResults | str]: ) -> dict[str, ScanResults | str]:
"""Sync a custom directory with the database (add/update/delete sounds). Admin only.""" """Sync a custom directory with the database. Admin only."""
try: try:
results = await scanner_service.scan_directory(directory, sound_type) results = await scanner_service.scan_directory(directory, sound_type)
except ValueError as e: except ValueError as e:
@@ -80,14 +80,15 @@ async def scan_custom_directory(
# NORMALIZE ENDPOINTS # NORMALIZE ENDPOINTS
@router.post("/normalize/all") @router.post("/normalize/all")
async def normalize_all_sounds( async def normalize_all_sounds(
current_user: Annotated[User, Depends(get_admin_user)], current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, SoundNormalizerService,
Depends(get_sound_normalizer_service), Depends(get_sound_normalizer_service),
], ],
*,
force: Annotated[ force: Annotated[
bool, bool,
Query( # noqa: FBT002 Query(
description="Force normalization of already normalized sounds", description="Force normalization of already normalized sounds",
), ),
] = False, ] = False,
@@ -119,14 +120,15 @@ async def normalize_all_sounds(
@router.post("/normalize/type/{sound_type}") @router.post("/normalize/type/{sound_type}")
async def normalize_sounds_by_type( async def normalize_sounds_by_type(
sound_type: str, sound_type: str,
current_user: Annotated[User, Depends(get_admin_user)], current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, SoundNormalizerService,
Depends(get_sound_normalizer_service), Depends(get_sound_normalizer_service),
], ],
*,
force: Annotated[ force: Annotated[
bool, bool,
Query( # noqa: FBT002 Query(
description="Force normalization of already normalized sounds", description="Force normalization of already normalized sounds",
), ),
] = False, ] = False,
@@ -167,14 +169,15 @@ async def normalize_sounds_by_type(
@router.post("/normalize/{sound_id}") @router.post("/normalize/{sound_id}")
async def normalize_sound_by_id( async def normalize_sound_by_id(
sound_id: int, sound_id: int,
current_user: Annotated[User, Depends(get_admin_user)], current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
normalizer_service: Annotated[ normalizer_service: Annotated[
SoundNormalizerService, SoundNormalizerService,
Depends(get_sound_normalizer_service), Depends(get_sound_normalizer_service),
], ],
*,
force: Annotated[ force: Annotated[
bool, bool,
Query( # noqa: FBT002 Query(
description="Force normalization of already normalized sound", description="Force normalization of already normalized sound",
), ),
] = False, ] = False,

View File

@@ -2,7 +2,7 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db

View File

@@ -2,7 +2,7 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db, get_session_factory from app.core.database import get_db, get_session_factory
@@ -18,7 +18,6 @@ from app.services.vlc_player import VLCPlayerService, get_vlc_player_service
router = APIRouter(prefix="/sounds", tags=["sounds"]) router = APIRouter(prefix="/sounds", tags=["sounds"])
async def get_extraction_service( async def get_extraction_service(
session: Annotated[AsyncSession, Depends(get_db)], session: Annotated[AsyncSession, Depends(get_db)],
) -> ExtractionService: ) -> ExtractionService:
@@ -43,7 +42,6 @@ async def get_sound_repository(
return SoundRepository(session) return SoundRepository(session)
# EXTRACT # EXTRACT
@router.post("/extract") @router.post("/extract")
async def create_extraction( async def create_extraction(
@@ -60,7 +58,8 @@ async def create_extraction(
) )
extraction_info = await extraction_service.create_extraction( extraction_info = await extraction_service.create_extraction(
url, current_user.id, url,
current_user.id,
) )
# Queue the extraction for background processing # Queue the extraction for background processing
@@ -83,8 +82,6 @@ async def create_extraction(
} }
@router.get("/extract/{extraction_id}") @router.get("/extract/{extraction_id}")
async def get_extraction( async def get_extraction(
extraction_id: int, extraction_id: int,
@@ -206,7 +203,6 @@ async def play_sound_with_vlc(
} }
@router.post("/stop") @router.post("/stop")
async def stop_all_vlc_instances( async def stop_all_vlc_instances(
current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001 current_user: Annotated[User, Depends(get_current_active_user_flexible)], # noqa: ARG001

View File

@@ -40,8 +40,10 @@ async def get_db() -> AsyncGenerator[AsyncSession, None]:
def get_session_factory() -> Callable[[], AsyncSession]: def get_session_factory() -> Callable[[], AsyncSession]:
"""Get a session factory function for services.""" """Get a session factory function for services."""
def session_factory() -> AsyncSession: def session_factory() -> AsyncSession:
return AsyncSession(engine) return AsyncSession(engine)
return session_factory return session_factory

View File

@@ -30,9 +30,7 @@ class Sound(BaseModel, table=True):
is_deletable: bool = Field(default=True, nullable=False) is_deletable: bool = Field(default=True, nullable=False)
# constraints # constraints
__table_args__ = ( __table_args__ = (UniqueConstraint("hash", name="uq_sound_hash"),)
UniqueConstraint("hash", name="uq_sound_hash"),
)
# relationships # relationships
playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound") playlist_sounds: list["PlaylistSound"] = Relationship(back_populates="sound")

View File

@@ -43,7 +43,9 @@ class BaseRepository[ModelType]:
return result.first() return result.first()
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to get %s by ID: %s", self.model.__name__, entity_id, "Failed to get %s by ID: %s",
self.model.__name__,
entity_id,
) )
raise raise

View File

@@ -91,8 +91,7 @@ class CreditTransactionRepository(BaseRepository[CreditTransaction]):
""" """
stmt = ( stmt = (
select(CreditTransaction) select(CreditTransaction).where(CreditTransaction.success == True) # noqa: E712
.where(CreditTransaction.success == True) # noqa: E712
) )
if user_id is not None: if user_id is not None:

View File

@@ -1,6 +1,5 @@
"""Extraction repository for database operations.""" """Extraction repository for database operations."""
from sqlalchemy import desc from sqlalchemy import desc
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -17,12 +16,15 @@ class ExtractionRepository(BaseRepository[Extraction]):
super().__init__(Extraction, session) super().__init__(Extraction, session)
async def get_by_service_and_id( async def get_by_service_and_id(
self, service: str, service_id: str, self,
service: str,
service_id: str,
) -> Extraction | None: ) -> Extraction | None:
"""Get an extraction by service and service_id.""" """Get an extraction by service and service_id."""
result = await self.session.exec( result = await self.session.exec(
select(Extraction).where( select(Extraction).where(
Extraction.service == service, Extraction.service_id == service_id, Extraction.service == service,
Extraction.service_id == service_id,
), ),
) )
return result.first() return result.first()

View File

@@ -1,6 +1,5 @@
"""Playlist repository for database operations.""" """Playlist repository for database operations."""
from sqlalchemy import func from sqlalchemy import func
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -66,7 +65,9 @@ class PlaylistRepository(BaseRepository[Playlist]):
raise raise
async def search_by_name( async def search_by_name(
self, query: str, user_id: int | None = None, self,
query: str,
user_id: int | None = None,
) -> list[Playlist]: ) -> list[Playlist]:
"""Search playlists by name (case-insensitive).""" """Search playlists by name (case-insensitive)."""
try: try:
@@ -98,7 +99,10 @@ class PlaylistRepository(BaseRepository[Playlist]):
raise raise
async def add_sound_to_playlist( async def add_sound_to_playlist(
self, playlist_id: int, sound_id: int, position: int | None = None, self,
playlist_id: int,
sound_id: int,
position: int | None = None,
) -> PlaylistSound: ) -> PlaylistSound:
"""Add a sound to a playlist.""" """Add a sound to a playlist."""
try: try:
@@ -121,7 +125,9 @@ class PlaylistRepository(BaseRepository[Playlist]):
except Exception: except Exception:
await self.session.rollback() await self.session.rollback()
logger.exception( logger.exception(
"Failed to add sound %s to playlist %s", sound_id, playlist_id, "Failed to add sound %s to playlist %s",
sound_id,
playlist_id,
) )
raise raise
else: else:
@@ -150,12 +156,16 @@ class PlaylistRepository(BaseRepository[Playlist]):
except Exception: except Exception:
await self.session.rollback() await self.session.rollback()
logger.exception( logger.exception(
"Failed to remove sound %s from playlist %s", sound_id, playlist_id, "Failed to remove sound %s from playlist %s",
sound_id,
playlist_id,
) )
raise raise
async def reorder_playlist_sounds( async def reorder_playlist_sounds(
self, playlist_id: int, sound_positions: list[tuple[int, int]], self,
playlist_id: int,
sound_positions: list[tuple[int, int]],
) -> None: ) -> None:
"""Reorder sounds in a playlist. """Reorder sounds in a playlist.
@@ -220,6 +230,8 @@ class PlaylistRepository(BaseRepository[Playlist]):
return result.first() is not None return result.first() is not None
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to check if sound %s is in playlist %s", sound_id, playlist_id, "Failed to check if sound %s is in playlist %s",
sound_id,
playlist_id,
) )
raise raise

View File

@@ -91,6 +91,7 @@ class SoundRepository(BaseRepository[Sound]):
return list(result.all()) return list(result.all())
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to get unnormalized sounds by type: %s", sound_type, "Failed to get unnormalized sounds by type: %s",
sound_type,
) )
raise raise

View File

@@ -1,6 +1,5 @@
"""Repository for user OAuth operations.""" """Repository for user OAuth operations."""
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -60,4 +59,3 @@ class UserOauthRepository(BaseRepository[UserOauth]):
raise raise
else: else:
return result.first() return result.first()

View File

@@ -30,17 +30,21 @@ class PlayerStateResponse(BaseModel):
status: str = Field(description="Player status (playing, paused, stopped)") status: str = Field(description="Player status (playing, paused, stopped)")
current_sound: dict[str, Any] | None = Field( current_sound: dict[str, Any] | None = Field(
None, description="Current sound information", None,
description="Current sound information",
) )
playlist: dict[str, Any] | None = Field( playlist: dict[str, Any] | None = Field(
None, description="Current playlist information", None,
description="Current playlist information",
) )
position: int = Field(description="Current position in milliseconds") position: int = Field(description="Current position in milliseconds")
duration: int | None = Field( duration: int | None = Field(
None, description="Total duration in milliseconds", None,
description="Total duration in milliseconds",
) )
volume: int = Field(description="Current volume (0-100)") volume: int = Field(description="Current volume (0-100)")
mode: str = Field(description="Current playback mode") mode: str = Field(description="Current playback mode")
index: int | None = Field( index: int | None = Field(
None, description="Current track index in playlist", None,
description="Current track index in playlist",
) )

View File

@@ -156,7 +156,8 @@ class ExtractionService:
# Check if extraction already exists for this service # Check if extraction already exists for this service
existing = await self.extraction_repo.get_by_service_and_id( existing = await self.extraction_repo.get_by_service_and_id(
service_info["service"], service_info["service_id"], service_info["service"],
service_info["service_id"],
) )
if existing and existing.id != extraction_id: if existing and existing.id != extraction_id:
error_msg = ( error_msg = (
@@ -181,7 +182,8 @@ class ExtractionService:
# Extract audio and thumbnail # Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media( audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url, extraction_id,
extraction_url,
) )
# Move files to final locations # Move files to final locations
@@ -227,7 +229,9 @@ class ExtractionService:
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
logger.exception( logger.exception(
"Failed to process extraction %d: %s", extraction_id, error_msg, "Failed to process extraction %d: %s",
extraction_id,
error_msg,
) )
else: else:
return { return {
@@ -262,7 +266,9 @@ class ExtractionService:
} }
async def _extract_media( async def _extract_media(
self, extraction_id: int, extraction_url: str, self,
extraction_id: int,
extraction_url: str,
) -> tuple[Path, Path | None]: ) -> tuple[Path, Path | None]:
"""Extract audio and thumbnail using yt-dlp.""" """Extract audio and thumbnail using yt-dlp."""
temp_dir = Path(settings.EXTRACTION_TEMP_DIR) temp_dir = Path(settings.EXTRACTION_TEMP_DIR)

View File

@@ -65,7 +65,8 @@ class ExtractionProcessor:
# The processor will pick it up on the next cycle # The processor will pick it up on the next cycle
else: else:
logger.warning( logger.warning(
"Extraction %d is already being processed", extraction_id, "Extraction %d is already being processed",
extraction_id,
) )
async def _process_queue(self) -> None: async def _process_queue(self) -> None:

View File

@@ -35,10 +35,11 @@ async def _is_current_playlist(session: AsyncSession, playlist_id: int) -> bool:
playlist_repo = PlaylistRepository(session) playlist_repo = PlaylistRepository(session)
current_playlist = await playlist_repo.get_current_playlist() current_playlist = await playlist_repo.get_current_playlist()
return current_playlist is not None and current_playlist.id == playlist_id
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
logger.warning("Failed to check if playlist is current", exc_info=True) logger.warning("Failed to check if playlist is current", exc_info=True)
return False return False
else:
return current_playlist is not None and current_playlist.id == playlist_id
class PlaylistService: class PlaylistService:
@@ -199,7 +200,7 @@ class PlaylistService:
await self.playlist_repo.delete(playlist) await self.playlist_repo.delete(playlist)
logger.info("Deleted playlist %s for user %s", playlist_id, user_id) logger.info("Deleted playlist %s for user %s", playlist_id, user_id)
# If the deleted playlist was current, reload player to use main playlist fallback # If the deleted playlist was current, reload player to use main fallback
if was_current: if was_current:
await _reload_player_playlist() await _reload_player_playlist()

View File

@@ -140,7 +140,10 @@ class SoundNormalizerService:
stream = ffmpeg.overwrite_output(stream) stream = ffmpeg.overwrite_output(stream)
await asyncio.to_thread( await asyncio.to_thread(
ffmpeg.run, stream, quiet=True, overwrite_output=True, ffmpeg.run,
stream,
quiet=True,
overwrite_output=True,
) )
logger.info("One-pass normalization completed: %s", output_path) logger.info("One-pass normalization completed: %s", output_path)
@@ -180,7 +183,10 @@ class SoundNormalizerService:
# Run first pass and capture output # Run first pass and capture output
try: try:
result = await asyncio.to_thread( result = await asyncio.to_thread(
ffmpeg.run, stream, capture_stderr=True, quiet=True, ffmpeg.run,
stream,
capture_stderr=True,
quiet=True,
) )
analysis_output = result[1].decode("utf-8") analysis_output = result[1].decode("utf-8")
except ffmpeg.Error as e: except ffmpeg.Error as e:
@@ -262,7 +268,10 @@ class SoundNormalizerService:
try: try:
await asyncio.to_thread( await asyncio.to_thread(
ffmpeg.run, stream, quiet=True, overwrite_output=True, ffmpeg.run,
stream,
quiet=True,
overwrite_output=True,
) )
logger.info("Two-pass normalization completed: %s", output_path) logger.info("Two-pass normalization completed: %s", output_path)
except ffmpeg.Error as e: except ffmpeg.Error as e:

View File

@@ -40,6 +40,7 @@ def requires_credits(
return True return True
""" """
def decorator(func: F) -> F: def decorator(func: F) -> F:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
@@ -70,7 +71,8 @@ def requires_credits(
# Validate credits before execution # Validate credits before execution
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
user_id, action_type, user_id,
action_type,
) )
# Execute the function # Execute the function
@@ -86,10 +88,14 @@ def requires_credits(
finally: finally:
# Deduct credits based on success # Deduct credits based on success
await credit_service.deduct_credits( await credit_service.deduct_credits(
user_id, action_type, success=success, metadata=metadata, user_id,
action_type,
success=success,
metadata=metadata,
) )
return wrapper # type: ignore[return-value] return wrapper # type: ignore[return-value]
return decorator return decorator
@@ -111,6 +117,7 @@ def validate_credits_only(
Decorated function that validates credits only Decorated function that validates credits only
""" """
def decorator(func: F) -> F: def decorator(func: F) -> F:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 async def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
@@ -141,6 +148,7 @@ def validate_credits_only(
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper # type: ignore[return-value] return wrapper # type: ignore[return-value]
return decorator return decorator
@@ -173,7 +181,8 @@ class CreditManager:
async def __aenter__(self) -> "CreditManager": async def __aenter__(self) -> "CreditManager":
"""Enter context manager - validate credits.""" """Enter context manager - validate credits."""
await self.credit_service.validate_and_reserve_credits( await self.credit_service.validate_and_reserve_credits(
self.user_id, self.action_type, self.user_id,
self.action_type,
) )
self.validated = True self.validated = True
return self return self
@@ -189,7 +198,10 @@ class CreditManager:
# If no exception occurred, consider it successful # If no exception occurred, consider it successful
success = exc_type is None and self.success success = exc_type is None and self.success
await self.credit_service.deduct_credits( await self.credit_service.deduct_credits(
self.user_id, self.action_type, success=success, metadata=self.metadata, self.user_id,
self.action_type,
success=success,
metadata=self.metadata,
) )
def mark_success(self) -> None: def mark_success(self) -> None:

View File

@@ -73,7 +73,9 @@ class TestAdminSoundEndpoints:
) as mock_scan: ) as mock_scan:
mock_scan.return_value = mock_results mock_scan.return_value = mock_results
response = await authenticated_admin_client.post("/api/v1/admin/sounds/scan") response = await authenticated_admin_client.post(
"/api/v1/admin/sounds/scan",
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@@ -114,6 +116,7 @@ class TestAdminSoundEndpoints:
) -> None: ) -> None:
"""Test scanning sounds with non-admin user.""" """Test scanning sounds with non-admin user."""
from fastapi import HTTPException from fastapi import HTTPException
from app.core.dependencies import get_admin_user from app.core.dependencies import get_admin_user
# Override the admin dependency to raise 403 for non-admin users # Override the admin dependency to raise 403 for non-admin users
@@ -150,7 +153,9 @@ class TestAdminSoundEndpoints:
) as mock_scan: ) as mock_scan:
mock_scan.side_effect = Exception("Directory not found") mock_scan.side_effect = Exception("Directory not found")
response = await authenticated_admin_client.post("/api/v1/admin/sounds/scan") response = await authenticated_admin_client.post(
"/api/v1/admin/sounds/scan",
)
assert response.status_code == 500 assert response.status_code == 500
data = response.json() data = response.json()
@@ -300,7 +305,9 @@ class TestAdminSoundEndpoints:
assert len(results["files"]) == 3 assert len(results["files"]) == 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_normalize_all_sounds_unauthenticated(self, client: AsyncClient) -> None: async def test_normalize_all_sounds_unauthenticated(
self, client: AsyncClient,
) -> None:
"""Test normalizing sounds without authentication.""" """Test normalizing sounds without authentication."""
response = await client.post("/api/v1/admin/sounds/normalize/all") response = await client.post("/api/v1/admin/sounds/normalize/all")
@@ -316,6 +323,7 @@ class TestAdminSoundEndpoints:
) -> None: ) -> None:
"""Test normalizing sounds with non-admin user.""" """Test normalizing sounds with non-admin user."""
from fastapi import HTTPException from fastapi import HTTPException
from app.core.dependencies import get_admin_user from app.core.dependencies import get_admin_user
# Override the admin dependency to raise 403 for non-admin users # Override the admin dependency to raise 403 for non-admin users
@@ -331,7 +339,8 @@ class TestAdminSoundEndpoints:
base_url="http://test", base_url="http://test",
) as client: ) as client:
response = await client.post( response = await client.post(
"/api/v1/admin/sounds/normalize/all", headers=headers, "/api/v1/admin/sounds/normalize/all",
headers=headers,
) )
assert response.status_code == 403 assert response.status_code == 403
@@ -405,7 +414,9 @@ class TestAdminSoundEndpoints:
# Verify the service was called with correct type # Verify the service was called with correct type
mock_normalize.assert_called_once_with( mock_normalize.assert_called_once_with(
sound_type="SDB", force=False, one_pass=None, sound_type="SDB",
force=False,
one_pass=None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -491,7 +502,7 @@ class TestAdminSoundEndpoints:
) -> None: ) -> None:
"""Test getting extraction processor status.""" """Test getting extraction processor status."""
with patch( with patch(
"app.services.extraction_processor.extraction_processor.get_status" "app.services.extraction_processor.extraction_processor.get_status",
) as mock_get_status: ) as mock_get_status:
mock_status = { mock_status = {
"is_running": True, "is_running": True,
@@ -502,7 +513,7 @@ class TestAdminSoundEndpoints:
mock_get_status.return_value = mock_status mock_get_status.return_value = mock_status
response = await authenticated_admin_client.get( response = await authenticated_admin_client.get(
"/api/v1/admin/sounds/extract/status" "/api/v1/admin/sounds/extract/status",
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -511,7 +522,8 @@ class TestAdminSoundEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_extraction_processor_status_unauthenticated( async def test_get_extraction_processor_status_unauthenticated(
self, client: AsyncClient self,
client: AsyncClient,
) -> None: ) -> None:
"""Test getting extraction processor status without authentication.""" """Test getting extraction processor status without authentication."""
response = await client.get("/api/v1/admin/sounds/extract/status") response = await client.get("/api/v1/admin/sounds/extract/status")
@@ -528,6 +540,7 @@ class TestAdminSoundEndpoints:
) -> None: ) -> None:
"""Test getting extraction processor status with non-admin user.""" """Test getting extraction processor status with non-admin user."""
from fastapi import HTTPException from fastapi import HTTPException
from app.core.dependencies import get_admin_user from app.core.dependencies import get_admin_user
# Override the admin dependency to raise 403 for non-admin users # Override the admin dependency to raise 403 for non-admin users
@@ -543,7 +556,8 @@ class TestAdminSoundEndpoints:
base_url="http://test", base_url="http://test",
) as client: ) as client:
response = await client.get( response = await client.get(
"/api/v1/admin/sounds/extract/status", headers=headers "/api/v1/admin/sounds/extract/status",
headers=headers,
) )
assert response.status_code == 403 assert response.status_code == 403

View File

@@ -54,7 +54,11 @@ class TestApiTokenEndpoints:
expires_at_str = data["expires_at"] expires_at_str = data["expires_at"]
# Handle both ISO format with/without timezone info # Handle both ISO format with/without timezone info
if expires_at_str.endswith("Z") or "+" in expires_at_str or expires_at_str.count("-") > 2: if (
expires_at_str.endswith("Z")
or "+" in expires_at_str
or expires_at_str.count("-") > 2
):
expires_at = datetime.fromisoformat(expires_at_str) expires_at = datetime.fromisoformat(expires_at_str)
else: else:
# Naive datetime, assume UTC # Naive datetime, assume UTC
@@ -84,7 +88,11 @@ class TestApiTokenEndpoints:
expires_at_str = data["expires_at"] expires_at_str = data["expires_at"]
# Handle both ISO format with/without timezone info # Handle both ISO format with/without timezone info
if expires_at_str.endswith("Z") or "+" in expires_at_str or expires_at_str.count("-") > 2: if (
expires_at_str.endswith("Z")
or "+" in expires_at_str
or expires_at_str.count("-") > 2
):
expires_at = datetime.fromisoformat(expires_at_str) expires_at = datetime.fromisoformat(expires_at_str)
else: else:
# Naive datetime, assume UTC # Naive datetime, assume UTC
@@ -116,7 +124,9 @@ class TestApiTokenEndpoints:
assert response.status_code == 422 assert response.status_code == 422
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_api_token_unauthenticated(self, client: AsyncClient) -> None: async def test_generate_api_token_unauthenticated(
self, client: AsyncClient,
) -> None:
"""Test API token generation without authentication.""" """Test API token generation without authentication."""
response = await client.post( response = await client.post(
"/api/v1/auth/api-token", "/api/v1/auth/api-token",
@@ -186,7 +196,9 @@ class TestApiTokenEndpoints:
assert data["is_expired"] is True assert data["is_expired"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_api_token_status_unauthenticated(self, client: AsyncClient) -> None: async def test_get_api_token_status_unauthenticated(
self, client: AsyncClient,
) -> None:
"""Test getting API token status without authentication.""" """Test getting API token status without authentication."""
response = await client.get("/api/v1/auth/api-token/status") response = await client.get("/api/v1/auth/api-token/status")
assert response.status_code == 401 assert response.status_code == 401
@@ -264,7 +276,9 @@ class TestApiTokenEndpoints:
assert "email" in data assert "email" in data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_authentication_invalid_token(self, client: AsyncClient) -> None: async def test_api_token_authentication_invalid_token(
self, client: AsyncClient,
) -> None:
"""Test authentication with invalid API token.""" """Test authentication with invalid API token."""
headers = {"API-TOKEN": "invalid_token"} headers = {"API-TOKEN": "invalid_token"}
response = await client.get("/api/v1/auth/me", headers=headers) response = await client.get("/api/v1/auth/me", headers=headers)
@@ -297,7 +311,9 @@ class TestApiTokenEndpoints:
assert "API token has expired" in data["detail"] assert "API token has expired" in data["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_authentication_empty_token(self, client: AsyncClient) -> None: async def test_api_token_authentication_empty_token(
self, client: AsyncClient,
) -> None:
"""Test authentication with empty API-TOKEN header.""" """Test authentication with empty API-TOKEN header."""
# Empty token # Empty token
headers = {"API-TOKEN": ""} headers = {"API-TOKEN": ""}

View File

@@ -1,6 +1,5 @@
"""Tests for extraction API endpoints.""" """Tests for extraction API endpoints."""
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@@ -10,7 +9,9 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_extraction_success( async def test_create_extraction_success(
self, test_client: AsyncClient, auth_cookies: dict[str, str], self,
test_client: AsyncClient,
auth_cookies: dict[str, str],
) -> None: ) -> None:
"""Test successful extraction creation.""" """Test successful extraction creation."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -26,7 +27,9 @@ class TestExtractionEndpoints:
assert response.status_code in [200, 400, 500] # Allow any non-auth error assert response.status_code in [200, 400, 500] # Allow any non-auth error
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_extraction_unauthenticated(self, test_client: AsyncClient) -> None: async def test_create_extraction_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test extraction creation without authentication.""" """Test extraction creation without authentication."""
response = await test_client.post( response = await test_client.post(
"/api/v1/sounds/extract", "/api/v1/sounds/extract",
@@ -37,7 +40,9 @@ class TestExtractionEndpoints:
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_extraction_unauthenticated(self, test_client: AsyncClient) -> None: async def test_get_extraction_unauthenticated(
self, test_client: AsyncClient,
) -> None:
"""Test extraction retrieval without authentication.""" """Test extraction retrieval without authentication."""
response = await test_client.get("/api/v1/sounds/extract/1") response = await test_client.get("/api/v1/sounds/extract/1")
@@ -46,7 +51,9 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_processor_status_moved_to_admin( async def test_get_processor_status_moved_to_admin(
self, test_client: AsyncClient, admin_cookies: dict[str, str], self,
test_client: AsyncClient,
admin_cookies: dict[str, str],
) -> None: ) -> None:
"""Test that processor status endpoint was moved to admin.""" """Test that processor status endpoint was moved to admin."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning
@@ -61,7 +68,9 @@ class TestExtractionEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_user_extractions( async def test_get_user_extractions(
self, test_client: AsyncClient, auth_cookies: dict[str, str], self,
test_client: AsyncClient,
auth_cookies: dict[str, str],
) -> None: ) -> None:
"""Test getting user extractions.""" """Test getting user extractions."""
# Set cookies on client instance to avoid deprecation warning # Set cookies on client instance to avoid deprecation warning

View File

@@ -1,6 +1,5 @@
"""Tests for playlist API endpoints.""" """Tests for playlist API endpoints."""
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from httpx import AsyncClient from httpx import AsyncClient
@@ -348,7 +347,8 @@ class TestPlaylistEndpoints:
} }
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}", json=payload, f"/api/v1/playlists/{playlist_id}",
json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -386,7 +386,8 @@ class TestPlaylistEndpoints:
payload = {"name": "Updated Playlist", "description": "Updated description"} payload = {"name": "Updated Playlist", "description": "Updated description"}
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}", json=payload, f"/api/v1/playlists/{playlist_id}",
json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -613,7 +614,8 @@ class TestPlaylistEndpoints:
payload = {"sound_id": sound_id} payload = {"sound_id": sound_id}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload, f"/api/v1/playlists/{playlist_id}/sounds",
json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -670,7 +672,8 @@ class TestPlaylistEndpoints:
payload = {"sound_id": sound_id, "position": 5} payload = {"sound_id": sound_id, "position": 5}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload, f"/api/v1/playlists/{playlist_id}/sounds",
json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -718,13 +721,15 @@ class TestPlaylistEndpoints:
# Add sound first time # Add sound first time
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload, f"/api/v1/playlists/{playlist_id}/sounds",
json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200
# Try to add same sound again # Try to add same sound again
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload, f"/api/v1/playlists/{playlist_id}/sounds",
json=payload,
) )
assert response.status_code == 400 assert response.status_code == 400
assert "already in this playlist" in response.json()["detail"] assert "already in this playlist" in response.json()["detail"]
@@ -758,7 +763,8 @@ class TestPlaylistEndpoints:
payload = {"sound_id": 99999} payload = {"sound_id": 99999}
response = await authenticated_client.post( response = await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload, f"/api/v1/playlists/{playlist_id}/sounds",
json=payload,
) )
assert response.status_code == 404 assert response.status_code == 404
@@ -806,7 +812,8 @@ class TestPlaylistEndpoints:
# Add sound first # Add sound first
payload = {"sound_id": sound_id} payload = {"sound_id": sound_id}
await authenticated_client.post( await authenticated_client.post(
f"/api/v1/playlists/{playlist_id}/sounds", json=payload, f"/api/v1/playlists/{playlist_id}/sounds",
json=payload,
) )
# Remove sound # Remove sound
@@ -918,11 +925,15 @@ class TestPlaylistEndpoints:
# Reorder sounds - use positions that don't cause constraints # Reorder sounds - use positions that don't cause constraints
# When swapping, we need to be careful about unique constraints # When swapping, we need to be careful about unique constraints
payload = { payload = {
"sound_positions": [[sound1_id, 10], [sound2_id, 5]], # Use different positions to avoid constraints "sound_positions": [
[sound1_id, 10],
[sound2_id, 5],
], # Use different positions to avoid constraints
} }
response = await authenticated_client.put( response = await authenticated_client.put(
f"/api/v1/playlists/{playlist_id}/sounds/reorder", json=payload, f"/api/v1/playlists/{playlist_id}/sounds/reorder",
json=payload,
) )
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -158,7 +158,9 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_missing_parameters( async def test_send_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User, self,
authenticated_client: AsyncClient,
authenticated_user: User,
) -> None: ) -> None:
"""Test sending message with missing parameters.""" """Test sending message with missing parameters."""
# Missing target_user_id # Missing target_user_id
@@ -177,7 +179,9 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_broadcast_message_missing_parameters( async def test_broadcast_message_missing_parameters(
self, authenticated_client: AsyncClient, authenticated_user: User, self,
authenticated_client: AsyncClient,
authenticated_user: User,
) -> None: ) -> None:
"""Test broadcasting message with missing parameters.""" """Test broadcasting message with missing parameters."""
response = await authenticated_client.post("/api/v1/socket/broadcast") response = await authenticated_client.post("/api/v1/socket/broadcast")
@@ -185,7 +189,9 @@ class TestSocketEndpoints:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_invalid_user_id( async def test_send_message_invalid_user_id(
self, authenticated_client: AsyncClient, authenticated_user: User, self,
authenticated_client: AsyncClient,
authenticated_user: User,
) -> None: ) -> None:
"""Test sending message with invalid user ID.""" """Test sending message with invalid user ID."""
response = await authenticated_client.post( response = await authenticated_client.post(

View File

@@ -35,10 +35,10 @@ class TestSoundEndpoints:
with ( with (
patch( patch(
"app.services.extraction.ExtractionService.create_extraction" "app.services.extraction.ExtractionService.create_extraction",
) as mock_create, ) as mock_create,
patch( patch(
"app.services.extraction_processor.extraction_processor.queue_extraction" "app.services.extraction_processor.extraction_processor.queue_extraction",
) as mock_queue, ) as mock_queue,
): ):
mock_create.return_value = mock_extraction_info mock_create.return_value = mock_extraction_info
@@ -53,7 +53,10 @@ class TestSoundEndpoints:
data = response.json() data = response.json()
assert data["message"] == "Extraction queued successfully" assert data["message"] == "Extraction queued successfully"
assert data["extraction"]["id"] == 1 assert data["extraction"]["id"] == 1
assert data["extraction"]["url"] == "https://www.youtube.com/watch?v=dQw4w9WgXcQ" assert (
data["extraction"]["url"]
== "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_extraction_unauthenticated(self, client: AsyncClient) -> None: async def test_create_extraction_unauthenticated(self, client: AsyncClient) -> None:
@@ -75,7 +78,7 @@ class TestSoundEndpoints:
) -> None: ) -> None:
"""Test extraction creation with invalid URL.""" """Test extraction creation with invalid URL."""
with patch( with patch(
"app.services.extraction.ExtractionService.create_extraction" "app.services.extraction.ExtractionService.create_extraction",
) as mock_create: ) as mock_create:
mock_create.side_effect = ValueError("Invalid URL") mock_create.side_effect = ValueError("Invalid URL")
@@ -107,7 +110,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.extraction.ExtractionService.get_extraction_by_id" "app.services.extraction.ExtractionService.get_extraction_by_id",
) as mock_get: ) as mock_get:
mock_get.return_value = mock_extraction_info mock_get.return_value = mock_extraction_info
@@ -128,7 +131,7 @@ class TestSoundEndpoints:
) -> None: ) -> None:
"""Test getting non-existent extraction.""" """Test getting non-existent extraction."""
with patch( with patch(
"app.services.extraction.ExtractionService.get_extraction_by_id" "app.services.extraction.ExtractionService.get_extraction_by_id",
) as mock_get: ) as mock_get:
mock_get.return_value = None mock_get.return_value = None
@@ -169,7 +172,7 @@ class TestSoundEndpoints:
] ]
with patch( with patch(
"app.services.extraction.ExtractionService.get_user_extractions" "app.services.extraction.ExtractionService.get_user_extractions",
) as mock_get: ) as mock_get:
mock_get.return_value = mock_extractions mock_get.return_value = mock_extractions
@@ -202,7 +205,9 @@ class TestSoundEndpoints:
with ( with (
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
patch("app.services.credit.CreditService.validate_and_reserve_credits") as mock_validate, patch(
"app.services.credit.CreditService.validate_and_reserve_credits",
) as mock_validate,
patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play, patch("app.services.vlc_player.VLCPlayerService.play_sound") as mock_play,
patch("app.services.credit.CreditService.deduct_credits") as mock_deduct, patch("app.services.credit.CreditService.deduct_credits") as mock_deduct,
): ):
@@ -227,7 +232,9 @@ class TestSoundEndpoints:
authenticated_user: User, authenticated_user: User,
) -> None: ) -> None:
"""Test playing non-existent sound with VLC.""" """Test playing non-existent sound with VLC."""
with patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound: with patch(
"app.repositories.sound.SoundRepository.get_by_id",
) as mock_get_sound:
mock_get_sound.return_value = None mock_get_sound.return_value = None
response = await authenticated_client.post("/api/v1/sounds/play/999") response = await authenticated_client.post("/api/v1/sounds/play/999")
@@ -259,11 +266,14 @@ class TestSoundEndpoints:
with ( with (
patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound, patch("app.repositories.sound.SoundRepository.get_by_id") as mock_get_sound,
patch("app.services.credit.CreditService.validate_and_reserve_credits") as mock_validate, patch(
"app.services.credit.CreditService.validate_and_reserve_credits",
) as mock_validate,
): ):
mock_get_sound.return_value = mock_sound mock_get_sound.return_value = mock_sound
mock_validate.side_effect = InsufficientCreditsError( mock_validate.side_effect = InsufficientCreditsError(
required=1, available=0 required=1,
available=0,
) )
response = await authenticated_client.post("/api/v1/sounds/play/1") response = await authenticated_client.post("/api/v1/sounds/play/1")
@@ -286,7 +296,7 @@ class TestSoundEndpoints:
} }
with patch( with patch(
"app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances" "app.services.vlc_player.VLCPlayerService.stop_all_vlc_instances",
) as mock_stop: ) as mock_stop:
mock_stop.return_value = mock_result mock_stop.return_value = mock_result

View File

@@ -57,7 +57,8 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_no_header( async def test_get_current_user_api_token_no_header(
self, mock_auth_service: AsyncMock, self,
mock_auth_service: AsyncMock,
) -> None: ) -> None:
"""Test API token authentication without API-TOKEN header.""" """Test API token authentication without API-TOKEN header."""
with pytest.raises(HTTPException) as exc_info: with pytest.raises(HTTPException) as exc_info:
@@ -68,7 +69,8 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_empty_token( async def test_get_current_user_api_token_empty_token(
self, mock_auth_service: AsyncMock, self,
mock_auth_service: AsyncMock,
) -> None: ) -> None:
"""Test API token authentication with empty token.""" """Test API token authentication with empty token."""
api_token_header = " " api_token_header = " "
@@ -81,7 +83,8 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_whitespace_token( async def test_get_current_user_api_token_whitespace_token(
self, mock_auth_service: AsyncMock, self,
mock_auth_service: AsyncMock,
) -> None: ) -> None:
"""Test API token authentication with whitespace-only token.""" """Test API token authentication with whitespace-only token."""
api_token_header = " " api_token_header = " "
@@ -94,7 +97,8 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_invalid_token( async def test_get_current_user_api_token_invalid_token(
self, mock_auth_service: AsyncMock, self,
mock_auth_service: AsyncMock,
) -> None: ) -> None:
"""Test API token authentication with invalid token.""" """Test API token authentication with invalid token."""
mock_auth_service.get_user_by_api_token.return_value = None mock_auth_service.get_user_by_api_token.return_value = None
@@ -146,7 +150,8 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_api_token_service_exception( async def test_get_current_user_api_token_service_exception(
self, mock_auth_service: AsyncMock, self,
mock_auth_service: AsyncMock,
) -> None: ) -> None:
"""Test API token authentication with service exception.""" """Test API token authentication with service exception."""
mock_auth_service.get_user_by_api_token.side_effect = Exception( mock_auth_service.get_user_by_api_token.side_effect = Exception(
@@ -186,7 +191,8 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_flexible_falls_back_to_jwt( async def test_get_current_user_flexible_falls_back_to_jwt(
self, mock_auth_service: AsyncMock, self,
mock_auth_service: AsyncMock,
) -> None: ) -> None:
"""Test flexible authentication falls back to JWT when no API token.""" """Test flexible authentication falls back to JWT when no API token."""
# Mock the get_current_user function (normally imported) # Mock the get_current_user function (normally imported)
@@ -197,7 +203,9 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_no_expiry_never_expires( async def test_api_token_no_expiry_never_expires(
self, mock_auth_service: AsyncMock, test_user: User, self,
mock_auth_service: AsyncMock,
test_user: User,
) -> None: ) -> None:
"""Test API token with no expiry date never expires.""" """Test API token with no expiry date never expires."""
test_user.api_token_expires_at = None test_user.api_token_expires_at = None
@@ -211,7 +219,9 @@ class TestApiTokenDependencies:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_token_with_whitespace( async def test_api_token_with_whitespace(
self, mock_auth_service: AsyncMock, test_user: User, self,
mock_auth_service: AsyncMock,
test_user: User,
) -> None: ) -> None:
"""Test API token with leading/trailing whitespace is handled correctly.""" """Test API token with leading/trailing whitespace is handled correctly."""
mock_auth_service.get_user_by_api_token.return_value = test_user mock_auth_service.get_user_by_api_token.return_value = test_user

View File

@@ -202,13 +202,17 @@ class TestCreditTransactionRepository:
"""Test getting transactions by user ID with pagination.""" """Test getting transactions by user ID with pagination."""
# Get first 2 transactions # Get first 2 transactions
first_page = await credit_transaction_repository.get_by_user_id( first_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=0, test_user_id,
limit=2,
offset=0,
) )
assert len(first_page) == PAGE_SIZE assert len(first_page) == PAGE_SIZE
# Get next 2 transactions # Get next 2 transactions
second_page = await credit_transaction_repository.get_by_user_id( second_page = await credit_transaction_repository.get_by_user_id(
test_user_id, limit=2, offset=2, test_user_id,
limit=2,
offset=2,
) )
assert len(second_page) == PAGE_SIZE assert len(second_page) == PAGE_SIZE
@@ -251,14 +255,17 @@ class TestCreditTransactionRepository:
"""Test getting transactions by action type with pagination.""" """Test getting transactions by action type with pagination."""
# Test with limit # Test with limit
transactions = await credit_transaction_repository.get_by_action_type( transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, "vlc_play_sound",
limit=1,
) )
assert len(transactions) == 1 assert len(transactions) == 1
assert transactions[0].action_type == "vlc_play_sound" assert transactions[0].action_type == "vlc_play_sound"
# Test with offset # Test with offset
transactions = await credit_transaction_repository.get_by_action_type( transactions = await credit_transaction_repository.get_by_action_type(
"vlc_play_sound", limit=1, offset=1, "vlc_play_sound",
limit=1,
offset=1,
) )
assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total assert len(transactions) <= 1 # Might be 0 if only 1 VLC transaction in total
@@ -269,7 +276,9 @@ class TestCreditTransactionRepository:
test_transactions: list[CreditTransaction], test_transactions: list[CreditTransaction],
) -> None: ) -> None:
"""Test getting only successful transactions.""" """Test getting only successful transactions."""
successful_transactions = await credit_transaction_repository.get_successful_transactions() successful_transactions = (
await credit_transaction_repository.get_successful_transactions()
)
# Should only return successful transactions # Should only return successful transactions
assert all(t.success is True for t in successful_transactions) assert all(t.success is True for t in successful_transactions)
@@ -285,8 +294,10 @@ class TestCreditTransactionRepository:
test_user_id: int, test_user_id: int,
) -> None: ) -> None:
"""Test getting successful transactions filtered by user.""" """Test getting successful transactions filtered by user."""
successful_transactions = await credit_transaction_repository.get_successful_transactions( successful_transactions = (
user_id=test_user_id, await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id,
)
) )
# Should only return successful transactions for test_user # Should only return successful transactions for test_user
@@ -305,14 +316,18 @@ class TestCreditTransactionRepository:
"""Test getting successful transactions with pagination.""" """Test getting successful transactions with pagination."""
# Get first 2 successful transactions # Get first 2 successful transactions
first_page = await credit_transaction_repository.get_successful_transactions( first_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=0, user_id=test_user_id,
limit=2,
offset=0,
) )
assert len(first_page) == PAGE_SIZE assert len(first_page) == PAGE_SIZE
assert all(t.success is True for t in first_page) assert all(t.success is True for t in first_page)
# Get next successful transaction # Get next successful transaction
second_page = await credit_transaction_repository.get_successful_transactions( second_page = await credit_transaction_repository.get_successful_transactions(
user_id=test_user_id, limit=2, offset=2, user_id=test_user_id,
limit=2,
offset=2,
) )
assert len(second_page) == 1 # Should be 1 remaining assert len(second_page) == 1 # Should be 1 remaining
assert all(t.success is True for t in second_page) assert all(t.success is True for t in second_page)
@@ -328,7 +343,9 @@ class TestCreditTransactionRepository:
all_transactions = await credit_transaction_repository.get_all() all_transactions = await credit_transaction_repository.get_all()
# Should return all transactions # Should return all transactions
assert len(all_transactions) >= MIN_ALL_TRANSACTIONS # 4 from test_transactions + 1 other_user_transaction assert (
len(all_transactions) >= MIN_ALL_TRANSACTIONS
) # 4 from test_transactions + 1 other_user_transaction
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_transaction( async def test_create_transaction(
@@ -374,7 +391,8 @@ class TestCreditTransactionRepository:
} }
updated_transaction = await credit_transaction_repository.update( updated_transaction = await credit_transaction_repository.update(
transaction, update_data, transaction,
update_data,
) )
assert updated_transaction.id == transaction.id assert updated_transaction.id == transaction.id
@@ -412,7 +430,9 @@ class TestCreditTransactionRepository:
# Verify transaction is deleted # Verify transaction is deleted
assert transaction_id is not None assert transaction_id is not None
deleted_transaction = await credit_transaction_repository.get_by_id(transaction_id) deleted_transaction = await credit_transaction_repository.get_by_id(
transaction_id,
)
assert deleted_transaction is None assert deleted_transaction is None
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -407,7 +407,8 @@ class TestPlaylistRepository:
# Test the repository method # Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist( playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id, playlist_id,
sound_id,
) )
assert playlist_sound.playlist_id == playlist_id assert playlist_sound.playlist_id == playlist_id
@@ -472,7 +473,9 @@ class TestPlaylistRepository:
# Test the repository method # Test the repository method
playlist_sound = await playlist_repository.add_sound_to_playlist( playlist_sound = await playlist_repository.add_sound_to_playlist(
playlist_id, sound_id, position=5, playlist_id,
sound_id,
position=5,
) )
assert playlist_sound.position == TEST_POSITION assert playlist_sound.position == TEST_POSITION
@@ -535,17 +538,20 @@ class TestPlaylistRepository:
# Verify it was added # Verify it was added
assert await playlist_repository.is_sound_in_playlist( assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id, playlist_id,
sound_id,
) )
# Remove the sound # Remove the sound
await playlist_repository.remove_sound_from_playlist( await playlist_repository.remove_sound_from_playlist(
playlist_id, sound_id, playlist_id,
sound_id,
) )
# Verify it was removed # Verify it was removed
assert not await playlist_repository.is_sound_in_playlist( assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id, playlist_id,
sound_id,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -732,7 +738,8 @@ class TestPlaylistRepository:
# Initially not in playlist # Initially not in playlist
assert not await playlist_repository.is_sound_in_playlist( assert not await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id, playlist_id,
sound_id,
) )
# Add sound # Add sound
@@ -740,7 +747,8 @@ class TestPlaylistRepository:
# Now in playlist # Now in playlist
assert await playlist_repository.is_sound_in_playlist( assert await playlist_repository.is_sound_in_playlist(
playlist_id, sound_id, playlist_id,
sound_id,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -794,16 +802,21 @@ class TestPlaylistRepository:
# Add sounds to playlist # Add sounds to playlist
await playlist_repository.add_sound_to_playlist( await playlist_repository.add_sound_to_playlist(
playlist_id, sound1_id, position=0, playlist_id,
sound1_id,
position=0,
) )
await playlist_repository.add_sound_to_playlist( await playlist_repository.add_sound_to_playlist(
playlist_id, sound2_id, position=1, playlist_id,
sound2_id,
position=1,
) )
# Reorder sounds - use different positions to avoid constraint issues # Reorder sounds - use different positions to avoid constraint issues
sound_positions = [(sound1_id, 10), (sound2_id, 5)] sound_positions = [(sound1_id, 10), (sound2_id, 5)]
await playlist_repository.reorder_playlist_sounds( await playlist_repository.reorder_playlist_sounds(
playlist_id, sound_positions, playlist_id,
sound_positions,
) )
# Verify new order # Verify new order
@@ -863,10 +876,14 @@ class TestPlaylistRepository:
# Add sounds to playlist at positions 0 and 1 # Add sounds to playlist at positions 0 and 1
await playlist_repository.add_sound_to_playlist( await playlist_repository.add_sound_to_playlist(
playlist_id, sound1_id, position=0, playlist_id,
sound1_id,
position=0,
) )
await playlist_repository.add_sound_to_playlist( await playlist_repository.add_sound_to_playlist(
playlist_id, sound2_id, position=1, playlist_id,
sound2_id,
position=1,
) )
# Verify initial order # Verify initial order
@@ -878,7 +895,8 @@ class TestPlaylistRepository:
# Swap positions - this used to cause unique constraint violation # Swap positions - this used to cause unique constraint violation
sound_positions = [(sound1_id, 1), (sound2_id, 0)] sound_positions = [(sound1_id, 1), (sound2_id, 0)]
await playlist_repository.reorder_playlist_sounds( await playlist_repository.reorder_playlist_sounds(
playlist_id, sound_positions, playlist_id,
sound_positions,
) )
# Verify swapped order # Verify swapped order

View File

@@ -61,7 +61,8 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by provider user ID when it exists.""" """Test getting OAuth by provider user ID when it exists."""
oauth = await user_oauth_repository.get_by_provider_user_id( oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "google_123456", "google",
"google_123456",
) )
assert oauth is not None assert oauth is not None
@@ -77,7 +78,8 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by provider user ID when it doesn't exist.""" """Test getting OAuth by provider user ID when it doesn't exist."""
oauth = await user_oauth_repository.get_by_provider_user_id( oauth = await user_oauth_repository.get_by_provider_user_id(
"google", "nonexistent_id", "google",
"nonexistent_id",
) )
assert oauth is None assert oauth is None
@@ -91,7 +93,8 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by user ID and provider when it exists.""" """Test getting OAuth by user ID and provider when it exists."""
oauth = await user_oauth_repository.get_by_user_id_and_provider( oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google", test_user_id,
"google",
) )
assert oauth is not None assert oauth is not None
@@ -107,7 +110,8 @@ class TestUserOauthRepository:
) -> None: ) -> None:
"""Test getting OAuth by user ID and provider when it doesn't exist.""" """Test getting OAuth by user ID and provider when it doesn't exist."""
oauth = await user_oauth_repository.get_by_user_id_and_provider( oauth = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github", test_user_id,
"github",
) )
assert oauth is None assert oauth is None
@@ -186,7 +190,8 @@ class TestUserOauthRepository:
# Verify it's deleted by trying to find it # Verify it's deleted by trying to find it
deleted_oauth = await user_oauth_repository.get_by_provider_user_id( deleted_oauth = await user_oauth_repository.get_by_provider_user_id(
"twitter", "twitter_456", "twitter",
"twitter_456",
) )
assert deleted_oauth is None assert deleted_oauth is None
@@ -243,10 +248,12 @@ class TestUserOauthRepository:
# Verify both exist by querying back from database # Verify both exist by querying back from database
found_google = await user_oauth_repository.get_by_user_id_and_provider( found_google = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "google", test_user_id,
"google",
) )
found_github = await user_oauth_repository.get_by_user_id_and_provider( found_github = await user_oauth_repository.get_by_user_id_and_provider(
test_user_id, "github", test_user_id,
"github",
) )
assert found_google is not None assert found_google is not None
@@ -260,10 +267,12 @@ class TestUserOauthRepository:
# Verify we can also find them by provider_user_id # Verify we can also find them by provider_user_id
found_google_by_provider = await user_oauth_repository.get_by_provider_user_id( found_google_by_provider = await user_oauth_repository.get_by_provider_user_id(
"google", "google_user_1", "google",
"google_user_1",
) )
found_github_by_provider = await user_oauth_repository.get_by_provider_user_id( found_github_by_provider = await user_oauth_repository.get_by_provider_user_id(
"github", "github_user_1", "github",
"github_user_1",
) )
assert found_google_by_provider is not None assert found_google_by_provider is not None

View File

@@ -48,7 +48,9 @@ class TestCreditService:
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND) result = await credit_service.check_credits(
1, CreditActionType.VLC_PLAY_SOUND,
)
assert result is True assert result is True
mock_repo.get_by_id.assert_called_once_with(1) mock_repo.get_by_id.assert_called_once_with(1)
@@ -72,7 +74,9 @@ class TestCreditService:
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user mock_repo.get_by_id.return_value = poor_user
result = await credit_service.check_credits(1, CreditActionType.VLC_PLAY_SOUND) result = await credit_service.check_credits(
1, CreditActionType.VLC_PLAY_SOUND,
)
assert result is False assert result is False
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -87,13 +91,17 @@ class TestCreditService:
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = None mock_repo.get_by_id.return_value = None
result = await credit_service.check_credits(999, CreditActionType.VLC_PLAY_SOUND) result = await credit_service.check_credits(
999, CreditActionType.VLC_PLAY_SOUND,
)
assert result is False assert result is False
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_and_reserve_credits_success(self, credit_service, sample_user) -> None: async def test_validate_and_reserve_credits_success(
self, credit_service, sample_user,
) -> None:
"""Test successful credit validation and reservation.""" """Test successful credit validation and reservation."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
@@ -103,7 +111,8 @@ class TestCreditService:
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
user, action = await credit_service.validate_and_reserve_credits( user, action = await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND, 1,
CreditActionType.VLC_PLAY_SOUND,
) )
assert user == sample_user assert user == sample_user
@@ -112,7 +121,9 @@ class TestCreditService:
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_and_reserve_credits_insufficient(self, credit_service) -> None: async def test_validate_and_reserve_credits_insufficient(
self, credit_service,
) -> None:
"""Test credit validation with insufficient credits.""" """Test credit validation with insufficient credits."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
poor_user = User( poor_user = User(
@@ -131,7 +142,8 @@ class TestCreditService:
with pytest.raises(InsufficientCreditsError) as exc_info: with pytest.raises(InsufficientCreditsError) as exc_info:
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
1, CreditActionType.VLC_PLAY_SOUND, 1,
CreditActionType.VLC_PLAY_SOUND,
) )
assert exc_info.value.required == 1 assert exc_info.value.required == 1
@@ -139,7 +151,9 @@ class TestCreditService:
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_and_reserve_credits_user_not_found(self, credit_service) -> None: async def test_validate_and_reserve_credits_user_not_found(
self, credit_service,
) -> None:
"""Test credit validation when user is not found.""" """Test credit validation when user is not found."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
@@ -150,7 +164,8 @@ class TestCreditService:
with pytest.raises(ValueError, match="User 999 not found"): with pytest.raises(ValueError, match="User 999 not found"):
await credit_service.validate_and_reserve_credits( await credit_service.validate_and_reserve_credits(
999, CreditActionType.VLC_PLAY_SOUND, 999,
CreditActionType.VLC_PLAY_SOUND,
) )
mock_session.close.assert_called_once() mock_session.close.assert_called_once()
@@ -160,15 +175,20 @@ class TestCreditService:
"""Test successful credit deduction.""" """Test successful credit deduction."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \ with (
patch("app.services.credit.socket_manager") as mock_socket_manager: patch("app.services.credit.UserRepository") as mock_repo_class,
patch("app.services.credit.socket_manager") as mock_socket_manager,
):
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
await credit_service.deduct_credits( await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, success=True, metadata={"test": "data"}, 1,
CreditActionType.VLC_PLAY_SOUND,
success=True,
metadata={"test": "data"},
) )
# Verify user credits were updated # Verify user credits were updated
@@ -180,7 +200,9 @@ class TestCreditService:
# Verify socket event was emitted # Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with( mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", { "1",
"user_credits_changed",
{
"user_id": "1", "user_id": "1",
"credits_before": 10, "credits_before": 10,
"credits_after": 9, "credits_after": 9,
@@ -202,19 +224,25 @@ class TestCreditService:
assert json.loads(added_transaction.metadata_json) == {"test": "data"} assert json.loads(added_transaction.metadata_json) == {"test": "data"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deduct_credits_failed_action_requires_success(self, credit_service, sample_user) -> None: async def test_deduct_credits_failed_action_requires_success(
self, credit_service, sample_user,
) -> None:
"""Test credit deduction when action failed but requires success.""" """Test credit deduction when action failed but requires success."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \ with (
patch("app.services.credit.socket_manager") as mock_socket_manager: patch("app.services.credit.UserRepository") as mock_repo_class,
patch("app.services.credit.socket_manager") as mock_socket_manager,
):
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
await credit_service.deduct_credits( await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, success=False, # Action failed 1,
CreditActionType.VLC_PLAY_SOUND,
success=False, # Action failed
) )
# Verify user credits were NOT updated (action requires success) # Verify user credits were NOT updated (action requires success)
@@ -247,8 +275,10 @@ class TestCreditService:
plan_id=1, plan_id=1,
) )
with patch("app.services.credit.UserRepository") as mock_repo_class, \ with (
patch("app.services.credit.socket_manager") as mock_socket_manager: patch("app.services.credit.UserRepository") as mock_repo_class,
patch("app.services.credit.socket_manager") as mock_socket_manager,
):
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = poor_user mock_repo.get_by_id.return_value = poor_user
@@ -256,7 +286,9 @@ class TestCreditService:
with pytest.raises(InsufficientCreditsError): with pytest.raises(InsufficientCreditsError):
await credit_service.deduct_credits( await credit_service.deduct_credits(
1, CreditActionType.VLC_PLAY_SOUND, success=True, 1,
CreditActionType.VLC_PLAY_SOUND,
success=True,
) )
# Verify no socket event was emitted since credits could not be deducted # Verify no socket event was emitted since credits could not be deducted
@@ -270,15 +302,20 @@ class TestCreditService:
"""Test adding credits to user account.""" """Test adding credits to user account."""
mock_session = credit_service.db_session_factory() mock_session = credit_service.db_session_factory()
with patch("app.services.credit.UserRepository") as mock_repo_class, \ with (
patch("app.services.credit.socket_manager") as mock_socket_manager: patch("app.services.credit.UserRepository") as mock_repo_class,
patch("app.services.credit.socket_manager") as mock_socket_manager,
):
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
mock_repo.get_by_id.return_value = sample_user mock_repo.get_by_id.return_value = sample_user
mock_socket_manager.send_to_user = AsyncMock() mock_socket_manager.send_to_user = AsyncMock()
await credit_service.add_credits( await credit_service.add_credits(
1, 5, "Bonus credits", {"reason": "signup"}, 1,
5,
"Bonus credits",
{"reason": "signup"},
) )
# Verify user credits were updated # Verify user credits were updated
@@ -290,7 +327,9 @@ class TestCreditService:
# Verify socket event was emitted # Verify socket event was emitted
mock_socket_manager.send_to_user.assert_called_once_with( mock_socket_manager.send_to_user.assert_called_once_with(
"1", "user_credits_changed", { "1",
"user_credits_changed",
{
"user_id": "1", "user_id": "1",
"credits_before": 10, "credits_before": 10,
"credits_after": 15, "credits_after": 15,

View File

@@ -53,7 +53,9 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_service_info_youtube( async def test_detect_service_info_youtube(
self, mock_ydl_class, extraction_service, self,
mock_ydl_class,
extraction_service,
) -> None: ) -> None:
"""Test service detection for YouTube.""" """Test service detection for YouTube."""
mock_ydl = Mock() mock_ydl = Mock()
@@ -78,7 +80,9 @@ class TestExtractionService:
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_service_info_failure( async def test_detect_service_info_failure(
self, mock_ydl_class, extraction_service, self,
mock_ydl_class,
extraction_service,
) -> None: ) -> None:
"""Test service detection failure.""" """Test service detection failure."""
mock_ydl = Mock() mock_ydl = Mock()
@@ -170,7 +174,9 @@ class TestExtractionService:
assert result["status"] == "pending" assert result["status"] == "pending"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_process_extraction_with_service_detection(self, extraction_service) -> None: async def test_process_extraction_with_service_detection(
self, extraction_service,
) -> None:
"""Test extraction processing with service detection.""" """Test extraction processing with service detection."""
extraction_id = 1 extraction_id = 1
@@ -202,14 +208,18 @@ class TestExtractionService:
with ( with (
patch.object( patch.object(
extraction_service, "_detect_service_info", return_value=service_info, extraction_service,
"_detect_service_info",
return_value=service_info,
), ),
patch.object(extraction_service, "_extract_media") as mock_extract, patch.object(extraction_service, "_extract_media") as mock_extract,
patch.object( patch.object(
extraction_service, "_move_files_to_final_location", extraction_service,
"_move_files_to_final_location",
) as mock_move, ) as mock_move,
patch.object( patch.object(
extraction_service, "_create_sound_record", extraction_service,
"_create_sound_record",
) as mock_create_sound, ) as mock_create_sound,
patch.object(extraction_service, "_normalize_sound"), patch.object(extraction_service, "_normalize_sound"),
patch.object(extraction_service, "_add_to_main_playlist"), patch.object(extraction_service, "_add_to_main_playlist"),
@@ -289,11 +299,13 @@ class TestExtractionService:
with ( with (
patch( patch(
"app.services.extraction.get_audio_duration", return_value=240000, "app.services.extraction.get_audio_duration",
return_value=240000,
), ),
patch("app.services.extraction.get_file_size", return_value=1024), patch("app.services.extraction.get_file_size", return_value=1024),
patch( patch(
"app.services.extraction.get_file_hash", return_value="test_hash", "app.services.extraction.get_file_hash",
return_value="test_hash",
), ),
): ):
extraction_service.sound_repo.create = AsyncMock( extraction_service.sound_repo.create = AsyncMock(

View File

@@ -29,7 +29,9 @@ class TestExtractionProcessor:
"""Test starting and stopping the processor.""" """Test starting and stopping the processor."""
# Mock the _process_queue method to avoid actual processing # Mock the _process_queue method to avoid actual processing
with patch.object( with patch.object(
processor, "_process_queue", new_callable=AsyncMock, processor,
"_process_queue",
new_callable=AsyncMock,
): ):
# Start the processor # Start the processor
await processor.start() await processor.start()
@@ -229,7 +231,9 @@ class TestExtractionProcessor:
"app.services.extraction_processor.AsyncSession", "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch.object( patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock, processor,
"_process_single_extraction",
new_callable=AsyncMock,
), ),
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",
@@ -274,7 +278,9 @@ class TestExtractionProcessor:
"app.services.extraction_processor.AsyncSession", "app.services.extraction_processor.AsyncSession",
) as mock_session_class, ) as mock_session_class,
patch.object( patch.object(
processor, "_process_single_extraction", new_callable=AsyncMock, processor,
"_process_single_extraction",
new_callable=AsyncMock,
), ),
patch( patch(
"app.services.extraction_processor.ExtractionService", "app.services.extraction_processor.ExtractionService",

View File

@@ -131,11 +131,15 @@ class TestPlayerService:
yield mock yield mock
@pytest.fixture @pytest.fixture
def player_service(self, mock_db_session_factory, mock_vlc_instance, mock_socket_manager): def player_service(
self, mock_db_session_factory, mock_vlc_instance, mock_socket_manager,
):
"""Create a player service instance for testing.""" """Create a player service instance for testing."""
return PlayerService(mock_db_session_factory) return PlayerService(mock_db_session_factory)
def test_init_creates_player_service(self, mock_db_session_factory, mock_vlc_instance) -> None: def test_init_creates_player_service(
self, mock_db_session_factory, mock_vlc_instance,
) -> None:
"""Test that player service initializes correctly.""" """Test that player service initializes correctly."""
with patch("app.services.player.socket_manager"): with patch("app.services.player.socket_manager"):
service = PlayerService(mock_db_session_factory) service = PlayerService(mock_db_session_factory)
@@ -152,7 +156,9 @@ class TestPlayerService:
assert service._loop is None assert service._loop is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_start_initializes_service(self, player_service, mock_vlc_instance) -> None: async def test_start_initializes_service(
self, player_service, mock_vlc_instance,
) -> None:
"""Test that start method initializes the service.""" """Test that start method initializes the service."""
with patch.object(player_service, "reload_playlist", new_callable=AsyncMock): with patch.object(player_service, "reload_playlist", new_callable=AsyncMock):
await player_service.start() await player_service.start()
@@ -197,7 +203,9 @@ class TestPlayerService:
mock_file_path.exists.return_value = True mock_file_path.exists.return_value = True
mock_path.return_value = mock_file_path mock_path.return_value = mock_file_path
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock): with patch.object(
player_service, "_broadcast_state", new_callable=AsyncMock,
):
mock_media = Mock() mock_media = Mock()
player_service._vlc_instance.media_new.return_value = mock_media player_service._vlc_instance.media_new.return_value = mock_media
player_service._player.play.return_value = 0 # Success player_service._player.play.return_value = 0 # Success
@@ -252,7 +260,9 @@ class TestPlayerService:
"""Test pausing when not playing does nothing.""" """Test pausing when not playing does nothing."""
player_service.state.status = PlayerStatus.STOPPED player_service.state.status = PlayerStatus.STOPPED
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock) as mock_broadcast: with patch.object(
player_service, "_broadcast_state", new_callable=AsyncMock,
) as mock_broadcast:
await player_service.pause() await player_service.pause()
assert player_service.state.status == PlayerStatus.STOPPED assert player_service.state.status == PlayerStatus.STOPPED
@@ -264,8 +274,12 @@ class TestPlayerService:
player_service.state.status = PlayerStatus.PLAYING player_service.state.status = PlayerStatus.PLAYING
player_service.state.current_sound_position = 5000 player_service.state.current_sound_position = 5000
with patch.object(player_service, "_process_play_count", new_callable=AsyncMock): with patch.object(
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock): player_service, "_process_play_count", new_callable=AsyncMock,
):
with patch.object(
player_service, "_broadcast_state", new_callable=AsyncMock,
):
await player_service.stop_playback() await player_service.stop_playback()
assert player_service.state.status == PlayerStatus.STOPPED assert player_service.state.status == PlayerStatus.STOPPED
@@ -314,7 +328,9 @@ class TestPlayerService:
"""Test seeking when stopped does nothing.""" """Test seeking when stopped does nothing."""
player_service.state.status = PlayerStatus.STOPPED player_service.state.status = PlayerStatus.STOPPED
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock) as mock_broadcast: with patch.object(
player_service, "_broadcast_state", new_callable=AsyncMock,
) as mock_broadcast:
await player_service.seek(15000) await player_service.seek(15000)
player_service._player.set_position.assert_not_called() player_service._player.set_position.assert_not_called()
@@ -364,7 +380,9 @@ class TestPlayerService:
mock_playlist = Mock() mock_playlist = Mock()
mock_playlist.id = 1 mock_playlist.id = 1
mock_playlist.name = "Test Playlist" mock_playlist.name = "Test Playlist"
mock_repo.get_current_playlist.return_value = mock_playlist # Return current playlist directly mock_repo.get_current_playlist.return_value = (
mock_playlist # Return current playlist directly
)
# Mock sounds # Mock sounds
sound1 = Sound(id=1, name="Song 1", filename="song1.mp3", duration=30000) sound1 = Sound(id=1, name="Song 1", filename="song1.mp3", duration=30000)
@@ -372,7 +390,9 @@ class TestPlayerService:
mock_sounds = [sound1, sound2] mock_sounds = [sound1, sound2]
mock_repo.get_playlist_sounds.return_value = mock_sounds mock_repo.get_playlist_sounds.return_value = mock_sounds
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock): with patch.object(
player_service, "_broadcast_state", new_callable=AsyncMock,
):
await player_service.reload_playlist() await player_service.reload_playlist()
assert player_service.state.playlist_id == 1 assert player_service.state.playlist_id == 1
@@ -394,7 +414,9 @@ class TestPlayerService:
sound2 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000) sound2 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000)
sounds = [sound1, sound2] sounds = [sound1, sound2]
with patch.object(player_service, "_stop_playback", new_callable=AsyncMock) as mock_stop: with patch.object(
player_service, "_stop_playback", new_callable=AsyncMock,
) as mock_stop:
await player_service._handle_playlist_id_changed(1, 2, sounds) await player_service._handle_playlist_id_changed(1, 2, sounds)
# Should stop playback and set first track as current # Should stop playback and set first track as current
@@ -404,11 +426,15 @@ class TestPlayerService:
assert player_service.state.current_sound_id == 1 assert player_service.state.current_sound_id == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_playlist_id_changed_empty_playlist(self, player_service) -> None: async def test_handle_playlist_id_changed_empty_playlist(
self, player_service,
) -> None:
"""Test handling playlist ID change with empty playlist.""" """Test handling playlist ID change with empty playlist."""
player_service.state.status = PlayerStatus.PLAYING player_service.state.status = PlayerStatus.PLAYING
with patch.object(player_service, "_stop_playback", new_callable=AsyncMock) as mock_stop: with patch.object(
player_service, "_stop_playback", new_callable=AsyncMock,
) as mock_stop:
await player_service._handle_playlist_id_changed(1, 2, []) await player_service._handle_playlist_id_changed(1, 2, [])
mock_stop.assert_called_once() mock_stop.assert_called_once()
@@ -417,7 +443,9 @@ class TestPlayerService:
assert player_service.state.current_sound_id is None assert player_service.state.current_sound_id is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_same_playlist_track_exists_same_index(self, player_service) -> None: async def test_handle_same_playlist_track_exists_same_index(
self, player_service,
) -> None:
"""Test handling same playlist when track exists at same index.""" """Test handling same playlist when track exists at same index."""
sound1 = Sound(id=1, name="Song 1", filename="song1.mp3", duration=30000) sound1 = Sound(id=1, name="Song 1", filename="song1.mp3", duration=30000)
sound2 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000) sound2 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000)
@@ -426,11 +454,15 @@ class TestPlayerService:
await player_service._handle_same_playlist_track_check(1, 0, sounds) await player_service._handle_same_playlist_track_check(1, 0, sounds)
# Should update sound object reference but keep same index # Should update sound object reference but keep same index
assert player_service.state.current_sound_index == 0 # Should be set to 0 from new_index assert (
player_service.state.current_sound_index == 0
) # Should be set to 0 from new_index
assert player_service.state.current_sound == sound1 assert player_service.state.current_sound == sound1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_same_playlist_track_exists_different_index(self, player_service) -> None: async def test_handle_same_playlist_track_exists_different_index(
self, player_service,
) -> None:
"""Test handling same playlist when track exists at different index.""" """Test handling same playlist when track exists at different index."""
sound1 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000) sound1 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000)
sound2 = Sound(id=1, name="Song 1", filename="song1.mp3", duration=30000) sound2 = Sound(id=1, name="Song 1", filename="song1.mp3", duration=30000)
@@ -450,7 +482,9 @@ class TestPlayerService:
sound2 = Sound(id=3, name="Song 3", filename="song3.mp3", duration=60000) sound2 = Sound(id=3, name="Song 3", filename="song3.mp3", duration=60000)
sounds = [sound1, sound2] # Track with ID 1 is missing sounds = [sound1, sound2] # Track with ID 1 is missing
with patch.object(player_service, "_handle_track_removed", new_callable=AsyncMock) as mock_removed: with patch.object(
player_service, "_handle_track_removed", new_callable=AsyncMock,
) as mock_removed:
await player_service._handle_same_playlist_track_check(1, 0, sounds) await player_service._handle_same_playlist_track_check(1, 0, sounds)
mock_removed.assert_called_once_with(1, sounds) mock_removed.assert_called_once_with(1, sounds)
@@ -461,7 +495,9 @@ class TestPlayerService:
sound1 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000) sound1 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000)
sounds = [sound1] sounds = [sound1]
with patch.object(player_service, "_stop_playback", new_callable=AsyncMock) as mock_stop: with patch.object(
player_service, "_stop_playback", new_callable=AsyncMock,
) as mock_stop:
await player_service._handle_track_removed(1, sounds) await player_service._handle_track_removed(1, sounds)
mock_stop.assert_called_once() mock_stop.assert_called_once()
@@ -474,7 +510,9 @@ class TestPlayerService:
"""Test handling when current track is removed with empty playlist.""" """Test handling when current track is removed with empty playlist."""
player_service.state.status = PlayerStatus.PLAYING player_service.state.status = PlayerStatus.PLAYING
with patch.object(player_service, "_stop_playback", new_callable=AsyncMock) as mock_stop: with patch.object(
player_service, "_stop_playback", new_callable=AsyncMock,
) as mock_stop:
await player_service._handle_track_removed(1, []) await player_service._handle_track_removed(1, [])
mock_stop.assert_called_once() mock_stop.assert_called_once()
@@ -562,14 +600,20 @@ class TestPlayerService:
mock_playlist = Mock() mock_playlist = Mock()
mock_playlist.id = 2 # Different ID mock_playlist.id = 2 # Different ID
mock_playlist.name = "New Playlist" mock_playlist.name = "New Playlist"
mock_repo.get_current_playlist.return_value = mock_playlist # Return current playlist directly mock_repo.get_current_playlist.return_value = (
mock_playlist # Return current playlist directly
)
sound1 = Sound(id=1, name="Song 1", filename="song1.mp3", duration=30000) sound1 = Sound(id=1, name="Song 1", filename="song1.mp3", duration=30000)
mock_sounds = [sound1] mock_sounds = [sound1]
mock_repo.get_playlist_sounds.return_value = mock_sounds mock_repo.get_playlist_sounds.return_value = mock_sounds
with patch.object(player_service, "_stop_playback", new_callable=AsyncMock) as mock_stop: with patch.object(
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock): player_service, "_stop_playback", new_callable=AsyncMock,
) as mock_stop:
with patch.object(
player_service, "_broadcast_state", new_callable=AsyncMock,
):
await player_service.reload_playlist() await player_service.reload_playlist()
# Should stop and reset to first track # Should stop and reset to first track
@@ -597,7 +641,9 @@ class TestPlayerService:
mock_playlist = Mock() mock_playlist = Mock()
mock_playlist.id = 1 mock_playlist.id = 1
mock_playlist.name = "Same Playlist" mock_playlist.name = "Same Playlist"
mock_repo.get_current_playlist.return_value = mock_playlist # Return current playlist directly mock_repo.get_current_playlist.return_value = (
mock_playlist # Return current playlist directly
)
# Track 2 moved to index 0 # Track 2 moved to index 0
sound1 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000) sound1 = Sound(id=2, name="Song 2", filename="song2.mp3", duration=45000)
@@ -605,7 +651,9 @@ class TestPlayerService:
mock_sounds = [sound1, sound2] # Track 2 now at index 0 mock_sounds = [sound1, sound2] # Track 2 now at index 0
mock_repo.get_playlist_sounds.return_value = mock_sounds mock_repo.get_playlist_sounds.return_value = mock_sounds
with patch.object(player_service, "_broadcast_state", new_callable=AsyncMock): with patch.object(
player_service, "_broadcast_state", new_callable=AsyncMock,
):
await player_service.reload_playlist() await player_service.reload_playlist()
# Should update index but keep same track # Should update index but keep same track
@@ -614,7 +662,6 @@ class TestPlayerService:
assert player_service.state.current_sound_id == 2 # Same track assert player_service.state.current_sound_id == 2 # Same track
assert player_service.state.current_sound == sound1 assert player_service.state.current_sound == sound1
def test_get_next_index_continuous_mode(self, player_service) -> None: def test_get_next_index_continuous_mode(self, player_service) -> None:
"""Test getting next index in continuous mode.""" """Test getting next index in continuous mode."""
player_service.state.mode = PlayerMode.CONTINUOUS player_service.state.mode = PlayerMode.CONTINUOUS
@@ -734,7 +781,8 @@ class TestPlayerService:
# Verify sound play count was updated # Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
mock_sound, {"play_count": 6}, mock_sound,
{"play_count": 6},
) )
# Verify SoundPlayed record was created with None user_id for player # Verify SoundPlayed record was created with None user_id for player

View File

@@ -98,7 +98,11 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_success( async def test_connect_handler_success(
self, mock_decode, mock_extract_token, socket_manager, mock_sio, self,
mock_decode,
mock_extract_token,
socket_manager,
mock_sio,
) -> None: ) -> None:
"""Test successful connection with valid token.""" """Test successful connection with valid token."""
# Setup mocks # Setup mocks
@@ -132,7 +136,10 @@ class TestSocketManager:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
async def test_connect_handler_no_token( async def test_connect_handler_no_token(
self, mock_extract_token, socket_manager, mock_sio, self,
mock_extract_token,
socket_manager,
mock_sio,
) -> None: ) -> None:
"""Test connection with no access token.""" """Test connection with no access token."""
# Setup mocks # Setup mocks
@@ -165,7 +172,11 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_invalid_token( async def test_connect_handler_invalid_token(
self, mock_decode, mock_extract_token, socket_manager, mock_sio, self,
mock_decode,
mock_extract_token,
socket_manager,
mock_sio,
) -> None: ) -> None:
"""Test connection with invalid token.""" """Test connection with invalid token."""
# Setup mocks # Setup mocks
@@ -199,7 +210,11 @@ class TestSocketManager:
@patch("app.services.socket.extract_access_token_from_cookies") @patch("app.services.socket.extract_access_token_from_cookies")
@patch("app.services.socket.JWTUtils.decode_access_token") @patch("app.services.socket.JWTUtils.decode_access_token")
async def test_connect_handler_missing_user_id( async def test_connect_handler_missing_user_id(
self, mock_decode, mock_extract_token, socket_manager, mock_sio, self,
mock_decode,
mock_extract_token,
socket_manager,
mock_sio,
) -> None: ) -> None:
"""Test connection with token missing user ID.""" """Test connection with token missing user ID."""
# Setup mocks # Setup mocks
@@ -254,7 +269,9 @@ class TestSocketManager:
assert "123" not in socket_manager.user_rooms assert "123" not in socket_manager.user_rooms
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_disconnect_handler_unknown_socket(self, socket_manager, mock_sio) -> None: async def test_disconnect_handler_unknown_socket(
self, socket_manager, mock_sio,
) -> None:
"""Test disconnect handler with unknown socket.""" """Test disconnect handler with unknown socket."""
# Access the disconnect handler directly # Access the disconnect handler directly
handlers = {} handlers = {}

View File

@@ -154,7 +154,9 @@ class TestSoundNormalizerService:
assert result["id"] == 1 assert result["id"] == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_normalize_sound_force_already_normalized(self, normalizer_service) -> None: async def test_normalize_sound_force_already_normalized(
self, normalizer_service,
) -> None:
"""Test force normalizing a sound that's already normalized.""" """Test force normalizing a sound that's already normalized."""
sound = Sound( sound = Sound(
id=1, id=1,
@@ -172,14 +174,17 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path, patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path, patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
patch.object( patch.object(
normalizer_service, "_normalize_audio_two_pass", normalizer_service,
"_normalize_audio_two_pass",
), ),
patch( patch(
"app.services.sound_normalizer.get_audio_duration", return_value=6000, "app.services.sound_normalizer.get_audio_duration",
return_value=6000,
), ),
patch("app.services.sound_normalizer.get_file_size", return_value=2048), patch("app.services.sound_normalizer.get_file_size", return_value=2048),
patch( patch(
"app.services.sound_normalizer.get_file_hash", return_value="new_hash", "app.services.sound_normalizer.get_file_hash",
return_value="new_hash",
), ),
): ):
# Setup path mocks # Setup path mocks
@@ -245,14 +250,17 @@ class TestSoundNormalizerService:
patch.object(normalizer_service, "_get_original_path") as mock_orig_path, patch.object(normalizer_service, "_get_original_path") as mock_orig_path,
patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path, patch.object(normalizer_service, "_get_normalized_path") as mock_norm_path,
patch.object( patch.object(
normalizer_service, "_normalize_audio_one_pass", normalizer_service,
"_normalize_audio_one_pass",
) as mock_normalize, ) as mock_normalize,
patch( patch(
"app.services.sound_normalizer.get_audio_duration", return_value=5500, "app.services.sound_normalizer.get_audio_duration",
return_value=5500,
), ),
patch("app.services.sound_normalizer.get_file_size", return_value=1500), patch("app.services.sound_normalizer.get_file_size", return_value=1500),
patch( patch(
"app.services.sound_normalizer.get_file_hash", return_value="norm_hash", "app.services.sound_normalizer.get_file_hash",
return_value="norm_hash",
), ),
): ):
# Setup path mocks # Setup path mocks
@@ -275,7 +283,9 @@ class TestSoundNormalizerService:
mock_normalize.assert_called_once() mock_normalize.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_normalize_sound_normalization_error(self, normalizer_service) -> None: async def test_normalize_sound_normalization_error(
self, normalizer_service,
) -> None:
"""Test handling normalization errors.""" """Test handling normalization errors."""
sound = Sound( sound = Sound(
id=1, id=1,
@@ -300,7 +310,8 @@ class TestSoundNormalizerService:
with ( with (
patch("pathlib.Path.exists", return_value=True), patch("pathlib.Path.exists", return_value=True),
patch.object( patch.object(
normalizer_service, "_normalize_audio_two_pass", normalizer_service,
"_normalize_audio_two_pass",
) as mock_normalize, ) as mock_normalize,
): ):
mock_normalize.side_effect = Exception("Normalization failed") mock_normalize.side_effect = Exception("Normalization failed")
@@ -529,7 +540,11 @@ class TestSoundNormalizerService:
# Verify ffmpeg chain was called correctly # Verify ffmpeg chain was called correctly
mock_ffmpeg.input.assert_called_once_with(str(input_path)) mock_ffmpeg.input.assert_called_once_with(str(input_path))
mock_ffmpeg.filter.assert_called_once_with( mock_ffmpeg.filter.assert_called_once_with(
mock_stream, "loudnorm", I=-23, TP=-2, LRA=7, mock_stream,
"loudnorm",
I=-23,
TP=-2,
LRA=7,
) )
mock_ffmpeg.output.assert_called_once() mock_ffmpeg.output.assert_called_once()
mock_ffmpeg.run.assert_called_once() mock_ffmpeg.run.assert_called_once()

View File

@@ -153,7 +153,10 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "SDB", existing_sound, results, temp_path,
"SDB",
existing_sound,
results,
) )
assert results["skipped"] == 1 assert results["skipped"] == 1
@@ -257,7 +260,10 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "SDB", existing_sound, results, temp_path,
"SDB",
existing_sound,
results,
) )
assert results["updated"] == 1 assert results["updated"] == 1
@@ -296,7 +302,8 @@ class TestSoundScannerService:
# Mock file operations # Mock file operations
with ( with (
patch( patch(
"app.services.sound_scanner.get_file_hash", return_value="custom_hash", "app.services.sound_scanner.get_file_hash",
return_value="custom_hash",
), ),
patch("app.services.sound_scanner.get_audio_duration", return_value=60000), patch("app.services.sound_scanner.get_audio_duration", return_value=60000),
patch("app.services.sound_scanner.get_file_size", return_value=2048), patch("app.services.sound_scanner.get_file_size", return_value=2048),
@@ -316,7 +323,10 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( await scanner_service._sync_audio_file(
temp_path, "CUSTOM", None, results, temp_path,
"CUSTOM",
None,
results,
) )
assert results["added"] == 1 assert results["added"] == 1

View File

@@ -80,6 +80,7 @@ class TestVLCPlayerService:
# Mock Path to return True for the first absolute path # Mock Path to return True for the first absolute path
with patch("app.services.vlc_player.Path") as mock_path: with patch("app.services.vlc_player.Path") as mock_path:
def path_side_effect(path_str): def path_side_effect(path_str):
mock_instance = Mock() mock_instance = Mock()
mock_instance.exists.return_value = str(path_str) == "/usr/bin/vlc" mock_instance.exists.return_value = str(path_str) == "/usr/bin/vlc"
@@ -105,11 +106,13 @@ class TestVLCPlayerService:
service = VLCPlayerService() service = VLCPlayerService()
assert service.vlc_executable == "vlc" assert service.vlc_executable == "vlc"
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_success( async def test_play_sound_success(
self, mock_subprocess, vlc_service, sample_sound, self,
mock_subprocess,
vlc_service,
sample_sound,
) -> None: ) -> None:
"""Test successful sound playback.""" """Test successful sound playback."""
# Mock subprocess # Mock subprocess
@@ -142,7 +145,9 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_play_sound_file_not_found( async def test_play_sound_file_not_found(
self, vlc_service, sample_sound, self,
vlc_service,
sample_sound,
) -> None: ) -> None:
"""Test sound playback when file doesn't exist.""" """Test sound playback when file doesn't exist."""
# Mock the file path utility to return a non-existent path # Mock the file path utility to return a non-existent path
@@ -158,7 +163,10 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_subprocess_error( async def test_play_sound_subprocess_error(
self, mock_subprocess, vlc_service, sample_sound, self,
mock_subprocess,
vlc_service,
sample_sound,
) -> None: ) -> None:
"""Test sound playback when subprocess fails.""" """Test sound playback when subprocess fails."""
# Mock the file path utility to return an existing path # Mock the file path utility to return an existing path
@@ -176,7 +184,9 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_success(self, mock_subprocess, vlc_service) -> None: async def test_stop_all_vlc_instances_success(
self, mock_subprocess, vlc_service,
) -> None:
"""Test successful stopping of all VLC instances.""" """Test successful stopping of all VLC instances."""
# Mock pgrep process (find VLC processes) # Mock pgrep process (find VLC processes)
mock_find_process = Mock() mock_find_process = Mock()
@@ -212,7 +222,9 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_no_processes( async def test_stop_all_vlc_instances_no_processes(
self, mock_subprocess, vlc_service, self,
mock_subprocess,
vlc_service,
) -> None: ) -> None:
"""Test stopping VLC instances when none are running.""" """Test stopping VLC instances when none are running."""
# Mock pgrep process (no VLC processes found) # Mock pgrep process (no VLC processes found)
@@ -232,7 +244,9 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_partial_kill( async def test_stop_all_vlc_instances_partial_kill(
self, mock_subprocess, vlc_service, self,
mock_subprocess,
vlc_service,
) -> None: ) -> None:
"""Test stopping VLC instances when some processes remain.""" """Test stopping VLC instances when some processes remain."""
# Mock pgrep process (find VLC processes) # Mock pgrep process (find VLC processes)
@@ -266,7 +280,9 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_stop_all_vlc_instances_error(self, mock_subprocess, vlc_service) -> None: async def test_stop_all_vlc_instances_error(
self, mock_subprocess, vlc_service,
) -> None:
"""Test stopping VLC instances when an error occurs.""" """Test stopping VLC instances when an error occurs."""
# Mock subprocess exception # Mock subprocess exception
mock_subprocess.side_effect = Exception("Command failed") mock_subprocess.side_effect = Exception("Command failed")
@@ -287,6 +303,7 @@ class TestVLCPlayerService:
# Clear the global instance # Clear the global instance
import app.services.vlc_player import app.services.vlc_player
app.services.vlc_player.vlc_player_service = None app.services.vlc_player.vlc_player_service = None
# First call should create new instance # First call should create new instance
@@ -304,7 +321,10 @@ class TestVLCPlayerService:
@pytest.mark.asyncio @pytest.mark.asyncio
@patch("app.services.vlc_player.asyncio.create_subprocess_exec") @patch("app.services.vlc_player.asyncio.create_subprocess_exec")
async def test_play_sound_with_play_count_tracking( async def test_play_sound_with_play_count_tracking(
self, mock_subprocess, vlc_service_with_db, sample_sound, self,
mock_subprocess,
vlc_service_with_db,
sample_sound,
) -> None: ) -> None:
"""Test sound playback with play count tracking.""" """Test sound playback with play count tracking."""
# Mock subprocess # Mock subprocess
@@ -320,11 +340,17 @@ class TestVLCPlayerService:
mock_sound_repo = AsyncMock() mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock() mock_user_repo = AsyncMock()
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo): with patch(
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo): "app.services.vlc_player.SoundRepository", return_value=mock_sound_repo,
):
with patch(
"app.services.vlc_player.UserRepository", return_value=mock_user_repo,
):
with patch("app.services.vlc_player.socket_manager") as mock_socket: with patch("app.services.vlc_player.socket_manager") as mock_socket:
# Mock the file path utility # Mock the file path utility
with patch("app.services.vlc_player.get_sound_file_path") as mock_get_path: with patch(
"app.services.vlc_player.get_sound_file_path",
) as mock_get_path:
mock_path = Mock() mock_path = Mock()
mock_path.exists.return_value = True mock_path.exists.return_value = True
mock_get_path.return_value = mock_path mock_get_path.return_value = mock_path
@@ -397,8 +423,12 @@ class TestVLCPlayerService:
role="admin", role="admin",
) )
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo): with patch(
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo): "app.services.vlc_player.SoundRepository", return_value=mock_sound_repo,
):
with patch(
"app.services.vlc_player.UserRepository", return_value=mock_user_repo,
):
with patch("app.services.vlc_player.socket_manager") as mock_socket: with patch("app.services.vlc_player.socket_manager") as mock_socket:
# Setup mocks # Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound mock_sound_repo.get_by_id.return_value = test_sound
@@ -412,7 +442,8 @@ class TestVLCPlayerService:
# Verify sound repository calls # Verify sound repository calls
mock_sound_repo.get_by_id.assert_called_once_with(1) mock_sound_repo.get_by_id.assert_called_once_with(1)
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 1}, test_sound,
{"play_count": 1},
) )
# Verify user repository calls # Verify user repository calls
@@ -442,7 +473,9 @@ class TestVLCPlayerService:
# The method should return early without doing anything # The method should return early without doing anything
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_record_play_count_always_creates_record(self, vlc_service_with_db) -> None: async def test_record_play_count_always_creates_record(
self, vlc_service_with_db,
) -> None:
"""Test play count recording always creates a new SoundPlayed record.""" """Test play count recording always creates a new SoundPlayed record."""
# Mock session and repositories # Mock session and repositories
mock_session = AsyncMock() mock_session = AsyncMock()
@@ -469,28 +502,33 @@ class TestVLCPlayerService:
role="admin", role="admin",
) )
with patch("app.services.vlc_player.SoundRepository", return_value=mock_sound_repo): with patch(
with patch("app.services.vlc_player.UserRepository", return_value=mock_user_repo): "app.services.vlc_player.SoundRepository", return_value=mock_sound_repo,
):
with patch(
"app.services.vlc_player.UserRepository", return_value=mock_user_repo,
):
with patch("app.services.vlc_player.socket_manager") as mock_socket: with patch("app.services.vlc_player.socket_manager") as mock_socket:
# Setup mocks # Setup mocks
mock_sound_repo.get_by_id.return_value = test_sound mock_sound_repo.get_by_id.return_value = test_sound
mock_user_repo.get_by_id.return_value = admin_user mock_user_repo.get_by_id.return_value = admin_user
# Mock socket broadcast # Mock socket broadcast
mock_socket.broadcast_to_all = AsyncMock() mock_socket.broadcast_to_all = AsyncMock()
await vlc_service_with_db._record_play_count(1, "Test Sound") await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound play count was updated # Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with( mock_sound_repo.update.assert_called_once_with(
test_sound, {"play_count": 6}, test_sound,
) {"play_count": 6},
)
# Verify new SoundPlayed record was always added # Verify new SoundPlayed record was always added
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
# Verify commit happened # Verify commit happened
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
def test_uses_shared_sound_path_utility(self, vlc_service, sample_sound) -> None: def test_uses_shared_sound_path_utility(self, vlc_service, sample_sound) -> None:
"""Test that VLC service uses the shared sound path utility.""" """Test that VLC service uses the shared sound path utility."""

View File

@@ -19,8 +19,8 @@ from app.utils.audio import (
SHA256_HASH_LENGTH = 64 SHA256_HASH_LENGTH = 64
BINARY_FILE_SIZE = 700 BINARY_FILE_SIZE = 700
EXPECTED_DURATION_MS_1 = 123456 # 123.456 seconds * 1000 EXPECTED_DURATION_MS_1 = 123456 # 123.456 seconds * 1000
EXPECTED_DURATION_MS_2 = 60000 # 60 seconds * 1000 EXPECTED_DURATION_MS_2 = 60000 # 60 seconds * 1000
EXPECTED_DURATION_MS_3 = 45123 # 45.123 seconds * 1000 EXPECTED_DURATION_MS_3 = 45123 # 45.123 seconds * 1000
class TestAudioUtils: class TestAudioUtils:
@@ -220,7 +220,8 @@ class TestAudioUtils:
@patch("app.utils.audio.ffmpeg.probe") @patch("app.utils.audio.ffmpeg.probe")
def test_get_audio_duration_fractional_duration( def test_get_audio_duration_fractional_duration(
self, mock_probe: MagicMock, self,
mock_probe: MagicMock,
) -> None: ) -> None:
"""Test audio duration extraction with fractional seconds.""" """Test audio duration extraction with fractional seconds."""
# Mock ffmpeg.probe to return fractional duration # Mock ffmpeg.probe to return fractional duration

View File

@@ -27,13 +27,17 @@ class TestRequiresCreditsDecorator:
return service return service
@pytest.fixture @pytest.fixture
def credit_service_factory(self, mock_credit_service: AsyncMock) -> Callable[[], AsyncMock]: def credit_service_factory(
self, mock_credit_service: AsyncMock,
) -> Callable[[], AsyncMock]:
"""Create a credit service factory.""" """Create a credit service factory."""
return lambda: mock_credit_service return lambda: mock_credit_service
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decorator_success( async def test_decorator_success(
self, credit_service_factory: Callable[[], AsyncMock], mock_credit_service: AsyncMock, self,
credit_service_factory: Callable[[], AsyncMock],
mock_credit_service: AsyncMock,
) -> None: ) -> None:
"""Test decorator with successful action.""" """Test decorator with successful action."""
@@ -49,15 +53,21 @@ class TestRequiresCreditsDecorator:
assert result == "Success: test" assert result == "Success: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, 123,
CreditActionType.VLC_PLAY_SOUND,
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, success=True, metadata=None, 123,
CreditActionType.VLC_PLAY_SOUND,
success=True,
metadata=None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decorator_with_metadata( async def test_decorator_with_metadata(
self, credit_service_factory: Callable[[], AsyncMock], mock_credit_service: AsyncMock, self,
credit_service_factory: Callable[[], AsyncMock],
mock_credit_service: AsyncMock,
) -> None: ) -> None:
"""Test decorator with metadata extraction.""" """Test decorator with metadata extraction."""
@@ -76,14 +86,20 @@ class TestRequiresCreditsDecorator:
await test_action(user_id=123, sound_name="test.mp3") await test_action(user_id=123, sound_name="test.mp3")
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, 123,
CreditActionType.VLC_PLAY_SOUND,
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, success=True, metadata={"sound_name": "test.mp3"}, 123,
CreditActionType.VLC_PLAY_SOUND,
success=True,
metadata={"sound_name": "test.mp3"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decorator_failed_action(self, credit_service_factory, mock_credit_service) -> None: async def test_decorator_failed_action(
self, credit_service_factory, mock_credit_service,
) -> None:
"""Test decorator with failed action.""" """Test decorator with failed action."""
@requires_credits( @requires_credits(
@@ -98,11 +114,16 @@ class TestRequiresCreditsDecorator:
assert result is False assert result is False
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, success=False, metadata=None, 123,
CreditActionType.VLC_PLAY_SOUND,
success=False,
metadata=None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decorator_exception_in_action(self, credit_service_factory, mock_credit_service) -> None: async def test_decorator_exception_in_action(
self, credit_service_factory, mock_credit_service,
) -> None:
"""Test decorator when action raises exception.""" """Test decorator when action raises exception."""
@requires_credits( @requires_credits(
@@ -118,13 +139,20 @@ class TestRequiresCreditsDecorator:
await test_action(user_id=123) await test_action(user_id=123)
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, success=False, metadata=None, 123,
CreditActionType.VLC_PLAY_SOUND,
success=False,
metadata=None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decorator_insufficient_credits(self, credit_service_factory, mock_credit_service) -> None: async def test_decorator_insufficient_credits(
self, credit_service_factory, mock_credit_service,
) -> None:
"""Test decorator with insufficient credits.""" """Test decorator with insufficient credits."""
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0) mock_credit_service.validate_and_reserve_credits.side_effect = (
InsufficientCreditsError(1, 0)
)
@requires_credits( @requires_credits(
CreditActionType.VLC_PLAY_SOUND, CreditActionType.VLC_PLAY_SOUND,
@@ -141,7 +169,9 @@ class TestRequiresCreditsDecorator:
mock_credit_service.deduct_credits.assert_not_called() mock_credit_service.deduct_credits.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decorator_user_id_in_args(self, credit_service_factory, mock_credit_service) -> None: async def test_decorator_user_id_in_args(
self, credit_service_factory, mock_credit_service,
) -> None:
"""Test decorator extracting user_id from positional args.""" """Test decorator extracting user_id from positional args."""
@requires_credits( @requires_credits(
@@ -156,7 +186,8 @@ class TestRequiresCreditsDecorator:
assert result == "test" assert result == "test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, 123,
CreditActionType.VLC_PLAY_SOUND,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -186,12 +217,16 @@ class TestValidateCreditsOnlyDecorator:
return service return service
@pytest.fixture @pytest.fixture
def credit_service_factory(self, mock_credit_service: AsyncMock) -> Callable[[], AsyncMock]: def credit_service_factory(
self, mock_credit_service: AsyncMock,
) -> Callable[[], AsyncMock]:
"""Create a credit service factory.""" """Create a credit service factory."""
return lambda: mock_credit_service return lambda: mock_credit_service
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_only_decorator(self, credit_service_factory, mock_credit_service) -> None: async def test_validate_only_decorator(
self, credit_service_factory, mock_credit_service,
) -> None:
"""Test validate_credits_only decorator.""" """Test validate_credits_only decorator."""
@validate_credits_only( @validate_credits_only(
@@ -206,7 +241,8 @@ class TestValidateCreditsOnlyDecorator:
assert result == "Validated: test" assert result == "Validated: test"
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, 123,
CreditActionType.VLC_PLAY_SOUND,
) )
# Should not deduct credits, only validate # Should not deduct credits, only validate
mock_credit_service.deduct_credits.assert_not_called() mock_credit_service.deduct_credits.assert_not_called()
@@ -235,10 +271,14 @@ class TestCreditManager:
manager.mark_success() manager.mark_success()
mock_credit_service.validate_and_reserve_credits.assert_called_once_with( mock_credit_service.validate_and_reserve_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, 123,
CreditActionType.VLC_PLAY_SOUND,
) )
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, success=True, metadata={"test": "data"}, 123,
CreditActionType.VLC_PLAY_SOUND,
success=True,
metadata={"test": "data"},
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -253,7 +293,10 @@ class TestCreditManager:
pass pass
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, success=False, metadata=None, 123,
CreditActionType.VLC_PLAY_SOUND,
success=False,
metadata=None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -269,13 +312,18 @@ class TestCreditManager:
raise ValueError(msg) raise ValueError(msg)
mock_credit_service.deduct_credits.assert_called_once_with( mock_credit_service.deduct_credits.assert_called_once_with(
123, CreditActionType.VLC_PLAY_SOUND, success=False, metadata=None, 123,
CreditActionType.VLC_PLAY_SOUND,
success=False,
metadata=None,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_credit_manager_validation_failure(self, mock_credit_service) -> None: async def test_credit_manager_validation_failure(self, mock_credit_service) -> None:
"""Test CreditManager when validation fails.""" """Test CreditManager when validation fails."""
mock_credit_service.validate_and_reserve_credits.side_effect = InsufficientCreditsError(1, 0) mock_credit_service.validate_and_reserve_credits.side_effect = (
InsufficientCreditsError(1, 0)
)
with pytest.raises(InsufficientCreditsError): with pytest.raises(InsufficientCreditsError):
async with CreditManager( async with CreditManager(