feat: Update Extraction model and service to support deferred service detection

This commit is contained in:
JSC
2025-07-29 10:50:50 +02:00
parent 9b5f83eef0
commit e3fcab99ae
4 changed files with 227 additions and 176 deletions

View File

@@ -12,8 +12,8 @@ if TYPE_CHECKING:
class Extraction(BaseModel, table=True): class Extraction(BaseModel, table=True):
"""Database model for a stream.""" """Database model for a stream."""
service: str = Field(nullable=False) service: str | None = Field(default=None)
service_id: str = Field(nullable=False) service_id: str | None = Field(default=None)
user_id: int = Field(foreign_key="user.id", nullable=False) user_id: int = Field(foreign_key="user.id", nullable=False)
sound_id: int | None = Field(foreign_key="sound.id", default=None) sound_id: int | None = Field(foreign_key="sound.id", default=None)
url: str = Field(nullable=False) url: str = Field(nullable=False)
@@ -25,14 +25,8 @@ class Extraction(BaseModel, table=True):
status: str = Field(nullable=False, default="pending") status: str = Field(nullable=False, default="pending")
error: str | None = Field(default=None) error: str | None = Field(default=None)
# constraints # constraints - only enforce uniqueness when both service and service_id are not null
__table_args__ = ( __table_args__ = ()
UniqueConstraint(
"service",
"service_id",
name="uq_extraction_service_service_id",
),
)
# relationships # relationships
sound: "Sound" = Relationship(back_populates="extractions") sound: "Sound" = Relationship(back_populates="extractions")

View File

@@ -1,5 +1,6 @@
"""Extraction service for audio extraction from external services using yt-dlp.""" """Extraction service for audio extraction from external services using yt-dlp."""
import asyncio
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import TypedDict from typing import TypedDict
@@ -24,8 +25,8 @@ class ExtractionInfo(TypedDict):
id: int id: int
url: str url: str
service: str service: str | None
service_id: str service_id: str | None
title: str | None title: str | None
status: str status: str
error: str | None error: str | None
@@ -61,39 +62,13 @@ class ExtractionService:
logger.info("Creating extraction for URL: %s (user: %d)", url, user_id) logger.info("Creating extraction for URL: %s (user: %d)", url, user_id)
try: try:
# First, detect service and service_id using yt-dlp # Create the extraction record without service detection for fast response
service_info = self._detect_service_info(url)
if not service_info:
raise ValueError("Unable to detect service information from URL")
service = service_info["service"]
service_id = service_info["service_id"]
title = service_info.get("title")
logger.info(
"Detected service: %s, service_id: %s, title: %s",
service,
service_id,
title,
)
# Check if extraction already exists
existing = await self.extraction_repo.get_by_service_and_id(
service, service_id
)
if existing:
error_msg = f"Extraction already exists for {service}:{service_id}"
logger.warning(error_msg)
raise ValueError(error_msg)
# Create the extraction record
extraction_data = { extraction_data = {
"url": url, "url": url,
"user_id": user_id, "user_id": user_id,
"service": service, "service": None, # Will be detected during processing
"service_id": service_id, "service_id": None, # Will be detected during processing
"title": title, "title": None, # Will be detected during processing
"status": "pending", "status": "pending",
} }
@@ -115,7 +90,7 @@ class ExtractionService:
logger.exception("Failed to create extraction for URL: %s", url) logger.exception("Failed to create extraction for URL: %s", url)
raise raise
def _detect_service_info(self, url: str) -> dict | None: async def _detect_service_info(self, url: str) -> dict[str, str | None] | None:
"""Detect service information from URL using yt-dlp.""" """Detect service information from URL using yt-dlp."""
try: try:
# Configure yt-dlp for info extraction only # Configure yt-dlp for info extraction only
@@ -125,34 +100,21 @@ class ExtractionService:
"extract_flat": False, "extract_flat": False,
} }
def _extract_info() -> dict | None:
with yt_dlp.YoutubeDL(ydl_opts) as ydl: with yt_dlp.YoutubeDL(ydl_opts) as ydl:
# Extract info without downloading # Extract info without downloading
info = ydl.extract_info(url, download=False) return ydl.extract_info(url, download=False)
# Run the blocking operation in a thread pool
info = await asyncio.to_thread(_extract_info)
if not info: if not info:
return None return None
# Map extractor names to our service names
extractor_map = {
"youtube": "youtube",
"dailymotion": "dailymotion",
"vimeo": "vimeo",
"soundcloud": "soundcloud",
"twitter": "twitter",
"tiktok": "tiktok",
"instagram": "instagram",
}
extractor = info.get("extractor", "").lower()
service = extractor_map.get(extractor, extractor)
return { return {
"service": service, "service": info.get("extractor", ""),
"service_id": str(info.get("id", "")), "service_id": str(info.get("id", "")),
"title": info.get("title"), "title": info.get("title"),
"duration": info.get("duration"),
"uploader": info.get("uploader"),
"description": info.get("description"),
} }
except Exception: except Exception:
@@ -171,9 +133,9 @@ class ExtractionService:
# Store all needed values early to avoid session detachment issues # Store all needed values early to avoid session detachment issues
user_id = extraction.user_id user_id = extraction.user_id
extraction_url = extraction.url extraction_url = extraction.url
extraction_title = extraction.title
extraction_service = extraction.service extraction_service = extraction.service
extraction_service_id = extraction.service_id extraction_service_id = extraction.service_id
extraction_title = extraction.title
logger.info("Processing extraction %d: %s", extraction_id, extraction_url) logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
@@ -181,21 +143,52 @@ class ExtractionService:
# Update status to processing # Update status to processing
await self.extraction_repo.update(extraction, {"status": "processing"}) await self.extraction_repo.update(extraction, {"status": "processing"})
# Detect service info if not already available
if not extraction_service or not extraction_service_id:
logger.info("Detecting service info for extraction %d", extraction_id)
service_info = await self._detect_service_info(extraction_url)
if not service_info:
raise ValueError("Unable to detect service information from URL")
# Check if extraction already exists for this service
existing = await self.extraction_repo.get_by_service_and_id(
service_info["service"], service_info["service_id"]
)
if existing and existing.id != extraction_id:
error_msg = f"Extraction already exists for {service_info['service']}:{service_info['service_id']}"
logger.warning(error_msg)
raise ValueError(error_msg)
# Update extraction with service info
update_data = {
"service": service_info["service"],
"service_id": service_info["service_id"],
"title": service_info.get("title") or extraction_title,
}
await self.extraction_repo.update(extraction, update_data)
# Update values for processing
extraction_service = service_info["service"]
extraction_service_id = service_info["service_id"]
extraction_title = service_info.get("title") or extraction_title
# Extract audio and thumbnail # Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media( audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_url extraction_id, extraction_url
) )
# Move files to final locations # Move files to final locations
final_audio_path, final_thumbnail_path = ( (
await self._move_files_to_final_location( final_audio_path,
final_thumbnail_path,
) = await self._move_files_to_final_location(
audio_file, audio_file,
thumbnail_file, thumbnail_file,
extraction_title, extraction_title,
extraction_service, extraction_service,
extraction_service_id, extraction_service_id,
) )
)
# Create Sound record # Create Sound record
sound = await self._create_sound_record( sound = await self._create_sound_record(
@@ -294,11 +287,15 @@ class ExtractionService:
], ],
} }
try: def _download_media() -> None:
with yt_dlp.YoutubeDL(ydl_opts) as ydl: with yt_dlp.YoutubeDL(ydl_opts) as ydl:
# Download and extract # Download and extract
ydl.download([extraction_url]) ydl.download([extraction_url])
try:
# Run the blocking download operation in a thread pool
await asyncio.to_thread(_download_media)
# Find the extracted files # Find the extracted files
audio_files = list( audio_files = list(
temp_dir.glob( temp_dir.glob(
@@ -334,12 +331,14 @@ class ExtractionService:
audio_file: Path, audio_file: Path,
thumbnail_file: Path | None, thumbnail_file: Path | None,
title: str | None, title: str | None,
service: str, service: str | None,
service_id: str, service_id: str | None,
) -> tuple[Path, Path | None]: ) -> tuple[Path, Path | None]:
"""Move extracted files to their final locations.""" """Move extracted files to their final locations."""
# Generate clean filename based on title and service # Generate clean filename based on title and service
safe_title = self._sanitize_filename(title or f"{service}_{service_id}") safe_title = self._sanitize_filename(
title or f"{service or 'unknown'}_{service_id or 'unknown'}"
)
# Move audio file # Move audio file
final_audio_path = ( final_audio_path = (
@@ -395,7 +394,11 @@ class ExtractionService:
counter += 1 counter += 1
async def _create_sound_record( async def _create_sound_record(
self, audio_path: Path, title: str | None, service: str, service_id: str self,
audio_path: Path,
title: str | None,
service: str | None,
service_id: str | None,
) -> Sound: ) -> Sound:
"""Create a Sound record for the extracted audio.""" """Create a Sound record for the extracted audio."""
# Get audio metadata # Get audio metadata
@@ -406,7 +409,7 @@ class ExtractionService:
# Create sound data # Create sound data
sound_data = { sound_data = {
"type": "EXT", "type": "EXT",
"name": title or f"{service}_{service_id}", "name": title or f"{service or 'unknown'}_{service_id or 'unknown'}",
"filename": audio_path.name, "filename": audio_path.name,
"duration": duration, "duration": duration,
"size": size, "size": size,

View File

@@ -1,7 +1,6 @@
"""Background extraction processor for handling extraction queue.""" """Background extraction processor for handling extraction queue."""
import asyncio import asyncio
from typing import Set
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -19,7 +18,7 @@ class ExtractionProcessor:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the extraction processor.""" """Initialize the extraction processor."""
self.max_concurrent = settings.EXTRACTION_MAX_CONCURRENT self.max_concurrent = settings.EXTRACTION_MAX_CONCURRENT
self.running_extractions: Set[int] = set() self.running_extractions: set[int] = set()
self.processing_lock = asyncio.Lock() self.processing_lock = asyncio.Lock()
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
self.processor_task: asyncio.Task | None = None self.processor_task: asyncio.Task | None = None
@@ -71,7 +70,7 @@ class ExtractionProcessor:
) )
async def _process_queue(self) -> None: async def _process_queue(self) -> None:
"""Main processing loop that handles the extraction queue.""" """Process the extraction queue in the main processing loop."""
logger.info("Starting extraction queue processor") logger.info("Starting extraction queue processor")
while not self.shutdown_event.is_set(): while not self.shutdown_event.is_set():
@@ -161,7 +160,7 @@ class ExtractionProcessor:
logger.exception("Error processing extraction %d: %s", extraction_id, e) logger.exception("Error processing extraction %d: %s", extraction_id, e)
def _on_extraction_completed(self, extraction_id: int, task: asyncio.Task) -> None: def _on_extraction_completed(self, extraction_id: int, task: asyncio.Task) -> None:
"""Callback when an extraction task is completed.""" """Handle completion of an extraction task."""
# Remove from running set # Remove from running set
self.running_extractions.discard(extraction_id) self.running_extractions.discard(extraction_id)

View File

@@ -51,7 +51,8 @@ class TestExtractionService:
assert result == expected assert result == expected
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
def test_detect_service_info_youtube(self, mock_ydl_class, extraction_service): @pytest.mark.asyncio
async def test_detect_service_info_youtube(self, mock_ydl_class, extraction_service):
"""Test service detection for YouTube.""" """Test service detection for YouTube."""
mock_ydl = Mock() mock_ydl = Mock()
mock_ydl_class.return_value.__enter__.return_value = mock_ydl mock_ydl_class.return_value.__enter__.return_value = mock_ydl
@@ -63,7 +64,7 @@ class TestExtractionService:
"uploader": "Test Channel", "uploader": "Test Channel",
} }
result = extraction_service._detect_service_info( result = await extraction_service._detect_service_info(
"https://www.youtube.com/watch?v=test123" "https://www.youtube.com/watch?v=test123"
) )
@@ -71,16 +72,16 @@ class TestExtractionService:
assert result["service"] == "youtube" assert result["service"] == "youtube"
assert result["service_id"] == "test123" assert result["service_id"] == "test123"
assert result["title"] == "Test Video" assert result["title"] == "Test Video"
assert result["duration"] == 240
@patch("app.services.extraction.yt_dlp.YoutubeDL") @patch("app.services.extraction.yt_dlp.YoutubeDL")
def test_detect_service_info_failure(self, mock_ydl_class, extraction_service): @pytest.mark.asyncio
async def test_detect_service_info_failure(self, mock_ydl_class, extraction_service):
"""Test service detection failure.""" """Test service detection failure."""
mock_ydl = Mock() mock_ydl = Mock()
mock_ydl_class.return_value.__enter__.return_value = mock_ydl mock_ydl_class.return_value.__enter__.return_value = mock_ydl
mock_ydl.extract_info.side_effect = Exception("Network error") mock_ydl.extract_info.side_effect = Exception("Network error")
result = extraction_service._detect_service_info("https://invalid.url") result = await extraction_service._detect_service_info("https://invalid.url")
assert result is None assert result is None
@@ -90,27 +91,14 @@ class TestExtractionService:
url = "https://www.youtube.com/watch?v=test123" url = "https://www.youtube.com/watch?v=test123"
user_id = 1 user_id = 1
# Mock service detection # Mock repository call - no service detection happens during creation
service_info = {
"service": "youtube",
"service_id": "test123",
"title": "Test Video",
}
with patch.object(
extraction_service, "_detect_service_info", return_value=service_info
):
# Mock repository calls
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
return_value=None
)
mock_extraction = Extraction( mock_extraction = Extraction(
id=1, id=1,
url=url, url=url,
user_id=user_id, user_id=user_id,
service="youtube", service=None, # Service detection deferred to processing
service_id="test123", service_id=None, # Service detection deferred to processing
title="Test Video", title=None, # Service detection deferred to processing
status="pending", status="pending",
) )
extraction_service.extraction_repo.create = AsyncMock( extraction_service.extraction_repo.create = AsyncMock(
@@ -120,17 +108,87 @@ class TestExtractionService:
result = await extraction_service.create_extraction(url, user_id) result = await extraction_service.create_extraction(url, user_id)
assert result["id"] == 1 assert result["id"] == 1
assert result["service"] == "youtube" assert result["service"] is None # Not detected during creation
assert result["service_id"] == "test123" assert result["service_id"] is None # Not detected during creation
assert result["title"] == "Test Video" assert result["title"] is None # Not detected during creation
assert result["status"] == "pending" assert result["status"] == "pending"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_extraction_duplicate(self, extraction_service): async def test_create_extraction_basic(self, extraction_service):
"""Test extraction creation with duplicate service/service_id.""" """Test basic extraction creation without validation."""
url = "https://www.youtube.com/watch?v=test123" url = "https://www.youtube.com/watch?v=test123"
user_id = 1 user_id = 1
# Mock repository call - creation always succeeds now
mock_extraction = Extraction(
id=2,
url=url,
user_id=user_id,
service=None,
service_id=None,
title=None,
status="pending",
)
extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction
)
result = await extraction_service.create_extraction(url, user_id)
assert result["id"] == 2
assert result["url"] == url
assert result["status"] == "pending"
@pytest.mark.asyncio
async def test_create_extraction_any_url(self, extraction_service):
"""Test extraction creation accepts any URL."""
url = "https://invalid.url"
user_id = 1
# Mock repository call - even invalid URLs are accepted during creation
mock_extraction = Extraction(
id=3,
url=url,
user_id=user_id,
service=None,
service_id=None,
title=None,
status="pending",
)
extraction_service.extraction_repo.create = AsyncMock(
return_value=mock_extraction
)
result = await extraction_service.create_extraction(url, user_id)
assert result["id"] == 3
assert result["url"] == url
assert result["status"] == "pending"
@pytest.mark.asyncio
async def test_process_extraction_with_service_detection(self, extraction_service):
"""Test extraction processing with service detection."""
extraction_id = 1
# Mock extraction without service info
mock_extraction = Extraction(
id=extraction_id,
url="https://www.youtube.com/watch?v=test123",
user_id=1,
service=None,
service_id=None,
title=None,
status="pending",
)
extraction_service.extraction_repo.get_by_id = AsyncMock(
return_value=mock_extraction
)
extraction_service.extraction_repo.update = AsyncMock()
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
return_value=None
)
# Mock service detection # Mock service detection
service_info = { service_info = {
"service": "youtube", "service": "youtube",
@@ -138,38 +196,35 @@ class TestExtractionService:
"title": "Test Video", "title": "Test Video",
} }
existing_extraction = Extraction( with (
id=1, patch.object(
url=url,
user_id=2, # Different user
service="youtube",
service_id="test123",
status="completed",
)
with patch.object(
extraction_service, "_detect_service_info", return_value=service_info extraction_service, "_detect_service_info", return_value=service_info
),
patch.object(extraction_service, "_extract_media") as mock_extract,
patch.object(extraction_service, "_move_files_to_final_location") as mock_move,
patch.object(extraction_service, "_create_sound_record") as mock_create_sound,
patch.object(extraction_service, "_normalize_sound") as mock_normalize,
patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist,
): ):
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock( mock_sound = Sound(id=42, type="EXT", name="Test", filename="test.mp3")
return_value=existing_extraction mock_extract.return_value = (Path("/fake/audio.mp3"), None)
mock_move.return_value = (Path("/final/audio.mp3"), None)
mock_create_sound.return_value = mock_sound
result = await extraction_service.process_extraction(extraction_id)
# Verify service detection was called
extraction_service._detect_service_info.assert_called_once_with(
"https://www.youtube.com/watch?v=test123"
) )
with pytest.raises(ValueError, match="Extraction already exists"): # Verify extraction was updated with service info
await extraction_service.create_extraction(url, user_id) extraction_service.extraction_repo.update.assert_called()
@pytest.mark.asyncio assert result["status"] == "completed"
async def test_create_extraction_invalid_url(self, extraction_service): assert result["service"] == "youtube"
"""Test extraction creation with invalid URL.""" assert result["service_id"] == "test123"
url = "https://invalid.url" assert result["title"] == "Test Video"
user_id = 1
with patch.object(
extraction_service, "_detect_service_info", return_value=None
):
with pytest.raises(
ValueError, match="Unable to detect service information"
):
await extraction_service.create_extraction(url, user_id)
def test_ensure_unique_filename(self, extraction_service): def test_ensure_unique_filename(self, extraction_service):
"""Test unique filename generation.""" """Test unique filename generation."""