"""User repository.""" from typing import Any from enum import Enum from sqlalchemy import func from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.plan import Plan from app.models.user import User from app.repositories.base import BaseRepository logger = get_logger(__name__) class UserSortField(str, Enum): """User sort fields.""" NAME = "name" EMAIL = "email" ROLE = "role" CREDITS = "credits" CREATED_AT = "created_at" class SortOrder(str, Enum): """Sort order.""" ASC = "asc" DESC = "desc" class UserStatus(str, Enum): """User status filter.""" ALL = "all" ACTIVE = "active" INACTIVE = "inactive" class UserRepository(BaseRepository[User]): """Repository for user operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the user repository.""" super().__init__(User, session) async def get_all_with_plan( self, limit: int = 100, offset: int = 0, ) -> list[User]: """Get all users with plan relationship loaded.""" try: statement = ( select(User) .options(selectinload(User.plan)) .limit(limit) .offset(offset) ) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get all users with plan") raise async def get_all_with_plan_paginated( self, page: int = 1, limit: int = 50, search: str | None = None, sort_by: UserSortField = UserSortField.NAME, sort_order: SortOrder = SortOrder.ASC, status_filter: UserStatus = UserStatus.ALL, ) -> tuple[list[User], int]: """Get all users with plan relationship loaded and return total count.""" try: # Calculate offset offset = (page - 1) * limit # Build base query base_query = select(User).options(selectinload(User.plan)) count_query = select(func.count(User.id)) # Apply search filter if search and search.strip(): search_pattern = f"%{search.strip().lower()}%" search_condition = ( func.lower(User.name).like(search_pattern) | func.lower(User.email).like(search_pattern) ) base_query = base_query.where(search_condition) count_query = count_query.where(search_condition) # Apply status filter if status_filter == UserStatus.ACTIVE: base_query = base_query.where(User.is_active == True) # noqa: E712 count_query = count_query.where(User.is_active == True) # noqa: E712 elif status_filter == UserStatus.INACTIVE: base_query = base_query.where(User.is_active == False) # noqa: E712 count_query = count_query.where(User.is_active == False) # noqa: E712 # Apply sorting if sort_by == UserSortField.EMAIL: if sort_order == SortOrder.DESC: base_query = base_query.order_by(User.email.desc()) else: base_query = base_query.order_by(User.email.asc()) elif sort_by == UserSortField.ROLE: if sort_order == SortOrder.DESC: base_query = base_query.order_by(User.role.desc()) else: base_query = base_query.order_by(User.role.asc()) elif sort_by == UserSortField.CREDITS: if sort_order == SortOrder.DESC: base_query = base_query.order_by(User.credits.desc()) else: base_query = base_query.order_by(User.credits.asc()) elif sort_by == UserSortField.CREATED_AT: if sort_order == SortOrder.DESC: base_query = base_query.order_by(User.created_at.desc()) else: base_query = base_query.order_by(User.created_at.asc()) else: # Default to name if sort_order == SortOrder.DESC: base_query = base_query.order_by(User.name.desc()) else: base_query = base_query.order_by(User.name.asc()) # Get total count count_result = await self.session.exec(count_query) total_count = count_result.one() # Apply pagination and get results paginated_query = base_query.limit(limit).offset(offset) result = await self.session.exec(paginated_query) users = list(result.all()) return users, total_count except Exception: logger.exception("Failed to get paginated users with plan") raise async def get_by_id_with_plan(self, entity_id: int) -> User | None: """Get a user by ID with plan relationship loaded.""" try: statement = ( select(User) .options(selectinload(User.plan)) .where(User.id == entity_id) ) result = await self.session.exec(statement) return result.first() except Exception: logger.exception( "Failed to get user by ID with plan: %s", entity_id, ) raise async def get_by_email(self, email: str) -> User | None: """Get a user by email address.""" try: statement = select(User).where(User.email == email) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get user by email: %s", email) raise async def get_by_api_token(self, api_token: str) -> User | None: """Get a user by API token.""" try: statement = select(User).where(User.api_token == api_token) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get user by API token") raise async def create(self, user_data: dict[str, Any]) -> User: """Create a new user with plan assignment and first user admin logic.""" def _raise_plan_not_found() -> None: msg = "Default plan not found" raise ValueError(msg) try: # Check if this is the first user user_count_statement = select(User) user_count_result = await self.session.exec(user_count_statement) is_first_user = user_count_result.first() is None if is_first_user: # First user gets admin role and pro plan plan_statement = select(Plan).where(Plan.code == "pro") user_data["role"] = "admin" logger.info("Creating first user with admin role and pro plan") else: # Regular users get free plan plan_statement = select(Plan).where(Plan.code == "free") plan_result = await self.session.exec(plan_statement) default_plan = plan_result.first() if default_plan is None: _raise_plan_not_found() # Type assertion to help type checker understand default_plan is not None assert default_plan is not None # noqa: S101 # Set plan_id and default credits user_data["plan_id"] = default_plan.id user_data["credits"] = default_plan.credits # Use BaseRepository's create method return await super().create(user_data) except Exception: logger.exception("Failed to create user") raise async def email_exists(self, email: str) -> bool: """Check if an email address is already registered.""" try: statement = select(User).where(User.email == email) result = await self.session.exec(statement) return result.first() is not None except Exception: logger.exception("Failed to check if email exists: %s", email) raise