130 lines
4.5 KiB
Python
130 lines
4.5 KiB
Python
"""User repository."""
|
|
|
|
from typing import Any
|
|
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.plan import Plan
|
|
from app.models.user import User
|
|
from app.repositories.base import BaseRepository
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class UserRepository(BaseRepository[User]):
|
|
"""Repository for user operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the user repository."""
|
|
super().__init__(User, session)
|
|
|
|
async def get_all_with_plan(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[User]:
|
|
"""Get all users with plan relationship loaded."""
|
|
try:
|
|
statement = (
|
|
select(User)
|
|
.options(selectinload(User.plan))
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get all users with plan")
|
|
raise
|
|
|
|
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
|
|
"""Get a user by ID with plan relationship loaded."""
|
|
try:
|
|
statement = (
|
|
select(User)
|
|
.options(selectinload(User.plan))
|
|
.where(User.id == entity_id)
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to get user by ID with plan: %s",
|
|
entity_id,
|
|
)
|
|
raise
|
|
|
|
async def get_by_email(self, email: str) -> User | None:
|
|
"""Get a user by email address."""
|
|
try:
|
|
statement = select(User).where(User.email == email)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get user by email: %s", email)
|
|
raise
|
|
|
|
async def get_by_api_token(self, api_token: str) -> User | None:
|
|
"""Get a user by API token."""
|
|
try:
|
|
statement = select(User).where(User.api_token == api_token)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get user by API token")
|
|
raise
|
|
|
|
async def create(self, user_data: dict[str, Any]) -> User:
|
|
"""Create a new user with plan assignment and first user admin logic."""
|
|
|
|
def _raise_plan_not_found() -> None:
|
|
msg = "Default plan not found"
|
|
raise ValueError(msg)
|
|
|
|
try:
|
|
# Check if this is the first user
|
|
user_count_statement = select(User)
|
|
user_count_result = await self.session.exec(user_count_statement)
|
|
is_first_user = user_count_result.first() is None
|
|
|
|
if is_first_user:
|
|
# First user gets admin role and pro plan
|
|
plan_statement = select(Plan).where(Plan.code == "pro")
|
|
user_data["role"] = "admin"
|
|
logger.info("Creating first user with admin role and pro plan")
|
|
else:
|
|
# Regular users get free plan
|
|
plan_statement = select(Plan).where(Plan.code == "free")
|
|
|
|
plan_result = await self.session.exec(plan_statement)
|
|
default_plan = plan_result.first()
|
|
|
|
if default_plan is None:
|
|
_raise_plan_not_found()
|
|
|
|
# Type assertion to help type checker understand default_plan is not None
|
|
assert default_plan is not None # noqa: S101
|
|
|
|
# Set plan_id and default credits
|
|
user_data["plan_id"] = default_plan.id
|
|
user_data["credits"] = default_plan.credits
|
|
|
|
# Use BaseRepository's create method
|
|
return await super().create(user_data)
|
|
except Exception:
|
|
logger.exception("Failed to create user")
|
|
raise
|
|
|
|
async def email_exists(self, email: str) -> bool:
|
|
"""Check if an email address is already registered."""
|
|
try:
|
|
statement = select(User).where(User.email == email)
|
|
result = await self.session.exec(statement)
|
|
return result.first() is not None
|
|
except Exception:
|
|
logger.exception("Failed to check if email exists: %s", email)
|
|
raise
|