- Updated tests for listing users to validate pagination and response format. - Changed mock return values to include total count and pagination details. - Refactored user creation mocks for clarity and consistency. - Enhanced assertions to check for presence of pagination fields in responses. - Adjusted test cases for user retrieval and updates to ensure proper handling of user data. - Improved readability by restructuring mock definitions and assertions across various test files.
157 lines
5.5 KiB
Python
157 lines
5.5 KiB
Python
"""Extraction repository for database operations."""
|
|
|
|
from sqlalchemy import asc, desc, func, or_
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.extraction import Extraction
|
|
from app.models.user import User
|
|
from app.repositories.base import BaseRepository
|
|
|
|
|
|
class ExtractionRepository(BaseRepository[Extraction]):
|
|
"""Repository for extraction database operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the extraction repository."""
|
|
super().__init__(Extraction, session)
|
|
|
|
async def get_by_service_and_id(
|
|
self,
|
|
service: str,
|
|
service_id: str,
|
|
) -> Extraction | None:
|
|
"""Get an extraction by service and service_id."""
|
|
result = await self.session.exec(
|
|
select(Extraction).where(
|
|
Extraction.service == service,
|
|
Extraction.service_id == service_id,
|
|
),
|
|
)
|
|
return result.first()
|
|
|
|
async def get_by_user(self, user_id: int) -> list[Extraction]:
|
|
"""Get all extractions for a user."""
|
|
result = await self.session.exec(
|
|
select(Extraction)
|
|
.where(Extraction.user_id == user_id)
|
|
.order_by(desc(Extraction.created_at)),
|
|
)
|
|
return list(result.all())
|
|
|
|
async def get_pending_extractions(self) -> list[tuple[Extraction, User]]:
|
|
"""Get all pending extractions."""
|
|
result = await self.session.exec(
|
|
select(Extraction, User)
|
|
.join(User, Extraction.user_id == User.id)
|
|
.where(Extraction.status == "pending")
|
|
.order_by(Extraction.created_at),
|
|
)
|
|
return list(result.all())
|
|
|
|
async def get_extractions_by_status(self, status: str) -> list[Extraction]:
|
|
"""Get extractions by status."""
|
|
result = await self.session.exec(
|
|
select(Extraction)
|
|
.where(Extraction.status == status)
|
|
.order_by(desc(Extraction.created_at)),
|
|
)
|
|
return list(result.all())
|
|
|
|
async def get_user_extractions_filtered( # noqa: PLR0913
|
|
self,
|
|
user_id: int,
|
|
search: str | None = None,
|
|
sort_by: str = "created_at",
|
|
sort_order: str = "desc",
|
|
status_filter: str | None = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> tuple[list[tuple[Extraction, User]], int]:
|
|
"""Get extractions for a user with filtering, search, and sorting."""
|
|
base_query = (
|
|
select(Extraction, User)
|
|
.join(User, Extraction.user_id == User.id)
|
|
.where(Extraction.user_id == user_id)
|
|
)
|
|
|
|
# Apply search filter
|
|
if search:
|
|
search_pattern = f"%{search}%"
|
|
base_query = base_query.where(
|
|
or_(
|
|
Extraction.title.ilike(search_pattern),
|
|
Extraction.url.ilike(search_pattern),
|
|
Extraction.service.ilike(search_pattern),
|
|
),
|
|
)
|
|
|
|
# Apply status filter
|
|
if status_filter:
|
|
base_query = base_query.where(Extraction.status == status_filter)
|
|
|
|
# Get total count before pagination
|
|
count_query = select(func.count()).select_from(
|
|
base_query.subquery(),
|
|
)
|
|
count_result = await self.session.exec(count_query)
|
|
total_count = count_result.one()
|
|
|
|
# Apply sorting and pagination
|
|
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
|
|
if sort_order.lower() == "asc":
|
|
base_query = base_query.order_by(asc(sort_column))
|
|
else:
|
|
base_query = base_query.order_by(desc(sort_column))
|
|
|
|
paginated_query = base_query.limit(limit).offset(offset)
|
|
result = await self.session.exec(paginated_query)
|
|
|
|
return list(result.all()), total_count
|
|
|
|
async def get_all_extractions_filtered( # noqa: PLR0913
|
|
self,
|
|
search: str | None = None,
|
|
sort_by: str = "created_at",
|
|
sort_order: str = "desc",
|
|
status_filter: str | None = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> tuple[list[tuple[Extraction, User]], int]:
|
|
"""Get all extractions with filtering, search, and sorting."""
|
|
base_query = select(Extraction, User).join(User, Extraction.user_id == User.id)
|
|
|
|
# Apply search filter
|
|
if search:
|
|
search_pattern = f"%{search}%"
|
|
base_query = base_query.where(
|
|
or_(
|
|
Extraction.title.ilike(search_pattern),
|
|
Extraction.url.ilike(search_pattern),
|
|
Extraction.service.ilike(search_pattern),
|
|
),
|
|
)
|
|
|
|
# Apply status filter
|
|
if status_filter:
|
|
base_query = base_query.where(Extraction.status == status_filter)
|
|
|
|
# Get total count before pagination
|
|
count_query = select(func.count()).select_from(
|
|
base_query.subquery(),
|
|
)
|
|
count_result = await self.session.exec(count_query)
|
|
total_count = count_result.one()
|
|
|
|
# Apply sorting and pagination
|
|
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
|
|
if sort_order.lower() == "asc":
|
|
base_query = base_query.order_by(asc(sort_column))
|
|
else:
|
|
base_query = base_query.order_by(desc(sort_column))
|
|
|
|
paginated_query = base_query.limit(limit).offset(offset)
|
|
result = await self.session.exec(paginated_query)
|
|
|
|
return list(result.all()), total_count
|