"""Extraction repository for database operations.""" from sqlalchemy import desc from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.extraction import Extraction class ExtractionRepository: """Repository for extraction database operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the extraction repository.""" self.session = session async def create(self, extraction_data: dict) -> Extraction: """Create a new extraction.""" extraction = Extraction(**extraction_data) self.session.add(extraction) await self.session.commit() await self.session.refresh(extraction) return extraction async def get_by_id(self, extraction_id: int) -> Extraction | None: """Get an extraction by ID.""" result = await self.session.exec( select(Extraction).where(Extraction.id == extraction_id) ) return result.first() async def get_by_service_and_id( self, service: str, service_id: str ) -> Extraction | None: """Get an extraction by service and service_id.""" result = await self.session.exec( select(Extraction).where( Extraction.service == service, Extraction.service_id == service_id ) ) return result.first() async def get_by_user(self, user_id: int) -> list[Extraction]: """Get all extractions for a user.""" result = await self.session.exec( select(Extraction) .where(Extraction.user_id == user_id) .order_by(desc(Extraction.created_at)) ) return list(result.all()) async def get_pending_extractions(self) -> list[Extraction]: """Get all pending extractions.""" result = await self.session.exec( select(Extraction) .where(Extraction.status == "pending") .order_by(Extraction.created_at) ) return list(result.all()) async def update(self, extraction: Extraction, update_data: dict) -> Extraction: """Update an extraction.""" for key, value in update_data.items(): setattr(extraction, key, value) await self.session.commit() await self.session.refresh(extraction) return extraction async def delete(self, extraction: Extraction) -> None: """Delete an extraction.""" await self.session.delete(extraction) await self.session.commit() async def get_extractions_by_status(self, status: str) -> list[Extraction]: """Get extractions by status.""" result = await self.session.exec( select(Extraction) .where(Extraction.status == status) .order_by(desc(Extraction.created_at)) ) return list(result.all())