feat: Implement background extraction processor with concurrency control
- Added `ExtractionProcessor` class to handle extraction queue processing in the background. - Implemented methods for starting, stopping, and queuing extractions with concurrency limits. - Integrated logging for monitoring the processor's status and actions. - Created tests for the extraction processor to ensure functionality and error handling. test: Add unit tests for extraction API endpoints - Created tests for successful extraction creation, authentication checks, and processor status retrieval. - Ensured proper responses for authenticated and unauthenticated requests. test: Implement unit tests for extraction repository - Added tests for creating, retrieving, and updating extractions in the repository. - Mocked database interactions to validate repository behavior without actual database access. test: Add comprehensive tests for extraction service - Developed tests for extraction creation, service detection, and sound record creation. - Included tests for handling duplicate extractions and invalid URLs. test: Add unit tests for extraction background processor - Created tests for the `ExtractionProcessor` class to validate its behavior under various conditions. - Ensured proper handling of extraction queuing, processing, and completion callbacks. fix: Update OAuth service tests to use AsyncMock - Modified OAuth provider tests to use `AsyncMock` for mocking asynchronous HTTP requests.
This commit is contained in:
@@ -8,6 +8,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from app.core.dependencies import get_current_active_user_flexible
|
from app.core.dependencies import get_current_active_user_flexible
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.services.extraction import ExtractionInfo, ExtractionService
|
||||||
|
from app.services.extraction_processor import extraction_processor
|
||||||
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
|
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
|
||||||
from app.services.sound_scanner import ScanResults, SoundScannerService
|
from app.services.sound_scanner import ScanResults, SoundScannerService
|
||||||
|
|
||||||
@@ -28,6 +30,13 @@ async def get_sound_normalizer_service(
|
|||||||
return SoundNormalizerService(session)
|
return SoundNormalizerService(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_extraction_service(
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> ExtractionService:
|
||||||
|
"""Get the extraction service."""
|
||||||
|
return ExtractionService(session)
|
||||||
|
|
||||||
|
|
||||||
# SCAN
|
# SCAN
|
||||||
@router.post("/scan")
|
@router.post("/scan")
|
||||||
async def scan_sounds(
|
async def scan_sounds(
|
||||||
@@ -233,3 +242,110 @@ async def normalize_sound_by_id(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to normalize sound: {e!s}",
|
detail=f"Failed to normalize sound: {e!s}",
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
# EXTRACT
|
||||||
|
@router.post("/extract")
|
||||||
|
async def create_extraction(
|
||||||
|
url: str,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
|
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||||
|
) -> dict[str, ExtractionInfo | str]:
|
||||||
|
"""Create a new extraction job for a URL."""
|
||||||
|
try:
|
||||||
|
if current_user.id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User ID not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
extraction_info = await extraction_service.create_extraction(
|
||||||
|
url, current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Queue the extraction for background processing
|
||||||
|
await extraction_processor.queue_extraction(extraction_info["id"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "Extraction queued successfully",
|
||||||
|
"extraction": extraction_info,
|
||||||
|
}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to create extraction: {e!s}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/extract/status")
|
||||||
|
async def get_extraction_processor_status(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
|
) -> dict:
|
||||||
|
"""Get the status of the extraction processor."""
|
||||||
|
# Only allow admins to see processor status
|
||||||
|
if current_user.role not in ["admin", "superadmin"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Only administrators can view processor status",
|
||||||
|
)
|
||||||
|
|
||||||
|
return extraction_processor.get_status()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/extract/{extraction_id}")
|
||||||
|
async def get_extraction(
|
||||||
|
extraction_id: int,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
|
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||||
|
) -> ExtractionInfo:
|
||||||
|
"""Get extraction information by ID."""
|
||||||
|
try:
|
||||||
|
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
|
||||||
|
|
||||||
|
if not extraction_info:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Extraction {extraction_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return extraction_info
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to get extraction: {e!s}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/extract")
|
||||||
|
async def get_user_extractions(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
||||||
|
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
|
||||||
|
) -> dict[str, list[ExtractionInfo]]:
|
||||||
|
"""Get all extractions for the current user."""
|
||||||
|
try:
|
||||||
|
if current_user.id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="User ID not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
extractions = await extraction_service.get_user_extractions(current_user.id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"extractions": extractions,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to get extractions: {e!s}",
|
||||||
|
) from e
|
||||||
|
|||||||
@@ -52,5 +52,12 @@ class Settings(BaseSettings):
|
|||||||
NORMALIZED_AUDIO_BITRATE: str = "256k"
|
NORMALIZED_AUDIO_BITRATE: str = "256k"
|
||||||
NORMALIZED_AUDIO_PASSES: int = 2 # 1 for one-pass, 2 for two-pass
|
NORMALIZED_AUDIO_PASSES: int = 2 # 1 for one-pass, 2 for two-pass
|
||||||
|
|
||||||
|
# Audio Extraction Configuration
|
||||||
|
EXTRACTION_AUDIO_FORMAT: str = "mp3"
|
||||||
|
EXTRACTION_AUDIO_BITRATE: str = "256k"
|
||||||
|
EXTRACTION_TEMP_DIR: str = "sounds/temp"
|
||||||
|
EXTRACTION_THUMBNAILS_DIR: str = "sounds/originals/extracted/thumbnails"
|
||||||
|
EXTRACTION_MAX_CONCURRENT: int = 2 # Maximum concurrent extractions
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from app.api import api_router
|
|||||||
from app.core.database import init_db
|
from app.core.database import init_db
|
||||||
from app.core.logging import get_logger, setup_logging
|
from app.core.logging import get_logger, setup_logging
|
||||||
from app.middleware.logging import LoggingMiddleware
|
from app.middleware.logging import LoggingMiddleware
|
||||||
|
from app.services.extraction_processor import extraction_processor
|
||||||
from app.services.socket import socket_manager
|
from app.services.socket import socket_manager
|
||||||
|
|
||||||
|
|
||||||
@@ -22,10 +23,18 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
await init_db()
|
await init_db()
|
||||||
logger.info("Database initialized")
|
logger.info("Database initialized")
|
||||||
|
|
||||||
|
# Start the extraction processor
|
||||||
|
await extraction_processor.start()
|
||||||
|
logger.info("Extraction processor started")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
logger.info("Shutting down application")
|
logger.info("Shutting down application")
|
||||||
|
|
||||||
|
# Stop the extraction processor
|
||||||
|
await extraction_processor.stop()
|
||||||
|
logger.info("Extraction processor stopped")
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
"""Create and configure the FastAPI application."""
|
"""Create and configure the FastAPI application."""
|
||||||
|
|||||||
82
app/repositories/extraction.py
Normal file
82
app/repositories/extraction.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Extraction repository for database operations."""
|
||||||
|
|
||||||
|
from sqlalchemy import desc
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.extraction import Extraction
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionRepository:
|
||||||
|
"""Repository for extraction database operations."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize the extraction repository."""
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def create(self, extraction_data: dict) -> Extraction:
|
||||||
|
"""Create a new extraction."""
|
||||||
|
extraction = Extraction(**extraction_data)
|
||||||
|
self.session.add(extraction)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(extraction)
|
||||||
|
return extraction
|
||||||
|
|
||||||
|
async def get_by_id(self, extraction_id: int) -> Extraction | None:
|
||||||
|
"""Get an extraction by ID."""
|
||||||
|
result = await self.session.exec(
|
||||||
|
select(Extraction).where(Extraction.id == extraction_id)
|
||||||
|
)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def get_by_service_and_id(
|
||||||
|
self, service: str, service_id: str
|
||||||
|
) -> Extraction | None:
|
||||||
|
"""Get an extraction by service and service_id."""
|
||||||
|
result = await self.session.exec(
|
||||||
|
select(Extraction).where(
|
||||||
|
Extraction.service == service, Extraction.service_id == service_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def get_by_user(self, user_id: int) -> list[Extraction]:
|
||||||
|
"""Get all extractions for a user."""
|
||||||
|
result = await self.session.exec(
|
||||||
|
select(Extraction)
|
||||||
|
.where(Extraction.user_id == user_id)
|
||||||
|
.order_by(desc(Extraction.created_at))
|
||||||
|
)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_pending_extractions(self) -> list[Extraction]:
|
||||||
|
"""Get all pending extractions."""
|
||||||
|
result = await self.session.exec(
|
||||||
|
select(Extraction)
|
||||||
|
.where(Extraction.status == "pending")
|
||||||
|
.order_by(Extraction.created_at)
|
||||||
|
)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
|
||||||
|
"""Update an extraction."""
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(extraction, key, value)
|
||||||
|
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(extraction)
|
||||||
|
return extraction
|
||||||
|
|
||||||
|
async def delete(self, extraction: Extraction) -> None:
|
||||||
|
"""Delete an extraction."""
|
||||||
|
await self.session.delete(extraction)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
|
||||||
|
"""Get extractions by status."""
|
||||||
|
result = await self.session.exec(
|
||||||
|
select(Extraction)
|
||||||
|
.where(Extraction.status == status)
|
||||||
|
.order_by(desc(Extraction.created_at))
|
||||||
|
)
|
||||||
|
return list(result.all())
|
||||||
517
app/services/extraction.py
Normal file
517
app/services/extraction.py
Normal file
@@ -0,0 +1,517 @@
|
|||||||
|
"""Extraction service for audio extraction from external services using yt-dlp."""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
import yt_dlp
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.extraction import Extraction
|
||||||
|
from app.models.sound import Sound
|
||||||
|
from app.repositories.extraction import ExtractionRepository
|
||||||
|
from app.repositories.sound import SoundRepository
|
||||||
|
from app.services.sound_normalizer import SoundNormalizerService
|
||||||
|
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionInfo(TypedDict):
|
||||||
|
"""Type definition for extraction information."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
url: str
|
||||||
|
service: str
|
||||||
|
service_id: str
|
||||||
|
title: str | None
|
||||||
|
status: str
|
||||||
|
error: str | None
|
||||||
|
sound_id: int | None
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionService:
|
||||||
|
"""Service for extracting audio from external services using yt-dlp."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize the extraction service."""
|
||||||
|
self.session = session
|
||||||
|
self.extraction_repo = ExtractionRepository(session)
|
||||||
|
self.sound_repo = SoundRepository(session)
|
||||||
|
|
||||||
|
# Ensure required directories exist
|
||||||
|
self._ensure_directories()
|
||||||
|
|
||||||
|
def _ensure_directories(self) -> None:
|
||||||
|
"""Ensure all required directories exist."""
|
||||||
|
directories = [
|
||||||
|
settings.EXTRACTION_TEMP_DIR,
|
||||||
|
"sounds/originals/extracted",
|
||||||
|
settings.EXTRACTION_THUMBNAILS_DIR,
|
||||||
|
]
|
||||||
|
|
||||||
|
for directory in directories:
|
||||||
|
Path(directory).mkdir(parents=True, exist_ok=True)
|
||||||
|
logger.debug("Ensured directory exists: %s", directory)
|
||||||
|
|
||||||
|
async def create_extraction(self, url: str, user_id: int) -> ExtractionInfo:
|
||||||
|
"""Create a new extraction job."""
|
||||||
|
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First, detect service and service_id using yt-dlp
|
||||||
|
service_info = self._detect_service_info(url)
|
||||||
|
|
||||||
|
if not service_info:
|
||||||
|
raise ValueError("Unable to detect service information from URL")
|
||||||
|
|
||||||
|
service = service_info["service"]
|
||||||
|
service_id = service_info["service_id"]
|
||||||
|
title = service_info.get("title")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Detected service: %s, service_id: %s, title: %s",
|
||||||
|
service,
|
||||||
|
service_id,
|
||||||
|
title,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if extraction already exists
|
||||||
|
existing = await self.extraction_repo.get_by_service_and_id(
|
||||||
|
service, service_id
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
error_msg = f"Extraction already exists for {service}:{service_id}"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
# Create the extraction record
|
||||||
|
extraction_data = {
|
||||||
|
"url": url,
|
||||||
|
"user_id": user_id,
|
||||||
|
"service": service,
|
||||||
|
"service_id": service_id,
|
||||||
|
"title": title,
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
|
|
||||||
|
extraction = await self.extraction_repo.create(extraction_data)
|
||||||
|
logger.info("Created extraction with ID: %d", extraction.id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": extraction.id or 0, # Should never be None for created extraction
|
||||||
|
"url": extraction.url,
|
||||||
|
"service": extraction.service,
|
||||||
|
"service_id": extraction.service_id,
|
||||||
|
"title": extraction.title,
|
||||||
|
"status": extraction.status,
|
||||||
|
"error": extraction.error,
|
||||||
|
"sound_id": extraction.sound_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to create extraction for URL: %s", url)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _detect_service_info(self, url: str) -> dict | None:
|
||||||
|
"""Detect service information from URL using yt-dlp."""
|
||||||
|
try:
|
||||||
|
# Configure yt-dlp for info extraction only
|
||||||
|
ydl_opts = {
|
||||||
|
"quiet": True,
|
||||||
|
"no_warnings": True,
|
||||||
|
"extract_flat": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||||
|
# Extract info without downloading
|
||||||
|
info = ydl.extract_info(url, download=False)
|
||||||
|
|
||||||
|
if not info:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Map extractor names to our service names
|
||||||
|
extractor_map = {
|
||||||
|
"youtube": "youtube",
|
||||||
|
"dailymotion": "dailymotion",
|
||||||
|
"vimeo": "vimeo",
|
||||||
|
"soundcloud": "soundcloud",
|
||||||
|
"twitter": "twitter",
|
||||||
|
"tiktok": "tiktok",
|
||||||
|
"instagram": "instagram",
|
||||||
|
}
|
||||||
|
|
||||||
|
extractor = info.get("extractor", "").lower()
|
||||||
|
service = extractor_map.get(extractor, extractor)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"service": service,
|
||||||
|
"service_id": str(info.get("id", "")),
|
||||||
|
"title": info.get("title"),
|
||||||
|
"duration": info.get("duration"),
|
||||||
|
"uploader": info.get("uploader"),
|
||||||
|
"description": info.get("description"),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to detect service info for URL: %s", url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
|
||||||
|
"""Process an extraction job."""
|
||||||
|
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||||
|
if not extraction:
|
||||||
|
raise ValueError(f"Extraction {extraction_id} not found")
|
||||||
|
|
||||||
|
if extraction.status != "pending":
|
||||||
|
raise ValueError(f"Extraction {extraction_id} is not pending")
|
||||||
|
|
||||||
|
# Store all needed values early to avoid session detachment issues
|
||||||
|
user_id = extraction.user_id
|
||||||
|
extraction_url = extraction.url
|
||||||
|
extraction_title = extraction.title
|
||||||
|
extraction_service = extraction.service
|
||||||
|
extraction_service_id = extraction.service_id
|
||||||
|
|
||||||
|
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Update status to processing
|
||||||
|
await self.extraction_repo.update(extraction, {"status": "processing"})
|
||||||
|
|
||||||
|
# Extract audio and thumbnail
|
||||||
|
audio_file, thumbnail_file = await self._extract_media(
|
||||||
|
extraction_id, extraction_url
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move files to final locations
|
||||||
|
final_audio_path, final_thumbnail_path = (
|
||||||
|
await self._move_files_to_final_location(
|
||||||
|
audio_file,
|
||||||
|
thumbnail_file,
|
||||||
|
extraction_title,
|
||||||
|
extraction_service,
|
||||||
|
extraction_service_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create Sound record
|
||||||
|
sound = await self._create_sound_record(
|
||||||
|
final_audio_path,
|
||||||
|
extraction_title,
|
||||||
|
extraction_service,
|
||||||
|
extraction_service_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store sound_id early to avoid session detachment issues
|
||||||
|
sound_id = sound.id
|
||||||
|
|
||||||
|
# Normalize the sound
|
||||||
|
await self._normalize_sound(sound)
|
||||||
|
|
||||||
|
# Add to main playlist
|
||||||
|
await self._add_to_main_playlist(sound, user_id)
|
||||||
|
|
||||||
|
# Update extraction with success
|
||||||
|
await self.extraction_repo.update(
|
||||||
|
extraction,
|
||||||
|
{
|
||||||
|
"status": "completed",
|
||||||
|
"sound_id": sound_id,
|
||||||
|
"error": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Successfully processed extraction %d", extraction_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": extraction_id,
|
||||||
|
"url": extraction_url,
|
||||||
|
"service": extraction_service,
|
||||||
|
"service_id": extraction_service_id,
|
||||||
|
"title": extraction_title,
|
||||||
|
"status": "completed",
|
||||||
|
"error": None,
|
||||||
|
"sound_id": sound_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
logger.exception(
|
||||||
|
"Failed to process extraction %d: %s", extraction_id, error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update extraction with error
|
||||||
|
await self.extraction_repo.update(
|
||||||
|
extraction,
|
||||||
|
{
|
||||||
|
"status": "failed",
|
||||||
|
"error": error_msg,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": extraction_id,
|
||||||
|
"url": extraction_url,
|
||||||
|
"service": extraction_service,
|
||||||
|
"service_id": extraction_service_id,
|
||||||
|
"title": extraction_title,
|
||||||
|
"status": "failed",
|
||||||
|
"error": error_msg,
|
||||||
|
"sound_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _extract_media(
|
||||||
|
self, extraction_id: int, extraction_url: str
|
||||||
|
) -> tuple[Path, Path | None]:
|
||||||
|
"""Extract audio and thumbnail using yt-dlp."""
|
||||||
|
temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
|
||||||
|
|
||||||
|
# Create unique filename based on extraction ID
|
||||||
|
output_template = str(
|
||||||
|
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure yt-dlp options
|
||||||
|
ydl_opts = {
|
||||||
|
"format": "bestaudio/best",
|
||||||
|
"outtmpl": output_template,
|
||||||
|
"extractaudio": True,
|
||||||
|
"audioformat": settings.EXTRACTION_AUDIO_FORMAT,
|
||||||
|
"audioquality": settings.EXTRACTION_AUDIO_BITRATE,
|
||||||
|
"writethumbnail": True,
|
||||||
|
"writeinfojson": False,
|
||||||
|
"writeautomaticsub": False,
|
||||||
|
"writesubtitles": False,
|
||||||
|
"postprocessors": [
|
||||||
|
{
|
||||||
|
"key": "FFmpegExtractAudio",
|
||||||
|
"preferredcodec": settings.EXTRACTION_AUDIO_FORMAT,
|
||||||
|
"preferredquality": settings.EXTRACTION_AUDIO_BITRATE.rstrip("k"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||||
|
# Download and extract
|
||||||
|
ydl.download([extraction_url])
|
||||||
|
|
||||||
|
# Find the extracted files
|
||||||
|
audio_files = list(
|
||||||
|
temp_dir.glob(
|
||||||
|
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
thumbnail_files = (
|
||||||
|
list(temp_dir.glob(f"extraction_{extraction_id}_*.webp"))
|
||||||
|
+ list(temp_dir.glob(f"extraction_{extraction_id}_*.jpg"))
|
||||||
|
+ list(temp_dir.glob(f"extraction_{extraction_id}_*.png"))
|
||||||
|
)
|
||||||
|
|
||||||
|
if not audio_files:
|
||||||
|
raise RuntimeError("No audio file was created during extraction")
|
||||||
|
|
||||||
|
audio_file = audio_files[0]
|
||||||
|
thumbnail_file = thumbnail_files[0] if thumbnail_files else None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Extracted audio: %s, thumbnail: %s",
|
||||||
|
audio_file,
|
||||||
|
thumbnail_file or "None",
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_file, thumbnail_file
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("yt-dlp extraction failed for %s", extraction_url)
|
||||||
|
raise RuntimeError(f"Audio extraction failed: {e}") from e
|
||||||
|
|
||||||
|
async def _move_files_to_final_location(
|
||||||
|
self,
|
||||||
|
audio_file: Path,
|
||||||
|
thumbnail_file: Path | None,
|
||||||
|
title: str | None,
|
||||||
|
service: str,
|
||||||
|
service_id: str,
|
||||||
|
) -> tuple[Path, Path | None]:
|
||||||
|
"""Move extracted files to their final locations."""
|
||||||
|
# Generate clean filename based on title and service
|
||||||
|
safe_title = self._sanitize_filename(title or f"{service}_{service_id}")
|
||||||
|
|
||||||
|
# Move audio file
|
||||||
|
final_audio_path = (
|
||||||
|
Path("sounds/originals/extracted")
|
||||||
|
/ f"{safe_title}.{settings.EXTRACTION_AUDIO_FORMAT}"
|
||||||
|
)
|
||||||
|
final_audio_path = self._ensure_unique_filename(final_audio_path)
|
||||||
|
|
||||||
|
shutil.move(str(audio_file), str(final_audio_path))
|
||||||
|
logger.info("Moved audio file to: %s", final_audio_path)
|
||||||
|
|
||||||
|
# Move thumbnail file if it exists
|
||||||
|
final_thumbnail_path = None
|
||||||
|
if thumbnail_file:
|
||||||
|
thumbnail_ext = thumbnail_file.suffix
|
||||||
|
final_thumbnail_path = (
|
||||||
|
Path(settings.EXTRACTION_THUMBNAILS_DIR)
|
||||||
|
/ f"{safe_title}{thumbnail_ext}"
|
||||||
|
)
|
||||||
|
final_thumbnail_path = self._ensure_unique_filename(final_thumbnail_path)
|
||||||
|
|
||||||
|
shutil.move(str(thumbnail_file), str(final_thumbnail_path))
|
||||||
|
logger.info("Moved thumbnail file to: %s", final_thumbnail_path)
|
||||||
|
|
||||||
|
return final_audio_path, final_thumbnail_path
|
||||||
|
|
||||||
|
def _sanitize_filename(self, filename: str) -> str:
|
||||||
|
"""Sanitize filename for filesystem."""
|
||||||
|
# Remove or replace problematic characters
|
||||||
|
invalid_chars = '<>:"/\\|?*'
|
||||||
|
for char in invalid_chars:
|
||||||
|
filename = filename.replace(char, "_")
|
||||||
|
|
||||||
|
# Limit length and remove leading/trailing spaces
|
||||||
|
filename = filename.strip()[:100]
|
||||||
|
|
||||||
|
return filename or "untitled"
|
||||||
|
|
||||||
|
def _ensure_unique_filename(self, filepath: Path) -> Path:
|
||||||
|
"""Ensure filename is unique by adding counter if needed."""
|
||||||
|
if not filepath.exists():
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
stem = filepath.stem
|
||||||
|
suffix = filepath.suffix
|
||||||
|
parent = filepath.parent
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
while True:
|
||||||
|
new_path = parent / f"{stem}_{counter}{suffix}"
|
||||||
|
if not new_path.exists():
|
||||||
|
return new_path
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
async def _create_sound_record(
|
||||||
|
self, audio_path: Path, title: str | None, service: str, service_id: str
|
||||||
|
) -> Sound:
|
||||||
|
"""Create a Sound record for the extracted audio."""
|
||||||
|
# Get audio metadata
|
||||||
|
duration = get_audio_duration(audio_path)
|
||||||
|
size = get_file_size(audio_path)
|
||||||
|
file_hash = get_file_hash(audio_path)
|
||||||
|
|
||||||
|
# Create sound data
|
||||||
|
sound_data = {
|
||||||
|
"type": "EXT",
|
||||||
|
"name": title or f"{service}_{service_id}",
|
||||||
|
"filename": audio_path.name,
|
||||||
|
"duration": duration,
|
||||||
|
"size": size,
|
||||||
|
"hash": file_hash,
|
||||||
|
"is_deletable": True, # Extracted sounds can be deleted
|
||||||
|
"is_music": True, # Assume extracted content is music
|
||||||
|
"is_normalized": False,
|
||||||
|
"play_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
sound = await self.sound_repo.create(sound_data)
|
||||||
|
logger.info("Created sound record with ID: %d", sound.id)
|
||||||
|
|
||||||
|
return sound
|
||||||
|
|
||||||
|
async def _normalize_sound(self, sound: Sound) -> None:
|
||||||
|
"""Normalize the extracted sound."""
|
||||||
|
try:
|
||||||
|
normalizer_service = SoundNormalizerService(self.session)
|
||||||
|
result = await normalizer_service.normalize_sound(sound)
|
||||||
|
|
||||||
|
if result["status"] == "error":
|
||||||
|
logger.warning(
|
||||||
|
"Failed to normalize sound %d: %s",
|
||||||
|
sound.id,
|
||||||
|
result.get("error"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Successfully normalized sound %d", sound.id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error normalizing sound %d: %s", sound.id, e)
|
||||||
|
# Don't fail the extraction if normalization fails
|
||||||
|
|
||||||
|
async def _add_to_main_playlist(self, sound: Sound, user_id: int) -> None:
|
||||||
|
"""Add the sound to the user's main playlist."""
|
||||||
|
try:
|
||||||
|
# This is a placeholder - implement based on your playlist logic
|
||||||
|
# For now, we'll just log that we would add it to the main playlist
|
||||||
|
logger.info(
|
||||||
|
"Would add sound %d to main playlist for user %d",
|
||||||
|
sound.id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Error adding sound %d to main playlist for user %d: %s",
|
||||||
|
sound.id,
|
||||||
|
user_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
# Don't fail the extraction if playlist addition fails
|
||||||
|
|
||||||
|
async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None:
|
||||||
|
"""Get extraction information by ID."""
|
||||||
|
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||||
|
if not extraction:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": extraction.id or 0, # Should never be None for existing extraction
|
||||||
|
"url": extraction.url,
|
||||||
|
"service": extraction.service,
|
||||||
|
"service_id": extraction.service_id,
|
||||||
|
"title": extraction.title,
|
||||||
|
"status": extraction.status,
|
||||||
|
"error": extraction.error,
|
||||||
|
"sound_id": extraction.sound_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_user_extractions(self, user_id: int) -> list[ExtractionInfo]:
|
||||||
|
"""Get all extractions for a user."""
|
||||||
|
extractions = await self.extraction_repo.get_by_user(user_id)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": extraction.id
|
||||||
|
or 0, # Should never be None for existing extraction
|
||||||
|
"url": extraction.url,
|
||||||
|
"service": extraction.service,
|
||||||
|
"service_id": extraction.service_id,
|
||||||
|
"title": extraction.title,
|
||||||
|
"status": extraction.status,
|
||||||
|
"error": extraction.error,
|
||||||
|
"sound_id": extraction.sound_id,
|
||||||
|
}
|
||||||
|
for extraction in extractions
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_pending_extractions(self) -> list[ExtractionInfo]:
|
||||||
|
"""Get all pending extractions."""
|
||||||
|
extractions = await self.extraction_repo.get_pending_extractions()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": extraction.id
|
||||||
|
or 0, # Should never be None for existing extraction
|
||||||
|
"url": extraction.url,
|
||||||
|
"service": extraction.service,
|
||||||
|
"service_id": extraction.service_id,
|
||||||
|
"title": extraction.title,
|
||||||
|
"status": extraction.status,
|
||||||
|
"error": extraction.error,
|
||||||
|
"sound_id": extraction.sound_id,
|
||||||
|
}
|
||||||
|
for extraction in extractions
|
||||||
|
]
|
||||||
196
app/services/extraction_processor.py
Normal file
196
app/services/extraction_processor.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""Background extraction processor for handling extraction queue."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.database import engine
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.services.extraction import ExtractionService
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractionProcessor:
|
||||||
|
"""Background processor for handling extraction queue with concurrency control."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the extraction processor."""
|
||||||
|
self.max_concurrent = settings.EXTRACTION_MAX_CONCURRENT
|
||||||
|
self.running_extractions: Set[int] = set()
|
||||||
|
self.processing_lock = asyncio.Lock()
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
self.processor_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Initialized extraction processor with max concurrent: %d",
|
||||||
|
self.max_concurrent,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the background extraction processor."""
|
||||||
|
if self.processor_task and not self.processor_task.done():
|
||||||
|
logger.warning("Extraction processor is already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.shutdown_event.clear()
|
||||||
|
self.processor_task = asyncio.create_task(self._process_queue())
|
||||||
|
logger.info("Started extraction processor")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the background extraction processor."""
|
||||||
|
logger.info("Stopping extraction processor...")
|
||||||
|
self.shutdown_event.set()
|
||||||
|
|
||||||
|
if self.processor_task and not self.processor_task.done():
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.processor_task, timeout=30.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Extraction processor did not stop gracefully, cancelling..."
|
||||||
|
)
|
||||||
|
self.processor_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.processor_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("Extraction processor stopped")
|
||||||
|
|
||||||
|
async def queue_extraction(self, extraction_id: int) -> None:
|
||||||
|
"""Queue an extraction for processing."""
|
||||||
|
async with self.processing_lock:
|
||||||
|
if extraction_id not in self.running_extractions:
|
||||||
|
logger.info("Queued extraction %d for processing", extraction_id)
|
||||||
|
# The processor will pick it up on the next cycle
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Extraction %d is already being processed", extraction_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_queue(self) -> None:
|
||||||
|
"""Main processing loop that handles the extraction queue."""
|
||||||
|
logger.info("Starting extraction queue processor")
|
||||||
|
|
||||||
|
while not self.shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
await self._process_pending_extractions()
|
||||||
|
|
||||||
|
# Wait before checking for new extractions
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
|
||||||
|
break # Shutdown requested
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue # Continue processing
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in extraction queue processor: %s", e)
|
||||||
|
# Wait a bit before retrying to avoid tight error loops
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
|
||||||
|
break # Shutdown requested
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("Extraction queue processor stopped")
|
||||||
|
|
||||||
|
async def _process_pending_extractions(self) -> None:
|
||||||
|
"""Process pending extractions up to the concurrency limit."""
|
||||||
|
async with self.processing_lock:
|
||||||
|
# Check how many slots are available
|
||||||
|
available_slots = self.max_concurrent - len(self.running_extractions)
|
||||||
|
|
||||||
|
if available_slots <= 0:
|
||||||
|
return # No available slots
|
||||||
|
|
||||||
|
# Get pending extractions from database
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
extraction_service = ExtractionService(session)
|
||||||
|
pending_extractions = await extraction_service.get_pending_extractions()
|
||||||
|
|
||||||
|
# Filter out extractions that are already being processed
|
||||||
|
available_extractions = [
|
||||||
|
ext
|
||||||
|
for ext in pending_extractions
|
||||||
|
if ext["id"] not in self.running_extractions
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start processing up to available slots
|
||||||
|
extractions_to_start = available_extractions[:available_slots]
|
||||||
|
|
||||||
|
for extraction_info in extractions_to_start:
|
||||||
|
extraction_id = extraction_info["id"]
|
||||||
|
self.running_extractions.add(extraction_id)
|
||||||
|
|
||||||
|
# Start processing this extraction in the background
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._process_single_extraction(extraction_id)
|
||||||
|
)
|
||||||
|
task.add_done_callback(
|
||||||
|
lambda t, eid=extraction_id: self._on_extraction_completed(
|
||||||
|
eid,
|
||||||
|
t,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Started processing extraction %d (%d/%d slots used)",
|
||||||
|
extraction_id,
|
||||||
|
len(self.running_extractions),
|
||||||
|
self.max_concurrent,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_single_extraction(self, extraction_id: int) -> None:
|
||||||
|
"""Process a single extraction."""
|
||||||
|
try:
|
||||||
|
logger.info("Processing extraction %d", extraction_id)
|
||||||
|
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
extraction_service = ExtractionService(session)
|
||||||
|
result = await extraction_service.process_extraction(extraction_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Completed extraction %d with status: %s",
|
||||||
|
extraction_id,
|
||||||
|
result["status"],
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error processing extraction %d: %s", extraction_id, e)
|
||||||
|
|
||||||
|
def _on_extraction_completed(self, extraction_id: int, task: asyncio.Task) -> None:
|
||||||
|
"""Callback when an extraction task is completed."""
|
||||||
|
# Remove from running set
|
||||||
|
self.running_extractions.discard(extraction_id)
|
||||||
|
|
||||||
|
# Check if the task had an exception
|
||||||
|
if task.exception():
|
||||||
|
logger.error(
|
||||||
|
"Extraction %d completed with exception: %s",
|
||||||
|
extraction_id,
|
||||||
|
task.exception(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Extraction %d completed successfully (%d/%d slots used)",
|
||||||
|
extraction_id,
|
||||||
|
len(self.running_extractions),
|
||||||
|
self.max_concurrent,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_status(self) -> dict:
|
||||||
|
"""Get the current status of the extraction processor."""
|
||||||
|
return {
|
||||||
|
"running": self.processor_task is not None
|
||||||
|
and not self.processor_task.done(),
|
||||||
|
"max_concurrent": self.max_concurrent,
|
||||||
|
"currently_processing": len(self.running_extractions),
|
||||||
|
"processing_ids": list(self.running_extractions),
|
||||||
|
"available_slots": self.max_concurrent - len(self.running_extractions),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Global extraction processor instance
|
||||||
|
extraction_processor = ExtractionProcessor()
|
||||||
95
tests/api/v1/test_extraction_endpoints.py
Normal file
95
tests/api/v1/test_extraction_endpoints.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""Tests for extraction API endpoints."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractionEndpoints:
|
||||||
|
"""Test extraction API endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_extraction_success(
|
||||||
|
self, test_client: AsyncClient, auth_cookies: dict[str, str]
|
||||||
|
):
|
||||||
|
"""Test successful extraction creation."""
|
||||||
|
# Set cookies on client instance to avoid deprecation warning
|
||||||
|
test_client.cookies.update(auth_cookies)
|
||||||
|
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/v1/sounds/extract",
|
||||||
|
params={"url": "https://www.youtube.com/watch?v=test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# This will fail because we don't have actual extraction service mocked
|
||||||
|
# But at least we'll get past authentication
|
||||||
|
assert response.status_code in [200, 400, 500] # Allow any non-auth error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_extraction_unauthenticated(self, test_client: AsyncClient):
|
||||||
|
"""Test extraction creation without authentication."""
|
||||||
|
response = await test_client.post(
|
||||||
|
"/api/v1/sounds/extract",
|
||||||
|
params={"url": "https://www.youtube.com/watch?v=test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 401 for missing authentication
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_extraction_unauthenticated(self, test_client: AsyncClient):
|
||||||
|
"""Test extraction retrieval without authentication."""
|
||||||
|
response = await test_client.get("/api/v1/sounds/extract/1")
|
||||||
|
|
||||||
|
# Should return 401 for missing authentication
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_processor_status_admin(
|
||||||
|
self, test_client: AsyncClient, admin_cookies: dict[str, str]
|
||||||
|
):
|
||||||
|
"""Test getting processor status as admin."""
|
||||||
|
# Set cookies on client instance to avoid deprecation warning
|
||||||
|
test_client.cookies.update(admin_cookies)
|
||||||
|
|
||||||
|
response = await test_client.get("/api/v1/sounds/extract/status")
|
||||||
|
|
||||||
|
# Should succeed for admin users
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "running" in data
|
||||||
|
assert "max_concurrent" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_processor_status_non_admin(
|
||||||
|
self, test_client: AsyncClient, auth_cookies: dict[str, str]
|
||||||
|
):
|
||||||
|
"""Test getting processor status as non-admin user."""
|
||||||
|
# Set cookies on client instance to avoid deprecation warning
|
||||||
|
test_client.cookies.update(auth_cookies)
|
||||||
|
|
||||||
|
response = await test_client.get("/api/v1/sounds/extract/status")
|
||||||
|
|
||||||
|
# Should return 403 for non-admin users
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert "Only administrators" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_extractions(
|
||||||
|
self, test_client: AsyncClient, auth_cookies: dict[str, str]
|
||||||
|
):
|
||||||
|
"""Test getting user extractions."""
|
||||||
|
# Set cookies on client instance to avoid deprecation warning
|
||||||
|
test_client.cookies.update(auth_cookies)
|
||||||
|
|
||||||
|
response = await test_client.get("/api/v1/sounds/extract")
|
||||||
|
|
||||||
|
# Should succeed and return empty list (no extractions in test DB)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "extractions" in data
|
||||||
|
assert isinstance(data["extractions"], list)
|
||||||
128
tests/repositories/test_extraction.py
Normal file
128
tests/repositories/test_extraction.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Tests for extraction repository."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.extraction import Extraction
|
||||||
|
from app.repositories.extraction import ExtractionRepository
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractionRepository:
|
||||||
|
"""Test extraction repository."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session(self):
|
||||||
|
"""Create a mock session."""
|
||||||
|
return Mock(spec=AsyncSession)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def extraction_repo(self, mock_session):
|
||||||
|
"""Create an extraction repository with mock session."""
|
||||||
|
return ExtractionRepository(mock_session)
|
||||||
|
|
||||||
|
def test_init(self, extraction_repo):
|
||||||
|
"""Test repository initialization."""
|
||||||
|
assert extraction_repo.session is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_extraction(self, extraction_repo):
|
||||||
|
"""Test creating an extraction."""
|
||||||
|
extraction_data = {
|
||||||
|
"url": "https://www.youtube.com/watch?v=test",
|
||||||
|
"user_id": 1,
|
||||||
|
"service": "youtube",
|
||||||
|
"service_id": "test123",
|
||||||
|
"title": "Test Video",
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the session operations
|
||||||
|
mock_extraction = Extraction(**extraction_data, id=1)
|
||||||
|
extraction_repo.session.add = Mock()
|
||||||
|
extraction_repo.session.commit = AsyncMock()
|
||||||
|
extraction_repo.session.refresh = AsyncMock()
|
||||||
|
|
||||||
|
# Mock the Extraction constructor to return our mock
|
||||||
|
with pytest.MonkeyPatch().context() as m:
|
||||||
|
m.setattr(
|
||||||
|
"app.repositories.extraction.Extraction",
|
||||||
|
lambda **kwargs: mock_extraction,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await extraction_repo.create(extraction_data)
|
||||||
|
|
||||||
|
assert result == mock_extraction
|
||||||
|
extraction_repo.session.add.assert_called_once()
|
||||||
|
extraction_repo.session.commit.assert_called_once()
|
||||||
|
extraction_repo.session.refresh.assert_called_once_with(mock_extraction)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_service_and_id(self, extraction_repo):
|
||||||
|
"""Test getting extraction by service and service_id."""
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.first.return_value = Extraction(
|
||||||
|
id=1,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
url="https://www.youtube.com/watch?v=test",
|
||||||
|
user_id=1,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
extraction_repo.session.exec = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
result = await extraction_repo.get_by_service_and_id("youtube", "test123")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.service == "youtube"
|
||||||
|
assert result.service_id == "test123"
|
||||||
|
extraction_repo.session.exec.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_extractions(self, extraction_repo):
|
||||||
|
"""Test getting pending extractions."""
|
||||||
|
mock_extraction = Extraction(
|
||||||
|
id=1,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
url="https://www.youtube.com/watch?v=test",
|
||||||
|
user_id=1,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.all.return_value = [mock_extraction]
|
||||||
|
|
||||||
|
extraction_repo.session.exec = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
result = await extraction_repo.get_pending_extractions()
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].status == "pending"
|
||||||
|
extraction_repo.session.exec.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_extraction(self, extraction_repo):
|
||||||
|
"""Test updating an extraction."""
|
||||||
|
extraction = Extraction(
|
||||||
|
id=1,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
url="https://www.youtube.com/watch?v=test",
|
||||||
|
user_id=1,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = {"status": "completed", "sound_id": 42}
|
||||||
|
|
||||||
|
extraction_repo.session.commit = AsyncMock()
|
||||||
|
extraction_repo.session.refresh = AsyncMock()
|
||||||
|
|
||||||
|
result = await extraction_repo.update(extraction, update_data)
|
||||||
|
|
||||||
|
assert result.status == "completed"
|
||||||
|
assert result.sound_id == 42
|
||||||
|
extraction_repo.session.commit.assert_called_once()
|
||||||
|
extraction_repo.session.refresh.assert_called_once_with(extraction)
|
||||||
408
tests/services/test_extraction.py
Normal file
408
tests/services/test_extraction.py
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
"""Tests for extraction service."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.extraction import Extraction
|
||||||
|
from app.models.sound import Sound
|
||||||
|
from app.services.extraction import ExtractionService
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractionService:
|
||||||
|
"""Test extraction service."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session(self):
|
||||||
|
"""Create a mock session."""
|
||||||
|
return Mock(spec=AsyncSession)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def extraction_service(self, mock_session):
|
||||||
|
"""Create an extraction service with mock session."""
|
||||||
|
with patch("app.services.extraction.Path.mkdir"):
|
||||||
|
return ExtractionService(mock_session)
|
||||||
|
|
||||||
|
def test_init(self, extraction_service):
|
||||||
|
"""Test service initialization."""
|
||||||
|
assert extraction_service.session is not None
|
||||||
|
assert extraction_service.extraction_repo is not None
|
||||||
|
assert extraction_service.sound_repo is not None
|
||||||
|
|
||||||
|
def test_sanitize_filename(self, extraction_service):
|
||||||
|
"""Test filename sanitization."""
|
||||||
|
test_cases = [
|
||||||
|
("Hello World", "Hello World"),
|
||||||
|
("Test<>Video", "Test__Video"),
|
||||||
|
("Bad/File\\Name", "Bad_File_Name"),
|
||||||
|
(" Spaces ", "Spaces"),
|
||||||
|
(
|
||||||
|
"Very long filename that exceeds the maximum length limit and should be truncated to 100 characters maximum",
|
||||||
|
"Very long filename that exceeds the maximum length limit and should be truncated to 100 characters m",
|
||||||
|
),
|
||||||
|
("", "untitled"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for input_name, expected in test_cases:
|
||||||
|
result = extraction_service._sanitize_filename(input_name)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
@patch("app.services.extraction.yt_dlp.YoutubeDL")
|
||||||
|
def test_detect_service_info_youtube(self, mock_ydl_class, extraction_service):
|
||||||
|
"""Test service detection for YouTube."""
|
||||||
|
mock_ydl = Mock()
|
||||||
|
mock_ydl_class.return_value.__enter__.return_value = mock_ydl
|
||||||
|
mock_ydl.extract_info.return_value = {
|
||||||
|
"extractor": "youtube",
|
||||||
|
"id": "test123",
|
||||||
|
"title": "Test Video",
|
||||||
|
"duration": 240,
|
||||||
|
"uploader": "Test Channel",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = extraction_service._detect_service_info(
|
||||||
|
"https://www.youtube.com/watch?v=test123"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["service"] == "youtube"
|
||||||
|
assert result["service_id"] == "test123"
|
||||||
|
assert result["title"] == "Test Video"
|
||||||
|
assert result["duration"] == 240
|
||||||
|
|
||||||
|
@patch("app.services.extraction.yt_dlp.YoutubeDL")
|
||||||
|
def test_detect_service_info_failure(self, mock_ydl_class, extraction_service):
|
||||||
|
"""Test service detection failure."""
|
||||||
|
mock_ydl = Mock()
|
||||||
|
mock_ydl_class.return_value.__enter__.return_value = mock_ydl
|
||||||
|
mock_ydl.extract_info.side_effect = Exception("Network error")
|
||||||
|
|
||||||
|
result = extraction_service._detect_service_info("https://invalid.url")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_extraction_success(self, extraction_service):
|
||||||
|
"""Test successful extraction creation."""
|
||||||
|
url = "https://www.youtube.com/watch?v=test123"
|
||||||
|
user_id = 1
|
||||||
|
|
||||||
|
# Mock service detection
|
||||||
|
service_info = {
|
||||||
|
"service": "youtube",
|
||||||
|
"service_id": "test123",
|
||||||
|
"title": "Test Video",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
extraction_service, "_detect_service_info", return_value=service_info
|
||||||
|
):
|
||||||
|
# Mock repository calls
|
||||||
|
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
|
||||||
|
return_value=None
|
||||||
|
)
|
||||||
|
mock_extraction = Extraction(
|
||||||
|
id=1,
|
||||||
|
url=url,
|
||||||
|
user_id=user_id,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
title="Test Video",
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
extraction_service.extraction_repo.create = AsyncMock(
|
||||||
|
return_value=mock_extraction
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await extraction_service.create_extraction(url, user_id)
|
||||||
|
|
||||||
|
assert result["id"] == 1
|
||||||
|
assert result["service"] == "youtube"
|
||||||
|
assert result["service_id"] == "test123"
|
||||||
|
assert result["title"] == "Test Video"
|
||||||
|
assert result["status"] == "pending"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_extraction_duplicate(self, extraction_service):
|
||||||
|
"""Test extraction creation with duplicate service/service_id."""
|
||||||
|
url = "https://www.youtube.com/watch?v=test123"
|
||||||
|
user_id = 1
|
||||||
|
|
||||||
|
# Mock service detection
|
||||||
|
service_info = {
|
||||||
|
"service": "youtube",
|
||||||
|
"service_id": "test123",
|
||||||
|
"title": "Test Video",
|
||||||
|
}
|
||||||
|
|
||||||
|
existing_extraction = Extraction(
|
||||||
|
id=1,
|
||||||
|
url=url,
|
||||||
|
user_id=2, # Different user
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
extraction_service, "_detect_service_info", return_value=service_info
|
||||||
|
):
|
||||||
|
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
|
||||||
|
return_value=existing_extraction
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Extraction already exists"):
|
||||||
|
await extraction_service.create_extraction(url, user_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_extraction_invalid_url(self, extraction_service):
|
||||||
|
"""Test extraction creation with invalid URL."""
|
||||||
|
url = "https://invalid.url"
|
||||||
|
user_id = 1
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
extraction_service, "_detect_service_info", return_value=None
|
||||||
|
):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Unable to detect service information"
|
||||||
|
):
|
||||||
|
await extraction_service.create_extraction(url, user_id)
|
||||||
|
|
||||||
|
def test_ensure_unique_filename(self, extraction_service):
|
||||||
|
"""Test unique filename generation."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_path = Path(temp_dir)
|
||||||
|
|
||||||
|
# Create original file
|
||||||
|
original_file = temp_path / "test.mp3"
|
||||||
|
original_file.touch()
|
||||||
|
|
||||||
|
# Test unique filename generation
|
||||||
|
result = extraction_service._ensure_unique_filename(original_file)
|
||||||
|
expected = temp_path / "test_1.mp3"
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
# Create the first duplicate and test again
|
||||||
|
expected.touch()
|
||||||
|
result = extraction_service._ensure_unique_filename(original_file)
|
||||||
|
expected_2 = temp_path / "test_2.mp3"
|
||||||
|
assert result == expected_2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_sound_record(self, extraction_service):
|
||||||
|
"""Test sound record creation."""
|
||||||
|
# Create temporary audio file
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
|
||||||
|
audio_path = Path(f.name)
|
||||||
|
f.write(b"fake audio data")
|
||||||
|
|
||||||
|
try:
|
||||||
|
extraction = Extraction(
|
||||||
|
id=1,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
title="Test Video",
|
||||||
|
url="https://www.youtube.com/watch?v=test123",
|
||||||
|
user_id=1,
|
||||||
|
status="processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_sound = Sound(
|
||||||
|
id=1,
|
||||||
|
type="EXT",
|
||||||
|
name="Test Video",
|
||||||
|
filename=audio_path.name,
|
||||||
|
duration=240000,
|
||||||
|
size=1024,
|
||||||
|
hash="test_hash",
|
||||||
|
is_deletable=True,
|
||||||
|
is_music=True,
|
||||||
|
is_normalized=False,
|
||||||
|
play_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.extraction.get_audio_duration", return_value=240000
|
||||||
|
),
|
||||||
|
patch("app.services.extraction.get_file_size", return_value=1024),
|
||||||
|
patch(
|
||||||
|
"app.services.extraction.get_file_hash", return_value="test_hash"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
|
||||||
|
extraction_service.sound_repo.create = AsyncMock(
|
||||||
|
return_value=mock_sound
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await extraction_service._create_sound_record(
|
||||||
|
audio_path,
|
||||||
|
extraction.title,
|
||||||
|
extraction.service,
|
||||||
|
extraction.service_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.type == "EXT"
|
||||||
|
assert result.name == "Test Video"
|
||||||
|
assert result.is_deletable is True
|
||||||
|
assert result.is_music is True
|
||||||
|
assert result.is_normalized is False
|
||||||
|
|
||||||
|
finally:
|
||||||
|
audio_path.unlink()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_normalize_sound_success(self, extraction_service):
|
||||||
|
"""Test sound normalization."""
|
||||||
|
sound = Sound(
|
||||||
|
id=1,
|
||||||
|
type="EXT",
|
||||||
|
name="Test Sound",
|
||||||
|
filename="test.mp3",
|
||||||
|
duration=240000,
|
||||||
|
size=1024,
|
||||||
|
hash="test_hash",
|
||||||
|
is_normalized=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_normalizer = Mock()
|
||||||
|
mock_normalizer.normalize_sound = AsyncMock(
|
||||||
|
return_value={"status": "normalized"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.extraction.SoundNormalizerService",
|
||||||
|
return_value=mock_normalizer,
|
||||||
|
):
|
||||||
|
# Should not raise exception
|
||||||
|
await extraction_service._normalize_sound(sound)
|
||||||
|
mock_normalizer.normalize_sound.assert_called_once_with(sound)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_normalize_sound_failure(self, extraction_service):
|
||||||
|
"""Test sound normalization failure."""
|
||||||
|
sound = Sound(
|
||||||
|
id=1,
|
||||||
|
type="EXT",
|
||||||
|
name="Test Sound",
|
||||||
|
filename="test.mp3",
|
||||||
|
duration=240000,
|
||||||
|
size=1024,
|
||||||
|
hash="test_hash",
|
||||||
|
is_normalized=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_normalizer = Mock()
|
||||||
|
mock_normalizer.normalize_sound = AsyncMock(
|
||||||
|
return_value={"status": "error", "error": "Test error"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.extraction.SoundNormalizerService",
|
||||||
|
return_value=mock_normalizer,
|
||||||
|
):
|
||||||
|
# Should not raise exception even on failure
|
||||||
|
await extraction_service._normalize_sound(sound)
|
||||||
|
mock_normalizer.normalize_sound.assert_called_once_with(sound)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_extraction_by_id(self, extraction_service):
|
||||||
|
"""Test getting extraction by ID."""
|
||||||
|
extraction = Extraction(
|
||||||
|
id=1,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
url="https://www.youtube.com/watch?v=test123",
|
||||||
|
user_id=1,
|
||||||
|
title="Test Video",
|
||||||
|
status="completed",
|
||||||
|
sound_id=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
extraction_service.extraction_repo.get_by_id = AsyncMock(
|
||||||
|
return_value=extraction
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await extraction_service.get_extraction_by_id(1)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["id"] == 1
|
||||||
|
assert result["service"] == "youtube"
|
||||||
|
assert result["service_id"] == "test123"
|
||||||
|
assert result["title"] == "Test Video"
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert result["sound_id"] == 42
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_extraction_by_id_not_found(self, extraction_service):
|
||||||
|
"""Test getting extraction by ID when not found."""
|
||||||
|
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
result = await extraction_service.get_extraction_by_id(999)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_extractions(self, extraction_service):
|
||||||
|
"""Test getting user extractions."""
|
||||||
|
extractions = [
|
||||||
|
Extraction(
|
||||||
|
id=1,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
url="https://www.youtube.com/watch?v=test123",
|
||||||
|
user_id=1,
|
||||||
|
title="Test Video 1",
|
||||||
|
status="completed",
|
||||||
|
sound_id=42,
|
||||||
|
),
|
||||||
|
Extraction(
|
||||||
|
id=2,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test456",
|
||||||
|
url="https://www.youtube.com/watch?v=test456",
|
||||||
|
user_id=1,
|
||||||
|
title="Test Video 2",
|
||||||
|
status="pending",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
extraction_service.extraction_repo.get_by_user = AsyncMock(
|
||||||
|
return_value=extractions
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await extraction_service.get_user_extractions(1)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["id"] == 1
|
||||||
|
assert result[0]["title"] == "Test Video 1"
|
||||||
|
assert result[1]["id"] == 2
|
||||||
|
assert result[1]["title"] == "Test Video 2"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_extractions(self, extraction_service):
|
||||||
|
"""Test getting pending extractions."""
|
||||||
|
pending_extractions = [
|
||||||
|
Extraction(
|
||||||
|
id=1,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
url="https://www.youtube.com/watch?v=test123",
|
||||||
|
user_id=1,
|
||||||
|
title="Pending Video",
|
||||||
|
status="pending",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
extraction_service.extraction_repo.get_pending_extractions = AsyncMock(
|
||||||
|
return_value=pending_extractions
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await extraction_service.get_pending_extractions()
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["id"] == 1
|
||||||
|
assert result[0]["status"] == "pending"
|
||||||
298
tests/services/test_extraction_processor.py
Normal file
298
tests/services/test_extraction_processor.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
"""Tests for extraction background processor."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.extraction_processor import ExtractionProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractionProcessor:
|
||||||
|
"""Test extraction background processor."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def processor(self):
|
||||||
|
"""Create an extraction processor instance."""
|
||||||
|
# Use a custom processor instance to avoid affecting the global one
|
||||||
|
return ExtractionProcessor()
|
||||||
|
|
||||||
|
def test_init(self, processor):
|
||||||
|
"""Test processor initialization."""
|
||||||
|
assert processor.max_concurrent > 0
|
||||||
|
assert len(processor.running_extractions) == 0
|
||||||
|
assert processor.processing_lock is not None
|
||||||
|
assert processor.shutdown_event is not None
|
||||||
|
assert processor.processor_task is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_and_stop(self, processor):
|
||||||
|
"""Test starting and stopping the processor."""
|
||||||
|
# Mock the _process_queue method to avoid actual processing
|
||||||
|
with patch.object(processor, "_process_queue", new_callable=AsyncMock) as mock_process:
|
||||||
|
|
||||||
|
# Start the processor
|
||||||
|
await processor.start()
|
||||||
|
assert processor.processor_task is not None
|
||||||
|
assert not processor.processor_task.done()
|
||||||
|
|
||||||
|
# Stop the processor
|
||||||
|
await processor.stop()
|
||||||
|
assert processor.processor_task.done()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_already_running(self, processor):
|
||||||
|
"""Test starting processor when already running."""
|
||||||
|
with patch.object(processor, "_process_queue", new_callable=AsyncMock):
|
||||||
|
|
||||||
|
# Start first time
|
||||||
|
await processor.start()
|
||||||
|
first_task = processor.processor_task
|
||||||
|
|
||||||
|
# Start second time (should not create new task)
|
||||||
|
await processor.start()
|
||||||
|
assert processor.processor_task is first_task
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await processor.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_queue_extraction(self, processor):
|
||||||
|
"""Test queuing an extraction."""
|
||||||
|
extraction_id = 123
|
||||||
|
|
||||||
|
await processor.queue_extraction(extraction_id)
|
||||||
|
# The extraction should not be in running_extractions yet
|
||||||
|
# (it gets added when actually started by the processor)
|
||||||
|
assert extraction_id not in processor.running_extractions
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_queue_extraction_already_running(self, processor):
|
||||||
|
"""Test queuing an extraction that's already running."""
|
||||||
|
extraction_id = 123
|
||||||
|
processor.running_extractions.add(extraction_id)
|
||||||
|
|
||||||
|
await processor.queue_extraction(extraction_id)
|
||||||
|
# Should still be in running extractions
|
||||||
|
assert extraction_id in processor.running_extractions
|
||||||
|
|
||||||
|
def test_get_status(self, processor):
|
||||||
|
"""Test getting processor status."""
|
||||||
|
status = processor.get_status()
|
||||||
|
|
||||||
|
assert "running" in status
|
||||||
|
assert "max_concurrent" in status
|
||||||
|
assert "currently_processing" in status
|
||||||
|
assert "processing_ids" in status
|
||||||
|
assert "available_slots" in status
|
||||||
|
|
||||||
|
assert status["max_concurrent"] == processor.max_concurrent
|
||||||
|
assert status["currently_processing"] == 0
|
||||||
|
assert status["available_slots"] == processor.max_concurrent
|
||||||
|
|
||||||
|
def test_get_status_with_running_extractions(self, processor):
|
||||||
|
"""Test getting processor status with running extractions."""
|
||||||
|
processor.running_extractions.add(123)
|
||||||
|
processor.running_extractions.add(456)
|
||||||
|
|
||||||
|
status = processor.get_status()
|
||||||
|
|
||||||
|
assert status["currently_processing"] == 2
|
||||||
|
assert status["available_slots"] == processor.max_concurrent - 2
|
||||||
|
assert 123 in status["processing_ids"]
|
||||||
|
assert 456 in status["processing_ids"]
|
||||||
|
|
||||||
|
def test_on_extraction_completed(self, processor):
|
||||||
|
"""Test extraction completion callback."""
|
||||||
|
extraction_id = 123
|
||||||
|
processor.running_extractions.add(extraction_id)
|
||||||
|
|
||||||
|
# Create a mock completed task
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.exception.return_value = None
|
||||||
|
|
||||||
|
processor._on_extraction_completed(extraction_id, mock_task)
|
||||||
|
|
||||||
|
# Should be removed from running extractions
|
||||||
|
assert extraction_id not in processor.running_extractions
|
||||||
|
|
||||||
|
def test_on_extraction_completed_with_exception(self, processor):
|
||||||
|
"""Test extraction completion callback with exception."""
|
||||||
|
extraction_id = 123
|
||||||
|
processor.running_extractions.add(extraction_id)
|
||||||
|
|
||||||
|
# Create a mock task with exception
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_task.exception.return_value = Exception("Test error")
|
||||||
|
|
||||||
|
processor._on_extraction_completed(extraction_id, mock_task)
|
||||||
|
|
||||||
|
# Should still be removed from running extractions
|
||||||
|
assert extraction_id not in processor.running_extractions
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_single_extraction_success(self, processor):
|
||||||
|
"""Test processing a single extraction successfully."""
|
||||||
|
extraction_id = 123
|
||||||
|
|
||||||
|
# Mock the extraction service
|
||||||
|
mock_service = Mock()
|
||||||
|
mock_service.process_extraction = AsyncMock(
|
||||||
|
return_value={"status": "completed", "id": extraction_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.AsyncSession"
|
||||||
|
) as mock_session_class,
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.ExtractionService",
|
||||||
|
return_value=mock_service,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
await processor._process_single_extraction(extraction_id)
|
||||||
|
|
||||||
|
mock_service.process_extraction.assert_called_once_with(extraction_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_single_extraction_failure(self, processor):
|
||||||
|
"""Test processing a single extraction with failure."""
|
||||||
|
extraction_id = 123
|
||||||
|
|
||||||
|
# Mock the extraction service to raise an exception
|
||||||
|
mock_service = Mock()
|
||||||
|
mock_service.process_extraction = AsyncMock(side_effect=Exception("Test error"))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.AsyncSession"
|
||||||
|
) as mock_session_class,
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.ExtractionService",
|
||||||
|
return_value=mock_service,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
# Should not raise exception (errors are logged)
|
||||||
|
await processor._process_single_extraction(extraction_id)
|
||||||
|
|
||||||
|
mock_service.process_extraction.assert_called_once_with(extraction_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_pending_extractions_no_slots(self, processor):
|
||||||
|
"""Test processing when no slots are available."""
|
||||||
|
# Fill all slots
|
||||||
|
for i in range(processor.max_concurrent):
|
||||||
|
processor.running_extractions.add(i)
|
||||||
|
|
||||||
|
# Mock extraction service
|
||||||
|
mock_service = Mock()
|
||||||
|
mock_service.get_pending_extractions = AsyncMock(
|
||||||
|
return_value=[{"id": 100, "status": "pending"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.AsyncSession"
|
||||||
|
) as mock_session_class,
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.ExtractionService",
|
||||||
|
return_value=mock_service,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
await processor._process_pending_extractions()
|
||||||
|
|
||||||
|
# Should not have started any new extractions
|
||||||
|
assert 100 not in processor.running_extractions
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_pending_extractions_with_slots(self, processor):
|
||||||
|
"""Test processing when slots are available."""
|
||||||
|
# Mock extraction service
|
||||||
|
mock_service = Mock()
|
||||||
|
mock_service.get_pending_extractions = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"id": 100, "status": "pending"},
|
||||||
|
{"id": 101, "status": "pending"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.AsyncSession"
|
||||||
|
) as mock_session_class,
|
||||||
|
patch.object(processor, "_process_single_extraction", new_callable=AsyncMock) as mock_process,
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.ExtractionService",
|
||||||
|
return_value=mock_service,
|
||||||
|
),
|
||||||
|
patch("asyncio.create_task") as mock_create_task,
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
# Mock task creation
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_create_task.return_value = mock_task
|
||||||
|
|
||||||
|
await processor._process_pending_extractions()
|
||||||
|
|
||||||
|
# Should have added extractions to running set
|
||||||
|
assert 100 in processor.running_extractions
|
||||||
|
assert 101 in processor.running_extractions
|
||||||
|
|
||||||
|
# Should have created tasks for both
|
||||||
|
assert mock_create_task.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_pending_extractions_respect_limit(self, processor):
|
||||||
|
"""Test that processing respects concurrency limit."""
|
||||||
|
# Set max concurrent to 1 for this test
|
||||||
|
processor.max_concurrent = 1
|
||||||
|
|
||||||
|
# Mock extraction service with multiple pending extractions
|
||||||
|
mock_service = Mock()
|
||||||
|
mock_service.get_pending_extractions = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"id": 100, "status": "pending"},
|
||||||
|
{"id": 101, "status": "pending"},
|
||||||
|
{"id": 102, "status": "pending"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.AsyncSession"
|
||||||
|
) as mock_session_class,
|
||||||
|
patch.object(processor, "_process_single_extraction", new_callable=AsyncMock) as mock_process,
|
||||||
|
patch(
|
||||||
|
"app.services.extraction_processor.ExtractionService",
|
||||||
|
return_value=mock_service,
|
||||||
|
),
|
||||||
|
patch("asyncio.create_task") as mock_create_task,
|
||||||
|
):
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session_class.return_value.__aenter__.return_value = mock_session
|
||||||
|
|
||||||
|
# Mock task creation
|
||||||
|
mock_task = Mock()
|
||||||
|
mock_create_task.return_value = mock_task
|
||||||
|
|
||||||
|
await processor._process_pending_extractions()
|
||||||
|
|
||||||
|
# Should only have started one extraction (due to limit)
|
||||||
|
assert len(processor.running_extractions) == 1
|
||||||
|
assert mock_create_task.call_count == 1
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Tests for OAuth service."""
|
"""Tests for OAuth service."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -117,7 +117,7 @@ class TestGoogleOAuthProvider:
|
|||||||
"picture": "https://example.com/avatar.jpg",
|
"picture": "https://example.com/avatar.jpg",
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch("httpx.AsyncClient.get") as mock_get:
|
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.json.return_value = mock_response_data
|
mock_response.json.return_value = mock_response_data
|
||||||
@@ -162,7 +162,7 @@ class TestGitHubOAuthProvider:
|
|||||||
{"email": "secondary@example.com", "primary": False, "verified": True},
|
{"email": "secondary@example.com", "primary": False, "verified": True},
|
||||||
]
|
]
|
||||||
|
|
||||||
with patch("httpx.AsyncClient.get") as mock_get:
|
with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
|
||||||
# Mock user profile response
|
# Mock user profile response
|
||||||
mock_user_response = Mock()
|
mock_user_response = Mock()
|
||||||
mock_user_response.status_code = 200
|
mock_user_response.status_code = 200
|
||||||
@@ -174,7 +174,7 @@ class TestGitHubOAuthProvider:
|
|||||||
mock_emails_response.json.return_value = mock_emails_data
|
mock_emails_response.json.return_value = mock_emails_data
|
||||||
|
|
||||||
# Return different responses based on URL
|
# Return different responses based on URL
|
||||||
def side_effect(url, **kwargs):
|
async def side_effect(url, **kwargs):
|
||||||
if "user/emails" in str(url):
|
if "user/emails" in str(url):
|
||||||
return mock_emails_response
|
return mock_emails_response
|
||||||
return mock_user_response
|
return mock_user_response
|
||||||
|
|||||||
Reference in New Issue
Block a user