Add tests for authentication and utilities, and update dependencies
- Created a new test package for services and added tests for AuthService. - Implemented tests for user registration, login, and token creation. - Added a new test package for utilities and included tests for password and JWT utilities. - Updated `uv.lock` to include new dependencies: bcrypt, email-validator, pyjwt, and pytest-asyncio.
This commit is contained in:
@@ -2,10 +2,11 @@
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import main
|
||||
from app.api.v1 import auth, main
|
||||
|
||||
# V1 API router with v1 prefix
|
||||
api_router = APIRouter(prefix="/v1")
|
||||
|
||||
# Include all route modules
|
||||
api_router.include_router(main.router, tags=["main"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
|
||||
|
||||
198
app/api/v1/auth.py
Normal file
198
app/api/v1/auth.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.dependencies import get_auth_service, get_current_active_user
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.schemas.auth import (
|
||||
UserLoginRequest,
|
||||
UserRegisterRequest,
|
||||
UserResponse,
|
||||
)
|
||||
from app.services.auth import AuthService
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/register",
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def register(
|
||||
request: UserRegisterRequest,
|
||||
response: Response,
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> UserResponse:
|
||||
"""Register a new user account."""
|
||||
try:
|
||||
auth_response = await auth_service.register(request)
|
||||
|
||||
# Create and store refresh token - need to get User object from service
|
||||
user = await auth_service.get_current_user(auth_response.user.id)
|
||||
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
||||
|
||||
# Set HTTP-only cookies for both tokens
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=auth_response.token.access_token,
|
||||
max_age=auth_response.token.expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=refresh_token,
|
||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, # Convert days to seconds
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
|
||||
# Return only user data, tokens are now in cookies
|
||||
return auth_response.user
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Registration failed for email: %s", request.email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Registration failed",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
request: UserLoginRequest,
|
||||
response: Response,
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> UserResponse:
|
||||
"""Authenticate a user and return access token."""
|
||||
try:
|
||||
auth_response = await auth_service.login(request)
|
||||
|
||||
# Create and store refresh token - need to get User object from service
|
||||
user = await auth_service.get_current_user(auth_response.user.id)
|
||||
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
||||
|
||||
# Set HTTP-only cookies for both tokens
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=auth_response.token.access_token,
|
||||
max_age=auth_response.token.expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=refresh_token,
|
||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, # Convert days to seconds
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
|
||||
# Return only user data, tokens are now in cookies
|
||||
return auth_response.user
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Login failed for email: %s", request.email)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_current_user_info(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> UserResponse:
|
||||
"""Get current user information."""
|
||||
try:
|
||||
return await auth_service.create_user_response(current_user)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get current user info")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve user information",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token(
|
||||
response: Response,
|
||||
refresh_token: Annotated[str | None, Cookie()],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Refresh access token using refresh token."""
|
||||
try:
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No refresh token provided",
|
||||
)
|
||||
|
||||
# Get new access token
|
||||
token_response = await auth_service.refresh_access_token(refresh_token)
|
||||
|
||||
# Set new access token cookie
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token_response.access_token,
|
||||
max_age=token_response.expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
|
||||
return {"message": "Token refreshed successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Token refresh failed")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Token refresh failed",
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
response: Response,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Logout endpoint - clears cookies and revokes refresh token."""
|
||||
try:
|
||||
# Revoke refresh token from database
|
||||
await auth_service.revoke_refresh_token(current_user)
|
||||
|
||||
# Clear both cookies
|
||||
response.delete_cookie(
|
||||
key="access_token",
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
response.delete_cookie(
|
||||
key="refresh_token",
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
|
||||
return {"message": "Successfully logged out"}
|
||||
except Exception as e:
|
||||
logger.exception("Logout failed")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Logout failed",
|
||||
) from e
|
||||
@@ -13,4 +13,4 @@ logger = get_logger(__name__)
|
||||
def health() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
logger.info("Health check endpoint accessed")
|
||||
return {"status": "healthy"}
|
||||
return {"status": "healthy"}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
@@ -14,15 +16,27 @@ class Settings(BaseSettings):
|
||||
HOST: str = "localhost"
|
||||
PORT: int = 8000
|
||||
RELOAD: bool = True
|
||||
LOG_LEVEL: str = "info"
|
||||
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
|
||||
DATABASE_ECHO: bool = False
|
||||
|
||||
LOG_LEVEL: str = "info"
|
||||
LOG_FILE: str = "logs/app.log"
|
||||
LOG_MAX_SIZE: int = 10 * 1024 * 1024
|
||||
LOG_BACKUP_COUNT: int = 5
|
||||
LOG_FORMAT: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
|
||||
DATABASE_ECHO: bool = False
|
||||
# JWT Configuration
|
||||
JWT_SECRET_KEY: str = (
|
||||
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
|
||||
)
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # Shorter-lived access token
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Longer-lived refresh token
|
||||
|
||||
# Cookie Configuration
|
||||
COOKIE_SECURE: bool = True # Set to False for development without HTTPS
|
||||
COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
103
app/core/dependencies.py
Normal file
103
app/core/dependencies.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""FastAPI dependencies."""
|
||||
|
||||
from typing import Annotated, NoReturn, cast
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.services.auth import AuthService
|
||||
from app.utils.auth import JWTUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _raise_invalid_token_error() -> NoReturn:
|
||||
"""Raise an invalid token HTTP exception."""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
)
|
||||
|
||||
|
||||
def _raise_auth_error() -> NoReturn:
|
||||
"""Raise an authentication HTTP exception."""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
|
||||
|
||||
async def get_auth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AuthService:
|
||||
"""Get the authentication service."""
|
||||
return AuthService(session)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
access_token: Annotated[str | None, Cookie()],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> User:
|
||||
"""Get the current authenticated user from JWT token in HTTP-only cookie."""
|
||||
try:
|
||||
# Check if access token cookie exists
|
||||
if not access_token:
|
||||
logger.warning("No access token cookie found")
|
||||
_raise_auth_error()
|
||||
|
||||
# Decode the JWT token
|
||||
payload = JWTUtils.decode_access_token(access_token)
|
||||
|
||||
# Extract user ID from token
|
||||
user_id_str = payload.get("sub")
|
||||
if not user_id_str:
|
||||
_raise_invalid_token_error()
|
||||
|
||||
# At this point user_id_str is guaranteed to be truthy, safe to cast
|
||||
user_id_str = cast("str", user_id_str)
|
||||
|
||||
try:
|
||||
user_id = int(user_id_str)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning("Invalid user ID in token: %s", user_id_str)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
) from e
|
||||
|
||||
# Get the user
|
||||
return await auth_service.get_current_user(user_id)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to authenticate user")
|
||||
_raise_auth_error()
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> User:
|
||||
"""Get the current authenticated and active user."""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_admin_user(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> User:
|
||||
"""Get the current authenticated admin user."""
|
||||
if current_user.role not in ["admin", "superadmin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions",
|
||||
)
|
||||
return current_user
|
||||
@@ -26,6 +26,8 @@ class User(BaseModel, table=True):
|
||||
credits: int = Field(default=0, ge=0, nullable=False)
|
||||
api_token: str | None = Field(unique=True, default=None)
|
||||
api_token_expires_at: datetime | None = Field(default=None)
|
||||
refresh_token_hash: str | None = Field(default=None)
|
||||
refresh_token_expires_at: datetime | None = Field(default=None)
|
||||
|
||||
# relationships
|
||||
oauths: list["UserOauth"] = Relationship(back_populates="user")
|
||||
|
||||
134
app/repositories/user.py
Normal file
134
app/repositories/user.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""User repository."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.plan import Plan
|
||||
from app.models.user import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Repository for user operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the user repository."""
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, user_id: int) -> User | None:
|
||||
"""Get a user by ID."""
|
||||
try:
|
||||
statement = select(User).where(User.id == user_id)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get user by ID: %s", user_id)
|
||||
raise
|
||||
|
||||
async def get_by_email(self, email: str) -> User | None:
|
||||
"""Get a user by email address."""
|
||||
try:
|
||||
statement = select(User).where(User.email == email)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get user by email: %s", email)
|
||||
raise
|
||||
|
||||
async def get_by_api_token(self, api_token: str) -> User | None:
|
||||
"""Get a user by API token."""
|
||||
try:
|
||||
statement = select(User).where(User.api_token == api_token)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception("Failed to get user by API token")
|
||||
raise
|
||||
|
||||
async def create(self, user_data: dict[str, Any]) -> User:
|
||||
"""Create a new user."""
|
||||
def _raise_plan_not_found() -> None:
|
||||
msg = "Default plan not found"
|
||||
raise ValueError(msg)
|
||||
|
||||
try:
|
||||
# Check if this is the first user
|
||||
user_count_statement = select(User)
|
||||
user_count_result = await self.session.exec(user_count_statement)
|
||||
is_first_user = user_count_result.first() is None
|
||||
|
||||
if is_first_user:
|
||||
# First user gets admin role and pro plan
|
||||
plan_statement = select(Plan).where(Plan.code == "pro")
|
||||
user_data["role"] = "admin"
|
||||
logger.info("Creating first user with admin role and pro plan")
|
||||
else:
|
||||
# Regular users get free plan
|
||||
plan_statement = select(Plan).where(Plan.code == "free")
|
||||
|
||||
plan_result = await self.session.exec(plan_statement)
|
||||
default_plan = plan_result.first()
|
||||
|
||||
if default_plan is None:
|
||||
_raise_plan_not_found()
|
||||
|
||||
# Type assertion to help type checker understand default_plan is not None
|
||||
assert default_plan is not None # noqa: S101
|
||||
|
||||
# Set plan_id and default credits
|
||||
user_data["plan_id"] = default_plan.id
|
||||
user_data["credits"] = default_plan.credits
|
||||
|
||||
user = User(**user_data)
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(user)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create user")
|
||||
raise
|
||||
else:
|
||||
logger.info("Created new user with email: %s", user.email)
|
||||
return user
|
||||
|
||||
async def update(self, user: User, update_data: dict[str, Any]) -> User:
|
||||
"""Update a user."""
|
||||
try:
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
await self.session.commit()
|
||||
await self.session.refresh(user)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update user")
|
||||
raise
|
||||
else:
|
||||
logger.info("Updated user: %s", user.email)
|
||||
return user
|
||||
|
||||
async def delete(self, user: User) -> None:
|
||||
"""Delete a user."""
|
||||
try:
|
||||
await self.session.delete(user)
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Deleted user: %s", user.email)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete user")
|
||||
raise
|
||||
|
||||
async def email_exists(self, email: str) -> bool:
|
||||
"""Check if an email address is already registered."""
|
||||
try:
|
||||
statement = select(User).where(User.email == email)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first() is not None
|
||||
except Exception:
|
||||
logger.exception("Failed to check if email exists: %s", email)
|
||||
raise
|
||||
53
app/schemas/auth.py
Normal file
53
app/schemas/auth.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Authentication schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class UserRegisterRequest(BaseModel):
|
||||
"""Schema for user registration request."""
|
||||
|
||||
email: EmailStr = Field(..., description="User email address")
|
||||
password: str = Field(
|
||||
..., min_length=8, description="User password (minimum 8 characters)",
|
||||
)
|
||||
name: str = Field(..., min_length=1, max_length=100, description="User full name")
|
||||
|
||||
|
||||
class UserLoginRequest(BaseModel):
|
||||
"""Schema for user login request."""
|
||||
|
||||
email: EmailStr = Field(..., description="User email address")
|
||||
password: str = Field(..., description="User password")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Schema for authentication token response."""
|
||||
|
||||
access_token: str = Field(..., description="JWT access token")
|
||||
token_type: str = Field(default="bearer", description="Token type")
|
||||
expires_in: int = Field(..., description="Token expiration time in seconds")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Schema for user information response."""
|
||||
|
||||
id: int = Field(..., description="User ID")
|
||||
email: str = Field(..., description="User email address")
|
||||
name: str = Field(..., description="User full name")
|
||||
picture: str | None = Field(None, description="User profile picture URL")
|
||||
role: str = Field(..., description="User role")
|
||||
credits: int = Field(..., description="User credits")
|
||||
is_active: bool = Field(..., description="Whether user is active")
|
||||
plan: dict[str, Any] = Field(..., description="User plan information")
|
||||
created_at: datetime = Field(..., description="User creation timestamp")
|
||||
updated_at: datetime = Field(..., description="User last update timestamp")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""Schema for authentication response."""
|
||||
|
||||
user: UserResponse = Field(..., description="User information")
|
||||
token: TokenResponse = Field(..., description="Authentication token")
|
||||
268
app/services/auth.py
Normal file
268
app/services/auth.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Authentication service."""
|
||||
|
||||
import hashlib
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.auth import (
|
||||
AuthResponse,
|
||||
TokenResponse,
|
||||
UserLoginRequest,
|
||||
UserRegisterRequest,
|
||||
UserResponse,
|
||||
)
|
||||
from app.utils.auth import JWTUtils, PasswordUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for authentication operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the auth service."""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
|
||||
async def register(self, request: UserRegisterRequest) -> AuthResponse:
|
||||
"""Register a new user."""
|
||||
logger.info("Attempting to register user with email: %s", request.email)
|
||||
|
||||
# Check if email already exists
|
||||
if await self.user_repo.email_exists(request.email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email address is already registered",
|
||||
)
|
||||
|
||||
# Hash the password
|
||||
hashed_password = PasswordUtils.hash_password(request.password)
|
||||
|
||||
# Create user data
|
||||
user_data = {
|
||||
"email": request.email,
|
||||
"name": request.name,
|
||||
"password_hash": hashed_password,
|
||||
"role": "user",
|
||||
"is_active": True,
|
||||
}
|
||||
|
||||
# Create the user
|
||||
user = await self.user_repo.create(user_data)
|
||||
|
||||
# Generate access token
|
||||
token = self._create_access_token(user)
|
||||
|
||||
# Create response
|
||||
user_response = await self.create_user_response(user)
|
||||
|
||||
logger.info("Successfully registered user: %s", user.email)
|
||||
return AuthResponse(user=user_response, token=token)
|
||||
|
||||
async def login(self, request: UserLoginRequest) -> AuthResponse:
|
||||
"""Authenticate a user login."""
|
||||
logger.info("Attempting to login user with email: %s", request.email)
|
||||
|
||||
# Get user by email
|
||||
user = await self.user_repo.get_by_email(request.email)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
# Verify password
|
||||
if not user.password_hash or not PasswordUtils.verify_password(
|
||||
request.password,
|
||||
user.password_hash,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
# Generate access token
|
||||
token = self._create_access_token(user)
|
||||
|
||||
# Create response
|
||||
user_response = await self.create_user_response(user)
|
||||
|
||||
logger.info("Successfully authenticated user: %s", user.email)
|
||||
return AuthResponse(user=user_response, token=token)
|
||||
|
||||
async def get_current_user(self, user_id: int) -> User:
|
||||
"""Get the current authenticated user."""
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
def _create_access_token(self, user: User) -> TokenResponse:
|
||||
"""Create an access token for a user."""
|
||||
access_token_expires = timedelta(
|
||||
minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
)
|
||||
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
access_token = JWTUtils.create_access_token(
|
||||
data=token_data,
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="bearer", # noqa: S106 # This is OAuth2 standard, not a password
|
||||
expires_in=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
)
|
||||
|
||||
async def create_and_store_refresh_token(self, user: User) -> str:
|
||||
"""Create and store a refresh token for a user."""
|
||||
refresh_token_expires = timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
}
|
||||
|
||||
refresh_token = JWTUtils.create_refresh_token(
|
||||
data=token_data,
|
||||
expires_delta=refresh_token_expires,
|
||||
)
|
||||
|
||||
# Hash the refresh token for storage
|
||||
refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||
|
||||
# Store hash and expiration in database
|
||||
user.refresh_token_hash = refresh_token_hash
|
||||
user.refresh_token_expires_at = datetime.now(UTC) + refresh_token_expires
|
||||
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
|
||||
return refresh_token
|
||||
|
||||
async def refresh_access_token(self, refresh_token: str) -> TokenResponse:
|
||||
"""Create a new access token using a refresh token."""
|
||||
try:
|
||||
# Decode the refresh token
|
||||
payload = JWTUtils.decode_refresh_token(refresh_token)
|
||||
user_id_str = payload.get("sub")
|
||||
if not user_id_str:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
user_id = int(user_id_str)
|
||||
|
||||
# Get the user
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
|
||||
# Check if refresh token hash matches stored hash
|
||||
refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
|
||||
if (
|
||||
not user.refresh_token_hash
|
||||
or user.refresh_token_hash != refresh_token_hash
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
|
||||
# Check if refresh token is expired
|
||||
if user.refresh_token_expires_at and datetime.now(
|
||||
UTC
|
||||
) > user.refresh_token_expires_at.replace(tzinfo=UTC):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has expired",
|
||||
)
|
||||
|
||||
# Check if user is active
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
# Create new access token
|
||||
return self._create_access_token(user)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to refresh access token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
) from e
|
||||
|
||||
async def revoke_refresh_token(self, user: User) -> None:
|
||||
"""Revoke a user's refresh token."""
|
||||
user.refresh_token_hash = None
|
||||
user.refresh_token_expires_at = None
|
||||
self.session.add(user)
|
||||
await self.session.commit()
|
||||
logger.info("Refresh token revoked for user: %s", user.email)
|
||||
|
||||
async def create_user_response(self, user: User) -> UserResponse:
|
||||
"""Create a user response from a user model."""
|
||||
# Always refresh to ensure the plan relationship is loaded
|
||||
await self.session.refresh(user, ["plan"])
|
||||
|
||||
# Ensure user has an ID (should always be true for persisted users)
|
||||
if user.id is None:
|
||||
msg = "User must have an ID to create response"
|
||||
raise ValueError(msg)
|
||||
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
picture=user.picture,
|
||||
role=user.role,
|
||||
credits=user.credits,
|
||||
is_active=user.is_active,
|
||||
plan={
|
||||
"id": user.plan.id,
|
||||
"code": user.plan.code,
|
||||
"name": user.plan.name,
|
||||
"description": user.plan.description,
|
||||
"credits": user.plan.credits,
|
||||
"max_credits": user.plan.max_credits,
|
||||
},
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
179
app/utils/auth.py
Normal file
179
app/utils/auth.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Authentication utilities."""
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PasswordUtils:
|
||||
"""Utility class for password operations."""
|
||||
|
||||
@staticmethod
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed = bcrypt.hashpw(password.encode("utf-8"), salt)
|
||||
return hashed.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def verify_password(password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return bcrypt.checkpw(password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
|
||||
class JWTUtils:
|
||||
"""Utility class for JWT operations."""
|
||||
|
||||
@staticmethod
|
||||
def create_access_token(
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
"""Create a JWT access token."""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(
|
||||
minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
try:
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create JWT token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not create access token",
|
||||
) from e
|
||||
else:
|
||||
logger.info("JWT token created successfully")
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def create_refresh_token(
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
"""Create a JWT refresh token."""
|
||||
to_encode = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS,
|
||||
)
|
||||
|
||||
to_encode.update({"exp": expire, "type": "refresh"})
|
||||
|
||||
try:
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create JWT refresh token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Could not create refresh token",
|
||||
) from e
|
||||
else:
|
||||
logger.info("JWT refresh token created successfully")
|
||||
return encoded_jwt
|
||||
|
||||
@staticmethod
|
||||
def decode_access_token(token: str) -> dict[str, Any]:
|
||||
"""Decode and validate a JWT access token."""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
)
|
||||
# Ensure this is not a refresh token
|
||||
if payload.get("type") == "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
)
|
||||
return dict(payload)
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token has expired",
|
||||
) from e
|
||||
except jwt.PyJWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
) from e
|
||||
|
||||
@staticmethod
|
||||
def decode_refresh_token(token: str) -> dict[str, Any]:
|
||||
"""Decode and validate a JWT refresh token."""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM],
|
||||
)
|
||||
# Ensure this is a refresh token
|
||||
if payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
)
|
||||
return dict(payload)
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Refresh token has expired",
|
||||
) from e
|
||||
except jwt.PyJWTError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate refresh token",
|
||||
) from e
|
||||
|
||||
|
||||
class TokenUtils:
|
||||
"""Utility class for API token operations."""
|
||||
|
||||
@staticmethod
|
||||
def generate_api_token() -> str:
|
||||
"""Generate a secure random API token."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
@staticmethod
|
||||
def is_token_expired(expires_at: datetime | None) -> bool:
|
||||
"""Check if a token is expired."""
|
||||
if expires_at is None:
|
||||
return False
|
||||
|
||||
# Handle timezone-aware and naive datetimes
|
||||
if expires_at.tzinfo is None:
|
||||
# Naive datetime - assume UTC
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
else:
|
||||
# Convert to UTC if not already
|
||||
expires_at = expires_at.astimezone(UTC)
|
||||
|
||||
return datetime.now(UTC) > expires_at
|
||||
Reference in New Issue
Block a user