feat: Add method to get extractions by status and implement user info retrieval in extraction service
This commit is contained in:
@@ -39,6 +39,15 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
|||||||
)
|
)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_by_status(self, status: str) -> list[Extraction]:
|
||||||
|
"""Get all extractions by status."""
|
||||||
|
result = await self.session.exec(
|
||||||
|
select(Extraction)
|
||||||
|
.where(Extraction.status == status)
|
||||||
|
.order_by(Extraction.created_at)
|
||||||
|
)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
|
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
|
||||||
"""Get all pending extractions."""
|
"""Get all pending extractions."""
|
||||||
result = await self.session.exec(
|
result = await self.session.exec(
|
||||||
|
|||||||
@@ -168,12 +168,31 @@ class ExtractionService:
|
|||||||
extraction_service_id = extraction.service_id
|
extraction_service_id = extraction.service_id
|
||||||
extraction_title = extraction.title
|
extraction_title = extraction.title
|
||||||
|
|
||||||
|
# Get user information for return value
|
||||||
|
try:
|
||||||
|
user = await self.user_repo.get_by_id(user_id)
|
||||||
|
user_name = user.name if user else None
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to get user %d for extraction", user_id)
|
||||||
|
user_name = None
|
||||||
|
|
||||||
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
|
logger.info("Processing extraction %d: %s", extraction_id, extraction_url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Update status to processing
|
# Update status to processing
|
||||||
await self.extraction_repo.update(extraction, {"status": "processing"})
|
await self.extraction_repo.update(extraction, {"status": "processing"})
|
||||||
|
|
||||||
|
# Emit WebSocket event for processing start
|
||||||
|
await self._emit_extraction_event(
|
||||||
|
user_id,
|
||||||
|
{
|
||||||
|
"extraction_id": extraction_id,
|
||||||
|
"status": "processing",
|
||||||
|
"title": extraction_title or "Processing extraction...",
|
||||||
|
"url": extraction_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Detect service info if not already available
|
# Detect service info if not already available
|
||||||
if not extraction_service or not extraction_service_id:
|
if not extraction_service or not extraction_service_id:
|
||||||
logger.info("Detecting service info for extraction %d", extraction_id)
|
logger.info("Detecting service info for extraction %d", extraction_id)
|
||||||
@@ -184,9 +203,16 @@ class ExtractionService:
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Check if extraction already exists for this service
|
# Check if extraction already exists for this service
|
||||||
|
service_name = service_info["service"]
|
||||||
|
service_id_val = service_info["service_id"]
|
||||||
|
|
||||||
|
if not service_name or not service_id_val:
|
||||||
|
msg = "Service info is incomplete"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
existing = await self.extraction_repo.get_by_service_and_id(
|
existing = await self.extraction_repo.get_by_service_and_id(
|
||||||
service_info["service"],
|
service_name,
|
||||||
service_info["service_id"],
|
service_id_val,
|
||||||
)
|
)
|
||||||
if existing and existing.id != extraction_id:
|
if existing and existing.id != extraction_id:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
@@ -209,6 +235,16 @@ class ExtractionService:
|
|||||||
extraction_service_id = service_info["service_id"]
|
extraction_service_id = service_info["service_id"]
|
||||||
extraction_title = service_info.get("title") or extraction_title
|
extraction_title = service_info.get("title") or extraction_title
|
||||||
|
|
||||||
|
await self._emit_extraction_event(
|
||||||
|
user_id,
|
||||||
|
{
|
||||||
|
"extraction_id": extraction_id,
|
||||||
|
"status": "processing",
|
||||||
|
"title": extraction_title,
|
||||||
|
"url": extraction_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Extract audio and thumbnail
|
# Extract audio and thumbnail
|
||||||
audio_file, thumbnail_file = await self._extract_media(
|
audio_file, thumbnail_file = await self._extract_media(
|
||||||
extraction_id,
|
extraction_id,
|
||||||
@@ -258,15 +294,20 @@ class ExtractionService:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Successfully processed extraction %d", extraction_id)
|
# Emit WebSocket event for completion
|
||||||
except Exception as e:
|
await self._emit_extraction_event(
|
||||||
error_msg = str(e)
|
user_id,
|
||||||
logger.exception(
|
{
|
||||||
"Failed to process extraction %d: %s",
|
"extraction_id": extraction_id,
|
||||||
extraction_id,
|
"status": "completed",
|
||||||
error_msg,
|
"title": extraction_title,
|
||||||
|
"url": extraction_url,
|
||||||
|
"sound_id": sound_id,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
logger.info("Successfully processed extraction %d", extraction_id)
|
||||||
|
|
||||||
# Get updated extraction to get latest timestamps
|
# Get updated extraction to get latest timestamps
|
||||||
updated_extraction = await self.extraction_repo.get_by_id(extraction_id)
|
updated_extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||||
return {
|
return {
|
||||||
@@ -279,6 +320,7 @@ class ExtractionService:
|
|||||||
"error": None,
|
"error": None,
|
||||||
"sound_id": sound_id,
|
"sound_id": sound_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
"user_name": user_name,
|
||||||
"created_at": (
|
"created_at": (
|
||||||
updated_extraction.created_at.isoformat()
|
updated_extraction.created_at.isoformat()
|
||||||
if updated_extraction
|
if updated_extraction
|
||||||
@@ -291,6 +333,26 @@ class ExtractionService:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
logger.exception(
|
||||||
|
"Failed to process extraction %d: %s",
|
||||||
|
extraction_id,
|
||||||
|
error_msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit WebSocket event for failure
|
||||||
|
await self._emit_extraction_event(
|
||||||
|
user_id,
|
||||||
|
{
|
||||||
|
"extraction_id": extraction_id,
|
||||||
|
"status": "failed",
|
||||||
|
"title": extraction_title or "Extraction failed",
|
||||||
|
"url": extraction_url,
|
||||||
|
"error": error_msg,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Update extraction with error
|
# Update extraction with error
|
||||||
await self.extraction_repo.update(
|
await self.extraction_repo.update(
|
||||||
extraction,
|
extraction,
|
||||||
@@ -312,6 +374,7 @@ class ExtractionService:
|
|||||||
"error": error_msg,
|
"error": error_msg,
|
||||||
"sound_id": None,
|
"sound_id": None,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
"user_name": user_name,
|
||||||
"created_at": (
|
"created_at": (
|
||||||
updated_extraction.created_at.isoformat()
|
updated_extraction.created_at.isoformat()
|
||||||
if updated_extraction
|
if updated_extraction
|
||||||
@@ -549,6 +612,21 @@ class ExtractionService:
|
|||||||
)
|
)
|
||||||
# Don't fail the extraction if playlist addition fails
|
# Don't fail the extraction if playlist addition fails
|
||||||
|
|
||||||
|
async def _emit_extraction_event(self, user_id: int, data: dict) -> None:
|
||||||
|
"""Emit WebSocket event for extraction status updates to all users."""
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from app.services.socket import socket_manager # noqa: PLC0415
|
||||||
|
|
||||||
|
await socket_manager.broadcast_to_all("extraction_status_update", data)
|
||||||
|
logger.debug(
|
||||||
|
"Broadcasted extraction event (initiated by user %d): %s",
|
||||||
|
user_id,
|
||||||
|
data["status"],
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to emit extraction event")
|
||||||
|
|
||||||
async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None:
|
async def get_extraction_by_id(self, extraction_id: int) -> ExtractionInfo | None:
|
||||||
"""Get extraction information by ID."""
|
"""Get extraction information by ID."""
|
||||||
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
extraction = await self.extraction_repo.get_by_id(extraction_id)
|
||||||
|
|||||||
@@ -35,6 +35,9 @@ class ExtractionProcessor:
|
|||||||
logger.warning("Extraction processor is already running")
|
logger.warning("Extraction processor is already running")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Reset any stuck extractions from previous runs
|
||||||
|
await self._reset_stuck_extractions()
|
||||||
|
|
||||||
self.shutdown_event.clear()
|
self.shutdown_event.clear()
|
||||||
self.processor_task = asyncio.create_task(self._process_queue())
|
self.processor_task = asyncio.create_task(self._process_queue())
|
||||||
logger.info("Started extraction processor")
|
logger.info("Started extraction processor")
|
||||||
@@ -179,6 +182,46 @@ class ExtractionProcessor:
|
|||||||
self.max_concurrent,
|
self.max_concurrent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _reset_stuck_extractions(self) -> None:
|
||||||
|
"""Reset any extractions stuck in 'processing' status back to 'pending'."""
|
||||||
|
try:
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
extraction_service = ExtractionService(session)
|
||||||
|
|
||||||
|
# Get all extractions stuck in processing status
|
||||||
|
stuck_extractions = (
|
||||||
|
await extraction_service.extraction_repo.get_by_status("processing")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not stuck_extractions:
|
||||||
|
logger.info("No stuck extractions found to reset")
|
||||||
|
return
|
||||||
|
|
||||||
|
reset_count = 0
|
||||||
|
for extraction in stuck_extractions:
|
||||||
|
try:
|
||||||
|
await extraction_service.extraction_repo.update(
|
||||||
|
extraction, {"status": "pending", "error": None}
|
||||||
|
)
|
||||||
|
reset_count += 1
|
||||||
|
logger.info(
|
||||||
|
"Reset stuck extraction %d from processing to pending",
|
||||||
|
extraction.id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to reset extraction %d", extraction.id
|
||||||
|
)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Successfully reset %d stuck extractions from processing to pending",
|
||||||
|
reset_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to reset stuck extractions")
|
||||||
|
|
||||||
def get_status(self) -> dict:
|
def get_status(self) -> dict:
|
||||||
"""Get the current status of the extraction processor."""
|
"""Get the current status of the extraction processor."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -129,3 +129,36 @@ class TestExtractionRepository:
|
|||||||
assert result.sound_id == TEST_SOUND_ID
|
assert result.sound_id == TEST_SOUND_ID
|
||||||
extraction_repo.session.commit.assert_called_once()
|
extraction_repo.session.commit.assert_called_once()
|
||||||
extraction_repo.session.refresh.assert_called_once_with(extraction)
|
extraction_repo.session.refresh.assert_called_once_with(extraction)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_status(self, extraction_repo):
|
||||||
|
"""Test getting extractions by status."""
|
||||||
|
mock_extractions = [
|
||||||
|
Extraction(
|
||||||
|
id=1,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test123",
|
||||||
|
url="https://www.youtube.com/watch?v=test1",
|
||||||
|
user_id=1,
|
||||||
|
status="processing",
|
||||||
|
),
|
||||||
|
Extraction(
|
||||||
|
id=2,
|
||||||
|
service="youtube",
|
||||||
|
service_id="test456",
|
||||||
|
url="https://www.youtube.com/watch?v=test2",
|
||||||
|
user_id=1,
|
||||||
|
status="processing",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_result = Mock()
|
||||||
|
mock_result.all.return_value = mock_extractions
|
||||||
|
|
||||||
|
extraction_repo.session.exec = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
result = await extraction_repo.get_by_status("processing")
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(extraction.status == "processing" for extraction in result)
|
||||||
|
extraction_repo.session.exec.assert_called_once()
|
||||||
|
|||||||
@@ -217,6 +217,10 @@ class TestExtractionService:
|
|||||||
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
|
extraction_service.extraction_repo.get_by_service_and_id = AsyncMock(
|
||||||
return_value=None,
|
return_value=None,
|
||||||
)
|
)
|
||||||
|
# Mock user repository
|
||||||
|
from app.models.user import User
|
||||||
|
mock_user = User(id=1, name="Test User")
|
||||||
|
extraction_service.user_repo.get_by_id = AsyncMock(return_value=mock_user)
|
||||||
|
|
||||||
# Mock service detection
|
# Mock service detection
|
||||||
service_info = {
|
service_info = {
|
||||||
|
|||||||
Reference in New Issue
Block a user