331 lines
11 KiB
Python
331 lines
11 KiB
Python
"""Authentication endpoints."""
|
|
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
|
|
from fastapi.responses import RedirectResponse
|
|
|
|
from app.core.config import settings
|
|
from app.core.dependencies import (
|
|
get_auth_service,
|
|
get_current_active_user,
|
|
get_oauth_service,
|
|
)
|
|
from app.core.logging import get_logger
|
|
from app.models.user import User
|
|
from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse
|
|
from app.services.auth import AuthService
|
|
from app.services.oauth import OAuthService
|
|
from app.utils.auth import JWTUtils
|
|
|
|
router = APIRouter()
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
# Authentication endpoints
|
|
@router.post(
|
|
"/register",
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
async def register(
|
|
request: UserRegisterRequest,
|
|
response: Response,
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> UserResponse:
|
|
"""Register a new user account."""
|
|
try:
|
|
auth_response = await auth_service.register(request)
|
|
|
|
# Create and store refresh token - need to get User object from service
|
|
user = await auth_service.get_current_user(auth_response.user.id)
|
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
|
|
|
# Set HTTP-only cookies for both tokens
|
|
response.set_cookie(
|
|
key="access_token",
|
|
value=auth_response.token.access_token,
|
|
max_age=auth_response.token.expires_in,
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
response.set_cookie(
|
|
key="refresh_token",
|
|
value=refresh_token,
|
|
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
* 24
|
|
* 60
|
|
* 60, # Convert days to seconds
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Registration failed for email: %s", request.email)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Registration failed",
|
|
) from e
|
|
else:
|
|
return auth_response.user
|
|
|
|
|
|
@router.post("/login")
|
|
async def login(
|
|
request: UserLoginRequest,
|
|
response: Response,
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> UserResponse:
|
|
"""Authenticate a user and return access token."""
|
|
try:
|
|
auth_response = await auth_service.login(request)
|
|
|
|
# Create and store refresh token - need to get User object from service
|
|
user = await auth_service.get_current_user(auth_response.user.id)
|
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
|
|
|
# Set HTTP-only cookies for both tokens
|
|
response.set_cookie(
|
|
key="access_token",
|
|
value=auth_response.token.access_token,
|
|
max_age=auth_response.token.expires_in,
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
response.set_cookie(
|
|
key="refresh_token",
|
|
value=refresh_token,
|
|
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
* 24
|
|
* 60
|
|
* 60, # Convert days to seconds
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Login failed for email: %s", request.email)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Login failed",
|
|
) from e
|
|
else:
|
|
return auth_response.user
|
|
|
|
|
|
@router.get("/me")
|
|
async def get_current_user_info(
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> UserResponse:
|
|
"""Get current user information."""
|
|
try:
|
|
return await auth_service.create_user_response(current_user)
|
|
except Exception as e:
|
|
logger.exception("Failed to get current user info")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to retrieve user information",
|
|
) from e
|
|
|
|
|
|
@router.post("/refresh")
|
|
async def refresh_token(
|
|
response: Response,
|
|
refresh_token: Annotated[str | None, Cookie()],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> dict[str, str]:
|
|
"""Refresh access token using refresh token."""
|
|
try:
|
|
if not refresh_token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="No refresh token provided",
|
|
)
|
|
|
|
# Get new access token
|
|
token_response = await auth_service.refresh_access_token(refresh_token)
|
|
|
|
# Set new access token cookie
|
|
response.set_cookie(
|
|
key="access_token",
|
|
value=token_response.access_token,
|
|
max_age=token_response.expires_in,
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Token refresh failed")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Token refresh failed",
|
|
) from e
|
|
else:
|
|
return {"message": "Token refreshed successfully"}
|
|
|
|
|
|
@router.post("/logout")
|
|
async def logout(
|
|
response: Response,
|
|
access_token: Annotated[str | None, Cookie()],
|
|
refresh_token: Annotated[str | None, Cookie()],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> dict[str, str]:
|
|
"""Logout endpoint - clears cookies and revokes refresh token."""
|
|
user = None
|
|
|
|
# Try to get user from access token first
|
|
if access_token:
|
|
try:
|
|
payload = JWTUtils.decode_access_token(access_token)
|
|
user_id_str = payload.get("sub")
|
|
if user_id_str:
|
|
user_id = int(user_id_str)
|
|
user = await auth_service.get_current_user(user_id)
|
|
logger.info("Found user from access token: %s", user.email)
|
|
except (HTTPException, Exception) as e:
|
|
logger.info("Access token validation failed: %s", str(e))
|
|
|
|
# If no user found, try refresh token
|
|
if not user and refresh_token:
|
|
try:
|
|
payload = JWTUtils.decode_refresh_token(refresh_token)
|
|
user_id_str = payload.get("sub")
|
|
if user_id_str:
|
|
user_id = int(user_id_str)
|
|
user = await auth_service.get_current_user(user_id)
|
|
logger.info("Found user from refresh token: %s", user.email)
|
|
except (HTTPException, Exception) as e:
|
|
logger.info("Refresh token validation failed: %s", str(e))
|
|
|
|
# If we found a user, revoke their refresh token
|
|
if user:
|
|
await auth_service.revoke_refresh_token(user)
|
|
logger.info("Successfully revoked refresh token for user: %s", user.email)
|
|
else:
|
|
logger.info("No user found, skipping token revocation")
|
|
|
|
# Always clear both cookies regardless of token validity
|
|
response.delete_cookie(
|
|
key="access_token",
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
response.delete_cookie(
|
|
key="refresh_token",
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
|
|
return {"message": "Successfully logged out"}
|
|
|
|
|
|
# OAuth2 endpoints
|
|
@router.get("/{provider}/authorize")
|
|
async def oauth_authorize(
|
|
provider: str,
|
|
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
|
) -> dict[str, str]:
|
|
"""Get OAuth authorization URL."""
|
|
try:
|
|
# Generate secure state parameter
|
|
state = oauth_service.generate_state()
|
|
|
|
# Get authorization URL
|
|
auth_url = oauth_service.get_authorization_url(provider, state)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("OAuth authorization failed for provider: %s", provider)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="OAuth authorization failed",
|
|
) from e
|
|
else:
|
|
return {
|
|
"authorization_url": auth_url,
|
|
"state": state,
|
|
}
|
|
|
|
|
|
@router.get("/{provider}/callback")
|
|
async def oauth_callback(
|
|
provider: str,
|
|
response: Response,
|
|
code: Annotated[str, Query()],
|
|
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> RedirectResponse:
|
|
"""Handle OAuth callback."""
|
|
try:
|
|
# Handle OAuth callback and get user info
|
|
oauth_user_info = await oauth_service.handle_callback(provider, code)
|
|
|
|
# Perform OAuth login (link or create user)
|
|
auth_response = await auth_service.oauth_login(oauth_user_info)
|
|
|
|
# Create and store refresh token
|
|
user = await auth_service.get_current_user(auth_response.user.id)
|
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
|
|
|
# Set HTTP-only cookies for both tokens
|
|
response.set_cookie(
|
|
key="access_token",
|
|
value=auth_response.token.access_token,
|
|
max_age=auth_response.token.expires_in,
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
response.set_cookie(
|
|
key="refresh_token",
|
|
value=refresh_token,
|
|
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
)
|
|
|
|
logger.info(
|
|
"OAuth login successful for user: %s via %s",
|
|
auth_response.user.email,
|
|
provider,
|
|
)
|
|
|
|
# Redirect back to frontend after successful authentication
|
|
return RedirectResponse(
|
|
url="http://localhost:8001/?auth=success",
|
|
status_code=302,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("OAuth callback failed for provider: %s", provider)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="OAuth callback failed",
|
|
) from e
|
|
|
|
|
|
@router.get("/providers")
|
|
async def get_oauth_providers() -> dict[str, list[str]]:
|
|
"""Get list of available OAuth providers."""
|
|
return {
|
|
"providers": ["google", "github"],
|
|
}
|