diff --git a/app/repositories/extraction.py b/app/repositories/extraction.py index 2b791e2..da7beb6 100644 --- a/app/repositories/extraction.py +++ b/app/repositories/extraction.py @@ -39,6 +39,15 @@ class ExtractionRepository(BaseRepository[Extraction]): ) return list(result.all()) + async def get_by_status(self, status: str) -> list[Extraction]: + """Get all extractions by status.""" + result = await self.session.exec( + select(Extraction) + .where(Extraction.status == status) + .order_by(Extraction.created_at) + ) + return list(result.all()) + async def get_pending_extractions(self) -> list[tuple[Extraction, User]]: """Get all pending extractions.""" result = await self.session.exec( diff --git a/app/services/extraction.py b/app/services/extraction.py index fd7ce52..742e0b1 100644 --- a/app/services/extraction.py +++ b/app/services/extraction.py @@ -168,12 +168,31 @@ class ExtractionService: extraction_service_id = extraction.service_id extraction_title = extraction.title + # Get user information for return value + try: + user = await self.user_repo.get_by_id(user_id) + user_name = user.name if user else None + except Exception: + logger.warning("Failed to get user %d for extraction", user_id) + user_name = None + logger.info("Processing extraction %d: %s", extraction_id, extraction_url) try: # Update status to processing await self.extraction_repo.update(extraction, {"status": "processing"}) + # Emit WebSocket event for processing start + await self._emit_extraction_event( + user_id, + { + "extraction_id": extraction_id, + "status": "processing", + "title": extraction_title or "Processing extraction...", + "url": extraction_url, + }, + ) + # Detect service info if not already available if not extraction_service or not extraction_service_id: logger.info("Detecting service info for extraction %d", extraction_id) @@ -184,9 +203,16 @@ class ExtractionService: raise ValueError(msg) # Check if extraction already exists for this service + service_name = service_info["service"] + service_id_val = service_info["service_id"] + + if not service_name or not service_id_val: + msg = "Service info is incomplete" + raise ValueError(msg) + existing = await self.extraction_repo.get_by_service_and_id( - service_info["service"], - service_info["service_id"], + service_name, + service_id_val, ) if existing and existing.id != extraction_id: error_msg = ( @@ -209,6 +235,16 @@ class ExtractionService: extraction_service_id = service_info["service_id"] extraction_title = service_info.get("title") or extraction_title + await self._emit_extraction_event( + user_id, + { + "extraction_id": extraction_id, + "status": "processing", + "title": extraction_title, + "url": extraction_url, + }, + ) + # Extract audio and thumbnail audio_file, thumbnail_file = await self._extract_media( extraction_id, @@ -258,15 +294,20 @@ class ExtractionService: }, ) - logger.info("Successfully processed extraction %d", extraction_id) - except Exception as e: - error_msg = str(e) - logger.exception( - "Failed to process extraction %d: %s", - extraction_id, - error_msg, + # Emit WebSocket event for completion + await self._emit_extraction_event( + user_id, + { + "extraction_id": extraction_id, + "status": "completed", + "title": extraction_title, + "url": extraction_url, + "sound_id": sound_id, + }, ) - else: + + logger.info("Successfully processed extraction %d", extraction_id) + # Get updated extraction to get latest timestamps updated_extraction = await self.extraction_repo.get_by_id(extraction_id) return { @@ -279,6 +320,7 @@ class ExtractionService: "error": None, "sound_id": sound_id, "user_id": user_id, + "user_name": user_name, "created_at": ( updated_extraction.created_at.isoformat() if updated_extraction @@ -291,6 +333,26 @@ class ExtractionService: ), } + except Exception as e: + error_msg = str(e) + logger.exception( + "Failed to process extraction %d: %s", + extraction_id, + error_msg, + ) + + # Emit WebSocket event for failure + await self._emit_extraction_event( + user_id, + { + "extraction_id": extraction_id, + "status": "failed", + "title": extraction_title or "Extraction failed", + "url": extraction_url, + "error": error_msg, + }, + ) + # Update extraction with error await self.extraction_repo.update( extraction, @@ -312,6 +374,7 @@ class ExtractionService: "error": error_msg, "sound_id": None, "user_id": user_id, + "user_name": user_name, "created_at": ( updated_extraction.created_at.isoformat() if updated_extraction @@ -549,6 +612,21 @@ class ExtractionService: ) # Don't fail the extraction if playlist addition fails + async def _emit_extraction_event(self, user_id: int, data: dict) -> None: + """Emit WebSocket event for extraction status updates to all users.""" + try: + # Import here to avoid circular imports + from app.services.socket import socket_manager # noqa: PLC0415 + + await socket_manager.broadcast_to_all("extraction_status_update", data) + logger.debug( + "Broadcasted extraction event (initiated by user %d): %s", + user_id, + data["status"], + ) + except Exception: + logger.exception("Failed to emit extraction event") + async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None: """Get extraction information by ID.""" extraction = await self.extraction_repo.get_by_id(extraction_id) diff --git a/app/services/extraction_processor.py b/app/services/extraction_processor.py index 1835ac6..31dcccf 100644 --- a/app/services/extraction_processor.py +++ b/app/services/extraction_processor.py @@ -35,6 +35,9 @@ class ExtractionProcessor: logger.warning("Extraction processor is already running") return + # Reset any stuck extractions from previous runs + await self._reset_stuck_extractions() + self.shutdown_event.clear() self.processor_task = asyncio.create_task(self._process_queue()) logger.info("Started extraction processor") @@ -179,6 +182,46 @@ class ExtractionProcessor: self.max_concurrent, ) + async def _reset_stuck_extractions(self) -> None: + """Reset any extractions stuck in 'processing' status back to 'pending'.""" + try: + async with AsyncSession(engine) as session: + extraction_service = ExtractionService(session) + + # Get all extractions stuck in processing status + stuck_extractions = ( + await extraction_service.extraction_repo.get_by_status("processing") + ) + + if not stuck_extractions: + logger.info("No stuck extractions found to reset") + return + + reset_count = 0 + for extraction in stuck_extractions: + try: + await extraction_service.extraction_repo.update( + extraction, {"status": "pending", "error": None} + ) + reset_count += 1 + logger.info( + "Reset stuck extraction %d from processing to pending", + extraction.id, + ) + except Exception: + logger.exception( + "Failed to reset extraction %d", extraction.id + ) + + await session.commit() + logger.info( + "Successfully reset %d stuck extractions from processing to pending", + reset_count, + ) + + except Exception: + logger.exception("Failed to reset stuck extractions") + def get_status(self) -> dict: """Get the current status of the extraction processor.""" return { diff --git a/tests/repositories/test_extraction.py b/tests/repositories/test_extraction.py index f6f1c37..0f74b52 100644 --- a/tests/repositories/test_extraction.py +++ b/tests/repositories/test_extraction.py @@ -129,3 +129,36 @@ class TestExtractionRepository: assert result.sound_id == TEST_SOUND_ID extraction_repo.session.commit.assert_called_once() extraction_repo.session.refresh.assert_called_once_with(extraction) + + @pytest.mark.asyncio + async def test_get_by_status(self, extraction_repo): + """Test getting extractions by status.""" + mock_extractions = [ + Extraction( + id=1, + service="youtube", + service_id="test123", + url="https://www.youtube.com/watch?v=test1", + user_id=1, + status="processing", + ), + Extraction( + id=2, + service="youtube", + service_id="test456", + url="https://www.youtube.com/watch?v=test2", + user_id=1, + status="processing", + ), + ] + + mock_result = Mock() + mock_result.all.return_value = mock_extractions + + extraction_repo.session.exec = AsyncMock(return_value=mock_result) + + result = await extraction_repo.get_by_status("processing") + + assert len(result) == 2 + assert all(extraction.status == "processing" for extraction in result) + extraction_repo.session.exec.assert_called_once() diff --git a/tests/services/test_extraction.py b/tests/services/test_extraction.py index 8640c0d..49637e9 100644 --- a/tests/services/test_extraction.py +++ b/tests/services/test_extraction.py @@ -217,6 +217,10 @@ class TestExtractionService: extraction_service.extraction_repo.get_by_service_and_id = AsyncMock( return_value=None, ) + # Mock user repository + from app.models.user import User + mock_user = User(id=1, name="Test User") + extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user) # Mock service detection service_info = {