"""User repository.""" from typing import Any from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.plan import Plan from app.models.user import User from app.repositories.base import BaseRepository logger = get_logger(__name__) class UserRepository(BaseRepository[User]): """Repository for user operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the user repository.""" super().__init__(User, session) async def get_by_email(self, email: str) -> User | None: """Get a user by email address.""" try: statement = select(User).where(User.email == email) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get user by email: %s", email) raise async def get_by_api_token(self, api_token: str) -> User | None: """Get a user by API token.""" try: statement = select(User).where(User.api_token == api_token) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get user by API token") raise async def create(self, user_data: dict[str, Any]) -> User: """Create a new user with plan assignment and first user admin logic.""" def _raise_plan_not_found() -> None: msg = "Default plan not found" raise ValueError(msg) try: # Check if this is the first user user_count_statement = select(User) user_count_result = await self.session.exec(user_count_statement) is_first_user = user_count_result.first() is None if is_first_user: # First user gets admin role and pro plan plan_statement = select(Plan).where(Plan.code == "pro") user_data["role"] = "admin" logger.info("Creating first user with admin role and pro plan") else: # Regular users get free plan plan_statement = select(Plan).where(Plan.code == "free") plan_result = await self.session.exec(plan_statement) default_plan = plan_result.first() if default_plan is None: _raise_plan_not_found() # Type assertion to help type checker understand default_plan is not None assert default_plan is not None # noqa: S101 # Set plan_id and default credits user_data["plan_id"] = default_plan.id user_data["credits"] = default_plan.credits # Use BaseRepository's create method return await super().create(user_data) except Exception: logger.exception("Failed to create user") raise async def email_exists(self, email: str) -> bool: """Check if an email address is already registered.""" try: statement = select(User).where(User.email == email) result = await self.session.exec(statement) return result.first() is not None except Exception: logger.exception("Failed to check if email exists: %s", email) raise