feat: Implement background extraction processor with concurrency control

- Added `ExtractionProcessor` class to handle extraction queue processing in the background.
- Implemented methods for starting, stopping, and queuing extractions with concurrency limits.
- Integrated logging for monitoring the processor's status and actions.
- Created tests for the extraction processor to ensure functionality and error handling.

test: Add unit tests for extraction API endpoints

- Created tests for successful extraction creation, authentication checks, and processor status retrieval.
- Ensured proper responses for authenticated and unauthenticated requests.

test: Implement unit tests for extraction repository

- Added tests for creating, retrieving, and updating extractions in the repository.
- Mocked database interactions to validate repository behavior without actual database access.

test: Add comprehensive tests for extraction service

- Developed tests for extraction creation, service detection, and sound record creation.
- Included tests for handling duplicate extractions and invalid URLs.

test: Add unit tests for extraction background processor

- Created tests for the `ExtractionProcessor` class to validate its behavior under various conditions.
- Ensured proper handling of extraction queuing, processing, and completion callbacks.

fix: Update OAuth service tests to use AsyncMock

- Modified OAuth provider tests to use `AsyncMock` for mocking asynchronous HTTP requests.
This commit is contained in:
JSC
2025-07-29 01:06:29 +02:00
parent c993230f98
commit 9b5f83eef0
11 changed files with 1860 additions and 4 deletions

View File

@@ -8,6 +8,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import get_current_active_user_flexible
from app.models.user import User
from app.services.extraction import ExtractionInfo, ExtractionService
from app.services.extraction_processor import extraction_processor
from app.services.sound_normalizer import NormalizationResults, SoundNormalizerService
from app.services.sound_scanner import ScanResults, SoundScannerService
@@ -28,6 +30,13 @@ async def get_sound_normalizer_service(
return SoundNormalizerService(session)
async def get_extraction_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> ExtractionService:
"""Get the extraction service."""
return ExtractionService(session)
# SCAN
@router.post("/scan")
async def scan_sounds(
@@ -233,3 +242,110 @@ async def normalize_sound_by_id(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to normalize sound: {e!s}",
) from e
# EXTRACT
@router.post("/extract")
async def create_extraction(
url: str,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, ExtractionInfo | str]:
"""Create a new extraction job for a URL."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
extraction_info = await extraction_service.create_extraction(
url, current_user.id
)
# Queue the extraction for background processing
await extraction_processor.queue_extraction(extraction_info["id"])
return {
"message": "Extraction queued successfully",
"extraction": extraction_info,
}
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
) from e
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create extraction: {e!s}",
) from e
@router.get("/extract/status")
async def get_extraction_processor_status(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
) -> dict:
"""Get the status of the extraction processor."""
# Only allow admins to see processor status
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can view processor status",
)
return extraction_processor.get_status()
@router.get("/extract/{extraction_id}")
async def get_extraction(
extraction_id: int,
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> ExtractionInfo:
"""Get extraction information by ID."""
try:
extraction_info = await extraction_service.get_extraction_by_id(extraction_id)
if not extraction_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Extraction {extraction_id} not found",
)
return extraction_info
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extraction: {e!s}",
) from e
@router.get("/extract")
async def get_user_extractions(
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
extraction_service: Annotated[ExtractionService, Depends(get_extraction_service)],
) -> dict[str, list[ExtractionInfo]]:
"""Get all extractions for the current user."""
try:
if current_user.id is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User ID not available",
)
extractions = await extraction_service.get_user_extractions(current_user.id)
return {
"extractions": extractions,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get extractions: {e!s}",
) from e

View File

@@ -52,5 +52,12 @@ class Settings(BaseSettings):
NORMALIZED_AUDIO_BITRATE: str = "256k"
NORMALIZED_AUDIO_PASSES: int = 2 # 1 for one-pass, 2 for two-pass
# Audio Extraction Configuration
EXTRACTION_AUDIO_FORMAT: str = "mp3"
EXTRACTION_AUDIO_BITRATE: str = "256k"
EXTRACTION_TEMP_DIR: str = "sounds/temp"
EXTRACTION_THUMBNAILS_DIR: str = "sounds/originals/extracted/thumbnails"
EXTRACTION_MAX_CONCURRENT: int = 2 # Maximum concurrent extractions
settings = Settings()

View File

@@ -9,6 +9,7 @@ from app.api import api_router
from app.core.database import init_db
from app.core.logging import get_logger, setup_logging
from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor
from app.services.socket import socket_manager
@@ -22,9 +23,17 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
await init_db()
logger.info("Database initialized")
# Start the extraction processor
await extraction_processor.start()
logger.info("Extraction processor started")
yield
logger.info("Shutting down application")
# Stop the extraction processor
await extraction_processor.stop()
logger.info("Extraction processor stopped")
def create_app():

View File

@@ -0,0 +1,82 @@
"""Extraction repository for database operations."""
from sqlalchemy import desc
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction
class ExtractionRepository:
"""Repository for extraction database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the extraction repository."""
self.session = session
async def create(self, extraction_data: dict) -> Extraction:
"""Create a new extraction."""
extraction = Extraction(**extraction_data)
self.session.add(extraction)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def get_by_id(self, extraction_id: int) -> Extraction | None:
"""Get an extraction by ID."""
result = await self.session.exec(
select(Extraction).where(Extraction.id == extraction_id)
)
return result.first()
async def get_by_service_and_id(
self, service: str, service_id: str
) -> Extraction | None:
"""Get an extraction by service and service_id."""
result = await self.session.exec(
select(Extraction).where(
Extraction.service == service, Extraction.service_id == service_id
)
)
return result.first()
async def get_by_user(self, user_id: int) -> list[Extraction]:
"""Get all extractions for a user."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.user_id == user_id)
.order_by(desc(Extraction.created_at))
)
return list(result.all())
async def get_pending_extractions(self) -> list[Extraction]:
"""Get all pending extractions."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == "pending")
.order_by(Extraction.created_at)
)
return list(result.all())
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
"""Update an extraction."""
for key, value in update_data.items():
setattr(extraction, key, value)
await self.session.commit()
await self.session.refresh(extraction)
return extraction
async def delete(self, extraction: Extraction) -> None:
"""Delete an extraction."""
await self.session.delete(extraction)
await self.session.commit()
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
"""Get extractions by status."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == status)
.order_by(desc(Extraction.created_at))
)
return list(result.all())

517
app/services/extraction.py Normal file
View File

@@ -0,0 +1,517 @@
"""Extraction service for audio extraction from external services using yt-dlp."""
import shutil
from pathlib import Path
from typing import TypedDict
import yt_dlp
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.models.extraction import Extraction
from app.models.sound import Sound
from app.repositories.extraction import ExtractionRepository
from app.repositories.sound import SoundRepository
from app.services.sound_normalizer import SoundNormalizerService
from app.utils.audio import get_audio_duration, get_file_hash, get_file_size
logger = get_logger(__name__)
class ExtractionInfo(TypedDict):
"""Type definition for extraction information."""
id: int
url: str
service: str
service_id: str
title: str | None
status: str
error: str | None
sound_id: int | None
class ExtractionService:
"""Service for extracting audio from external services using yt-dlp."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the extraction service."""
self.session = session
self.extraction_repo = ExtractionRepository(session)
self.sound_repo = SoundRepository(session)
# Ensure required directories exist
self._ensure_directories()
def _ensure_directories(self) -> None:
"""Ensure all required directories exist."""
directories = [
settings.EXTRACTION_TEMP_DIR,
"sounds/originals/extracted",
settings.EXTRACTION_THUMBNAILS_DIR,
]
for directory in directories:
Path(directory).mkdir(parents=True, exist_ok=True)
logger.debug("Ensured directory exists: %s", directory)
async def create_extraction(self, url: str, user_id: int) -> ExtractionInfo:
"""Create a new extraction job."""
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
try:
# First, detect service and service_id using yt-dlp
service_info = self._detect_service_info(url)
if not service_info:
raise ValueError("Unable to detect service information from URL")
service = service_info["service"]
service_id = service_info["service_id"]
title = service_info.get("title")
logger.info(
"Detected service: %s, service_id: %s, title: %s",
service,
service_id,
title,
)
# Check if extraction already exists
existing = await self.extraction_repo.get_by_service_and_id(
service, service_id
)
if existing:
error_msg = f"Extraction already exists for {service}:{service_id}"
logger.warning(error_msg)
raise ValueError(error_msg)
# Create the extraction record
extraction_data = {
"url": url,
"user_id": user_id,
"service": service,
"service_id": service_id,
"title": title,
"status": "pending",
}
extraction = await self.extraction_repo.create(extraction_data)
logger.info("Created extraction with ID: %d", extraction.id)
return {
"id": extraction.id or 0, # Should never be None for created extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
}
except Exception:
logger.exception("Failed to create extraction for URL: %s", url)
raise
def _detect_service_info(self, url: str) -> dict | None:
"""Detect service information from URL using yt-dlp."""
try:
# Configure yt-dlp for info extraction only
ydl_opts = {
"quiet": True,
"no_warnings": True,
"extract_flat": False,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
# Extract info without downloading
info = ydl.extract_info(url, download=False)
if not info:
return None
# Map extractor names to our service names
extractor_map = {
"youtube": "youtube",
"dailymotion": "dailymotion",
"vimeo": "vimeo",
"soundcloud": "soundcloud",
"twitter": "twitter",
"tiktok": "tiktok",
"instagram": "instagram",
}
extractor = info.get("extractor", "").lower()
service = extractor_map.get(extractor, extractor)
return {
"service": service,
"service_id": str(info.get("id", "")),
"title": info.get("title"),
"duration": info.get("duration"),
"uploader": info.get("uploader"),
"description": info.get("description"),
}
except Exception:
logger.exception("Failed to detect service info for URL: %s", url)
return None
async def process_extraction(self, extraction_id: int) -> ExtractionInfo:
"""Process an extraction job."""
extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction:
raise ValueError(f"Extraction {extraction_id} not found")
if extraction.status != "pending":
raise ValueError(f"Extraction {extraction_id} is not pending")
# Store all needed values early to avoid session detachment issues
user_id = extraction.user_id
extraction_url = extraction.url
extraction_title = extraction.title
extraction_service = extraction.service
extraction_service_id = extraction.service_id
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
try:
# Update status to processing
await self.extraction_repo.update(extraction, {"status": "processing"})
# Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url
)
# Move files to final locations
final_audio_path, final_thumbnail_path = (
await self._move_files_to_final_location(
audio_file,
thumbnail_file,
extraction_title,
extraction_service,
extraction_service_id,
)
)
# Create Sound record
sound = await self._create_sound_record(
final_audio_path,
extraction_title,
extraction_service,
extraction_service_id,
)
# Store sound_id early to avoid session detachment issues
sound_id = sound.id
# Normalize the sound
await self._normalize_sound(sound)
# Add to main playlist
await self._add_to_main_playlist(sound, user_id)
# Update extraction with success
await self.extraction_repo.update(
extraction,
{
"status": "completed",
"sound_id": sound_id,
"error": None,
},
)
logger.info("Successfully processed extraction %d", extraction_id)
return {
"id": extraction_id,
"url": extraction_url,
"service": extraction_service,
"service_id": extraction_service_id,
"title": extraction_title,
"status": "completed",
"error": None,
"sound_id": sound_id,
}
except Exception as e:
error_msg = str(e)
logger.exception(
"Failed to process extraction %d: %s", extraction_id, error_msg
)
# Update extraction with error
await self.extraction_repo.update(
extraction,
{
"status": "failed",
"error": error_msg,
},
)
return {
"id": extraction_id,
"url": extraction_url,
"service": extraction_service,
"service_id": extraction_service_id,
"title": extraction_title,
"status": "failed",
"error": error_msg,
"sound_id": None,
}
async def _extract_media(
self, extraction_id: int, extraction_url: str
) -> tuple[Path, Path | None]:
"""Extract audio and thumbnail using yt-dlp."""
temp_dir = Path(settings.EXTRACTION_TEMP_DIR)
# Create unique filename based on extraction ID
output_template = str(
temp_dir / f"extraction_{extraction_id}_%(title)s.%(ext)s"
)
# Configure yt-dlp options
ydl_opts = {
"format": "bestaudio/best",
"outtmpl": output_template,
"extractaudio": True,
"audioformat": settings.EXTRACTION_AUDIO_FORMAT,
"audioquality": settings.EXTRACTION_AUDIO_BITRATE,
"writethumbnail": True,
"writeinfojson": False,
"writeautomaticsub": False,
"writesubtitles": False,
"postprocessors": [
{
"key": "FFmpegExtractAudio",
"preferredcodec": settings.EXTRACTION_AUDIO_FORMAT,
"preferredquality": settings.EXTRACTION_AUDIO_BITRATE.rstrip("k"),
},
],
}
try:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
# Download and extract
ydl.download([extraction_url])
# Find the extracted files
audio_files = list(
temp_dir.glob(
f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}"
)
)
thumbnail_files = (
list(temp_dir.glob(f"extraction_{extraction_id}_*.webp"))
+ list(temp_dir.glob(f"extraction_{extraction_id}_*.jpg"))
+ list(temp_dir.glob(f"extraction_{extraction_id}_*.png"))
)
if not audio_files:
raise RuntimeError("No audio file was created during extraction")
audio_file = audio_files[0]
thumbnail_file = thumbnail_files[0] if thumbnail_files else None
logger.info(
"Extracted audio: %s, thumbnail: %s",
audio_file,
thumbnail_file or "None",
)
return audio_file, thumbnail_file
except Exception as e:
logger.exception("yt-dlp extraction failed for %s", extraction_url)
raise RuntimeError(f"Audio extraction failed: {e}") from e
async def _move_files_to_final_location(
self,
audio_file: Path,
thumbnail_file: Path | None,
title: str | None,
service: str,
service_id: str,
) -> tuple[Path, Path | None]:
"""Move extracted files to their final locations."""
# Generate clean filename based on title and service
safe_title = self._sanitize_filename(title or f"{service}_{service_id}")
# Move audio file
final_audio_path = (
Path("sounds/originals/extracted")
/ f"{safe_title}.{settings.EXTRACTION_AUDIO_FORMAT}"
)
final_audio_path = self._ensure_unique_filename(final_audio_path)
shutil.move(str(audio_file), str(final_audio_path))
logger.info("Moved audio file to: %s", final_audio_path)
# Move thumbnail file if it exists
final_thumbnail_path = None
if thumbnail_file:
thumbnail_ext = thumbnail_file.suffix
final_thumbnail_path = (
Path(settings.EXTRACTION_THUMBNAILS_DIR)
/ f"{safe_title}{thumbnail_ext}"
)
final_thumbnail_path = self._ensure_unique_filename(final_thumbnail_path)
shutil.move(str(thumbnail_file), str(final_thumbnail_path))
logger.info("Moved thumbnail file to: %s", final_thumbnail_path)
return final_audio_path, final_thumbnail_path
def _sanitize_filename(self, filename: str) -> str:
"""Sanitize filename for filesystem."""
# Remove or replace problematic characters
invalid_chars = '<>:"/\\|?*'
for char in invalid_chars:
filename = filename.replace(char, "_")
# Limit length and remove leading/trailing spaces
filename = filename.strip()[:100]
return filename or "untitled"
def _ensure_unique_filename(self, filepath: Path) -> Path:
"""Ensure filename is unique by adding counter if needed."""
if not filepath.exists():
return filepath
stem = filepath.stem
suffix = filepath.suffix
parent = filepath.parent
counter = 1
while True:
new_path = parent / f"{stem}_{counter}{suffix}"
if not new_path.exists():
return new_path
counter += 1
async def _create_sound_record(
self, audio_path: Path, title: str | None, service: str, service_id: str
) -> Sound:
"""Create a Sound record for the extracted audio."""
# Get audio metadata
duration = get_audio_duration(audio_path)
size = get_file_size(audio_path)
file_hash = get_file_hash(audio_path)
# Create sound data
sound_data = {
"type": "EXT",
"name": title or f"{service}_{service_id}",
"filename": audio_path.name,
"duration": duration,
"size": size,
"hash": file_hash,
"is_deletable": True, # Extracted sounds can be deleted
"is_music": True, # Assume extracted content is music
"is_normalized": False,
"play_count": 0,
}
sound = await self.sound_repo.create(sound_data)
logger.info("Created sound record with ID: %d", sound.id)
return sound
async def _normalize_sound(self, sound: Sound) -> None:
"""Normalize the extracted sound."""
try:
normalizer_service = SoundNormalizerService(self.session)
result = await normalizer_service.normalize_sound(sound)
if result["status"] == "error":
logger.warning(
"Failed to normalize sound %d: %s",
sound.id,
result.get("error"),
)
else:
logger.info("Successfully normalized sound %d", sound.id)
except Exception as e:
logger.exception("Error normalizing sound %d: %s", sound.id, e)
# Don't fail the extraction if normalization fails
async def _add_to_main_playlist(self, sound: Sound, user_id: int) -> None:
"""Add the sound to the user's main playlist."""
try:
# This is a placeholder - implement based on your playlist logic
# For now, we'll just log that we would add it to the main playlist
logger.info(
"Would add sound %d to main playlist for user %d",
sound.id,
user_id,
)
except Exception as e:
logger.exception(
"Error adding sound %d to main playlist for user %d: %s",
sound.id,
user_id,
e,
)
# Don't fail the extraction if playlist addition fails
async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None:
"""Get extraction information by ID."""
extraction = await self.extraction_repo.get_by_id(extraction_id)
if not extraction:
return None
return {
"id": extraction.id or 0, # Should never be None for existing extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
}
async def get_user_extractions(self, user_id: int) -> list[ExtractionInfo]:
"""Get all extractions for a user."""
extractions = await self.extraction_repo.get_by_user(user_id)
return [
{
"id": extraction.id
or 0, # Should never be None for existing extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
}
for extraction in extractions
]
async def get_pending_extractions(self) -> list[ExtractionInfo]:
"""Get all pending extractions."""
extractions = await self.extraction_repo.get_pending_extractions()
return [
{
"id": extraction.id
or 0, # Should never be None for existing extraction
"url": extraction.url,
"service": extraction.service,
"service_id": extraction.service_id,
"title": extraction.title,
"status": extraction.status,
"error": extraction.error,
"sound_id": extraction.sound_id,
}
for extraction in extractions
]

View File

@@ -0,0 +1,196 @@
"""Background extraction processor for handling extraction queue."""
import asyncio
from typing import Set
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.database import engine
from app.core.logging import get_logger
from app.services.extraction import ExtractionService
logger = get_logger(__name__)
class ExtractionProcessor:
"""Background processor for handling extraction queue with concurrency control."""
def __init__(self) -> None:
"""Initialize the extraction processor."""
self.max_concurrent = settings.EXTRACTION_MAX_CONCURRENT
self.running_extractions: Set[int] = set()
self.processing_lock = asyncio.Lock()
self.shutdown_event = asyncio.Event()
self.processor_task: asyncio.Task | None = None
logger.info(
"Initialized extraction processor with max concurrent: %d",
self.max_concurrent,
)
async def start(self) -> None:
"""Start the background extraction processor."""
if self.processor_task and not self.processor_task.done():
logger.warning("Extraction processor is already running")
return
self.shutdown_event.clear()
self.processor_task = asyncio.create_task(self._process_queue())
logger.info("Started extraction processor")
async def stop(self) -> None:
"""Stop the background extraction processor."""
logger.info("Stopping extraction processor...")
self.shutdown_event.set()
if self.processor_task and not self.processor_task.done():
try:
await asyncio.wait_for(self.processor_task, timeout=30.0)
except asyncio.TimeoutError:
logger.warning(
"Extraction processor did not stop gracefully, cancelling..."
)
self.processor_task.cancel()
try:
await self.processor_task
except asyncio.CancelledError:
pass
logger.info("Extraction processor stopped")
async def queue_extraction(self, extraction_id: int) -> None:
"""Queue an extraction for processing."""
async with self.processing_lock:
if extraction_id not in self.running_extractions:
logger.info("Queued extraction %d for processing", extraction_id)
# The processor will pick it up on the next cycle
else:
logger.warning(
"Extraction %d is already being processed", extraction_id
)
async def _process_queue(self) -> None:
"""Main processing loop that handles the extraction queue."""
logger.info("Starting extraction queue processor")
while not self.shutdown_event.is_set():
try:
await self._process_pending_extractions()
# Wait before checking for new extractions
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=5.0)
break # Shutdown requested
except asyncio.TimeoutError:
continue # Continue processing
except Exception as e:
logger.exception("Error in extraction queue processor: %s", e)
# Wait a bit before retrying to avoid tight error loops
try:
await asyncio.wait_for(self.shutdown_event.wait(), timeout=10.0)
break # Shutdown requested
except asyncio.TimeoutError:
continue
logger.info("Extraction queue processor stopped")
async def _process_pending_extractions(self) -> None:
"""Process pending extractions up to the concurrency limit."""
async with self.processing_lock:
# Check how many slots are available
available_slots = self.max_concurrent - len(self.running_extractions)
if available_slots <= 0:
return # No available slots
# Get pending extractions from database
async with AsyncSession(engine) as session:
extraction_service = ExtractionService(session)
pending_extractions = await extraction_service.get_pending_extractions()
# Filter out extractions that are already being processed
available_extractions = [
ext
for ext in pending_extractions
if ext["id"] not in self.running_extractions
]
# Start processing up to available slots
extractions_to_start = available_extractions[:available_slots]
for extraction_info in extractions_to_start:
extraction_id = extraction_info["id"]
self.running_extractions.add(extraction_id)
# Start processing this extraction in the background
task = asyncio.create_task(
self._process_single_extraction(extraction_id)
)
task.add_done_callback(
lambda t, eid=extraction_id: self._on_extraction_completed(
eid,
t,
)
)
logger.info(
"Started processing extraction %d (%d/%d slots used)",
extraction_id,
len(self.running_extractions),
self.max_concurrent,
)
async def _process_single_extraction(self, extraction_id: int) -> None:
"""Process a single extraction."""
try:
logger.info("Processing extraction %d", extraction_id)
async with AsyncSession(engine) as session:
extraction_service = ExtractionService(session)
result = await extraction_service.process_extraction(extraction_id)
logger.info(
"Completed extraction %d with status: %s",
extraction_id,
result["status"],
)
except Exception as e:
logger.exception("Error processing extraction %d: %s", extraction_id, e)
def _on_extraction_completed(self, extraction_id: int, task: asyncio.Task) -> None:
"""Callback when an extraction task is completed."""
# Remove from running set
self.running_extractions.discard(extraction_id)
# Check if the task had an exception
if task.exception():
logger.error(
"Extraction %d completed with exception: %s",
extraction_id,
task.exception(),
)
else:
logger.info(
"Extraction %d completed successfully (%d/%d slots used)",
extraction_id,
len(self.running_extractions),
self.max_concurrent,
)
def get_status(self) -> dict:
"""Get the current status of the extraction processor."""
return {
"running": self.processor_task is not None
and not self.processor_task.done(),
"max_concurrent": self.max_concurrent,
"currently_processing": len(self.running_extractions),
"processing_ids": list(self.running_extractions),
"available_slots": self.max_concurrent - len(self.running_extractions),
}
# Global extraction processor instance
extraction_processor = ExtractionProcessor()