369 lines
12 KiB
Python
369 lines
12 KiB
Python
"""User repository."""
|
|
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from sqlalchemy import Select, func
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.credit_transaction import CreditTransaction
|
|
from app.models.extraction import Extraction
|
|
from app.models.plan import Plan
|
|
from app.models.playlist import Playlist
|
|
from app.models.sound_played import SoundPlayed
|
|
from app.models.tts import TTS
|
|
from app.models.user import User
|
|
from app.repositories.base import BaseRepository
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class UserSortField(str, Enum):
|
|
"""User sort fields."""
|
|
|
|
NAME = "name"
|
|
EMAIL = "email"
|
|
ROLE = "role"
|
|
CREDITS = "credits"
|
|
CREATED_AT = "created_at"
|
|
|
|
|
|
class SortOrder(str, Enum):
|
|
"""Sort order."""
|
|
|
|
ASC = "asc"
|
|
DESC = "desc"
|
|
|
|
|
|
class UserStatus(str, Enum):
|
|
"""User status filter."""
|
|
|
|
ALL = "all"
|
|
ACTIVE = "active"
|
|
INACTIVE = "inactive"
|
|
|
|
|
|
class UserRepository(BaseRepository[User]):
|
|
"""Repository for user operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the user repository."""
|
|
super().__init__(User, session)
|
|
|
|
async def get_all_with_plan(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[User]:
|
|
"""Get all users with plan relationship loaded."""
|
|
try:
|
|
statement = (
|
|
select(User)
|
|
.options(selectinload(User.plan))
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get all users with plan")
|
|
raise
|
|
|
|
async def get_all_with_plan_paginated( # noqa: PLR0913
|
|
self,
|
|
page: int = 1,
|
|
limit: int = 50,
|
|
search: str | None = None,
|
|
sort_by: UserSortField = UserSortField.NAME,
|
|
sort_order: SortOrder = SortOrder.ASC,
|
|
status_filter: UserStatus = UserStatus.ALL,
|
|
) -> tuple[list[User], int]:
|
|
"""Get all users with plan relationship loaded and return total count."""
|
|
try:
|
|
# Calculate offset
|
|
offset = (page - 1) * limit
|
|
|
|
# Build base query
|
|
base_query = select(User).options(selectinload(User.plan))
|
|
count_query = select(func.count(User.id))
|
|
|
|
# Apply search filter
|
|
if search and search.strip():
|
|
search_pattern = f"%{search.strip().lower()}%"
|
|
search_condition = func.lower(User.name).like(
|
|
search_pattern,
|
|
) | func.lower(User.email).like(search_pattern)
|
|
base_query = base_query.where(search_condition)
|
|
count_query = count_query.where(search_condition)
|
|
|
|
# Apply status filter
|
|
if status_filter == UserStatus.ACTIVE:
|
|
base_query = base_query.where(User.is_active == True) # noqa: E712
|
|
count_query = count_query.where(User.is_active == True) # noqa: E712
|
|
elif status_filter == UserStatus.INACTIVE:
|
|
base_query = base_query.where(User.is_active == False) # noqa: E712
|
|
count_query = count_query.where(User.is_active == False) # noqa: E712
|
|
|
|
# Apply sorting
|
|
sort_column = {
|
|
UserSortField.NAME: User.name,
|
|
UserSortField.EMAIL: User.email,
|
|
UserSortField.ROLE: User.role,
|
|
UserSortField.CREDITS: User.credits,
|
|
UserSortField.CREATED_AT: User.created_at,
|
|
}.get(sort_by, User.name)
|
|
|
|
if sort_order == SortOrder.DESC:
|
|
base_query = base_query.order_by(sort_column.desc())
|
|
else:
|
|
base_query = base_query.order_by(sort_column.asc())
|
|
|
|
# Get total count
|
|
count_result = await self.session.exec(count_query)
|
|
total_count = count_result.one()
|
|
|
|
# Apply pagination and get results
|
|
paginated_query = base_query.limit(limit).offset(offset)
|
|
result = await self.session.exec(paginated_query)
|
|
users = list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get paginated users with plan")
|
|
raise
|
|
else:
|
|
return users, total_count
|
|
|
|
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
|
|
"""Get a user by ID with plan relationship loaded."""
|
|
try:
|
|
statement = (
|
|
select(User)
|
|
.options(selectinload(User.plan))
|
|
.where(User.id == entity_id)
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to get user by ID with plan: %s",
|
|
entity_id,
|
|
)
|
|
raise
|
|
|
|
async def get_by_email(self, email: str) -> User | None:
|
|
"""Get a user by email address."""
|
|
try:
|
|
statement = select(User).where(User.email == email)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get user by email: %s", email)
|
|
raise
|
|
|
|
async def get_by_api_token(self, api_token: str) -> User | None:
|
|
"""Get a user by API token."""
|
|
try:
|
|
statement = select(User).where(User.api_token == api_token)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception("Failed to get user by API token")
|
|
raise
|
|
|
|
async def create(self, entity_data: dict[str, Any]) -> User:
|
|
"""Create a new user with plan assignment and first user admin logic."""
|
|
|
|
def _raise_plan_not_found() -> None:
|
|
msg = "Default plan not found"
|
|
raise ValueError(msg)
|
|
|
|
try:
|
|
# Check if this is the first user
|
|
user_count_statement = select(User)
|
|
user_count_result = await self.session.exec(user_count_statement)
|
|
is_first_user = user_count_result.first() is None
|
|
|
|
if is_first_user:
|
|
# First user gets admin role and pro plan
|
|
plan_statement = select(Plan).where(Plan.code == "pro")
|
|
entity_data["role"] = "admin"
|
|
logger.info("Creating first user with admin role and pro plan")
|
|
else:
|
|
# Regular users get free plan
|
|
plan_statement = select(Plan).where(Plan.code == "free")
|
|
|
|
plan_result = await self.session.exec(plan_statement)
|
|
default_plan = plan_result.first()
|
|
|
|
if default_plan is None:
|
|
_raise_plan_not_found()
|
|
|
|
# Type assertion to help type checker understand default_plan is not None
|
|
assert default_plan is not None # noqa: S101
|
|
|
|
# Set plan_id and default credits
|
|
entity_data["plan_id"] = default_plan.id
|
|
entity_data["credits"] = default_plan.credits
|
|
|
|
# Use BaseRepository's create method
|
|
return await super().create(entity_data)
|
|
except Exception:
|
|
logger.exception("Failed to create user")
|
|
raise
|
|
|
|
async def email_exists(self, email: str) -> bool:
|
|
"""Check if an email address is already registered."""
|
|
try:
|
|
statement = select(User).where(User.email == email)
|
|
result = await self.session.exec(statement)
|
|
return result.first() is not None
|
|
except Exception:
|
|
logger.exception("Failed to check if email exists: %s", email)
|
|
raise
|
|
|
|
async def get_top_users(
|
|
self,
|
|
metric_type: str,
|
|
date_filter: datetime | None = None,
|
|
limit: int = 10,
|
|
) -> list[dict[str, Any]]:
|
|
"""Get top users by different metrics."""
|
|
try:
|
|
query = self._build_top_users_query(metric_type, date_filter)
|
|
|
|
# Add ordering and limit
|
|
query = query.order_by(func.count().desc()).limit(limit)
|
|
|
|
result = await self.session.exec(query)
|
|
rows = result.all()
|
|
|
|
return [
|
|
{
|
|
"id": row[0],
|
|
"name": row[1],
|
|
"count": int(row[2]),
|
|
}
|
|
for row in rows
|
|
]
|
|
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to get top users for metric=%s, date_filter=%s",
|
|
metric_type,
|
|
date_filter,
|
|
)
|
|
raise
|
|
|
|
def _build_top_users_query(
|
|
self,
|
|
metric_type: str,
|
|
date_filter: datetime | None,
|
|
) -> Select:
|
|
"""Build query for top users based on metric type."""
|
|
match metric_type:
|
|
case "sounds_played":
|
|
query = self._build_sounds_played_query()
|
|
case "credits_used":
|
|
query = self._build_credits_used_query()
|
|
case "tracks_added":
|
|
query = self._build_tracks_added_query()
|
|
case "tts_added":
|
|
query = self._build_tts_added_query()
|
|
case "playlists_created":
|
|
query = self._build_playlists_created_query()
|
|
case _:
|
|
msg = f"Unknown metric type: {metric_type}"
|
|
raise ValueError(msg)
|
|
|
|
# Apply date filter if provided
|
|
if date_filter:
|
|
query = self._apply_date_filter(query, metric_type, date_filter)
|
|
|
|
return query
|
|
|
|
def _build_sounds_played_query(self) -> Select:
|
|
"""Build query for sounds played metric."""
|
|
return (
|
|
select(
|
|
User.id,
|
|
User.name,
|
|
func.count(SoundPlayed.id).label("count"),
|
|
)
|
|
.join(SoundPlayed, User.id == SoundPlayed.user_id)
|
|
.group_by(User.id, User.name)
|
|
)
|
|
|
|
def _build_credits_used_query(self) -> Select:
|
|
"""Build query for credits used metric."""
|
|
return (
|
|
select(
|
|
User.id,
|
|
User.name,
|
|
func.sum(func.abs(CreditTransaction.amount)).label("count"),
|
|
)
|
|
.join(CreditTransaction, User.id == CreditTransaction.user_id)
|
|
.where(CreditTransaction.amount < 0)
|
|
.group_by(User.id, User.name)
|
|
)
|
|
|
|
def _build_tracks_added_query(self) -> Select:
|
|
"""Build query for tracks added metric."""
|
|
return (
|
|
select(
|
|
User.id,
|
|
User.name,
|
|
func.count(Extraction.id).label("count"),
|
|
)
|
|
.join(Extraction, User.id == Extraction.user_id)
|
|
.where(Extraction.sound_id.is_not(None))
|
|
.group_by(User.id, User.name)
|
|
)
|
|
|
|
def _build_tts_added_query(self) -> Select:
|
|
"""Build query for TTS added metric."""
|
|
return (
|
|
select(
|
|
User.id,
|
|
User.name,
|
|
func.count(TTS.id).label("count"),
|
|
)
|
|
.join(TTS, User.id == TTS.user_id)
|
|
.group_by(User.id, User.name)
|
|
)
|
|
|
|
def _build_playlists_created_query(self) -> Select:
|
|
"""Build query for playlists created metric."""
|
|
return (
|
|
select(
|
|
User.id,
|
|
User.name,
|
|
func.count(Playlist.id).label("count"),
|
|
)
|
|
.join(Playlist, User.id == Playlist.user_id)
|
|
.group_by(User.id, User.name)
|
|
)
|
|
|
|
def _apply_date_filter(
|
|
self,
|
|
query: Select,
|
|
metric_type: str,
|
|
date_filter: datetime,
|
|
) -> Select:
|
|
"""Apply date filter to query based on metric type."""
|
|
match metric_type:
|
|
case "sounds_played":
|
|
return query.where(SoundPlayed.created_at >= date_filter)
|
|
case "credits_used":
|
|
return query.where(CreditTransaction.created_at >= date_filter)
|
|
case "tracks_added":
|
|
return query.where(Extraction.created_at >= date_filter)
|
|
case "tts_added":
|
|
return query.where(TTS.created_at >= date_filter)
|
|
case "playlists_created":
|
|
return query.where(Playlist.created_at >= date_filter)
|
|
case _:
|
|
return query
|