"""FastAPI dependencies.""" from typing import Annotated, NoReturn, cast from fastapi import Cookie, Depends, HTTPException, status from sqlmodel.ext.asyncio.session import AsyncSession from app.core.database import get_db from app.core.logging import get_logger from app.models.user import User from app.services.auth import AuthService from app.utils.auth import JWTUtils logger = get_logger(__name__) def _raise_invalid_token_error() -> NoReturn: """Raise an invalid token HTTP exception.""" raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload", ) def _raise_auth_error() -> NoReturn: """Raise an authentication HTTP exception.""" raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) async def get_auth_service( session: Annotated[AsyncSession, Depends(get_db)], ) -> AuthService: """Get the authentication service.""" return AuthService(session) async def get_current_user( access_token: Annotated[str | None, Cookie()], auth_service: Annotated[AuthService, Depends(get_auth_service)], ) -> User: """Get the current authenticated user from JWT token in HTTP-only cookie.""" try: # Check if access token cookie exists if not access_token: logger.warning("No access token cookie found") _raise_auth_error() # Decode the JWT token payload = JWTUtils.decode_access_token(access_token) # Extract user ID from token user_id_str = payload.get("sub") if not user_id_str: _raise_invalid_token_error() # At this point user_id_str is guaranteed to be truthy, safe to cast user_id_str = cast("str", user_id_str) try: user_id = int(user_id_str) except (ValueError, TypeError) as e: logger.warning("Invalid user ID in token: %s", user_id_str) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload", ) from e # Get the user return await auth_service.get_current_user(user_id) except HTTPException: # Re-raise HTTPExceptions without wrapping them raise except Exception: logger.exception("Failed to authenticate user") _raise_auth_error() async def get_current_active_user( current_user: Annotated[User, Depends(get_current_user)], ) -> User: """Get the current authenticated and active user.""" if not current_user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated", ) return current_user async def get_admin_user( current_user: Annotated[User, Depends(get_current_active_user)], ) -> User: """Get the current authenticated admin user.""" if current_user.role not in ["admin", "superadmin"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions", ) return current_user