Refactor user endpoint tests to include pagination and response structure validation
- Updated tests for listing users to validate pagination and response format. - Changed mock return values to include total count and pagination details. - Refactored user creation mocks for clarity and consistency. - Enhanced assertions to check for presence of pagination fields in responses. - Adjusted test cases for user retrieval and updates to ensure proper handling of user data. - Improved readability by restructuring mock definitions and assertions across various test files.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""User repository."""
|
||||
|
||||
from typing import Any
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import selectinload
|
||||
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
|
||||
|
||||
class UserSortField(str, Enum):
|
||||
"""User sort fields."""
|
||||
|
||||
NAME = "name"
|
||||
EMAIL = "email"
|
||||
ROLE = "role"
|
||||
@@ -27,12 +28,14 @@ class UserSortField(str, Enum):
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
"""Sort order."""
|
||||
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
"""User status filter."""
|
||||
|
||||
ALL = "all"
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
@@ -64,7 +67,7 @@ class UserRepository(BaseRepository[User]):
|
||||
logger.exception("Failed to get all users with plan")
|
||||
raise
|
||||
|
||||
async def get_all_with_plan_paginated(
|
||||
async def get_all_with_plan_paginated( # noqa: PLR0913
|
||||
self,
|
||||
page: int = 1,
|
||||
limit: int = 50,
|
||||
@@ -77,21 +80,20 @@ class UserRepository(BaseRepository[User]):
|
||||
try:
|
||||
# Calculate offset
|
||||
offset = (page - 1) * limit
|
||||
|
||||
|
||||
# Build base query
|
||||
base_query = select(User).options(selectinload(User.plan))
|
||||
count_query = select(func.count(User.id))
|
||||
|
||||
|
||||
# Apply search filter
|
||||
if search and search.strip():
|
||||
search_pattern = f"%{search.strip().lower()}%"
|
||||
search_condition = (
|
||||
func.lower(User.name).like(search_pattern) |
|
||||
func.lower(User.email).like(search_pattern)
|
||||
)
|
||||
search_condition = func.lower(User.name).like(
|
||||
search_pattern,
|
||||
) | func.lower(User.email).like(search_pattern)
|
||||
base_query = base_query.where(search_condition)
|
||||
count_query = count_query.where(search_condition)
|
||||
|
||||
|
||||
# Apply status filter
|
||||
if status_filter == UserStatus.ACTIVE:
|
||||
base_query = base_query.where(User.is_active == True) # noqa: E712
|
||||
@@ -99,47 +101,34 @@ class UserRepository(BaseRepository[User]):
|
||||
elif status_filter == UserStatus.INACTIVE:
|
||||
base_query = base_query.where(User.is_active == False) # noqa: E712
|
||||
count_query = count_query.where(User.is_active == False) # noqa: E712
|
||||
|
||||
|
||||
# Apply sorting
|
||||
if sort_by == UserSortField.EMAIL:
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.email.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.email.asc())
|
||||
elif sort_by == UserSortField.ROLE:
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.role.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.role.asc())
|
||||
elif sort_by == UserSortField.CREDITS:
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.credits.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.credits.asc())
|
||||
elif sort_by == UserSortField.CREATED_AT:
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.created_at.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.created_at.asc())
|
||||
else: # Default to name
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(User.name.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(User.name.asc())
|
||||
|
||||
sort_column = {
|
||||
UserSortField.NAME: User.name,
|
||||
UserSortField.EMAIL: User.email,
|
||||
UserSortField.ROLE: User.role,
|
||||
UserSortField.CREDITS: User.credits,
|
||||
UserSortField.CREATED_AT: User.created_at,
|
||||
}.get(sort_by, User.name)
|
||||
|
||||
if sort_order == SortOrder.DESC:
|
||||
base_query = base_query.order_by(sort_column.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(sort_column.asc())
|
||||
|
||||
# Get total count
|
||||
count_result = await self.session.exec(count_query)
|
||||
total_count = count_result.one()
|
||||
|
||||
|
||||
# Apply pagination and get results
|
||||
paginated_query = base_query.limit(limit).offset(offset)
|
||||
result = await self.session.exec(paginated_query)
|
||||
users = list(result.all())
|
||||
|
||||
return users, total_count
|
||||
except Exception:
|
||||
logger.exception("Failed to get paginated users with plan")
|
||||
raise
|
||||
else:
|
||||
return users, total_count
|
||||
|
||||
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
|
||||
"""Get a user by ID with plan relationship loaded."""
|
||||
@@ -178,7 +167,7 @@ class UserRepository(BaseRepository[User]):
|
||||
logger.exception("Failed to get user by API token")
|
||||
raise
|
||||
|
||||
async def create(self, user_data: dict[str, Any]) -> User:
|
||||
async def create(self, entity_data: dict[str, Any]) -> User:
|
||||
"""Create a new user with plan assignment and first user admin logic."""
|
||||
|
||||
def _raise_plan_not_found() -> None:
|
||||
@@ -194,7 +183,7 @@ class UserRepository(BaseRepository[User]):
|
||||
if is_first_user:
|
||||
# First user gets admin role and pro plan
|
||||
plan_statement = select(Plan).where(Plan.code == "pro")
|
||||
user_data["role"] = "admin"
|
||||
entity_data["role"] = "admin"
|
||||
logger.info("Creating first user with admin role and pro plan")
|
||||
else:
|
||||
# Regular users get free plan
|
||||
@@ -210,11 +199,11 @@ class UserRepository(BaseRepository[User]):
|
||||
assert default_plan is not None # noqa: S101
|
||||
|
||||
# Set plan_id and default credits
|
||||
user_data["plan_id"] = default_plan.id
|
||||
user_data["credits"] = default_plan.credits
|
||||
entity_data["plan_id"] = default_plan.id
|
||||
entity_data["credits"] = default_plan.credits
|
||||
|
||||
# Use BaseRepository's create method
|
||||
return await super().create(user_data)
|
||||
return await super().create(entity_data)
|
||||
except Exception:
|
||||
logger.exception("Failed to create user")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user