Add tests for extraction API endpoints and enhance existing tests
- Implement tests for admin extraction API endpoints including status retrieval, deletion of extractions, and permission checks. - Add tests for user extraction deletion, ensuring proper handling of permissions and non-existent extractions. - Enhance sound endpoint tests to include duplicate handling in responses. - Refactor favorite service tests to utilize mock dependencies for better maintainability and clarity. - Update sound scanner tests to improve file handling and ensure proper deletion of associated files.
This commit is contained in:
@@ -2,18 +2,58 @@
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import get_admin_user
|
||||
from app.models.user import User
|
||||
from app.services.extraction import ExtractionService
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
|
||||
router = APIRouter(prefix="/extractions", tags=["admin-extractions"])
|
||||
|
||||
|
||||
async def get_extraction_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> ExtractionService:
|
||||
"""Get the extraction service."""
|
||||
return ExtractionService(session)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_extraction_processor_status(
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
) -> dict:
|
||||
"""Get the status of the extraction processor. Admin only."""
|
||||
return extraction_processor.get_status()
|
||||
|
||||
|
||||
@router.delete("/{extraction_id}")
|
||||
async def delete_extraction(
|
||||
extraction_id: int,
|
||||
current_user: Annotated[User, Depends(get_admin_user)], # noqa: ARG001
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Delete any extraction and its associated sound and files. Admin only."""
|
||||
try:
|
||||
deleted = await extraction_service.delete_extraction(extraction_id, None)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Extraction {extraction_id} not found",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete extraction: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": f"Extraction {extraction_id} deleted successfully",
|
||||
}
|
||||
|
||||
@@ -170,7 +170,7 @@ async def get_processing_extractions(
|
||||
try:
|
||||
# Get all extractions with processing status
|
||||
processing_extractions = await extraction_service.extraction_repo.get_by_status(
|
||||
"processing"
|
||||
"processing",
|
||||
)
|
||||
|
||||
# Convert to ExtractionInfo format
|
||||
@@ -196,10 +196,53 @@ async def get_processing_extractions(
|
||||
}
|
||||
result.append(extraction_info)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get processing extractions: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/{extraction_id}")
|
||||
async def delete_extraction(
|
||||
extraction_id: int,
|
||||
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Delete extraction and associated sound/files. Users can only delete their own."""
|
||||
try:
|
||||
if current_user.id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User ID not available",
|
||||
)
|
||||
|
||||
deleted = await extraction_service.delete_extraction(
|
||||
extraction_id, current_user.id,
|
||||
)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Extraction {extraction_id} not found",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
) from e
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete extraction: {e!s}",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"message": f"Extraction {extraction_id} deleted successfully",
|
||||
}
|
||||
|
||||
@@ -35,10 +35,10 @@ class Sound(BaseModel, table=True):
|
||||
|
||||
# relationships
|
||||
playlist_sounds: list["PlaylistSound"] = Relationship(
|
||||
back_populates="sound", cascade_delete=True
|
||||
back_populates="sound", cascade_delete=True,
|
||||
)
|
||||
extractions: list["Extraction"] = Relationship(back_populates="sound")
|
||||
play_history: list["SoundPlayed"] = Relationship(
|
||||
back_populates="sound", cascade_delete=True
|
||||
back_populates="sound", cascade_delete=True,
|
||||
)
|
||||
favorites: list["Favorite"] = Relationship(back_populates="sound")
|
||||
|
||||
@@ -44,7 +44,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
result = await self.session.exec(
|
||||
select(Extraction)
|
||||
.where(Extraction.status == status)
|
||||
.order_by(Extraction.created_at)
|
||||
.order_by(Extraction.created_at),
|
||||
)
|
||||
return list(result.all())
|
||||
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import yt_dlp
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.sound import Sound
|
||||
from app.repositories.extraction import ExtractionRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
@@ -21,6 +23,18 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionContext:
|
||||
"""Context data for extraction processing."""
|
||||
|
||||
extraction_id: int
|
||||
extraction_url: str
|
||||
extraction_service: str | None
|
||||
extraction_service_id: str | None
|
||||
extraction_title: str | None
|
||||
user_id: int
|
||||
|
||||
|
||||
class ExtractionInfo(TypedDict):
|
||||
"""Type definition for extraction information."""
|
||||
|
||||
@@ -150,8 +164,8 @@ class ExtractionService:
|
||||
logger.exception("Failed to detect service info for URL: %s", url)
|
||||
return None
|
||||
|
||||
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
|
||||
"""Process an extraction job."""
|
||||
async def _validate_extraction(self, extraction_id: int) -> tuple:
|
||||
"""Validate extraction and return extraction data."""
|
||||
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||
if not extraction:
|
||||
msg = f"Extraction {extraction_id} not found"
|
||||
@@ -173,9 +187,183 @@ class ExtractionService:
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
user_name = user.name if user else None
|
||||
except Exception:
|
||||
logger.warning("Failed to get user %d for extraction", user_id)
|
||||
logger.exception("Failed to get user %d for extraction", user_id)
|
||||
user_name = None
|
||||
|
||||
return (
|
||||
extraction,
|
||||
user_id,
|
||||
extraction_url,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
extraction_title,
|
||||
user_name,
|
||||
)
|
||||
|
||||
async def _handle_service_detection(
|
||||
self,
|
||||
extraction: Extraction,
|
||||
context: ExtractionContext,
|
||||
) -> tuple:
|
||||
"""Handle service detection and duplicate checking."""
|
||||
if context.extraction_service and context.extraction_service_id:
|
||||
return (
|
||||
context.extraction_service,
|
||||
context.extraction_service_id,
|
||||
context.extraction_title,
|
||||
)
|
||||
|
||||
logger.info("Detecting service info for extraction %d", context.extraction_id)
|
||||
service_info = await self._detect_service_info(context.extraction_url)
|
||||
|
||||
if not service_info:
|
||||
msg = "Unable to detect service information from URL"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check if extraction already exists for this service
|
||||
service_name = service_info["service"]
|
||||
service_id_val = service_info["service_id"]
|
||||
|
||||
if not service_name or not service_id_val:
|
||||
msg = "Service info is incomplete"
|
||||
raise ValueError(msg)
|
||||
|
||||
existing = await self.extraction_repo.get_by_service_and_id(
|
||||
service_name,
|
||||
service_id_val,
|
||||
)
|
||||
if existing and existing.id != context.extraction_id:
|
||||
error_msg = (
|
||||
f"Extraction already exists for "
|
||||
f"{service_info['service']}:{service_info['service_id']}"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Update extraction with service info
|
||||
update_data = {
|
||||
"service": service_info["service"],
|
||||
"service_id": service_info["service_id"],
|
||||
"title": service_info.get("title") or context.extraction_title,
|
||||
}
|
||||
await self.extraction_repo.update(extraction, update_data)
|
||||
|
||||
# Update values for processing
|
||||
new_service = service_info["service"]
|
||||
new_service_id = service_info["service_id"]
|
||||
new_title = service_info.get("title") or context.extraction_title
|
||||
|
||||
await self._emit_extraction_event(
|
||||
context.user_id,
|
||||
{
|
||||
"extraction_id": context.extraction_id,
|
||||
"status": "processing",
|
||||
"title": new_title,
|
||||
"url": context.extraction_url,
|
||||
},
|
||||
)
|
||||
|
||||
return new_service, new_service_id, new_title
|
||||
|
||||
async def _process_media_files(
|
||||
self,
|
||||
extraction_id: int,
|
||||
extraction_url: str,
|
||||
extraction_title: str | None,
|
||||
extraction_service: str,
|
||||
extraction_service_id: str,
|
||||
) -> int:
|
||||
"""Process media files and create sound record."""
|
||||
# Extract audio and thumbnail
|
||||
audio_file, thumbnail_file = await self._extract_media(
|
||||
extraction_id,
|
||||
extraction_url,
|
||||
)
|
||||
|
||||
# Move files to final locations
|
||||
final_audio_path, final_thumbnail_path = (
|
||||
await self._move_files_to_final_location(
|
||||
audio_file,
|
||||
thumbnail_file,
|
||||
extraction_title,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Create Sound record
|
||||
sound = await self._create_sound_record(
|
||||
final_audio_path,
|
||||
final_thumbnail_path,
|
||||
extraction_title,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
)
|
||||
|
||||
if not sound.id:
|
||||
msg = "Sound creation failed - no ID returned"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return sound.id
|
||||
|
||||
async def _complete_extraction(
|
||||
self,
|
||||
extraction: Extraction,
|
||||
context: ExtractionContext,
|
||||
sound_id: int,
|
||||
) -> None:
|
||||
"""Complete extraction processing."""
|
||||
# Normalize the sound
|
||||
await self._normalize_sound(sound_id)
|
||||
|
||||
# Add to main playlist
|
||||
await self._add_to_main_playlist(sound_id, context.user_id)
|
||||
|
||||
# Update extraction with success
|
||||
await self.extraction_repo.update(
|
||||
extraction,
|
||||
{
|
||||
"status": "completed",
|
||||
"sound_id": sound_id,
|
||||
"error": None,
|
||||
},
|
||||
)
|
||||
|
||||
# Emit WebSocket event for completion
|
||||
await self._emit_extraction_event(
|
||||
context.user_id,
|
||||
{
|
||||
"extraction_id": context.extraction_id,
|
||||
"status": "completed",
|
||||
"title": context.extraction_title,
|
||||
"url": context.extraction_url,
|
||||
"sound_id": sound_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
|
||||
"""Process an extraction job."""
|
||||
# Validate extraction and get context data
|
||||
(
|
||||
extraction,
|
||||
user_id,
|
||||
extraction_url,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
extraction_title,
|
||||
user_name,
|
||||
) = await self._validate_extraction(extraction_id)
|
||||
|
||||
# Create context object for helper methods
|
||||
context = ExtractionContext(
|
||||
extraction_id=extraction_id,
|
||||
extraction_url=extraction_url,
|
||||
extraction_service=extraction_service,
|
||||
extraction_service_id=extraction_service_id,
|
||||
extraction_title=extraction_title,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
|
||||
|
||||
try:
|
||||
@@ -184,142 +372,53 @@ class ExtractionService:
|
||||
|
||||
# Emit WebSocket event for processing start
|
||||
await self._emit_extraction_event(
|
||||
user_id,
|
||||
context.user_id,
|
||||
{
|
||||
"extraction_id": extraction_id,
|
||||
"extraction_id": context.extraction_id,
|
||||
"status": "processing",
|
||||
"title": extraction_title or "Processing extraction...",
|
||||
"url": extraction_url,
|
||||
"title": context.extraction_title or "Processing extraction...",
|
||||
"url": context.extraction_url,
|
||||
},
|
||||
)
|
||||
|
||||
# Detect service info if not already available
|
||||
if not extraction_service or not extraction_service_id:
|
||||
logger.info("Detecting service info for extraction %d", extraction_id)
|
||||
service_info = await self._detect_service_info(extraction_url)
|
||||
|
||||
if not service_info:
|
||||
msg = "Unable to detect service information from URL"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check if extraction already exists for this service
|
||||
service_name = service_info["service"]
|
||||
service_id_val = service_info["service_id"]
|
||||
|
||||
if not service_name or not service_id_val:
|
||||
msg = "Service info is incomplete"
|
||||
raise ValueError(msg)
|
||||
|
||||
existing = await self.extraction_repo.get_by_service_and_id(
|
||||
service_name,
|
||||
service_id_val,
|
||||
)
|
||||
if existing and existing.id != extraction_id:
|
||||
error_msg = (
|
||||
f"Extraction already exists for "
|
||||
f"{service_info['service']}:{service_info['service_id']}"
|
||||
)
|
||||
logger.warning(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Update extraction with service info
|
||||
update_data = {
|
||||
"service": service_info["service"],
|
||||
"service_id": service_info["service_id"],
|
||||
"title": service_info.get("title") or extraction_title,
|
||||
}
|
||||
await self.extraction_repo.update(extraction, update_data)
|
||||
|
||||
# Update values for processing
|
||||
extraction_service = service_info["service"]
|
||||
extraction_service_id = service_info["service_id"]
|
||||
extraction_title = service_info.get("title") or extraction_title
|
||||
|
||||
await self._emit_extraction_event(
|
||||
user_id,
|
||||
{
|
||||
"extraction_id": extraction_id,
|
||||
"status": "processing",
|
||||
"title": extraction_title,
|
||||
"url": extraction_url,
|
||||
},
|
||||
)
|
||||
|
||||
# Extract audio and thumbnail
|
||||
audio_file, thumbnail_file = await self._extract_media(
|
||||
extraction_id,
|
||||
extraction_url,
|
||||
# Handle service detection and duplicate checking
|
||||
extraction_service, extraction_service_id, extraction_title = (
|
||||
await self._handle_service_detection(extraction, context)
|
||||
)
|
||||
|
||||
# Move files to final locations
|
||||
(
|
||||
final_audio_path,
|
||||
final_thumbnail_path,
|
||||
) = await self._move_files_to_final_location(
|
||||
audio_file,
|
||||
thumbnail_file,
|
||||
extraction_title,
|
||||
# Update context with potentially new values
|
||||
context.extraction_service = extraction_service
|
||||
context.extraction_service_id = extraction_service_id
|
||||
context.extraction_title = extraction_title
|
||||
|
||||
# Process media files and create sound record
|
||||
sound_id = await self._process_media_files(
|
||||
context.extraction_id,
|
||||
context.extraction_url,
|
||||
context.extraction_title,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
)
|
||||
|
||||
# Create Sound record
|
||||
sound = await self._create_sound_record(
|
||||
final_audio_path,
|
||||
final_thumbnail_path,
|
||||
extraction_title,
|
||||
extraction_service,
|
||||
extraction_service_id,
|
||||
)
|
||||
# Complete extraction processing
|
||||
await self._complete_extraction(extraction, context, sound_id)
|
||||
|
||||
# Store sound_id early to avoid session detachment issues
|
||||
sound_id = sound.id
|
||||
if not sound_id:
|
||||
msg = "Sound creation failed - no ID returned"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Normalize the sound
|
||||
await self._normalize_sound(sound_id)
|
||||
|
||||
# Add to main playlist
|
||||
await self._add_to_main_playlist(sound_id, user_id)
|
||||
|
||||
# Update extraction with success
|
||||
await self.extraction_repo.update(
|
||||
extraction,
|
||||
{
|
||||
"status": "completed",
|
||||
"sound_id": sound_id,
|
||||
"error": None,
|
||||
},
|
||||
)
|
||||
|
||||
# Emit WebSocket event for completion
|
||||
await self._emit_extraction_event(
|
||||
user_id,
|
||||
{
|
||||
"extraction_id": extraction_id,
|
||||
"status": "completed",
|
||||
"title": extraction_title,
|
||||
"url": extraction_url,
|
||||
"sound_id": sound_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Successfully processed extraction %d", extraction_id)
|
||||
logger.info("Successfully processed extraction %d", context.extraction_id)
|
||||
|
||||
# Get updated extraction to get latest timestamps
|
||||
updated_extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||
updated_extraction = await self.extraction_repo.get_by_id(
|
||||
context.extraction_id,
|
||||
)
|
||||
return {
|
||||
"id": extraction_id,
|
||||
"url": extraction_url,
|
||||
"id": context.extraction_id,
|
||||
"url": context.extraction_url,
|
||||
"service": extraction_service,
|
||||
"service_id": extraction_service_id,
|
||||
"title": extraction_title,
|
||||
"status": "completed",
|
||||
"error": None,
|
||||
"sound_id": sound_id,
|
||||
"user_id": user_id,
|
||||
"user_id": context.user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": (
|
||||
updated_extraction.created_at.isoformat()
|
||||
@@ -337,18 +436,18 @@ class ExtractionService:
|
||||
error_msg = str(e)
|
||||
logger.exception(
|
||||
"Failed to process extraction %d: %s",
|
||||
extraction_id,
|
||||
context.extraction_id,
|
||||
error_msg,
|
||||
)
|
||||
|
||||
# Emit WebSocket event for failure
|
||||
await self._emit_extraction_event(
|
||||
user_id,
|
||||
context.user_id,
|
||||
{
|
||||
"extraction_id": extraction_id,
|
||||
"extraction_id": context.extraction_id,
|
||||
"status": "failed",
|
||||
"title": extraction_title or "Extraction failed",
|
||||
"url": extraction_url,
|
||||
"title": context.extraction_title or "Extraction failed",
|
||||
"url": context.extraction_url,
|
||||
"error": error_msg,
|
||||
},
|
||||
)
|
||||
@@ -363,17 +462,19 @@ class ExtractionService:
|
||||
)
|
||||
|
||||
# Get updated extraction to get latest timestamps
|
||||
updated_extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||
updated_extraction = await self.extraction_repo.get_by_id(
|
||||
context.extraction_id,
|
||||
)
|
||||
return {
|
||||
"id": extraction_id,
|
||||
"url": extraction_url,
|
||||
"service": extraction_service,
|
||||
"service_id": extraction_service_id,
|
||||
"title": extraction_title,
|
||||
"id": context.extraction_id,
|
||||
"url": context.extraction_url,
|
||||
"service": context.extraction_service,
|
||||
"service_id": context.extraction_service_id,
|
||||
"title": context.extraction_title,
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"sound_id": None,
|
||||
"user_id": user_id,
|
||||
"user_id": context.user_id,
|
||||
"user_name": user_name,
|
||||
"created_at": (
|
||||
updated_extraction.created_at.isoformat()
|
||||
@@ -780,3 +881,174 @@ class ExtractionService:
|
||||
}
|
||||
for extraction, user in extraction_user_tuples
|
||||
]
|
||||
|
||||
async def delete_extraction(
|
||||
self,
|
||||
extraction_id: int,
|
||||
user_id: int | None = None,
|
||||
) -> bool:
|
||||
"""Delete an extraction and its associated sound and files.
|
||||
|
||||
Args:
|
||||
extraction_id: The ID of the extraction to delete
|
||||
user_id: Optional user ID for ownership verification (None for admin)
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False if extraction not found
|
||||
|
||||
Raises:
|
||||
ValueError: If user doesn't own the extraction (when user_id is provided)
|
||||
|
||||
"""
|
||||
logger.info(
|
||||
"Deleting extraction: %d (user: %s)",
|
||||
extraction_id,
|
||||
user_id or "admin",
|
||||
)
|
||||
|
||||
# Get the extraction record
|
||||
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||
if not extraction:
|
||||
logger.warning("Extraction %d not found", extraction_id)
|
||||
return False
|
||||
|
||||
# Check ownership if user_id is provided (non-admin request)
|
||||
if user_id is not None and extraction.user_id != user_id:
|
||||
msg = "You don't have permission to delete this extraction"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Get associated sound if it exists and capture its attributes immediately
|
||||
sound_data = None
|
||||
sound_object = None
|
||||
if extraction.sound_id:
|
||||
sound_object = await self.sound_repo.get_by_id(extraction.sound_id)
|
||||
if sound_object:
|
||||
# Capture attributes immediately while session is valid
|
||||
sound_data = {
|
||||
"id": sound_object.id,
|
||||
"type": sound_object.type,
|
||||
"filename": sound_object.filename,
|
||||
"is_normalized": sound_object.is_normalized,
|
||||
"normalized_filename": sound_object.normalized_filename,
|
||||
"thumbnail": sound_object.thumbnail,
|
||||
}
|
||||
|
||||
try:
|
||||
# Delete the extraction record first
|
||||
await self.extraction_repo.delete(extraction)
|
||||
logger.info("Deleted extraction record: %d", extraction_id)
|
||||
|
||||
# Check if sound was in current playlist before deletion
|
||||
sound_was_in_current_playlist = False
|
||||
if sound_object and sound_data:
|
||||
sound_was_in_current_playlist = (
|
||||
await self._check_sound_in_current_playlist(sound_data["id"])
|
||||
)
|
||||
|
||||
# If there's an associated sound, delete it and its files
|
||||
if sound_object and sound_data:
|
||||
await self._delete_sound_and_files(sound_object, sound_data)
|
||||
logger.info(
|
||||
"Deleted associated sound: %d (%s)",
|
||||
sound_data["id"],
|
||||
sound_data["filename"],
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
await self.session.commit()
|
||||
|
||||
# Reload player playlist if deleted sound was in current playlist
|
||||
if sound_was_in_current_playlist and sound_data:
|
||||
await self._reload_player_playlist()
|
||||
logger.info(
|
||||
"Reloaded player playlist after deleting sound %d "
|
||||
"from current playlist",
|
||||
sound_data["id"],
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Rollback on any error
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete extraction %d", extraction_id)
|
||||
raise
|
||||
else:
|
||||
return True
|
||||
|
||||
async def _delete_sound_and_files(
|
||||
self,
|
||||
sound: Sound,
|
||||
sound_data: dict[str, Any],
|
||||
) -> None:
|
||||
"""Delete a sound record and all its associated files."""
|
||||
# Collect all file paths to delete using captured attributes
|
||||
files_to_delete = []
|
||||
|
||||
# Original audio file
|
||||
if sound_data["type"] == "EXT": # Extracted sounds
|
||||
original_path = Path("sounds/originals/extracted") / sound_data["filename"]
|
||||
if original_path.exists():
|
||||
files_to_delete.append(original_path)
|
||||
|
||||
# Normalized file
|
||||
if sound_data["is_normalized"] and sound_data["normalized_filename"]:
|
||||
normalized_path = (
|
||||
Path("sounds/normalized/extracted") / sound_data["normalized_filename"]
|
||||
)
|
||||
if normalized_path.exists():
|
||||
files_to_delete.append(normalized_path)
|
||||
|
||||
# Thumbnail file
|
||||
if sound_data["thumbnail"]:
|
||||
thumbnail_path = (
|
||||
Path(settings.EXTRACTION_THUMBNAILS_DIR) / sound_data["thumbnail"]
|
||||
)
|
||||
if thumbnail_path.exists():
|
||||
files_to_delete.append(thumbnail_path)
|
||||
|
||||
# Delete the sound from database first
|
||||
await self.sound_repo.delete(sound)
|
||||
|
||||
# Delete all associated files
|
||||
for file_path in files_to_delete:
|
||||
try:
|
||||
file_path.unlink()
|
||||
logger.info("Deleted file: %s", file_path)
|
||||
except OSError:
|
||||
logger.exception("Failed to delete file %s", file_path)
|
||||
# Continue with other files even if one fails
|
||||
|
||||
async def _check_sound_in_current_playlist(self, sound_id: int) -> bool:
|
||||
"""Check if a sound is in the current playlist."""
|
||||
try:
|
||||
from app.repositories.playlist import PlaylistRepository # noqa: PLC0415
|
||||
|
||||
playlist_repo = PlaylistRepository(self.session)
|
||||
current_playlist = await playlist_repo.get_current_playlist()
|
||||
|
||||
if not current_playlist or not current_playlist.id:
|
||||
return False
|
||||
|
||||
return await playlist_repo.is_sound_in_playlist(
|
||||
current_playlist.id, sound_id,
|
||||
)
|
||||
except (ImportError, AttributeError, ValueError, RuntimeError) as e:
|
||||
logger.warning(
|
||||
"Failed to check if sound %s is in current playlist: %s",
|
||||
sound_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
async def _reload_player_playlist(self) -> None:
|
||||
"""Reload the player playlist after a sound is deleted."""
|
||||
try:
|
||||
# Import here to avoid circular import issues
|
||||
from app.services.player import get_player_service # noqa: PLC0415
|
||||
|
||||
player = get_player_service()
|
||||
await player.reload_playlist()
|
||||
logger.debug("Player playlist reloaded after sound deletion")
|
||||
except (ImportError, AttributeError, ValueError, RuntimeError) as e:
|
||||
# Don't fail the deletion operation if player reload fails
|
||||
logger.warning("Failed to reload player playlist: %s", e, exc_info=True)
|
||||
|
||||
@@ -201,7 +201,7 @@ class ExtractionProcessor:
|
||||
for extraction in stuck_extractions:
|
||||
try:
|
||||
await extraction_service.extraction_repo.update(
|
||||
extraction, {"status": "pending", "error": None}
|
||||
extraction, {"status": "pending", "error": None},
|
||||
)
|
||||
reset_count += 1
|
||||
logger.info(
|
||||
@@ -210,12 +210,13 @@ class ExtractionProcessor:
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to reset extraction %d", extraction.id
|
||||
"Failed to reset extraction %d", extraction.id,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"Successfully reset %d stuck extractions from processing to pending",
|
||||
"Successfully reset %d stuck extractions from processing to "
|
||||
"pending",
|
||||
reset_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Sound scanner service for scanning and importing audio files."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -13,6 +14,28 @@ from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioFileInfo:
|
||||
"""Data class for audio file metadata."""
|
||||
|
||||
filename: str
|
||||
name: str
|
||||
duration: int
|
||||
size: int
|
||||
file_hash: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncContext:
|
||||
"""Context data for audio file synchronization."""
|
||||
|
||||
file_path: Path
|
||||
sound_type: str
|
||||
existing_sound_by_hash: dict | Sound | None
|
||||
existing_sound_by_filename: dict | Sound | None
|
||||
file_hash: str
|
||||
|
||||
|
||||
class FileInfo(TypedDict):
|
||||
"""Type definition for file information in scan results."""
|
||||
|
||||
@@ -75,11 +98,15 @@ class SoundScannerService:
|
||||
|
||||
def _get_normalized_path(self, sound_type: str, filename: str) -> Path:
|
||||
"""Get the normalized file path for a sound."""
|
||||
directory = self.normalized_directories.get(sound_type, "sounds/normalized/other")
|
||||
directory = self.normalized_directories.get(
|
||||
sound_type, "sounds/normalized/other",
|
||||
)
|
||||
return Path(directory) / filename
|
||||
|
||||
def _rename_normalized_file(self, sound_type: str, old_filename: str, new_filename: str) -> bool:
|
||||
"""Rename a normalized file if it exists. Returns True if renamed, False if not found."""
|
||||
def _rename_normalized_file(
|
||||
self, sound_type: str, old_filename: str, new_filename: str,
|
||||
) -> bool:
|
||||
"""Rename normalized file if exists. Returns True if renamed, else False."""
|
||||
old_path = self._get_normalized_path(sound_type, old_filename)
|
||||
new_path = self._get_normalized_path(sound_type, new_filename)
|
||||
|
||||
@@ -89,26 +116,395 @@ class SoundScannerService:
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
old_path.rename(new_path)
|
||||
logger.info("Renamed normalized file: %s -> %s", old_path, new_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to rename normalized file %s -> %s: %s", old_path, new_path, e)
|
||||
except OSError:
|
||||
logger.exception(
|
||||
"Failed to rename normalized file %s -> %s",
|
||||
old_path,
|
||||
new_path,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _delete_normalized_file(self, sound_type: str, filename: str) -> bool:
|
||||
"""Delete a normalized file if it exists. Returns True if deleted, False if not found."""
|
||||
"""Delete normalized file if exists. Returns True if deleted, else False."""
|
||||
normalized_path = self._get_normalized_path(sound_type, filename)
|
||||
|
||||
if normalized_path.exists():
|
||||
try:
|
||||
normalized_path.unlink()
|
||||
logger.info("Deleted normalized file: %s", normalized_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete normalized file %s: %s", normalized_path, e)
|
||||
except OSError:
|
||||
logger.exception(
|
||||
"Failed to delete normalized file %s", normalized_path,
|
||||
)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_sound_attributes(self, sound_data: dict | Sound | None) -> dict:
|
||||
"""Extract attributes from sound data (dict or Sound object)."""
|
||||
if sound_data is None:
|
||||
return {}
|
||||
|
||||
if isinstance(sound_data, dict):
|
||||
return {
|
||||
"filename": sound_data.get("filename"),
|
||||
"name": sound_data.get("name"),
|
||||
"duration": sound_data.get("duration"),
|
||||
"size": sound_data.get("size"),
|
||||
"id": sound_data.get("id"),
|
||||
"object": sound_data.get("sound_object"),
|
||||
"type": sound_data.get("type"),
|
||||
"is_normalized": sound_data.get("is_normalized"),
|
||||
"normalized_filename": sound_data.get("normalized_filename"),
|
||||
}
|
||||
# Sound object (for tests)
|
||||
return {
|
||||
"filename": sound_data.filename,
|
||||
"name": sound_data.name,
|
||||
"duration": sound_data.duration,
|
||||
"size": sound_data.size,
|
||||
"id": sound_data.id,
|
||||
"object": sound_data,
|
||||
"type": sound_data.type,
|
||||
"is_normalized": sound_data.is_normalized,
|
||||
"normalized_filename": sound_data.normalized_filename,
|
||||
}
|
||||
|
||||
def _handle_unchanged_file(
|
||||
self,
|
||||
filename: str,
|
||||
existing_attrs: dict,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle unchanged file (same hash, same filename)."""
|
||||
logger.debug("Sound unchanged: %s", filename)
|
||||
results["skipped"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "skipped",
|
||||
"reason": "file unchanged",
|
||||
"name": existing_attrs["name"],
|
||||
"duration": existing_attrs["duration"],
|
||||
"size": existing_attrs["size"],
|
||||
"id": existing_attrs["id"],
|
||||
"error": None,
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
def _handle_duplicate_file(
|
||||
self,
|
||||
filename: str,
|
||||
existing_filename: str,
|
||||
file_hash: str,
|
||||
existing_attrs: dict,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle duplicate file (same hash, different filename)."""
|
||||
logger.warning(
|
||||
"Duplicate file detected: '%s' has same content as existing "
|
||||
"'%s' (hash: %s). Skipping duplicate file.",
|
||||
filename,
|
||||
existing_filename,
|
||||
file_hash[:8] + "...",
|
||||
)
|
||||
results["skipped"] += 1
|
||||
results["duplicates"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "skipped",
|
||||
"reason": "duplicate content",
|
||||
"name": existing_attrs["name"],
|
||||
"duration": existing_attrs["duration"],
|
||||
"size": existing_attrs["size"],
|
||||
"id": existing_attrs["id"],
|
||||
"error": None,
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
async def _handle_file_rename(
|
||||
self,
|
||||
file_info: AudioFileInfo,
|
||||
existing_attrs: dict,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle file rename (same hash, different filename)."""
|
||||
update_data = {
|
||||
"filename": file_info.filename,
|
||||
"name": file_info.name,
|
||||
}
|
||||
|
||||
# If the sound has a normalized file, rename it too
|
||||
if existing_attrs["is_normalized"] and existing_attrs["normalized_filename"]:
|
||||
old_normalized_base = Path(existing_attrs["normalized_filename"]).name
|
||||
new_normalized_base = (
|
||||
Path(file_info.filename).stem
|
||||
+ Path(existing_attrs["normalized_filename"]).suffix
|
||||
)
|
||||
|
||||
renamed = self._rename_normalized_file(
|
||||
existing_attrs["type"],
|
||||
old_normalized_base,
|
||||
new_normalized_base,
|
||||
)
|
||||
|
||||
if renamed:
|
||||
update_data["normalized_filename"] = new_normalized_base
|
||||
logger.info(
|
||||
"Renamed normalized file: %s -> %s",
|
||||
old_normalized_base,
|
||||
new_normalized_base,
|
||||
)
|
||||
|
||||
await self.sound_repo.update(existing_attrs["object"], update_data)
|
||||
logger.info(
|
||||
"Detected rename: %s -> %s (ID: %s)",
|
||||
existing_attrs["filename"],
|
||||
file_info.filename,
|
||||
existing_attrs["id"],
|
||||
)
|
||||
|
||||
# Build changes list
|
||||
changes = ["filename", "name"]
|
||||
if "normalized_filename" in update_data:
|
||||
changes.append("normalized_filename")
|
||||
|
||||
results["updated"] += 1
|
||||
results["files"].append({
|
||||
"filename": file_info.filename,
|
||||
"status": "updated",
|
||||
"reason": "file was renamed",
|
||||
"name": file_info.name,
|
||||
"duration": existing_attrs["duration"],
|
||||
"size": existing_attrs["size"],
|
||||
"id": existing_attrs["id"],
|
||||
"error": None,
|
||||
"changes": changes,
|
||||
# Store old filename to prevent deletion
|
||||
"old_filename": existing_attrs["filename"],
|
||||
})
|
||||
|
||||
async def _handle_file_modification(
|
||||
self,
|
||||
file_info: AudioFileInfo,
|
||||
existing_attrs: dict,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle file modification (same filename, different hash)."""
|
||||
update_data = {
|
||||
"name": file_info.name,
|
||||
"duration": file_info.duration,
|
||||
"size": file_info.size,
|
||||
"hash": file_info.file_hash,
|
||||
}
|
||||
|
||||
await self.sound_repo.update(existing_attrs["object"], update_data)
|
||||
logger.info(
|
||||
"Updated modified sound: %s (ID: %s)",
|
||||
file_info.name,
|
||||
existing_attrs["id"],
|
||||
)
|
||||
|
||||
results["updated"] += 1
|
||||
results["files"].append({
|
||||
"filename": file_info.filename,
|
||||
"status": "updated",
|
||||
"reason": "file was modified",
|
||||
"name": file_info.name,
|
||||
"duration": file_info.duration,
|
||||
"size": file_info.size,
|
||||
"id": existing_attrs["id"],
|
||||
"error": None,
|
||||
"changes": ["hash", "duration", "size", "name"],
|
||||
})
|
||||
|
||||
async def _handle_new_file(
|
||||
self,
|
||||
file_info: AudioFileInfo,
|
||||
sound_type: str,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Handle new file (neither hash nor filename exists)."""
|
||||
sound_data = {
|
||||
"type": sound_type,
|
||||
"name": file_info.name,
|
||||
"filename": file_info.filename,
|
||||
"duration": file_info.duration,
|
||||
"size": file_info.size,
|
||||
"hash": file_info.file_hash,
|
||||
"is_deletable": False,
|
||||
"is_music": False,
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id)
|
||||
|
||||
results["added"] += 1
|
||||
results["files"].append({
|
||||
"filename": file_info.filename,
|
||||
"status": "added",
|
||||
"reason": None,
|
||||
"name": file_info.name,
|
||||
"duration": file_info.duration,
|
||||
"size": file_info.size,
|
||||
"id": sound.id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
async def _load_existing_sounds(self, sound_type: str) -> tuple[dict, dict]:
|
||||
"""Load existing sounds and create lookup dictionaries."""
|
||||
existing_sounds = await self.sound_repo.get_by_type(sound_type)
|
||||
|
||||
# Create lookup dictionaries with immediate attribute access
|
||||
# to avoid session detachment
|
||||
sounds_by_hash = {}
|
||||
sounds_by_filename = {}
|
||||
|
||||
for sound in existing_sounds:
|
||||
# Capture all attributes immediately while session is valid
|
||||
sound_data = {
|
||||
"id": sound.id,
|
||||
"hash": sound.hash,
|
||||
"filename": sound.filename,
|
||||
"name": sound.name,
|
||||
"duration": sound.duration,
|
||||
"size": sound.size,
|
||||
"type": sound.type,
|
||||
"is_normalized": sound.is_normalized,
|
||||
"normalized_filename": sound.normalized_filename,
|
||||
"sound_object": sound, # Keep reference for database operations
|
||||
}
|
||||
sounds_by_hash[sound.hash] = sound_data
|
||||
sounds_by_filename[sound.filename] = sound_data
|
||||
|
||||
return sounds_by_hash, sounds_by_filename
|
||||
|
||||
async def _process_audio_files(
|
||||
self,
|
||||
scan_path: Path,
|
||||
sound_type: str,
|
||||
sounds_by_hash: dict,
|
||||
sounds_by_filename: dict,
|
||||
results: ScanResults,
|
||||
) -> set[str]:
|
||||
"""Process all audio files in directory and return processed filenames."""
|
||||
# Get all audio files from directory
|
||||
audio_files = [
|
||||
f
|
||||
for f in scan_path.iterdir()
|
||||
if f.is_file() and f.suffix.lower() in self.supported_extensions
|
||||
]
|
||||
|
||||
# Process each file in directory
|
||||
processed_filenames = set()
|
||||
for file_path in audio_files:
|
||||
results["scanned"] += 1
|
||||
filename = file_path.name
|
||||
processed_filenames.add(filename)
|
||||
|
||||
try:
|
||||
# Calculate hash first to enable hash-based lookup
|
||||
file_hash = get_file_hash(file_path)
|
||||
existing_sound_by_hash = sounds_by_hash.get(file_hash)
|
||||
existing_sound_by_filename = sounds_by_filename.get(filename)
|
||||
|
||||
# Create sync context
|
||||
sync_context = SyncContext(
|
||||
file_path=file_path,
|
||||
sound_type=sound_type,
|
||||
existing_sound_by_hash=existing_sound_by_hash,
|
||||
existing_sound_by_filename=existing_sound_by_filename,
|
||||
file_hash=file_hash,
|
||||
)
|
||||
|
||||
await self._sync_audio_file(sync_context, results)
|
||||
|
||||
# Check if this was a rename and mark old filename as processed
|
||||
if results["files"] and results["files"][-1].get("old_filename"):
|
||||
old_filename = results["files"][-1]["old_filename"]
|
||||
processed_filenames.add(old_filename)
|
||||
logger.debug("Marked old filename as processed: %s", old_filename)
|
||||
# Remove temporary tracking field from results
|
||||
del results["files"][-1]["old_filename"]
|
||||
except Exception as e:
|
||||
logger.exception("Error processing file %s", file_path)
|
||||
results["errors"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "error",
|
||||
"reason": None,
|
||||
"name": None,
|
||||
"duration": None,
|
||||
"size": None,
|
||||
"id": None,
|
||||
"error": str(e),
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
return processed_filenames
|
||||
|
||||
async def _delete_missing_sounds(
|
||||
self,
|
||||
sounds_by_filename: dict,
|
||||
processed_filenames: set[str],
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Delete sounds that no longer exist in directory."""
|
||||
for filename, sound_data in sounds_by_filename.items():
|
||||
if filename not in processed_filenames:
|
||||
# Attributes already captured in sound_data dictionary
|
||||
sound_name = sound_data["name"]
|
||||
sound_duration = sound_data["duration"]
|
||||
sound_size = sound_data["size"]
|
||||
sound_id = sound_data["id"]
|
||||
sound_object = sound_data["sound_object"]
|
||||
sound_type = sound_data["type"]
|
||||
sound_is_normalized = sound_data["is_normalized"]
|
||||
sound_normalized_filename = sound_data["normalized_filename"]
|
||||
|
||||
try:
|
||||
# Delete the sound from database first
|
||||
await self.sound_repo.delete(sound_object)
|
||||
logger.info("Deleted sound no longer in directory: %s", filename)
|
||||
|
||||
# If the sound had a normalized file, delete it too
|
||||
if sound_is_normalized and sound_normalized_filename:
|
||||
normalized_base = Path(sound_normalized_filename).name
|
||||
self._delete_normalized_file(sound_type, normalized_base)
|
||||
|
||||
results["deleted"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "deleted",
|
||||
"reason": "file no longer exists",
|
||||
"name": sound_name,
|
||||
"duration": sound_duration,
|
||||
"size": sound_size,
|
||||
"id": sound_id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting sound %s", filename)
|
||||
results["errors"] += 1
|
||||
results["files"].append({
|
||||
"filename": filename,
|
||||
"status": "error",
|
||||
"reason": "failed to delete",
|
||||
"name": sound_name,
|
||||
"duration": sound_duration,
|
||||
"size": sound_size,
|
||||
"id": sound_id,
|
||||
"error": str(e),
|
||||
"changes": None,
|
||||
})
|
||||
|
||||
async def scan_directory(
|
||||
self,
|
||||
directory_path: str,
|
||||
@@ -138,368 +534,84 @@ class SoundScannerService:
|
||||
|
||||
logger.info("Starting sync of directory: %s", directory_path)
|
||||
|
||||
# Get all existing sounds of this type from database
|
||||
existing_sounds = await self.sound_repo.get_by_type(sound_type)
|
||||
# Load existing sounds from database
|
||||
sounds_by_hash, sounds_by_filename = await self._load_existing_sounds(
|
||||
sound_type,
|
||||
)
|
||||
|
||||
# Create lookup dictionaries with immediate attribute access
|
||||
# to avoid session detachment
|
||||
sounds_by_hash = {}
|
||||
sounds_by_filename = {}
|
||||
|
||||
for sound in existing_sounds:
|
||||
# Capture all attributes immediately while session is valid
|
||||
sound_data = {
|
||||
"id": sound.id,
|
||||
"hash": sound.hash,
|
||||
"filename": sound.filename,
|
||||
"name": sound.name,
|
||||
"duration": sound.duration,
|
||||
"size": sound.size,
|
||||
"type": sound.type,
|
||||
"is_normalized": sound.is_normalized,
|
||||
"normalized_filename": sound.normalized_filename,
|
||||
"sound_object": sound, # Keep reference for database operations
|
||||
}
|
||||
sounds_by_hash[sound.hash] = sound_data
|
||||
sounds_by_filename[sound.filename] = sound_data
|
||||
|
||||
# Get all audio files from directory
|
||||
audio_files = [
|
||||
f
|
||||
for f in scan_path.iterdir()
|
||||
if f.is_file() and f.suffix.lower() in self.supported_extensions
|
||||
]
|
||||
|
||||
# Process each file in directory
|
||||
processed_filenames = set()
|
||||
for file_path in audio_files:
|
||||
results["scanned"] += 1
|
||||
filename = file_path.name
|
||||
processed_filenames.add(filename)
|
||||
|
||||
try:
|
||||
# Calculate hash first to enable hash-based lookup
|
||||
file_hash = get_file_hash(file_path)
|
||||
existing_sound_by_hash = sounds_by_hash.get(file_hash)
|
||||
existing_sound_by_filename = sounds_by_filename.get(filename)
|
||||
|
||||
await self._sync_audio_file(
|
||||
file_path,
|
||||
sound_type,
|
||||
existing_sound_by_hash,
|
||||
existing_sound_by_filename,
|
||||
file_hash,
|
||||
results,
|
||||
)
|
||||
|
||||
# Check if this was a rename operation and mark old filename as processed
|
||||
if results["files"] and results["files"][-1].get("old_filename"):
|
||||
old_filename = results["files"][-1]["old_filename"]
|
||||
processed_filenames.add(old_filename)
|
||||
logger.debug("Marked old filename as processed: %s", old_filename)
|
||||
# Remove temporary tracking field from results
|
||||
del results["files"][-1]["old_filename"]
|
||||
except Exception as e:
|
||||
logger.exception("Error processing file %s", file_path)
|
||||
results["errors"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "error",
|
||||
"reason": None,
|
||||
"name": None,
|
||||
"duration": None,
|
||||
"size": None,
|
||||
"id": None,
|
||||
"error": str(e),
|
||||
"changes": None,
|
||||
},
|
||||
)
|
||||
# Process audio files in directory
|
||||
processed_filenames = await self._process_audio_files(
|
||||
scan_path,
|
||||
sound_type,
|
||||
sounds_by_hash,
|
||||
sounds_by_filename,
|
||||
results,
|
||||
)
|
||||
|
||||
# Delete sounds that no longer exist in directory
|
||||
for filename, sound_data in sounds_by_filename.items():
|
||||
if filename not in processed_filenames:
|
||||
# Attributes already captured in sound_data dictionary
|
||||
sound_name = sound_data["name"]
|
||||
sound_duration = sound_data["duration"]
|
||||
sound_size = sound_data["size"]
|
||||
sound_id = sound_data["id"]
|
||||
sound_object = sound_data["sound_object"]
|
||||
sound_type = sound_data["type"]
|
||||
sound_is_normalized = sound_data["is_normalized"]
|
||||
sound_normalized_filename = sound_data["normalized_filename"]
|
||||
|
||||
try:
|
||||
# Delete the sound from database first
|
||||
await self.sound_repo.delete(sound_object)
|
||||
logger.info("Deleted sound no longer in directory: %s", filename)
|
||||
|
||||
# If the sound had a normalized file, delete it too
|
||||
if sound_is_normalized and sound_normalized_filename:
|
||||
normalized_base = Path(sound_normalized_filename).name
|
||||
self._delete_normalized_file(sound_type, normalized_base)
|
||||
|
||||
results["deleted"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "deleted",
|
||||
"reason": "file no longer exists",
|
||||
"name": sound_name,
|
||||
"duration": sound_duration,
|
||||
"size": sound_size,
|
||||
"id": sound_id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error deleting sound %s", filename)
|
||||
results["errors"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "error",
|
||||
"reason": "failed to delete",
|
||||
"name": sound_name,
|
||||
"duration": sound_duration,
|
||||
"size": sound_size,
|
||||
"id": sound_id,
|
||||
"error": str(e),
|
||||
"changes": None,
|
||||
},
|
||||
)
|
||||
await self._delete_missing_sounds(
|
||||
sounds_by_filename,
|
||||
processed_filenames,
|
||||
results,
|
||||
)
|
||||
|
||||
logger.info("Sync completed: %s", results)
|
||||
return results
|
||||
|
||||
async def _sync_audio_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
sound_type: str,
|
||||
existing_sound_by_hash: dict | Sound | None,
|
||||
existing_sound_by_filename: dict | Sound | None,
|
||||
file_hash: str,
|
||||
sync_context: SyncContext,
|
||||
results: ScanResults,
|
||||
) -> None:
|
||||
"""Sync a single audio file using hash-first identification strategy."""
|
||||
filename = file_path.name
|
||||
duration = get_audio_duration(file_path)
|
||||
size = get_file_size(file_path)
|
||||
filename = sync_context.file_path.name
|
||||
duration = get_audio_duration(sync_context.file_path)
|
||||
size = get_file_size(sync_context.file_path)
|
||||
name = self.extract_name_from_filename(filename)
|
||||
|
||||
# Extract attributes - handle both dict (normal) and Sound object (tests)
|
||||
existing_hash_filename = None
|
||||
existing_hash_name = None
|
||||
existing_hash_duration = None
|
||||
existing_hash_size = None
|
||||
existing_hash_id = None
|
||||
existing_hash_object = None
|
||||
existing_hash_type = None
|
||||
existing_hash_is_normalized = None
|
||||
existing_hash_normalized_filename = None
|
||||
# Create file info object
|
||||
file_info = AudioFileInfo(
|
||||
filename=filename,
|
||||
name=name,
|
||||
duration=duration,
|
||||
size=size,
|
||||
file_hash=sync_context.file_hash,
|
||||
)
|
||||
|
||||
if existing_sound_by_hash is not None:
|
||||
if isinstance(existing_sound_by_hash, dict):
|
||||
existing_hash_filename = existing_sound_by_hash["filename"]
|
||||
existing_hash_name = existing_sound_by_hash["name"]
|
||||
existing_hash_duration = existing_sound_by_hash["duration"]
|
||||
existing_hash_size = existing_sound_by_hash["size"]
|
||||
existing_hash_id = existing_sound_by_hash["id"]
|
||||
existing_hash_object = existing_sound_by_hash["sound_object"]
|
||||
existing_hash_type = existing_sound_by_hash["type"]
|
||||
existing_hash_is_normalized = existing_sound_by_hash["is_normalized"]
|
||||
existing_hash_normalized_filename = existing_sound_by_hash["normalized_filename"]
|
||||
else: # Sound object (for tests)
|
||||
existing_hash_filename = existing_sound_by_hash.filename
|
||||
existing_hash_name = existing_sound_by_hash.name
|
||||
existing_hash_duration = existing_sound_by_hash.duration
|
||||
existing_hash_size = existing_sound_by_hash.size
|
||||
existing_hash_id = existing_sound_by_hash.id
|
||||
existing_hash_object = existing_sound_by_hash
|
||||
existing_hash_type = existing_sound_by_hash.type
|
||||
existing_hash_is_normalized = existing_sound_by_hash.is_normalized
|
||||
existing_hash_normalized_filename = existing_sound_by_hash.normalized_filename
|
||||
|
||||
existing_filename_id = None
|
||||
existing_filename_object = None
|
||||
if existing_sound_by_filename is not None:
|
||||
if isinstance(existing_sound_by_filename, dict):
|
||||
existing_filename_id = existing_sound_by_filename["id"]
|
||||
existing_filename_object = existing_sound_by_filename["sound_object"]
|
||||
else: # Sound object (for tests)
|
||||
existing_filename_id = existing_sound_by_filename.id
|
||||
existing_filename_object = existing_sound_by_filename
|
||||
# Extract attributes from existing sounds
|
||||
hash_attrs = self._extract_sound_attributes(sync_context.existing_sound_by_hash)
|
||||
filename_attrs = self._extract_sound_attributes(
|
||||
sync_context.existing_sound_by_filename,
|
||||
)
|
||||
|
||||
# Hash-first identification strategy
|
||||
if existing_sound_by_hash is not None:
|
||||
if sync_context.existing_sound_by_hash is not None:
|
||||
# Content exists in database (same hash)
|
||||
if existing_hash_filename == filename:
|
||||
if hash_attrs["filename"] == filename:
|
||||
# Same hash, same filename - file unchanged
|
||||
logger.debug("Sound unchanged: %s", filename)
|
||||
results["skipped"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "skipped",
|
||||
"reason": "file unchanged",
|
||||
"name": existing_hash_name,
|
||||
"duration": existing_hash_duration,
|
||||
"size": existing_hash_size,
|
||||
"id": existing_hash_id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
},
|
||||
)
|
||||
self._handle_unchanged_file(filename, hash_attrs, results)
|
||||
else:
|
||||
# Same hash, different filename - could be rename or duplicate
|
||||
# Check if both files exist to determine if it's a duplicate
|
||||
old_file_path = file_path.parent / existing_hash_filename
|
||||
old_file_path = sync_context.file_path.parent / hash_attrs["filename"]
|
||||
if old_file_path.exists():
|
||||
# Both files exist with same hash - this is a duplicate
|
||||
logger.warning(
|
||||
"Duplicate file detected: '%s' has same content as existing '%s' (hash: %s). "
|
||||
"Skipping duplicate file.",
|
||||
self._handle_duplicate_file(
|
||||
filename,
|
||||
existing_hash_filename,
|
||||
file_hash[:8] + "...",
|
||||
)
|
||||
|
||||
results["skipped"] += 1
|
||||
results["duplicates"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "skipped",
|
||||
"reason": "duplicate content",
|
||||
"name": existing_hash_name,
|
||||
"duration": existing_hash_duration,
|
||||
"size": existing_hash_size,
|
||||
"id": existing_hash_id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
},
|
||||
hash_attrs["filename"],
|
||||
sync_context.file_hash,
|
||||
hash_attrs,
|
||||
results,
|
||||
)
|
||||
else:
|
||||
# Old file doesn't exist - this is a genuine rename
|
||||
update_data = {
|
||||
"filename": filename,
|
||||
"name": name,
|
||||
}
|
||||
await self._handle_file_rename(file_info, hash_attrs, results)
|
||||
|
||||
# If the sound has a normalized file, rename it too
|
||||
if existing_hash_is_normalized and existing_hash_normalized_filename:
|
||||
# Extract base filename without path for normalized file
|
||||
old_normalized_base = Path(existing_hash_normalized_filename).name
|
||||
new_normalized_base = Path(filename).stem + Path(existing_hash_normalized_filename).suffix
|
||||
|
||||
renamed = self._rename_normalized_file(
|
||||
existing_hash_type,
|
||||
old_normalized_base,
|
||||
new_normalized_base
|
||||
)
|
||||
|
||||
if renamed:
|
||||
update_data["normalized_filename"] = new_normalized_base
|
||||
logger.info(
|
||||
"Renamed normalized file: %s -> %s",
|
||||
old_normalized_base,
|
||||
new_normalized_base
|
||||
)
|
||||
|
||||
await self.sound_repo.update(existing_hash_object, update_data)
|
||||
logger.info(
|
||||
"Detected rename: %s -> %s (ID: %s)",
|
||||
existing_hash_filename,
|
||||
filename,
|
||||
existing_hash_id,
|
||||
)
|
||||
|
||||
# Build changes list
|
||||
changes = ["filename", "name"]
|
||||
if "normalized_filename" in update_data:
|
||||
changes.append("normalized_filename")
|
||||
|
||||
results["updated"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "updated",
|
||||
"reason": "file was renamed",
|
||||
"name": name,
|
||||
"duration": existing_hash_duration,
|
||||
"size": existing_hash_size,
|
||||
"id": existing_hash_id,
|
||||
"error": None,
|
||||
"changes": changes,
|
||||
# Store old filename to prevent deletion
|
||||
"old_filename": existing_hash_filename,
|
||||
},
|
||||
)
|
||||
|
||||
elif existing_sound_by_filename is not None:
|
||||
elif sync_context.existing_sound_by_filename is not None:
|
||||
# Same filename but different hash - file was modified
|
||||
update_data = {
|
||||
"name": name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
}
|
||||
|
||||
await self.sound_repo.update(existing_filename_object, update_data)
|
||||
logger.info(
|
||||
"Updated modified sound: %s (ID: %s)",
|
||||
name,
|
||||
existing_filename_id,
|
||||
)
|
||||
|
||||
results["updated"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "updated",
|
||||
"reason": "file was modified",
|
||||
"name": name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"id": existing_filename_id,
|
||||
"error": None,
|
||||
"changes": ["hash", "duration", "size", "name"],
|
||||
},
|
||||
)
|
||||
|
||||
await self._handle_file_modification(file_info, filename_attrs, results)
|
||||
else:
|
||||
# New file - neither hash nor filename exists
|
||||
sound_data = {
|
||||
"type": sound_type,
|
||||
"name": name,
|
||||
"filename": filename,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"hash": file_hash,
|
||||
"is_deletable": False,
|
||||
"is_music": False,
|
||||
"is_normalized": False,
|
||||
"play_count": 0,
|
||||
}
|
||||
|
||||
sound = await self.sound_repo.create(sound_data)
|
||||
logger.info("Added new sound: %s (ID: %s)", sound.name, sound.id)
|
||||
|
||||
results["added"] += 1
|
||||
results["files"].append(
|
||||
{
|
||||
"filename": filename,
|
||||
"status": "added",
|
||||
"reason": None,
|
||||
"name": name,
|
||||
"duration": duration,
|
||||
"size": size,
|
||||
"id": sound.id,
|
||||
"error": None,
|
||||
"changes": None,
|
||||
},
|
||||
)
|
||||
await self._handle_new_file(file_info, sync_context.sound_type, results)
|
||||
|
||||
async def scan_soundboard_directory(self) -> ScanResults:
|
||||
"""Sync the default soundboard directory."""
|
||||
|
||||
154
tests/api/v1/admin/test_extraction_endpoints.py
Normal file
154
tests/api/v1/admin/test_extraction_endpoints.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Tests for admin extraction API endpoints."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models.extraction import Extraction
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class TestAdminExtractionEndpoints:
|
||||
"""Test admin extraction endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_extraction_processor_status(self, authenticated_admin_client):
|
||||
"""Test getting extraction processor status."""
|
||||
response = await authenticated_admin_client.get(
|
||||
"/api/v1/admin/extractions/status",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Check expected status fields (match actual processor status format)
|
||||
assert "currently_processing" in data
|
||||
assert "max_concurrent" in data
|
||||
assert "available_slots" in data
|
||||
assert "processing_ids" in data
|
||||
assert isinstance(data["currently_processing"], int)
|
||||
assert isinstance(data["max_concurrent"], int)
|
||||
assert isinstance(data["available_slots"], int)
|
||||
assert isinstance(data["processing_ids"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_extraction_success(
|
||||
self,
|
||||
authenticated_admin_client,
|
||||
test_session,
|
||||
test_plan,
|
||||
):
|
||||
"""Test admin successfully deleting any extraction."""
|
||||
# Create a test user
|
||||
user = User(
|
||||
name="Test User",
|
||||
email="test@example.com",
|
||||
is_active=True,
|
||||
plan_id=test_plan.id,
|
||||
)
|
||||
test_session.add(user)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(user)
|
||||
|
||||
# Create test extraction
|
||||
extraction = Extraction(
|
||||
url="https://example.com/video",
|
||||
user_id=user.id,
|
||||
status="completed",
|
||||
)
|
||||
test_session.add(extraction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(extraction)
|
||||
|
||||
# Admin delete the extraction
|
||||
response = await authenticated_admin_client.delete(
|
||||
f"/api/v1/admin/extractions/{extraction.id}",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
|
||||
|
||||
# Verify extraction was deleted from database
|
||||
deleted_extraction = await test_session.get(Extraction, extraction.id)
|
||||
assert deleted_extraction is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_extraction_not_found(self, authenticated_admin_client):
|
||||
"""Test admin deleting non-existent extraction."""
|
||||
response = await authenticated_admin_client.delete(
|
||||
"/api/v1/admin/extractions/999",
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "not found" in data["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_delete_extraction_any_user(
|
||||
self,
|
||||
authenticated_admin_client,
|
||||
test_session,
|
||||
test_plan,
|
||||
):
|
||||
"""Test admin deleting extraction owned by any user."""
|
||||
# Create another user and their extraction
|
||||
other_user = User(
|
||||
name="Other User",
|
||||
email="other@example.com",
|
||||
is_active=True,
|
||||
plan_id=test_plan.id,
|
||||
)
|
||||
test_session.add(other_user)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(other_user)
|
||||
|
||||
extraction = Extraction(
|
||||
url="https://example.com/video",
|
||||
user_id=other_user.id,
|
||||
status="completed",
|
||||
)
|
||||
test_session.add(extraction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(extraction)
|
||||
|
||||
# Admin can delete any user's extraction
|
||||
response = await authenticated_admin_client.delete(
|
||||
f"/api/v1/admin/extractions/{extraction.id}",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_non_admin(self, authenticated_client, test_user, test_session):
|
||||
"""Test non-admin user cannot access admin deletion endpoint."""
|
||||
# Create test extraction
|
||||
extraction = Extraction(
|
||||
url="https://example.com/video",
|
||||
user_id=test_user.id,
|
||||
status="completed",
|
||||
)
|
||||
test_session.add(extraction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(extraction)
|
||||
|
||||
# Non-admin user cannot access admin endpoint
|
||||
response = await authenticated_client.delete(
|
||||
f"/api/v1/admin/extractions/{extraction.id}",
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
data = response.json()
|
||||
assert "permissions" in data["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_endpoints_unauthenticated(self, client: AsyncClient):
|
||||
"""Test admin endpoints require authentication."""
|
||||
# Status endpoint
|
||||
response = await client.get("/api/v1/admin/extractions/status")
|
||||
assert response.status_code == 401
|
||||
|
||||
# Delete endpoint
|
||||
response = await client.delete("/api/v1/admin/extractions/1")
|
||||
assert response.status_code == 401
|
||||
@@ -31,6 +31,7 @@ class TestAdminSoundEndpoints:
|
||||
"deleted": 1,
|
||||
"skipped": 0,
|
||||
"errors": 0,
|
||||
"duplicates": 0,
|
||||
"files": [
|
||||
{
|
||||
"filename": "test1.mp3",
|
||||
@@ -176,6 +177,7 @@ class TestAdminSoundEndpoints:
|
||||
"deleted": 0,
|
||||
"skipped": 0,
|
||||
"errors": 0,
|
||||
"duplicates": 0,
|
||||
"files": [
|
||||
{
|
||||
"filename": "custom1.wav",
|
||||
|
||||
@@ -229,3 +229,73 @@ class TestExtractionEndpoints:
|
||||
break
|
||||
|
||||
assert processing_found, "Processing extraction not found in results"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_success(self, authenticated_client, test_user, test_session):
|
||||
"""Test successful deletion of user's own extraction."""
|
||||
# Create test extraction
|
||||
extraction = Extraction(
|
||||
url="https://example.com/video",
|
||||
user_id=test_user.id,
|
||||
status="completed",
|
||||
)
|
||||
test_session.add(extraction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(extraction)
|
||||
|
||||
# Delete the extraction
|
||||
response = await authenticated_client.delete(f"/api/v1/extractions/{extraction.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["message"] == f"Extraction {extraction.id} deleted successfully"
|
||||
|
||||
# Verify extraction was deleted from database
|
||||
deleted_extraction = await test_session.get(Extraction, extraction.id)
|
||||
assert deleted_extraction is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_not_found(self, authenticated_client):
|
||||
"""Test deleting non-existent extraction."""
|
||||
response = await authenticated_client.delete("/api/v1/extractions/999")
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "not found" in data["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_permission_denied(self, authenticated_client, test_session, test_plan):
|
||||
"""Test deleting another user's extraction."""
|
||||
# Create extraction owned by different user
|
||||
other_user = User(
|
||||
name="Other User",
|
||||
email="other@example.com",
|
||||
is_active=True,
|
||||
plan_id=test_plan.id,
|
||||
)
|
||||
test_session.add(other_user)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(other_user)
|
||||
|
||||
extraction = Extraction(
|
||||
url="https://example.com/video",
|
||||
user_id=other_user.id,
|
||||
status="completed",
|
||||
)
|
||||
test_session.add(extraction)
|
||||
await test_session.commit()
|
||||
await test_session.refresh(extraction)
|
||||
|
||||
# Try to delete other user's extraction
|
||||
response = await authenticated_client.delete(f"/api/v1/extractions/{extraction.id}")
|
||||
|
||||
assert response.status_code == 403
|
||||
data = response.json()
|
||||
assert "permission" in data["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_unauthenticated(self, client):
|
||||
"""Test deleting extraction without authentication."""
|
||||
response = await client.delete("/api/v1/extractions/1")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for favorite API endpoints."""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
@@ -129,10 +131,8 @@ class TestFavoriteEndpoints:
|
||||
) -> None:
|
||||
"""Test successfully adding a sound to favorites."""
|
||||
# Clean up any existing favorite first
|
||||
try:
|
||||
with suppress(Exception):
|
||||
await authenticated_client.delete("/api/v1/favorites/sounds/1")
|
||||
except:
|
||||
pass # It's ok if it doesn't exist
|
||||
|
||||
response = await authenticated_client.post("/api/v1/favorites/sounds/1")
|
||||
|
||||
@@ -176,10 +176,8 @@ class TestFavoriteEndpoints:
|
||||
) -> None:
|
||||
"""Test successfully adding a playlist to favorites."""
|
||||
# Clean up any existing favorite first
|
||||
try:
|
||||
with suppress(Exception):
|
||||
await authenticated_client.delete("/api/v1/favorites/playlists/1")
|
||||
except:
|
||||
pass # It's ok if it doesn't exist
|
||||
|
||||
response = await authenticated_client.post("/api/v1/favorites/playlists/1")
|
||||
|
||||
@@ -473,10 +471,8 @@ class TestFavoriteEndpoints:
|
||||
) -> None:
|
||||
"""Test checking if a sound is favorited (false case)."""
|
||||
# Make sure sound 1 is not favorited
|
||||
try:
|
||||
with suppress(Exception):
|
||||
await authenticated_client.delete("/api/v1/favorites/sounds/1")
|
||||
except:
|
||||
pass # It's ok if it doesn't exist
|
||||
|
||||
response = await authenticated_client.get("/api/v1/favorites/sounds/1/check")
|
||||
|
||||
@@ -509,10 +505,8 @@ class TestFavoriteEndpoints:
|
||||
) -> None:
|
||||
"""Test checking if a playlist is favorited (false case)."""
|
||||
# Make sure playlist 1 is not favorited
|
||||
try:
|
||||
with suppress(Exception):
|
||||
await authenticated_client.delete("/api/v1/favorites/playlists/1")
|
||||
except:
|
||||
pass # It's ok if it doesn't exist
|
||||
|
||||
response = await authenticated_client.get("/api/v1/favorites/playlists/1/check")
|
||||
|
||||
|
||||
@@ -541,3 +541,143 @@ class TestExtractionService:
|
||||
assert result[0]["id"] == 1
|
||||
assert result[0]["status"] == "pending"
|
||||
assert result[0]["user_name"] == "Test User"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_with_sound(self, extraction_service, test_user):
|
||||
"""Test deleting extraction with associated sound and files."""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
# Create temporary directories for testing
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
|
||||
# Set up temporary directories
|
||||
original_dir = temp_dir_path / "originals" / "extracted"
|
||||
normalized_dir = temp_dir_path / "normalized" / "extracted"
|
||||
thumbnail_dir = temp_dir_path / "thumbnails"
|
||||
|
||||
original_dir.mkdir(parents=True)
|
||||
normalized_dir.mkdir(parents=True)
|
||||
thumbnail_dir.mkdir(parents=True)
|
||||
|
||||
# Create test files
|
||||
audio_file = original_dir / "test_audio.mp3"
|
||||
normalized_file = normalized_dir / "test_audio.mp3"
|
||||
thumbnail_file = thumbnail_dir / "test_thumb.jpg"
|
||||
|
||||
audio_file.write_text("audio content")
|
||||
normalized_file.write_text("normalized content")
|
||||
thumbnail_file.write_text("thumbnail content")
|
||||
|
||||
# Create extraction and sound records
|
||||
extraction = Extraction(
|
||||
id=1,
|
||||
url="https://example.com/video",
|
||||
user_id=test_user.id,
|
||||
status="completed",
|
||||
sound_id=1,
|
||||
)
|
||||
|
||||
sound = Sound(
|
||||
id=1,
|
||||
type="EXT",
|
||||
name="Test Audio",
|
||||
filename="test_audio.mp3",
|
||||
duration=60000,
|
||||
size=2048,
|
||||
hash="test_hash",
|
||||
is_normalized=True,
|
||||
normalized_filename="test_audio.mp3",
|
||||
thumbnail="test_thumb.jpg",
|
||||
is_deletable=True,
|
||||
is_music=True,
|
||||
)
|
||||
|
||||
# Mock repository methods
|
||||
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction)
|
||||
extraction_service.sound_repo.get_by_id = AsyncMock(return_value=sound)
|
||||
extraction_service.extraction_repo.delete = AsyncMock()
|
||||
extraction_service.sound_repo.delete = AsyncMock()
|
||||
extraction_service.session.commit = AsyncMock()
|
||||
extraction_service.session.rollback = AsyncMock()
|
||||
|
||||
# Monkey patch the paths in the service method
|
||||
import app.services.extraction
|
||||
original_path_class = app.services.extraction.Path
|
||||
|
||||
def mock_path(*args: str):
|
||||
path_str = str(args[0])
|
||||
if path_str == "sounds/originals/extracted":
|
||||
return original_dir
|
||||
if path_str == "sounds/normalized/extracted":
|
||||
return normalized_dir
|
||||
if path_str.endswith("thumbnails"):
|
||||
return thumbnail_dir
|
||||
return original_path_class(*args)
|
||||
|
||||
# Mock the Path constructor and settings
|
||||
with patch("app.services.extraction.Path", side_effect=mock_path), \
|
||||
patch("app.services.extraction.settings") as mock_settings:
|
||||
mock_settings.EXTRACTION_THUMBNAILS_DIR = str(thumbnail_dir)
|
||||
|
||||
# Test deletion
|
||||
result = await extraction_service.delete_extraction(1, test_user.id)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify repository calls
|
||||
extraction_service.extraction_repo.get_by_id.assert_called_once_with(1)
|
||||
extraction_service.sound_repo.get_by_id.assert_called_once_with(1)
|
||||
extraction_service.extraction_repo.delete.assert_called_once_with(extraction)
|
||||
extraction_service.sound_repo.delete.assert_called_once_with(sound)
|
||||
extraction_service.session.commit.assert_called_once()
|
||||
|
||||
# Verify files were deleted
|
||||
assert not audio_file.exists()
|
||||
assert not normalized_file.exists()
|
||||
assert not thumbnail_file.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_not_found(self, extraction_service, test_user):
|
||||
"""Test deleting non-existent extraction."""
|
||||
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=None)
|
||||
|
||||
result = await extraction_service.delete_extraction(999, test_user.id)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_permission_denied(self, extraction_service, test_user):
|
||||
"""Test deleting extraction owned by another user."""
|
||||
extraction = Extraction(
|
||||
id=1,
|
||||
url="https://example.com/video",
|
||||
user_id=999, # Different user ID
|
||||
status="completed",
|
||||
)
|
||||
|
||||
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction)
|
||||
|
||||
with pytest.raises(ValueError, match="You don't have permission"):
|
||||
await extraction_service.delete_extraction(1, test_user.id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_extraction_admin(self, extraction_service, test_user):
|
||||
"""Test admin deleting any extraction."""
|
||||
extraction = Extraction(
|
||||
id=1,
|
||||
url="https://example.com/video",
|
||||
user_id=999, # Different user ID
|
||||
status="completed",
|
||||
)
|
||||
|
||||
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=extraction)
|
||||
extraction_service.extraction_repo.delete = AsyncMock()
|
||||
extraction_service.session.commit = AsyncMock()
|
||||
|
||||
# Admin deletion (user_id=None)
|
||||
result = await extraction_service.delete_extraction(1, None)
|
||||
|
||||
assert result is True
|
||||
extraction_service.extraction_repo.delete.assert_called_once_with(extraction)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for favorite service."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -14,6 +15,31 @@ from app.models.user import User
|
||||
from app.services.favorite import FavoriteService
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockRepositories:
|
||||
"""Container for all mock repositories."""
|
||||
|
||||
user_repo: AsyncMock
|
||||
favorite_repo: AsyncMock
|
||||
sound_repo: AsyncMock
|
||||
socket_manager: AsyncMock
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockServiceDependencies:
|
||||
"""Container for all mock service dependencies."""
|
||||
|
||||
sound_repo_class: AsyncMock
|
||||
user_repo_class: AsyncMock
|
||||
favorite_repo_class: AsyncMock
|
||||
socket_manager: AsyncMock
|
||||
sound_repo: AsyncMock
|
||||
user_repo: AsyncMock
|
||||
favorite_repo: AsyncMock
|
||||
playlist_repo_class: AsyncMock | None = None
|
||||
playlist_repo: AsyncMock | None = None
|
||||
|
||||
|
||||
class TestFavoriteService:
|
||||
"""Test favorite service operations."""
|
||||
|
||||
@@ -71,34 +97,75 @@ class TestFavoriteService:
|
||||
"playlist_repo": AsyncMock(),
|
||||
}
|
||||
|
||||
@patch("app.services.favorite.socket_manager")
|
||||
@patch("app.services.favorite.FavoriteRepository")
|
||||
@patch("app.services.favorite.UserRepository")
|
||||
@patch("app.services.favorite.SoundRepository")
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_sound_favorite_dependencies(self) -> MockServiceDependencies:
|
||||
"""Create mock dependencies for sound favorite operations."""
|
||||
with (
|
||||
patch("app.services.favorite.SoundRepository") as mock_sound_repo_class,
|
||||
patch("app.services.favorite.UserRepository") as mock_user_repo_class,
|
||||
patch("app.services.favorite.FavoriteRepository") as mock_favorite_repo_class,
|
||||
patch("app.services.favorite.socket_manager") as mock_socket_manager,
|
||||
):
|
||||
mock_sound_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
mock_favorite_repo = AsyncMock()
|
||||
|
||||
mock_sound_repo_class.return_value = mock_sound_repo
|
||||
mock_user_repo_class.return_value = mock_user_repo
|
||||
mock_favorite_repo_class.return_value = mock_favorite_repo
|
||||
|
||||
yield MockServiceDependencies(
|
||||
sound_repo_class=mock_sound_repo_class,
|
||||
user_repo_class=mock_user_repo_class,
|
||||
favorite_repo_class=mock_favorite_repo_class,
|
||||
socket_manager=mock_socket_manager,
|
||||
sound_repo=mock_sound_repo,
|
||||
user_repo=mock_user_repo,
|
||||
favorite_repo=mock_favorite_repo,
|
||||
)
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_playlist_favorite_dependencies(self) -> MockServiceDependencies:
|
||||
"""Create mock dependencies for playlist favorite operations."""
|
||||
with (
|
||||
patch("app.services.favorite.UserRepository") as mock_user_repo_class,
|
||||
patch("app.services.favorite.PlaylistRepository") as mock_playlist_repo_class,
|
||||
patch("app.services.favorite.FavoriteRepository") as mock_favorite_repo_class,
|
||||
):
|
||||
mock_user_repo = AsyncMock()
|
||||
mock_playlist_repo = AsyncMock()
|
||||
mock_favorite_repo = AsyncMock()
|
||||
|
||||
mock_user_repo_class.return_value = mock_user_repo
|
||||
mock_playlist_repo_class.return_value = mock_playlist_repo
|
||||
mock_favorite_repo_class.return_value = mock_favorite_repo
|
||||
|
||||
yield MockServiceDependencies(
|
||||
sound_repo_class=AsyncMock(), # Not used in playlist tests
|
||||
user_repo_class=mock_user_repo_class,
|
||||
favorite_repo_class=mock_favorite_repo_class,
|
||||
socket_manager=AsyncMock(), # Not used in playlist tests
|
||||
sound_repo=AsyncMock(), # Not used in playlist tests
|
||||
user_repo=mock_user_repo,
|
||||
favorite_repo=mock_favorite_repo,
|
||||
playlist_repo_class=mock_playlist_repo_class,
|
||||
playlist_repo=mock_playlist_repo,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_sound_favorite_success(
|
||||
self,
|
||||
mock_sound_repo_class: AsyncMock,
|
||||
mock_user_repo_class: AsyncMock,
|
||||
mock_favorite_repo_class: AsyncMock,
|
||||
mock_socket_manager: AsyncMock,
|
||||
mock_sound_favorite_dependencies: MockServiceDependencies,
|
||||
favorite_service: FavoriteService,
|
||||
test_user: User,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test successfully adding a sound favorite."""
|
||||
# Setup mocks
|
||||
mock_favorite_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
mock_sound_repo = AsyncMock()
|
||||
|
||||
mock_favorite_repo_class.return_value = mock_favorite_repo
|
||||
mock_user_repo_class.return_value = mock_user_repo
|
||||
mock_sound_repo_class.return_value = mock_sound_repo
|
||||
|
||||
mock_user_repo.get_by_id.return_value = test_user
|
||||
mock_sound_repo.get_by_id.return_value = test_sound
|
||||
mock_favorite_repo.get_by_user_and_sound.return_value = None
|
||||
mocks = mock_sound_favorite_dependencies
|
||||
mocks.user_repo.get_by_id.return_value = test_user
|
||||
mocks.sound_repo.get_by_id.return_value = test_sound
|
||||
mocks.favorite_repo.get_by_user_and_sound.return_value = None
|
||||
|
||||
expected_favorite = Favorite(
|
||||
id=1,
|
||||
@@ -106,23 +173,23 @@ class TestFavoriteService:
|
||||
sound_id=test_sound.id,
|
||||
playlist_id=None,
|
||||
)
|
||||
mock_favorite_repo.create.return_value = expected_favorite
|
||||
mock_favorite_repo.count_sound_favorites.return_value = 1
|
||||
mocks.favorite_repo.create.return_value = expected_favorite
|
||||
mocks.favorite_repo.count_sound_favorites.return_value = 1
|
||||
|
||||
# Execute
|
||||
result = await favorite_service.add_sound_favorite(test_user.id, test_sound.id)
|
||||
|
||||
# Verify
|
||||
assert result == expected_favorite
|
||||
mock_user_repo.get_by_id.assert_called_once_with(test_user.id)
|
||||
mock_sound_repo.get_by_id.assert_called_once_with(test_sound.id)
|
||||
mock_favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id)
|
||||
mock_favorite_repo.create.assert_called_once_with({
|
||||
mocks.user_repo.get_by_id.assert_called_once_with(test_user.id)
|
||||
mocks.sound_repo.get_by_id.assert_called_once_with(test_sound.id)
|
||||
mocks.favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id)
|
||||
mocks.favorite_repo.create.assert_called_once_with({
|
||||
"user_id": test_user.id,
|
||||
"sound_id": test_sound.id,
|
||||
"playlist_id": None,
|
||||
})
|
||||
mock_socket_manager.broadcast_to_all.assert_called_once()
|
||||
mocks.socket_manager.broadcast_to_all.assert_called_once()
|
||||
|
||||
@patch("app.services.favorite.UserRepository")
|
||||
@pytest.mark.asyncio
|
||||
@@ -161,62 +228,38 @@ class TestFavoriteService:
|
||||
with pytest.raises(ValueError, match="Sound with ID 1 not found"):
|
||||
await favorite_service.add_sound_favorite(test_user.id, 1)
|
||||
|
||||
@patch("app.services.favorite.FavoriteRepository")
|
||||
@patch("app.services.favorite.SoundRepository")
|
||||
@patch("app.services.favorite.UserRepository")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_sound_favorite_already_exists(
|
||||
self,
|
||||
mock_user_repo_class: AsyncMock,
|
||||
mock_sound_repo_class: AsyncMock,
|
||||
mock_favorite_repo_class: AsyncMock,
|
||||
mock_sound_favorite_dependencies: MockServiceDependencies,
|
||||
favorite_service: FavoriteService,
|
||||
test_user: User,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test adding sound favorite that already exists."""
|
||||
mock_user_repo = AsyncMock()
|
||||
mock_sound_repo = AsyncMock()
|
||||
mock_favorite_repo = AsyncMock()
|
||||
|
||||
mock_user_repo_class.return_value = mock_user_repo
|
||||
mock_sound_repo_class.return_value = mock_sound_repo
|
||||
mock_favorite_repo_class.return_value = mock_favorite_repo
|
||||
|
||||
mock_user_repo.get_by_id.return_value = test_user
|
||||
mock_sound_repo.get_by_id.return_value = test_sound
|
||||
mocks = mock_sound_favorite_dependencies
|
||||
mocks.user_repo.get_by_id.return_value = test_user
|
||||
mocks.sound_repo.get_by_id.return_value = test_sound
|
||||
existing_favorite = Favorite(user_id=test_user.id, sound_id=test_sound.id)
|
||||
mock_favorite_repo.get_by_user_and_sound.return_value = existing_favorite
|
||||
mocks.favorite_repo.get_by_user_and_sound.return_value = existing_favorite
|
||||
|
||||
with pytest.raises(ValueError, match="already favorited"):
|
||||
await favorite_service.add_sound_favorite(test_user.id, test_sound.id)
|
||||
|
||||
@patch("app.services.favorite.FavoriteRepository")
|
||||
@patch("app.services.favorite.PlaylistRepository")
|
||||
@patch("app.services.favorite.UserRepository")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_playlist_favorite_success(
|
||||
self,
|
||||
mock_user_repo_class: AsyncMock,
|
||||
mock_playlist_repo_class: AsyncMock,
|
||||
mock_favorite_repo_class: AsyncMock,
|
||||
mock_playlist_favorite_dependencies: MockServiceDependencies,
|
||||
favorite_service: FavoriteService,
|
||||
test_user: User,
|
||||
test_playlist: Playlist,
|
||||
) -> None:
|
||||
"""Test successfully adding a playlist favorite."""
|
||||
# Setup mocks
|
||||
mock_favorite_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
mock_playlist_repo = AsyncMock()
|
||||
|
||||
mock_favorite_repo_class.return_value = mock_favorite_repo
|
||||
mock_user_repo_class.return_value = mock_user_repo
|
||||
mock_playlist_repo_class.return_value = mock_playlist_repo
|
||||
|
||||
mock_user_repo.get_by_id.return_value = test_user
|
||||
mock_playlist_repo.get_by_id.return_value = test_playlist
|
||||
mock_favorite_repo.get_by_user_and_playlist.return_value = None
|
||||
mocks = mock_playlist_favorite_dependencies
|
||||
mocks.user_repo.get_by_id.return_value = test_user
|
||||
mocks.playlist_repo.get_by_id.return_value = test_playlist
|
||||
mocks.favorite_repo.get_by_user_and_playlist.return_value = None
|
||||
|
||||
expected_favorite = Favorite(
|
||||
id=1,
|
||||
@@ -224,59 +267,45 @@ class TestFavoriteService:
|
||||
sound_id=None,
|
||||
playlist_id=test_playlist.id,
|
||||
)
|
||||
mock_favorite_repo.create.return_value = expected_favorite
|
||||
mocks.favorite_repo.create.return_value = expected_favorite
|
||||
|
||||
# Execute
|
||||
result = await favorite_service.add_playlist_favorite(test_user.id, test_playlist.id)
|
||||
|
||||
# Verify
|
||||
assert result == expected_favorite
|
||||
mock_user_repo.get_by_id.assert_called_once_with(test_user.id)
|
||||
mock_playlist_repo.get_by_id.assert_called_once_with(test_playlist.id)
|
||||
mock_favorite_repo.get_by_user_and_playlist.assert_called_once_with(test_user.id, test_playlist.id)
|
||||
mock_favorite_repo.create.assert_called_once_with({
|
||||
mocks.user_repo.get_by_id.assert_called_once_with(test_user.id)
|
||||
mocks.playlist_repo.get_by_id.assert_called_once_with(test_playlist.id)
|
||||
mocks.favorite_repo.get_by_user_and_playlist.assert_called_once_with(test_user.id, test_playlist.id)
|
||||
mocks.favorite_repo.create.assert_called_once_with({
|
||||
"user_id": test_user.id,
|
||||
"sound_id": None,
|
||||
"playlist_id": test_playlist.id,
|
||||
})
|
||||
|
||||
@patch("app.services.favorite.socket_manager")
|
||||
@patch("app.services.favorite.FavoriteRepository")
|
||||
@patch("app.services.favorite.SoundRepository")
|
||||
@patch("app.services.favorite.UserRepository")
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_sound_favorite_success(
|
||||
self,
|
||||
mock_user_repo_class: AsyncMock,
|
||||
mock_sound_repo_class: AsyncMock,
|
||||
mock_favorite_repo_class: AsyncMock,
|
||||
mock_socket_manager: AsyncMock,
|
||||
mock_sound_favorite_dependencies: MockServiceDependencies,
|
||||
favorite_service: FavoriteService,
|
||||
test_user: User,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test successfully removing a sound favorite."""
|
||||
mock_favorite_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
mock_sound_repo = AsyncMock()
|
||||
|
||||
mock_favorite_repo_class.return_value = mock_favorite_repo
|
||||
mock_user_repo_class.return_value = mock_user_repo
|
||||
mock_sound_repo_class.return_value = mock_sound_repo
|
||||
|
||||
mocks = mock_sound_favorite_dependencies
|
||||
existing_favorite = Favorite(user_id=test_user.id, sound_id=test_sound.id)
|
||||
mock_favorite_repo.get_by_user_and_sound.return_value = existing_favorite
|
||||
mock_user_repo.get_by_id.return_value = test_user
|
||||
mock_sound_repo.get_by_id.return_value = test_sound
|
||||
mock_favorite_repo.count_sound_favorites.return_value = 0
|
||||
mocks.favorite_repo.get_by_user_and_sound.return_value = existing_favorite
|
||||
mocks.user_repo.get_by_id.return_value = test_user
|
||||
mocks.sound_repo.get_by_id.return_value = test_sound
|
||||
mocks.favorite_repo.count_sound_favorites.return_value = 0
|
||||
|
||||
# Execute
|
||||
await favorite_service.remove_sound_favorite(test_user.id, test_sound.id)
|
||||
|
||||
# Verify
|
||||
mock_favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id)
|
||||
mock_favorite_repo.delete.assert_called_once_with(existing_favorite)
|
||||
mock_socket_manager.broadcast_to_all.assert_called_once()
|
||||
mocks.favorite_repo.get_by_user_and_sound.assert_called_once_with(test_user.id, test_sound.id)
|
||||
mocks.favorite_repo.delete.assert_called_once_with(existing_favorite)
|
||||
mocks.socket_manager.broadcast_to_all.assert_called_once()
|
||||
|
||||
@patch("app.services.favorite.FavoriteRepository")
|
||||
@pytest.mark.asyncio
|
||||
@@ -503,46 +532,31 @@ class TestFavoriteService:
|
||||
assert result == 3
|
||||
mock_favorite_repo.count_playlist_favorites.assert_called_once_with(1)
|
||||
|
||||
@patch("app.services.favorite.socket_manager")
|
||||
@patch("app.services.favorite.FavoriteRepository")
|
||||
@patch("app.services.favorite.SoundRepository")
|
||||
@patch("app.services.favorite.UserRepository")
|
||||
@pytest.mark.asyncio
|
||||
async def test_socket_broadcast_error_handling(
|
||||
self,
|
||||
mock_user_repo_class: AsyncMock,
|
||||
mock_sound_repo_class: AsyncMock,
|
||||
mock_favorite_repo_class: AsyncMock,
|
||||
mock_socket_manager: AsyncMock,
|
||||
mock_sound_favorite_dependencies: MockServiceDependencies,
|
||||
favorite_service: FavoriteService,
|
||||
test_user: User,
|
||||
test_sound: Sound,
|
||||
) -> None:
|
||||
"""Test that socket broadcast errors don't affect the operation."""
|
||||
# Setup mocks
|
||||
mock_favorite_repo = AsyncMock()
|
||||
mock_user_repo = AsyncMock()
|
||||
mock_sound_repo = AsyncMock()
|
||||
|
||||
mock_favorite_repo_class.return_value = mock_favorite_repo
|
||||
mock_user_repo_class.return_value = mock_user_repo
|
||||
mock_sound_repo_class.return_value = mock_sound_repo
|
||||
|
||||
mock_user_repo.get_by_id.return_value = test_user
|
||||
mock_sound_repo.get_by_id.return_value = test_sound
|
||||
mock_favorite_repo.get_by_user_and_sound.return_value = None
|
||||
mocks = mock_sound_favorite_dependencies
|
||||
mocks.user_repo.get_by_id.return_value = test_user
|
||||
mocks.sound_repo.get_by_id.return_value = test_sound
|
||||
mocks.favorite_repo.get_by_user_and_sound.return_value = None
|
||||
|
||||
expected_favorite = Favorite(id=1, user_id=test_user.id, sound_id=test_sound.id)
|
||||
mock_favorite_repo.create.return_value = expected_favorite
|
||||
mock_favorite_repo.count_sound_favorites.return_value = 1
|
||||
mocks.favorite_repo.create.return_value = expected_favorite
|
||||
mocks.favorite_repo.count_sound_favorites.return_value = 1
|
||||
|
||||
# Make socket broadcast raise an exception
|
||||
mock_socket_manager.broadcast_to_all.side_effect = Exception("Socket error")
|
||||
mocks.socket_manager.broadcast_to_all.side_effect = Exception("Socket error")
|
||||
|
||||
# Execute - should not raise exception despite socket error
|
||||
result = await favorite_service.add_sound_favorite(test_user.id, test_sound.id)
|
||||
|
||||
# Verify operation still succeeded
|
||||
assert result == expected_favorite
|
||||
mock_favorite_repo.create.assert_called_once()
|
||||
mocks.favorite_repo.create.assert_called_once()
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.sound import Sound
|
||||
from app.services.sound_scanner import SoundScannerService
|
||||
from app.services.sound_scanner import SoundScannerService, SyncContext
|
||||
|
||||
|
||||
class TestSoundScannerService:
|
||||
@@ -155,14 +155,14 @@ class TestSoundScannerService:
|
||||
# Set the existing sound filename to match temp file for "unchanged" test
|
||||
existing_sound.filename = temp_path.name
|
||||
|
||||
await scanner_service._sync_audio_file(
|
||||
temp_path,
|
||||
"SDB",
|
||||
existing_sound, # existing_sound_by_hash (same hash)
|
||||
None, # existing_sound_by_filename (no conflict)
|
||||
"same_hash",
|
||||
results,
|
||||
sync_context = SyncContext(
|
||||
file_path=temp_path,
|
||||
sound_type="SDB",
|
||||
existing_sound_by_hash=existing_sound,
|
||||
existing_sound_by_filename=None,
|
||||
file_hash="same_hash",
|
||||
)
|
||||
await scanner_service._sync_audio_file(sync_context, results)
|
||||
|
||||
assert results["skipped"] == 1
|
||||
assert results["added"] == 0
|
||||
@@ -210,14 +210,14 @@ class TestSoundScannerService:
|
||||
"files": [],
|
||||
}
|
||||
|
||||
await scanner_service._sync_audio_file(
|
||||
temp_path,
|
||||
"SDB",
|
||||
existing_sound, # existing_sound_by_hash (same hash)
|
||||
None, # existing_sound_by_filename (different filename)
|
||||
"same_hash",
|
||||
results,
|
||||
sync_context = SyncContext(
|
||||
file_path=temp_path,
|
||||
sound_type="SDB",
|
||||
existing_sound_by_hash=existing_sound,
|
||||
existing_sound_by_filename=None,
|
||||
file_hash="same_hash",
|
||||
)
|
||||
await scanner_service._sync_audio_file(sync_context, results)
|
||||
|
||||
# Should be marked as updated (renamed)
|
||||
assert results["updated"] == 1
|
||||
@@ -257,12 +257,11 @@ class TestSoundScannerService:
|
||||
|
||||
# Create temporary directory with renamed file
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create the "renamed" file (same hash, different name)
|
||||
new_file_path = os.path.join(temp_dir, "new_name.mp3")
|
||||
with open(new_file_path, "wb") as f:
|
||||
new_file_path = Path(temp_dir) / "new_name.mp3"
|
||||
with new_file_path.open("wb") as f:
|
||||
f.write(b"test audio content") # This will produce consistent hash
|
||||
|
||||
# Mock file operations to return same hash
|
||||
@@ -308,16 +307,15 @@ class TestSoundScannerService:
|
||||
|
||||
# Create temporary directory with both original and duplicate files
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create both files (simulating duplicate content)
|
||||
original_path = os.path.join(temp_dir, "original.mp3")
|
||||
duplicate_path = os.path.join(temp_dir, "duplicate.mp3")
|
||||
original_path = Path(temp_dir) / "original.mp3"
|
||||
duplicate_path = Path(temp_dir) / "duplicate.mp3"
|
||||
|
||||
with open(original_path, "wb") as f:
|
||||
with original_path.open("wb") as f:
|
||||
f.write(b"test audio content")
|
||||
with open(duplicate_path, "wb") as f:
|
||||
with duplicate_path.open("wb") as f:
|
||||
f.write(b"test audio content") # Same content = same hash
|
||||
|
||||
# Mock file operations
|
||||
@@ -375,14 +373,14 @@ class TestSoundScannerService:
|
||||
"errors": 0,
|
||||
"files": [],
|
||||
}
|
||||
await scanner_service._sync_audio_file(
|
||||
temp_path,
|
||||
"SDB",
|
||||
None, # existing_sound_by_hash
|
||||
None, # existing_sound_by_filename
|
||||
"test_hash",
|
||||
results,
|
||||
sync_context = SyncContext(
|
||||
file_path=temp_path,
|
||||
sound_type="SDB",
|
||||
existing_sound_by_hash=None,
|
||||
existing_sound_by_filename=None,
|
||||
file_hash="test_hash",
|
||||
)
|
||||
await scanner_service._sync_audio_file(sync_context, results)
|
||||
|
||||
assert results["added"] == 1
|
||||
assert results["skipped"] == 0
|
||||
@@ -439,14 +437,14 @@ class TestSoundScannerService:
|
||||
"errors": 0,
|
||||
"files": [],
|
||||
}
|
||||
await scanner_service._sync_audio_file(
|
||||
temp_path,
|
||||
"SDB",
|
||||
None, # existing_sound_by_hash (different hash)
|
||||
existing_sound, # existing_sound_by_filename
|
||||
"new_hash",
|
||||
results,
|
||||
sync_context = SyncContext(
|
||||
file_path=temp_path,
|
||||
sound_type="SDB",
|
||||
existing_sound_by_hash=None,
|
||||
existing_sound_by_filename=existing_sound,
|
||||
file_hash="new_hash",
|
||||
)
|
||||
await scanner_service._sync_audio_file(sync_context, results)
|
||||
|
||||
assert results["updated"] == 1
|
||||
assert results["added"] == 0
|
||||
@@ -504,14 +502,14 @@ class TestSoundScannerService:
|
||||
"errors": 0,
|
||||
"files": [],
|
||||
}
|
||||
await scanner_service._sync_audio_file(
|
||||
temp_path,
|
||||
"CUSTOM",
|
||||
None, # existing_sound_by_hash
|
||||
None, # existing_sound_by_filename
|
||||
"custom_hash",
|
||||
results,
|
||||
sync_context = SyncContext(
|
||||
file_path=temp_path,
|
||||
sound_type="CUSTOM",
|
||||
existing_sound_by_hash=None,
|
||||
existing_sound_by_filename=None,
|
||||
file_hash="custom_hash",
|
||||
)
|
||||
await scanner_service._sync_audio_file(sync_context, results)
|
||||
|
||||
assert results["added"] == 1
|
||||
assert results["skipped"] == 0
|
||||
@@ -533,19 +531,19 @@ class TestSoundScannerService:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_audio_file_rename_with_normalized_file(
|
||||
self, test_session, scanner_service
|
||||
self, test_session, scanner_service,
|
||||
):
|
||||
"""Test that renaming a sound file also renames its normalized file."""
|
||||
# Create temporary directories for testing
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
|
||||
# Set up the scanner's normalized directories to use temp dir
|
||||
scanner_service.normalized_directories = {
|
||||
"SDB": str(temp_dir_path / "normalized" / "soundboard")
|
||||
"SDB": str(temp_dir_path / "normalized" / "soundboard"),
|
||||
}
|
||||
|
||||
# Create the normalized directory
|
||||
@@ -557,7 +555,6 @@ class TestSoundScannerService:
|
||||
old_normalized_file.write_text("normalized audio content")
|
||||
|
||||
# Create the audio files (they need to exist for the scanner)
|
||||
old_path = temp_dir_path / "old_sound.mp3"
|
||||
new_path = temp_dir_path / "new_sound.mp3"
|
||||
|
||||
# Create a dummy audio file for the new path
|
||||
@@ -565,8 +562,8 @@ class TestSoundScannerService:
|
||||
|
||||
# Mock the audio utility functions since we're using fake files
|
||||
from unittest.mock import patch
|
||||
with patch('app.services.sound_scanner.get_audio_duration', return_value=60000), \
|
||||
patch('app.services.sound_scanner.get_file_size', return_value=2048):
|
||||
with patch("app.services.sound_scanner.get_audio_duration", return_value=60000), \
|
||||
patch("app.services.sound_scanner.get_file_size", return_value=2048):
|
||||
|
||||
# Create existing sound with normalized file info
|
||||
existing_sound = Sound(
|
||||
@@ -584,7 +581,7 @@ class TestSoundScannerService:
|
||||
normalized_hash="normalized_hash",
|
||||
play_count=5,
|
||||
is_deletable=False,
|
||||
is_music=False
|
||||
is_music=False,
|
||||
)
|
||||
|
||||
results = {
|
||||
@@ -602,14 +599,14 @@ class TestSoundScannerService:
|
||||
scanner_service.sound_repo.update = AsyncMock()
|
||||
|
||||
# Simulate rename detection by calling _sync_audio_file
|
||||
await scanner_service._sync_audio_file(
|
||||
new_path,
|
||||
"SDB",
|
||||
existing_sound, # existing_sound_by_hash (same hash, different filename)
|
||||
None, # existing_sound_by_filename (no file with new name exists)
|
||||
"test_hash",
|
||||
results,
|
||||
sync_context = SyncContext(
|
||||
file_path=new_path,
|
||||
sound_type="SDB",
|
||||
existing_sound_by_hash=existing_sound,
|
||||
existing_sound_by_filename=None,
|
||||
file_hash="test_hash",
|
||||
)
|
||||
await scanner_service._sync_audio_file(sync_context, results)
|
||||
|
||||
# Verify the results
|
||||
assert results["updated"] == 1
|
||||
@@ -635,12 +632,12 @@ class TestSoundScannerService:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_directory_delete_with_normalized_file(
|
||||
self, test_session, scanner_service
|
||||
self, test_session, scanner_service,
|
||||
):
|
||||
"""Test that deleting a sound also deletes its normalized file."""
|
||||
# Create temporary directories for testing
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
temp_dir_path = Path(temp_dir)
|
||||
@@ -649,7 +646,7 @@ class TestSoundScannerService:
|
||||
|
||||
# Set up the scanner's normalized directories to use temp dir
|
||||
scanner_service.normalized_directories = {
|
||||
"SDB": str(temp_dir_path / "normalized" / "soundboard")
|
||||
"SDB": str(temp_dir_path / "normalized" / "soundboard"),
|
||||
}
|
||||
|
||||
# Create the normalized directory and file
|
||||
@@ -674,7 +671,7 @@ class TestSoundScannerService:
|
||||
normalized_hash="normalized_hash",
|
||||
play_count=5,
|
||||
is_deletable=False,
|
||||
is_music=False
|
||||
is_music=False,
|
||||
)
|
||||
|
||||
# Mock sound repository methods
|
||||
@@ -683,8 +680,8 @@ class TestSoundScannerService:
|
||||
|
||||
# Mock audio utility functions
|
||||
from unittest.mock import patch
|
||||
with patch('app.services.sound_scanner.get_audio_duration'), \
|
||||
patch('app.services.sound_scanner.get_file_size'):
|
||||
with patch("app.services.sound_scanner.get_audio_duration"), \
|
||||
patch("app.services.sound_scanner.get_file_size"):
|
||||
|
||||
# Run scan with empty directory (should trigger deletion)
|
||||
results = await scanner_service.scan_directory(str(scan_dir), "SDB")
|
||||
|
||||
Reference in New Issue
Block a user