Add tests for authentication and utilities, and update dependencies

- Created a new test package for services and added tests for AuthService.
- Implemented tests for user registration, login, and token creation.
- Added a new test package for utilities and included tests for password and JWT utilities.
- Updated `uv.lock` to include new dependencies: bcrypt, email-validator, pyjwt, and pytest-asyncio.
This commit is contained in:
JSC
2025-07-25 17:48:43 +02:00
parent af20bc8724
commit e456d34897
23 changed files with 2381 additions and 8 deletions

View File

@@ -2,10 +2,11 @@
from fastapi import APIRouter
from app.api.v1 import main
from app.api.v1 import auth, main
# V1 API router with v1 prefix
api_router = APIRouter(prefix="/v1")
# Include all route modules
api_router.include_router(main.router, tags=["main"])
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])

198
app/api/v1/auth.py Normal file
View File

@@ -0,0 +1,198 @@
"""Authentication endpoints."""
from typing import Annotated
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status
from app.core.config import settings
from app.core.dependencies import get_auth_service, get_current_active_user
from app.core.logging import get_logger
from app.models.user import User
from app.schemas.auth import (
UserLoginRequest,
UserRegisterRequest,
UserResponse,
)
from app.services.auth import AuthService
router = APIRouter()
logger = get_logger(__name__)
@router.post(
"/register",
status_code=status.HTTP_201_CREATED,
)
async def register(
request: UserRegisterRequest,
response: Response,
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> UserResponse:
"""Register a new user account."""
try:
auth_response = await auth_service.register(request)
# Create and store refresh token - need to get User object from service
user = await auth_service.get_current_user(auth_response.user.id)
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
max_age=auth_response.token.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, # Convert days to seconds
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
# Return only user data, tokens are now in cookies
return auth_response.user
except HTTPException:
raise
except Exception as e:
logger.exception("Registration failed for email: %s", request.email)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Registration failed",
) from e
@router.post("/login")
async def login(
request: UserLoginRequest,
response: Response,
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> UserResponse:
"""Authenticate a user and return access token."""
try:
auth_response = await auth_service.login(request)
# Create and store refresh token - need to get User object from service
user = await auth_service.get_current_user(auth_response.user.id)
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
max_age=auth_response.token.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, # Convert days to seconds
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
# Return only user data, tokens are now in cookies
return auth_response.user
except HTTPException:
raise
except Exception as e:
logger.exception("Login failed for email: %s", request.email)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Login failed",
) from e
@router.get("/me")
async def get_current_user_info(
current_user: Annotated[User, Depends(get_current_active_user)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> UserResponse:
"""Get current user information."""
try:
return await auth_service.create_user_response(current_user)
except Exception as e:
logger.exception("Failed to get current user info")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve user information",
) from e
@router.post("/refresh")
async def refresh_token(
response: Response,
refresh_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> dict[str, str]:
"""Refresh access token using refresh token."""
try:
if not refresh_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No refresh token provided",
)
# Get new access token
token_response = await auth_service.refresh_access_token(refresh_token)
# Set new access token cookie
response.set_cookie(
key="access_token",
value=token_response.access_token,
max_age=token_response.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
return {"message": "Token refreshed successfully"}
except HTTPException:
raise
except Exception as e:
logger.exception("Token refresh failed")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Token refresh failed",
) from e
@router.post("/logout")
async def logout(
response: Response,
current_user: Annotated[User, Depends(get_current_active_user)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> dict[str, str]:
"""Logout endpoint - clears cookies and revokes refresh token."""
try:
# Revoke refresh token from database
await auth_service.revoke_refresh_token(current_user)
# Clear both cookies
response.delete_cookie(
key="access_token",
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
response.delete_cookie(
key="refresh_token",
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
return {"message": "Successfully logged out"}
except Exception as e:
logger.exception("Logout failed")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Logout failed",
) from e

View File

@@ -13,4 +13,4 @@ logger = get_logger(__name__)
def health() -> dict[str, str]:
"""Health check endpoint."""
logger.info("Health check endpoint accessed")
return {"status": "healthy"}
return {"status": "healthy"}

View File

@@ -1,3 +1,5 @@
from typing import Literal
from pydantic_settings import BaseSettings, SettingsConfigDict
@@ -14,15 +16,27 @@ class Settings(BaseSettings):
HOST: str = "localhost"
PORT: int = 8000
RELOAD: bool = True
LOG_LEVEL: str = "info"
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
DATABASE_ECHO: bool = False
LOG_LEVEL: str = "info"
LOG_FILE: str = "logs/app.log"
LOG_MAX_SIZE: int = 10 * 1024 * 1024
LOG_BACKUP_COUNT: int = 5
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
DATABASE_ECHO: bool = False
# JWT Configuration
JWT_SECRET_KEY: str = (
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
)
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # Shorter-lived access token
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Longer-lived refresh token
# Cookie Configuration
COOKIE_SECURE: bool = True # Set to False for development without HTTPS
COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax"
settings = Settings()

103
app/core/dependencies.py Normal file
View File

@@ -0,0 +1,103 @@
"""FastAPI dependencies."""
from typing import Annotated, NoReturn, cast
from fastapi import Cookie, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.logging import get_logger
from app.models.user import User
from app.services.auth import AuthService
from app.utils.auth import JWTUtils
logger = get_logger(__name__)
def _raise_invalid_token_error() -> NoReturn:
"""Raise an invalid token HTTP exception."""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
)
def _raise_auth_error() -> NoReturn:
"""Raise an authentication HTTP exception."""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
async def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
"""Get the authentication service."""
return AuthService(session)
async def get_current_user(
access_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> User:
"""Get the current authenticated user from JWT token in HTTP-only cookie."""
try:
# Check if access token cookie exists
if not access_token:
logger.warning("No access token cookie found")
_raise_auth_error()
# Decode the JWT token
payload = JWTUtils.decode_access_token(access_token)
# Extract user ID from token
user_id_str = payload.get("sub")
if not user_id_str:
_raise_invalid_token_error()
# At this point user_id_str is guaranteed to be truthy, safe to cast
user_id_str = cast("str", user_id_str)
try:
user_id = int(user_id_str)
except (ValueError, TypeError) as e:
logger.warning("Invalid user ID in token: %s", user_id_str)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
) from e
# Get the user
return await auth_service.get_current_user(user_id)
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception:
logger.exception("Failed to authenticate user")
_raise_auth_error()
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
) -> User:
"""Get the current authenticated and active user."""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return current_user
async def get_admin_user(
current_user: Annotated[User, Depends(get_current_active_user)],
) -> User:
"""Get the current authenticated admin user."""
if current_user.role not in ["admin", "superadmin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
)
return current_user

View File

@@ -26,6 +26,8 @@ class User(BaseModel, table=True):
credits: int = Field(default=0, ge=0, nullable=False)
api_token: str | None = Field(unique=True, default=None)
api_token_expires_at: datetime | None = Field(default=None)
refresh_token_hash: str | None = Field(default=None)
refresh_token_expires_at: datetime | None = Field(default=None)
# relationships
oauths: list["UserOauth"] = Relationship(back_populates="user")

134
app/repositories/user.py Normal file
View File

@@ -0,0 +1,134 @@
"""User repository."""
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.plan import Plan
from app.models.user import User
logger = get_logger(__name__)
class UserRepository:
"""Repository for user operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the user repository."""
self.session = session
async def get_by_id(self, user_id: int) -> User | None:
"""Get a user by ID."""
try:
statement = select(User).where(User.id == user_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by ID: %s", user_id)
raise
async def get_by_email(self, email: str) -> User | None:
"""Get a user by email address."""
try:
statement = select(User).where(User.email == email)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by email: %s", email)
raise
async def get_by_api_token(self, api_token: str) -> User | None:
"""Get a user by API token."""
try:
statement = select(User).where(User.api_token == api_token)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception("Failed to get user by API token")
raise
async def create(self, user_data: dict[str, Any]) -> User:
"""Create a new user."""
def _raise_plan_not_found() -> None:
msg = "Default plan not found"
raise ValueError(msg)
try:
# Check if this is the first user
user_count_statement = select(User)
user_count_result = await self.session.exec(user_count_statement)
is_first_user = user_count_result.first() is None
if is_first_user:
# First user gets admin role and pro plan
plan_statement = select(Plan).where(Plan.code == "pro")
user_data["role"] = "admin"
logger.info("Creating first user with admin role and pro plan")
else:
# Regular users get free plan
plan_statement = select(Plan).where(Plan.code == "free")
plan_result = await self.session.exec(plan_statement)
default_plan = plan_result.first()
if default_plan is None:
_raise_plan_not_found()
# Type assertion to help type checker understand default_plan is not None
assert default_plan is not None # noqa: S101
# Set plan_id and default credits
user_data["plan_id"] = default_plan.id
user_data["credits"] = default_plan.credits
user = User(**user_data)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user)
except Exception:
await self.session.rollback()
logger.exception("Failed to create user")
raise
else:
logger.info("Created new user with email: %s", user.email)
return user
async def update(self, user: User, update_data: dict[str, Any]) -> User:
"""Update a user."""
try:
for field, value in update_data.items():
setattr(user, field, value)
await self.session.commit()
await self.session.refresh(user)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user")
raise
else:
logger.info("Updated user: %s", user.email)
return user
async def delete(self, user: User) -> None:
"""Delete a user."""
try:
await self.session.delete(user)
await self.session.commit()
logger.info("Deleted user: %s", user.email)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user")
raise
async def email_exists(self, email: str) -> bool:
"""Check if an email address is already registered."""
try:
statement = select(User).where(User.email == email)
result = await self.session.exec(statement)
return result.first() is not None
except Exception:
logger.exception("Failed to check if email exists: %s", email)
raise

53
app/schemas/auth.py Normal file
View File

@@ -0,0 +1,53 @@
"""Authentication schemas."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel, EmailStr, Field
class UserRegisterRequest(BaseModel):
"""Schema for user registration request."""
email: EmailStr = Field(..., description="User email address")
password: str = Field(
..., min_length=8, description="User password (minimum 8 characters)",
)
name: str = Field(..., min_length=1, max_length=100, description="User full name")
class UserLoginRequest(BaseModel):
"""Schema for user login request."""
email: EmailStr = Field(..., description="User email address")
password: str = Field(..., description="User password")
class TokenResponse(BaseModel):
"""Schema for authentication token response."""
access_token: str = Field(..., description="JWT access token")
token_type: str = Field(default="bearer", description="Token type")
expires_in: int = Field(..., description="Token expiration time in seconds")
class UserResponse(BaseModel):
"""Schema for user information response."""
id: int = Field(..., description="User ID")
email: str = Field(..., description="User email address")
name: str = Field(..., description="User full name")
picture: str | None = Field(None, description="User profile picture URL")
role: str = Field(..., description="User role")
credits: int = Field(..., description="User credits")
is_active: bool = Field(..., description="Whether user is active")
plan: dict[str, Any] = Field(..., description="User plan information")
created_at: datetime = Field(..., description="User creation timestamp")
updated_at: datetime = Field(..., description="User last update timestamp")
class AuthResponse(BaseModel):
"""Schema for authentication response."""
user: UserResponse = Field(..., description="User information")
token: TokenResponse = Field(..., description="Authentication token")

268
app/services/auth.py Normal file
View File

@@ -0,0 +1,268 @@
"""Authentication service."""
import hashlib
from datetime import UTC, datetime, timedelta
from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.models.user import User
from app.repositories.user import UserRepository
from app.schemas.auth import (
AuthResponse,
TokenResponse,
UserLoginRequest,
UserRegisterRequest,
UserResponse,
)
from app.utils.auth import JWTUtils, PasswordUtils
logger = get_logger(__name__)
class AuthService:
"""Service for authentication operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the auth service."""
self.session = session
self.user_repo = UserRepository(session)
async def register(self, request: UserRegisterRequest) -> AuthResponse:
"""Register a new user."""
logger.info("Attempting to register user with email: %s", request.email)
# Check if email already exists
if await self.user_repo.email_exists(request.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email address is already registered",
)
# Hash the password
hashed_password = PasswordUtils.hash_password(request.password)
# Create user data
user_data = {
"email": request.email,
"name": request.name,
"password_hash": hashed_password,
"role": "user",
"is_active": True,
}
# Create the user
user = await self.user_repo.create(user_data)
# Generate access token
token = self._create_access_token(user)
# Create response
user_response = await self.create_user_response(user)
logger.info("Successfully registered user: %s", user.email)
return AuthResponse(user=user_response, token=token)
async def login(self, request: UserLoginRequest) -> AuthResponse:
"""Authenticate a user login."""
logger.info("Attempting to login user with email: %s", request.email)
# Get user by email
user = await self.user_repo.get_by_email(request.email)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
# Verify password
if not user.password_hash or not PasswordUtils.verify_password(
request.password,
user.password_hash,
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
# Generate access token
token = self._create_access_token(user)
# Create response
user_response = await self.create_user_response(user)
logger.info("Successfully authenticated user: %s", user.email)
return AuthResponse(user=user_response, token=token)
async def get_current_user(self, user_id: int) -> User:
"""Get the current authenticated user."""
user = await self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return user
def _create_access_token(self, user: User) -> TokenResponse:
"""Create an access token for a user."""
access_token_expires = timedelta(
minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES,
)
token_data = {
"sub": str(user.id),
"email": user.email,
"role": user.role,
}
access_token = JWTUtils.create_access_token(
data=token_data,
expires_delta=access_token_expires,
)
return TokenResponse(
access_token=access_token,
token_type="bearer", # noqa: S106 # This is OAuth2 standard, not a password
expires_in=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
async def create_and_store_refresh_token(self, user: User) -> str:
"""Create and store a refresh token for a user."""
refresh_token_expires = timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
token_data = {
"sub": str(user.id),
"email": user.email,
}
refresh_token = JWTUtils.create_refresh_token(
data=token_data,
expires_delta=refresh_token_expires,
)
# Hash the refresh token for storage
refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
# Store hash and expiration in database
user.refresh_token_hash = refresh_token_hash
user.refresh_token_expires_at = datetime.now(UTC) + refresh_token_expires
self.session.add(user)
await self.session.commit()
return refresh_token
async def refresh_access_token(self, refresh_token: str) -> TokenResponse:
"""Create a new access token using a refresh token."""
try:
# Decode the refresh token
payload = JWTUtils.decode_refresh_token(refresh_token)
user_id_str = payload.get("sub")
if not user_id_str:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
user_id = int(user_id_str)
# Get the user
user = await self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
# Check if refresh token hash matches stored hash
refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
if (
not user.refresh_token_hash
or user.refresh_token_hash != refresh_token_hash
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
# Check if refresh token is expired
if user.refresh_token_expires_at and datetime.now(
UTC
) > user.refresh_token_expires_at.replace(tzinfo=UTC):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has expired",
)
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
# Create new access token
return self._create_access_token(user)
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to refresh access token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
) from e
async def revoke_refresh_token(self, user: User) -> None:
"""Revoke a user's refresh token."""
user.refresh_token_hash = None
user.refresh_token_expires_at = None
self.session.add(user)
await self.session.commit()
logger.info("Refresh token revoked for user: %s", user.email)
async def create_user_response(self, user: User) -> UserResponse:
"""Create a user response from a user model."""
# Always refresh to ensure the plan relationship is loaded
await self.session.refresh(user, ["plan"])
# Ensure user has an ID (should always be true for persisted users)
if user.id is None:
msg = "User must have an ID to create response"
raise ValueError(msg)
return UserResponse(
id=user.id,
email=user.email,
name=user.name,
picture=user.picture,
role=user.role,
credits=user.credits,
is_active=user.is_active,
plan={
"id": user.plan.id,
"code": user.plan.code,
"name": user.plan.name,
"description": user.plan.description,
"credits": user.plan.credits,
"max_credits": user.plan.max_credits,
},
created_at=user.created_at,
updated_at=user.updated_at,
)

179
app/utils/auth.py Normal file
View File

@@ -0,0 +1,179 @@
"""Authentication utilities."""
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any
import bcrypt
import jwt
from fastapi import HTTPException, status
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class PasswordUtils:
"""Utility class for password operations."""
@staticmethod
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
salt = bcrypt.gensalt()
hashed = bcrypt.hashpw(password.encode("utf-8"), salt)
return hashed.decode("utf-8")
@staticmethod
def verify_password(password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(password.encode("utf-8"), hashed_password.encode("utf-8"))
class JWTUtils:
"""Utility class for JWT operations."""
@staticmethod
def create_access_token(
data: dict[str, Any],
expires_delta: timedelta | None = None,
) -> str:
"""Create a JWT access token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(
minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES,
)
to_encode.update({"exp": expire})
try:
encoded_jwt = jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM,
)
except Exception as e:
logger.exception("Failed to create JWT token")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not create access token",
) from e
else:
logger.info("JWT token created successfully")
return encoded_jwt
@staticmethod
def create_refresh_token(
data: dict[str, Any],
expires_delta: timedelta | None = None,
) -> str:
"""Create a JWT refresh token."""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS,
)
to_encode.update({"exp": expire, "type": "refresh"})
try:
encoded_jwt = jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM,
)
except Exception as e:
logger.exception("Failed to create JWT refresh token")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Could not create refresh token",
) from e
else:
logger.info("JWT refresh token created successfully")
return encoded_jwt
@staticmethod
def decode_access_token(token: str) -> dict[str, Any]:
"""Decode and validate a JWT access token."""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
# Ensure this is not a refresh token
if payload.get("type") == "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
return dict(payload)
except jwt.ExpiredSignatureError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from e
except jwt.PyJWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
) from e
@staticmethod
def decode_refresh_token(token: str) -> dict[str, Any]:
"""Decode and validate a JWT refresh token."""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
# Ensure this is a refresh token
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
return dict(payload)
except jwt.ExpiredSignatureError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has expired",
) from e
except jwt.PyJWTError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate refresh token",
) from e
class TokenUtils:
"""Utility class for API token operations."""
@staticmethod
def generate_api_token() -> str:
"""Generate a secure random API token."""
return secrets.token_urlsafe(32)
@staticmethod
def is_token_expired(expires_at: datetime | None) -> bool:
"""Check if a token is expired."""
if expires_at is None:
return False
# Handle timezone-aware and naive datetimes
if expires_at.tzinfo is None:
# Naive datetime - assume UTC
expires_at = expires_at.replace(tzinfo=UTC)
else:
# Convert to UTC if not already
expires_at = expires_at.astimezone(UTC)
return datetime.now(UTC) > expires_at