Files
sdb2-backend/app/core/dependencies.py
JSC 3dc21337f9 Refactor tests for improved consistency and readability
- Updated test cases in `test_auth_endpoints.py` to ensure consistent formatting and style.
- Enhanced `test_socket_endpoints.py` with consistent parameter formatting and improved readability.
- Cleaned up `conftest.py` by ensuring consistent parameter formatting in fixtures.
- Added comprehensive tests for API token dependencies in `test_api_token_dependencies.py`.
- Refactored `test_auth_service.py` to maintain consistent parameter formatting.
- Cleaned up `test_oauth_service.py` by removing unnecessary imports.
- Improved `test_socket_service.py` with consistent formatting and readability.
- Enhanced `test_cookies.py` by ensuring consistent formatting and readability.
- Introduced new tests for token utilities in `test_token_utils.py` to validate token generation and expiration logic.
2025-07-27 15:11:47 +02:00

194 lines
6.3 KiB
Python

"""FastAPI dependencies."""
from typing import Annotated, cast
from fastapi import Cookie, Depends, Header, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.logging import get_logger
from app.models.user import User
from app.services.auth import AuthService
from app.services.oauth import OAuthService
from app.utils.auth import JWTUtils, TokenUtils
logger = get_logger(__name__)
async def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
"""Get the authentication service."""
return AuthService(session)
async def get_oauth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> OAuthService:
"""Get the OAuth service."""
return OAuthService(session)
async def get_current_user(
auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None,
) -> User:
"""Get the current authenticated user from JWT token in HTTP-only cookie."""
try:
# Check if access token cookie exists
if not access_token:
logger.warning("No access token cookie found")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
# Decode the JWT token
payload = JWTUtils.decode_access_token(access_token)
# Extract user ID from token
user_id_str = payload.get("sub")
if not user_id_str:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
)
# At this point user_id_str is guaranteed to be truthy, safe to cast
user_id_str = cast("str", user_id_str)
try:
user_id = int(user_id_str)
except (ValueError, TypeError) as e:
logger.warning("Invalid user ID in token: %s", user_id_str)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
) from e
# Get the user
return await auth_service.get_current_user(user_id)
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
logger.exception("Failed to authenticate user")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
) from e
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
) -> User:
"""Get the current authenticated and active user."""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return current_user
async def get_current_user_api_token(
auth_service: Annotated[AuthService, Depends(get_auth_service)],
authorization: Annotated[str | None, Header()] = None,
) -> User:
"""Get the current authenticated user from API token in Authorization header."""
try:
# Check if Authorization header exists
if not authorization:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authorization header required",
)
# Check if it's a Bearer token
if not authorization.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization header format",
)
# Extract the API token
api_token = authorization[7:] # Remove "Bearer " prefix
if not api_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API token required",
)
# Get the user by API token
user = await auth_service.get_user_by_api_token(api_token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API token",
)
# Check if token is expired
if TokenUtils.is_token_expired(user.api_token_expires_at):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API token has expired",
)
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return user
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception as e:
logger.exception("Failed to authenticate user with API token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate API token",
) from e
async def get_current_user_flexible(
auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None,
authorization: Annotated[str | None, Header()] = None,
) -> User:
"""Get the current authenticated user from either JWT cookie or API token."""
# Try API token first if Authorization header is present
if authorization:
return await get_current_user_api_token(auth_service, authorization)
# Fall back to JWT cookie authentication
return await get_current_user(auth_service, access_token)
async def get_current_active_user_flexible(
current_user: Annotated[User, Depends(get_current_user_flexible)],
) -> User:
"""Get the current authenticated and active user using flexible auth."""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return current_user
async def get_admin_user(
current_user: Annotated[User, Depends(get_current_active_user)],
) -> User:
"""Get the current authenticated admin user."""
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
)
return current_user