feat: Enhance OAuth2 flow with temporary code exchange and update cookie handling

This commit is contained in:
JSC
2025-07-26 18:31:40 +02:00
parent 98e36b067d
commit 0f605d7ed1
6 changed files with 127 additions and 24 deletions

View File

@@ -1,6 +1,8 @@
"""Authentication endpoints."""
from typing import Annotated
import secrets
import time
from typing import Annotated, Any
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
from fastapi.responses import RedirectResponse
@@ -21,6 +23,9 @@ from app.utils.auth import JWTUtils
router = APIRouter()
logger = get_logger(__name__)
# Global temporary storage for OAuth codes (in production, use Redis with TTL)
_temp_oauth_codes: dict[str, dict[str, Any]] = {}
# Authentication endpoints
@router.post(
@@ -48,6 +53,7 @@ async def register(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
response.set_cookie(
key="refresh_token",
@@ -59,6 +65,7 @@ async def register(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
except HTTPException:
@@ -95,6 +102,7 @@ async def login(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
response.set_cookie(
key="refresh_token",
@@ -106,6 +114,7 @@ async def login(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
except HTTPException:
@@ -139,8 +148,8 @@ async def get_current_user_info(
@router.post("/refresh")
async def refresh_token(
response: Response,
refresh_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
refresh_token: Annotated[str | None, Cookie()] = None,
) -> dict[str, str]:
"""Refresh access token using refresh token."""
try:
@@ -161,6 +170,7 @@ async def refresh_token(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
except HTTPException:
@@ -178,9 +188,9 @@ async def refresh_token(
@router.post("/logout")
async def logout(
response: Response,
access_token: Annotated[str | None, Cookie()],
refresh_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None,
refresh_token: Annotated[str | None, Cookie()] = None,
) -> dict[str, str]:
"""Logout endpoint - clears cookies and revokes refresh token."""
user = None
@@ -222,12 +232,14 @@ async def logout(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Match the domain used when setting cookies
)
response.delete_cookie(
key="refresh_token",
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Match the domain used when setting cookies
)
return {"message": "Successfully logged out"}
@@ -272,6 +284,8 @@ async def oauth_callback(
) -> RedirectResponse:
"""Handle OAuth callback."""
try:
logger.info("OAuth callback started for provider: %s", provider)
# Handle OAuth callback and get user info
oauth_user_info = await oauth_service.handle_callback(provider, code)
@@ -282,7 +296,9 @@ async def oauth_callback(
user = await auth_service.get_current_user(auth_response.user.id)
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
# Set HTTP-only cookies for both tokens (not used due to cross-port issues)
# These cookies are kept for potential future same-origin scenarios
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
@@ -290,6 +306,8 @@ async def oauth_callback(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
path="/", # Ensure cookie is available for all paths
)
response.set_cookie(
key="refresh_token",
@@ -298,6 +316,8 @@ async def oauth_callback(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
path="/", # Ensure cookie is available for all paths
)
logger.info(
@@ -306,9 +326,26 @@ async def oauth_callback(
provider,
)
# Redirect back to frontend after successful authentication
# Instead of setting cookies that won't work across ports,
# let's redirect to a special frontend endpoint that can make an
# immediate API call. Frontend will call /exchange-oauth-token with this code
temp_code = secrets.token_urlsafe(32)
# Store the mapping temporarily (in production, use Redis with TTL)
# For now, store in memory with the user data
_temp_oauth_codes[temp_code] = {
"user_id": auth_response.user.id,
"access_token": auth_response.token.access_token,
"refresh_token": refresh_token,
"expires_in": auth_response.token.expires_in,
"created_at": time.time(),
}
redirect_url = f"http://localhost:8001/auth/callback?code={temp_code}"
logger.info("Redirecting to: %s", redirect_url)
return RedirectResponse(
url="http://localhost:8001/?auth=success",
url=redirect_url,
status_code=302,
)
@@ -328,3 +365,64 @@ async def get_oauth_providers() -> dict[str, list[str]]:
return {
"providers": ["google", "github"],
}
@router.post("/exchange-oauth-token")
async def exchange_oauth_token(
request: dict[str, str],
response: Response,
) -> dict[str, str]:
"""Exchange temporary OAuth code for proper auth cookies."""
code = request.get("code")
if not code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Code parameter is required",
)
logger.info("OAuth token exchange requested with code: %s", code[:10] + "...")
# Get the stored token data
if code not in _temp_oauth_codes:
logger.error("Invalid or expired OAuth code: %s", code[:10] + "...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired OAuth code",
)
token_data = _temp_oauth_codes.pop(code) # Remove after use
# Check if code is too old (5 minutes max)
code_expiry_seconds = 300
if time.time() - token_data["created_at"] > code_expiry_seconds:
logger.error("OAuth code expired: %s", code[:10] + "...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth code expired",
)
# Set the proper auth cookies
response.set_cookie(
key="access_token",
value=token_data["access_token"],
max_age=token_data["expires_in"],
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost",
path="/",
)
response.set_cookie(
key="refresh_token",
value=token_data["refresh_token"],
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost",
path="/",
)
user_id = token_data["user_id"]
logger.info("OAuth tokens exchanged successfully for user: %s", user_id)
return {"message": "Tokens set successfully", "user_id": str(user_id)}