- Created a new test package for services and added tests for AuthService. - Implemented tests for user registration, login, and token creation. - Added a new test package for utilities and included tests for password and JWT utilities. - Updated `uv.lock` to include new dependencies: bcrypt, email-validator, pyjwt, and pytest-asyncio.
180 lines
5.7 KiB
Python
180 lines
5.7 KiB
Python
"""Authentication utilities."""
|
|
|
|
import secrets
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
import bcrypt
|
|
import jwt
|
|
from fastapi import HTTPException, status
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class PasswordUtils:
|
|
"""Utility class for password operations."""
|
|
|
|
@staticmethod
|
|
def hash_password(password: str) -> str:
|
|
"""Hash a password using bcrypt."""
|
|
salt = bcrypt.gensalt()
|
|
hashed = bcrypt.hashpw(password.encode("utf-8"), salt)
|
|
return hashed.decode("utf-8")
|
|
|
|
@staticmethod
|
|
def verify_password(password: str, hashed_password: str) -> bool:
|
|
"""Verify a password against its hash."""
|
|
return bcrypt.checkpw(password.encode("utf-8"), hashed_password.encode("utf-8"))
|
|
|
|
|
|
class JWTUtils:
|
|
"""Utility class for JWT operations."""
|
|
|
|
@staticmethod
|
|
def create_access_token(
|
|
data: dict[str, Any],
|
|
expires_delta: timedelta | None = None,
|
|
) -> str:
|
|
"""Create a JWT access token."""
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.now(UTC) + expires_delta
|
|
else:
|
|
expire = datetime.now(UTC) + timedelta(
|
|
minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES,
|
|
)
|
|
|
|
to_encode.update({"exp": expire})
|
|
|
|
try:
|
|
encoded_jwt = jwt.encode(
|
|
to_encode,
|
|
settings.JWT_SECRET_KEY,
|
|
algorithm=settings.JWT_ALGORITHM,
|
|
)
|
|
except Exception as e:
|
|
logger.exception("Failed to create JWT token")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Could not create access token",
|
|
) from e
|
|
else:
|
|
logger.info("JWT token created successfully")
|
|
return encoded_jwt
|
|
|
|
@staticmethod
|
|
def create_refresh_token(
|
|
data: dict[str, Any],
|
|
expires_delta: timedelta | None = None,
|
|
) -> str:
|
|
"""Create a JWT refresh token."""
|
|
to_encode = data.copy()
|
|
|
|
if expires_delta:
|
|
expire = datetime.now(UTC) + expires_delta
|
|
else:
|
|
expire = datetime.now(UTC) + timedelta(
|
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS,
|
|
)
|
|
|
|
to_encode.update({"exp": expire, "type": "refresh"})
|
|
|
|
try:
|
|
encoded_jwt = jwt.encode(
|
|
to_encode,
|
|
settings.JWT_SECRET_KEY,
|
|
algorithm=settings.JWT_ALGORITHM,
|
|
)
|
|
except Exception as e:
|
|
logger.exception("Failed to create JWT refresh token")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Could not create refresh token",
|
|
) from e
|
|
else:
|
|
logger.info("JWT refresh token created successfully")
|
|
return encoded_jwt
|
|
|
|
@staticmethod
|
|
def decode_access_token(token: str) -> dict[str, Any]:
|
|
"""Decode and validate a JWT access token."""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.JWT_SECRET_KEY,
|
|
algorithms=[settings.JWT_ALGORITHM],
|
|
)
|
|
# Ensure this is not a refresh token
|
|
if payload.get("type") == "refresh":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token type",
|
|
)
|
|
return dict(payload)
|
|
except jwt.ExpiredSignatureError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token has expired",
|
|
) from e
|
|
except jwt.PyJWTError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate credentials",
|
|
) from e
|
|
|
|
@staticmethod
|
|
def decode_refresh_token(token: str) -> dict[str, Any]:
|
|
"""Decode and validate a JWT refresh token."""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.JWT_SECRET_KEY,
|
|
algorithms=[settings.JWT_ALGORITHM],
|
|
)
|
|
# Ensure this is a refresh token
|
|
if payload.get("type") != "refresh":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token type",
|
|
)
|
|
return dict(payload)
|
|
except jwt.ExpiredSignatureError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Refresh token has expired",
|
|
) from e
|
|
except jwt.PyJWTError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Could not validate refresh token",
|
|
) from e
|
|
|
|
|
|
class TokenUtils:
|
|
"""Utility class for API token operations."""
|
|
|
|
@staticmethod
|
|
def generate_api_token() -> str:
|
|
"""Generate a secure random API token."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
@staticmethod
|
|
def is_token_expired(expires_at: datetime | None) -> bool:
|
|
"""Check if a token is expired."""
|
|
if expires_at is None:
|
|
return False
|
|
|
|
# Handle timezone-aware and naive datetimes
|
|
if expires_at.tzinfo is None:
|
|
# Naive datetime - assume UTC
|
|
expires_at = expires_at.replace(tzinfo=UTC)
|
|
else:
|
|
# Convert to UTC if not already
|
|
expires_at = expires_at.astimezone(UTC)
|
|
|
|
return datetime.now(UTC) > expires_at
|