diff --git a/app/api/v1/admin/extractions.py b/app/api/v1/admin/extractions.py index f0a9ba5..d02e113 100644 --- a/app/api/v1/admin/extractions.py +++ b/app/api/v1/admin/extractions.py @@ -2,18 +2,58 @@ from typing import Annotated -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException, status +from sqlmodel.ext.asyncio.session import AsyncSession +from app.core.database import get_db from app.core.dependencies import get_admin_user from app.models.user import User +from app.services.extraction import ExtractionService from app.services.extraction_processor import extraction_processor router = APIRouter(prefix="/extractions", tags=["admin-extractions"]) +async def get_extraction_service( + session: Annotated[AsyncSession, Depends(get_db)], +) -> ExtractionService: + """Get the extraction service.""" + return ExtractionService(session) + + @router.get("/status") async def get_extraction_processor_status( current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001 ) -> dict: """Get the status of the extraction processor. Admin only.""" return extraction_processor.get_status() + + +@router.delete("/{extraction_id}") +async def delete_extraction( + extraction_id: int, + current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001 + extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)], +) -> dict[str, str]: + """Delete any extraction and its associated sound and files. Admin only.""" + try: + deleted = await extraction_service.delete_extraction(extraction_id, None) + + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Extraction {extraction_id} not found", + ) + + except HTTPException: + # Re-raise HTTPExceptions without wrapping them + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete extraction: {e!s}", + ) from e + else: + return { + "message": f"Extraction {extraction_id} deleted successfully", + } diff --git a/app/api/v1/extractions.py b/app/api/v1/extractions.py index d48b768..36c0e75 100644 --- a/app/api/v1/extractions.py +++ b/app/api/v1/extractions.py @@ -170,7 +170,7 @@ async def get_processing_extractions( try: # Get all extractions with processing status processing_extractions = await extraction_service.extraction_repo.get_by_status( - "processing" + "processing", ) # Convert to ExtractionInfo format @@ -196,10 +196,53 @@ async def get_processing_extractions( } result.append(extraction_info) - return result - except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to get processing extractions: {e!s}", ) from e + else: + return result + + +@router.delete("/{extraction_id}") +async def delete_extraction( + extraction_id: int, + current_user: Annotated[User, Depends(get_current_active_user_flexible)], + extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)], +) -> dict[str, str]: + """Delete extraction and associated sound/files. Users can only delete their own.""" + try: + if current_user.id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User ID not available", + ) + + deleted = await extraction_service.delete_extraction( + extraction_id, current_user.id, + ) + + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Extraction {extraction_id} not found", + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e), + ) from e + except HTTPException: + # Re-raise HTTPExceptions without wrapping them + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete extraction: {e!s}", + ) from e + else: + return { + "message": f"Extraction {extraction_id} deleted successfully", + } diff --git a/app/models/sound.py b/app/models/sound.py index fc28ae8..0156715 100644 --- a/app/models/sound.py +++ b/app/models/sound.py @@ -35,10 +35,10 @@ class Sound(BaseModel, table=True): # relationships playlist_sounds: list["PlaylistSound"] = Relationship( - back_populates="sound", cascade_delete=True + back_populates="sound", cascade_delete=True, ) extractions: list["Extraction"] = Relationship(back_populates="sound") play_history: list["SoundPlayed"] = Relationship( - back_populates="sound", cascade_delete=True + back_populates="sound", cascade_delete=True, ) favorites: list["Favorite"] = Relationship(back_populates="sound") diff --git a/app/repositories/extraction.py b/app/repositories/extraction.py index da7beb6..c38a589 100644 --- a/app/repositories/extraction.py +++ b/app/repositories/extraction.py @@ -44,7 +44,7 @@ class ExtractionRepository(BaseRepository[Extraction]): result = await self.session.exec( select(Extraction) .where(Extraction.status == status) - .order_by(Extraction.created_at) + .order_by(Extraction.created_at), ) return list(result.all()) diff --git a/app/services/extraction.py b/app/services/extraction.py index 742e0b1..6215138 100644 --- a/app/services/extraction.py +++ b/app/services/extraction.py @@ -2,14 +2,16 @@ import asyncio import shutil +from dataclasses import dataclass from pathlib import Path -from typing import TypedDict +from typing import Any, TypedDict import yt_dlp from sqlmodel.ext.asyncio.session import AsyncSession from app.core.config import settings from app.core.logging import get_logger +from app.models.extraction import Extraction from app.models.sound import Sound from app.repositories.extraction import ExtractionRepository from app.repositories.sound import SoundRepository @@ -21,6 +23,18 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size logger = get_logger(__name__) +@dataclass +class ExtractionContext: + """Context data for extraction processing.""" + + extraction_id: int + extraction_url: str + extraction_service: str | None + extraction_service_id: str | None + extraction_title: str | None + user_id: int + + class ExtractionInfo(TypedDict): """Type definition for extraction information.""" @@ -150,8 +164,8 @@ class ExtractionService: logger.exception("Failed to detect service info for URL: %s", url) return None - async def process_extraction(self, extraction_id: int) -> ExtractionInfo: - """Process an extraction job.""" + async def _validate_extraction(self, extraction_id: int) -> tuple: + """Validate extraction and return extraction data.""" extraction = await self.extraction_repo.get_by_id(extraction_id) if not extraction: msg = f"Extraction {extraction_id} not found" @@ -173,9 +187,183 @@ class ExtractionService: user = await self.user_repo.get_by_id(user_id) user_name = user.name if user else None except Exception: - logger.warning("Failed to get user %d for extraction", user_id) + logger.exception("Failed to get user %d for extraction", user_id) user_name = None + return ( + extraction, + user_id, + extraction_url, + extraction_service, + extraction_service_id, + extraction_title, + user_name, + ) + + async def _handle_service_detection( + self, + extraction: Extraction, + context: ExtractionContext, + ) -> tuple: + """Handle service detection and duplicate checking.""" + if context.extraction_service and context.extraction_service_id: + return ( + context.extraction_service, + context.extraction_service_id, + context.extraction_title, + ) + + logger.info("Detecting service info for extraction %d", context.extraction_id) + service_info = await self._detect_service_info(context.extraction_url) + + if not service_info: + msg = "Unable to detect service information from URL" + raise ValueError(msg) + + # Check if extraction already exists for this service + service_name = service_info["service"] + service_id_val = service_info["service_id"] + + if not service_name or not service_id_val: + msg = "Service info is incomplete" + raise ValueError(msg) + + existing = await self.extraction_repo.get_by_service_and_id( + service_name, + service_id_val, + ) + if existing and existing.id != context.extraction_id: + error_msg = ( + f"Extraction already exists for " + f"{service_info['service']}:{service_info['service_id']}" + ) + logger.warning(error_msg) + raise ValueError(error_msg) + + # Update extraction with service info + update_data = { + "service": service_info["service"], + "service_id": service_info["service_id"], + "title": service_info.get("title") or context.extraction_title, + } + await self.extraction_repo.update(extraction, update_data) + + # Update values for processing + new_service = service_info["service"] + new_service_id = service_info["service_id"] + new_title = service_info.get("title") or context.extraction_title + + await self._emit_extraction_event( + context.user_id, + { + "extraction_id": context.extraction_id, + "status": "processing", + "title": new_title, + "url": context.extraction_url, + }, + ) + + return new_service, new_service_id, new_title + + async def _process_media_files( + self, + extraction_id: int, + extraction_url: str, + extraction_title: str | None, + extraction_service: str, + extraction_service_id: str, + ) -> int: + """Process media files and create sound record.""" + # Extract audio and thumbnail + audio_file, thumbnail_file = await self._extract_media( + extraction_id, + extraction_url, + ) + + # Move files to final locations + final_audio_path, final_thumbnail_path = ( + await self._move_files_to_final_location( + audio_file, + thumbnail_file, + extraction_title, + extraction_service, + extraction_service_id, + ) + ) + + # Create Sound record + sound = await self._create_sound_record( + final_audio_path, + final_thumbnail_path, + extraction_title, + extraction_service, + extraction_service_id, + ) + + if not sound.id: + msg = "Sound creation failed - no ID returned" + raise RuntimeError(msg) + + return sound.id + + async def _complete_extraction( + self, + extraction: Extraction, + context: ExtractionContext, + sound_id: int, + ) -> None: + """Complete extraction processing.""" + # Normalize the sound + await self._normalize_sound(sound_id) + + # Add to main playlist + await self._add_to_main_playlist(sound_id, context.user_id) + + # Update extraction with success + await self.extraction_repo.update( + extraction, + { + "status": "completed", + "sound_id": sound_id, + "error": None, + }, + ) + + # Emit WebSocket event for completion + await self._emit_extraction_event( + context.user_id, + { + "extraction_id": context.extraction_id, + "status": "completed", + "title": context.extraction_title, + "url": context.extraction_url, + "sound_id": sound_id, + }, + ) + + async def process_extraction(self, extraction_id: int) -> ExtractionInfo: + """Process an extraction job.""" + # Validate extraction and get context data + ( + extraction, + user_id, + extraction_url, + extraction_service, + extraction_service_id, + extraction_title, + user_name, + ) = await self._validate_extraction(extraction_id) + + # Create context object for helper methods + context = ExtractionContext( + extraction_id=extraction_id, + extraction_url=extraction_url, + extraction_service=extraction_service, + extraction_service_id=extraction_service_id, + extraction_title=extraction_title, + user_id=user_id, + ) + logger.info("Processing extraction %d: %s", extraction_id, extraction_url) try: @@ -184,142 +372,53 @@ class ExtractionService: # Emit WebSocket event for processing start await self._emit_extraction_event( - user_id, + context.user_id, { - "extraction_id": extraction_id, + "extraction_id": context.extraction_id, "status": "processing", - "title": extraction_title or "Processing extraction...", - "url": extraction_url, + "title": context.extraction_title or "Processing extraction...", + "url": context.extraction_url, }, ) - # Detect service info if not already available - if not extraction_service or not extraction_service_id: - logger.info("Detecting service info for extraction %d", extraction_id) - service_info = await self._detect_service_info(extraction_url) - - if not service_info: - msg = "Unable to detect service information from URL" - raise ValueError(msg) - - # Check if extraction already exists for this service - service_name = service_info["service"] - service_id_val = service_info["service_id"] - - if not service_name or not service_id_val: - msg = "Service info is incomplete" - raise ValueError(msg) - - existing = await self.extraction_repo.get_by_service_and_id( - service_name, - service_id_val, - ) - if existing and existing.id != extraction_id: - error_msg = ( - f"Extraction already exists for " - f"{service_info['service']}:{service_info['service_id']}" - ) - logger.warning(error_msg) - raise ValueError(error_msg) - - # Update extraction with service info - update_data = { - "service": service_info["service"], - "service_id": service_info["service_id"], - "title": service_info.get("title") or extraction_title, - } - await self.extraction_repo.update(extraction, update_data) - - # Update values for processing - extraction_service = service_info["service"] - extraction_service_id = service_info["service_id"] - extraction_title = service_info.get("title") or extraction_title - - await self._emit_extraction_event( - user_id, - { - "extraction_id": extraction_id, - "status": "processing", - "title": extraction_title, - "url": extraction_url, - }, - ) - - # Extract audio and thumbnail - audio_file, thumbnail_file = await self._extract_media( - extraction_id, - extraction_url, + # Handle service detection and duplicate checking + extraction_service, extraction_service_id, extraction_title = ( + await self._handle_service_detection(extraction, context) ) - # Move files to final locations - ( - final_audio_path, - final_thumbnail_path, - ) = await self._move_files_to_final_location( - audio_file, - thumbnail_file, - extraction_title, + # Update context with potentially new values + context.extraction_service = extraction_service + context.extraction_service_id = extraction_service_id + context.extraction_title = extraction_title + + # Process media files and create sound record + sound_id = await self._process_media_files( + context.extraction_id, + context.extraction_url, + context.extraction_title, extraction_service, extraction_service_id, ) - # Create Sound record - sound = await self._create_sound_record( - final_audio_path, - final_thumbnail_path, - extraction_title, - extraction_service, - extraction_service_id, - ) + # Complete extraction processing + await self._complete_extraction(extraction, context, sound_id) - # Store sound_id early to avoid session detachment issues - sound_id = sound.id - if not sound_id: - msg = "Sound creation failed - no ID returned" - raise RuntimeError(msg) - - # Normalize the sound - await self._normalize_sound(sound_id) - - # Add to main playlist - await self._add_to_main_playlist(sound_id, user_id) - - # Update extraction with success - await self.extraction_repo.update( - extraction, - { - "status": "completed", - "sound_id": sound_id, - "error": None, - }, - ) - - # Emit WebSocket event for completion - await self._emit_extraction_event( - user_id, - { - "extraction_id": extraction_id, - "status": "completed", - "title": extraction_title, - "url": extraction_url, - "sound_id": sound_id, - }, - ) - - logger.info("Successfully processed extraction %d", extraction_id) + logger.info("Successfully processed extraction %d", context.extraction_id) # Get updated extraction to get latest timestamps - updated_extraction = await self.extraction_repo.get_by_id(extraction_id) + updated_extraction = await self.extraction_repo.get_by_id( + context.extraction_id, + ) return { - "id": extraction_id, - "url": extraction_url, + "id": context.extraction_id, + "url": context.extraction_url, "service": extraction_service, "service_id": extraction_service_id, "title": extraction_title, "status": "completed", "error": None, "sound_id": sound_id, - "user_id": user_id, + "user_id": context.user_id, "user_name": user_name, "created_at": ( updated_extraction.created_at.isoformat() @@ -337,18 +436,18 @@ class ExtractionService: error_msg = str(e) logger.exception( "Failed to process extraction %d: %s", - extraction_id, + context.extraction_id, error_msg, ) # Emit WebSocket event for failure await self._emit_extraction_event( - user_id, + context.user_id, { - "extraction_id": extraction_id, + "extraction_id": context.extraction_id, "status": "failed", - "title": extraction_title or "Extraction failed", - "url": extraction_url, + "title": context.extraction_title or "Extraction failed", + "url": context.extraction_url, "error": error_msg, }, ) @@ -363,17 +462,19 @@ class ExtractionService: ) # Get updated extraction to get latest timestamps - updated_extraction = await self.extraction_repo.get_by_id(extraction_id) + updated_extraction = await self.extraction_repo.get_by_id( + context.extraction_id, + ) return { - "id": extraction_id, - "url": extraction_url, - "service": extraction_service, - "service_id": extraction_service_id, - "title": extraction_title, + "id": context.extraction_id, + "url": context.extraction_url, + "service": context.extraction_service, + "service_id": context.extraction_service_id, + "title": context.extraction_title, "status": "failed", "error": error_msg, "sound_id": None, - "user_id": user_id, + "user_id": context.user_id, "user_name": user_name, "created_at": ( updated_extraction.created_at.isoformat() @@ -780,3 +881,174 @@ class ExtractionService: } for extraction, user in extraction_user_tuples ] + + async def delete_extraction( + self, + extraction_id: int, + user_id: int | None = None, + ) -> bool: + """Delete an extraction and its associated sound and files. + + Args: + extraction_id: The ID of the extraction to delete + user_id: Optional user ID for ownership verification (None for admin) + + Returns: + True if deletion was successful, False if extraction not found + + Raises: + ValueError: If user doesn't own the extraction (when user_id is provided) + + """ + logger.info( + "Deleting extraction: %d (user: %s)", + extraction_id, + user_id or "admin", + ) + + # Get the extraction record + extraction = await self.extraction_repo.get_by_id(extraction_id) + if not extraction: + logger.warning("Extraction %d not found", extraction_id) + return False + + # Check ownership if user_id is provided (non-admin request) + if user_id is not None and extraction.user_id != user_id: + msg = "You don't have permission to delete this extraction" + raise ValueError(msg) + + # Get associated sound if it exists and capture its attributes immediately + sound_data = None + sound_object = None + if extraction.sound_id: + sound_object = await self.sound_repo.get_by_id(extraction.sound_id) + if sound_object: + # Capture attributes immediately while session is valid + sound_data = { + "id": sound_object.id, + "type": sound_object.type, + "filename": sound_object.filename, + "is_normalized": sound_object.is_normalized, + "normalized_filename": sound_object.normalized_filename, + "thumbnail": sound_object.thumbnail, + } + + try: + # Delete the extraction record first + await self.extraction_repo.delete(extraction) + logger.info("Deleted extraction record: %d", extraction_id) + + # Check if sound was in current playlist before deletion + sound_was_in_current_playlist = False + if sound_object and sound_data: + sound_was_in_current_playlist = ( + await self._check_sound_in_current_playlist(sound_data["id"]) + ) + + # If there's an associated sound, delete it and its files + if sound_object and sound_data: + await self._delete_sound_and_files(sound_object, sound_data) + logger.info( + "Deleted associated sound: %d (%s)", + sound_data["id"], + sound_data["filename"], + ) + + # Commit the transaction + await self.session.commit() + + # Reload player playlist if deleted sound was in current playlist + if sound_was_in_current_playlist and sound_data: + await self._reload_player_playlist() + logger.info( + "Reloaded player playlist after deleting sound %d " + "from current playlist", + sound_data["id"], + ) + + except Exception: + # Rollback on any error + await self.session.rollback() + logger.exception("Failed to delete extraction %d", extraction_id) + raise + else: + return True + + async def _delete_sound_and_files( + self, + sound: Sound, + sound_data: dict[str, Any], + ) -> None: + """Delete a sound record and all its associated files.""" + # Collect all file paths to delete using captured attributes + files_to_delete = [] + + # Original audio file + if sound_data["type"] == "EXT": # Extracted sounds + original_path = Path("sounds/originals/extracted") / sound_data["filename"] + if original_path.exists(): + files_to_delete.append(original_path) + + # Normalized file + if sound_data["is_normalized"] and sound_data["normalized_filename"]: + normalized_path = ( + Path("sounds/normalized/extracted") / sound_data["normalized_filename"] + ) + if normalized_path.exists(): + files_to_delete.append(normalized_path) + + # Thumbnail file + if sound_data["thumbnail"]: + thumbnail_path = ( + Path(settings.EXTRACTION_THUMBNAILS_DIR) / sound_data["thumbnail"] + ) + if thumbnail_path.exists(): + files_to_delete.append(thumbnail_path) + + # Delete the sound from database first + await self.sound_repo.delete(sound) + + # Delete all associated files + for file_path in files_to_delete: + try: + file_path.unlink() + logger.info("Deleted file: %s", file_path) + except OSError: + logger.exception("Failed to delete file %s", file_path) + # Continue with other files even if one fails + + async def _check_sound_in_current_playlist(self, sound_id: int) -> bool: + """Check if a sound is in the current playlist.""" + try: + from app.repositories.playlist import PlaylistRepository # noqa: PLC0415 + + playlist_repo = PlaylistRepository(self.session) + current_playlist = await playlist_repo.get_current_playlist() + + if not current_playlist or not current_playlist.id: + return False + + return await playlist_repo.is_sound_in_playlist( + current_playlist.id, sound_id, + ) + except (ImportError, AttributeError, ValueError, RuntimeError) as e: + logger.warning( + "Failed to check if sound %s is in current playlist: %s", + sound_id, + e, + exc_info=True, + ) + return False + + async def _reload_player_playlist(self) -> None: + """Reload the player playlist after a sound is deleted.""" + try: + # Import here to avoid circular import issues + from app.services.player import get_player_service # noqa: PLC0415 + + player = get_player_service() + await player.reload_playlist() + logger.debug("Player playlist reloaded after sound deletion") + except (ImportError, AttributeError, ValueError, RuntimeError) as e: + # Don't fail the deletion operation if player reload fails + logger.warning("Failed to reload player playlist: %s", e, exc_info=True) diff --git a/app/services/extraction_processor.py b/app/services/extraction_processor.py index 31dcccf..59f268d 100644 --- a/app/services/extraction_processor.py +++ b/app/services/extraction_processor.py @@ -201,7 +201,7 @@ class ExtractionProcessor: for extraction in stuck_extractions: try: await extraction_service.extraction_repo.update( - extraction, {"status": "pending", "error": None} + extraction, {"status": "pending", "error": None}, ) reset_count += 1 logger.info( @@ -210,12 +210,13 @@ class ExtractionProcessor: ) except Exception: logger.exception( - "Failed to reset extraction %d", extraction.id + "Failed to reset extraction %d", extraction.id, ) await session.commit() logger.info( - "Successfully reset %d stuck extractions from processing to pending", + "Successfully reset %d stuck extractions from processing to " + "pending", reset_count, ) diff --git a/app/services/sound_scanner.py b/app/services/sound_scanner.py index 0c95c72..05a5158 100644 --- a/app/services/sound_scanner.py +++ b/app/services/sound_scanner.py @@ -1,5 +1,6 @@ """Sound scanner service for scanning and importing audio files.""" +from dataclasses import dataclass from pathlib import Path from typing import TypedDict @@ -13,6 +14,28 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size logger = get_logger(__name__) +@dataclass +class AudioFileInfo: + """Data class for audio file metadata.""" + + filename: str + name: str + duration: int + size: int + file_hash: str + + +@dataclass +class SyncContext: + """Context data for audio file synchronization.""" + + file_path: Path + sound_type: str + existing_sound_by_hash: dict | Sound | None + existing_sound_by_filename: dict | Sound | None + file_hash: str + + class FileInfo(TypedDict): """Type definition for file information in scan results.""" @@ -56,7 +79,7 @@ class SoundScannerService: ".m4a", ".aac", } - + # Directory mappings for normalized files (matching sound_normalizer) self.normalized_directories = { "SDB": "sounds/normalized/soundboard", @@ -72,43 +95,416 @@ class SoundScannerService: name = name.replace("_", " ").replace("-", " ") # Capitalize words return " ".join(word.capitalize() for word in name.split()) - + def _get_normalized_path(self, sound_type: str, filename: str) -> Path: """Get the normalized file path for a sound.""" - directory = self.normalized_directories.get(sound_type, "sounds/normalized/other") + directory = self.normalized_directories.get( + sound_type, "sounds/normalized/other", + ) return Path(directory) / filename - - def _rename_normalized_file(self, sound_type: str, old_filename: str, new_filename: str) -> bool: - """Rename a normalized file if it exists. Returns True if renamed, False if not found.""" + + def _rename_normalized_file( + self, sound_type: str, old_filename: str, new_filename: str, + ) -> bool: + """Rename normalized file if exists. Returns True if renamed, else False.""" old_path = self._get_normalized_path(sound_type, old_filename) new_path = self._get_normalized_path(sound_type, new_filename) - + if old_path.exists(): try: # Ensure the directory exists new_path.parent.mkdir(parents=True, exist_ok=True) old_path.rename(new_path) logger.info("Renamed normalized file: %s -> %s", old_path, new_path) - return True - except Exception as e: - logger.error("Failed to rename normalized file %s -> %s: %s", old_path, new_path, e) + except OSError: + logger.exception( + "Failed to rename normalized file %s -> %s", + old_path, + new_path, + ) return False + else: + return True return False - + def _delete_normalized_file(self, sound_type: str, filename: str) -> bool: - """Delete a normalized file if it exists. Returns True if deleted, False if not found.""" + """Delete normalized file if exists. Returns True if deleted, else False.""" normalized_path = self._get_normalized_path(sound_type, filename) - + if normalized_path.exists(): try: normalized_path.unlink() logger.info("Deleted normalized file: %s", normalized_path) - return True - except Exception as e: - logger.error("Failed to delete normalized file %s: %s", normalized_path, e) + except OSError: + logger.exception( + "Failed to delete normalized file %s", normalized_path, + ) return False + else: + return True return False + def _extract_sound_attributes(self, sound_data: dict | Sound | None) -> dict: + """Extract attributes from sound data (dict or Sound object).""" + if sound_data is None: + return {} + + if isinstance(sound_data, dict): + return { + "filename": sound_data.get("filename"), + "name": sound_data.get("name"), + "duration": sound_data.get("duration"), + "size": sound_data.get("size"), + "id": sound_data.get("id"), + "object": sound_data.get("sound_object"), + "type": sound_data.get("type"), + "is_normalized": sound_data.get("is_normalized"), + "normalized_filename": sound_data.get("normalized_filename"), + } + # Sound object (for tests) + return { + "filename": sound_data.filename, + "name": sound_data.name, + "duration": sound_data.duration, + "size": sound_data.size, + "id": sound_data.id, + "object": sound_data, + "type": sound_data.type, + "is_normalized": sound_data.is_normalized, + "normalized_filename": sound_data.normalized_filename, + } + + def _handle_unchanged_file( + self, + filename: str, + existing_attrs: dict, + results: ScanResults, + ) -> None: + """Handle unchanged file (same hash, same filename).""" + logger.debug("Sound unchanged: %s", filename) + results["skipped"] += 1 + results["files"].append({ + "filename": filename, + "status": "skipped", + "reason": "file unchanged", + "name": existing_attrs["name"], + "duration": existing_attrs["duration"], + "size": existing_attrs["size"], + "id": existing_attrs["id"], + "error": None, + "changes": None, + }) + + def _handle_duplicate_file( + self, + filename: str, + existing_filename: str, + file_hash: str, + existing_attrs: dict, + results: ScanResults, + ) -> None: + """Handle duplicate file (same hash, different filename).""" + logger.warning( + "Duplicate file detected: '%s' has same content as existing " + "'%s' (hash: %s). Skipping duplicate file.", + filename, + existing_filename, + file_hash[:8] + "...", + ) + results["skipped"] += 1 + results["duplicates"] += 1 + results["files"].append({ + "filename": filename, + "status": "skipped", + "reason": "duplicate content", + "name": existing_attrs["name"], + "duration": existing_attrs["duration"], + "size": existing_attrs["size"], + "id": existing_attrs["id"], + "error": None, + "changes": None, + }) + + async def _handle_file_rename( + self, + file_info: AudioFileInfo, + existing_attrs: dict, + results: ScanResults, + ) -> None: + """Handle file rename (same hash, different filename).""" + update_data = { + "filename": file_info.filename, + "name": file_info.name, + } + + # If the sound has a normalized file, rename it too + if existing_attrs["is_normalized"] and existing_attrs["normalized_filename"]: + old_normalized_base = Path(existing_attrs["normalized_filename"]).name + new_normalized_base = ( + Path(file_info.filename).stem + + Path(existing_attrs["normalized_filename"]).suffix + ) + + renamed = self._rename_normalized_file( + existing_attrs["type"], + old_normalized_base, + new_normalized_base, + ) + + if renamed: + update_data["normalized_filename"] = new_normalized_base + logger.info( + "Renamed normalized file: %s -> %s", + old_normalized_base, + new_normalized_base, + ) + + await self.sound_repo.update(existing_attrs["object"], update_data) + logger.info( + "Detected rename: %s -> %s (ID: %s)", + existing_attrs["filename"], + file_info.filename, + existing_attrs["id"], + ) + + # Build changes list + changes = ["filename", "name"] + if "normalized_filename" in update_data: + changes.append("normalized_filename") + + results["updated"] += 1 + results["files"].append({ + "filename": file_info.filename, + "status": "updated", + "reason": "file was renamed", + "name": file_info.name, + "duration": existing_attrs["duration"], + "size": existing_attrs["size"], + "id": existing_attrs["id"], + "error": None, + "changes": changes, + # Store old filename to prevent deletion + "old_filename": existing_attrs["filename"], + }) + + async def _handle_file_modification( + self, + file_info: AudioFileInfo, + existing_attrs: dict, + results: ScanResults, + ) -> None: + """Handle file modification (same filename, different hash).""" + update_data = { + "name": file_info.name, + "duration": file_info.duration, + "size": file_info.size, + "hash": file_info.file_hash, + } + + await self.sound_repo.update(existing_attrs["object"], update_data) + logger.info( + "Updated modified sound: %s (ID: %s)", + file_info.name, + existing_attrs["id"], + ) + + results["updated"] += 1 + results["files"].append({ + "filename": file_info.filename, + "status": "updated", + "reason": "file was modified", + "name": file_info.name, + "duration": file_info.duration, + "size": file_info.size, + "id": existing_attrs["id"], + "error": None, + "changes": ["hash", "duration", "size", "name"], + }) + + async def _handle_new_file( + self, + file_info: AudioFileInfo, + sound_type: str, + results: ScanResults, + ) -> None: + """Handle new file (neither hash nor filename exists).""" + sound_data = { + "type": sound_type, + "name": file_info.name, + "filename": file_info.filename, + "duration": file_info.duration, + "size": file_info.size, + "hash": file_info.file_hash, + "is_deletable": False, + "is_music": False, + "is_normalized": False, + "play_count": 0, + } + + sound = await self.sound_repo.create(sound_data) + logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id) + + results["added"] += 1 + results["files"].append({ + "filename": file_info.filename, + "status": "added", + "reason": None, + "name": file_info.name, + "duration": file_info.duration, + "size": file_info.size, + "id": sound.id, + "error": None, + "changes": None, + }) + + async def _load_existing_sounds(self, sound_type: str) -> tuple[dict, dict]: + """Load existing sounds and create lookup dictionaries.""" + existing_sounds = await self.sound_repo.get_by_type(sound_type) + + # Create lookup dictionaries with immediate attribute access + # to avoid session detachment + sounds_by_hash = {} + sounds_by_filename = {} + + for sound in existing_sounds: + # Capture all attributes immediately while session is valid + sound_data = { + "id": sound.id, + "hash": sound.hash, + "filename": sound.filename, + "name": sound.name, + "duration": sound.duration, + "size": sound.size, + "type": sound.type, + "is_normalized": sound.is_normalized, + "normalized_filename": sound.normalized_filename, + "sound_object": sound, # Keep reference for database operations + } + sounds_by_hash[sound.hash] = sound_data + sounds_by_filename[sound.filename] = sound_data + + return sounds_by_hash, sounds_by_filename + + async def _process_audio_files( + self, + scan_path: Path, + sound_type: str, + sounds_by_hash: dict, + sounds_by_filename: dict, + results: ScanResults, + ) -> set[str]: + """Process all audio files in directory and return processed filenames.""" + # Get all audio files from directory + audio_files = [ + f + for f in scan_path.iterdir() + if f.is_file() and f.suffix.lower() in self.supported_extensions + ] + + # Process each file in directory + processed_filenames = set() + for file_path in audio_files: + results["scanned"] += 1 + filename = file_path.name + processed_filenames.add(filename) + + try: + # Calculate hash first to enable hash-based lookup + file_hash = get_file_hash(file_path) + existing_sound_by_hash = sounds_by_hash.get(file_hash) + existing_sound_by_filename = sounds_by_filename.get(filename) + + # Create sync context + sync_context = SyncContext( + file_path=file_path, + sound_type=sound_type, + existing_sound_by_hash=existing_sound_by_hash, + existing_sound_by_filename=existing_sound_by_filename, + file_hash=file_hash, + ) + + await self._sync_audio_file(sync_context, results) + + # Check if this was a rename and mark old filename as processed + if results["files"] and results["files"][-1].get("old_filename"): + old_filename = results["files"][-1]["old_filename"] + processed_filenames.add(old_filename) + logger.debug("Marked old filename as processed: %s", old_filename) + # Remove temporary tracking field from results + del results["files"][-1]["old_filename"] + except Exception as e: + logger.exception("Error processing file %s", file_path) + results["errors"] += 1 + results["files"].append({ + "filename": filename, + "status": "error", + "reason": None, + "name": None, + "duration": None, + "size": None, + "id": None, + "error": str(e), + "changes": None, + }) + + return processed_filenames + + async def _delete_missing_sounds( + self, + sounds_by_filename: dict, + processed_filenames: set[str], + results: ScanResults, + ) -> None: + """Delete sounds that no longer exist in directory.""" + for filename, sound_data in sounds_by_filename.items(): + if filename not in processed_filenames: + # Attributes already captured in sound_data dictionary + sound_name = sound_data["name"] + sound_duration = sound_data["duration"] + sound_size = sound_data["size"] + sound_id = sound_data["id"] + sound_object = sound_data["sound_object"] + sound_type = sound_data["type"] + sound_is_normalized = sound_data["is_normalized"] + sound_normalized_filename = sound_data["normalized_filename"] + + try: + # Delete the sound from database first + await self.sound_repo.delete(sound_object) + logger.info("Deleted sound no longer in directory: %s", filename) + + # If the sound had a normalized file, delete it too + if sound_is_normalized and sound_normalized_filename: + normalized_base = Path(sound_normalized_filename).name + self._delete_normalized_file(sound_type, normalized_base) + + results["deleted"] += 1 + results["files"].append({ + "filename": filename, + "status": "deleted", + "reason": "file no longer exists", + "name": sound_name, + "duration": sound_duration, + "size": sound_size, + "id": sound_id, + "error": None, + "changes": None, + }) + except Exception as e: + logger.exception("Error deleting sound %s", filename) + results["errors"] += 1 + results["files"].append({ + "filename": filename, + "status": "error", + "reason": "failed to delete", + "name": sound_name, + "duration": sound_duration, + "size": sound_size, + "id": sound_id, + "error": str(e), + "changes": None, + }) + async def scan_directory( self, directory_path: str, @@ -138,368 +534,84 @@ class SoundScannerService: logger.info("Starting sync of directory: %s", directory_path) - # Get all existing sounds of this type from database - existing_sounds = await self.sound_repo.get_by_type(sound_type) + # Load existing sounds from database + sounds_by_hash, sounds_by_filename = await self._load_existing_sounds( + sound_type, + ) - # Create lookup dictionaries with immediate attribute access - # to avoid session detachment - sounds_by_hash = {} - sounds_by_filename = {} - - for sound in existing_sounds: - # Capture all attributes immediately while session is valid - sound_data = { - "id": sound.id, - "hash": sound.hash, - "filename": sound.filename, - "name": sound.name, - "duration": sound.duration, - "size": sound.size, - "type": sound.type, - "is_normalized": sound.is_normalized, - "normalized_filename": sound.normalized_filename, - "sound_object": sound, # Keep reference for database operations - } - sounds_by_hash[sound.hash] = sound_data - sounds_by_filename[sound.filename] = sound_data - - # Get all audio files from directory - audio_files = [ - f - for f in scan_path.iterdir() - if f.is_file() and f.suffix.lower() in self.supported_extensions - ] - - # Process each file in directory - processed_filenames = set() - for file_path in audio_files: - results["scanned"] += 1 - filename = file_path.name - processed_filenames.add(filename) - - try: - # Calculate hash first to enable hash-based lookup - file_hash = get_file_hash(file_path) - existing_sound_by_hash = sounds_by_hash.get(file_hash) - existing_sound_by_filename = sounds_by_filename.get(filename) - - await self._sync_audio_file( - file_path, - sound_type, - existing_sound_by_hash, - existing_sound_by_filename, - file_hash, - results, - ) - - # Check if this was a rename operation and mark old filename as processed - if results["files"] and results["files"][-1].get("old_filename"): - old_filename = results["files"][-1]["old_filename"] - processed_filenames.add(old_filename) - logger.debug("Marked old filename as processed: %s", old_filename) - # Remove temporary tracking field from results - del results["files"][-1]["old_filename"] - except Exception as e: - logger.exception("Error processing file %s", file_path) - results["errors"] += 1 - results["files"].append( - { - "filename": filename, - "status": "error", - "reason": None, - "name": None, - "duration": None, - "size": None, - "id": None, - "error": str(e), - "changes": None, - }, - ) + # Process audio files in directory + processed_filenames = await self._process_audio_files( + scan_path, + sound_type, + sounds_by_hash, + sounds_by_filename, + results, + ) # Delete sounds that no longer exist in directory - for filename, sound_data in sounds_by_filename.items(): - if filename not in processed_filenames: - # Attributes already captured in sound_data dictionary - sound_name = sound_data["name"] - sound_duration = sound_data["duration"] - sound_size = sound_data["size"] - sound_id = sound_data["id"] - sound_object = sound_data["sound_object"] - sound_type = sound_data["type"] - sound_is_normalized = sound_data["is_normalized"] - sound_normalized_filename = sound_data["normalized_filename"] - - try: - # Delete the sound from database first - await self.sound_repo.delete(sound_object) - logger.info("Deleted sound no longer in directory: %s", filename) - - # If the sound had a normalized file, delete it too - if sound_is_normalized and sound_normalized_filename: - normalized_base = Path(sound_normalized_filename).name - self._delete_normalized_file(sound_type, normalized_base) - - results["deleted"] += 1 - results["files"].append( - { - "filename": filename, - "status": "deleted", - "reason": "file no longer exists", - "name": sound_name, - "duration": sound_duration, - "size": sound_size, - "id": sound_id, - "error": None, - "changes": None, - }, - ) - except Exception as e: - logger.exception("Error deleting sound %s", filename) - results["errors"] += 1 - results["files"].append( - { - "filename": filename, - "status": "error", - "reason": "failed to delete", - "name": sound_name, - "duration": sound_duration, - "size": sound_size, - "id": sound_id, - "error": str(e), - "changes": None, - }, - ) + await self._delete_missing_sounds( + sounds_by_filename, + processed_filenames, + results, + ) logger.info("Sync completed: %s", results) return results async def _sync_audio_file( self, - file_path: Path, - sound_type: str, - existing_sound_by_hash: dict | Sound | None, - existing_sound_by_filename: dict | Sound | None, - file_hash: str, + sync_context: SyncContext, results: ScanResults, ) -> None: """Sync a single audio file using hash-first identification strategy.""" - filename = file_path.name - duration = get_audio_duration(file_path) - size = get_file_size(file_path) + filename = sync_context.file_path.name + duration = get_audio_duration(sync_context.file_path) + size = get_file_size(sync_context.file_path) name = self.extract_name_from_filename(filename) - # Extract attributes - handle both dict (normal) and Sound object (tests) - existing_hash_filename = None - existing_hash_name = None - existing_hash_duration = None - existing_hash_size = None - existing_hash_id = None - existing_hash_object = None - existing_hash_type = None - existing_hash_is_normalized = None - existing_hash_normalized_filename = None + # Create file info object + file_info = AudioFileInfo( + filename=filename, + name=name, + duration=duration, + size=size, + file_hash=sync_context.file_hash, + ) - if existing_sound_by_hash is not None: - if isinstance(existing_sound_by_hash, dict): - existing_hash_filename = existing_sound_by_hash["filename"] - existing_hash_name = existing_sound_by_hash["name"] - existing_hash_duration = existing_sound_by_hash["duration"] - existing_hash_size = existing_sound_by_hash["size"] - existing_hash_id = existing_sound_by_hash["id"] - existing_hash_object = existing_sound_by_hash["sound_object"] - existing_hash_type = existing_sound_by_hash["type"] - existing_hash_is_normalized = existing_sound_by_hash["is_normalized"] - existing_hash_normalized_filename = existing_sound_by_hash["normalized_filename"] - else: # Sound object (for tests) - existing_hash_filename = existing_sound_by_hash.filename - existing_hash_name = existing_sound_by_hash.name - existing_hash_duration = existing_sound_by_hash.duration - existing_hash_size = existing_sound_by_hash.size - existing_hash_id = existing_sound_by_hash.id - existing_hash_object = existing_sound_by_hash - existing_hash_type = existing_sound_by_hash.type - existing_hash_is_normalized = existing_sound_by_hash.is_normalized - existing_hash_normalized_filename = existing_sound_by_hash.normalized_filename - - existing_filename_id = None - existing_filename_object = None - if existing_sound_by_filename is not None: - if isinstance(existing_sound_by_filename, dict): - existing_filename_id = existing_sound_by_filename["id"] - existing_filename_object = existing_sound_by_filename["sound_object"] - else: # Sound object (for tests) - existing_filename_id = existing_sound_by_filename.id - existing_filename_object = existing_sound_by_filename + # Extract attributes from existing sounds + hash_attrs = self._extract_sound_attributes(sync_context.existing_sound_by_hash) + filename_attrs = self._extract_sound_attributes( + sync_context.existing_sound_by_filename, + ) # Hash-first identification strategy - if existing_sound_by_hash is not None: + if sync_context.existing_sound_by_hash is not None: # Content exists in database (same hash) - if existing_hash_filename == filename: + if hash_attrs["filename"] == filename: # Same hash, same filename - file unchanged - logger.debug("Sound unchanged: %s", filename) - results["skipped"] += 1 - results["files"].append( - { - "filename": filename, - "status": "skipped", - "reason": "file unchanged", - "name": existing_hash_name, - "duration": existing_hash_duration, - "size": existing_hash_size, - "id": existing_hash_id, - "error": None, - "changes": None, - }, - ) + self._handle_unchanged_file(filename, hash_attrs, results) else: # Same hash, different filename - could be rename or duplicate - # Check if both files exist to determine if it's a duplicate - old_file_path = file_path.parent / existing_hash_filename + old_file_path = sync_context.file_path.parent / hash_attrs["filename"] if old_file_path.exists(): # Both files exist with same hash - this is a duplicate - logger.warning( - "Duplicate file detected: '%s' has same content as existing '%s' (hash: %s). " - "Skipping duplicate file.", + self._handle_duplicate_file( filename, - existing_hash_filename, - file_hash[:8] + "...", - ) - - results["skipped"] += 1 - results["duplicates"] += 1 - results["files"].append( - { - "filename": filename, - "status": "skipped", - "reason": "duplicate content", - "name": existing_hash_name, - "duration": existing_hash_duration, - "size": existing_hash_size, - "id": existing_hash_id, - "error": None, - "changes": None, - }, + hash_attrs["filename"], + sync_context.file_hash, + hash_attrs, + results, ) else: # Old file doesn't exist - this is a genuine rename - update_data = { - "filename": filename, - "name": name, - } - - # If the sound has a normalized file, rename it too - if existing_hash_is_normalized and existing_hash_normalized_filename: - # Extract base filename without path for normalized file - old_normalized_base = Path(existing_hash_normalized_filename).name - new_normalized_base = Path(filename).stem + Path(existing_hash_normalized_filename).suffix - - renamed = self._rename_normalized_file( - existing_hash_type, - old_normalized_base, - new_normalized_base - ) - - if renamed: - update_data["normalized_filename"] = new_normalized_base - logger.info( - "Renamed normalized file: %s -> %s", - old_normalized_base, - new_normalized_base - ) + await self._handle_file_rename(file_info, hash_attrs, results) - await self.sound_repo.update(existing_hash_object, update_data) - logger.info( - "Detected rename: %s -> %s (ID: %s)", - existing_hash_filename, - filename, - existing_hash_id, - ) - - # Build changes list - changes = ["filename", "name"] - if "normalized_filename" in update_data: - changes.append("normalized_filename") - - results["updated"] += 1 - results["files"].append( - { - "filename": filename, - "status": "updated", - "reason": "file was renamed", - "name": name, - "duration": existing_hash_duration, - "size": existing_hash_size, - "id": existing_hash_id, - "error": None, - "changes": changes, - # Store old filename to prevent deletion - "old_filename": existing_hash_filename, - }, - ) - - elif existing_sound_by_filename is not None: + elif sync_context.existing_sound_by_filename is not None: # Same filename but different hash - file was modified - update_data = { - "name": name, - "duration": duration, - "size": size, - "hash": file_hash, - } - - await self.sound_repo.update(existing_filename_object, update_data) - logger.info( - "Updated modified sound: %s (ID: %s)", - name, - existing_filename_id, - ) - - results["updated"] += 1 - results["files"].append( - { - "filename": filename, - "status": "updated", - "reason": "file was modified", - "name": name, - "duration": duration, - "size": size, - "id": existing_filename_id, - "error": None, - "changes": ["hash", "duration", "size", "name"], - }, - ) - + await self._handle_file_modification(file_info, filename_attrs, results) else: # New file - neither hash nor filename exists - sound_data = { - "type": sound_type, - "name": name, - "filename": filename, - "duration": duration, - "size": size, - "hash": file_hash, - "is_deletable": False, - "is_music": False, - "is_normalized": False, - "play_count": 0, - } - - sound = await self.sound_repo.create(sound_data) - logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id) - - results["added"] += 1 - results["files"].append( - { - "filename": filename, - "status": "added", - "reason": None, - "name": name, - "duration": duration, - "size": size, - "id": sound.id, - "error": None, - "changes": None, - }, - ) + await self._handle_new_file(file_info, sync_context.sound_type, results) async def scan_soundboard_directory(self) -> ScanResults: """Sync the default soundboard directory.""" diff --git a/tests/api/v1/admin/test_extraction_endpoints.py b/tests/api/v1/admin/test_extraction_endpoints.py new file mode 100644 index 0000000..ccbe9fc --- /dev/null +++ b/tests/api/v1/admin/test_extraction_endpoints.py @@ -0,0 +1,154 @@ +"""Tests for admin extraction API endpoints.""" + +import pytest +from httpx import AsyncClient + +from app.models.extraction import Extraction +from app.models.user import User + + +class TestAdminExtractionEndpoints: + """Test admin extraction endpoints.""" + + @pytest.mark.asyncio + async def test_get_extraction_processor_status(self, authenticated_admin_client): + """Test getting extraction processor status.""" + response = await authenticated_admin_client.get( + "/api/v1/admin/extractions/status", + ) + + assert response.status_code == 200 + data = response.json() + + # Check expected status fields (match actual processor status format) + assert "currently_processing" in data + assert "max_concurrent" in data + assert "available_slots" in data + assert "processing_ids" in data + assert isinstance(data["currently_processing"], int) + assert isinstance(data["max_concurrent"], int) + assert isinstance(data["available_slots"], int) + assert isinstance(data["processing_ids"], list) + + @pytest.mark.asyncio + async def test_admin_delete_extraction_success( + self, + authenticated_admin_client, + test_session, + test_plan, + ): + """Test admin successfully deleting any extraction.""" + # Create a test user + user = User( + name="Test User", + email="test@example.com", + is_active=True, + plan_id=test_plan.id, + ) + test_session.add(user) + await test_session.commit() + await test_session.refresh(user) + + # Create test extraction + extraction = Extraction( + url="https://example.com/video", + user_id=user.id, + status="completed", + ) + test_session.add(extraction) + await test_session.commit() + await test_session.refresh(extraction) + + # Admin delete the extraction + response = await authenticated_admin_client.delete( + f"/api/v1/admin/extractions/{extraction.id}", + ) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == f"Extraction {extraction.id} deleted successfully" + + # Verify extraction was deleted from database + deleted_extraction = await test_session.get(Extraction, extraction.id) + assert deleted_extraction is None + + @pytest.mark.asyncio + async def test_admin_delete_extraction_not_found(self, authenticated_admin_client): + """Test admin deleting non-existent extraction.""" + response = await authenticated_admin_client.delete( + "/api/v1/admin/extractions/999", + ) + + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + @pytest.mark.asyncio + async def test_admin_delete_extraction_any_user( + self, + authenticated_admin_client, + test_session, + test_plan, + ): + """Test admin deleting extraction owned by any user.""" + # Create another user and their extraction + other_user = User( + name="Other User", + email="other@example.com", + is_active=True, + plan_id=test_plan.id, + ) + test_session.add(other_user) + await test_session.commit() + await test_session.refresh(other_user) + + extraction = Extraction( + url="https://example.com/video", + user_id=other_user.id, + status="completed", + ) + test_session.add(extraction) + await test_session.commit() + await test_session.refresh(extraction) + + # Admin can delete any user's extraction + response = await authenticated_admin_client.delete( + f"/api/v1/admin/extractions/{extraction.id}", + ) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == f"Extraction {extraction.id} deleted successfully" + + @pytest.mark.asyncio + async def test_delete_extraction_non_admin(self, authenticated_client, test_user, test_session): + """Test non-admin user cannot access admin deletion endpoint.""" + # Create test extraction + extraction = Extraction( + url="https://example.com/video", + user_id=test_user.id, + status="completed", + ) + test_session.add(extraction) + await test_session.commit() + await test_session.refresh(extraction) + + # Non-admin user cannot access admin endpoint + response = await authenticated_client.delete( + f"/api/v1/admin/extractions/{extraction.id}", + ) + + assert response.status_code == 403 + data = response.json() + assert "permissions" in data["detail"].lower() + + @pytest.mark.asyncio + async def test_admin_endpoints_unauthenticated(self, client: AsyncClient): + """Test admin endpoints require authentication.""" + # Status endpoint + response = await client.get("/api/v1/admin/extractions/status") + assert response.status_code == 401 + + # Delete endpoint + response = await client.delete("/api/v1/admin/extractions/1") + assert response.status_code == 401 diff --git a/tests/api/v1/admin/test_sound_endpoints.py b/tests/api/v1/admin/test_sound_endpoints.py index ea5bbcf..d62119f 100644 --- a/tests/api/v1/admin/test_sound_endpoints.py +++ b/tests/api/v1/admin/test_sound_endpoints.py @@ -31,6 +31,7 @@ class TestAdminSoundEndpoints: "deleted": 1, "skipped": 0, "errors": 0, + "duplicates": 0, "files": [ { "filename": "test1.mp3", @@ -176,6 +177,7 @@ class TestAdminSoundEndpoints: "deleted": 0, "skipped": 0, "errors": 0, + "duplicates": 0, "files": [ { "filename": "custom1.wav", diff --git a/tests/api/v1/test_extraction_endpoints.py b/tests/api/v1/test_extraction_endpoints.py index 71db18b..7087cc5 100644 --- a/tests/api/v1/test_extraction_endpoints.py +++ b/tests/api/v1/test_extraction_endpoints.py @@ -229,3 +229,73 @@ class TestExtractionEndpoints: break assert processing_found, "Processing extraction not found in results" + + @pytest.mark.asyncio + async def test_delete_extraction_success(self, authenticated_client, test_user, test_session): + """Test successful deletion of user's own extraction.""" + # Create test extraction + extraction = Extraction( + url="https://example.com/video", + user_id=test_user.id, + status="completed", + ) + test_session.add(extraction) + await test_session.commit() + await test_session.refresh(extraction) + + # Delete the extraction + response = await authenticated_client.delete(f"/api/v1/extractions/{extraction.id}") + + assert response.status_code == 200 + data = response.json() + assert data["message"] == f"Extraction {extraction.id} deleted successfully" + + # Verify extraction was deleted from database + deleted_extraction = await test_session.get(Extraction, extraction.id) + assert deleted_extraction is None + + @pytest.mark.asyncio + async def test_delete_extraction_not_found(self, authenticated_client): + """Test deleting non-existent extraction.""" + response = await authenticated_client.delete("/api/v1/extractions/999") + + assert response.status_code == 404 + data = response.json() + assert "not found" in data["detail"].lower() + + @pytest.mark.asyncio + async def test_delete_extraction_permission_denied(self, authenticated_client, test_session, test_plan): + """Test deleting another user's extraction.""" + # Create extraction owned by different user + other_user = User( + name="Other User", + email="other@example.com", + is_active=True, + plan_id=test_plan.id, + ) + test_session.add(other_user) + await test_session.commit() + await test_session.refresh(other_user) + + extraction = Extraction( + url="https://example.com/video", + user_id=other_user.id, + status="completed", + ) + test_session.add(extraction) + await test_session.commit() + await test_session.refresh(extraction) + + # Try to delete other user's extraction + response = await authenticated_client.delete(f"/api/v1/extractions/{extraction.id}") + + assert response.status_code == 403 + data = response.json() + assert "permission" in data["detail"].lower() + + @pytest.mark.asyncio + async def test_delete_extraction_unauthenticated(self, client): + """Test deleting extraction without authentication.""" + response = await client.delete("/api/v1/extractions/1") + + assert response.status_code == 401 diff --git a/tests/api/v1/test_favorite_endpoints.py b/tests/api/v1/test_favorite_endpoints.py index 2f0fb7a..b3001e0 100644 --- a/tests/api/v1/test_favorite_endpoints.py +++ b/tests/api/v1/test_favorite_endpoints.py @@ -1,5 +1,7 @@ """Tests for favorite API endpoints.""" +from contextlib import suppress + import pytest import pytest_asyncio from httpx import AsyncClient @@ -129,10 +131,8 @@ class TestFavoriteEndpoints: ) -> None: """Test successfully adding a sound to favorites.""" # Clean up any existing favorite first - try: + with suppress(Exception): await authenticated_client.delete("/api/v1/favorites/sounds/1") - except: - pass # It's ok if it doesn't exist response = await authenticated_client.post("/api/v1/favorites/sounds/1") @@ -176,10 +176,8 @@ class TestFavoriteEndpoints: ) -> None: """Test successfully adding a playlist to favorites.""" # Clean up any existing favorite first - try: + with suppress(Exception): await authenticated_client.delete("/api/v1/favorites/playlists/1") - except: - pass # It's ok if it doesn't exist response = await authenticated_client.post("/api/v1/favorites/playlists/1") @@ -473,10 +471,8 @@ class TestFavoriteEndpoints: ) -> None: """Test checking if a sound is favorited (false case).""" # Make sure sound 1 is not favorited - try: + with suppress(Exception): await authenticated_client.delete("/api/v1/favorites/sounds/1") - except: - pass # It's ok if it doesn't exist response = await authenticated_client.get("/api/v1/favorites/sounds/1/check") @@ -509,10 +505,8 @@ class TestFavoriteEndpoints: ) -> None: """Test checking if a playlist is favorited (false case).""" # Make sure playlist 1 is not favorited - try: + with suppress(Exception): await authenticated_client.delete("/api/v1/favorites/playlists/1") - except: - pass # It's ok if it doesn't exist response = await authenticated_client.get("/api/v1/favorites/playlists/1/check") diff --git a/tests/repositories/test_extraction.py b/tests/repositories/test_extraction.py index 0f74b52..046610d 100644 --- a/tests/repositories/test_extraction.py +++ b/tests/repositories/test_extraction.py @@ -144,7 +144,7 @@ class TestExtractionRepository: ), Extraction( id=2, - service="youtube", + service="youtube", service_id="test456", url="https://www.youtube.com/watch?v=test2", user_id=1, diff --git a/tests/services/test_extraction.py b/tests/services/test_extraction.py index 49637e9..e6a2908 100644 --- a/tests/services/test_extraction.py +++ b/tests/services/test_extraction.py @@ -541,3 +541,143 @@ class TestExtractionService: assert result[0]["id"] == 1 assert result[0]["status"] == "pending" assert result[0]["user_name"] == "Test User" + + @pytest.mark.asyncio + async def test_delete_extraction_with_sound(self, extraction_service, test_user): + """Test deleting extraction with associated sound and files.""" + import tempfile + from pathlib import Path + + # Create temporary directories for testing + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir_path = Path(temp_dir) + + # Set up temporary directories + original_dir = temp_dir_path / "originals" / "extracted" + normalized_dir = temp_dir_path / "normalized" / "extracted" + thumbnail_dir = temp_dir_path / "thumbnails" + + original_dir.mkdir(parents=True) + normalized_dir.mkdir(parents=True) + thumbnail_dir.mkdir(parents=True) + + # Create test files + audio_file = original_dir / "test_audio.mp3" + normalized_file = normalized_dir / "test_audio.mp3" + thumbnail_file = thumbnail_dir / "test_thumb.jpg" + + audio_file.write_text("audio content") + normalized_file.write_text("normalized content") + thumbnail_file.write_text("thumbnail content") + + # Create extraction and sound records + extraction = Extraction( + id=1, + url="https://example.com/video", + user_id=test_user.id, + status="completed", + sound_id=1, + ) + + sound = Sound( + id=1, + type="EXT", + name="Test Audio", + filename="test_audio.mp3", + duration=60000, + size=2048, + hash="test_hash", + is_normalized=True, + normalized_filename="test_audio.mp3", + thumbnail="test_thumb.jpg", + is_deletable=True, + is_music=True, + ) + + # Mock repository methods + extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction) + extraction_service.sound_repo.get_by_id = AsyncMock(return_value=sound) + extraction_service.extraction_repo.delete = AsyncMock() + extraction_service.sound_repo.delete = AsyncMock() + extraction_service.session.commit = AsyncMock() + extraction_service.session.rollback = AsyncMock() + + # Monkey patch the paths in the service method + import app.services.extraction + original_path_class = app.services.extraction.Path + + def mock_path(*args: str): + path_str = str(args[0]) + if path_str == "sounds/originals/extracted": + return original_dir + if path_str == "sounds/normalized/extracted": + return normalized_dir + if path_str.endswith("thumbnails"): + return thumbnail_dir + return original_path_class(*args) + + # Mock the Path constructor and settings + with patch("app.services.extraction.Path", side_effect=mock_path), \ + patch("app.services.extraction.settings") as mock_settings: + mock_settings.EXTRACTION_THUMBNAILS_DIR = str(thumbnail_dir) + + # Test deletion + result = await extraction_service.delete_extraction(1, test_user.id) + + assert result is True + + # Verify repository calls + extraction_service.extraction_repo.get_by_id.assert_called_once_with(1) + extraction_service.sound_repo.get_by_id.assert_called_once_with(1) + extraction_service.extraction_repo.delete.assert_called_once_with(extraction) + extraction_service.sound_repo.delete.assert_called_once_with(sound) + extraction_service.session.commit.assert_called_once() + + # Verify files were deleted + assert not audio_file.exists() + assert not normalized_file.exists() + assert not thumbnail_file.exists() + + @pytest.mark.asyncio + async def test_delete_extraction_not_found(self, extraction_service, test_user): + """Test deleting non-existent extraction.""" + extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=None) + + result = await extraction_service.delete_extraction(999, test_user.id) + + assert result is False + + @pytest.mark.asyncio + async def test_delete_extraction_permission_denied(self, extraction_service, test_user): + """Test deleting extraction owned by another user.""" + extraction = Extraction( + id=1, + url="https://example.com/video", + user_id=999, # Different user ID + status="completed", + ) + + extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction) + + with pytest.raises(ValueError, match="You don't have permission"): + await extraction_service.delete_extraction(1, test_user.id) + + @pytest.mark.asyncio + async def test_delete_extraction_admin(self, extraction_service, test_user): + """Test admin deleting any extraction.""" + extraction = Extraction( + id=1, + url="https://example.com/video", + user_id=999, # Different user ID + status="completed", + ) + + extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction) + extraction_service.extraction_repo.delete = AsyncMock() + extraction_service.session.commit = AsyncMock() + + # Admin deletion (user_id=None) + result = await extraction_service.delete_extraction(1, None) + + assert result is True + extraction_service.extraction_repo.delete.assert_called_once_with(extraction) diff --git a/tests/services/test_favorite.py b/tests/services/test_favorite.py index 4e5f9c8..e633f95 100644 --- a/tests/services/test_favorite.py +++ b/tests/services/test_favorite.py @@ -1,6 +1,7 @@ """Tests for favorite service.""" from collections.abc import AsyncGenerator +from dataclasses import dataclass from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -14,6 +15,31 @@ from app.models.user import User from app.services.favorite import FavoriteService +@dataclass +class MockRepositories: + """Container for all mock repositories.""" + + user_repo: AsyncMock + favorite_repo: AsyncMock + sound_repo: AsyncMock + socket_manager: AsyncMock + + +@dataclass +class MockServiceDependencies: + """Container for all mock service dependencies.""" + + sound_repo_class: AsyncMock + user_repo_class: AsyncMock + favorite_repo_class: AsyncMock + socket_manager: AsyncMock + sound_repo: AsyncMock + user_repo: AsyncMock + favorite_repo: AsyncMock + playlist_repo_class: AsyncMock | None = None + playlist_repo: AsyncMock | None = None + + class TestFavoriteService: """Test favorite service operations.""" @@ -71,34 +97,75 @@ class TestFavoriteService: "playlist_repo": AsyncMock(), } - @patch("app.services.favorite.socket_manager") - @patch("app.services.favorite.FavoriteRepository") - @patch("app.services.favorite.UserRepository") - @patch("app.services.favorite.SoundRepository") + @pytest_asyncio.fixture + async def mock_sound_favorite_dependencies(self) -> MockServiceDependencies: + """Create mock dependencies for sound favorite operations.""" + with ( + patch("app.services.favorite.SoundRepository") as mock_sound_repo_class, + patch("app.services.favorite.UserRepository") as mock_user_repo_class, + patch("app.services.favorite.FavoriteRepository") as mock_favorite_repo_class, + patch("app.services.favorite.socket_manager") as mock_socket_manager, + ): + mock_sound_repo = AsyncMock() + mock_user_repo = AsyncMock() + mock_favorite_repo = AsyncMock() + + mock_sound_repo_class.return_value = mock_sound_repo + mock_user_repo_class.return_value = mock_user_repo + mock_favorite_repo_class.return_value = mock_favorite_repo + + yield MockServiceDependencies( + sound_repo_class=mock_sound_repo_class, + user_repo_class=mock_user_repo_class, + favorite_repo_class=mock_favorite_repo_class, + socket_manager=mock_socket_manager, + sound_repo=mock_sound_repo, + user_repo=mock_user_repo, + favorite_repo=mock_favorite_repo, + ) + + @pytest_asyncio.fixture + async def mock_playlist_favorite_dependencies(self) -> MockServiceDependencies: + """Create mock dependencies for playlist favorite operations.""" + with ( + patch("app.services.favorite.UserRepository") as mock_user_repo_class, + patch("app.services.favorite.PlaylistRepository") as mock_playlist_repo_class, + patch("app.services.favorite.FavoriteRepository") as mock_favorite_repo_class, + ): + mock_user_repo = AsyncMock() + mock_playlist_repo = AsyncMock() + mock_favorite_repo = AsyncMock() + + mock_user_repo_class.return_value = mock_user_repo + mock_playlist_repo_class.return_value = mock_playlist_repo + mock_favorite_repo_class.return_value = mock_favorite_repo + + yield MockServiceDependencies( + sound_repo_class=AsyncMock(), # Not used in playlist tests + user_repo_class=mock_user_repo_class, + favorite_repo_class=mock_favorite_repo_class, + socket_manager=AsyncMock(), # Not used in playlist tests + sound_repo=AsyncMock(), # Not used in playlist tests + user_repo=mock_user_repo, + favorite_repo=mock_favorite_repo, + playlist_repo_class=mock_playlist_repo_class, + playlist_repo=mock_playlist_repo, + ) + @pytest.mark.asyncio async def test_add_sound_favorite_success( self, - mock_sound_repo_class: AsyncMock, - mock_user_repo_class: AsyncMock, - mock_favorite_repo_class: AsyncMock, - mock_socket_manager: AsyncMock, + mock_sound_favorite_dependencies: MockServiceDependencies, favorite_service: FavoriteService, test_user: User, test_sound: Sound, ) -> None: """Test successfully adding a sound favorite.""" # Setup mocks - mock_favorite_repo = AsyncMock() - mock_user_repo = AsyncMock() - mock_sound_repo = AsyncMock() - - mock_favorite_repo_class.return_value = mock_favorite_repo - mock_user_repo_class.return_value = mock_user_repo - mock_sound_repo_class.return_value = mock_sound_repo - - mock_user_repo.get_by_id.return_value = test_user - mock_sound_repo.get_by_id.return_value = test_sound - mock_favorite_repo.get_by_user_and_sound.return_value = None + mocks = mock_sound_favorite_dependencies + mocks.user_repo.get_by_id.return_value = test_user + mocks.sound_repo.get_by_id.return_value = test_sound + mocks.favorite_repo.get_by_user_and_sound.return_value = None expected_favorite = Favorite( id=1, @@ -106,23 +173,23 @@ class TestFavoriteService: sound_id=test_sound.id, playlist_id=None, ) - mock_favorite_repo.create.return_value = expected_favorite - mock_favorite_repo.count_sound_favorites.return_value = 1 + mocks.favorite_repo.create.return_value = expected_favorite + mocks.favorite_repo.count_sound_favorites.return_value = 1 # Execute result = await favorite_service.add_sound_favorite(test_user.id, test_sound.id) # Verify assert result == expected_favorite - mock_user_repo.get_by_id.assert_called_once_with(test_user.id) - mock_sound_repo.get_by_id.assert_called_once_with(test_sound.id) - mock_favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id) - mock_favorite_repo.create.assert_called_once_with({ + mocks.user_repo.get_by_id.assert_called_once_with(test_user.id) + mocks.sound_repo.get_by_id.assert_called_once_with(test_sound.id) + mocks.favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id) + mocks.favorite_repo.create.assert_called_once_with({ "user_id": test_user.id, "sound_id": test_sound.id, "playlist_id": None, }) - mock_socket_manager.broadcast_to_all.assert_called_once() + mocks.socket_manager.broadcast_to_all.assert_called_once() @patch("app.services.favorite.UserRepository") @pytest.mark.asyncio @@ -161,62 +228,38 @@ class TestFavoriteService: with pytest.raises(ValueError, match="Sound with ID 1 not found"): await favorite_service.add_sound_favorite(test_user.id, 1) - @patch("app.services.favorite.FavoriteRepository") - @patch("app.services.favorite.SoundRepository") - @patch("app.services.favorite.UserRepository") @pytest.mark.asyncio async def test_add_sound_favorite_already_exists( self, - mock_user_repo_class: AsyncMock, - mock_sound_repo_class: AsyncMock, - mock_favorite_repo_class: AsyncMock, + mock_sound_favorite_dependencies: MockServiceDependencies, favorite_service: FavoriteService, test_user: User, test_sound: Sound, ) -> None: """Test adding sound favorite that already exists.""" - mock_user_repo = AsyncMock() - mock_sound_repo = AsyncMock() - mock_favorite_repo = AsyncMock() - - mock_user_repo_class.return_value = mock_user_repo - mock_sound_repo_class.return_value = mock_sound_repo - mock_favorite_repo_class.return_value = mock_favorite_repo - - mock_user_repo.get_by_id.return_value = test_user - mock_sound_repo.get_by_id.return_value = test_sound + mocks = mock_sound_favorite_dependencies + mocks.user_repo.get_by_id.return_value = test_user + mocks.sound_repo.get_by_id.return_value = test_sound existing_favorite = Favorite(user_id=test_user.id, sound_id=test_sound.id) - mock_favorite_repo.get_by_user_and_sound.return_value = existing_favorite + mocks.favorite_repo.get_by_user_and_sound.return_value = existing_favorite with pytest.raises(ValueError, match="already favorited"): await favorite_service.add_sound_favorite(test_user.id, test_sound.id) - @patch("app.services.favorite.FavoriteRepository") - @patch("app.services.favorite.PlaylistRepository") - @patch("app.services.favorite.UserRepository") @pytest.mark.asyncio async def test_add_playlist_favorite_success( self, - mock_user_repo_class: AsyncMock, - mock_playlist_repo_class: AsyncMock, - mock_favorite_repo_class: AsyncMock, + mock_playlist_favorite_dependencies: MockServiceDependencies, favorite_service: FavoriteService, test_user: User, test_playlist: Playlist, ) -> None: """Test successfully adding a playlist favorite.""" # Setup mocks - mock_favorite_repo = AsyncMock() - mock_user_repo = AsyncMock() - mock_playlist_repo = AsyncMock() - - mock_favorite_repo_class.return_value = mock_favorite_repo - mock_user_repo_class.return_value = mock_user_repo - mock_playlist_repo_class.return_value = mock_playlist_repo - - mock_user_repo.get_by_id.return_value = test_user - mock_playlist_repo.get_by_id.return_value = test_playlist - mock_favorite_repo.get_by_user_and_playlist.return_value = None + mocks = mock_playlist_favorite_dependencies + mocks.user_repo.get_by_id.return_value = test_user + mocks.playlist_repo.get_by_id.return_value = test_playlist + mocks.favorite_repo.get_by_user_and_playlist.return_value = None expected_favorite = Favorite( id=1, @@ -224,59 +267,45 @@ class TestFavoriteService: sound_id=None, playlist_id=test_playlist.id, ) - mock_favorite_repo.create.return_value = expected_favorite + mocks.favorite_repo.create.return_value = expected_favorite # Execute result = await favorite_service.add_playlist_favorite(test_user.id, test_playlist.id) # Verify assert result == expected_favorite - mock_user_repo.get_by_id.assert_called_once_with(test_user.id) - mock_playlist_repo.get_by_id.assert_called_once_with(test_playlist.id) - mock_favorite_repo.get_by_user_and_playlist.assert_called_once_with(test_user.id, test_playlist.id) - mock_favorite_repo.create.assert_called_once_with({ + mocks.user_repo.get_by_id.assert_called_once_with(test_user.id) + mocks.playlist_repo.get_by_id.assert_called_once_with(test_playlist.id) + mocks.favorite_repo.get_by_user_and_playlist.assert_called_once_with(test_user.id, test_playlist.id) + mocks.favorite_repo.create.assert_called_once_with({ "user_id": test_user.id, "sound_id": None, "playlist_id": test_playlist.id, }) - @patch("app.services.favorite.socket_manager") - @patch("app.services.favorite.FavoriteRepository") - @patch("app.services.favorite.SoundRepository") - @patch("app.services.favorite.UserRepository") @pytest.mark.asyncio async def test_remove_sound_favorite_success( self, - mock_user_repo_class: AsyncMock, - mock_sound_repo_class: AsyncMock, - mock_favorite_repo_class: AsyncMock, - mock_socket_manager: AsyncMock, + mock_sound_favorite_dependencies: MockServiceDependencies, favorite_service: FavoriteService, test_user: User, test_sound: Sound, ) -> None: """Test successfully removing a sound favorite.""" - mock_favorite_repo = AsyncMock() - mock_user_repo = AsyncMock() - mock_sound_repo = AsyncMock() - - mock_favorite_repo_class.return_value = mock_favorite_repo - mock_user_repo_class.return_value = mock_user_repo - mock_sound_repo_class.return_value = mock_sound_repo - + mocks = mock_sound_favorite_dependencies existing_favorite = Favorite(user_id=test_user.id, sound_id=test_sound.id) - mock_favorite_repo.get_by_user_and_sound.return_value = existing_favorite - mock_user_repo.get_by_id.return_value = test_user - mock_sound_repo.get_by_id.return_value = test_sound - mock_favorite_repo.count_sound_favorites.return_value = 0 + mocks.favorite_repo.get_by_user_and_sound.return_value = existing_favorite + mocks.user_repo.get_by_id.return_value = test_user + mocks.sound_repo.get_by_id.return_value = test_sound + mocks.favorite_repo.count_sound_favorites.return_value = 0 # Execute await favorite_service.remove_sound_favorite(test_user.id, test_sound.id) # Verify - mock_favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id) - mock_favorite_repo.delete.assert_called_once_with(existing_favorite) - mock_socket_manager.broadcast_to_all.assert_called_once() + mocks.favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id) + mocks.favorite_repo.delete.assert_called_once_with(existing_favorite) + mocks.socket_manager.broadcast_to_all.assert_called_once() @patch("app.services.favorite.FavoriteRepository") @pytest.mark.asyncio @@ -503,46 +532,31 @@ class TestFavoriteService: assert result == 3 mock_favorite_repo.count_playlist_favorites.assert_called_once_with(1) - @patch("app.services.favorite.socket_manager") - @patch("app.services.favorite.FavoriteRepository") - @patch("app.services.favorite.SoundRepository") - @patch("app.services.favorite.UserRepository") @pytest.mark.asyncio async def test_socket_broadcast_error_handling( self, - mock_user_repo_class: AsyncMock, - mock_sound_repo_class: AsyncMock, - mock_favorite_repo_class: AsyncMock, - mock_socket_manager: AsyncMock, + mock_sound_favorite_dependencies: MockServiceDependencies, favorite_service: FavoriteService, test_user: User, test_sound: Sound, ) -> None: """Test that socket broadcast errors don't affect the operation.""" - # Setup mocks - mock_favorite_repo = AsyncMock() - mock_user_repo = AsyncMock() - mock_sound_repo = AsyncMock() - - mock_favorite_repo_class.return_value = mock_favorite_repo - mock_user_repo_class.return_value = mock_user_repo - mock_sound_repo_class.return_value = mock_sound_repo - - mock_user_repo.get_by_id.return_value = test_user - mock_sound_repo.get_by_id.return_value = test_sound - mock_favorite_repo.get_by_user_and_sound.return_value = None + mocks = mock_sound_favorite_dependencies + mocks.user_repo.get_by_id.return_value = test_user + mocks.sound_repo.get_by_id.return_value = test_sound + mocks.favorite_repo.get_by_user_and_sound.return_value = None expected_favorite = Favorite(id=1, user_id=test_user.id, sound_id=test_sound.id) - mock_favorite_repo.create.return_value = expected_favorite - mock_favorite_repo.count_sound_favorites.return_value = 1 + mocks.favorite_repo.create.return_value = expected_favorite + mocks.favorite_repo.count_sound_favorites.return_value = 1 # Make socket broadcast raise an exception - mock_socket_manager.broadcast_to_all.side_effect = Exception("Socket error") + mocks.socket_manager.broadcast_to_all.side_effect = Exception("Socket error") # Execute - should not raise exception despite socket error result = await favorite_service.add_sound_favorite(test_user.id, test_sound.id) # Verify operation still succeeded assert result == expected_favorite - mock_favorite_repo.create.assert_called_once() + mocks.favorite_repo.create.assert_called_once() diff --git a/tests/services/test_sound_scanner.py b/tests/services/test_sound_scanner.py index 7d5f2a5..9f945d1 100644 --- a/tests/services/test_sound_scanner.py +++ b/tests/services/test_sound_scanner.py @@ -8,7 +8,7 @@ import pytest from sqlmodel.ext.asyncio.session import AsyncSession from app.models.sound import Sound -from app.services.sound_scanner import SoundScannerService +from app.services.sound_scanner import SoundScannerService, SyncContext class TestSoundScannerService: @@ -154,15 +154,15 @@ class TestSoundScannerService: } # Set the existing sound filename to match temp file for "unchanged" test existing_sound.filename = temp_path.name - - await scanner_service._sync_audio_file( - temp_path, - "SDB", - existing_sound, # existing_sound_by_hash (same hash) - None, # existing_sound_by_filename (no conflict) - "same_hash", - results, + + sync_context = SyncContext( + file_path=temp_path, + sound_type="SDB", + existing_sound_by_hash=existing_sound, + existing_sound_by_filename=None, + file_hash="same_hash", ) + await scanner_service._sync_audio_file(sync_context, results) assert results["skipped"] == 1 assert results["added"] == 0 @@ -186,7 +186,7 @@ class TestSoundScannerService: size=1024, hash="same_hash", ) - + scanner_service.sound_repo.update = AsyncMock(return_value=existing_sound) # Mock file operations to return same hash @@ -209,15 +209,15 @@ class TestSoundScannerService: "errors": 0, "files": [], } - - await scanner_service._sync_audio_file( - temp_path, - "SDB", - existing_sound, # existing_sound_by_hash (same hash) - None, # existing_sound_by_filename (different filename) - "same_hash", - results, + + sync_context = SyncContext( + file_path=temp_path, + sound_type="SDB", + existing_sound_by_hash=existing_sound, + existing_sound_by_filename=None, + file_hash="same_hash", ) + await scanner_service._sync_audio_file(sync_context, results) # Should be marked as updated (renamed) assert results["updated"] == 1 @@ -227,12 +227,12 @@ class TestSoundScannerService: assert results["files"][0]["status"] == "updated" assert results["files"][0]["reason"] == "file was renamed" assert results["files"][0]["changes"] == ["filename", "name"] - + # Verify update was called with new filename scanner_service.sound_repo.update.assert_called_once() call_args = scanner_service.sound_repo.update.call_args[0][1] # update_data assert call_args["filename"] == temp_path.name - + finally: temp_path.unlink() @@ -249,22 +249,21 @@ class TestSoundScannerService: size=1024, hash="same_hash", ) - + # Mock the repository to return the existing sound scanner_service.sound_repo.get_by_type = AsyncMock(return_value=[existing_sound]) scanner_service.sound_repo.update = AsyncMock() scanner_service.sound_repo.delete = AsyncMock() - + # Create temporary directory with renamed file import tempfile - import os - + with tempfile.TemporaryDirectory() as temp_dir: # Create the "renamed" file (same hash, different name) - new_file_path = os.path.join(temp_dir, "new_name.mp3") - with open(new_file_path, "wb") as f: + new_file_path = Path(temp_dir) / "new_name.mp3" + with new_file_path.open("wb") as f: f.write(b"test audio content") # This will produce consistent hash - + # Mock file operations to return same hash with ( patch("app.services.sound_scanner.get_file_hash", return_value="same_hash"), @@ -272,18 +271,18 @@ class TestSoundScannerService: patch("app.services.sound_scanner.get_file_size", return_value=1024), ): results = await scanner_service.scan_directory(temp_dir, "SDB") - + # Should have detected one renamed file assert results["updated"] == 1 assert results["deleted"] == 0 # This is the key assertion - no deletion! assert results["added"] == 0 assert len(results["files"]) == 1 - + # Verify it was marked as renamed file_result = results["files"][0] assert file_result["status"] == "updated" assert file_result["reason"] == "file was renamed" - + # Verify update was called but delete was NOT called scanner_service.sound_repo.update.assert_called_once() scanner_service.sound_repo.delete.assert_not_called() @@ -301,25 +300,24 @@ class TestSoundScannerService: size=1024, hash="same_hash", ) - + # Mock the repository scanner_service.sound_repo.get_by_type = AsyncMock(return_value=[existing_sound]) scanner_service.sound_repo.update = AsyncMock() - + # Create temporary directory with both original and duplicate files import tempfile - import os - + with tempfile.TemporaryDirectory() as temp_dir: # Create both files (simulating duplicate content) - original_path = os.path.join(temp_dir, "original.mp3") - duplicate_path = os.path.join(temp_dir, "duplicate.mp3") - - with open(original_path, "wb") as f: + original_path = Path(temp_dir) / "original.mp3" + duplicate_path = Path(temp_dir) / "duplicate.mp3" + + with original_path.open("wb") as f: f.write(b"test audio content") - with open(duplicate_path, "wb") as f: + with duplicate_path.open("wb") as f: f.write(b"test audio content") # Same content = same hash - + # Mock file operations with ( patch("app.services.sound_scanner.get_file_hash", return_value="same_hash"), @@ -327,14 +325,14 @@ class TestSoundScannerService: patch("app.services.sound_scanner.get_file_size", return_value=1024), ): results = await scanner_service.scan_directory(temp_dir, "SDB") - + # Should have 1 unchanged (original) and 1 skipped (duplicate) assert results["skipped"] == 2 # Both files have same hash, both skipped assert results["duplicates"] == 1 # One duplicate detected assert results["updated"] == 0 assert results["added"] == 0 assert results["deleted"] == 0 - + # Check that duplicate was properly detected skipped_files = [f for f in results["files"] if f["status"] == "skipped"] duplicate_file = next((f for f in skipped_files if "duplicate" in f["reason"]), None) @@ -375,14 +373,14 @@ class TestSoundScannerService: "errors": 0, "files": [], } - await scanner_service._sync_audio_file( - temp_path, - "SDB", - None, # existing_sound_by_hash - None, # existing_sound_by_filename - "test_hash", - results, + sync_context = SyncContext( + file_path=temp_path, + sound_type="SDB", + existing_sound_by_hash=None, + existing_sound_by_filename=None, + file_hash="test_hash", ) + await scanner_service._sync_audio_file(sync_context, results) assert results["added"] == 1 assert results["skipped"] == 0 @@ -439,14 +437,14 @@ class TestSoundScannerService: "errors": 0, "files": [], } - await scanner_service._sync_audio_file( - temp_path, - "SDB", - None, # existing_sound_by_hash (different hash) - existing_sound, # existing_sound_by_filename - "new_hash", - results, + sync_context = SyncContext( + file_path=temp_path, + sound_type="SDB", + existing_sound_by_hash=None, + existing_sound_by_filename=existing_sound, + file_hash="new_hash", ) + await scanner_service._sync_audio_file(sync_context, results) assert results["updated"] == 1 assert results["added"] == 0 @@ -504,14 +502,14 @@ class TestSoundScannerService: "errors": 0, "files": [], } - await scanner_service._sync_audio_file( - temp_path, - "CUSTOM", - None, # existing_sound_by_hash - None, # existing_sound_by_filename - "custom_hash", - results, + sync_context = SyncContext( + file_path=temp_path, + sound_type="CUSTOM", + existing_sound_by_hash=None, + existing_sound_by_filename=None, + file_hash="custom_hash", ) + await scanner_service._sync_audio_file(sync_context, results) assert results["added"] == 1 assert results["skipped"] == 0 @@ -533,41 +531,40 @@ class TestSoundScannerService: @pytest.mark.asyncio async def test_sync_audio_file_rename_with_normalized_file( - self, test_session, scanner_service + self, test_session, scanner_service, ): """Test that renaming a sound file also renames its normalized file.""" # Create temporary directories for testing - from pathlib import Path import tempfile - + from pathlib import Path + with tempfile.TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) - + # Set up the scanner's normalized directories to use temp dir scanner_service.normalized_directories = { - "SDB": str(temp_dir_path / "normalized" / "soundboard") + "SDB": str(temp_dir_path / "normalized" / "soundboard"), } - + # Create the normalized directory normalized_dir = temp_dir_path / "normalized" / "soundboard" normalized_dir.mkdir(parents=True) - + # Create the old normalized file old_normalized_file = normalized_dir / "old_sound.mp3" old_normalized_file.write_text("normalized audio content") - + # Create the audio files (they need to exist for the scanner) - old_path = temp_dir_path / "old_sound.mp3" new_path = temp_dir_path / "new_sound.mp3" - + # Create a dummy audio file for the new path new_path.write_bytes(b"fake audio data for testing") - + # Mock the audio utility functions since we're using fake files from unittest.mock import patch - with patch('app.services.sound_scanner.get_audio_duration', return_value=60000), \ - patch('app.services.sound_scanner.get_file_size', return_value=2048): - + with patch("app.services.sound_scanner.get_audio_duration", return_value=60000), \ + patch("app.services.sound_scanner.get_file_size", return_value=2048): + # Create existing sound with normalized file info existing_sound = Sound( id=1, @@ -584,9 +581,9 @@ class TestSoundScannerService: normalized_hash="normalized_hash", play_count=5, is_deletable=False, - is_music=False + is_music=False, ) - + results = { "scanned": 0, "added": 0, @@ -597,36 +594,36 @@ class TestSoundScannerService: "errors": 0, "files": [], } - + # Mock the sound repository update scanner_service.sound_repo.update = AsyncMock() - + # Simulate rename detection by calling _sync_audio_file - await scanner_service._sync_audio_file( - new_path, - "SDB", - existing_sound, # existing_sound_by_hash (same hash, different filename) - None, # existing_sound_by_filename (no file with new name exists) - "test_hash", - results, + sync_context = SyncContext( + file_path=new_path, + sound_type="SDB", + existing_sound_by_hash=existing_sound, + existing_sound_by_filename=None, + file_hash="test_hash", ) - + await scanner_service._sync_audio_file(sync_context, results) + # Verify the results assert results["updated"] == 1 assert len(results["files"]) == 1 assert results["files"][0]["status"] == "updated" assert results["files"][0]["reason"] == "file was renamed" assert "normalized_filename" in results["files"][0]["changes"] - + # Verify sound_repo.update was called with normalized filename update update_call = scanner_service.sound_repo.update.call_args update_data = update_call[0][1] # Second argument is the update data - + assert "filename" in update_data assert "name" in update_data assert "normalized_filename" in update_data assert update_data["normalized_filename"] == "new_sound.mp3" - + # Verify the normalized file was actually renamed new_normalized_file = normalized_dir / "new_sound.mp3" assert new_normalized_file.exists() @@ -635,29 +632,29 @@ class TestSoundScannerService: @pytest.mark.asyncio async def test_scan_directory_delete_with_normalized_file( - self, test_session, scanner_service + self, test_session, scanner_service, ): """Test that deleting a sound also deletes its normalized file.""" # Create temporary directories for testing - from pathlib import Path import tempfile - + from pathlib import Path + with tempfile.TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) scan_dir = temp_dir_path / "sounds" scan_dir.mkdir() - + # Set up the scanner's normalized directories to use temp dir scanner_service.normalized_directories = { - "SDB": str(temp_dir_path / "normalized" / "soundboard") + "SDB": str(temp_dir_path / "normalized" / "soundboard"), } - + # Create the normalized directory and file normalized_dir = temp_dir_path / "normalized" / "soundboard" normalized_dir.mkdir(parents=True) normalized_file = normalized_dir / "test_sound.mp3" normalized_file.write_text("normalized audio content") - + # Create existing sound with normalized file info existing_sound = Sound( id=1, @@ -674,21 +671,21 @@ class TestSoundScannerService: normalized_hash="normalized_hash", play_count=5, is_deletable=False, - is_music=False + is_music=False, ) - + # Mock sound repository methods scanner_service.sound_repo.get_by_type = AsyncMock(return_value=[existing_sound]) scanner_service.sound_repo.delete = AsyncMock() - + # Mock audio utility functions from unittest.mock import patch - with patch('app.services.sound_scanner.get_audio_duration'), \ - patch('app.services.sound_scanner.get_file_size'): - + with patch("app.services.sound_scanner.get_audio_duration"), \ + patch("app.services.sound_scanner.get_file_size"): + # Run scan with empty directory (should trigger deletion) results = await scanner_service.scan_directory(str(scan_dir), "SDB") - + # Verify the results assert results["deleted"] == 1 assert results["added"] == 0 @@ -696,9 +693,9 @@ class TestSoundScannerService: assert len(results["files"]) == 1 assert results["files"][0]["status"] == "deleted" assert results["files"][0]["reason"] == "file no longer exists" - + # Verify sound_repo.delete was called scanner_service.sound_repo.delete.assert_called_once_with(existing_sound) - + # Verify the normalized file was actually deleted assert not normalized_file.exists()