"""Authentication utilities.""" import secrets from datetime import UTC, datetime, timedelta from typing import Any import bcrypt import jwt from fastapi import HTTPException, status from app.core.config import settings from app.core.logging import get_logger logger = get_logger(__name__) class PasswordUtils: """Utility class for password operations.""" @staticmethod def hash_password(password: str) -> str: """Hash a password using bcrypt.""" salt = bcrypt.gensalt() hashed = bcrypt.hashpw(password.encode("utf-8"), salt) return hashed.decode("utf-8") @staticmethod def verify_password(password: str, hashed_password: str) -> bool: """Verify a password against its hash.""" return bcrypt.checkpw(password.encode("utf-8"), hashed_password.encode("utf-8")) class JWTUtils: """Utility class for JWT operations.""" @staticmethod def create_access_token( data: dict[str, Any], expires_delta: timedelta | None = None, ) -> str: """Create a JWT access token.""" to_encode = data.copy() if expires_delta: expire = datetime.now(UTC) + expires_delta else: expire = datetime.now(UTC) + timedelta( minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES, ) to_encode.update({"exp": expire}) try: encoded_jwt = jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM, ) except Exception as e: logger.exception("Failed to create JWT token") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not create access token", ) from e else: logger.info("JWT token created successfully") return encoded_jwt @staticmethod def create_refresh_token( data: dict[str, Any], expires_delta: timedelta | None = None, ) -> str: """Create a JWT refresh token.""" to_encode = data.copy() if expires_delta: expire = datetime.now(UTC) + expires_delta else: expire = datetime.now(UTC) + timedelta( days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS, ) to_encode.update({"exp": expire, "type": "refresh"}) try: encoded_jwt = jwt.encode( to_encode, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM, ) except Exception as e: logger.exception("Failed to create JWT refresh token") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not create refresh token", ) from e else: logger.info("JWT refresh token created successfully") return encoded_jwt @staticmethod def decode_access_token(token: str) -> dict[str, Any]: """Decode and validate a JWT access token.""" try: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM], ) # Ensure this is not a refresh token if payload.get("type") == "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type", ) return dict(payload) except jwt.ExpiredSignatureError as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired", ) from e except jwt.PyJWTError as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) from e @staticmethod def decode_refresh_token(token: str) -> dict[str, Any]: """Decode and validate a JWT refresh token.""" try: payload = jwt.decode( token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM], ) # Ensure this is a refresh token if payload.get("type") != "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type", ) return dict(payload) except jwt.ExpiredSignatureError as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token has expired", ) from e except jwt.PyJWTError as e: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate refresh token", ) from e class TokenUtils: """Utility class for API token operations.""" @staticmethod def generate_api_token() -> str: """Generate a secure random API token.""" return secrets.token_urlsafe(32) @staticmethod def is_token_expired(expires_at: datetime | None) -> bool: """Check if a token is expired.""" if expires_at is None: return False # Handle timezone-aware and naive datetimes if expires_at.tzinfo is None: # Naive datetime - assume UTC expires_at = expires_at.replace(tzinfo=UTC) else: # Convert to UTC if not already expires_at = expires_at.astimezone(UTC) return datetime.now(UTC) > expires_at