Files
sdb2-backend/app/repositories/user.py

231 lines
8.3 KiB
Python

"""User repository."""
from typing import Any
from enum import Enum
from sqlalchemy import func
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.plan import Plan
from app.models.user import User
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class UserSortField(str, Enum):
"""User sort fields."""
NAME = "name"
EMAIL = "email"
ROLE = "role"
CREDITS = "credits"
CREATED_AT = "created_at"
class SortOrder(str, Enum):
"""Sort order."""
ASC = "asc"
DESC = "desc"
class UserStatus(str, Enum):
"""User status filter."""
ALL = "all"
ACTIVE = "active"
INACTIVE = "inactive"
class UserRepository(BaseRepository[User]):
"""Repository for user operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the user repository."""
super().__init__(User, session)
async def get_all_with_plan(
self,
limit: int = 100,
offset: int = 0,
) -> list[User]:
"""Get all users with plan relationship loaded."""
try:
statement = (
select(User)
.options(selectinload(User.plan))
.limit(limit)
.offset(offset)
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all users with plan")
raise
async def get_all_with_plan_paginated(
self,
page: int = 1,
limit: int = 50,
search: str | None = None,
sort_by: UserSortField = UserSortField.NAME,
sort_order: SortOrder = SortOrder.ASC,
status_filter: UserStatus = UserStatus.ALL,
) -> tuple[list[User], int]:
"""Get all users with plan relationship loaded and return total count."""
try:
# Calculate offset
offset = (page - 1) * limit
# Build base query
base_query = select(User).options(selectinload(User.plan))
count_query = select(func.count(User.id))
# Apply search filter
if search and search.strip():
search_pattern = f"%{search.strip().lower()}%"
search_condition = (
func.lower(User.name).like(search_pattern) |
func.lower(User.email).like(search_pattern)
)
base_query = base_query.where(search_condition)
count_query = count_query.where(search_condition)
# Apply status filter
if status_filter == UserStatus.ACTIVE:
base_query = base_query.where(User.is_active == True) # noqa: E712
count_query = count_query.where(User.is_active == True) # noqa: E712
elif status_filter == UserStatus.INACTIVE:
base_query = base_query.where(User.is_active == False) # noqa: E712
count_query = count_query.where(User.is_active == False) # noqa: E712
# Apply sorting
if sort_by == UserSortField.EMAIL:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.email.desc())
else:
base_query = base_query.order_by(User.email.asc())
elif sort_by == UserSortField.ROLE:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.role.desc())
else:
base_query = base_query.order_by(User.role.asc())
elif sort_by == UserSortField.CREDITS:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.credits.desc())
else:
base_query = base_query.order_by(User.credits.asc())
elif sort_by == UserSortField.CREATED_AT:
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.created_at.desc())
else:
base_query = base_query.order_by(User.created_at.asc())
else: # Default to name
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(User.name.desc())
else:
base_query = base_query.order_by(User.name.asc())
# Get total count
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply pagination and get results
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
users = list(result.all())
return users, total_count
except Exception:
logger.exception("Failed to get paginated users with plan")
raise
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
"""Get a user by ID with plan relationship loaded."""
try:
statement = (
select(User)
.options(selectinload(User.plan))
.where(User.id == entity_id)
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception(
"Failed to get user by ID with plan: %s",
entity_id,
)
raise
async def get_by_email(self, email: str) -> User | None:
"""Get a user by email address."""
try:
statement = select(User).where(User.email == email)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by email: %s", email)
raise
async def get_by_api_token(self, api_token: str) -> User | None:
"""Get a user by API token."""
try:
statement = select(User).where(User.api_token == api_token)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by API token")
raise
async def create(self, user_data: dict[str, Any]) -> User:
"""Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None:
msg = "Default plan not found"
raise ValueError(msg)
try:
# Check if this is the first user
user_count_statement = select(User)
user_count_result = await self.session.exec(user_count_statement)
is_first_user = user_count_result.first() is None
if is_first_user:
# First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro")
user_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan")
else:
# Regular users get free plan
plan_statement = select(Plan).where(Plan.code == "free")
plan_result = await self.session.exec(plan_statement)
default_plan = plan_result.first()
if default_plan is None:
_raise_plan_not_found()
# Type assertion to help type checker understand default_plan is not None
assert default_plan is not None # noqa: S101
# Set plan_id and default credits
user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits
# Use BaseRepository's create method
return await super().create(user_data)
except Exception:
logger.exception("Failed to create user")
raise
async def email_exists(self, email: str) -> bool:
"""Check if an email address is already registered."""
try:
statement = select(User).where(User.email == email)
result = await self.session.exec(statement)
return result.first() is not None
except Exception:
logger.exception("Failed to check if email exists: %s", email)
raise