187 lines
6.1 KiB
Python
187 lines
6.1 KiB
Python
"""FastAPI dependencies."""
|
|
|
|
from typing import Annotated, cast
|
|
|
|
from fastapi import Cookie, Depends, Header, HTTPException, status
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.database import get_db
|
|
from app.core.logging import get_logger
|
|
from app.models.user import User
|
|
from app.services.auth import AuthService
|
|
from app.services.oauth import OAuthService
|
|
from app.utils.auth import JWTUtils, TokenUtils
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
async def get_auth_service(
|
|
session: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> AuthService:
|
|
"""Get the authentication service."""
|
|
return AuthService(session)
|
|
|
|
|
|
async def get_oauth_service(
|
|
session: Annotated[AsyncSession, Depends(get_db)],
|
|
) -> OAuthService:
|
|
"""Get the OAuth service."""
|
|
return OAuthService(session)
|
|
|
|
|
|
async def get_current_user(
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
access_token: Annotated[str | None, Cookie()] = None,
|
|
) -> User:
|
|
"""Get the current authenticated user from JWT token in HTTP-only cookie."""
|
|
try:
|
|
# Check if access token cookie exists
|
|
if not access_token:
|
|
logger.warning("No access token cookie found")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
)
|
|
|
|
# Decode the JWT token
|
|
payload = JWTUtils.decode_access_token(access_token)
|
|
|
|
# Extract user ID from token
|
|
user_id_str = payload.get("sub")
|
|
if not user_id_str:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token payload",
|
|
)
|
|
|
|
# At this point user_id_str is guaranteed to be truthy, safe to cast
|
|
user_id_str = cast("str", user_id_str)
|
|
|
|
try:
|
|
user_id = int(user_id_str)
|
|
except (ValueError, TypeError) as e:
|
|
logger.warning("Invalid user ID in token: %s", user_id_str)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token payload",
|
|
) from e
|
|
|
|
# Get the user
|
|
return await auth_service.get_current_user(user_id)
|
|
|
|
except HTTPException:
|
|
# Re-raise HTTPExceptions without wrapping them
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Failed to authenticate user")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
) from e
|
|
|
|
|
|
async def get_current_active_user(
|
|
current_user: Annotated[User, Depends(get_current_user)],
|
|
) -> User:
|
|
"""Get the current authenticated and active user."""
|
|
if not current_user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Account is deactivated",
|
|
)
|
|
return current_user
|
|
|
|
|
|
async def get_current_user_api_token(
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None,
|
|
) -> User:
|
|
"""Get the current authenticated user from API token in API-TOKEN header."""
|
|
try:
|
|
# Check if API-TOKEN header exists
|
|
if not api_token_header:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="API-TOKEN header required",
|
|
)
|
|
|
|
# Use the API token directly
|
|
api_token = api_token_header.strip()
|
|
if not api_token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="API token required",
|
|
)
|
|
|
|
# Get the user by API token
|
|
user = await auth_service.get_user_by_api_token(api_token)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid API token",
|
|
)
|
|
|
|
# Check if token is expired
|
|
if TokenUtils.is_token_expired(user.api_token_expires_at):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="API token has expired",
|
|
)
|
|
|
|
# Check if user is active
|
|
if not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Account is deactivated",
|
|
)
|
|
|
|
except HTTPException:
|
|
# Re-raise HTTPExceptions without wrapping them
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Failed to authenticate user with API token")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate API token",
|
|
) from e
|
|
else:
|
|
return user
|
|
|
|
|
|
async def get_current_user_flexible(
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
access_token: Annotated[str | None, Cookie()] = None,
|
|
api_token_header: Annotated[str | None, Header(alias="API-TOKEN")] = None,
|
|
) -> User:
|
|
"""Get the current authenticated user from either JWT cookie or API token."""
|
|
# Try API token first if API-TOKEN header is present
|
|
if api_token_header:
|
|
return await get_current_user_api_token(auth_service, api_token_header)
|
|
|
|
# Fall back to JWT cookie authentication
|
|
return await get_current_user(auth_service, access_token)
|
|
|
|
|
|
async def get_current_active_user_flexible(
|
|
current_user: Annotated[User, Depends(get_current_user_flexible)],
|
|
) -> User:
|
|
"""Get the current authenticated and active user using flexible auth."""
|
|
if not current_user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Account is deactivated",
|
|
)
|
|
return current_user
|
|
|
|
|
|
async def get_admin_user(
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> User:
|
|
"""Get the current authenticated admin user."""
|
|
if current_user.role not in ["admin", "superadmin"]:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Not enough permissions",
|
|
)
|
|
return current_user
|