"""Extraction repository for database operations.""" from sqlalchemy import asc, desc, func, or_ from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.extraction import Extraction from app.models.user import User from app.repositories.base import BaseRepository class ExtractionRepository(BaseRepository[Extraction]): """Repository for extraction database operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the extraction repository.""" super().__init__(Extraction, session) async def get_by_service_and_id( self, service: str, service_id: str, ) -> Extraction | None: """Get an extraction by service and service_id.""" result = await self.session.exec( select(Extraction).where( Extraction.service == service, Extraction.service_id == service_id, ), ) return result.first() async def get_by_user(self, user_id: int) -> list[Extraction]: """Get all extractions for a user.""" result = await self.session.exec( select(Extraction) .where(Extraction.user_id == user_id) .order_by(desc(Extraction.created_at)), ) return list(result.all()) async def get_pending_extractions(self) -> list[tuple[Extraction, User]]: """Get all pending extractions.""" result = await self.session.exec( select(Extraction, User) .join(User, Extraction.user_id == User.id) .where(Extraction.status == "pending") .order_by(Extraction.created_at), ) return list(result.all()) async def get_extractions_by_status(self, status: str) -> list[Extraction]: """Get extractions by status.""" result = await self.session.exec( select(Extraction) .where(Extraction.status == status) .order_by(desc(Extraction.created_at)), ) return list(result.all()) async def get_user_extractions_filtered( self, user_id: int, search: str | None = None, sort_by: str = "created_at", sort_order: str = "desc", status_filter: str | None = None, limit: int = 50, offset: int = 0, ) -> tuple[list[tuple[Extraction, User]], int]: """Get extractions for a user with filtering, search, and sorting.""" base_query = ( select(Extraction, User) .join(User, Extraction.user_id == User.id) .where(Extraction.user_id == user_id) ) # Apply search filter if search: search_pattern = f"%{search}%" base_query = base_query.where( or_( Extraction.title.ilike(search_pattern), Extraction.url.ilike(search_pattern), Extraction.service.ilike(search_pattern), ), ) # Apply status filter if status_filter: base_query = base_query.where(Extraction.status == status_filter) # Get total count before pagination count_query = select(func.count()).select_from( base_query.subquery() ) count_result = await self.session.exec(count_query) total_count = count_result.one() # Apply sorting and pagination sort_column = getattr(Extraction, sort_by, Extraction.created_at) if sort_order.lower() == "asc": base_query = base_query.order_by(asc(sort_column)) else: base_query = base_query.order_by(desc(sort_column)) paginated_query = base_query.limit(limit).offset(offset) result = await self.session.exec(paginated_query) return list(result.all()), total_count async def get_all_extractions_filtered( self, search: str | None = None, sort_by: str = "created_at", sort_order: str = "desc", status_filter: str | None = None, limit: int = 50, offset: int = 0, ) -> tuple[list[tuple[Extraction, User]], int]: """Get all extractions with filtering, search, and sorting.""" base_query = select(Extraction, User).join(User, Extraction.user_id == User.id) # Apply search filter if search: search_pattern = f"%{search}%" base_query = base_query.where( or_( Extraction.title.ilike(search_pattern), Extraction.url.ilike(search_pattern), Extraction.service.ilike(search_pattern), ), ) # Apply status filter if status_filter: base_query = base_query.where(Extraction.status == status_filter) # Get total count before pagination count_query = select(func.count()).select_from( base_query.subquery() ) count_result = await self.session.exec(count_query) total_count = count_result.one() # Apply sorting and pagination sort_column = getattr(Extraction, sort_by, Extraction.created_at) if sort_order.lower() == "asc": base_query = base_query.order_by(asc(sort_column)) else: base_query = base_query.order_by(desc(sort_column)) paginated_query = base_query.limit(limit).offset(offset) result = await self.session.exec(paginated_query) return list(result.all()), total_count