Files
sdb2-backend/app/repositories/user.py
JSC 95e166eefb
Some checks failed
Backend CI / lint (push) Failing after 9s
Backend CI / test (push) Failing after 1m36s
feat: Add endpoint and service method to retrieve top users by various metrics
2025-09-27 21:52:00 +02:00

333 lines
12 KiB
Python

"""User repository."""
from datetime import datetime
from enum import Enum
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import selectinload
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.plan import Plan
from app.models.user import User
from app.models.sound_played import SoundPlayed
from app.models.credit_transaction import CreditTransaction
from app.models.playlist import Playlist
from app.models.sound import Sound
from app.models.tts import TTS
from app.repositories.base import BaseRepository
logger = get_logger(__name__)
class UserSortField(str, Enum):
"""User sort fields."""
NAME = "name"
EMAIL = "email"
ROLE = "role"
CREDITS = "credits"
CREATED_AT = "created_at"
class SortOrder(str, Enum):
"""Sort order."""
ASC = "asc"
DESC = "desc"
class UserStatus(str, Enum):
"""User status filter."""
ALL = "all"
ACTIVE = "active"
INACTIVE = "inactive"
class UserRepository(BaseRepository[User]):
"""Repository for user operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the user repository."""
super().__init__(User, session)
async def get_all_with_plan(
self,
limit: int = 100,
offset: int = 0,
) -> list[User]:
"""Get all users with plan relationship loaded."""
try:
statement = (
select(User)
.options(selectinload(User.plan))
.limit(limit)
.offset(offset)
)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all users with plan")
raise
async def get_all_with_plan_paginated( # noqa: PLR0913
self,
page: int = 1,
limit: int = 50,
search: str | None = None,
sort_by: UserSortField = UserSortField.NAME,
sort_order: SortOrder = SortOrder.ASC,
status_filter: UserStatus = UserStatus.ALL,
) -> tuple[list[User], int]:
"""Get all users with plan relationship loaded and return total count."""
try:
# Calculate offset
offset = (page - 1) * limit
# Build base query
base_query = select(User).options(selectinload(User.plan))
count_query = select(func.count(User.id))
# Apply search filter
if search and search.strip():
search_pattern = f"%{search.strip().lower()}%"
search_condition = func.lower(User.name).like(
search_pattern,
) | func.lower(User.email).like(search_pattern)
base_query = base_query.where(search_condition)
count_query = count_query.where(search_condition)
# Apply status filter
if status_filter == UserStatus.ACTIVE:
base_query = base_query.where(User.is_active == True) # noqa: E712
count_query = count_query.where(User.is_active == True) # noqa: E712
elif status_filter == UserStatus.INACTIVE:
base_query = base_query.where(User.is_active == False) # noqa: E712
count_query = count_query.where(User.is_active == False) # noqa: E712
# Apply sorting
sort_column = {
UserSortField.NAME: User.name,
UserSortField.EMAIL: User.email,
UserSortField.ROLE: User.role,
UserSortField.CREDITS: User.credits,
UserSortField.CREATED_AT: User.created_at,
}.get(sort_by, User.name)
if sort_order == SortOrder.DESC:
base_query = base_query.order_by(sort_column.desc())
else:
base_query = base_query.order_by(sort_column.asc())
# Get total count
count_result = await self.session.exec(count_query)
total_count = count_result.one()
# Apply pagination and get results
paginated_query = base_query.limit(limit).offset(offset)
result = await self.session.exec(paginated_query)
users = list(result.all())
except Exception:
logger.exception("Failed to get paginated users with plan")
raise
else:
return users, total_count
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
"""Get a user by ID with plan relationship loaded."""
try:
statement = (
select(User)
.options(selectinload(User.plan))
.where(User.id == entity_id)
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception(
"Failed to get user by ID with plan: %s",
entity_id,
)
raise
async def get_by_email(self, email: str) -> User | None:
"""Get a user by email address."""
try:
statement = select(User).where(User.email == email)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by email: %s", email)
raise
async def get_by_api_token(self, api_token: str) -> User | None:
"""Get a user by API token."""
try:
statement = select(User).where(User.api_token == api_token)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by API token")
raise
async def create(self, entity_data: dict[str, Any]) -> User:
"""Create a new user with plan assignment and first user admin logic."""
def _raise_plan_not_found() -> None:
msg = "Default plan not found"
raise ValueError(msg)
try:
# Check if this is the first user
user_count_statement = select(User)
user_count_result = await self.session.exec(user_count_statement)
is_first_user = user_count_result.first() is None
if is_first_user:
# First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro")
entity_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan")
else:
# Regular users get free plan
plan_statement = select(Plan).where(Plan.code == "free")
plan_result = await self.session.exec(plan_statement)
default_plan = plan_result.first()
if default_plan is None:
_raise_plan_not_found()
# Type assertion to help type checker understand default_plan is not None
assert default_plan is not None # noqa: S101
# Set plan_id and default credits
entity_data["plan_id"] = default_plan.id
entity_data["credits"] = default_plan.credits
# Use BaseRepository's create method
return await super().create(entity_data)
except Exception:
logger.exception("Failed to create user")
raise
async def email_exists(self, email: str) -> bool:
"""Check if an email address is already registered."""
try:
statement = select(User).where(User.email == email)
result = await self.session.exec(statement)
return result.first() is not None
except Exception:
logger.exception("Failed to check if email exists: %s", email)
raise
async def get_top_users(
self,
metric_type: str,
date_filter: datetime | None = None,
limit: int = 10,
) -> list[dict[str, Any]]:
"""Get top users by different metrics."""
try:
if metric_type == "sounds_played":
# Get users with most sounds played
query = (
select(
User.id,
User.name,
func.count(SoundPlayed.id).label("count")
)
.join(SoundPlayed, User.id == SoundPlayed.user_id)
.group_by(User.id, User.name)
)
if date_filter:
query = query.where(SoundPlayed.created_at >= date_filter)
elif metric_type == "credits_used":
# Get users with most credits used (negative transactions)
query = (
select(
User.id,
User.name,
func.sum(func.abs(CreditTransaction.amount)).label("count")
)
.join(CreditTransaction, User.id == CreditTransaction.user_id)
.where(CreditTransaction.amount < 0)
.group_by(User.id, User.name)
)
if date_filter:
query = query.where(CreditTransaction.created_at >= date_filter)
elif metric_type == "tracks_added":
# Get users with most EXT sounds added
query = (
select(
User.id,
User.name,
func.count(Sound.id).label("count")
)
.join(Sound, User.id == Sound.user_id)
.where(Sound.type == "EXT")
.group_by(User.id, User.name)
)
if date_filter:
query = query.where(Sound.created_at >= date_filter)
elif metric_type == "tts_added":
# Get users with most TTS sounds added
query = (
select(
User.id,
User.name,
func.count(TTS.id).label("count")
)
.join(TTS, User.id == TTS.user_id)
.group_by(User.id, User.name)
)
if date_filter:
query = query.where(TTS.created_at >= date_filter)
elif metric_type == "playlists_created":
# Get users with most playlists created
query = (
select(
User.id,
User.name,
func.count(Playlist.id).label("count")
)
.join(Playlist, User.id == Playlist.user_id)
.group_by(User.id, User.name)
)
if date_filter:
query = query.where(Playlist.created_at >= date_filter)
else:
msg = f"Unknown metric type: {metric_type}"
raise ValueError(msg)
# Add ordering and limit
query = query.order_by(func.count().desc()).limit(limit)
result = await self.session.exec(query)
rows = result.all()
return [
{
"id": row[0],
"name": row[1],
"count": int(row[2]),
}
for row in rows
]
except Exception:
logger.exception(
"Failed to get top users for metric=%s, date_filter=%s",
metric_type,
date_filter,
)
raise