feat: Enhance OAuth2 flow with temporary code exchange and update cookie handling

This commit is contained in:
JSC
2025-07-26 18:31:40 +02:00
parent 98e36b067d
commit 0f605d7ed1
6 changed files with 127 additions and 24 deletions

View File

@@ -20,10 +20,10 @@ JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
# Cookie Configuration
COOKIE_SECURE=false
COOKIE_SAMESITE=lax
# OAuth2 Configuration
GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET=
GITHUB_CLIENT_ID=
GITHUB_CLIENT_SECRET=
OAUTH_REDIRECT_URL=http://localhost:8001/auth/callback

View File

@@ -1,6 +1,8 @@
"""Authentication endpoints."""
from typing import Annotated
import secrets
import time
from typing import Annotated, Any
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
from fastapi.responses import RedirectResponse
@@ -21,6 +23,9 @@ from app.utils.auth import JWTUtils
router = APIRouter()
logger = get_logger(__name__)
# Global temporary storage for OAuth codes (in production, use Redis with TTL)
_temp_oauth_codes: dict[str, dict[str, Any]] = {}
# Authentication endpoints
@router.post(
@@ -48,6 +53,7 @@ async def register(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
response.set_cookie(
key="refresh_token",
@@ -59,6 +65,7 @@ async def register(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
except HTTPException:
@@ -95,6 +102,7 @@ async def login(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
response.set_cookie(
key="refresh_token",
@@ -106,6 +114,7 @@ async def login(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
except HTTPException:
@@ -139,8 +148,8 @@ async def get_current_user_info(
@router.post("/refresh")
async def refresh_token(
response: Response,
refresh_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
refresh_token: Annotated[str | None, Cookie()] = None,
) -> dict[str, str]:
"""Refresh access token using refresh token."""
try:
@@ -161,6 +170,7 @@ async def refresh_token(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
)
except HTTPException:
@@ -178,9 +188,9 @@ async def refresh_token(
@router.post("/logout")
async def logout(
response: Response,
access_token: Annotated[str | None, Cookie()],
refresh_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None,
refresh_token: Annotated[str | None, Cookie()] = None,
) -> dict[str, str]:
"""Logout endpoint - clears cookies and revokes refresh token."""
user = None
@@ -222,12 +232,14 @@ async def logout(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Match the domain used when setting cookies
)
response.delete_cookie(
key="refresh_token",
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Match the domain used when setting cookies
)
return {"message": "Successfully logged out"}
@@ -272,6 +284,8 @@ async def oauth_callback(
) -> RedirectResponse:
"""Handle OAuth callback."""
try:
logger.info("OAuth callback started for provider: %s", provider)
# Handle OAuth callback and get user info
oauth_user_info = await oauth_service.handle_callback(provider, code)
@@ -282,7 +296,9 @@ async def oauth_callback(
user = await auth_service.get_current_user(auth_response.user.id)
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
# Set HTTP-only cookies for both tokens (not used due to cross-port issues)
# These cookies are kept for potential future same-origin scenarios
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
@@ -290,6 +306,8 @@ async def oauth_callback(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
path="/", # Ensure cookie is available for all paths
)
response.set_cookie(
key="refresh_token",
@@ -298,6 +316,8 @@ async def oauth_callback(
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
path="/", # Ensure cookie is available for all paths
)
logger.info(
@@ -306,9 +326,26 @@ async def oauth_callback(
provider,
)
# Redirect back to frontend after successful authentication
# Instead of setting cookies that won't work across ports,
# let's redirect to a special frontend endpoint that can make an
# immediate API call. Frontend will call /exchange-oauth-token with this code
temp_code = secrets.token_urlsafe(32)
# Store the mapping temporarily (in production, use Redis with TTL)
# For now, store in memory with the user data
_temp_oauth_codes[temp_code] = {
"user_id": auth_response.user.id,
"access_token": auth_response.token.access_token,
"refresh_token": refresh_token,
"expires_in": auth_response.token.expires_in,
"created_at": time.time(),
}
redirect_url = f"http://localhost:8001/auth/callback?code={temp_code}"
logger.info("Redirecting to: %s", redirect_url)
return RedirectResponse(
url="http://localhost:8001/?auth=success",
url=redirect_url,
status_code=302,
)
@@ -328,3 +365,64 @@ async def get_oauth_providers() -> dict[str, list[str]]:
return {
"providers": ["google", "github"],
}
@router.post("/exchange-oauth-token")
async def exchange_oauth_token(
request: dict[str, str],
response: Response,
) -> dict[str, str]:
"""Exchange temporary OAuth code for proper auth cookies."""
code = request.get("code")
if not code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Code parameter is required",
)
logger.info("OAuth token exchange requested with code: %s", code[:10] + "...")
# Get the stored token data
if code not in _temp_oauth_codes:
logger.error("Invalid or expired OAuth code: %s", code[:10] + "...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired OAuth code",
)
token_data = _temp_oauth_codes.pop(code) # Remove after use
# Check if code is too old (5 minutes max)
code_expiry_seconds = 300
if time.time() - token_data["created_at"] > code_expiry_seconds:
logger.error("OAuth code expired: %s", code[:10] + "...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth code expired",
)
# Set the proper auth cookies
response.set_cookie(
key="access_token",
value=token_data["access_token"],
max_age=token_data["expires_in"],
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost",
path="/",
)
response.set_cookie(
key="refresh_token",
value=token_data["refresh_token"],
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost",
path="/",
)
user_id = token_data["user_id"]
logger.info("OAuth tokens exchanged successfully for user: %s", user_id)
return {"message": "Tokens set successfully", "user_id": str(user_id)}

View File

@@ -46,7 +46,6 @@ class Settings(BaseSettings):
GOOGLE_CLIENT_SECRET: str = ""
GITHUB_CLIENT_ID: str = ""
GITHUB_CLIENT_SECRET: str = ""
OAUTH_REDIRECT_URL: str = "http://localhost:8001/auth/callback"
settings = Settings()

View File

@@ -30,8 +30,8 @@ async def get_oauth_service(
async def get_current_user(
access_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None,
) -> User:
"""Get the current authenticated user from JWT token in HTTP-only cookie."""
try:

View File

@@ -70,7 +70,7 @@ class OAuthProvider(ABC):
"""Generate authorization URL with state parameter."""
# Construct provider-specific redirect URI
redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
)
params = {
@@ -86,7 +86,7 @@ class OAuthProvider(ABC):
"""Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request)
redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
)
data = {
@@ -150,7 +150,7 @@ class GoogleOAuthProvider(OAuthProvider):
"""Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request)
redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
)
data = {

View File

@@ -65,9 +65,11 @@ class TestAuthEndpoints:
assert data["credits"] > 0
assert "plan" in data
# Check cookies are set
assert "access_token" in response.cookies
assert "refresh_token" in response.cookies
# Check cookies are set - HTTPX AsyncClient preserves Set-Cookie headers
set_cookie_headers = response.headers.get_list("set-cookie")
cookie_names = [header.split("=")[0] for header in set_cookie_headers]
assert "access_token" in cookie_names
assert "refresh_token" in cookie_names
@pytest.mark.asyncio
async def test_register_duplicate_email(
@@ -140,9 +142,11 @@ class TestAuthEndpoints:
assert "role" in data
assert data["is_active"] is True
# Check cookies are set
assert "access_token" in response.cookies
assert "refresh_token" in response.cookies
# Check cookies are set - HTTPX AsyncClient preserves Set-Cookie headers
set_cookie_headers = response.headers.get_list("set-cookie")
cookie_names = [header.split("=")[0] for header in set_cookie_headers]
assert "access_token" in cookie_names
assert "refresh_token" in cookie_names
@pytest.mark.asyncio
async def test_login_invalid_email(self, test_client: AsyncClient) -> None:
@@ -202,7 +206,7 @@ class TestAuthEndpoints:
"""Test getting current user without authentication token."""
response = await test_client.get("/api/v1/auth/me")
assert response.status_code == 422 # Validation error (no cookie provided)
assert response.status_code == 401 # Unauthorized (no cookie provided)
@pytest.mark.asyncio
async def test_get_current_user_invalid_token(
@@ -386,9 +390,10 @@ class TestAuthEndpoints:
follow_redirects=False,
)
# OAuth callback should successfully process and redirect to frontend
# OAuth callback should successfully process and redirect to frontend with temp code
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:8001/?auth=success"
location = response.headers["location"]
assert location.startswith("http://localhost:8001/auth/callback?code=")
# The fact that we get a 302 redirect means the OAuth login was successful
# Detailed cookie testing can be done in integration tests
@@ -417,9 +422,10 @@ class TestAuthEndpoints:
follow_redirects=False,
)
# OAuth callback should successfully process and redirect to frontend
# OAuth callback should successfully process and redirect to frontend with temp code
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:8001/?auth=success"
location = response.headers["location"]
assert location.startswith("http://localhost:8001/auth/callback?code=")
# The fact that we get a 302 redirect means the OAuth login was successful
# Detailed cookie testing can be done in integration tests