feat: Enhance OAuth2 flow with temporary code exchange and update cookie handling

This commit is contained in:
JSC
2025-07-26 18:31:40 +02:00
parent 98e36b067d
commit 0f605d7ed1
6 changed files with 127 additions and 24 deletions

View File

@@ -20,10 +20,10 @@ JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
# Cookie Configuration # Cookie Configuration
COOKIE_SECURE=false COOKIE_SECURE=false
COOKIE_SAMESITE=lax
# OAuth2 Configuration # OAuth2 Configuration
GOOGLE_CLIENT_ID= GOOGLE_CLIENT_ID=
GOOGLE_CLIENT_SECRET= GOOGLE_CLIENT_SECRET=
GITHUB_CLIENT_ID= GITHUB_CLIENT_ID=
GITHUB_CLIENT_SECRET= GITHUB_CLIENT_SECRET=
OAUTH_REDIRECT_URL=http://localhost:8001/auth/callback

View File

@@ -1,6 +1,8 @@
"""Authentication endpoints.""" """Authentication endpoints."""
from typing import Annotated import secrets
import time
from typing import Annotated, Any
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
@@ -21,6 +23,9 @@ from app.utils.auth import JWTUtils
router = APIRouter() router = APIRouter()
logger = get_logger(__name__) logger = get_logger(__name__)
# Global temporary storage for OAuth codes (in production, use Redis with TTL)
_temp_oauth_codes: dict[str, dict[str, Any]] = {}
# Authentication endpoints # Authentication endpoints
@router.post( @router.post(
@@ -48,6 +53,7 @@ async def register(
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
) )
response.set_cookie( response.set_cookie(
key="refresh_token", key="refresh_token",
@@ -59,6 +65,7 @@ async def register(
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
) )
except HTTPException: except HTTPException:
@@ -95,6 +102,7 @@ async def login(
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
) )
response.set_cookie( response.set_cookie(
key="refresh_token", key="refresh_token",
@@ -106,6 +114,7 @@ async def login(
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
) )
except HTTPException: except HTTPException:
@@ -139,8 +148,8 @@ async def get_current_user_info(
@router.post("/refresh") @router.post("/refresh")
async def refresh_token( async def refresh_token(
response: Response, response: Response,
refresh_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)], auth_service: Annotated[AuthService, Depends(get_auth_service)],
refresh_token: Annotated[str | None, Cookie()] = None,
) -> dict[str, str]: ) -> dict[str, str]:
"""Refresh access token using refresh token.""" """Refresh access token using refresh token."""
try: try:
@@ -161,6 +170,7 @@ async def refresh_token(
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
) )
except HTTPException: except HTTPException:
@@ -178,9 +188,9 @@ async def refresh_token(
@router.post("/logout") @router.post("/logout")
async def logout( async def logout(
response: Response, response: Response,
access_token: Annotated[str | None, Cookie()],
refresh_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)], auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None,
refresh_token: Annotated[str | None, Cookie()] = None,
) -> dict[str, str]: ) -> dict[str, str]:
"""Logout endpoint - clears cookies and revokes refresh token.""" """Logout endpoint - clears cookies and revokes refresh token."""
user = None user = None
@@ -222,12 +232,14 @@ async def logout(
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Match the domain used when setting cookies
) )
response.delete_cookie( response.delete_cookie(
key="refresh_token", key="refresh_token",
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Match the domain used when setting cookies
) )
return {"message": "Successfully logged out"} return {"message": "Successfully logged out"}
@@ -272,6 +284,8 @@ async def oauth_callback(
) -> RedirectResponse: ) -> RedirectResponse:
"""Handle OAuth callback.""" """Handle OAuth callback."""
try: try:
logger.info("OAuth callback started for provider: %s", provider)
# Handle OAuth callback and get user info # Handle OAuth callback and get user info
oauth_user_info = await oauth_service.handle_callback(provider, code) oauth_user_info = await oauth_service.handle_callback(provider, code)
@@ -282,7 +296,9 @@ async def oauth_callback(
user = await auth_service.get_current_user(auth_response.user.id) user = await auth_service.get_current_user(auth_response.user.id)
refresh_token = await auth_service.create_and_store_refresh_token(user) refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens # Set HTTP-only cookies for both tokens (not used due to cross-port issues)
# These cookies are kept for potential future same-origin scenarios
response.set_cookie( response.set_cookie(
key="access_token", key="access_token",
value=auth_response.token.access_token, value=auth_response.token.access_token,
@@ -290,6 +306,8 @@ async def oauth_callback(
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
path="/", # Ensure cookie is available for all paths
) )
response.set_cookie( response.set_cookie(
key="refresh_token", key="refresh_token",
@@ -298,6 +316,8 @@ async def oauth_callback(
httponly=True, httponly=True,
secure=settings.COOKIE_SECURE, secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE, samesite=settings.COOKIE_SAMESITE,
domain="localhost", # Allow cookie across localhost ports
path="/", # Ensure cookie is available for all paths
) )
logger.info( logger.info(
@@ -306,9 +326,26 @@ async def oauth_callback(
provider, provider,
) )
# Redirect back to frontend after successful authentication # Instead of setting cookies that won't work across ports,
# let's redirect to a special frontend endpoint that can make an
# immediate API call. Frontend will call /exchange-oauth-token with this code
temp_code = secrets.token_urlsafe(32)
# Store the mapping temporarily (in production, use Redis with TTL)
# For now, store in memory with the user data
_temp_oauth_codes[temp_code] = {
"user_id": auth_response.user.id,
"access_token": auth_response.token.access_token,
"refresh_token": refresh_token,
"expires_in": auth_response.token.expires_in,
"created_at": time.time(),
}
redirect_url = f"http://localhost:8001/auth/callback?code={temp_code}"
logger.info("Redirecting to: %s", redirect_url)
return RedirectResponse( return RedirectResponse(
url="http://localhost:8001/?auth=success", url=redirect_url,
status_code=302, status_code=302,
) )
@@ -328,3 +365,64 @@ async def get_oauth_providers() -> dict[str, list[str]]:
return { return {
"providers": ["google", "github"], "providers": ["google", "github"],
} }
@router.post("/exchange-oauth-token")
async def exchange_oauth_token(
request: dict[str, str],
response: Response,
) -> dict[str, str]:
"""Exchange temporary OAuth code for proper auth cookies."""
code = request.get("code")
if not code:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Code parameter is required",
)
logger.info("OAuth token exchange requested with code: %s", code[:10] + "...")
# Get the stored token data
if code not in _temp_oauth_codes:
logger.error("Invalid or expired OAuth code: %s", code[:10] + "...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid or expired OAuth code",
)
token_data = _temp_oauth_codes.pop(code) # Remove after use
# Check if code is too old (5 minutes max)
code_expiry_seconds = 300
if time.time() - token_data["created_at"] > code_expiry_seconds:
logger.error("OAuth code expired: %s", code[:10] + "...")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="OAuth code expired",
)
# Set the proper auth cookies
response.set_cookie(
key="access_token",
value=token_data["access_token"],
max_age=token_data["expires_in"],
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost",
path="/",
)
response.set_cookie(
key="refresh_token",
value=token_data["refresh_token"],
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
domain="localhost",
path="/",
)
user_id = token_data["user_id"]
logger.info("OAuth tokens exchanged successfully for user: %s", user_id)
return {"message": "Tokens set successfully", "user_id": str(user_id)}

View File

@@ -46,7 +46,6 @@ class Settings(BaseSettings):
GOOGLE_CLIENT_SECRET: str = "" GOOGLE_CLIENT_SECRET: str = ""
GITHUB_CLIENT_ID: str = "" GITHUB_CLIENT_ID: str = ""
GITHUB_CLIENT_SECRET: str = "" GITHUB_CLIENT_SECRET: str = ""
OAUTH_REDIRECT_URL: str = "http://localhost:8001/auth/callback"
settings = Settings() settings = Settings()

View File

@@ -30,8 +30,8 @@ async def get_oauth_service(
async def get_current_user( async def get_current_user(
access_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)], auth_service: Annotated[AuthService, Depends(get_auth_service)],
access_token: Annotated[str | None, Cookie()] = None,
) -> User: ) -> User:
"""Get the current authenticated user from JWT token in HTTP-only cookie.""" """Get the current authenticated user from JWT token in HTTP-only cookie."""
try: try:

View File

@@ -70,7 +70,7 @@ class OAuthProvider(ABC):
"""Generate authorization URL with state parameter.""" """Generate authorization URL with state parameter."""
# Construct provider-specific redirect URI # Construct provider-specific redirect URI
redirect_uri = ( redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
) )
params = { params = {
@@ -86,7 +86,7 @@ class OAuthProvider(ABC):
"""Exchange authorization code for access token.""" """Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request) # Construct provider-specific redirect URI (must match authorization request)
redirect_uri = ( redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
) )
data = { data = {
@@ -150,7 +150,7 @@ class GoogleOAuthProvider(OAuthProvider):
"""Exchange authorization code for access token.""" """Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request) # Construct provider-specific redirect URI (must match authorization request)
redirect_uri = ( redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
) )
data = { data = {

View File

@@ -65,9 +65,11 @@ class TestAuthEndpoints:
assert data["credits"] > 0 assert data["credits"] > 0
assert "plan" in data assert "plan" in data
# Check cookies are set # Check cookies are set - HTTPX AsyncClient preserves Set-Cookie headers
assert "access_token" in response.cookies set_cookie_headers = response.headers.get_list("set-cookie")
assert "refresh_token" in response.cookies cookie_names = [header.split("=")[0] for header in set_cookie_headers]
assert "access_token" in cookie_names
assert "refresh_token" in cookie_names
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_duplicate_email( async def test_register_duplicate_email(
@@ -140,9 +142,11 @@ class TestAuthEndpoints:
assert "role" in data assert "role" in data
assert data["is_active"] is True assert data["is_active"] is True
# Check cookies are set # Check cookies are set - HTTPX AsyncClient preserves Set-Cookie headers
assert "access_token" in response.cookies set_cookie_headers = response.headers.get_list("set-cookie")
assert "refresh_token" in response.cookies cookie_names = [header.split("=")[0] for header in set_cookie_headers]
assert "access_token" in cookie_names
assert "refresh_token" in cookie_names
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_invalid_email(self, test_client: AsyncClient) -> None: async def test_login_invalid_email(self, test_client: AsyncClient) -> None:
@@ -202,7 +206,7 @@ class TestAuthEndpoints:
"""Test getting current user without authentication token.""" """Test getting current user without authentication token."""
response = await test_client.get("/api/v1/auth/me") response = await test_client.get("/api/v1/auth/me")
assert response.status_code == 422 # Validation error (no cookie provided) assert response.status_code == 401 # Unauthorized (no cookie provided)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_current_user_invalid_token( async def test_get_current_user_invalid_token(
@@ -386,9 +390,10 @@ class TestAuthEndpoints:
follow_redirects=False, follow_redirects=False,
) )
# OAuth callback should successfully process and redirect to frontend # OAuth callback should successfully process and redirect to frontend with temp code
assert response.status_code == 302 assert response.status_code == 302
assert response.headers["location"] == "http://localhost:8001/?auth=success" location = response.headers["location"]
assert location.startswith("http://localhost:8001/auth/callback?code=")
# The fact that we get a 302 redirect means the OAuth login was successful # The fact that we get a 302 redirect means the OAuth login was successful
# Detailed cookie testing can be done in integration tests # Detailed cookie testing can be done in integration tests
@@ -417,9 +422,10 @@ class TestAuthEndpoints:
follow_redirects=False, follow_redirects=False,
) )
# OAuth callback should successfully process and redirect to frontend # OAuth callback should successfully process and redirect to frontend with temp code
assert response.status_code == 302 assert response.status_code == 302
assert response.headers["location"] == "http://localhost:8001/?auth=success" location = response.headers["location"]
assert location.startswith("http://localhost:8001/auth/callback?code=")
# The fact that we get a 302 redirect means the OAuth login was successful # The fact that we get a 302 redirect means the OAuth login was successful
# Detailed cookie testing can be done in integration tests # Detailed cookie testing can be done in integration tests