feat: Add method to get extractions by status and implement user info retrieval in extraction service
Some checks failed
Backend CI / lint (push) Failing after 4m53s
Backend CI / test (push) Failing after 4m31s

This commit is contained in:
JSC
2025-08-24 13:24:48 +02:00
parent 28faca55bc
commit 16eb789539
5 changed files with 177 additions and 10 deletions

View File

@@ -39,6 +39,15 @@ class ExtractionRepository(BaseRepository[Extraction]):
) )
return list(result.all()) return list(result.all())
async def get_by_status(self, status: str) -> list[Extraction]:
"""Get all extractions by status."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == status)
.order_by(Extraction.created_at)
)
return list(result.all())
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]: async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
"""Get all pending extractions.""" """Get all pending extractions."""
result = await self.session.exec( result = await self.session.exec(

View File

@@ -168,12 +168,31 @@ class ExtractionService:
extraction_service_id = extraction.service_id extraction_service_id = extraction.service_id
extraction_title = extraction.title extraction_title = extraction.title
# Get user information for return value
try:
user = await self.user_repo.get_by_id(user_id)
user_name = user.name if user else None
except Exception:
logger.warning("Failed to get user %d for extraction", user_id)
user_name = None
logger.info("Processing extraction %d: %s", extraction_id, extraction_url) logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
try: try:
# Update status to processing # Update status to processing
await self.extraction_repo.update(extraction, {"status": "processing"}) await self.extraction_repo.update(extraction, {"status": "processing"})
# Emit WebSocket event for processing start
await self._emit_extraction_event(
user_id,
{
"extraction_id": extraction_id,
"status": "processing",
"title": extraction_title or "Processing extraction...",
"url": extraction_url,
},
)
# Detect service info if not already available # Detect service info if not already available
if not extraction_service or not extraction_service_id: if not extraction_service or not extraction_service_id:
logger.info("Detecting service info for extraction %d", extraction_id) logger.info("Detecting service info for extraction %d", extraction_id)
@@ -184,9 +203,16 @@ class ExtractionService:
raise ValueError(msg) raise ValueError(msg)
# Check if extraction already exists for this service # Check if extraction already exists for this service
service_name = service_info["service"]
service_id_val = service_info["service_id"]
if not service_name or not service_id_val:
msg = "Service info is incomplete"
raise ValueError(msg)
existing = await self.extraction_repo.get_by_service_and_id( existing = await self.extraction_repo.get_by_service_and_id(
service_info["service"], service_name,
service_info["service_id"], service_id_val,
) )
if existing and existing.id != extraction_id: if existing and existing.id != extraction_id:
error_msg = ( error_msg = (
@@ -209,6 +235,16 @@ class ExtractionService:
extraction_service_id = service_info["service_id"] extraction_service_id = service_info["service_id"]
extraction_title = service_info.get("title") or extraction_title extraction_title = service_info.get("title") or extraction_title
await self._emit_extraction_event(
user_id,
{
"extraction_id": extraction_id,
"status": "processing",
"title": extraction_title,
"url": extraction_url,
},
)
# Extract audio and thumbnail # Extract audio and thumbnail
audio_file, thumbnail_file = await self._extract_media( audio_file, thumbnail_file = await self._extract_media(
extraction_id, extraction_id,
@@ -258,15 +294,20 @@ class ExtractionService:
}, },
) )
logger.info("Successfully processed extraction %d", extraction_id) # Emit WebSocket event for completion
except Exception as e: await self._emit_extraction_event(
error_msg = str(e) user_id,
logger.exception( {
"Failed to process extraction %d: %s", "extraction_id": extraction_id,
extraction_id, "status": "completed",
error_msg, "title": extraction_title,
"url": extraction_url,
"sound_id": sound_id,
},
) )
else:
logger.info("Successfully processed extraction %d", extraction_id)
# Get updated extraction to get latest timestamps # Get updated extraction to get latest timestamps
updated_extraction = await self.extraction_repo.get_by_id(extraction_id) updated_extraction = await self.extraction_repo.get_by_id(extraction_id)
return { return {
@@ -279,6 +320,7 @@ class ExtractionService:
"error": None, "error": None,
"sound_id": sound_id, "sound_id": sound_id,
"user_id": user_id, "user_id": user_id,
"user_name": user_name,
"created_at": ( "created_at": (
updated_extraction.created_at.isoformat() updated_extraction.created_at.isoformat()
if updated_extraction if updated_extraction
@@ -291,6 +333,26 @@ class ExtractionService:
), ),
} }
except Exception as e:
error_msg = str(e)
logger.exception(
"Failed to process extraction %d: %s",
extraction_id,
error_msg,
)
# Emit WebSocket event for failure
await self._emit_extraction_event(
user_id,
{
"extraction_id": extraction_id,
"status": "failed",
"title": extraction_title or "Extraction failed",
"url": extraction_url,
"error": error_msg,
},
)
# Update extraction with error # Update extraction with error
await self.extraction_repo.update( await self.extraction_repo.update(
extraction, extraction,
@@ -312,6 +374,7 @@ class ExtractionService:
"error": error_msg, "error": error_msg,
"sound_id": None, "sound_id": None,
"user_id": user_id, "user_id": user_id,
"user_name": user_name,
"created_at": ( "created_at": (
updated_extraction.created_at.isoformat() updated_extraction.created_at.isoformat()
if updated_extraction if updated_extraction
@@ -549,6 +612,21 @@ class ExtractionService:
) )
# Don't fail the extraction if playlist addition fails # Don't fail the extraction if playlist addition fails
async def _emit_extraction_event(self, user_id: int, data: dict) -> None:
"""Emit WebSocket event for extraction status updates to all users."""
try:
# Import here to avoid circular imports
from app.services.socket import socket_manager # noqa: PLC0415
await socket_manager.broadcast_to_all("extraction_status_update", data)
logger.debug(
"Broadcasted extraction event (initiated by user %d): %s",
user_id,
data["status"],
)
except Exception:
logger.exception("Failed to emit extraction event")
async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None: async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None:
"""Get extraction information by ID.""" """Get extraction information by ID."""
extraction = await self.extraction_repo.get_by_id(extraction_id) extraction = await self.extraction_repo.get_by_id(extraction_id)

View File

@@ -35,6 +35,9 @@ class ExtractionProcessor:
logger.warning("Extraction processor is already running") logger.warning("Extraction processor is already running")
return return
# Reset any stuck extractions from previous runs
await self._reset_stuck_extractions()
self.shutdown_event.clear() self.shutdown_event.clear()
self.processor_task = asyncio.create_task(self._process_queue()) self.processor_task = asyncio.create_task(self._process_queue())
logger.info("Started extraction processor") logger.info("Started extraction processor")
@@ -179,6 +182,46 @@ class ExtractionProcessor:
self.max_concurrent, self.max_concurrent,
) )
async def _reset_stuck_extractions(self) -> None:
"""Reset any extractions stuck in 'processing' status back to 'pending'."""
try:
async with AsyncSession(engine) as session:
extraction_service = ExtractionService(session)
# Get all extractions stuck in processing status
stuck_extractions = (
await extraction_service.extraction_repo.get_by_status("processing")
)
if not stuck_extractions:
logger.info("No stuck extractions found to reset")
return
reset_count = 0
for extraction in stuck_extractions:
try:
await extraction_service.extraction_repo.update(
extraction, {"status": "pending", "error": None}
)
reset_count += 1
logger.info(
"Reset stuck extraction %d from processing to pending",
extraction.id,
)
except Exception:
logger.exception(
"Failed to reset extraction %d", extraction.id
)
await session.commit()
logger.info(
"Successfully reset %d stuck extractions from processing to pending",
reset_count,
)
except Exception:
logger.exception("Failed to reset stuck extractions")
def get_status(self) -> dict: def get_status(self) -> dict:
"""Get the current status of the extraction processor.""" """Get the current status of the extraction processor."""
return { return {

View File

@@ -129,3 +129,36 @@ class TestExtractionRepository:
assert result.sound_id == TEST_SOUND_ID assert result.sound_id == TEST_SOUND_ID
extraction_repo.session.commit.assert_called_once() extraction_repo.session.commit.assert_called_once()
extraction_repo.session.refresh.assert_called_once_with(extraction) extraction_repo.session.refresh.assert_called_once_with(extraction)
@pytest.mark.asyncio
async def test_get_by_status(self, extraction_repo):
"""Test getting extractions by status."""
mock_extractions = [
Extraction(
id=1,
service="youtube",
service_id="test123",
url="https://www.youtube.com/watch?v=test1",
user_id=1,
status="processing",
),
Extraction(
id=2,
service="youtube",
service_id="test456",
url="https://www.youtube.com/watch?v=test2",
user_id=1,
status="processing",
),
]
mock_result = Mock()
mock_result.all.return_value = mock_extractions
extraction_repo.session.exec = AsyncMock(return_value=mock_result)
result = await extraction_repo.get_by_status("processing")
assert len(result) == 2
assert all(extraction.status == "processing" for extraction in result)
extraction_repo.session.exec.assert_called_once()

View File

@@ -217,6 +217,10 @@ class TestExtractionService:
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock( extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
return_value=None, return_value=None,
) )
# Mock user repository
from app.models.user import User
mock_user = User(id=1, name="Test User")
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
# Mock service detection # Mock service detection
service_info = { service_info = {