- Added `ExtractionProcessor` class to handle extraction queue processing in the background. - Implemented methods for starting, stopping, and queuing extractions with concurrency limits. - Integrated logging for monitoring the processor's status and actions. - Created tests for the extraction processor to ensure functionality and error handling. test: Add unit tests for extraction API endpoints - Created tests for successful extraction creation, authentication checks, and processor status retrieval. - Ensured proper responses for authenticated and unauthenticated requests. test: Implement unit tests for extraction repository - Added tests for creating, retrieving, and updating extractions in the repository. - Mocked database interactions to validate repository behavior without actual database access. test: Add comprehensive tests for extraction service - Developed tests for extraction creation, service detection, and sound record creation. - Included tests for handling duplicate extractions and invalid URLs. test: Add unit tests for extraction background processor - Created tests for the `ExtractionProcessor` class to validate its behavior under various conditions. - Ensured proper handling of extraction queuing, processing, and completion callbacks. fix: Update OAuth service tests to use AsyncMock - Modified OAuth provider tests to use `AsyncMock` for mocking asynchronous HTTP requests.
83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
"""Extraction repository for database operations."""
|
|
|
|
from sqlalchemy import desc
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.extraction import Extraction
|
|
|
|
|
|
class ExtractionRepository:
|
|
"""Repository for extraction database operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the extraction repository."""
|
|
self.session = session
|
|
|
|
async def create(self, extraction_data: dict) -> Extraction:
|
|
"""Create a new extraction."""
|
|
extraction = Extraction(**extraction_data)
|
|
self.session.add(extraction)
|
|
await self.session.commit()
|
|
await self.session.refresh(extraction)
|
|
return extraction
|
|
|
|
async def get_by_id(self, extraction_id: int) -> Extraction | None:
|
|
"""Get an extraction by ID."""
|
|
result = await self.session.exec(
|
|
select(Extraction).where(Extraction.id == extraction_id)
|
|
)
|
|
return result.first()
|
|
|
|
async def get_by_service_and_id(
|
|
self, service: str, service_id: str
|
|
) -> Extraction | None:
|
|
"""Get an extraction by service and service_id."""
|
|
result = await self.session.exec(
|
|
select(Extraction).where(
|
|
Extraction.service == service, Extraction.service_id == service_id
|
|
)
|
|
)
|
|
return result.first()
|
|
|
|
async def get_by_user(self, user_id: int) -> list[Extraction]:
|
|
"""Get all extractions for a user."""
|
|
result = await self.session.exec(
|
|
select(Extraction)
|
|
.where(Extraction.user_id == user_id)
|
|
.order_by(desc(Extraction.created_at))
|
|
)
|
|
return list(result.all())
|
|
|
|
async def get_pending_extractions(self) -> list[Extraction]:
|
|
"""Get all pending extractions."""
|
|
result = await self.session.exec(
|
|
select(Extraction)
|
|
.where(Extraction.status == "pending")
|
|
.order_by(Extraction.created_at)
|
|
)
|
|
return list(result.all())
|
|
|
|
async def update(self, extraction: Extraction, update_data: dict) -> Extraction:
|
|
"""Update an extraction."""
|
|
for key, value in update_data.items():
|
|
setattr(extraction, key, value)
|
|
|
|
await self.session.commit()
|
|
await self.session.refresh(extraction)
|
|
return extraction
|
|
|
|
async def delete(self, extraction: Extraction) -> None:
|
|
"""Delete an extraction."""
|
|
await self.session.delete(extraction)
|
|
await self.session.commit()
|
|
|
|
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
|
|
"""Get extractions by status."""
|
|
result = await self.session.exec(
|
|
select(Extraction)
|
|
.where(Extraction.status == status)
|
|
.order_by(desc(Extraction.created_at))
|
|
)
|
|
return list(result.all())
|