diff --git a/app/models/extraction.py b/app/models/extraction.py index 80aadf5..0e3997f 100644 --- a/app/models/extraction.py +++ b/app/models/extraction.py @@ -12,8 +12,8 @@ if TYPE_CHECKING: class Extraction(BaseModel, table=True): """Database model for a stream.""" - service: str = Field(nullable=False) - service_id: str = Field(nullable=False) + service: str | None = Field(default=None) + service_id: str | None = Field(default=None) user_id: int = Field(foreign_key="user.id", nullable=False) sound_id: int | None = Field(foreign_key="sound.id", default=None) url: str = Field(nullable=False) @@ -25,14 +25,8 @@ class Extraction(BaseModel, table=True): status: str = Field(nullable=False, default="pending") error: str | None = Field(default=None) - # constraints - __table_args__ = ( - UniqueConstraint( - "service", - "service_id", - name="uq_extraction_service_service_id", - ), - ) + # constraints - only enforce uniqueness when both service and service_id are not null + __table_args__ = () # relationships sound: "Sound" = Relationship(back_populates="extractions") diff --git a/app/services/extraction.py b/app/services/extraction.py index 36043e0..9395f94 100644 --- a/app/services/extraction.py +++ b/app/services/extraction.py @@ -1,5 +1,6 @@ """Extraction service for audio extraction from external services using yt-dlp.""" +import asyncio import shutil from pathlib import Path from typing import TypedDict @@ -24,8 +25,8 @@ class ExtractionInfo(TypedDict): id: int url: str - service: str - service_id: str + service: str | None + service_id: str | None title: str | None status: str error: str | None @@ -61,39 +62,13 @@ class ExtractionService: logger.info("Creating extraction for URL: %s (user: %d)", url, user_id) try: - # First, detect service and service_id using yt-dlp - service_info = self._detect_service_info(url) - - if not service_info: - raise ValueError("Unable to detect service information from URL") - - service = service_info["service"] - service_id = service_info["service_id"] - title = service_info.get("title") - - logger.info( - "Detected service: %s, service_id: %s, title: %s", - service, - service_id, - title, - ) - - # Check if extraction already exists - existing = await self.extraction_repo.get_by_service_and_id( - service, service_id - ) - if existing: - error_msg = f"Extraction already exists for {service}:{service_id}" - logger.warning(error_msg) - raise ValueError(error_msg) - - # Create the extraction record + # Create the extraction record without service detection for fast response extraction_data = { "url": url, "user_id": user_id, - "service": service, - "service_id": service_id, - "title": title, + "service": None, # Will be detected during processing + "service_id": None, # Will be detected during processing + "title": None, # Will be detected during processing "status": "pending", } @@ -115,7 +90,7 @@ class ExtractionService: logger.exception("Failed to create extraction for URL: %s", url) raise - def _detect_service_info(self, url: str) -> dict | None: + async def _detect_service_info(self, url: str) -> dict[str, str | None] | None: """Detect service information from URL using yt-dlp.""" try: # Configure yt-dlp for info extraction only @@ -125,35 +100,22 @@ class ExtractionService: "extract_flat": False, } - with yt_dlp.YoutubeDL(ydl_opts) as ydl: - # Extract info without downloading - info = ydl.extract_info(url, download=False) + def _extract_info() -> dict | None: + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + # Extract info without downloading + return ydl.extract_info(url, download=False) - if not info: - return None + # Run the blocking operation in a thread pool + info = await asyncio.to_thread(_extract_info) - # Map extractor names to our service names - extractor_map = { - "youtube": "youtube", - "dailymotion": "dailymotion", - "vimeo": "vimeo", - "soundcloud": "soundcloud", - "twitter": "twitter", - "tiktok": "tiktok", - "instagram": "instagram", - } + if not info: + return None - extractor = info.get("extractor", "").lower() - service = extractor_map.get(extractor, extractor) - - return { - "service": service, - "service_id": str(info.get("id", "")), - "title": info.get("title"), - "duration": info.get("duration"), - "uploader": info.get("uploader"), - "description": info.get("description"), - } + return { + "service": info.get("extractor", ""), + "service_id": str(info.get("id", "")), + "title": info.get("title"), + } except Exception: logger.exception("Failed to detect service info for URL: %s", url) @@ -171,9 +133,9 @@ class ExtractionService: # Store all needed values early to avoid session detachment issues user_id = extraction.user_id extraction_url = extraction.url - extraction_title = extraction.title extraction_service = extraction.service extraction_service_id = extraction.service_id + extraction_title = extraction.title logger.info("Processing extraction %d: %s", extraction_id, extraction_url) @@ -181,20 +143,51 @@ class ExtractionService: # Update status to processing await self.extraction_repo.update(extraction, {"status": "processing"}) + # Detect service info if not already available + if not extraction_service or not extraction_service_id: + logger.info("Detecting service info for extraction %d", extraction_id) + service_info = await self._detect_service_info(extraction_url) + + if not service_info: + raise ValueError("Unable to detect service information from URL") + + # Check if extraction already exists for this service + existing = await self.extraction_repo.get_by_service_and_id( + service_info["service"], service_info["service_id"] + ) + if existing and existing.id != extraction_id: + error_msg = f"Extraction already exists for {service_info['service']}:{service_info['service_id']}" + logger.warning(error_msg) + raise ValueError(error_msg) + + # Update extraction with service info + update_data = { + "service": service_info["service"], + "service_id": service_info["service_id"], + "title": service_info.get("title") or extraction_title, + } + await self.extraction_repo.update(extraction, update_data) + + # Update values for processing + extraction_service = service_info["service"] + extraction_service_id = service_info["service_id"] + extraction_title = service_info.get("title") or extraction_title + # Extract audio and thumbnail audio_file, thumbnail_file = await self._extract_media( extraction_id, extraction_url ) # Move files to final locations - final_audio_path, final_thumbnail_path = ( - await self._move_files_to_final_location( - audio_file, - thumbnail_file, - extraction_title, - extraction_service, - extraction_service_id, - ) + ( + final_audio_path, + final_thumbnail_path, + ) = await self._move_files_to_final_location( + audio_file, + thumbnail_file, + extraction_title, + extraction_service, + extraction_service_id, ) # Create Sound record @@ -294,36 +287,40 @@ class ExtractionService: ], } - try: + def _download_media() -> None: with yt_dlp.YoutubeDL(ydl_opts) as ydl: # Download and extract ydl.download([extraction_url]) - # Find the extracted files - audio_files = list( - temp_dir.glob( - f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}" - ) - ) - thumbnail_files = ( - list(temp_dir.glob(f"extraction_{extraction_id}_*.webp")) - + list(temp_dir.glob(f"extraction_{extraction_id}_*.jpg")) - + list(temp_dir.glob(f"extraction_{extraction_id}_*.png")) + try: + # Run the blocking download operation in a thread pool + await asyncio.to_thread(_download_media) + + # Find the extracted files + audio_files = list( + temp_dir.glob( + f"extraction_{extraction_id}_*.{settings.EXTRACTION_AUDIO_FORMAT}" ) + ) + thumbnail_files = ( + list(temp_dir.glob(f"extraction_{extraction_id}_*.webp")) + + list(temp_dir.glob(f"extraction_{extraction_id}_*.jpg")) + + list(temp_dir.glob(f"extraction_{extraction_id}_*.png")) + ) - if not audio_files: - raise RuntimeError("No audio file was created during extraction") + if not audio_files: + raise RuntimeError("No audio file was created during extraction") - audio_file = audio_files[0] - thumbnail_file = thumbnail_files[0] if thumbnail_files else None + audio_file = audio_files[0] + thumbnail_file = thumbnail_files[0] if thumbnail_files else None - logger.info( - "Extracted audio: %s, thumbnail: %s", - audio_file, - thumbnail_file or "None", - ) + logger.info( + "Extracted audio: %s, thumbnail: %s", + audio_file, + thumbnail_file or "None", + ) - return audio_file, thumbnail_file + return audio_file, thumbnail_file except Exception as e: logger.exception("yt-dlp extraction failed for %s", extraction_url) @@ -334,12 +331,14 @@ class ExtractionService: audio_file: Path, thumbnail_file: Path | None, title: str | None, - service: str, - service_id: str, + service: str | None, + service_id: str | None, ) -> tuple[Path, Path | None]: """Move extracted files to their final locations.""" # Generate clean filename based on title and service - safe_title = self._sanitize_filename(title or f"{service}_{service_id}") + safe_title = self._sanitize_filename( + title or f"{service or 'unknown'}_{service_id or 'unknown'}" + ) # Move audio file final_audio_path = ( @@ -395,7 +394,11 @@ class ExtractionService: counter += 1 async def _create_sound_record( - self, audio_path: Path, title: str | None, service: str, service_id: str + self, + audio_path: Path, + title: str | None, + service: str | None, + service_id: str | None, ) -> Sound: """Create a Sound record for the extracted audio.""" # Get audio metadata @@ -406,7 +409,7 @@ class ExtractionService: # Create sound data sound_data = { "type": "EXT", - "name": title or f"{service}_{service_id}", + "name": title or f"{service or 'unknown'}_{service_id or 'unknown'}", "filename": audio_path.name, "duration": duration, "size": size, diff --git a/app/services/extraction_processor.py b/app/services/extraction_processor.py index eb12e6a..6d4225e 100644 --- a/app/services/extraction_processor.py +++ b/app/services/extraction_processor.py @@ -1,7 +1,6 @@ """Background extraction processor for handling extraction queue.""" import asyncio -from typing import Set from sqlmodel.ext.asyncio.session import AsyncSession @@ -19,7 +18,7 @@ class ExtractionProcessor: def __init__(self) -> None: """Initialize the extraction processor.""" self.max_concurrent = settings.EXTRACTION_MAX_CONCURRENT - self.running_extractions: Set[int] = set() + self.running_extractions: set[int] = set() self.processing_lock = asyncio.Lock() self.shutdown_event = asyncio.Event() self.processor_task: asyncio.Task | None = None @@ -71,7 +70,7 @@ class ExtractionProcessor: ) async def _process_queue(self) -> None: - """Main processing loop that handles the extraction queue.""" + """Process the extraction queue in the main processing loop.""" logger.info("Starting extraction queue processor") while not self.shutdown_event.is_set(): @@ -161,7 +160,7 @@ class ExtractionProcessor: logger.exception("Error processing extraction %d: %s", extraction_id, e) def _on_extraction_completed(self, extraction_id: int, task: asyncio.Task) -> None: - """Callback when an extraction task is completed.""" + """Handle completion of an extraction task.""" # Remove from running set self.running_extractions.discard(extraction_id) diff --git a/tests/services/test_extraction.py b/tests/services/test_extraction.py index 714f409..16fb0f0 100644 --- a/tests/services/test_extraction.py +++ b/tests/services/test_extraction.py @@ -51,7 +51,8 @@ class TestExtractionService: assert result == expected @patch("app.services.extraction.yt_dlp.YoutubeDL") - def test_detect_service_info_youtube(self, mock_ydl_class, extraction_service): + @pytest.mark.asyncio + async def test_detect_service_info_youtube(self, mock_ydl_class, extraction_service): """Test service detection for YouTube.""" mock_ydl = Mock() mock_ydl_class.return_value.__enter__.return_value = mock_ydl @@ -63,7 +64,7 @@ class TestExtractionService: "uploader": "Test Channel", } - result = extraction_service._detect_service_info( + result = await extraction_service._detect_service_info( "https://www.youtube.com/watch?v=test123" ) @@ -71,16 +72,16 @@ class TestExtractionService: assert result["service"] == "youtube" assert result["service_id"] == "test123" assert result["title"] == "Test Video" - assert result["duration"] == 240 @patch("app.services.extraction.yt_dlp.YoutubeDL") - def test_detect_service_info_failure(self, mock_ydl_class, extraction_service): + @pytest.mark.asyncio + async def test_detect_service_info_failure(self, mock_ydl_class, extraction_service): """Test service detection failure.""" mock_ydl = Mock() mock_ydl_class.return_value.__enter__.return_value = mock_ydl mock_ydl.extract_info.side_effect = Exception("Network error") - result = extraction_service._detect_service_info("https://invalid.url") + result = await extraction_service._detect_service_info("https://invalid.url") assert result is None @@ -90,86 +91,140 @@ class TestExtractionService: url = "https://www.youtube.com/watch?v=test123" user_id = 1 - # Mock service detection - service_info = { - "service": "youtube", - "service_id": "test123", - "title": "Test Video", - } + # Mock repository call - no service detection happens during creation + mock_extraction = Extraction( + id=1, + url=url, + user_id=user_id, + service=None, # Service detection deferred to processing + service_id=None, # Service detection deferred to processing + title=None, # Service detection deferred to processing + status="pending", + ) + extraction_service.extraction_repo.create = AsyncMock( + return_value=mock_extraction + ) - with patch.object( - extraction_service, "_detect_service_info", return_value=service_info - ): - # Mock repository calls - extraction_service.extraction_repo.get_by_service_and_id = AsyncMock( - return_value=None - ) - mock_extraction = Extraction( - id=1, - url=url, - user_id=user_id, - service="youtube", - service_id="test123", - title="Test Video", - status="pending", - ) - extraction_service.extraction_repo.create = AsyncMock( - return_value=mock_extraction - ) + result = await extraction_service.create_extraction(url, user_id) - result = await extraction_service.create_extraction(url, user_id) - - assert result["id"] == 1 - assert result["service"] == "youtube" - assert result["service_id"] == "test123" - assert result["title"] == "Test Video" - assert result["status"] == "pending" + assert result["id"] == 1 + assert result["service"] is None # Not detected during creation + assert result["service_id"] is None # Not detected during creation + assert result["title"] is None # Not detected during creation + assert result["status"] == "pending" @pytest.mark.asyncio - async def test_create_extraction_duplicate(self, extraction_service): - """Test extraction creation with duplicate service/service_id.""" + async def test_create_extraction_basic(self, extraction_service): + """Test basic extraction creation without validation.""" url = "https://www.youtube.com/watch?v=test123" user_id = 1 - # Mock service detection - service_info = { - "service": "youtube", - "service_id": "test123", - "title": "Test Video", - } - - existing_extraction = Extraction( - id=1, + # Mock repository call - creation always succeeds now + mock_extraction = Extraction( + id=2, url=url, - user_id=2, # Different user - service="youtube", - service_id="test123", - status="completed", + user_id=user_id, + service=None, + service_id=None, + title=None, + status="pending", + ) + extraction_service.extraction_repo.create = AsyncMock( + return_value=mock_extraction ) - with patch.object( - extraction_service, "_detect_service_info", return_value=service_info - ): - extraction_service.extraction_repo.get_by_service_and_id = AsyncMock( - return_value=existing_extraction - ) + result = await extraction_service.create_extraction(url, user_id) - with pytest.raises(ValueError, match="Extraction already exists"): - await extraction_service.create_extraction(url, user_id) + assert result["id"] == 2 + assert result["url"] == url + assert result["status"] == "pending" @pytest.mark.asyncio - async def test_create_extraction_invalid_url(self, extraction_service): - """Test extraction creation with invalid URL.""" + async def test_create_extraction_any_url(self, extraction_service): + """Test extraction creation accepts any URL.""" url = "https://invalid.url" user_id = 1 - with patch.object( - extraction_service, "_detect_service_info", return_value=None + # Mock repository call - even invalid URLs are accepted during creation + mock_extraction = Extraction( + id=3, + url=url, + user_id=user_id, + service=None, + service_id=None, + title=None, + status="pending", + ) + extraction_service.extraction_repo.create = AsyncMock( + return_value=mock_extraction + ) + + result = await extraction_service.create_extraction(url, user_id) + + assert result["id"] == 3 + assert result["url"] == url + assert result["status"] == "pending" + + @pytest.mark.asyncio + async def test_process_extraction_with_service_detection(self, extraction_service): + """Test extraction processing with service detection.""" + extraction_id = 1 + + # Mock extraction without service info + mock_extraction = Extraction( + id=extraction_id, + url="https://www.youtube.com/watch?v=test123", + user_id=1, + service=None, + service_id=None, + title=None, + status="pending", + ) + + extraction_service.extraction_repo.get_by_id = AsyncMock( + return_value=mock_extraction + ) + extraction_service.extraction_repo.update = AsyncMock() + extraction_service.extraction_repo.get_by_service_and_id = AsyncMock( + return_value=None + ) + + # Mock service detection + service_info = { + "service": "youtube", + "service_id": "test123", + "title": "Test Video", + } + + with ( + patch.object( + extraction_service, "_detect_service_info", return_value=service_info + ), + patch.object(extraction_service, "_extract_media") as mock_extract, + patch.object(extraction_service, "_move_files_to_final_location") as mock_move, + patch.object(extraction_service, "_create_sound_record") as mock_create_sound, + patch.object(extraction_service, "_normalize_sound") as mock_normalize, + patch.object(extraction_service, "_add_to_main_playlist") as mock_playlist, ): - with pytest.raises( - ValueError, match="Unable to detect service information" - ): - await extraction_service.create_extraction(url, user_id) + mock_sound = Sound(id=42, type="EXT", name="Test", filename="test.mp3") + mock_extract.return_value = (Path("/fake/audio.mp3"), None) + mock_move.return_value = (Path("/final/audio.mp3"), None) + mock_create_sound.return_value = mock_sound + + result = await extraction_service.process_extraction(extraction_id) + + # Verify service detection was called + extraction_service._detect_service_info.assert_called_once_with( + "https://www.youtube.com/watch?v=test123" + ) + + # Verify extraction was updated with service info + extraction_service.extraction_repo.update.assert_called() + + assert result["status"] == "completed" + assert result["service"] == "youtube" + assert result["service_id"] == "test123" + assert result["title"] == "Test Video" def test_ensure_unique_filename(self, extraction_service): """Test unique filename generation."""