feat: Enhance OAuth2 flow with temporary code exchange and update cookie handling
This commit is contained in:
@@ -20,10 +20,10 @@ JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
|
|||||||
|
|
||||||
# Cookie Configuration
|
# Cookie Configuration
|
||||||
COOKIE_SECURE=false
|
COOKIE_SECURE=false
|
||||||
|
COOKIE_SAMESITE=lax
|
||||||
|
|
||||||
# OAuth2 Configuration
|
# OAuth2 Configuration
|
||||||
GOOGLE_CLIENT_ID=
|
GOOGLE_CLIENT_ID=
|
||||||
GOOGLE_CLIENT_SECRET=
|
GOOGLE_CLIENT_SECRET=
|
||||||
GITHUB_CLIENT_ID=
|
GITHUB_CLIENT_ID=
|
||||||
GITHUB_CLIENT_SECRET=
|
GITHUB_CLIENT_SECRET=
|
||||||
OAUTH_REDIRECT_URL=http://localhost:8001/auth/callback
|
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Authentication endpoints."""
|
"""Authentication endpoints."""
|
||||||
|
|
||||||
from typing import Annotated
|
import secrets
|
||||||
|
import time
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
|
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
@@ -21,6 +23,9 @@ from app.utils.auth import JWTUtils
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Global temporary storage for OAuth codes (in production, use Redis with TTL)
|
||||||
|
_temp_oauth_codes: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
# Authentication endpoints
|
# Authentication endpoints
|
||||||
@router.post(
|
@router.post(
|
||||||
@@ -48,6 +53,7 @@ async def register(
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Allow cookie across localhost ports
|
||||||
)
|
)
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="refresh_token",
|
key="refresh_token",
|
||||||
@@ -59,6 +65,7 @@ async def register(
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Allow cookie across localhost ports
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -95,6 +102,7 @@ async def login(
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Allow cookie across localhost ports
|
||||||
)
|
)
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="refresh_token",
|
key="refresh_token",
|
||||||
@@ -106,6 +114,7 @@ async def login(
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Allow cookie across localhost ports
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -139,8 +148,8 @@ async def get_current_user_info(
|
|||||||
@router.post("/refresh")
|
@router.post("/refresh")
|
||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
response: Response,
|
response: Response,
|
||||||
refresh_token: Annotated[str | None, Cookie()],
|
|
||||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
refresh_token: Annotated[str | None, Cookie()] = None,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Refresh access token using refresh token."""
|
"""Refresh access token using refresh token."""
|
||||||
try:
|
try:
|
||||||
@@ -161,6 +170,7 @@ async def refresh_token(
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Allow cookie across localhost ports
|
||||||
)
|
)
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -178,9 +188,9 @@ async def refresh_token(
|
|||||||
@router.post("/logout")
|
@router.post("/logout")
|
||||||
async def logout(
|
async def logout(
|
||||||
response: Response,
|
response: Response,
|
||||||
access_token: Annotated[str | None, Cookie()],
|
|
||||||
refresh_token: Annotated[str | None, Cookie()],
|
|
||||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
access_token: Annotated[str | None, Cookie()] = None,
|
||||||
|
refresh_token: Annotated[str | None, Cookie()] = None,
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""Logout endpoint - clears cookies and revokes refresh token."""
|
"""Logout endpoint - clears cookies and revokes refresh token."""
|
||||||
user = None
|
user = None
|
||||||
@@ -222,12 +232,14 @@ async def logout(
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Match the domain used when setting cookies
|
||||||
)
|
)
|
||||||
response.delete_cookie(
|
response.delete_cookie(
|
||||||
key="refresh_token",
|
key="refresh_token",
|
||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Match the domain used when setting cookies
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"message": "Successfully logged out"}
|
return {"message": "Successfully logged out"}
|
||||||
@@ -272,6 +284,8 @@ async def oauth_callback(
|
|||||||
) -> RedirectResponse:
|
) -> RedirectResponse:
|
||||||
"""Handle OAuth callback."""
|
"""Handle OAuth callback."""
|
||||||
try:
|
try:
|
||||||
|
logger.info("OAuth callback started for provider: %s", provider)
|
||||||
|
|
||||||
# Handle OAuth callback and get user info
|
# Handle OAuth callback and get user info
|
||||||
oauth_user_info = await oauth_service.handle_callback(provider, code)
|
oauth_user_info = await oauth_service.handle_callback(provider, code)
|
||||||
|
|
||||||
@@ -282,7 +296,9 @@ async def oauth_callback(
|
|||||||
user = await auth_service.get_current_user(auth_response.user.id)
|
user = await auth_service.get_current_user(auth_response.user.id)
|
||||||
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
||||||
|
|
||||||
# Set HTTP-only cookies for both tokens
|
# Set HTTP-only cookies for both tokens (not used due to cross-port issues)
|
||||||
|
# These cookies are kept for potential future same-origin scenarios
|
||||||
|
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="access_token",
|
key="access_token",
|
||||||
value=auth_response.token.access_token,
|
value=auth_response.token.access_token,
|
||||||
@@ -290,6 +306,8 @@ async def oauth_callback(
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Allow cookie across localhost ports
|
||||||
|
path="/", # Ensure cookie is available for all paths
|
||||||
)
|
)
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="refresh_token",
|
key="refresh_token",
|
||||||
@@ -298,6 +316,8 @@ async def oauth_callback(
|
|||||||
httponly=True,
|
httponly=True,
|
||||||
secure=settings.COOKIE_SECURE,
|
secure=settings.COOKIE_SECURE,
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost", # Allow cookie across localhost ports
|
||||||
|
path="/", # Ensure cookie is available for all paths
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -306,9 +326,26 @@ async def oauth_callback(
|
|||||||
provider,
|
provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Redirect back to frontend after successful authentication
|
# Instead of setting cookies that won't work across ports,
|
||||||
|
# let's redirect to a special frontend endpoint that can make an
|
||||||
|
# immediate API call. Frontend will call /exchange-oauth-token with this code
|
||||||
|
temp_code = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
# Store the mapping temporarily (in production, use Redis with TTL)
|
||||||
|
# For now, store in memory with the user data
|
||||||
|
_temp_oauth_codes[temp_code] = {
|
||||||
|
"user_id": auth_response.user.id,
|
||||||
|
"access_token": auth_response.token.access_token,
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"expires_in": auth_response.token.expires_in,
|
||||||
|
"created_at": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
redirect_url = f"http://localhost:8001/auth/callback?code={temp_code}"
|
||||||
|
logger.info("Redirecting to: %s", redirect_url)
|
||||||
|
|
||||||
return RedirectResponse(
|
return RedirectResponse(
|
||||||
url="http://localhost:8001/?auth=success",
|
url=redirect_url,
|
||||||
status_code=302,
|
status_code=302,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -328,3 +365,64 @@ async def get_oauth_providers() -> dict[str, list[str]]:
|
|||||||
return {
|
return {
|
||||||
"providers": ["google", "github"],
|
"providers": ["google", "github"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/exchange-oauth-token")
|
||||||
|
async def exchange_oauth_token(
|
||||||
|
request: dict[str, str],
|
||||||
|
response: Response,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Exchange temporary OAuth code for proper auth cookies."""
|
||||||
|
code = request.get("code")
|
||||||
|
if not code:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Code parameter is required",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("OAuth token exchange requested with code: %s", code[:10] + "...")
|
||||||
|
|
||||||
|
# Get the stored token data
|
||||||
|
if code not in _temp_oauth_codes:
|
||||||
|
logger.error("Invalid or expired OAuth code: %s", code[:10] + "...")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid or expired OAuth code",
|
||||||
|
)
|
||||||
|
|
||||||
|
token_data = _temp_oauth_codes.pop(code) # Remove after use
|
||||||
|
|
||||||
|
# Check if code is too old (5 minutes max)
|
||||||
|
code_expiry_seconds = 300
|
||||||
|
if time.time() - token_data["created_at"] > code_expiry_seconds:
|
||||||
|
logger.error("OAuth code expired: %s", code[:10] + "...")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="OAuth code expired",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the proper auth cookies
|
||||||
|
response.set_cookie(
|
||||||
|
key="access_token",
|
||||||
|
value=token_data["access_token"],
|
||||||
|
max_age=token_data["expires_in"],
|
||||||
|
httponly=True,
|
||||||
|
secure=settings.COOKIE_SECURE,
|
||||||
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost",
|
||||||
|
path="/",
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token",
|
||||||
|
value=token_data["refresh_token"],
|
||||||
|
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
||||||
|
httponly=True,
|
||||||
|
secure=settings.COOKIE_SECURE,
|
||||||
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
domain="localhost",
|
||||||
|
path="/",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = token_data["user_id"]
|
||||||
|
logger.info("OAuth tokens exchanged successfully for user: %s", user_id)
|
||||||
|
return {"message": "Tokens set successfully", "user_id": str(user_id)}
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ class Settings(BaseSettings):
|
|||||||
GOOGLE_CLIENT_SECRET: str = ""
|
GOOGLE_CLIENT_SECRET: str = ""
|
||||||
GITHUB_CLIENT_ID: str = ""
|
GITHUB_CLIENT_ID: str = ""
|
||||||
GITHUB_CLIENT_SECRET: str = ""
|
GITHUB_CLIENT_SECRET: str = ""
|
||||||
OAUTH_REDIRECT_URL: str = "http://localhost:8001/auth/callback"
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -30,8 +30,8 @@ async def get_oauth_service(
|
|||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
access_token: Annotated[str | None, Cookie()],
|
|
||||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
access_token: Annotated[str | None, Cookie()] = None,
|
||||||
) -> User:
|
) -> User:
|
||||||
"""Get the current authenticated user from JWT token in HTTP-only cookie."""
|
"""Get the current authenticated user from JWT token in HTTP-only cookie."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class OAuthProvider(ABC):
|
|||||||
"""Generate authorization URL with state parameter."""
|
"""Generate authorization URL with state parameter."""
|
||||||
# Construct provider-specific redirect URI
|
# Construct provider-specific redirect URI
|
||||||
redirect_uri = (
|
redirect_uri = (
|
||||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
@@ -86,7 +86,7 @@ class OAuthProvider(ABC):
|
|||||||
"""Exchange authorization code for access token."""
|
"""Exchange authorization code for access token."""
|
||||||
# Construct provider-specific redirect URI (must match authorization request)
|
# Construct provider-specific redirect URI (must match authorization request)
|
||||||
redirect_uri = (
|
redirect_uri = (
|
||||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
@@ -150,7 +150,7 @@ class GoogleOAuthProvider(OAuthProvider):
|
|||||||
"""Exchange authorization code for access token."""
|
"""Exchange authorization code for access token."""
|
||||||
# Construct provider-specific redirect URI (must match authorization request)
|
# Construct provider-specific redirect URI (must match authorization request)
|
||||||
redirect_uri = (
|
redirect_uri = (
|
||||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
|||||||
@@ -65,9 +65,11 @@ class TestAuthEndpoints:
|
|||||||
assert data["credits"] > 0
|
assert data["credits"] > 0
|
||||||
assert "plan" in data
|
assert "plan" in data
|
||||||
|
|
||||||
# Check cookies are set
|
# Check cookies are set - HTTPX AsyncClient preserves Set-Cookie headers
|
||||||
assert "access_token" in response.cookies
|
set_cookie_headers = response.headers.get_list("set-cookie")
|
||||||
assert "refresh_token" in response.cookies
|
cookie_names = [header.split("=")[0] for header in set_cookie_headers]
|
||||||
|
assert "access_token" in cookie_names
|
||||||
|
assert "refresh_token" in cookie_names
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_duplicate_email(
|
async def test_register_duplicate_email(
|
||||||
@@ -140,9 +142,11 @@ class TestAuthEndpoints:
|
|||||||
assert "role" in data
|
assert "role" in data
|
||||||
assert data["is_active"] is True
|
assert data["is_active"] is True
|
||||||
|
|
||||||
# Check cookies are set
|
# Check cookies are set - HTTPX AsyncClient preserves Set-Cookie headers
|
||||||
assert "access_token" in response.cookies
|
set_cookie_headers = response.headers.get_list("set-cookie")
|
||||||
assert "refresh_token" in response.cookies
|
cookie_names = [header.split("=")[0] for header in set_cookie_headers]
|
||||||
|
assert "access_token" in cookie_names
|
||||||
|
assert "refresh_token" in cookie_names
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_login_invalid_email(self, test_client: AsyncClient) -> None:
|
async def test_login_invalid_email(self, test_client: AsyncClient) -> None:
|
||||||
@@ -202,7 +206,7 @@ class TestAuthEndpoints:
|
|||||||
"""Test getting current user without authentication token."""
|
"""Test getting current user without authentication token."""
|
||||||
response = await test_client.get("/api/v1/auth/me")
|
response = await test_client.get("/api/v1/auth/me")
|
||||||
|
|
||||||
assert response.status_code == 422 # Validation error (no cookie provided)
|
assert response.status_code == 401 # Unauthorized (no cookie provided)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_current_user_invalid_token(
|
async def test_get_current_user_invalid_token(
|
||||||
@@ -386,9 +390,10 @@ class TestAuthEndpoints:
|
|||||||
follow_redirects=False,
|
follow_redirects=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# OAuth callback should successfully process and redirect to frontend
|
# OAuth callback should successfully process and redirect to frontend with temp code
|
||||||
assert response.status_code == 302
|
assert response.status_code == 302
|
||||||
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
location = response.headers["location"]
|
||||||
|
assert location.startswith("http://localhost:8001/auth/callback?code=")
|
||||||
|
|
||||||
# The fact that we get a 302 redirect means the OAuth login was successful
|
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||||
# Detailed cookie testing can be done in integration tests
|
# Detailed cookie testing can be done in integration tests
|
||||||
@@ -417,9 +422,10 @@ class TestAuthEndpoints:
|
|||||||
follow_redirects=False,
|
follow_redirects=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# OAuth callback should successfully process and redirect to frontend
|
# OAuth callback should successfully process and redirect to frontend with temp code
|
||||||
assert response.status_code == 302
|
assert response.status_code == 302
|
||||||
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
location = response.headers["location"]
|
||||||
|
assert location.startswith("http://localhost:8001/auth/callback?code=")
|
||||||
|
|
||||||
# The fact that we get a 302 redirect means the OAuth login was successful
|
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||||
# Detailed cookie testing can be done in integration tests
|
# Detailed cookie testing can be done in integration tests
|
||||||
|
|||||||
Reference in New Issue
Block a user