"""User repository.""" from datetime import datetime from enum import Enum from typing import Any from sqlalchemy import func from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.plan import Plan from app.models.user import User from app.models.sound_played import SoundPlayed from app.models.credit_transaction import CreditTransaction from app.models.playlist import Playlist from app.models.sound import Sound from app.models.tts import TTS from app.repositories.base import BaseRepository logger = get_logger(__name__) class UserSortField(str, Enum): """User sort fields.""" NAME = "name" EMAIL = "email" ROLE = "role" CREDITS = "credits" CREATED_AT = "created_at" class SortOrder(str, Enum): """Sort order.""" ASC = "asc" DESC = "desc" class UserStatus(str, Enum): """User status filter.""" ALL = "all" ACTIVE = "active" INACTIVE = "inactive" class UserRepository(BaseRepository[User]): """Repository for user operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the user repository.""" super().__init__(User, session) async def get_all_with_plan( self, limit: int = 100, offset: int = 0, ) -> list[User]: """Get all users with plan relationship loaded.""" try: statement = ( select(User) .options(selectinload(User.plan)) .limit(limit) .offset(offset) ) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get all users with plan") raise async def get_all_with_plan_paginated( # noqa: PLR0913 self, page: int = 1, limit: int = 50, search: str | None = None, sort_by: UserSortField = UserSortField.NAME, sort_order: SortOrder = SortOrder.ASC, status_filter: UserStatus = UserStatus.ALL, ) -> tuple[list[User], int]: """Get all users with plan relationship loaded and return total count.""" try: # Calculate offset offset = (page - 1) * limit # Build base query base_query = select(User).options(selectinload(User.plan)) count_query = select(func.count(User.id)) # Apply search filter if search and search.strip(): search_pattern = f"%{search.strip().lower()}%" search_condition = func.lower(User.name).like( search_pattern, ) | func.lower(User.email).like(search_pattern) base_query = base_query.where(search_condition) count_query = count_query.where(search_condition) # Apply status filter if status_filter == UserStatus.ACTIVE: base_query = base_query.where(User.is_active == True) # noqa: E712 count_query = count_query.where(User.is_active == True) # noqa: E712 elif status_filter == UserStatus.INACTIVE: base_query = base_query.where(User.is_active == False) # noqa: E712 count_query = count_query.where(User.is_active == False) # noqa: E712 # Apply sorting sort_column = { UserSortField.NAME: User.name, UserSortField.EMAIL: User.email, UserSortField.ROLE: User.role, UserSortField.CREDITS: User.credits, UserSortField.CREATED_AT: User.created_at, }.get(sort_by, User.name) if sort_order == SortOrder.DESC: base_query = base_query.order_by(sort_column.desc()) else: base_query = base_query.order_by(sort_column.asc()) # Get total count count_result = await self.session.exec(count_query) total_count = count_result.one() # Apply pagination and get results paginated_query = base_query.limit(limit).offset(offset) result = await self.session.exec(paginated_query) users = list(result.all()) except Exception: logger.exception("Failed to get paginated users with plan") raise else: return users, total_count async def get_by_id_with_plan(self, entity_id: int) -> User | None: """Get a user by ID with plan relationship loaded.""" try: statement = ( select(User) .options(selectinload(User.plan)) .where(User.id == entity_id) ) result = await self.session.exec(statement) return result.first() except Exception: logger.exception( "Failed to get user by ID with plan: %s", entity_id, ) raise async def get_by_email(self, email: str) -> User | None: """Get a user by email address.""" try: statement = select(User).where(User.email == email) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get user by email: %s", email) raise async def get_by_api_token(self, api_token: str) -> User | None: """Get a user by API token.""" try: statement = select(User).where(User.api_token == api_token) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get user by API token") raise async def create(self, entity_data: dict[str, Any]) -> User: """Create a new user with plan assignment and first user admin logic.""" def _raise_plan_not_found() -> None: msg = "Default plan not found" raise ValueError(msg) try: # Check if this is the first user user_count_statement = select(User) user_count_result = await self.session.exec(user_count_statement) is_first_user = user_count_result.first() is None if is_first_user: # First user gets admin role and pro plan plan_statement = select(Plan).where(Plan.code == "pro") entity_data["role"] = "admin" logger.info("Creating first user with admin role and pro plan") else: # Regular users get free plan plan_statement = select(Plan).where(Plan.code == "free") plan_result = await self.session.exec(plan_statement) default_plan = plan_result.first() if default_plan is None: _raise_plan_not_found() # Type assertion to help type checker understand default_plan is not None assert default_plan is not None # noqa: S101 # Set plan_id and default credits entity_data["plan_id"] = default_plan.id entity_data["credits"] = default_plan.credits # Use BaseRepository's create method return await super().create(entity_data) except Exception: logger.exception("Failed to create user") raise async def email_exists(self, email: str) -> bool: """Check if an email address is already registered.""" try: statement = select(User).where(User.email == email) result = await self.session.exec(statement) return result.first() is not None except Exception: logger.exception("Failed to check if email exists: %s", email) raise async def get_top_users( self, metric_type: str, date_filter: datetime | None = None, limit: int = 10, ) -> list[dict[str, Any]]: """Get top users by different metrics.""" try: if metric_type == "sounds_played": # Get users with most sounds played query = ( select( User.id, User.name, func.count(SoundPlayed.id).label("count") ) .join(SoundPlayed, User.id == SoundPlayed.user_id) .group_by(User.id, User.name) ) if date_filter: query = query.where(SoundPlayed.created_at >= date_filter) elif metric_type == "credits_used": # Get users with most credits used (negative transactions) query = ( select( User.id, User.name, func.sum(func.abs(CreditTransaction.amount)).label("count") ) .join(CreditTransaction, User.id == CreditTransaction.user_id) .where(CreditTransaction.amount < 0) .group_by(User.id, User.name) ) if date_filter: query = query.where(CreditTransaction.created_at >= date_filter) elif metric_type == "tracks_added": # Get users with most EXT sounds added query = ( select( User.id, User.name, func.count(Sound.id).label("count") ) .join(Sound, User.id == Sound.user_id) .where(Sound.type == "EXT") .group_by(User.id, User.name) ) if date_filter: query = query.where(Sound.created_at >= date_filter) elif metric_type == "tts_added": # Get users with most TTS sounds added query = ( select( User.id, User.name, func.count(TTS.id).label("count") ) .join(TTS, User.id == TTS.user_id) .group_by(User.id, User.name) ) if date_filter: query = query.where(TTS.created_at >= date_filter) elif metric_type == "playlists_created": # Get users with most playlists created query = ( select( User.id, User.name, func.count(Playlist.id).label("count") ) .join(Playlist, User.id == Playlist.user_id) .group_by(User.id, User.name) ) if date_filter: query = query.where(Playlist.created_at >= date_filter) else: msg = f"Unknown metric type: {metric_type}" raise ValueError(msg) # Add ordering and limit query = query.order_by(func.count().desc()).limit(limit) result = await self.session.exec(query) rows = result.all() return [ { "id": row[0], "name": row[1], "count": int(row[2]), } for row in rows ] except Exception: logger.exception( "Failed to get top users for metric=%s, date_filter=%s", metric_type, date_filter, ) raise