feat: Implement background extraction processor with concurrency control

- Added `ExtractionProcessor` class to handle extraction queue processing in the background.
- Implemented methods for starting, stopping, and queuing extractions with concurrency limits.
- Integrated logging for monitoring the processor's status and actions.
- Created tests for the extraction processor to ensure functionality and error handling.

test: Add unit tests for extraction API endpoints

- Created tests for successful extraction creation, authentication checks, and processor status retrieval.
- Ensured proper responses for authenticated and unauthenticated requests.

test: Implement unit tests for extraction repository

- Added tests for creating, retrieving, and updating extractions in the repository.
- Mocked database interactions to validate repository behavior without actual database access.

test: Add comprehensive tests for extraction service

- Developed tests for extraction creation, service detection, and sound record creation.
- Included tests for handling duplicate extractions and invalid URLs.

test: Add unit tests for extraction background processor

- Created tests for the `ExtractionProcessor` class to validate its behavior under various conditions.
- Ensured proper handling of extraction queuing, processing, and completion callbacks.

fix: Update OAuth service tests to use AsyncMock

- Modified OAuth provider tests to use `AsyncMock` for mocking asynchronous HTTP requests.
This commit is contained in:
JSC
2025-07-29 01:06:29 +02:00
parent c993230f98
commit 9b5f83eef0
11 changed files with 1860 additions and 4 deletions

View File

@@ -8,6 +8,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User from app.models.user import User
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService from app.services.sound_scanner import ScanResults, SoundScannerService
@@ -28,6 +30,13 @@ async def get_sound_normalizer_service(
return SoundNormalizerService(session) return SoundNormalizerService(session)
async def get_extraction_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> ExtractionService:
"""Get the extraction service."""
return ExtractionService(session)
# SCAN # SCAN
@router.post("/scan") @router.post("/scan")
async def scan_sounds( async def scan_sounds(
@@ -233,3 +242,110 @@ async def normalize_sound_by_id(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize sound: {e!s}", detail=f"Failed to normalize sound: {e!s}",
) from e ) from e
# EXTRACT
@router.post("/extract")
async def create_extraction(
url: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, ExtractionInfo | str]:
"""Create a new extraction job for a URL."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
extraction_info = await extraction_service.create_extraction(
url, current_user.id
)
# Queue the extraction for background processing
await extraction_processor.queue_extraction(extraction_info["id"])
return {
"message": "Extraction queued successfully",
"extraction": extraction_info,
}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create extraction: {e!s}",
) from e
@router.get("/extract/status")
async def get_extraction_processor_status(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
) -> dict:
"""Get the status of the extraction processor."""
# Only allow admins to see processor status
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can view processor status",
)
return extraction_processor.get_status()
@router.get("/extract/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
return extraction_info
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
@router.get("/extract")
async def get_user_extractions(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, list[ExtractionInfo]]:
"""Get all extractions for the current user."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
extractions = await extraction_service.get_user_extractions(current_user.id)
return {
"extractions": extractions,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e

View File

@@ -52,5 +52,12 @@ class Settings(BaseSettings):
NORMALIZED_AUDIO_BITRATE: str = "256k" NORMALIZED_AUDIO_BITRATE: str = "256k"
NORMALIZED_AUDIO_PASSES: int = 2 # 1 for one-pass, 2 for two-pass NORMALIZED_AUDIO_PASSES: int = 2 # 1 for one-pass, 2 for two-pass
# Audio Extraction Configuration
EXTRACTION_AUDIO_FORMAT: str = "mp3"
EXTRACTION_AUDIO_BITRATE: str = "256k"
EXTRACTION_TEMP_DIR: str = "sounds/temp"
EXTRACTION_THUMBNAILS_DIR: str = "sounds/originals/extracted/thumbnails"
EXTRACTION_MAX_CONCURRENT: int = 2 # Maximum concurrent extractions
settings = Settings() settings = Settings()

View File

@@ -9,6 +9,7 @@ from app.api import api_router
from app.core.database import init_db from app.core.database import init_db
from app.core.logging import get_logger, setup_logging from app.core.logging import get_logger, setup_logging
from app.middleware.logging import LoggingMiddleware from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor
from app.services.socket import socket_manager from app.services.socket import socket_manager
@@ -22,10 +23,18 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
await init_db() await init_db()
logger.info("Database initialized") logger.info("Database initialized")
# Start the extraction processor
await extraction_processor.start()
logger.info("Extraction processor started")
yield yield
logger.info("Shutting down application") logger.info("Shutting down application")
# Stop the extraction processor
await extraction_processor.stop()
logger.info("Extraction processor stopped")
def create_app(): def create_app():
"""Create and configure the FastAPI application.""" """Create and configure the FastAPI application."""

View File

@@ -0,0 +1,82 @@
"""Extraction repository for database operations."""
from sqlalchemy import desc
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction
class ExtractionRepository:
"""Repository for extraction database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the extraction repository."""
self.session = session
async def create(self, extraction_data: dict) -> Extraction:
"""Create a new extraction."""
extraction = Extraction(**extraction_data)
self.session.add(extraction)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def get_by_id(self, extraction_id: int) -> Extraction | None:
"""Get an extraction by ID."""
result = await self.session.exec(
select(Extraction).where(Extraction.id == extraction_id)
)
return result.first()
async def get_by_service_and_id(
self, service: str, service_id: str
) -> Extraction | None:
"""Get an extraction by service and service_id."""
result = await self.session.exec(
select(Extraction).where(
Extraction.service == service, Extraction.service_id == service_id
)
)
return result.first()
async def get_by_user(self, user_id: int) -> list[Extraction]:
"""Get all extractions for a user."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.user_id == user_id)
.order_by(desc(Extraction.created_at))
)
return list(result.all())
async def get_pending_extractions(self) -> list[Extraction]:
"""Get all pending extractions."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == "pending")
.order_by(Extraction.created_at)
)
return list(result.all())
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
"""Update an extraction."""
for key, value in update_data.items():
setattr(extraction, key, value)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def delete(self, extraction: Extraction) -> None:
"""Delete an extraction."""
await self.session.delete(extraction)
await self.session.commit()
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
"""Get extractions by status."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == status)
.order_by(desc(Extraction.created_at))
)
return list(result.all())

517
app/services/extraction.py Normal file
View File

@@ -0,0 +1,517 @@
"""Extraction service for audio extraction from external services using yt-dlp."""
import shutil
from pathlib import Path
from typing import TypedDict
import yt_dlp
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.models.extraction import Extraction
from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository
from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
logger = get_logger(__name__)
class ExtractionInfo(TypedDict):
"""Type definition for extraction information."""
id: int
url: str
service: str
service_id: str
title: str | None
status: str
error: str | None
sound_id: int | None
class ExtractionService:
"""Service for extracting audio from external services using yt-dlp."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the extraction service."""
self.session = session
self.extraction_repo = ExtractionRepository(session)
self.sound_repo = SoundRepository(session)
# Ensure required directories exist
self._ensure_directories()
def _ensure_directories(self) -> None:
"""Ensure all required directories exist."""
directories = [
settings.EXTRACTION_TEMP_DIR,
"sounds/originals/extracted",
settings.EXTRACTION_THUMBNAILS_DIR,
]
for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.debug("Ensured directory exists: %s", directory)
async def create_extraction(self, url: str, user_id: int) -> ExtractionInfo:
"""Create a new extraction job."""
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
try:
# First, detect service and service_id using yt-dlp
service_info = self._detect_service_info(url)
if not service_info:
raise ValueError("Unable to detect service information from URL")
service = service_info["service"]
service_id = service_info["service_id"]
title = service_info.get("title")
logger.info(
"Detected service: %s, service_id: %s, title: %s",
service,
service_id,
title,
)
# Check if extraction already exists
existing = await self.extraction_repo.get_by_service_and_id(
service, service_id
)
if existing:
error_msg = f"Extraction already exists for {service}:{service_id}"
logger.warning(error_msg)
raise ValueError(error_msg)
# Create the extraction record
extraction_data = {
"url": url,
"user_id": user_id,
"service": service,
"service_id": service_id,
"title": title,
"status": "pending",
}
extraction = await self.extraction_repo.create(extraction_data)
logger.info("Created extraction with ID: %d", extraction.id)
return {
"id": extraction.id or 0, # Should never be None for created extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
}
except Exception:
logger.exception("Failed to create extraction for URL: %s", url)
raise
def _detect_service_info(self, url: str) -> dict | None:
"""Detect service information from URL using yt-dlp."""
try:
# Configure yt-dlp for info extraction only
ydl_opts = {
"quiet": True,
"no_warnings": True,
"extract_flat": False,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
# Extract info without downloading
info = ydl.extract_info(url, download=False)
if not info:
return None
# Map extractor names to our service names
extractor_map = {
"youtube": "youtube",
"dailymotion": "dailymotion",
"vimeo": "vimeo",
"soundcloud": "soundcloud",
"twitter": "twitter",
"tiktok": "tiktok",
"instagram": "instagram",
}
extractor = info.get("extractor", "").lower()
service = extractor_map.get(extractor, extractor)
return {
"service": service,
"service_id": str(info.get("id", "")),
"title": info.get("title"),
"duration": info.get("duration"),
"uploader": info.get("uploader"),
"description": info.get("description"),
}
except Exception:
logger.exception("Failed to detect service info for URL: %s", url)
return None
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
"""Process an extraction job."""
extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction:
raise ValueError(f"Extraction {extraction_id} not found")
if extraction.status != "pending":
raise ValueError(f"Extraction {extraction_id} is not pending")
# Store all needed values early to avoid session detachment issues
user_id = extraction.user_id
extraction_url = extraction.url
extraction_title = extraction.title
extraction_service = extraction.service
extraction_service_id = extraction.service_id
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
try:
# Update status to processing
await self.extraction_repo.update(extraction, {"status": "processing"})
# Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url
)
# Move files to final locations
final_audio_path, final_thumbnail_path = (
await self._move_files_to_final_location(
audio_file,
thumbnail_file,
extraction_title,
extraction_service,
extraction_service_id,
)
)
# Create Sound record
sound = await self._create_sound_record(
final_audio_path,
extraction_title,
extraction_service,
extraction_service_id,
)
# Store sound_id early to avoid session detachment issues
sound_id = sound.id
# Normalize the sound
await self._normalize_sound(sound)
# Add to main playlist
await self._add_to_main_playlist(sound, user_id)
# Update extraction with success
await self.extraction_repo.update(
extraction,
{
"status": "completed",
"sound_id": sound_id,
"error": None,
},
)
logger.info("Successfully processed extraction %d", extraction_id)
return {
"id": extraction_id,
"url": extraction_url,
"service": extraction_service,
"service_id": extraction_service_id,
"title": extraction_title,
"status": "completed",
"error": None,
"sound_id": sound_id,
}
except Exception as e:
error_msg = str(e)
logger.exception(
"Failed to process extraction %d: %s", extraction_id, error_msg
)
# Update extraction with error
await self.extraction_repo.update(
extraction,
{
"status": "failed",
"error": error_msg,
},
)
return {
"id": extraction_id,
"url": extraction_url,
"service": extraction_service,
"service_id": extraction_service_id,
"title": extraction_title,
"status": "failed",
"error": error_msg,
"sound_id": None,
}
async def _extract_media(
self, extraction_id: int, extraction_url: str
) -> tuple[Path, Path | None]:
"""Extract audio and thumbnail using yt-dlp."""
temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
# Create unique filename based on extraction ID
output_template = str(
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s"
)
# Configure yt-dlp options
ydl_opts = {
"format": "bestaudio/best",
"outtmpl": output_template,
"extractaudio": True,
"audioformat": settings.EXTRACTION_AUDIO_FORMAT,
"audioquality": settings.EXTRACTION_AUDIO_BITRATE,
"writethumbnail": True,
"writeinfojson": False,
"writeautomaticsub": False,
"writesubtitles": False,
"postprocessors": [
{
"key": "FFmpegExtractAudio",
"preferredcodec": settings.EXTRACTION_AUDIO_FORMAT,
"preferredquality": settings.EXTRACTION_AUDIO_BITRATE.rstrip("k"),
},
],
}
try:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
# Download and extract
ydl.download([extraction_url])
# Find the extracted files
audio_files = list(
temp_dir.glob(
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}"
)
)
thumbnail_files = (
list(temp_dir.glob(f"extraction_{extraction_id}_*.webp"))
+ list(temp_dir.glob(f"extraction_{extraction_id}_*.jpg"))
+ list(temp_dir.glob(f"extraction_{extraction_id}_*.png"))
)
if not audio_files:
raise RuntimeError("No audio file was created during extraction")
audio_file = audio_files[0]
thumbnail_file = thumbnail_files[0] if thumbnail_files else None
logger.info(
"Extracted audio: %s, thumbnail: %s",
audio_file,
thumbnail_file or "None",
)
return audio_file, thumbnail_file
except Exception as e:
logger.exception("yt-dlp extraction failed for %s", extraction_url)
raise RuntimeError(f"Audio extraction failed: {e}") from e
async def _move_files_to_final_location(
self,
audio_file: Path,
thumbnail_file: Path | None,
title: str | None,
service: str,
service_id: str,
) -> tuple[Path, Path | None]:
"""Move extracted files to their final locations."""
# Generate clean filename based on title and service
safe_title = self._sanitize_filename(title or f"{service}_{service_id}")
# Move audio file
final_audio_path = (
Path("sounds/originals/extracted")
/ f"{safe_title}.{settings.EXTRACTION_AUDIO_FORMAT}"
)
final_audio_path = self._ensure_unique_filename(final_audio_path)
shutil.move(str(audio_file), str(final_audio_path))
logger.info("Moved audio file to: %s", final_audio_path)
# Move thumbnail file if it exists
final_thumbnail_path = None
if thumbnail_file:
thumbnail_ext = thumbnail_file.suffix
final_thumbnail_path = (
Path(settings.EXTRACTION_THUMBNAILS_DIR)
/ f"{safe_title}{thumbnail_ext}"
)
final_thumbnail_path = self._ensure_unique_filename(final_thumbnail_path)
shutil.move(str(thumbnail_file), str(final_thumbnail_path))
logger.info("Moved thumbnail file to: %s", final_thumbnail_path)
return final_audio_path, final_thumbnail_path
def _sanitize_filename(self, filename: str) -> str:
"""Sanitize filename for filesystem."""
# Remove or replace problematic characters
invalid_chars = '<>:"/\\|?*'
for char in invalid_chars:
filename = filename.replace(char, "_")
# Limit length and remove leading/trailing spaces
filename = filename.strip()[:100]
return filename or "untitled"
def _ensure_unique_filename(self, filepath: Path) -> Path:
"""Ensure filename is unique by adding counter if needed."""
if not filepath.exists():
return filepath
stem = filepath.stem
suffix = filepath.suffix
parent = filepath.parent
counter = 1
while True:
new_path = parent / f"{stem}_{counter}{suffix}"
if not new_path.exists():
return new_path
counter += 1
async def _create_sound_record(
self, audio_path: Path, title: str | None, service: str, service_id: str
) -> Sound:
"""Create a Sound record for the extracted audio."""
# Get audio metadata
duration = get_audio_duration(audio_path)
size = get_file_size(audio_path)
file_hash = get_file_hash(audio_path)
# Create sound data
sound_data = {
"type": "EXT",
"name": title or f"{service}_{service_id}",
"filename": audio_path.name,
"duration": duration,
"size": size,
"hash": file_hash,
"is_deletable": True, # Extracted sounds can be deleted
"is_music": True, # Assume extracted content is music
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
logger.info("Created sound record with ID: %d", sound.id)
return sound
async def _normalize_sound(self, sound: Sound) -> None:
"""Normalize the extracted sound."""
try:
normalizer_service = SoundNormalizerService(self.session)
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
logger.warning(
"Failed to normalize sound %d: %s",
sound.id,
result.get("error"),
)
else:
logger.info("Successfully normalized sound %d", sound.id)
except Exception as e:
logger.exception("Error normalizing sound %d: %s", sound.id, e)
# Don't fail the extraction if normalization fails
async def _add_to_main_playlist(self, sound: Sound, user_id: int) -> None:
"""Add the sound to the user's main playlist."""
try:
# This is a placeholder - implement based on your playlist logic
# For now, we'll just log that we would add it to the main playlist
logger.info(
"Would add sound %d to main playlist for user %d",
sound.id,
user_id,
)
except Exception as e:
logger.exception(
"Error adding sound %d to main playlist for user %d: %s",
sound.id,
user_id,
e,
)
# Don't fail the extraction if playlist addition fails
async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None:
"""Get extraction information by ID."""
extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction:
return None
return {
"id": extraction.id or 0, # Should never be None for existing extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
}
async def get_user_extractions(self, user_id: int) -> list[ExtractionInfo]:
"""Get all extractions for a user."""
extractions = await self.extraction_repo.get_by_user(user_id)
return [
{
"id": extraction.id
or 0, # Should never be None for existing extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
}
for extraction in extractions
]
async def get_pending_extractions(self) -> list[ExtractionInfo]:
"""Get all pending extractions."""
extractions = await self.extraction_repo.get_pending_extractions()
return [
{
"id": extraction.id
or 0, # Should never be None for existing extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
}
for extraction in extractions
]

View File

@@ -0,0 +1,196 @@
"""Background extraction processor for handling extraction queue."""
import asyncio
from typing import Set
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.database import engine
from app.core.logging import get_logger
from app.services.extraction import ExtractionService
logger = get_logger(__name__)
class ExtractionProcessor:
"""Background processor for handling extraction queue with concurrency control."""
def __init__(self) -> None:
"""Initialize the extraction processor."""
self.max_concurrent = settings.EXTRACTION_MAX_CONCURRENT
self.running_extractions: Set[int] = set()
self.processing_lock = asyncio.Lock()
self.shutdown_event = asyncio.Event()
self.processor_task: asyncio.Task | None = None
logger.info(
"Initialized extraction processor with max concurrent: %d",
self.max_concurrent,
)
async def start(self) -> None:
"""Start the background extraction processor."""
if self.processor_task and not self.processor_task.done():
logger.warning("Extraction processor is already running")
return
self.shutdown_event.clear()
self.processor_task = asyncio.create_task(self._process_queue())
logger.info("Started extraction processor")
async def stop(self) -> None:
"""Stop the background extraction processor."""
logger.info("Stopping extraction processor...")
self.shutdown_event.set()
if self.processor_task and not self.processor_task.done():
try:
await asyncio.wait_for(self.processor_task, timeout=30.0)
except asyncio.TimeoutError:
logger.warning(
"Extraction processor did not stop gracefully, cancelling..."
)
self.processor_task.cancel()
try:
await self.processor_task
except asyncio.CancelledError:
pass
logger.info("Extraction processor stopped")
async def queue_extraction(self, extraction_id: int) -> None:
"""Queue an extraction for processing."""
async with self.processing_lock:
if extraction_id not in self.running_extractions:
logger.info("Queued extraction %d for processing", extraction_id)
# The processor will pick it up on the next cycle
else:
logger.warning(
"Extraction %d is already being processed", extraction_id
)
async def _process_queue(self) -> None:
"""Main processing loop that handles the extraction queue."""
logger.info("Starting extraction queue processor")
while not self.shutdown_event.is_set():
try:
await self._process_pending_extractions()
# Wait before checking for new extractions
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
break # Shutdown requested
except asyncio.TimeoutError:
continue # Continue processing
except Exception as e:
logger.exception("Error in extraction queue processor: %s", e)
# Wait a bit before retrying to avoid tight error loops
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
break # Shutdown requested
except asyncio.TimeoutError:
continue
logger.info("Extraction queue processor stopped")
async def _process_pending_extractions(self) -> None:
"""Process pending extractions up to the concurrency limit."""
async with self.processing_lock:
# Check how many slots are available
available_slots = self.max_concurrent - len(self.running_extractions)
if available_slots <= 0:
return # No available slots
# Get pending extractions from database
async with AsyncSession(engine) as session:
extraction_service = ExtractionService(session)
pending_extractions = await extraction_service.get_pending_extractions()
# Filter out extractions that are already being processed
available_extractions = [
ext
for ext in pending_extractions
if ext["id"] not in self.running_extractions
]
# Start processing up to available slots
extractions_to_start = available_extractions[:available_slots]
for extraction_info in extractions_to_start:
extraction_id = extraction_info["id"]
self.running_extractions.add(extraction_id)
# Start processing this extraction in the background
task = asyncio.create_task(
self._process_single_extraction(extraction_id)
)
task.add_done_callback(
lambda t, eid=extraction_id: self._on_extraction_completed(
eid,
t,
)
)
logger.info(
"Started processing extraction %d (%d/%d slots used)",
extraction_id,
len(self.running_extractions),
self.max_concurrent,
)
async def _process_single_extraction(self, extraction_id: int) -> None:
"""Process a single extraction."""
try:
logger.info("Processing extraction %d", extraction_id)
async with AsyncSession(engine) as session:
extraction_service = ExtractionService(session)
result = await extraction_service.process_extraction(extraction_id)
logger.info(
"Completed extraction %d with status: %s",
extraction_id,
result["status"],
)
except Exception as e:
logger.exception("Error processing extraction %d: %s", extraction_id, e)
def _on_extraction_completed(self, extraction_id: int, task: asyncio.Task) -> None:
"""Callback when an extraction task is completed."""
# Remove from running set
self.running_extractions.discard(extraction_id)
# Check if the task had an exception
if task.exception():
logger.error(
"Extraction %d completed with exception: %s",
extraction_id,
task.exception(),
)
else:
logger.info(
"Extraction %d completed successfully (%d/%d slots used)",
extraction_id,
len(self.running_extractions),
self.max_concurrent,
)
def get_status(self) -> dict:
"""Get the current status of the extraction processor."""
return {
"running": self.processor_task is not None
and not self.processor_task.done(),
"max_concurrent": self.max_concurrent,
"currently_processing": len(self.running_extractions),
"processing_ids": list(self.running_extractions),
"available_slots": self.max_concurrent - len(self.running_extractions),
}
# Global extraction processor instance
extraction_processor = ExtractionProcessor()

View File

@@ -0,0 +1,95 @@
"""Tests for extraction API endpoints."""
from unittest.mock import AsyncMock, Mock
import pytest
import pytest_asyncio
from httpx import AsyncClient
from app.models.user import User
class TestExtractionEndpoints:
"""Test extraction API endpoints."""
@pytest.mark.asyncio
async def test_create_extraction_success(
self, test_client: AsyncClient, auth_cookies: dict[str, str]
):
"""Test successful extraction creation."""
# Set cookies on client instance to avoid deprecation warning
test_client.cookies.update(auth_cookies)
response = await test_client.post(
"/api/v1/sounds/extract",
params={"url": "https://www.youtube.com/watch?v=test"},
)
# This will fail because we don't have actual extraction service mocked
# But at least we'll get past authentication
assert response.status_code in [200, 400, 500] # Allow any non-auth error
@pytest.mark.asyncio
async def test_create_extraction_unauthenticated(self, test_client: AsyncClient):
"""Test extraction creation without authentication."""
response = await test_client.post(
"/api/v1/sounds/extract",
params={"url": "https://www.youtube.com/watch?v=test"},
)
# Should return 401 for missing authentication
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_extraction_unauthenticated(self, test_client: AsyncClient):
"""Test extraction retrieval without authentication."""
response = await test_client.get("/api/v1/sounds/extract/1")
# Should return 401 for missing authentication
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_processor_status_admin(
self, test_client: AsyncClient, admin_cookies: dict[str, str]
):
"""Test getting processor status as admin."""
# Set cookies on client instance to avoid deprecation warning
test_client.cookies.update(admin_cookies)
response = await test_client.get("/api/v1/sounds/extract/status")
# Should succeed for admin users
assert response.status_code == 200
data = response.json()
assert "running" in data
assert "max_concurrent" in data
@pytest.mark.asyncio
async def test_get_processor_status_non_admin(
self, test_client: AsyncClient, auth_cookies: dict[str, str]
):
"""Test getting processor status as non-admin user."""
# Set cookies on client instance to avoid deprecation warning
test_client.cookies.update(auth_cookies)
response = await test_client.get("/api/v1/sounds/extract/status")
# Should return 403 for non-admin users
assert response.status_code == 403
assert "Only administrators" in response.json()["detail"]
@pytest.mark.asyncio
async def test_get_user_extractions(
self, test_client: AsyncClient, auth_cookies: dict[str, str]
):
"""Test getting user extractions."""
# Set cookies on client instance to avoid deprecation warning
test_client.cookies.update(auth_cookies)
response = await test_client.get("/api/v1/sounds/extract")
# Should succeed and return empty list (no extractions in test DB)
assert response.status_code == 200
data = response.json()
assert "extractions" in data
assert isinstance(data["extractions"], list)

View File

@@ -0,0 +1,128 @@
"""Tests for extraction repository."""
from unittest.mock import AsyncMock, Mock
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction
from app.repositories.extraction import ExtractionRepository
class TestExtractionRepository:
"""Test extraction repository."""
@pytest.fixture
def mock_session(self):
"""Create a mock session."""
return Mock(spec=AsyncSession)
@pytest.fixture
def extraction_repo(self, mock_session):
"""Create an extraction repository with mock session."""
return ExtractionRepository(mock_session)
def test_init(self, extraction_repo):
"""Test repository initialization."""
assert extraction_repo.session is not None
@pytest.mark.asyncio
async def test_create_extraction(self, extraction_repo):
"""Test creating an extraction."""
extraction_data = {
"url": "https://www.youtube.com/watch?v=test",
"user_id": 1,
"service": "youtube",
"service_id": "test123",
"title": "Test Video",
"status": "pending",
}
# Mock the session operations
mock_extraction = Extraction(**extraction_data, id=1)
extraction_repo.session.add = Mock()
extraction_repo.session.commit = AsyncMock()
extraction_repo.session.refresh = AsyncMock()
# Mock the Extraction constructor to return our mock
with pytest.MonkeyPatch().context() as m:
m.setattr(
"app.repositories.extraction.Extraction",
lambda **kwargs: mock_extraction,
)
result = await extraction_repo.create(extraction_data)
assert result == mock_extraction
extraction_repo.session.add.assert_called_once()
extraction_repo.session.commit.assert_called_once()
extraction_repo.session.refresh.assert_called_once_with(mock_extraction)
@pytest.mark.asyncio
async def test_get_by_service_and_id(self, extraction_repo):
"""Test getting extraction by service and service_id."""
mock_result = Mock()
mock_result.first.return_value = Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test",
user_id=1,
status="pending",
)
extraction_repo.session.exec = AsyncMock(return_value=mock_result)
result = await extraction_repo.get_by_service_and_id("youtube", "test123")
assert result is not None
assert result.service == "youtube"
assert result.service_id == "test123"
extraction_repo.session.exec.assert_called_once()
@pytest.mark.asyncio
async def test_get_pending_extractions(self, extraction_repo):
"""Test getting pending extractions."""
mock_extraction = Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test",
user_id=1,
status="pending",
)
mock_result = Mock()
mock_result.all.return_value = [mock_extraction]
extraction_repo.session.exec = AsyncMock(return_value=mock_result)
result = await extraction_repo.get_pending_extractions()
assert len(result) == 1
assert result[0].status == "pending"
extraction_repo.session.exec.assert_called_once()
@pytest.mark.asyncio
async def test_update_extraction(self, extraction_repo):
"""Test updating an extraction."""
extraction = Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test",
user_id=1,
status="pending",
)
update_data = {"status": "completed", "sound_id": 42}
extraction_repo.session.commit = AsyncMock()
extraction_repo.session.refresh = AsyncMock()
result = await extraction_repo.update(extraction, update_data)
assert result.status == "completed"
assert result.sound_id == 42
extraction_repo.session.commit.assert_called_once()
extraction_repo.session.refresh.assert_called_once_with(extraction)

View File

@@ -0,0 +1,408 @@
"""Tests for extraction service."""
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction
from app.models.sound import Sound
from app.services.extraction import ExtractionService
class TestExtractionService:
"""Test extraction service."""
@pytest.fixture
def mock_session(self):
"""Create a mock session."""
return Mock(spec=AsyncSession)
@pytest.fixture
def extraction_service(self, mock_session):
"""Create an extraction service with mock session."""
with patch("app.services.extraction.Path.mkdir"):
return ExtractionService(mock_session)
def test_init(self, extraction_service):
"""Test service initialization."""
assert extraction_service.session is not None
assert extraction_service.extraction_repo is not None
assert extraction_service.sound_repo is not None
def test_sanitize_filename(self, extraction_service):
"""Test filename sanitization."""
test_cases = [
("Hello World", "Hello World"),
("Test<>Video", "Test__Video"),
("Bad/File\\Name", "Bad_File_Name"),
(" Spaces ", "Spaces"),
(
"Very long filename that exceeds the maximum length limit and should be truncated to 100 characters maximum",
"Very long filename that exceeds the maximum length limit and should be truncated to 100 characters m",
),
("", "untitled"),
]
for input_name, expected in test_cases:
result = extraction_service._sanitize_filename(input_name)
assert result == expected
@patch("app.services.extraction.yt_dlp.YoutubeDL")
def test_detect_service_info_youtube(self, mock_ydl_class, extraction_service):
"""Test service detection for YouTube."""
mock_ydl = Mock()
mock_ydl_class.return_value.__enter__.return_value = mock_ydl
mock_ydl.extract_info.return_value = {
"extractor": "youtube",
"id": "test123",
"title": "Test Video",
"duration": 240,
"uploader": "Test Channel",
}
result = extraction_service._detect_service_info(
"https://www.youtube.com/watch?v=test123"
)
assert result is not None
assert result["service"] == "youtube"
assert result["service_id"] == "test123"
assert result["title"] == "Test Video"
assert result["duration"] == 240
@patch("app.services.extraction.yt_dlp.YoutubeDL")
def test_detect_service_info_failure(self, mock_ydl_class, extraction_service):
"""Test service detection failure."""
mock_ydl = Mock()
mock_ydl_class.return_value.__enter__.return_value = mock_ydl
mock_ydl.extract_info.side_effect = Exception("Network error")
result = extraction_service._detect_service_info("https://invalid.url")
assert result is None
@pytest.mark.asyncio
async def test_create_extraction_success(self, extraction_service):
"""Test successful extraction creation."""
url = "https://www.youtube.com/watch?v=test123"
user_id = 1
# Mock service detection
service_info = {
"service": "youtube",
"service_id": "test123",
"title": "Test Video",
}
with patch.object(
extraction_service, "_detect_service_info", return_value=service_info
):
# Mock repository calls
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
return_value=None
)
mock_extraction = Extraction(
id=1,
url=url,
user_id=user_id,
service="youtube",
service_id="test123",
title="Test Video",
status="pending",
)
extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction
)
result = await extraction_service.create_extraction(url, user_id)
assert result["id"] == 1
assert result["service"] == "youtube"
assert result["service_id"] == "test123"
assert result["title"] == "Test Video"
assert result["status"] == "pending"
@pytest.mark.asyncio
async def test_create_extraction_duplicate(self, extraction_service):
"""Test extraction creation with duplicate service/service_id."""
url = "https://www.youtube.com/watch?v=test123"
user_id = 1
# Mock service detection
service_info = {
"service": "youtube",
"service_id": "test123",
"title": "Test Video",
}
existing_extraction = Extraction(
id=1,
url=url,
user_id=2, # Different user
service="youtube",
service_id="test123",
status="completed",
)
with patch.object(
extraction_service, "_detect_service_info", return_value=service_info
):
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
return_value=existing_extraction
)
with pytest.raises(ValueError, match="Extraction already exists"):
await extraction_service.create_extraction(url, user_id)
@pytest.mark.asyncio
async def test_create_extraction_invalid_url(self, extraction_service):
"""Test extraction creation with invalid URL."""
url = "https://invalid.url"
user_id = 1
with patch.object(
extraction_service, "_detect_service_info", return_value=None
):
with pytest.raises(
ValueError, match="Unable to detect service information"
):
await extraction_service.create_extraction(url, user_id)
def test_ensure_unique_filename(self, extraction_service):
"""Test unique filename generation."""
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Create original file
original_file = temp_path / "test.mp3"
original_file.touch()
# Test unique filename generation
result = extraction_service._ensure_unique_filename(original_file)
expected = temp_path / "test_1.mp3"
assert result == expected
# Create the first duplicate and test again
expected.touch()
result = extraction_service._ensure_unique_filename(original_file)
expected_2 = temp_path / "test_2.mp3"
assert result == expected_2
@pytest.mark.asyncio
async def test_create_sound_record(self, extraction_service):
"""Test sound record creation."""
# Create temporary audio file
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
audio_path = Path(f.name)
f.write(b"fake audio data")
try:
extraction = Extraction(
id=1,
service="youtube",
service_id="test123",
title="Test Video",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
status="processing",
)
mock_sound = Sound(
id=1,
type="EXT",
name="Test Video",
filename=audio_path.name,
duration=240000,
size=1024,
hash="test_hash",
is_deletable=True,
is_music=True,
is_normalized=False,
play_count=0,
)
with (
patch(
"app.services.extraction.get_audio_duration", return_value=240000
),
patch("app.services.extraction.get_file_size", return_value=1024),
patch(
"app.services.extraction.get_file_hash", return_value="test_hash"
),
):
extraction_service.sound_repo.create = AsyncMock(
return_value=mock_sound
)
result = await extraction_service._create_sound_record(
audio_path,
extraction.title,
extraction.service,
extraction.service_id,
)
assert result.type == "EXT"
assert result.name == "Test Video"
assert result.is_deletable is True
assert result.is_music is True
assert result.is_normalized is False
finally:
audio_path.unlink()
@pytest.mark.asyncio
async def test_normalize_sound_success(self, extraction_service):
"""Test sound normalization."""
sound = Sound(
id=1,
type="EXT",
name="Test Sound",
filename="test.mp3",
duration=240000,
size=1024,
hash="test_hash",
is_normalized=False,
)
mock_normalizer = Mock()
mock_normalizer.normalize_sound = AsyncMock(
return_value={"status": "normalized"}
)
with patch(
"app.services.extraction.SoundNormalizerService",
return_value=mock_normalizer,
):
# Should not raise exception
await extraction_service._normalize_sound(sound)
mock_normalizer.normalize_sound.assert_called_once_with(sound)
@pytest.mark.asyncio
async def test_normalize_sound_failure(self, extraction_service):
"""Test sound normalization failure."""
sound = Sound(
id=1,
type="EXT",
name="Test Sound",
filename="test.mp3",
duration=240000,
size=1024,
hash="test_hash",
is_normalized=False,
)
mock_normalizer = Mock()
mock_normalizer.normalize_sound = AsyncMock(
return_value={"status": "error", "error": "Test error"}
)
with patch(
"app.services.extraction.SoundNormalizerService",
return_value=mock_normalizer,
):
# Should not raise exception even on failure
await extraction_service._normalize_sound(sound)
mock_normalizer.normalize_sound.assert_called_once_with(sound)
@pytest.mark.asyncio
async def test_get_extraction_by_id(self, extraction_service):
"""Test getting extraction by ID."""
extraction = Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Test Video",
status="completed",
sound_id=42,
)
extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=extraction
)
result = await extraction_service.get_extraction_by_id(1)
assert result is not None
assert result["id"] == 1
assert result["service"] == "youtube"
assert result["service_id"] == "test123"
assert result["title"] == "Test Video"
assert result["status"] == "completed"
assert result["sound_id"] == 42
@pytest.mark.asyncio
async def test_get_extraction_by_id_not_found(self, extraction_service):
"""Test getting extraction by ID when not found."""
extraction_service.extraction_repo.get_by_id = AsyncMock(return_value=None)
result = await extraction_service.get_extraction_by_id(999)
assert result is None
@pytest.mark.asyncio
async def test_get_user_extractions(self, extraction_service):
"""Test getting user extractions."""
extractions = [
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Test Video 1",
status="completed",
sound_id=42,
),
Extraction(
id=2,
service="youtube",
service_id="test456",
url="https://www.youtube.com/watch?v=test456",
user_id=1,
title="Test Video 2",
status="pending",
),
]
extraction_service.extraction_repo.get_by_user = AsyncMock(
return_value=extractions
)
result = await extraction_service.get_user_extractions(1)
assert len(result) == 2
assert result[0]["id"] == 1
assert result[0]["title"] == "Test Video 1"
assert result[1]["id"] == 2
assert result[1]["title"] == "Test Video 2"
@pytest.mark.asyncio
async def test_get_pending_extractions(self, extraction_service):
"""Test getting pending extractions."""
pending_extractions = [
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test123",
user_id=1,
title="Pending Video",
status="pending",
),
]
extraction_service.extraction_repo.get_pending_extractions = AsyncMock(
return_value=pending_extractions
)
result = await extraction_service.get_pending_extractions()
assert len(result) == 1
assert result[0]["id"] == 1
assert result[0]["status"] == "pending"

View File

@@ -0,0 +1,298 @@
"""Tests for extraction background processor."""
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import pytest
from app.services.extraction_processor import ExtractionProcessor
class TestExtractionProcessor:
"""Test extraction background processor."""
@pytest.fixture
def processor(self):
"""Create an extraction processor instance."""
# Use a custom processor instance to avoid affecting the global one
return ExtractionProcessor()
def test_init(self, processor):
"""Test processor initialization."""
assert processor.max_concurrent > 0
assert len(processor.running_extractions) == 0
assert processor.processing_lock is not None
assert processor.shutdown_event is not None
assert processor.processor_task is None
@pytest.mark.asyncio
async def test_start_and_stop(self, processor):
"""Test starting and stopping the processor."""
# Mock the _process_queue method to avoid actual processing
with patch.object(processor, "_process_queue", new_callable=AsyncMock) as mock_process:
# Start the processor
await processor.start()
assert processor.processor_task is not None
assert not processor.processor_task.done()
# Stop the processor
await processor.stop()
assert processor.processor_task.done()
@pytest.mark.asyncio
async def test_start_already_running(self, processor):
"""Test starting processor when already running."""
with patch.object(processor, "_process_queue", new_callable=AsyncMock):
# Start first time
await processor.start()
first_task = processor.processor_task
# Start second time (should not create new task)
await processor.start()
assert processor.processor_task is first_task
# Clean up
await processor.stop()
@pytest.mark.asyncio
async def test_queue_extraction(self, processor):
"""Test queuing an extraction."""
extraction_id = 123
await processor.queue_extraction(extraction_id)
# The extraction should not be in running_extractions yet
# (it gets added when actually started by the processor)
assert extraction_id not in processor.running_extractions
@pytest.mark.asyncio
async def test_queue_extraction_already_running(self, processor):
"""Test queuing an extraction that's already running."""
extraction_id = 123
processor.running_extractions.add(extraction_id)
await processor.queue_extraction(extraction_id)
# Should still be in running extractions
assert extraction_id in processor.running_extractions
def test_get_status(self, processor):
"""Test getting processor status."""
status = processor.get_status()
assert "running" in status
assert "max_concurrent" in status
assert "currently_processing" in status
assert "processing_ids" in status
assert "available_slots" in status
assert status["max_concurrent"] == processor.max_concurrent
assert status["currently_processing"] == 0
assert status["available_slots"] == processor.max_concurrent
def test_get_status_with_running_extractions(self, processor):
"""Test getting processor status with running extractions."""
processor.running_extractions.add(123)
processor.running_extractions.add(456)
status = processor.get_status()
assert status["currently_processing"] == 2
assert status["available_slots"] == processor.max_concurrent - 2
assert 123 in status["processing_ids"]
assert 456 in status["processing_ids"]
def test_on_extraction_completed(self, processor):
"""Test extraction completion callback."""
extraction_id = 123
processor.running_extractions.add(extraction_id)
# Create a mock completed task
mock_task = Mock()
mock_task.exception.return_value = None
processor._on_extraction_completed(extraction_id, mock_task)
# Should be removed from running extractions
assert extraction_id not in processor.running_extractions
def test_on_extraction_completed_with_exception(self, processor):
"""Test extraction completion callback with exception."""
extraction_id = 123
processor.running_extractions.add(extraction_id)
# Create a mock task with exception
mock_task = Mock()
mock_task.exception.return_value = Exception("Test error")
processor._on_extraction_completed(extraction_id, mock_task)
# Should still be removed from running extractions
assert extraction_id not in processor.running_extractions
@pytest.mark.asyncio
async def test_process_single_extraction_success(self, processor):
"""Test processing a single extraction successfully."""
extraction_id = 123
# Mock the extraction service
mock_service = Mock()
mock_service.process_extraction = AsyncMock(
return_value={"status": "completed", "id": extraction_id}
)
with (
patch(
"app.services.extraction_processor.AsyncSession"
) as mock_session_class,
patch(
"app.services.extraction_processor.ExtractionService",
return_value=mock_service,
),
):
mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session
await processor._process_single_extraction(extraction_id)
mock_service.process_extraction.assert_called_once_with(extraction_id)
@pytest.mark.asyncio
async def test_process_single_extraction_failure(self, processor):
"""Test processing a single extraction with failure."""
extraction_id = 123
# Mock the extraction service to raise an exception
mock_service = Mock()
mock_service.process_extraction = AsyncMock(side_effect=Exception("Test error"))
with (
patch(
"app.services.extraction_processor.AsyncSession"
) as mock_session_class,
patch(
"app.services.extraction_processor.ExtractionService",
return_value=mock_service,
),
):
mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session
# Should not raise exception (errors are logged)
await processor._process_single_extraction(extraction_id)
mock_service.process_extraction.assert_called_once_with(extraction_id)
@pytest.mark.asyncio
async def test_process_pending_extractions_no_slots(self, processor):
"""Test processing when no slots are available."""
# Fill all slots
for i in range(processor.max_concurrent):
processor.running_extractions.add(i)
# Mock extraction service
mock_service = Mock()
mock_service.get_pending_extractions = AsyncMock(
return_value=[{"id": 100, "status": "pending"}]
)
with (
patch(
"app.services.extraction_processor.AsyncSession"
) as mock_session_class,
patch(
"app.services.extraction_processor.ExtractionService",
return_value=mock_service,
),
):
mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session
await processor._process_pending_extractions()
# Should not have started any new extractions
assert 100 not in processor.running_extractions
@pytest.mark.asyncio
async def test_process_pending_extractions_with_slots(self, processor):
"""Test processing when slots are available."""
# Mock extraction service
mock_service = Mock()
mock_service.get_pending_extractions = AsyncMock(
return_value=[
{"id": 100, "status": "pending"},
{"id": 101, "status": "pending"},
]
)
with (
patch(
"app.services.extraction_processor.AsyncSession"
) as mock_session_class,
patch.object(processor, "_process_single_extraction", new_callable=AsyncMock) as mock_process,
patch(
"app.services.extraction_processor.ExtractionService",
return_value=mock_service,
),
patch("asyncio.create_task") as mock_create_task,
):
mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session
# Mock task creation
mock_task = Mock()
mock_create_task.return_value = mock_task
await processor._process_pending_extractions()
# Should have added extractions to running set
assert 100 in processor.running_extractions
assert 101 in processor.running_extractions
# Should have created tasks for both
assert mock_create_task.call_count == 2
@pytest.mark.asyncio
async def test_process_pending_extractions_respect_limit(self, processor):
"""Test that processing respects concurrency limit."""
# Set max concurrent to 1 for this test
processor.max_concurrent = 1
# Mock extraction service with multiple pending extractions
mock_service = Mock()
mock_service.get_pending_extractions = AsyncMock(
return_value=[
{"id": 100, "status": "pending"},
{"id": 101, "status": "pending"},
{"id": 102, "status": "pending"},
]
)
with (
patch(
"app.services.extraction_processor.AsyncSession"
) as mock_session_class,
patch.object(processor, "_process_single_extraction", new_callable=AsyncMock) as mock_process,
patch(
"app.services.extraction_processor.ExtractionService",
return_value=mock_service,
),
patch("asyncio.create_task") as mock_create_task,
):
mock_session = AsyncMock()
mock_session_class.return_value.__aenter__.return_value = mock_session
# Mock task creation
mock_task = Mock()
mock_create_task.return_value = mock_task
await processor._process_pending_extractions()
# Should only have started one extraction (due to limit)
assert len(processor.running_extractions) == 1
assert mock_create_task.call_count == 1

View File

@@ -1,7 +1,7 @@
"""Tests for OAuth service.""" """Tests for OAuth service."""
from typing import Any from typing import Any
from unittest.mock import Mock, patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
@@ -117,7 +117,7 @@ class TestGoogleOAuthProvider:
"picture": "https://example.com/avatar.jpg", "picture": "https://example.com/avatar.jpg",
} }
with patch("httpx.AsyncClient.get") as mock_get: with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.json.return_value = mock_response_data mock_response.json.return_value = mock_response_data
@@ -162,7 +162,7 @@ class TestGitHubOAuthProvider:
{"email": "secondary@example.com", "primary": False, "verified": True}, {"email": "secondary@example.com", "primary": False, "verified": True},
] ]
with patch("httpx.AsyncClient.get") as mock_get: with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
# Mock user profile response # Mock user profile response
mock_user_response = Mock() mock_user_response = Mock()
mock_user_response.status_code = 200 mock_user_response.status_code = 200
@@ -174,7 +174,7 @@ class TestGitHubOAuthProvider:
mock_emails_response.json.return_value = mock_emails_data mock_emails_response.json.return_value = mock_emails_data
# Return different responses based on URL # Return different responses based on URL
def side_effect(url, **kwargs): async def side_effect(url, **kwargs):
if "user/emails" in str(url): if "user/emails" in str(url):
return mock_emails_response return mock_emails_response
return mock_user_response return mock_user_response