Compare commits
2 Commits
734521c5c3
...
0a8b50a0be
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a8b50a0be | ||
|
|
9e07ce393f |
@@ -2,10 +2,11 @@
|
|||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api.v1.admin import extractions, sounds
|
from app.api.v1.admin import extractions, sounds, users
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin")
|
router = APIRouter(prefix="/admin")
|
||||||
|
|
||||||
# Include all admin sub-routers
|
# Include all admin sub-routers
|
||||||
router.include_router(extractions.router)
|
router.include_router(extractions.router)
|
||||||
router.include_router(sounds.router)
|
router.include_router(sounds.router)
|
||||||
|
router.include_router(users.router)
|
||||||
|
|||||||
148
app/api/v1/admin/users.py
Normal file
148
app/api/v1/admin/users.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""Admin users endpoints."""
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.dependencies import get_admin_user
|
||||||
|
from app.models.plan import Plan
|
||||||
|
from app.models.user import User
|
||||||
|
from app.repositories.plan import PlanRepository
|
||||||
|
from app.repositories.user import UserRepository
|
||||||
|
from app.schemas.auth import UserResponse
|
||||||
|
from app.schemas.user import UserUpdate
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/users",
|
||||||
|
tags=["admin-users"],
|
||||||
|
dependencies=[Depends(get_admin_user)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _user_to_response(user: User) -> UserResponse:
|
||||||
|
"""Convert User model to UserResponse."""
|
||||||
|
return UserResponse(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
name=user.name,
|
||||||
|
picture=user.picture,
|
||||||
|
role=user.role,
|
||||||
|
credits=user.credits,
|
||||||
|
is_active=user.is_active,
|
||||||
|
plan={
|
||||||
|
"id": user.plan.id,
|
||||||
|
"name": user.plan.name,
|
||||||
|
"max_credits": user.plan.max_credits,
|
||||||
|
"features": [], # Add features if needed
|
||||||
|
} if user.plan else {},
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/")
|
||||||
|
async def list_users(
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[UserResponse]:
|
||||||
|
"""Get all users (admin only)."""
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
users = await user_repo.get_all_with_plan(limit=limit, offset=offset)
|
||||||
|
return [_user_to_response(user) for user in users]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{user_id}")
|
||||||
|
async def get_user(
|
||||||
|
user_id: int,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> UserResponse:
|
||||||
|
"""Get a specific user by ID (admin only)."""
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id_with_plan(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
return _user_to_response(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{user_id}")
|
||||||
|
async def update_user(
|
||||||
|
user_id: int,
|
||||||
|
user_update: UserUpdate,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> UserResponse:
|
||||||
|
"""Update a user (admin only)."""
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id_with_plan(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = user_update.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# If plan_id is being updated, validate it exists
|
||||||
|
if "plan_id" in update_data:
|
||||||
|
plan_repo = PlanRepository(session)
|
||||||
|
plan = await plan_repo.get_by_id(update_data["plan_id"])
|
||||||
|
if not plan:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Plan not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_user = await user_repo.update(user, update_data)
|
||||||
|
# Need to refresh the plan relationship after update
|
||||||
|
await session.refresh(updated_user, ["plan"])
|
||||||
|
return _user_to_response(updated_user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{user_id}/disable")
|
||||||
|
async def disable_user(
|
||||||
|
user_id: int,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Disable a user (admin only)."""
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id_with_plan(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
await user_repo.update(user, {"is_active": False})
|
||||||
|
return {"message": "User disabled successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{user_id}/enable")
|
||||||
|
async def enable_user(
|
||||||
|
user_id: int,
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Enable a user (admin only)."""
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
user = await user_repo.get_by_id_with_plan(user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
await user_repo.update(user, {"is_active": True})
|
||||||
|
return {"message": "User enabled successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/plans/list")
|
||||||
|
async def list_plans(
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> list[Plan]:
|
||||||
|
"""Get all plans for user editing (admin only)."""
|
||||||
|
plan_repo = PlanRepository(session)
|
||||||
|
return await plan_repo.get_all()
|
||||||
@@ -20,6 +20,8 @@ from app.schemas.auth import (
|
|||||||
ApiTokenRequest,
|
ApiTokenRequest,
|
||||||
ApiTokenResponse,
|
ApiTokenResponse,
|
||||||
ApiTokenStatusResponse,
|
ApiTokenStatusResponse,
|
||||||
|
ChangePasswordRequest,
|
||||||
|
UpdateProfileRequest,
|
||||||
UserLoginRequest,
|
UserLoginRequest,
|
||||||
UserRegisterRequest,
|
UserRegisterRequest,
|
||||||
UserResponse,
|
UserResponse,
|
||||||
@@ -446,3 +448,85 @@ async def revoke_api_token(
|
|||||||
) from e
|
) from e
|
||||||
else:
|
else:
|
||||||
return {"message": "API token revoked successfully"}
|
return {"message": "API token revoked successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
# Profile management endpoints
|
||||||
|
@router.patch("/me")
|
||||||
|
async def update_profile(
|
||||||
|
request: UpdateProfileRequest,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
) -> UserResponse:
|
||||||
|
"""Update the current user's profile."""
|
||||||
|
try:
|
||||||
|
updated_user = await auth_service.update_user_profile(
|
||||||
|
current_user, request.model_dump(exclude_unset=True)
|
||||||
|
)
|
||||||
|
return await auth_service.user_to_response(updated_user)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to update profile for user: %s", current_user.email)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to update profile",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/change-password")
|
||||||
|
async def change_password(
|
||||||
|
request: ChangePasswordRequest,
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Change the current user's password."""
|
||||||
|
# Store user email before operations to avoid session detachment issues
|
||||||
|
user_email = current_user.email
|
||||||
|
try:
|
||||||
|
await auth_service.change_user_password(
|
||||||
|
current_user, request.current_password, request.new_password
|
||||||
|
)
|
||||||
|
return {"message": "Password changed successfully"}
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
|
) from e
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Failed to change password for user: %s", user_email)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to change password",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user-providers")
|
||||||
|
async def get_user_providers(
|
||||||
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||||
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Get the current user's connected authentication providers."""
|
||||||
|
providers = []
|
||||||
|
|
||||||
|
# Add password provider if user has password
|
||||||
|
if current_user.password_hash:
|
||||||
|
providers.append({
|
||||||
|
"provider": "password",
|
||||||
|
"display_name": "Password",
|
||||||
|
"connected_at": current_user.created_at.isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get OAuth providers from the database
|
||||||
|
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
||||||
|
for oauth in oauth_providers:
|
||||||
|
display_name = oauth.provider.title() # Capitalize first letter
|
||||||
|
if oauth.provider == "github":
|
||||||
|
display_name = "GitHub"
|
||||||
|
elif oauth.provider == "google":
|
||||||
|
display_name = "Google"
|
||||||
|
|
||||||
|
providers.append({
|
||||||
|
"provider": oauth.provider,
|
||||||
|
"display_name": display_name,
|
||||||
|
"connected_at": oauth.created_at.isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
|
return providers
|
||||||
|
|||||||
@@ -55,8 +55,8 @@ def create_app() -> FastAPI:
|
|||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
# Configure docs URLs for reverse proxy setup
|
# Configure docs URLs for reverse proxy setup
|
||||||
docs_url="/api/docs", # Swagger UI at /api/docs
|
docs_url="/api/docs", # Swagger UI at /api/docs
|
||||||
redoc_url="/api/redoc", # ReDoc at /api/redoc
|
redoc_url="/api/redoc", # ReDoc at /api/redoc
|
||||||
openapi_url="/api/openapi.json" # OpenAPI schema at /api/openapi.json
|
openapi_url="/api/openapi.json", # OpenAPI schema at /api/openapi.json
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add CORS middleware
|
# Add CORS middleware
|
||||||
|
|||||||
17
app/repositories/plan.py
Normal file
17
app/repositories/plan.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Plan repository."""
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.plan import Plan
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PlanRepository(BaseRepository[Plan]):
|
||||||
|
"""Repository for plan operations."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize the plan repository."""
|
||||||
|
super().__init__(Plan, session)
|
||||||
@@ -4,6 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
@@ -20,6 +21,42 @@ class UserRepository(BaseRepository[User]):
|
|||||||
"""Initialize the user repository."""
|
"""Initialize the user repository."""
|
||||||
super().__init__(User, session)
|
super().__init__(User, session)
|
||||||
|
|
||||||
|
async def get_all_with_plan(
|
||||||
|
self,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[User]:
|
||||||
|
"""Get all users with plan relationship loaded."""
|
||||||
|
try:
|
||||||
|
statement = (
|
||||||
|
select(User)
|
||||||
|
.options(selectinload(User.plan))
|
||||||
|
.limit(limit)
|
||||||
|
.offset(offset)
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get all users with plan")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_id_with_plan(self, entity_id: int) -> User | None:
|
||||||
|
"""Get a user by ID with plan relationship loaded."""
|
||||||
|
try:
|
||||||
|
statement = (
|
||||||
|
select(User)
|
||||||
|
.options(selectinload(User.plan))
|
||||||
|
.where(User.id == entity_id)
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.first()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to get user by ID with plan: %s",
|
||||||
|
entity_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> User | None:
|
async def get_by_email(self, email: str) -> User | None:
|
||||||
"""Get a user by email address."""
|
"""Get a user by email address."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -59,3 +59,13 @@ class UserOauthRepository(BaseRepository[UserOauth]):
|
|||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
return result.first()
|
return result.first()
|
||||||
|
|
||||||
|
async def get_by_user_id(self, user_id: int) -> list[UserOauth]:
|
||||||
|
"""Get all OAuth providers for a user."""
|
||||||
|
try:
|
||||||
|
statement = select(UserOauth).where(UserOauth.user_id == user_id)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to get OAuth providers for user ID: %s", user_id)
|
||||||
|
raise
|
||||||
|
|||||||
@@ -79,3 +79,22 @@ class ApiTokenStatusResponse(BaseModel):
|
|||||||
has_token: bool = Field(..., description="Whether user has an active API token")
|
has_token: bool = Field(..., description="Whether user has an active API token")
|
||||||
expires_at: datetime | None = Field(None, description="Token expiration timestamp")
|
expires_at: datetime | None = Field(None, description="Token expiration timestamp")
|
||||||
is_expired: bool = Field(..., description="Whether the token is expired")
|
is_expired: bool = Field(..., description="Whether the token is expired")
|
||||||
|
|
||||||
|
|
||||||
|
class ChangePasswordRequest(BaseModel):
|
||||||
|
"""Schema for password change request."""
|
||||||
|
|
||||||
|
current_password: str | None = Field(None, description="Current password (required if user has existing password)")
|
||||||
|
new_password: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=8,
|
||||||
|
description="New password (minimum 8 characters)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateProfileRequest(BaseModel):
|
||||||
|
"""Schema for profile update request."""
|
||||||
|
|
||||||
|
name: str | None = Field(
|
||||||
|
None, min_length=1, max_length=100, description="User display name"
|
||||||
|
)
|
||||||
|
|||||||
14
app/schemas/user.py
Normal file
14
app/schemas/user.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""User schemas."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class UserUpdate(BaseModel):
|
||||||
|
"""Schema for updating a user."""
|
||||||
|
|
||||||
|
name: str | None = Field(
|
||||||
|
None, min_length=1, max_length=100, description="User full name",
|
||||||
|
)
|
||||||
|
plan_id: int | None = Field(None, description="User plan ID")
|
||||||
|
credits: int | None = Field(None, ge=0, description="User credits")
|
||||||
|
is_active: bool | None = Field(None, description="Whether user is active")
|
||||||
@@ -430,3 +430,85 @@ class AuthService:
|
|||||||
oauth_user_info.email,
|
oauth_user_info.email,
|
||||||
)
|
)
|
||||||
return AuthResponse(user=user_response, token=token)
|
return AuthResponse(user=user_response, token=token)
|
||||||
|
|
||||||
|
async def update_user_profile(self, user: User, data: dict) -> User:
|
||||||
|
"""Update user profile information."""
|
||||||
|
logger.info("Updating profile for user: %s", user.email)
|
||||||
|
|
||||||
|
# Only allow updating specific fields
|
||||||
|
allowed_fields = {"name"}
|
||||||
|
update_data = {k: v for k, v in data.items() if k in allowed_fields}
|
||||||
|
|
||||||
|
if not update_data:
|
||||||
|
return user
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(user, field, value)
|
||||||
|
|
||||||
|
self.session.add(user)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(user, ["plan"])
|
||||||
|
|
||||||
|
logger.info("Profile updated successfully for user: %s", user.email)
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def change_user_password(
|
||||||
|
self, user: User, current_password: str | None, new_password: str
|
||||||
|
) -> None:
|
||||||
|
"""Change user's password."""
|
||||||
|
# Store user email before any operations to avoid session detachment issues
|
||||||
|
user_email = user.email
|
||||||
|
logger.info("Changing password for user: %s", user_email)
|
||||||
|
|
||||||
|
# Store whether user had existing password before we modify it
|
||||||
|
had_existing_password = user.password_hash is not None
|
||||||
|
|
||||||
|
# If user has existing password, verify it
|
||||||
|
if had_existing_password:
|
||||||
|
if not current_password:
|
||||||
|
raise ValueError("Current password is required when changing existing password")
|
||||||
|
if not PasswordUtils.verify_password(current_password, user.password_hash):
|
||||||
|
raise ValueError("Current password is incorrect")
|
||||||
|
else:
|
||||||
|
# User doesn't have a password (OAuth-only user), so we're setting their first password
|
||||||
|
logger.info("Setting first password for OAuth user: %s", user_email)
|
||||||
|
|
||||||
|
# Hash new password
|
||||||
|
new_password_hash = PasswordUtils.hash_password(new_password)
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
user.password_hash = new_password_hash
|
||||||
|
self.session.add(user)
|
||||||
|
await self.session.commit()
|
||||||
|
|
||||||
|
logger.info("Password %s successfully for user: %s",
|
||||||
|
"changed" if had_existing_password else "set", user_email)
|
||||||
|
|
||||||
|
async def user_to_response(self, user: User) -> UserResponse:
|
||||||
|
"""Convert User model to UserResponse with plan information."""
|
||||||
|
# Load plan relationship if not already loaded
|
||||||
|
if not hasattr(user, 'plan') or not user.plan:
|
||||||
|
await self.session.refresh(user, ["plan"])
|
||||||
|
|
||||||
|
return UserResponse(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
name=user.name,
|
||||||
|
picture=user.picture,
|
||||||
|
role=user.role,
|
||||||
|
credits=user.credits,
|
||||||
|
is_active=user.is_active,
|
||||||
|
plan={
|
||||||
|
"id": user.plan.id,
|
||||||
|
"name": user.plan.name,
|
||||||
|
"max_credits": user.plan.max_credits,
|
||||||
|
"features": [], # Add features if needed
|
||||||
|
},
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_oauth_providers(self, user: User):
|
||||||
|
"""Get OAuth providers connected to the user."""
|
||||||
|
return await self.oauth_repo.get_by_user_id(user.id)
|
||||||
|
|||||||
Reference in New Issue
Block a user