feat: Implement pagination for extractions and playlists with total count in responses
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""Extraction repository for database operations."""
|
||||
|
||||
from sqlalchemy import asc, desc, or_
|
||||
from sqlalchemy import asc, desc, func, or_
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -65,9 +65,11 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
status_filter: str | None = None,
|
||||
) -> list[tuple[Extraction, User]]:
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[tuple[Extraction, User]], int]:
|
||||
"""Get extractions for a user with filtering, search, and sorting."""
|
||||
query = (
|
||||
base_query = (
|
||||
select(Extraction, User)
|
||||
.join(User, Extraction.user_id == User.id)
|
||||
.where(Extraction.user_id == user_id)
|
||||
@@ -76,7 +78,7 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
# Apply search filter
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.where(
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
Extraction.title.ilike(search_pattern),
|
||||
Extraction.url.ilike(search_pattern),
|
||||
@@ -86,17 +88,26 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
|
||||
# Apply status filter
|
||||
if status_filter:
|
||||
query = query.where(Extraction.status == status_filter)
|
||||
base_query = base_query.where(Extraction.status == status_filter)
|
||||
|
||||
# Apply sorting
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(
|
||||
base_query.subquery()
|
||||
)
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
# Apply sorting and pagination
|
||||
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
|
||||
if sort_order.lower() == "asc":
|
||||
query = query.order_by(asc(sort_column))
|
||||
base_query = base_query.order_by(asc(sort_column))
|
||||
else:
|
||||
query = query.order_by(desc(sort_column))
|
||||
base_query = base_query.order_by(desc(sort_column))
|
||||
|
||||
result = await self.session.exec(query)
|
||||
return list(result.all())
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
|
||||
return list(result.all()), total_count
|
||||
|
||||
async def get_all_extractions_filtered(
|
||||
self,
|
||||
@@ -104,14 +115,16 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
status_filter: str | None = None,
|
||||
) -> list[tuple[Extraction, User]]:
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[tuple[Extraction, User]], int]:
|
||||
"""Get all extractions with filtering, search, and sorting."""
|
||||
query = select(Extraction, User).join(User, Extraction.user_id == User.id)
|
||||
base_query = select(Extraction, User).join(User, Extraction.user_id == User.id)
|
||||
|
||||
# Apply search filter
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
query = query.where(
|
||||
base_query = base_query.where(
|
||||
or_(
|
||||
Extraction.title.ilike(search_pattern),
|
||||
Extraction.url.ilike(search_pattern),
|
||||
@@ -121,14 +134,23 @@ class ExtractionRepository(BaseRepository[Extraction]):
|
||||
|
||||
# Apply status filter
|
||||
if status_filter:
|
||||
query = query.where(Extraction.status == status_filter)
|
||||
base_query = base_query.where(Extraction.status == status_filter)
|
||||
|
||||
# Apply sorting
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(
|
||||
base_query.subquery()
|
||||
)
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
# Apply sorting and pagination
|
||||
sort_column = getattr(Extraction, sort_by, Extraction.created_at)
|
||||
if sort_order.lower() == "asc":
|
||||
query = query.order_by(asc(sort_column))
|
||||
base_query = base_query.order_by(asc(sort_column))
|
||||
else:
|
||||
query = query.order_by(desc(sort_column))
|
||||
base_query = base_query.order_by(desc(sort_column))
|
||||
|
||||
result = await self.session.exec(query)
|
||||
return list(result.all())
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
|
||||
return list(result.all()), total_count
|
||||
|
||||
@@ -343,7 +343,9 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
offset: int = 0,
|
||||
favorites_only: bool = False,
|
||||
current_user_id: int | None = None,
|
||||
) -> list[dict]:
|
||||
*,
|
||||
return_count: bool = False,
|
||||
) -> list[dict] | tuple[list[dict], int]:
|
||||
"""Search and sort playlists with optional statistics."""
|
||||
try:
|
||||
if include_stats and sort_by in (
|
||||
@@ -491,6 +493,14 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
# Default sorting by name ascending
|
||||
subquery = subquery.order_by(Playlist.name.asc())
|
||||
|
||||
# Get total count if requested
|
||||
total_count = 0
|
||||
if return_count:
|
||||
# Create count query from the subquery before pagination
|
||||
count_query = select(func.count()).select_from(subquery.subquery())
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
# Apply pagination
|
||||
if offset > 0:
|
||||
subquery = subquery.offset(offset)
|
||||
@@ -532,4 +542,6 @@ class PlaylistRepository(BaseRepository[Playlist]):
|
||||
)
|
||||
raise
|
||||
else:
|
||||
if return_count:
|
||||
return playlists, total_count
|
||||
return playlists
|
||||
|
||||
Reference in New Issue
Block a user