Files
sdb2-backend/app/repositories/extraction.py
JSC 16eb789539
Some checks failed
Backend CI / lint (push) Failing after 4m53s
Backend CI / test (push) Failing after 4m31s
feat: Add method to get extractions by status and implement user info retrieval in extraction service
2025-08-24 13:24:48 +02:00

166 lines
5.8 KiB
Python

"""Extraction repository for database operations."""
from sqlalchemy import asc, desc, func, or_
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.extraction import Extraction
from app.models.user import User
from app.repositories.base import BaseRepository
class ExtractionRepository(BaseRepository[Extraction]):
"""Repository for extraction database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the extraction repository."""
super().__init__(Extraction, session)
async def get_by_service_and_id(
self,
service: str,
service_id: str,
) -> Extraction | None:
"""Get an extraction by service and service_id."""
result = await self.session.exec(
select(Extraction).where(
Extraction.service == service,
Extraction.service_id == service_id,
),
)
return result.first()
async def get_by_user(self, user_id: int) -> list[Extraction]:
"""Get all extractions for a user."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.user_id == user_id)
.order_by(desc(Extraction.created_at)),
)
return list(result.all())
async def get_by_status(self, status: str) -> list[Extraction]:
"""Get all extractions by status."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == status)
.order_by(Extraction.created_at)
)
return list(result.all())
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
"""Get all pending extractions."""
result = await self.session.exec(
select(Extraction, User)
.join(User, Extraction.user_id == User.id)
.where(Extraction.status == "pending")
.order_by(Extraction.created_at),
)
return list(result.all())
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
"""Get extractions by status."""
result = await self.session.exec(
select(Extraction)
.where(Extraction.status == status)
.order_by(desc(Extraction.created_at)),
)
return list(result.all())
async def get_user_extractions_filtered( # noqa: PLR0913
self,
user_id: int,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[tuple[Extraction, User]], int]:
"""Get extractions for a user with filtering, search, and sorting."""
base_query = (
select(Extraction, User)
.join(User, Extraction.user_id == User.id)
.where(Extraction.user_id == user_id)
)
# Apply search filter
if search:
search_pattern = f"%{search}%"
base_query = base_query.where(
or_(
Extraction.title.ilike(search_pattern),
Extraction.url.ilike(search_pattern),
Extraction.service.ilike(search_pattern),
),
)
# Apply status filter
if status_filter:
base_query = base_query.where(Extraction.status == status_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply sorting and pagination
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
if sort_order.lower() == "asc":
base_query = base_query.order_by(asc(sort_column))
else:
base_query = base_query.order_by(desc(sort_column))
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
return list(result.all()), total_count
async def get_all_extractions_filtered( # noqa: PLR0913
self,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
status_filter: str | None = None,
limit: int = 50,
offset: int = 0,
) -> tuple[list[tuple[Extraction, User]], int]:
"""Get all extractions with filtering, search, and sorting."""
base_query = select(Extraction, User).join(User, Extraction.user_id == User.id)
# Apply search filter
if search:
search_pattern = f"%{search}%"
base_query = base_query.where(
or_(
Extraction.title.ilike(search_pattern),
Extraction.url.ilike(search_pattern),
Extraction.service.ilike(search_pattern),
),
)
# Apply status filter
if status_filter:
base_query = base_query.where(Extraction.status == status_filter)
# Get total count before pagination
count_query = select(func.count()).select_from(
base_query.subquery(),
)
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply sorting and pagination
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
if sort_order.lower() == "asc":
base_query = base_query.order_by(asc(sort_column))
else:
base_query = base_query.order_by(desc(sort_column))
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
return list(result.all()), total_count