Add tests for extraction API endpoints and enhance existing tests
Some checks failed
Backend CI / lint (push) Successful in 9m25s
Backend CI / test (push) Failing after 4m48s

- Implement tests for admin extraction API endpoints including status retrieval, deletion of extractions, and permission checks.
- Add tests for user extraction deletion, ensuring proper handling of permissions and non-existent extractions.
- Enhance sound endpoint tests to include duplicate handling in responses.
- Refactor favorite service tests to utilize mock dependencies for better maintainability and clarity.
- Update sound scanner tests to improve file handling and ensure proper deletion of associated files.
This commit is contained in:
JSC
2025-08-25 21:40:31 +02:00
parent d3ce17f10d
commit 7dee6e320e
15 changed files with 1560 additions and 721 deletions

View File

@@ -2,18 +2,58 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_admin_user from app.core.dependencies import get_admin_user
from app.models.user import User from app.models.user import User
from app.services.extraction import ExtractionService
from app.services.extraction_processor import extraction_processor from app.services.extraction_processor import extraction_processor
router = APIRouter(prefix="/extractions", tags=["admin-extractions"]) router = APIRouter(prefix="/extractions", tags=["admin-extractions"])
async def get_extraction_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> ExtractionService:
"""Get the extraction service."""
return ExtractionService(session)
@router.get("/status") @router.get("/status")
async def get_extraction_processor_status( async def get_extraction_processor_status(
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001 current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
) -> dict: ) -> dict:
"""Get the status of the extraction processor. Admin only.""" """Get the status of the extraction processor. Admin only."""
return extraction_processor.get_status() return extraction_processor.get_status()
@router.delete("/{extraction_id}")
async def delete_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, str]:
"""Delete any extraction and its associated sound and files. Admin only."""
try:
deleted = await extraction_service.delete_extraction(extraction_id, None)
if not deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete extraction: {e!s}",
) from e
else:
return {
"message": f"Extraction {extraction_id} deleted successfully",
}

View File

@@ -170,7 +170,7 @@ async def get_processing_extractions(
try: try:
# Get all extractions with processing status # Get all extractions with processing status
processing_extractions = await extraction_service.extraction_repo.get_by_status( processing_extractions = await extraction_service.extraction_repo.get_by_status(
"processing" "processing",
) )
# Convert to ExtractionInfo format # Convert to ExtractionInfo format
@@ -196,10 +196,53 @@ async def get_processing_extractions(
} }
result.append(extraction_info) result.append(extraction_info)
return result
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get processing extractions: {e!s}", detail=f"Failed to get processing extractions: {e!s}",
) from e ) from e
else:
return result
@router.delete("/{extraction_id}")
async def delete_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, str]:
"""Delete extraction and associated sound/files. Users can only delete their own."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
deleted = await extraction_service.delete_extraction(
extraction_id, current_user.id,
)
if not deleted:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
) from e
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete extraction: {e!s}",
) from e
else:
return {
"message": f"Extraction {extraction_id} deleted successfully",
}

View File

@@ -35,10 +35,10 @@ class Sound(BaseModel, table=True):
# relationships # relationships
playlist_sounds: list["PlaylistSound"] = Relationship( playlist_sounds: list["PlaylistSound"] = Relationship(
back_populates="sound", cascade_delete=True back_populates="sound", cascade_delete=True,
) )
extractions: list["Extraction"] = Relationship(back_populates="sound") extractions: list["Extraction"] = Relationship(back_populates="sound")
play_history: list["SoundPlayed"] = Relationship( play_history: list["SoundPlayed"] = Relationship(
back_populates="sound", cascade_delete=True back_populates="sound", cascade_delete=True,
) )
favorites: list["Favorite"] = Relationship(back_populates="sound") favorites: list["Favorite"] = Relationship(back_populates="sound")

View File

@@ -44,7 +44,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
result = await self.session.exec( result = await self.session.exec(
select(Extraction) select(Extraction)
.where(Extraction.status == status) .where(Extraction.status == status)
.order_by(Extraction.created_at) .order_by(Extraction.created_at),
) )
return list(result.all()) return list(result.all())

View File

@@ -2,14 +2,16 @@
import asyncio import asyncio
import shutil import shutil
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TypedDict from typing import Any, TypedDict
import yt_dlp import yt_dlp
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings from app.core.config import settings
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.extraction import Extraction
from app.models.sound import Sound from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository from app.repositories.sound import SoundRepository
@@ -21,6 +23,18 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass
class ExtractionContext:
"""Context data for extraction processing."""
extraction_id: int
extraction_url: str
extraction_service: str | None
extraction_service_id: str | None
extraction_title: str | None
user_id: int
class ExtractionInfo(TypedDict): class ExtractionInfo(TypedDict):
"""Type definition for extraction information.""" """Type definition for extraction information."""
@@ -150,8 +164,8 @@ class ExtractionService:
logger.exception("Failed to detect service info for URL: %s", url) logger.exception("Failed to detect service info for URL: %s", url)
return None return None
async def process_extraction(self, extraction_id: int) -> ExtractionInfo: async def _validate_extraction(self, extraction_id: int) -> tuple:
"""Process an extraction job.""" """Validate extraction and return extraction data."""
extraction = await self.extraction_repo.get_by_id(extraction_id) extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction: if not extraction:
msg = f"Extraction {extraction_id} not found" msg = f"Extraction {extraction_id} not found"
@@ -173,30 +187,34 @@ class ExtractionService:
user = await self.user_repo.get_by_id(user_id) user = await self.user_repo.get_by_id(user_id)
user_name = user.name if user else None user_name = user.name if user else None
except Exception: except Exception:
logger.warning("Failed to get user %d for extraction", user_id) logger.exception("Failed to get user %d for extraction", user_id)
user_name = None user_name = None
logger.info("Processing extraction %d: %s", extraction_id, extraction_url) return (
extraction,
try:
# Update status to processing
await self.extraction_repo.update(extraction, {"status": "processing"})
# Emit WebSocket event for processing start
await self._emit_extraction_event(
user_id, user_id,
{ extraction_url,
"extraction_id": extraction_id, extraction_service,
"status": "processing", extraction_service_id,
"title": extraction_title or "Processing extraction...", extraction_title,
"url": extraction_url, user_name,
},
) )
# Detect service info if not already available async def _handle_service_detection(
if not extraction_service or not extraction_service_id: self,
logger.info("Detecting service info for extraction %d", extraction_id) extraction: Extraction,
service_info = await self._detect_service_info(extraction_url) context: ExtractionContext,
) -> tuple:
"""Handle service detection and duplicate checking."""
if context.extraction_service and context.extraction_service_id:
return (
context.extraction_service,
context.extraction_service_id,
context.extraction_title,
)
logger.info("Detecting service info for extraction %d", context.extraction_id)
service_info = await self._detect_service_info(context.extraction_url)
if not service_info: if not service_info:
msg = "Unable to detect service information from URL" msg = "Unable to detect service information from URL"
@@ -214,7 +232,7 @@ class ExtractionService:
service_name, service_name,
service_id_val, service_id_val,
) )
if existing and existing.id != extraction_id: if existing and existing.id != context.extraction_id:
error_msg = ( error_msg = (
f"Extraction already exists for " f"Extraction already exists for "
f"{service_info['service']}:{service_info['service_id']}" f"{service_info['service']}:{service_info['service_id']}"
@@ -226,25 +244,36 @@ class ExtractionService:
update_data = { update_data = {
"service": service_info["service"], "service": service_info["service"],
"service_id": service_info["service_id"], "service_id": service_info["service_id"],
"title": service_info.get("title") or extraction_title, "title": service_info.get("title") or context.extraction_title,
} }
await self.extraction_repo.update(extraction, update_data) await self.extraction_repo.update(extraction, update_data)
# Update values for processing # Update values for processing
extraction_service = service_info["service"] new_service = service_info["service"]
extraction_service_id = service_info["service_id"] new_service_id = service_info["service_id"]
extraction_title = service_info.get("title") or extraction_title new_title = service_info.get("title") or context.extraction_title
await self._emit_extraction_event( await self._emit_extraction_event(
user_id, context.user_id,
{ {
"extraction_id": extraction_id, "extraction_id": context.extraction_id,
"status": "processing", "status": "processing",
"title": extraction_title, "title": new_title,
"url": extraction_url, "url": context.extraction_url,
}, },
) )
return new_service, new_service_id, new_title
async def _process_media_files(
self,
extraction_id: int,
extraction_url: str,
extraction_title: str | None,
extraction_service: str,
extraction_service_id: str,
) -> int:
"""Process media files and create sound record."""
# Extract audio and thumbnail # Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media( audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_id,
@@ -252,16 +281,15 @@ class ExtractionService:
) )
# Move files to final locations # Move files to final locations
( final_audio_path, final_thumbnail_path = (
final_audio_path, await self._move_files_to_final_location(
final_thumbnail_path,
) = await self._move_files_to_final_location(
audio_file, audio_file,
thumbnail_file, thumbnail_file,
extraction_title, extraction_title,
extraction_service, extraction_service,
extraction_service_id, extraction_service_id,
) )
)
# Create Sound record # Create Sound record
sound = await self._create_sound_record( sound = await self._create_sound_record(
@@ -272,17 +300,24 @@ class ExtractionService:
extraction_service_id, extraction_service_id,
) )
# Store sound_id early to avoid session detachment issues if not sound.id:
sound_id = sound.id
if not sound_id:
msg = "Sound creation failed - no ID returned" msg = "Sound creation failed - no ID returned"
raise RuntimeError(msg) raise RuntimeError(msg)
return sound.id
async def _complete_extraction(
self,
extraction: Extraction,
context: ExtractionContext,
sound_id: int,
) -> None:
"""Complete extraction processing."""
# Normalize the sound # Normalize the sound
await self._normalize_sound(sound_id) await self._normalize_sound(sound_id)
# Add to main playlist # Add to main playlist
await self._add_to_main_playlist(sound_id, user_id) await self._add_to_main_playlist(sound_id, context.user_id)
# Update extraction with success # Update extraction with success
await self.extraction_repo.update( await self.extraction_repo.update(
@@ -296,30 +331,94 @@ class ExtractionService:
# Emit WebSocket event for completion # Emit WebSocket event for completion
await self._emit_extraction_event( await self._emit_extraction_event(
user_id, context.user_id,
{ {
"extraction_id": extraction_id, "extraction_id": context.extraction_id,
"status": "completed", "status": "completed",
"title": extraction_title, "title": context.extraction_title,
"url": extraction_url, "url": context.extraction_url,
"sound_id": sound_id, "sound_id": sound_id,
}, },
) )
logger.info("Successfully processed extraction %d", extraction_id) async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
"""Process an extraction job."""
# Validate extraction and get context data
(
extraction,
user_id,
extraction_url,
extraction_service,
extraction_service_id,
extraction_title,
user_name,
) = await self._validate_extraction(extraction_id)
# Create context object for helper methods
context = ExtractionContext(
extraction_id=extraction_id,
extraction_url=extraction_url,
extraction_service=extraction_service,
extraction_service_id=extraction_service_id,
extraction_title=extraction_title,
user_id=user_id,
)
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
try:
# Update status to processing
await self.extraction_repo.update(extraction, {"status": "processing"})
# Emit WebSocket event for processing start
await self._emit_extraction_event(
context.user_id,
{
"extraction_id": context.extraction_id,
"status": "processing",
"title": context.extraction_title or "Processing extraction...",
"url": context.extraction_url,
},
)
# Handle service detection and duplicate checking
extraction_service, extraction_service_id, extraction_title = (
await self._handle_service_detection(extraction, context)
)
# Update context with potentially new values
context.extraction_service = extraction_service
context.extraction_service_id = extraction_service_id
context.extraction_title = extraction_title
# Process media files and create sound record
sound_id = await self._process_media_files(
context.extraction_id,
context.extraction_url,
context.extraction_title,
extraction_service,
extraction_service_id,
)
# Complete extraction processing
await self._complete_extraction(extraction, context, sound_id)
logger.info("Successfully processed extraction %d", context.extraction_id)
# Get updated extraction to get latest timestamps # Get updated extraction to get latest timestamps
updated_extraction = await self.extraction_repo.get_by_id(extraction_id) updated_extraction = await self.extraction_repo.get_by_id(
context.extraction_id,
)
return { return {
"id": extraction_id, "id": context.extraction_id,
"url": extraction_url, "url": context.extraction_url,
"service": extraction_service, "service": extraction_service,
"service_id": extraction_service_id, "service_id": extraction_service_id,
"title": extraction_title, "title": extraction_title,
"status": "completed", "status": "completed",
"error": None, "error": None,
"sound_id": sound_id, "sound_id": sound_id,
"user_id": user_id, "user_id": context.user_id,
"user_name": user_name, "user_name": user_name,
"created_at": ( "created_at": (
updated_extraction.created_at.isoformat() updated_extraction.created_at.isoformat()
@@ -337,18 +436,18 @@ class ExtractionService:
error_msg = str(e) error_msg = str(e)
logger.exception( logger.exception(
"Failed to process extraction %d: %s", "Failed to process extraction %d: %s",
extraction_id, context.extraction_id,
error_msg, error_msg,
) )
# Emit WebSocket event for failure # Emit WebSocket event for failure
await self._emit_extraction_event( await self._emit_extraction_event(
user_id, context.user_id,
{ {
"extraction_id": extraction_id, "extraction_id": context.extraction_id,
"status": "failed", "status": "failed",
"title": extraction_title or "Extraction failed", "title": context.extraction_title or "Extraction failed",
"url": extraction_url, "url": context.extraction_url,
"error": error_msg, "error": error_msg,
}, },
) )
@@ -363,17 +462,19 @@ class ExtractionService:
) )
# Get updated extraction to get latest timestamps # Get updated extraction to get latest timestamps
updated_extraction = await self.extraction_repo.get_by_id(extraction_id) updated_extraction = await self.extraction_repo.get_by_id(
context.extraction_id,
)
return { return {
"id": extraction_id, "id": context.extraction_id,
"url": extraction_url, "url": context.extraction_url,
"service": extraction_service, "service": context.extraction_service,
"service_id": extraction_service_id, "service_id": context.extraction_service_id,
"title": extraction_title, "title": context.extraction_title,
"status": "failed", "status": "failed",
"error": error_msg, "error": error_msg,
"sound_id": None, "sound_id": None,
"user_id": user_id, "user_id": context.user_id,
"user_name": user_name, "user_name": user_name,
"created_at": ( "created_at": (
updated_extraction.created_at.isoformat() updated_extraction.created_at.isoformat()
@@ -780,3 +881,174 @@ class ExtractionService:
} }
for extraction, user in extraction_user_tuples for extraction, user in extraction_user_tuples
] ]
async def delete_extraction(
self,
extraction_id: int,
user_id: int | None = None,
) -> bool:
"""Delete an extraction and its associated sound and files.
Args:
extraction_id: The ID of the extraction to delete
user_id: Optional user ID for ownership verification (None for admin)
Returns:
True if deletion was successful, False if extraction not found
Raises:
ValueError: If user doesn't own the extraction (when user_id is provided)
"""
logger.info(
"Deleting extraction: %d (user: %s)",
extraction_id,
user_id or "admin",
)
# Get the extraction record
extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction:
logger.warning("Extraction %d not found", extraction_id)
return False
# Check ownership if user_id is provided (non-admin request)
if user_id is not None and extraction.user_id != user_id:
msg = "You don't have permission to delete this extraction"
raise ValueError(msg)
# Get associated sound if it exists and capture its attributes immediately
sound_data = None
sound_object = None
if extraction.sound_id:
sound_object = await self.sound_repo.get_by_id(extraction.sound_id)
if sound_object:
# Capture attributes immediately while session is valid
sound_data = {
"id": sound_object.id,
"type": sound_object.type,
"filename": sound_object.filename,
"is_normalized": sound_object.is_normalized,
"normalized_filename": sound_object.normalized_filename,
"thumbnail": sound_object.thumbnail,
}
try:
# Delete the extraction record first
await self.extraction_repo.delete(extraction)
logger.info("Deleted extraction record: %d", extraction_id)
# Check if sound was in current playlist before deletion
sound_was_in_current_playlist = False
if sound_object and sound_data:
sound_was_in_current_playlist = (
await self._check_sound_in_current_playlist(sound_data["id"])
)
# If there's an associated sound, delete it and its files
if sound_object and sound_data:
await self._delete_sound_and_files(sound_object, sound_data)
logger.info(
"Deleted associated sound: %d (%s)",
sound_data["id"],
sound_data["filename"],
)
# Commit the transaction
await self.session.commit()
# Reload player playlist if deleted sound was in current playlist
if sound_was_in_current_playlist and sound_data:
await self._reload_player_playlist()
logger.info(
"Reloaded player playlist after deleting sound %d "
"from current playlist",
sound_data["id"],
)
except Exception:
# Rollback on any error
await self.session.rollback()
logger.exception("Failed to delete extraction %d", extraction_id)
raise
else:
return True
async def _delete_sound_and_files(
self,
sound: Sound,
sound_data: dict[str, Any],
) -> None:
"""Delete a sound record and all its associated files."""
# Collect all file paths to delete using captured attributes
files_to_delete = []
# Original audio file
if sound_data["type"] == "EXT": # Extracted sounds
original_path = Path("sounds/originals/extracted") / sound_data["filename"]
if original_path.exists():
files_to_delete.append(original_path)
# Normalized file
if sound_data["is_normalized"] and sound_data["normalized_filename"]:
normalized_path = (
Path("sounds/normalized/extracted") / sound_data["normalized_filename"]
)
if normalized_path.exists():
files_to_delete.append(normalized_path)
# Thumbnail file
if sound_data["thumbnail"]:
thumbnail_path = (
Path(settings.EXTRACTION_THUMBNAILS_DIR) / sound_data["thumbnail"]
)
if thumbnail_path.exists():
files_to_delete.append(thumbnail_path)
# Delete the sound from database first
await self.sound_repo.delete(sound)
# Delete all associated files
for file_path in files_to_delete:
try:
file_path.unlink()
logger.info("Deleted file: %s", file_path)
except OSError:
logger.exception("Failed to delete file %s", file_path)
# Continue with other files even if one fails
async def _check_sound_in_current_playlist(self, sound_id: int) -> bool:
"""Check if a sound is in the current playlist."""
try:
from app.repositories.playlist import PlaylistRepository # noqa: PLC0415
playlist_repo = PlaylistRepository(self.session)
current_playlist = await playlist_repo.get_current_playlist()
if not current_playlist or not current_playlist.id:
return False
return await playlist_repo.is_sound_in_playlist(
current_playlist.id, sound_id,
)
except (ImportError, AttributeError, ValueError, RuntimeError) as e:
logger.warning(
"Failed to check if sound %s is in current playlist: %s",
sound_id,
e,
exc_info=True,
)
return False
async def _reload_player_playlist(self) -> None:
"""Reload the player playlist after a sound is deleted."""
try:
# Import here to avoid circular import issues
from app.services.player import get_player_service # noqa: PLC0415
player = get_player_service()
await player.reload_playlist()
logger.debug("Player playlist reloaded after sound deletion")
except (ImportError, AttributeError, ValueError, RuntimeError) as e:
# Don't fail the deletion operation if player reload fails
logger.warning("Failed to reload player playlist: %s", e, exc_info=True)

View File

@@ -201,7 +201,7 @@ class ExtractionProcessor:
for extraction in stuck_extractions: for extraction in stuck_extractions:
try: try:
await extraction_service.extraction_repo.update( await extraction_service.extraction_repo.update(
extraction, {"status": "pending", "error": None} extraction, {"status": "pending", "error": None},
) )
reset_count += 1 reset_count += 1
logger.info( logger.info(
@@ -210,12 +210,13 @@ class ExtractionProcessor:
) )
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to reset extraction %d", extraction.id "Failed to reset extraction %d", extraction.id,
) )
await session.commit() await session.commit()
logger.info( logger.info(
"Successfully reset %d stuck extractions from processing to pending", "Successfully reset %d stuck extractions from processing to "
"pending",
reset_count, reset_count,
) )

View File

@@ -1,5 +1,6 @@
"""Sound scanner service for scanning and importing audio files.""" """Sound scanner service for scanning and importing audio files."""
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TypedDict from typing import TypedDict
@@ -13,6 +14,28 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass
class AudioFileInfo:
"""Data class for audio file metadata."""
filename: str
name: str
duration: int
size: int
file_hash: str
@dataclass
class SyncContext:
"""Context data for audio file synchronization."""
file_path: Path
sound_type: str
existing_sound_by_hash: dict | Sound | None
existing_sound_by_filename: dict | Sound | None
file_hash: str
class FileInfo(TypedDict): class FileInfo(TypedDict):
"""Type definition for file information in scan results.""" """Type definition for file information in scan results."""
@@ -75,11 +98,15 @@ class SoundScannerService:
def _get_normalized_path(self, sound_type: str, filename: str) -> Path: def _get_normalized_path(self, sound_type: str, filename: str) -> Path:
"""Get the normalized file path for a sound.""" """Get the normalized file path for a sound."""
directory = self.normalized_directories.get(sound_type, "sounds/normalized/other") directory = self.normalized_directories.get(
sound_type, "sounds/normalized/other",
)
return Path(directory) / filename return Path(directory) / filename
def _rename_normalized_file(self, sound_type: str, old_filename: str, new_filename: str) -> bool: def _rename_normalized_file(
"""Rename a normalized file if it exists. Returns True if renamed, False if not found.""" self, sound_type: str, old_filename: str, new_filename: str,
) -> bool:
"""Rename normalized file if exists. Returns True if renamed, else False."""
old_path = self._get_normalized_path(sound_type, old_filename) old_path = self._get_normalized_path(sound_type, old_filename)
new_path = self._get_normalized_path(sound_type, new_filename) new_path = self._get_normalized_path(sound_type, new_filename)
@@ -89,25 +116,394 @@ class SoundScannerService:
new_path.parent.mkdir(parents=True, exist_ok=True) new_path.parent.mkdir(parents=True, exist_ok=True)
old_path.rename(new_path) old_path.rename(new_path)
logger.info("Renamed normalized file: %s -> %s", old_path, new_path) logger.info("Renamed normalized file: %s -> %s", old_path, new_path)
return True except OSError:
except Exception as e: logger.exception(
logger.error("Failed to rename normalized file %s -> %s: %s", old_path, new_path, e) "Failed to rename normalized file %s -> %s",
old_path,
new_path,
)
return False return False
else:
return True
return False return False
def _delete_normalized_file(self, sound_type: str, filename: str) -> bool: def _delete_normalized_file(self, sound_type: str, filename: str) -> bool:
"""Delete a normalized file if it exists. Returns True if deleted, False if not found.""" """Delete normalized file if exists. Returns True if deleted, else False."""
normalized_path = self._get_normalized_path(sound_type, filename) normalized_path = self._get_normalized_path(sound_type, filename)
if normalized_path.exists(): if normalized_path.exists():
try: try:
normalized_path.unlink() normalized_path.unlink()
logger.info("Deleted normalized file: %s", normalized_path) logger.info("Deleted normalized file: %s", normalized_path)
except OSError:
logger.exception(
"Failed to delete normalized file %s", normalized_path,
)
return False
else:
return True return True
return False
def _extract_sound_attributes(self, sound_data: dict | Sound | None) -> dict:
"""Extract attributes from sound data (dict or Sound object)."""
if sound_data is None:
return {}
if isinstance(sound_data, dict):
return {
"filename": sound_data.get("filename"),
"name": sound_data.get("name"),
"duration": sound_data.get("duration"),
"size": sound_data.get("size"),
"id": sound_data.get("id"),
"object": sound_data.get("sound_object"),
"type": sound_data.get("type"),
"is_normalized": sound_data.get("is_normalized"),
"normalized_filename": sound_data.get("normalized_filename"),
}
# Sound object (for tests)
return {
"filename": sound_data.filename,
"name": sound_data.name,
"duration": sound_data.duration,
"size": sound_data.size,
"id": sound_data.id,
"object": sound_data,
"type": sound_data.type,
"is_normalized": sound_data.is_normalized,
"normalized_filename": sound_data.normalized_filename,
}
def _handle_unchanged_file(
self,
filename: str,
existing_attrs: dict,
results: ScanResults,
) -> None:
"""Handle unchanged file (same hash, same filename)."""
logger.debug("Sound unchanged: %s", filename)
results["skipped"] += 1
results["files"].append({
"filename": filename,
"status": "skipped",
"reason": "file unchanged",
"name": existing_attrs["name"],
"duration": existing_attrs["duration"],
"size": existing_attrs["size"],
"id": existing_attrs["id"],
"error": None,
"changes": None,
})
def _handle_duplicate_file(
self,
filename: str,
existing_filename: str,
file_hash: str,
existing_attrs: dict,
results: ScanResults,
) -> None:
"""Handle duplicate file (same hash, different filename)."""
logger.warning(
"Duplicate file detected: '%s' has same content as existing "
"'%s' (hash: %s). Skipping duplicate file.",
filename,
existing_filename,
file_hash[:8] + "...",
)
results["skipped"] += 1
results["duplicates"] += 1
results["files"].append({
"filename": filename,
"status": "skipped",
"reason": "duplicate content",
"name": existing_attrs["name"],
"duration": existing_attrs["duration"],
"size": existing_attrs["size"],
"id": existing_attrs["id"],
"error": None,
"changes": None,
})
async def _handle_file_rename(
self,
file_info: AudioFileInfo,
existing_attrs: dict,
results: ScanResults,
) -> None:
"""Handle file rename (same hash, different filename)."""
update_data = {
"filename": file_info.filename,
"name": file_info.name,
}
# If the sound has a normalized file, rename it too
if existing_attrs["is_normalized"] and existing_attrs["normalized_filename"]:
old_normalized_base = Path(existing_attrs["normalized_filename"]).name
new_normalized_base = (
Path(file_info.filename).stem
+ Path(existing_attrs["normalized_filename"]).suffix
)
renamed = self._rename_normalized_file(
existing_attrs["type"],
old_normalized_base,
new_normalized_base,
)
if renamed:
update_data["normalized_filename"] = new_normalized_base
logger.info(
"Renamed normalized file: %s -> %s",
old_normalized_base,
new_normalized_base,
)
await self.sound_repo.update(existing_attrs["object"], update_data)
logger.info(
"Detected rename: %s -> %s (ID: %s)",
existing_attrs["filename"],
file_info.filename,
existing_attrs["id"],
)
# Build changes list
changes = ["filename", "name"]
if "normalized_filename" in update_data:
changes.append("normalized_filename")
results["updated"] += 1
results["files"].append({
"filename": file_info.filename,
"status": "updated",
"reason": "file was renamed",
"name": file_info.name,
"duration": existing_attrs["duration"],
"size": existing_attrs["size"],
"id": existing_attrs["id"],
"error": None,
"changes": changes,
# Store old filename to prevent deletion
"old_filename": existing_attrs["filename"],
})
async def _handle_file_modification(
self,
file_info: AudioFileInfo,
existing_attrs: dict,
results: ScanResults,
) -> None:
"""Handle file modification (same filename, different hash)."""
update_data = {
"name": file_info.name,
"duration": file_info.duration,
"size": file_info.size,
"hash": file_info.file_hash,
}
await self.sound_repo.update(existing_attrs["object"], update_data)
logger.info(
"Updated modified sound: %s (ID: %s)",
file_info.name,
existing_attrs["id"],
)
results["updated"] += 1
results["files"].append({
"filename": file_info.filename,
"status": "updated",
"reason": "file was modified",
"name": file_info.name,
"duration": file_info.duration,
"size": file_info.size,
"id": existing_attrs["id"],
"error": None,
"changes": ["hash", "duration", "size", "name"],
})
async def _handle_new_file(
self,
file_info: AudioFileInfo,
sound_type: str,
results: ScanResults,
) -> None:
"""Handle new file (neither hash nor filename exists)."""
sound_data = {
"type": sound_type,
"name": file_info.name,
"filename": file_info.filename,
"duration": file_info.duration,
"size": file_info.size,
"hash": file_info.file_hash,
"is_deletable": False,
"is_music": False,
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id)
results["added"] += 1
results["files"].append({
"filename": file_info.filename,
"status": "added",
"reason": None,
"name": file_info.name,
"duration": file_info.duration,
"size": file_info.size,
"id": sound.id,
"error": None,
"changes": None,
})
async def _load_existing_sounds(self, sound_type: str) -> tuple[dict, dict]:
"""Load existing sounds and create lookup dictionaries."""
existing_sounds = await self.sound_repo.get_by_type(sound_type)
# Create lookup dictionaries with immediate attribute access
# to avoid session detachment
sounds_by_hash = {}
sounds_by_filename = {}
for sound in existing_sounds:
# Capture all attributes immediately while session is valid
sound_data = {
"id": sound.id,
"hash": sound.hash,
"filename": sound.filename,
"name": sound.name,
"duration": sound.duration,
"size": sound.size,
"type": sound.type,
"is_normalized": sound.is_normalized,
"normalized_filename": sound.normalized_filename,
"sound_object": sound, # Keep reference for database operations
}
sounds_by_hash[sound.hash] = sound_data
sounds_by_filename[sound.filename] = sound_data
return sounds_by_hash, sounds_by_filename
async def _process_audio_files(
self,
scan_path: Path,
sound_type: str,
sounds_by_hash: dict,
sounds_by_filename: dict,
results: ScanResults,
) -> set[str]:
"""Process all audio files in directory and return processed filenames."""
# Get all audio files from directory
audio_files = [
f
for f in scan_path.iterdir()
if f.is_file() and f.suffix.lower() in self.supported_extensions
]
# Process each file in directory
processed_filenames = set()
for file_path in audio_files:
results["scanned"] += 1
filename = file_path.name
processed_filenames.add(filename)
try:
# Calculate hash first to enable hash-based lookup
file_hash = get_file_hash(file_path)
existing_sound_by_hash = sounds_by_hash.get(file_hash)
existing_sound_by_filename = sounds_by_filename.get(filename)
# Create sync context
sync_context = SyncContext(
file_path=file_path,
sound_type=sound_type,
existing_sound_by_hash=existing_sound_by_hash,
existing_sound_by_filename=existing_sound_by_filename,
file_hash=file_hash,
)
await self._sync_audio_file(sync_context, results)
# Check if this was a rename and mark old filename as processed
if results["files"] and results["files"][-1].get("old_filename"):
old_filename = results["files"][-1]["old_filename"]
processed_filenames.add(old_filename)
logger.debug("Marked old filename as processed: %s", old_filename)
# Remove temporary tracking field from results
del results["files"][-1]["old_filename"]
except Exception as e: except Exception as e:
logger.error("Failed to delete normalized file %s: %s", normalized_path, e) logger.exception("Error processing file %s", file_path)
return False results["errors"] += 1
return False results["files"].append({
"filename": filename,
"status": "error",
"reason": None,
"name": None,
"duration": None,
"size": None,
"id": None,
"error": str(e),
"changes": None,
})
return processed_filenames
async def _delete_missing_sounds(
self,
sounds_by_filename: dict,
processed_filenames: set[str],
results: ScanResults,
) -> None:
"""Delete sounds that no longer exist in directory."""
for filename, sound_data in sounds_by_filename.items():
if filename not in processed_filenames:
# Attributes already captured in sound_data dictionary
sound_name = sound_data["name"]
sound_duration = sound_data["duration"]
sound_size = sound_data["size"]
sound_id = sound_data["id"]
sound_object = sound_data["sound_object"]
sound_type = sound_data["type"]
sound_is_normalized = sound_data["is_normalized"]
sound_normalized_filename = sound_data["normalized_filename"]
try:
# Delete the sound from database first
await self.sound_repo.delete(sound_object)
logger.info("Deleted sound no longer in directory: %s", filename)
# If the sound had a normalized file, delete it too
if sound_is_normalized and sound_normalized_filename:
normalized_base = Path(sound_normalized_filename).name
self._delete_normalized_file(sound_type, normalized_base)
results["deleted"] += 1
results["files"].append({
"filename": filename,
"status": "deleted",
"reason": "file no longer exists",
"name": sound_name,
"duration": sound_duration,
"size": sound_size,
"id": sound_id,
"error": None,
"changes": None,
})
except Exception as e:
logger.exception("Error deleting sound %s", filename)
results["errors"] += 1
results["files"].append({
"filename": filename,
"status": "error",
"reason": "failed to delete",
"name": sound_name,
"duration": sound_duration,
"size": sound_size,
"id": sound_id,
"error": str(e),
"changes": None,
})
async def scan_directory( async def scan_directory(
self, self,
@@ -138,136 +534,25 @@ class SoundScannerService:
logger.info("Starting sync of directory: %s", directory_path) logger.info("Starting sync of directory: %s", directory_path)
# Get all existing sounds of this type from database # Load existing sounds from database
existing_sounds = await self.sound_repo.get_by_type(sound_type) sounds_by_hash, sounds_by_filename = await self._load_existing_sounds(
# Create lookup dictionaries with immediate attribute access
# to avoid session detachment
sounds_by_hash = {}
sounds_by_filename = {}
for sound in existing_sounds:
# Capture all attributes immediately while session is valid
sound_data = {
"id": sound.id,
"hash": sound.hash,
"filename": sound.filename,
"name": sound.name,
"duration": sound.duration,
"size": sound.size,
"type": sound.type,
"is_normalized": sound.is_normalized,
"normalized_filename": sound.normalized_filename,
"sound_object": sound, # Keep reference for database operations
}
sounds_by_hash[sound.hash] = sound_data
sounds_by_filename[sound.filename] = sound_data
# Get all audio files from directory
audio_files = [
f
for f in scan_path.iterdir()
if f.is_file() and f.suffix.lower() in self.supported_extensions
]
# Process each file in directory
processed_filenames = set()
for file_path in audio_files:
results["scanned"] += 1
filename = file_path.name
processed_filenames.add(filename)
try:
# Calculate hash first to enable hash-based lookup
file_hash = get_file_hash(file_path)
existing_sound_by_hash = sounds_by_hash.get(file_hash)
existing_sound_by_filename = sounds_by_filename.get(filename)
await self._sync_audio_file(
file_path,
sound_type, sound_type,
existing_sound_by_hash, )
existing_sound_by_filename,
file_hash, # Process audio files in directory
processed_filenames = await self._process_audio_files(
scan_path,
sound_type,
sounds_by_hash,
sounds_by_filename,
results, results,
) )
# Check if this was a rename operation and mark old filename as processed
if results["files"] and results["files"][-1].get("old_filename"):
old_filename = results["files"][-1]["old_filename"]
processed_filenames.add(old_filename)
logger.debug("Marked old filename as processed: %s", old_filename)
# Remove temporary tracking field from results
del results["files"][-1]["old_filename"]
except Exception as e:
logger.exception("Error processing file %s", file_path)
results["errors"] += 1
results["files"].append(
{
"filename": filename,
"status": "error",
"reason": None,
"name": None,
"duration": None,
"size": None,
"id": None,
"error": str(e),
"changes": None,
},
)
# Delete sounds that no longer exist in directory # Delete sounds that no longer exist in directory
for filename, sound_data in sounds_by_filename.items(): await self._delete_missing_sounds(
if filename not in processed_filenames: sounds_by_filename,
# Attributes already captured in sound_data dictionary processed_filenames,
sound_name = sound_data["name"] results,
sound_duration = sound_data["duration"]
sound_size = sound_data["size"]
sound_id = sound_data["id"]
sound_object = sound_data["sound_object"]
sound_type = sound_data["type"]
sound_is_normalized = sound_data["is_normalized"]
sound_normalized_filename = sound_data["normalized_filename"]
try:
# Delete the sound from database first
await self.sound_repo.delete(sound_object)
logger.info("Deleted sound no longer in directory: %s", filename)
# If the sound had a normalized file, delete it too
if sound_is_normalized and sound_normalized_filename:
normalized_base = Path(sound_normalized_filename).name
self._delete_normalized_file(sound_type, normalized_base)
results["deleted"] += 1
results["files"].append(
{
"filename": filename,
"status": "deleted",
"reason": "file no longer exists",
"name": sound_name,
"duration": sound_duration,
"size": sound_size,
"id": sound_id,
"error": None,
"changes": None,
},
)
except Exception as e:
logger.exception("Error deleting sound %s", filename)
results["errors"] += 1
results["files"].append(
{
"filename": filename,
"status": "error",
"reason": "failed to delete",
"name": sound_name,
"duration": sound_duration,
"size": sound_size,
"id": sound_id,
"error": str(e),
"changes": None,
},
) )
logger.info("Sync completed: %s", results) logger.info("Sync completed: %s", results)
@@ -275,231 +560,58 @@ class SoundScannerService:
async def _sync_audio_file( async def _sync_audio_file(
self, self,
file_path: Path, sync_context: SyncContext,
sound_type: str,
existing_sound_by_hash: dict | Sound | None,
existing_sound_by_filename: dict | Sound | None,
file_hash: str,
results: ScanResults, results: ScanResults,
) -> None: ) -> None:
"""Sync a single audio file using hash-first identification strategy.""" """Sync a single audio file using hash-first identification strategy."""
filename = file_path.name filename = sync_context.file_path.name
duration = get_audio_duration(file_path) duration = get_audio_duration(sync_context.file_path)
size = get_file_size(file_path) size = get_file_size(sync_context.file_path)
name = self.extract_name_from_filename(filename) name = self.extract_name_from_filename(filename)
# Extract attributes - handle both dict (normal) and Sound object (tests) # Create file info object
existing_hash_filename = None file_info = AudioFileInfo(
existing_hash_name = None filename=filename,
existing_hash_duration = None name=name,
existing_hash_size = None duration=duration,
existing_hash_id = None size=size,
existing_hash_object = None file_hash=sync_context.file_hash,
existing_hash_type = None )
existing_hash_is_normalized = None
existing_hash_normalized_filename = None
if existing_sound_by_hash is not None: # Extract attributes from existing sounds
if isinstance(existing_sound_by_hash, dict): hash_attrs = self._extract_sound_attributes(sync_context.existing_sound_by_hash)
existing_hash_filename = existing_sound_by_hash["filename"] filename_attrs = self._extract_sound_attributes(
existing_hash_name = existing_sound_by_hash["name"] sync_context.existing_sound_by_filename,
existing_hash_duration = existing_sound_by_hash["duration"] )
existing_hash_size = existing_sound_by_hash["size"]
existing_hash_id = existing_sound_by_hash["id"]
existing_hash_object = existing_sound_by_hash["sound_object"]
existing_hash_type = existing_sound_by_hash["type"]
existing_hash_is_normalized = existing_sound_by_hash["is_normalized"]
existing_hash_normalized_filename = existing_sound_by_hash["normalized_filename"]
else: # Sound object (for tests)
existing_hash_filename = existing_sound_by_hash.filename
existing_hash_name = existing_sound_by_hash.name
existing_hash_duration = existing_sound_by_hash.duration
existing_hash_size = existing_sound_by_hash.size
existing_hash_id = existing_sound_by_hash.id
existing_hash_object = existing_sound_by_hash
existing_hash_type = existing_sound_by_hash.type
existing_hash_is_normalized = existing_sound_by_hash.is_normalized
existing_hash_normalized_filename = existing_sound_by_hash.normalized_filename
existing_filename_id = None
existing_filename_object = None
if existing_sound_by_filename is not None:
if isinstance(existing_sound_by_filename, dict):
existing_filename_id = existing_sound_by_filename["id"]
existing_filename_object = existing_sound_by_filename["sound_object"]
else: # Sound object (for tests)
existing_filename_id = existing_sound_by_filename.id
existing_filename_object = existing_sound_by_filename
# Hash-first identification strategy # Hash-first identification strategy
if existing_sound_by_hash is not None: if sync_context.existing_sound_by_hash is not None:
# Content exists in database (same hash) # Content exists in database (same hash)
if existing_hash_filename == filename: if hash_attrs["filename"] == filename:
# Same hash, same filename - file unchanged # Same hash, same filename - file unchanged
logger.debug("Sound unchanged: %s", filename) self._handle_unchanged_file(filename, hash_attrs, results)
results["skipped"] += 1
results["files"].append(
{
"filename": filename,
"status": "skipped",
"reason": "file unchanged",
"name": existing_hash_name,
"duration": existing_hash_duration,
"size": existing_hash_size,
"id": existing_hash_id,
"error": None,
"changes": None,
},
)
else: else:
# Same hash, different filename - could be rename or duplicate # Same hash, different filename - could be rename or duplicate
# Check if both files exist to determine if it's a duplicate old_file_path = sync_context.file_path.parent / hash_attrs["filename"]
old_file_path = file_path.parent / existing_hash_filename
if old_file_path.exists(): if old_file_path.exists():
# Both files exist with same hash - this is a duplicate # Both files exist with same hash - this is a duplicate
logger.warning( self._handle_duplicate_file(
"Duplicate file detected: '%s' has same content as existing '%s' (hash: %s). "
"Skipping duplicate file.",
filename, filename,
existing_hash_filename, hash_attrs["filename"],
file_hash[:8] + "...", sync_context.file_hash,
) hash_attrs,
results,
results["skipped"] += 1
results["duplicates"] += 1
results["files"].append(
{
"filename": filename,
"status": "skipped",
"reason": "duplicate content",
"name": existing_hash_name,
"duration": existing_hash_duration,
"size": existing_hash_size,
"id": existing_hash_id,
"error": None,
"changes": None,
},
) )
else: else:
# Old file doesn't exist - this is a genuine rename # Old file doesn't exist - this is a genuine rename
update_data = { await self._handle_file_rename(file_info, hash_attrs, results)
"filename": filename,
"name": name,
}
# If the sound has a normalized file, rename it too elif sync_context.existing_sound_by_filename is not None:
if existing_hash_is_normalized and existing_hash_normalized_filename:
# Extract base filename without path for normalized file
old_normalized_base = Path(existing_hash_normalized_filename).name
new_normalized_base = Path(filename).stem + Path(existing_hash_normalized_filename).suffix
renamed = self._rename_normalized_file(
existing_hash_type,
old_normalized_base,
new_normalized_base
)
if renamed:
update_data["normalized_filename"] = new_normalized_base
logger.info(
"Renamed normalized file: %s -> %s",
old_normalized_base,
new_normalized_base
)
await self.sound_repo.update(existing_hash_object, update_data)
logger.info(
"Detected rename: %s -> %s (ID: %s)",
existing_hash_filename,
filename,
existing_hash_id,
)
# Build changes list
changes = ["filename", "name"]
if "normalized_filename" in update_data:
changes.append("normalized_filename")
results["updated"] += 1
results["files"].append(
{
"filename": filename,
"status": "updated",
"reason": "file was renamed",
"name": name,
"duration": existing_hash_duration,
"size": existing_hash_size,
"id": existing_hash_id,
"error": None,
"changes": changes,
# Store old filename to prevent deletion
"old_filename": existing_hash_filename,
},
)
elif existing_sound_by_filename is not None:
# Same filename but different hash - file was modified # Same filename but different hash - file was modified
update_data = { await self._handle_file_modification(file_info, filename_attrs, results)
"name": name,
"duration": duration,
"size": size,
"hash": file_hash,
}
await self.sound_repo.update(existing_filename_object, update_data)
logger.info(
"Updated modified sound: %s (ID: %s)",
name,
existing_filename_id,
)
results["updated"] += 1
results["files"].append(
{
"filename": filename,
"status": "updated",
"reason": "file was modified",
"name": name,
"duration": duration,
"size": size,
"id": existing_filename_id,
"error": None,
"changes": ["hash", "duration", "size", "name"],
},
)
else: else:
# New file - neither hash nor filename exists # New file - neither hash nor filename exists
sound_data = { await self._handle_new_file(file_info, sync_context.sound_type, results)
"type": sound_type,
"name": name,
"filename": filename,
"duration": duration,
"size": size,
"hash": file_hash,
"is_deletable": False,
"is_music": False,
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id)
results["added"] += 1
results["files"].append(
{
"filename": filename,
"status": "added",
"reason": None,
"name": name,
"duration": duration,
"size": size,
"id": sound.id,
"error": None,
"changes": None,
},
)
async def scan_soundboard_directory(self) -> ScanResults: async def scan_soundboard_directory(self) -> ScanResults:
"""Sync the default soundboard directory.""" """Sync the default soundboard directory."""

View File

@@ -0,0 +1,154 @@
"""Tests for admin extraction API endpoints."""
import pytest
from httpx import AsyncClient
from app.models.extraction import Extraction
from app.models.user import User
class TestAdminExtractionEndpoints:
"""Test admin extraction endpoints."""
@pytest.mark.asyncio
async def test_get_extraction_processor_status(self, authenticated_admin_client):
"""Test getting extraction processor status."""
response = await authenticated_admin_client.get(
"/api/v1/admin/extractions/status",
)
assert response.status_code == 200
data = response.json()
# Check expected status fields (match actual processor status format)
assert "currently_processing" in data
assert "max_concurrent" in data
assert "available_slots" in data
assert "processing_ids" in data
assert isinstance(data["currently_processing"], int)
assert isinstance(data["max_concurrent"], int)
assert isinstance(data["available_slots"], int)
assert isinstance(data["processing_ids"], list)
@pytest.mark.asyncio
async def test_admin_delete_extraction_success(
self,
authenticated_admin_client,
test_session,
test_plan,
):
"""Test admin successfully deleting any extraction."""
# Create a test user
user = User(
name="Test User",
email="test@example.com",
is_active=True,
plan_id=test_plan.id,
)
test_session.add(user)
await test_session.commit()
await test_session.refresh(user)
# Create test extraction
extraction = Extraction(
url="https://example.com/video",
user_id=user.id,
status="completed",
)
test_session.add(extraction)
await test_session.commit()
await test_session.refresh(extraction)
# Admin delete the extraction
response = await authenticated_admin_client.delete(
f"/api/v1/admin/extractions/{extraction.id}",
)
assert response.status_code == 200
data = response.json()
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
# Verify extraction was deleted from database
deleted_extraction = await test_session.get(Extraction, extraction.id)
assert deleted_extraction is None
@pytest.mark.asyncio
async def test_admin_delete_extraction_not_found(self, authenticated_admin_client):
"""Test admin deleting non-existent extraction."""
response = await authenticated_admin_client.delete(
"/api/v1/admin/extractions/999",
)
assert response.status_code == 404
data = response.json()
assert "not found" in data["detail"].lower()
@pytest.mark.asyncio
async def test_admin_delete_extraction_any_user(
self,
authenticated_admin_client,
test_session,
test_plan,
):
"""Test admin deleting extraction owned by any user."""
# Create another user and their extraction
other_user = User(
name="Other User",
email="other@example.com",
is_active=True,
plan_id=test_plan.id,
)
test_session.add(other_user)
await test_session.commit()
await test_session.refresh(other_user)
extraction = Extraction(
url="https://example.com/video",
user_id=other_user.id,
status="completed",
)
test_session.add(extraction)
await test_session.commit()
await test_session.refresh(extraction)
# Admin can delete any user's extraction
response = await authenticated_admin_client.delete(
f"/api/v1/admin/extractions/{extraction.id}",
)
assert response.status_code == 200
data = response.json()
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
@pytest.mark.asyncio
async def test_delete_extraction_non_admin(self, authenticated_client, test_user, test_session):
"""Test non-admin user cannot access admin deletion endpoint."""
# Create test extraction
extraction = Extraction(
url="https://example.com/video",
user_id=test_user.id,
status="completed",
)
test_session.add(extraction)
await test_session.commit()
await test_session.refresh(extraction)
# Non-admin user cannot access admin endpoint
response = await authenticated_client.delete(
f"/api/v1/admin/extractions/{extraction.id}",
)
assert response.status_code == 403
data = response.json()
assert "permissions" in data["detail"].lower()
@pytest.mark.asyncio
async def test_admin_endpoints_unauthenticated(self, client: AsyncClient):
"""Test admin endpoints require authentication."""
# Status endpoint
response = await client.get("/api/v1/admin/extractions/status")
assert response.status_code == 401
# Delete endpoint
response = await client.delete("/api/v1/admin/extractions/1")
assert response.status_code == 401

View File

@@ -31,6 +31,7 @@ class TestAdminSoundEndpoints:
"deleted": 1, "deleted": 1,
"skipped": 0, "skipped": 0,
"errors": 0, "errors": 0,
"duplicates": 0,
"files": [ "files": [
{ {
"filename": "test1.mp3", "filename": "test1.mp3",
@@ -176,6 +177,7 @@ class TestAdminSoundEndpoints:
"deleted": 0, "deleted": 0,
"skipped": 0, "skipped": 0,
"errors": 0, "errors": 0,
"duplicates": 0,
"files": [ "files": [
{ {
"filename": "custom1.wav", "filename": "custom1.wav",

View File

@@ -229,3 +229,73 @@ class TestExtractionEndpoints:
break break
assert processing_found, "Processing extraction not found in results" assert processing_found, "Processing extraction not found in results"
@pytest.mark.asyncio
async def test_delete_extraction_success(self, authenticated_client, test_user, test_session):
"""Test successful deletion of user's own extraction."""
# Create test extraction
extraction = Extraction(
url="https://example.com/video",
user_id=test_user.id,
status="completed",
)
test_session.add(extraction)
await test_session.commit()
await test_session.refresh(extraction)
# Delete the extraction
response = await authenticated_client.delete(f"/api/v1/extractions/{extraction.id}")
assert response.status_code == 200
data = response.json()
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
# Verify extraction was deleted from database
deleted_extraction = await test_session.get(Extraction, extraction.id)
assert deleted_extraction is None
@pytest.mark.asyncio
async def test_delete_extraction_not_found(self, authenticated_client):
"""Test deleting non-existent extraction."""
response = await authenticated_client.delete("/api/v1/extractions/999")
assert response.status_code == 404
data = response.json()
assert "not found" in data["detail"].lower()
@pytest.mark.asyncio
async def test_delete_extraction_permission_denied(self, authenticated_client, test_session, test_plan):
"""Test deleting another user's extraction."""
# Create extraction owned by different user
other_user = User(
name="Other User",
email="other@example.com",
is_active=True,
plan_id=test_plan.id,
)
test_session.add(other_user)
await test_session.commit()
await test_session.refresh(other_user)
extraction = Extraction(
url="https://example.com/video",
user_id=other_user.id,
status="completed",
)
test_session.add(extraction)
await test_session.commit()
await test_session.refresh(extraction)
# Try to delete other user's extraction
response = await authenticated_client.delete(f"/api/v1/extractions/{extraction.id}")
assert response.status_code == 403
data = response.json()
assert "permission" in data["detail"].lower()
@pytest.mark.asyncio
async def test_delete_extraction_unauthenticated(self, client):
"""Test deleting extraction without authentication."""
response = await client.delete("/api/v1/extractions/1")
assert response.status_code == 401

View File

@@ -1,5 +1,7 @@
"""Tests for favorite API endpoints.""" """Tests for favorite API endpoints."""
from contextlib import suppress
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from httpx import AsyncClient from httpx import AsyncClient
@@ -129,10 +131,8 @@ class TestFavoriteEndpoints:
) -> None: ) -> None:
"""Test successfully adding a sound to favorites.""" """Test successfully adding a sound to favorites."""
# Clean up any existing favorite first # Clean up any existing favorite first
try: with suppress(Exception):
await authenticated_client.delete("/api/v1/favorites/sounds/1") await authenticated_client.delete("/api/v1/favorites/sounds/1")
except:
pass # It's ok if it doesn't exist
response = await authenticated_client.post("/api/v1/favorites/sounds/1") response = await authenticated_client.post("/api/v1/favorites/sounds/1")
@@ -176,10 +176,8 @@ class TestFavoriteEndpoints:
) -> None: ) -> None:
"""Test successfully adding a playlist to favorites.""" """Test successfully adding a playlist to favorites."""
# Clean up any existing favorite first # Clean up any existing favorite first
try: with suppress(Exception):
await authenticated_client.delete("/api/v1/favorites/playlists/1") await authenticated_client.delete("/api/v1/favorites/playlists/1")
except:
pass # It's ok if it doesn't exist
response = await authenticated_client.post("/api/v1/favorites/playlists/1") response = await authenticated_client.post("/api/v1/favorites/playlists/1")
@@ -473,10 +471,8 @@ class TestFavoriteEndpoints:
) -> None: ) -> None:
"""Test checking if a sound is favorited (false case).""" """Test checking if a sound is favorited (false case)."""
# Make sure sound 1 is not favorited # Make sure sound 1 is not favorited
try: with suppress(Exception):
await authenticated_client.delete("/api/v1/favorites/sounds/1") await authenticated_client.delete("/api/v1/favorites/sounds/1")
except:
pass # It's ok if it doesn't exist
response = await authenticated_client.get("/api/v1/favorites/sounds/1/check") response = await authenticated_client.get("/api/v1/favorites/sounds/1/check")
@@ -509,10 +505,8 @@ class TestFavoriteEndpoints:
) -> None: ) -> None:
"""Test checking if a playlist is favorited (false case).""" """Test checking if a playlist is favorited (false case)."""
# Make sure playlist 1 is not favorited # Make sure playlist 1 is not favorited
try: with suppress(Exception):
await authenticated_client.delete("/api/v1/favorites/playlists/1") await authenticated_client.delete("/api/v1/favorites/playlists/1")
except:
pass # It's ok if it doesn't exist
response = await authenticated_client.get("/api/v1/favorites/playlists/1/check") response = await authenticated_client.get("/api/v1/favorites/playlists/1/check")

View File

@@ -541,3 +541,143 @@ class TestExtractionService:
assert result[0]["id"] == 1 assert result[0]["id"] == 1
assert result[0]["status"] == "pending" assert result[0]["status"] == "pending"
assert result[0]["user_name"] == "Test User" assert result[0]["user_name"] == "Test User"
@pytest.mark.asyncio
async def test_delete_extraction_with_sound(self, extraction_service, test_user):
"""Test deleting extraction with associated sound and files."""
import tempfile
from pathlib import Path
# Create temporary directories for testing
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
# Set up temporary directories
original_dir = temp_dir_path / "originals" / "extracted"
normalized_dir = temp_dir_path / "normalized" / "extracted"
thumbnail_dir = temp_dir_path / "thumbnails"
original_dir.mkdir(parents=True)
normalized_dir.mkdir(parents=True)
thumbnail_dir.mkdir(parents=True)
# Create test files
audio_file = original_dir / "test_audio.mp3"
normalized_file = normalized_dir / "test_audio.mp3"
thumbnail_file = thumbnail_dir / "test_thumb.jpg"
audio_file.write_text("audio content")
normalized_file.write_text("normalized content")
thumbnail_file.write_text("thumbnail content")
# Create extraction and sound records
extraction = Extraction(
id=1,
url="https://example.com/video",
user_id=test_user.id,
status="completed",
sound_id=1,
)
sound = Sound(
id=1,
type="EXT",
name="Test Audio",
filename="test_audio.mp3",
duration=60000,
size=2048,
hash="test_hash",
is_normalized=True,
normalized_filename="test_audio.mp3",
thumbnail="test_thumb.jpg",
is_deletable=True,
is_music=True,
)
# Mock repository methods
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction)
extraction_service.sound_repo.get_by_id = AsyncMock(return_value=sound)
extraction_service.extraction_repo.delete = AsyncMock()
extraction_service.sound_repo.delete = AsyncMock()
extraction_service.session.commit = AsyncMock()
extraction_service.session.rollback = AsyncMock()
# Monkey patch the paths in the service method
import app.services.extraction
original_path_class = app.services.extraction.Path
def mock_path(*args: str):
path_str = str(args[0])
if path_str == "sounds/originals/extracted":
return original_dir
if path_str == "sounds/normalized/extracted":
return normalized_dir
if path_str.endswith("thumbnails"):
return thumbnail_dir
return original_path_class(*args)
# Mock the Path constructor and settings
with patch("app.services.extraction.Path", side_effect=mock_path), \
patch("app.services.extraction.settings") as mock_settings:
mock_settings.EXTRACTION_THUMBNAILS_DIR = str(thumbnail_dir)
# Test deletion
result = await extraction_service.delete_extraction(1, test_user.id)
assert result is True
# Verify repository calls
extraction_service.extraction_repo.get_by_id.assert_called_once_with(1)
extraction_service.sound_repo.get_by_id.assert_called_once_with(1)
extraction_service.extraction_repo.delete.assert_called_once_with(extraction)
extraction_service.sound_repo.delete.assert_called_once_with(sound)
extraction_service.session.commit.assert_called_once()
# Verify files were deleted
assert not audio_file.exists()
assert not normalized_file.exists()
assert not thumbnail_file.exists()
@pytest.mark.asyncio
async def test_delete_extraction_not_found(self, extraction_service, test_user):
"""Test deleting non-existent extraction."""
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=None)
result = await extraction_service.delete_extraction(999, test_user.id)
assert result is False
@pytest.mark.asyncio
async def test_delete_extraction_permission_denied(self, extraction_service, test_user):
"""Test deleting extraction owned by another user."""
extraction = Extraction(
id=1,
url="https://example.com/video",
user_id=999, # Different user ID
status="completed",
)
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction)
with pytest.raises(ValueError, match="You don't have permission"):
await extraction_service.delete_extraction(1, test_user.id)
@pytest.mark.asyncio
async def test_delete_extraction_admin(self, extraction_service, test_user):
"""Test admin deleting any extraction."""
extraction = Extraction(
id=1,
url="https://example.com/video",
user_id=999, # Different user ID
status="completed",
)
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction)
extraction_service.extraction_repo.delete = AsyncMock()
extraction_service.session.commit = AsyncMock()
# Admin deletion (user_id=None)
result = await extraction_service.delete_extraction(1, None)
assert result is True
extraction_service.extraction_repo.delete.assert_called_once_with(extraction)

View File

@@ -1,6 +1,7 @@
"""Tests for favorite service.""" """Tests for favorite service."""
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -14,6 +15,31 @@ from app.models.user import User
from app.services.favorite import FavoriteService from app.services.favorite import FavoriteService
@dataclass
class MockRepositories:
"""Container for all mock repositories."""
user_repo: AsyncMock
favorite_repo: AsyncMock
sound_repo: AsyncMock
socket_manager: AsyncMock
@dataclass
class MockServiceDependencies:
"""Container for all mock service dependencies."""
sound_repo_class: AsyncMock
user_repo_class: AsyncMock
favorite_repo_class: AsyncMock
socket_manager: AsyncMock
sound_repo: AsyncMock
user_repo: AsyncMock
favorite_repo: AsyncMock
playlist_repo_class: AsyncMock | None = None
playlist_repo: AsyncMock | None = None
class TestFavoriteService: class TestFavoriteService:
"""Test favorite service operations.""" """Test favorite service operations."""
@@ -71,34 +97,75 @@ class TestFavoriteService:
"playlist_repo": AsyncMock(), "playlist_repo": AsyncMock(),
} }
@patch("app.services.favorite.socket_manager") @pytest_asyncio.fixture
@patch("app.services.favorite.FavoriteRepository") async def mock_sound_favorite_dependencies(self) -> MockServiceDependencies:
@patch("app.services.favorite.UserRepository") """Create mock dependencies for sound favorite operations."""
@patch("app.services.favorite.SoundRepository") with (
patch("app.services.favorite.SoundRepository") as mock_sound_repo_class,
patch("app.services.favorite.UserRepository") as mock_user_repo_class,
patch("app.services.favorite.FavoriteRepository") as mock_favorite_repo_class,
patch("app.services.favorite.socket_manager") as mock_socket_manager,
):
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
mock_favorite_repo = AsyncMock()
mock_sound_repo_class.return_value = mock_sound_repo
mock_user_repo_class.return_value = mock_user_repo
mock_favorite_repo_class.return_value = mock_favorite_repo
yield MockServiceDependencies(
sound_repo_class=mock_sound_repo_class,
user_repo_class=mock_user_repo_class,
favorite_repo_class=mock_favorite_repo_class,
socket_manager=mock_socket_manager,
sound_repo=mock_sound_repo,
user_repo=mock_user_repo,
favorite_repo=mock_favorite_repo,
)
@pytest_asyncio.fixture
async def mock_playlist_favorite_dependencies(self) -> MockServiceDependencies:
"""Create mock dependencies for playlist favorite operations."""
with (
patch("app.services.favorite.UserRepository") as mock_user_repo_class,
patch("app.services.favorite.PlaylistRepository") as mock_playlist_repo_class,
patch("app.services.favorite.FavoriteRepository") as mock_favorite_repo_class,
):
mock_user_repo = AsyncMock()
mock_playlist_repo = AsyncMock()
mock_favorite_repo = AsyncMock()
mock_user_repo_class.return_value = mock_user_repo
mock_playlist_repo_class.return_value = mock_playlist_repo
mock_favorite_repo_class.return_value = mock_favorite_repo
yield MockServiceDependencies(
sound_repo_class=AsyncMock(), # Not used in playlist tests
user_repo_class=mock_user_repo_class,
favorite_repo_class=mock_favorite_repo_class,
socket_manager=AsyncMock(), # Not used in playlist tests
sound_repo=AsyncMock(), # Not used in playlist tests
user_repo=mock_user_repo,
favorite_repo=mock_favorite_repo,
playlist_repo_class=mock_playlist_repo_class,
playlist_repo=mock_playlist_repo,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_sound_favorite_success( async def test_add_sound_favorite_success(
self, self,
mock_sound_repo_class: AsyncMock, mock_sound_favorite_dependencies: MockServiceDependencies,
mock_user_repo_class: AsyncMock,
mock_favorite_repo_class: AsyncMock,
mock_socket_manager: AsyncMock,
favorite_service: FavoriteService, favorite_service: FavoriteService,
test_user: User, test_user: User,
test_sound: Sound, test_sound: Sound,
) -> None: ) -> None:
"""Test successfully adding a sound favorite.""" """Test successfully adding a sound favorite."""
# Setup mocks # Setup mocks
mock_favorite_repo = AsyncMock() mocks = mock_sound_favorite_dependencies
mock_user_repo = AsyncMock() mocks.user_repo.get_by_id.return_value = test_user
mock_sound_repo = AsyncMock() mocks.sound_repo.get_by_id.return_value = test_sound
mocks.favorite_repo.get_by_user_and_sound.return_value = None
mock_favorite_repo_class.return_value = mock_favorite_repo
mock_user_repo_class.return_value = mock_user_repo
mock_sound_repo_class.return_value = mock_sound_repo
mock_user_repo.get_by_id.return_value = test_user
mock_sound_repo.get_by_id.return_value = test_sound
mock_favorite_repo.get_by_user_and_sound.return_value = None
expected_favorite = Favorite( expected_favorite = Favorite(
id=1, id=1,
@@ -106,23 +173,23 @@ class TestFavoriteService:
sound_id=test_sound.id, sound_id=test_sound.id,
playlist_id=None, playlist_id=None,
) )
mock_favorite_repo.create.return_value = expected_favorite mocks.favorite_repo.create.return_value = expected_favorite
mock_favorite_repo.count_sound_favorites.return_value = 1 mocks.favorite_repo.count_sound_favorites.return_value = 1
# Execute # Execute
result = await favorite_service.add_sound_favorite(test_user.id, test_sound.id) result = await favorite_service.add_sound_favorite(test_user.id, test_sound.id)
# Verify # Verify
assert result == expected_favorite assert result == expected_favorite
mock_user_repo.get_by_id.assert_called_once_with(test_user.id) mocks.user_repo.get_by_id.assert_called_once_with(test_user.id)
mock_sound_repo.get_by_id.assert_called_once_with(test_sound.id) mocks.sound_repo.get_by_id.assert_called_once_with(test_sound.id)
mock_favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id) mocks.favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id)
mock_favorite_repo.create.assert_called_once_with({ mocks.favorite_repo.create.assert_called_once_with({
"user_id": test_user.id, "user_id": test_user.id,
"sound_id": test_sound.id, "sound_id": test_sound.id,
"playlist_id": None, "playlist_id": None,
}) })
mock_socket_manager.broadcast_to_all.assert_called_once() mocks.socket_manager.broadcast_to_all.assert_called_once()
@patch("app.services.favorite.UserRepository") @patch("app.services.favorite.UserRepository")
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -161,62 +228,38 @@ class TestFavoriteService:
with pytest.raises(ValueError, match="Sound with ID 1 not found"): with pytest.raises(ValueError, match="Sound with ID 1 not found"):
await favorite_service.add_sound_favorite(test_user.id, 1) await favorite_service.add_sound_favorite(test_user.id, 1)
@patch("app.services.favorite.FavoriteRepository")
@patch("app.services.favorite.SoundRepository")
@patch("app.services.favorite.UserRepository")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_sound_favorite_already_exists( async def test_add_sound_favorite_already_exists(
self, self,
mock_user_repo_class: AsyncMock, mock_sound_favorite_dependencies: MockServiceDependencies,
mock_sound_repo_class: AsyncMock,
mock_favorite_repo_class: AsyncMock,
favorite_service: FavoriteService, favorite_service: FavoriteService,
test_user: User, test_user: User,
test_sound: Sound, test_sound: Sound,
) -> None: ) -> None:
"""Test adding sound favorite that already exists.""" """Test adding sound favorite that already exists."""
mock_user_repo = AsyncMock() mocks = mock_sound_favorite_dependencies
mock_sound_repo = AsyncMock() mocks.user_repo.get_by_id.return_value = test_user
mock_favorite_repo = AsyncMock() mocks.sound_repo.get_by_id.return_value = test_sound
mock_user_repo_class.return_value = mock_user_repo
mock_sound_repo_class.return_value = mock_sound_repo
mock_favorite_repo_class.return_value = mock_favorite_repo
mock_user_repo.get_by_id.return_value = test_user
mock_sound_repo.get_by_id.return_value = test_sound
existing_favorite = Favorite(user_id=test_user.id, sound_id=test_sound.id) existing_favorite = Favorite(user_id=test_user.id, sound_id=test_sound.id)
mock_favorite_repo.get_by_user_and_sound.return_value = existing_favorite mocks.favorite_repo.get_by_user_and_sound.return_value = existing_favorite
with pytest.raises(ValueError, match="already favorited"): with pytest.raises(ValueError, match="already favorited"):
await favorite_service.add_sound_favorite(test_user.id, test_sound.id) await favorite_service.add_sound_favorite(test_user.id, test_sound.id)
@patch("app.services.favorite.FavoriteRepository")
@patch("app.services.favorite.PlaylistRepository")
@patch("app.services.favorite.UserRepository")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_playlist_favorite_success( async def test_add_playlist_favorite_success(
self, self,
mock_user_repo_class: AsyncMock, mock_playlist_favorite_dependencies: MockServiceDependencies,
mock_playlist_repo_class: AsyncMock,
mock_favorite_repo_class: AsyncMock,
favorite_service: FavoriteService, favorite_service: FavoriteService,
test_user: User, test_user: User,
test_playlist: Playlist, test_playlist: Playlist,
) -> None: ) -> None:
"""Test successfully adding a playlist favorite.""" """Test successfully adding a playlist favorite."""
# Setup mocks # Setup mocks
mock_favorite_repo = AsyncMock() mocks = mock_playlist_favorite_dependencies
mock_user_repo = AsyncMock() mocks.user_repo.get_by_id.return_value = test_user
mock_playlist_repo = AsyncMock() mocks.playlist_repo.get_by_id.return_value = test_playlist
mocks.favorite_repo.get_by_user_and_playlist.return_value = None
mock_favorite_repo_class.return_value = mock_favorite_repo
mock_user_repo_class.return_value = mock_user_repo
mock_playlist_repo_class.return_value = mock_playlist_repo
mock_user_repo.get_by_id.return_value = test_user
mock_playlist_repo.get_by_id.return_value = test_playlist
mock_favorite_repo.get_by_user_and_playlist.return_value = None
expected_favorite = Favorite( expected_favorite = Favorite(
id=1, id=1,
@@ -224,59 +267,45 @@ class TestFavoriteService:
sound_id=None, sound_id=None,
playlist_id=test_playlist.id, playlist_id=test_playlist.id,
) )
mock_favorite_repo.create.return_value = expected_favorite mocks.favorite_repo.create.return_value = expected_favorite
# Execute # Execute
result = await favorite_service.add_playlist_favorite(test_user.id, test_playlist.id) result = await favorite_service.add_playlist_favorite(test_user.id, test_playlist.id)
# Verify # Verify
assert result == expected_favorite assert result == expected_favorite
mock_user_repo.get_by_id.assert_called_once_with(test_user.id) mocks.user_repo.get_by_id.assert_called_once_with(test_user.id)
mock_playlist_repo.get_by_id.assert_called_once_with(test_playlist.id) mocks.playlist_repo.get_by_id.assert_called_once_with(test_playlist.id)
mock_favorite_repo.get_by_user_and_playlist.assert_called_once_with(test_user.id, test_playlist.id) mocks.favorite_repo.get_by_user_and_playlist.assert_called_once_with(test_user.id, test_playlist.id)
mock_favorite_repo.create.assert_called_once_with({ mocks.favorite_repo.create.assert_called_once_with({
"user_id": test_user.id, "user_id": test_user.id,
"sound_id": None, "sound_id": None,
"playlist_id": test_playlist.id, "playlist_id": test_playlist.id,
}) })
@patch("app.services.favorite.socket_manager")
@patch("app.services.favorite.FavoriteRepository")
@patch("app.services.favorite.SoundRepository")
@patch("app.services.favorite.UserRepository")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_sound_favorite_success( async def test_remove_sound_favorite_success(
self, self,
mock_user_repo_class: AsyncMock, mock_sound_favorite_dependencies: MockServiceDependencies,
mock_sound_repo_class: AsyncMock,
mock_favorite_repo_class: AsyncMock,
mock_socket_manager: AsyncMock,
favorite_service: FavoriteService, favorite_service: FavoriteService,
test_user: User, test_user: User,
test_sound: Sound, test_sound: Sound,
) -> None: ) -> None:
"""Test successfully removing a sound favorite.""" """Test successfully removing a sound favorite."""
mock_favorite_repo = AsyncMock() mocks = mock_sound_favorite_dependencies
mock_user_repo = AsyncMock()
mock_sound_repo = AsyncMock()
mock_favorite_repo_class.return_value = mock_favorite_repo
mock_user_repo_class.return_value = mock_user_repo
mock_sound_repo_class.return_value = mock_sound_repo
existing_favorite = Favorite(user_id=test_user.id, sound_id=test_sound.id) existing_favorite = Favorite(user_id=test_user.id, sound_id=test_sound.id)
mock_favorite_repo.get_by_user_and_sound.return_value = existing_favorite mocks.favorite_repo.get_by_user_and_sound.return_value = existing_favorite
mock_user_repo.get_by_id.return_value = test_user mocks.user_repo.get_by_id.return_value = test_user
mock_sound_repo.get_by_id.return_value = test_sound mocks.sound_repo.get_by_id.return_value = test_sound
mock_favorite_repo.count_sound_favorites.return_value = 0 mocks.favorite_repo.count_sound_favorites.return_value = 0
# Execute # Execute
await favorite_service.remove_sound_favorite(test_user.id, test_sound.id) await favorite_service.remove_sound_favorite(test_user.id, test_sound.id)
# Verify # Verify
mock_favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id) mocks.favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id)
mock_favorite_repo.delete.assert_called_once_with(existing_favorite) mocks.favorite_repo.delete.assert_called_once_with(existing_favorite)
mock_socket_manager.broadcast_to_all.assert_called_once() mocks.socket_manager.broadcast_to_all.assert_called_once()
@patch("app.services.favorite.FavoriteRepository") @patch("app.services.favorite.FavoriteRepository")
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -503,46 +532,31 @@ class TestFavoriteService:
assert result == 3 assert result == 3
mock_favorite_repo.count_playlist_favorites.assert_called_once_with(1) mock_favorite_repo.count_playlist_favorites.assert_called_once_with(1)
@patch("app.services.favorite.socket_manager")
@patch("app.services.favorite.FavoriteRepository")
@patch("app.services.favorite.SoundRepository")
@patch("app.services.favorite.UserRepository")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_socket_broadcast_error_handling( async def test_socket_broadcast_error_handling(
self, self,
mock_user_repo_class: AsyncMock, mock_sound_favorite_dependencies: MockServiceDependencies,
mock_sound_repo_class: AsyncMock,
mock_favorite_repo_class: AsyncMock,
mock_socket_manager: AsyncMock,
favorite_service: FavoriteService, favorite_service: FavoriteService,
test_user: User, test_user: User,
test_sound: Sound, test_sound: Sound,
) -> None: ) -> None:
"""Test that socket broadcast errors don't affect the operation.""" """Test that socket broadcast errors don't affect the operation."""
# Setup mocks mocks = mock_sound_favorite_dependencies
mock_favorite_repo = AsyncMock() mocks.user_repo.get_by_id.return_value = test_user
mock_user_repo = AsyncMock() mocks.sound_repo.get_by_id.return_value = test_sound
mock_sound_repo = AsyncMock() mocks.favorite_repo.get_by_user_and_sound.return_value = None
mock_favorite_repo_class.return_value = mock_favorite_repo
mock_user_repo_class.return_value = mock_user_repo
mock_sound_repo_class.return_value = mock_sound_repo
mock_user_repo.get_by_id.return_value = test_user
mock_sound_repo.get_by_id.return_value = test_sound
mock_favorite_repo.get_by_user_and_sound.return_value = None
expected_favorite = Favorite(id=1, user_id=test_user.id, sound_id=test_sound.id) expected_favorite = Favorite(id=1, user_id=test_user.id, sound_id=test_sound.id)
mock_favorite_repo.create.return_value = expected_favorite mocks.favorite_repo.create.return_value = expected_favorite
mock_favorite_repo.count_sound_favorites.return_value = 1 mocks.favorite_repo.count_sound_favorites.return_value = 1
# Make socket broadcast raise an exception # Make socket broadcast raise an exception
mock_socket_manager.broadcast_to_all.side_effect = Exception("Socket error") mocks.socket_manager.broadcast_to_all.side_effect = Exception("Socket error")
# Execute - should not raise exception despite socket error # Execute - should not raise exception despite socket error
result = await favorite_service.add_sound_favorite(test_user.id, test_sound.id) result = await favorite_service.add_sound_favorite(test_user.id, test_sound.id)
# Verify operation still succeeded # Verify operation still succeeded
assert result == expected_favorite assert result == expected_favorite
mock_favorite_repo.create.assert_called_once() mocks.favorite_repo.create.assert_called_once()

View File

@@ -8,7 +8,7 @@ import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.sound import Sound from app.models.sound import Sound
from app.services.sound_scanner import SoundScannerService from app.services.sound_scanner import SoundScannerService, SyncContext
class TestSoundScannerService: class TestSoundScannerService:
@@ -155,14 +155,14 @@ class TestSoundScannerService:
# Set the existing sound filename to match temp file for "unchanged" test # Set the existing sound filename to match temp file for "unchanged" test
existing_sound.filename = temp_path.name existing_sound.filename = temp_path.name
await scanner_service._sync_audio_file( sync_context = SyncContext(
temp_path, file_path=temp_path,
"SDB", sound_type="SDB",
existing_sound, # existing_sound_by_hash (same hash) existing_sound_by_hash=existing_sound,
None, # existing_sound_by_filename (no conflict) existing_sound_by_filename=None,
"same_hash", file_hash="same_hash",
results,
) )
await scanner_service._sync_audio_file(sync_context, results)
assert results["skipped"] == 1 assert results["skipped"] == 1
assert results["added"] == 0 assert results["added"] == 0
@@ -210,14 +210,14 @@ class TestSoundScannerService:
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( sync_context = SyncContext(
temp_path, file_path=temp_path,
"SDB", sound_type="SDB",
existing_sound, # existing_sound_by_hash (same hash) existing_sound_by_hash=existing_sound,
None, # existing_sound_by_filename (different filename) existing_sound_by_filename=None,
"same_hash", file_hash="same_hash",
results,
) )
await scanner_service._sync_audio_file(sync_context, results)
# Should be marked as updated (renamed) # Should be marked as updated (renamed)
assert results["updated"] == 1 assert results["updated"] == 1
@@ -257,12 +257,11 @@ class TestSoundScannerService:
# Create temporary directory with renamed file # Create temporary directory with renamed file
import tempfile import tempfile
import os
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
# Create the "renamed" file (same hash, different name) # Create the "renamed" file (same hash, different name)
new_file_path = os.path.join(temp_dir, "new_name.mp3") new_file_path = Path(temp_dir) / "new_name.mp3"
with open(new_file_path, "wb") as f: with new_file_path.open("wb") as f:
f.write(b"test audio content") # This will produce consistent hash f.write(b"test audio content") # This will produce consistent hash
# Mock file operations to return same hash # Mock file operations to return same hash
@@ -308,16 +307,15 @@ class TestSoundScannerService:
# Create temporary directory with both original and duplicate files # Create temporary directory with both original and duplicate files
import tempfile import tempfile
import os
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
# Create both files (simulating duplicate content) # Create both files (simulating duplicate content)
original_path = os.path.join(temp_dir, "original.mp3") original_path = Path(temp_dir) / "original.mp3"
duplicate_path = os.path.join(temp_dir, "duplicate.mp3") duplicate_path = Path(temp_dir) / "duplicate.mp3"
with open(original_path, "wb") as f: with original_path.open("wb") as f:
f.write(b"test audio content") f.write(b"test audio content")
with open(duplicate_path, "wb") as f: with duplicate_path.open("wb") as f:
f.write(b"test audio content") # Same content = same hash f.write(b"test audio content") # Same content = same hash
# Mock file operations # Mock file operations
@@ -375,14 +373,14 @@ class TestSoundScannerService:
"errors": 0, "errors": 0,
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( sync_context = SyncContext(
temp_path, file_path=temp_path,
"SDB", sound_type="SDB",
None, # existing_sound_by_hash existing_sound_by_hash=None,
None, # existing_sound_by_filename existing_sound_by_filename=None,
"test_hash", file_hash="test_hash",
results,
) )
await scanner_service._sync_audio_file(sync_context, results)
assert results["added"] == 1 assert results["added"] == 1
assert results["skipped"] == 0 assert results["skipped"] == 0
@@ -439,14 +437,14 @@ class TestSoundScannerService:
"errors": 0, "errors": 0,
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( sync_context = SyncContext(
temp_path, file_path=temp_path,
"SDB", sound_type="SDB",
None, # existing_sound_by_hash (different hash) existing_sound_by_hash=None,
existing_sound, # existing_sound_by_filename existing_sound_by_filename=existing_sound,
"new_hash", file_hash="new_hash",
results,
) )
await scanner_service._sync_audio_file(sync_context, results)
assert results["updated"] == 1 assert results["updated"] == 1
assert results["added"] == 0 assert results["added"] == 0
@@ -504,14 +502,14 @@ class TestSoundScannerService:
"errors": 0, "errors": 0,
"files": [], "files": [],
} }
await scanner_service._sync_audio_file( sync_context = SyncContext(
temp_path, file_path=temp_path,
"CUSTOM", sound_type="CUSTOM",
None, # existing_sound_by_hash existing_sound_by_hash=None,
None, # existing_sound_by_filename existing_sound_by_filename=None,
"custom_hash", file_hash="custom_hash",
results,
) )
await scanner_service._sync_audio_file(sync_context, results)
assert results["added"] == 1 assert results["added"] == 1
assert results["skipped"] == 0 assert results["skipped"] == 0
@@ -533,19 +531,19 @@ class TestSoundScannerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_audio_file_rename_with_normalized_file( async def test_sync_audio_file_rename_with_normalized_file(
self, test_session, scanner_service self, test_session, scanner_service,
): ):
"""Test that renaming a sound file also renames its normalized file.""" """Test that renaming a sound file also renames its normalized file."""
# Create temporary directories for testing # Create temporary directories for testing
from pathlib import Path
import tempfile import tempfile
from pathlib import Path
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir) temp_dir_path = Path(temp_dir)
# Set up the scanner's normalized directories to use temp dir # Set up the scanner's normalized directories to use temp dir
scanner_service.normalized_directories = { scanner_service.normalized_directories = {
"SDB": str(temp_dir_path / "normalized" / "soundboard") "SDB": str(temp_dir_path / "normalized" / "soundboard"),
} }
# Create the normalized directory # Create the normalized directory
@@ -557,7 +555,6 @@ class TestSoundScannerService:
old_normalized_file.write_text("normalized audio content") old_normalized_file.write_text("normalized audio content")
# Create the audio files (they need to exist for the scanner) # Create the audio files (they need to exist for the scanner)
old_path = temp_dir_path / "old_sound.mp3"
new_path = temp_dir_path / "new_sound.mp3" new_path = temp_dir_path / "new_sound.mp3"
# Create a dummy audio file for the new path # Create a dummy audio file for the new path
@@ -565,8 +562,8 @@ class TestSoundScannerService:
# Mock the audio utility functions since we're using fake files # Mock the audio utility functions since we're using fake files
from unittest.mock import patch from unittest.mock import patch
with patch('app.services.sound_scanner.get_audio_duration', return_value=60000), \ with patch("app.services.sound_scanner.get_audio_duration", return_value=60000), \
patch('app.services.sound_scanner.get_file_size', return_value=2048): patch("app.services.sound_scanner.get_file_size", return_value=2048):
# Create existing sound with normalized file info # Create existing sound with normalized file info
existing_sound = Sound( existing_sound = Sound(
@@ -584,7 +581,7 @@ class TestSoundScannerService:
normalized_hash="normalized_hash", normalized_hash="normalized_hash",
play_count=5, play_count=5,
is_deletable=False, is_deletable=False,
is_music=False is_music=False,
) )
results = { results = {
@@ -602,14 +599,14 @@ class TestSoundScannerService:
scanner_service.sound_repo.update = AsyncMock() scanner_service.sound_repo.update = AsyncMock()
# Simulate rename detection by calling _sync_audio_file # Simulate rename detection by calling _sync_audio_file
await scanner_service._sync_audio_file( sync_context = SyncContext(
new_path, file_path=new_path,
"SDB", sound_type="SDB",
existing_sound, # existing_sound_by_hash (same hash, different filename) existing_sound_by_hash=existing_sound,
None, # existing_sound_by_filename (no file with new name exists) existing_sound_by_filename=None,
"test_hash", file_hash="test_hash",
results,
) )
await scanner_service._sync_audio_file(sync_context, results)
# Verify the results # Verify the results
assert results["updated"] == 1 assert results["updated"] == 1
@@ -635,12 +632,12 @@ class TestSoundScannerService:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scan_directory_delete_with_normalized_file( async def test_scan_directory_delete_with_normalized_file(
self, test_session, scanner_service self, test_session, scanner_service,
): ):
"""Test that deleting a sound also deletes its normalized file.""" """Test that deleting a sound also deletes its normalized file."""
# Create temporary directories for testing # Create temporary directories for testing
from pathlib import Path
import tempfile import tempfile
from pathlib import Path
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir) temp_dir_path = Path(temp_dir)
@@ -649,7 +646,7 @@ class TestSoundScannerService:
# Set up the scanner's normalized directories to use temp dir # Set up the scanner's normalized directories to use temp dir
scanner_service.normalized_directories = { scanner_service.normalized_directories = {
"SDB": str(temp_dir_path / "normalized" / "soundboard") "SDB": str(temp_dir_path / "normalized" / "soundboard"),
} }
# Create the normalized directory and file # Create the normalized directory and file
@@ -674,7 +671,7 @@ class TestSoundScannerService:
normalized_hash="normalized_hash", normalized_hash="normalized_hash",
play_count=5, play_count=5,
is_deletable=False, is_deletable=False,
is_music=False is_music=False,
) )
# Mock sound repository methods # Mock sound repository methods
@@ -683,8 +680,8 @@ class TestSoundScannerService:
# Mock audio utility functions # Mock audio utility functions
from unittest.mock import patch from unittest.mock import patch
with patch('app.services.sound_scanner.get_audio_duration'), \ with patch("app.services.sound_scanner.get_audio_duration"), \
patch('app.services.sound_scanner.get_file_size'): patch("app.services.sound_scanner.get_file_size"):
# Run scan with empty directory (should trigger deletion) # Run scan with empty directory (should trigger deletion)
results = await scanner_service.scan_directory(str(scan_dir), "SDB") results = await scanner_service.scan_directory(str(scan_dir), "SDB")