diff --git a/app/api/v1/dashboard.py b/app/api/v1/dashboard.py index fc46da6..8ad14d1 100644 --- a/app/api/v1/dashboard.py +++ b/app/api/v1/dashboard.py @@ -63,7 +63,10 @@ async def get_top_users( metric_type: Annotated[ str, Query( - description="Metric type: sounds_played, credits_used, tracks_added, tts_added, playlists_created", + description=( + "Metric type: sounds_played, credits_used, tracks_added, " + "tts_added, playlists_created" + ), ), ], period: Annotated[ diff --git a/app/repositories/sound.py b/app/repositories/sound.py index 6e28bdd..10498a0 100644 --- a/app/repositories/sound.py +++ b/app/repositories/sound.py @@ -201,7 +201,10 @@ class SoundRepository(BaseRepository[Sound]): ) raise - async def get_soundboard_statistics(self, sound_type: str = "SDB") -> dict[str, int | float]: + async def get_soundboard_statistics( + self, + sound_type: str = "SDB", + ) -> dict[str, int | float]: """Get statistics for sounds of a specific type.""" try: statement = select( diff --git a/app/repositories/user.py b/app/repositories/user.py index 5645218..2b8b045 100644 --- a/app/repositories/user.py +++ b/app/repositories/user.py @@ -4,20 +4,19 @@ from datetime import datetime from enum import Enum from typing import Any -from sqlalchemy import func +from sqlalchemy import Select, func from sqlalchemy.orm import selectinload from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger -from app.models.plan import Plan -from app.models.user import User -from app.models.sound_played import SoundPlayed from app.models.credit_transaction import CreditTransaction -from app.models.playlist import Playlist -from app.models.sound import Sound -from app.models.tts import TTS from app.models.extraction import Extraction +from app.models.plan import Plan +from app.models.playlist import Playlist +from app.models.sound_played import SoundPlayed +from app.models.tts import TTS +from app.models.user import User from app.repositories.base import BaseRepository logger = get_logger(__name__) @@ -233,81 +232,7 @@ class UserRepository(BaseRepository[User]): ) -> list[dict[str, Any]]: """Get top users by different metrics.""" try: - if metric_type == "sounds_played": - # Get users with most sounds played - query = ( - select( - User.id, - User.name, - func.count(SoundPlayed.id).label("count") - ) - .join(SoundPlayed, User.id == SoundPlayed.user_id) - .group_by(User.id, User.name) - ) - if date_filter: - query = query.where(SoundPlayed.created_at >= date_filter) - - elif metric_type == "credits_used": - # Get users with most credits used (negative transactions) - query = ( - select( - User.id, - User.name, - func.sum(func.abs(CreditTransaction.amount)).label("count") - ) - .join(CreditTransaction, User.id == CreditTransaction.user_id) - .where(CreditTransaction.amount < 0) - .group_by(User.id, User.name) - ) - if date_filter: - query = query.where(CreditTransaction.created_at >= date_filter) - - elif metric_type == "tracks_added": - # Get users with most EXT sounds added (via extractions) - query = ( - select( - User.id, - User.name, - func.count(Extraction.id).label("count") - ) - .join(Extraction, User.id == Extraction.user_id) - .where(Extraction.sound_id.is_not(None)) # Only count successful extractions - .group_by(User.id, User.name) - ) - if date_filter: - query = query.where(Extraction.created_at >= date_filter) - - elif metric_type == "tts_added": - # Get users with most TTS sounds added - query = ( - select( - User.id, - User.name, - func.count(TTS.id).label("count") - ) - .join(TTS, User.id == TTS.user_id) - .group_by(User.id, User.name) - ) - if date_filter: - query = query.where(TTS.created_at >= date_filter) - - elif metric_type == "playlists_created": - # Get users with most playlists created - query = ( - select( - User.id, - User.name, - func.count(Playlist.id).label("count") - ) - .join(Playlist, User.id == Playlist.user_id) - .group_by(User.id, User.name) - ) - if date_filter: - query = query.where(Playlist.created_at >= date_filter) - - else: - msg = f"Unknown metric type: {metric_type}" - raise ValueError(msg) + query = self._build_top_users_query(metric_type, date_filter) # Add ordering and limit query = query.order_by(func.count().desc()).limit(limit) @@ -331,3 +256,113 @@ class UserRepository(BaseRepository[User]): date_filter, ) raise + + def _build_top_users_query( + self, + metric_type: str, + date_filter: datetime | None, + ) -> Select: + """Build query for top users based on metric type.""" + match metric_type: + case "sounds_played": + query = self._build_sounds_played_query() + case "credits_used": + query = self._build_credits_used_query() + case "tracks_added": + query = self._build_tracks_added_query() + case "tts_added": + query = self._build_tts_added_query() + case "playlists_created": + query = self._build_playlists_created_query() + case _: + msg = f"Unknown metric type: {metric_type}" + raise ValueError(msg) + + # Apply date filter if provided + if date_filter: + query = self._apply_date_filter(query, metric_type, date_filter) + + return query + + def _build_sounds_played_query(self) -> Select: + """Build query for sounds played metric.""" + return ( + select( + User.id, + User.name, + func.count(SoundPlayed.id).label("count"), + ) + .join(SoundPlayed, User.id == SoundPlayed.user_id) + .group_by(User.id, User.name) + ) + + def _build_credits_used_query(self) -> Select: + """Build query for credits used metric.""" + return ( + select( + User.id, + User.name, + func.sum(func.abs(CreditTransaction.amount)).label("count"), + ) + .join(CreditTransaction, User.id == CreditTransaction.user_id) + .where(CreditTransaction.amount < 0) + .group_by(User.id, User.name) + ) + + def _build_tracks_added_query(self) -> Select: + """Build query for tracks added metric.""" + return ( + select( + User.id, + User.name, + func.count(Extraction.id).label("count"), + ) + .join(Extraction, User.id == Extraction.user_id) + .where(Extraction.sound_id.is_not(None)) + .group_by(User.id, User.name) + ) + + def _build_tts_added_query(self) -> Select: + """Build query for TTS added metric.""" + return ( + select( + User.id, + User.name, + func.count(TTS.id).label("count"), + ) + .join(TTS, User.id == TTS.user_id) + .group_by(User.id, User.name) + ) + + def _build_playlists_created_query(self) -> Select: + """Build query for playlists created metric.""" + return ( + select( + User.id, + User.name, + func.count(Playlist.id).label("count"), + ) + .join(Playlist, User.id == Playlist.user_id) + .group_by(User.id, User.name) + ) + + def _apply_date_filter( + self, + query: Select, + metric_type: str, + date_filter: datetime, + ) -> Select: + """Apply date filter to query based on metric type.""" + match metric_type: + case "sounds_played": + return query.where(SoundPlayed.created_at >= date_filter) + case "credits_used": + return query.where(CreditTransaction.created_at >= date_filter) + case "tracks_added": + return query.where(Extraction.created_at >= date_filter) + case "tts_added": + return query.where(TTS.created_at >= date_filter) + case "playlists_created": + return query.where(Playlist.created_at >= date_filter) + case _: + return query diff --git a/app/services/dashboard.py b/app/services/dashboard.py index b2845b7..101966f 100644 --- a/app/services/dashboard.py +++ b/app/services/dashboard.py @@ -13,7 +13,11 @@ logger = get_logger(__name__) class DashboardService: """Service for dashboard statistics and analytics.""" - def __init__(self, sound_repository: SoundRepository, user_repository: UserRepository) -> None: + def __init__( + self, + sound_repository: SoundRepository, + user_repository: UserRepository, + ) -> None: """Initialize the dashboard service.""" self.sound_repository = sound_repository self.user_repository = user_repository