feat: Consolidate OAuth2 endpoints into auth module and remove redundant oauth file

This commit is contained in:
JSC
2025-07-26 15:15:17 +02:00
parent 51423779a8
commit 98e36b067d
6 changed files with 255 additions and 274 deletions

View File

@@ -2,7 +2,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1 import auth, main, oauth from app.api.v1 import auth, main
# V1 API router with v1 prefix # V1 API router with v1 prefix
api_router = APIRouter(prefix="/v1") api_router = APIRouter(prefix="/v1")
@@ -10,4 +10,3 @@ api_router = APIRouter(prefix="/v1")
# Include all route modules # Include all route modules
api_router.include_router(main.router, tags=["main"]) api_router.include_router(main.router, tags=["main"])
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"]) api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
api_router.include_router(oauth.router, prefix="/oauth", tags=["oauth"])

View File

@@ -2,20 +2,27 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
from fastapi.responses import RedirectResponse
from app.core.config import settings from app.core.config import settings
from app.core.dependencies import get_auth_service, get_current_active_user from app.core.dependencies import (
get_auth_service,
get_current_active_user,
get_oauth_service,
)
from app.core.logging import get_logger from app.core.logging import get_logger
from app.models.user import User from app.models.user import User
from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse
from app.services.auth import AuthService from app.services.auth import AuthService
from app.services.oauth import OAuthService
from app.utils.auth import JWTUtils from app.utils.auth import JWTUtils
router = APIRouter() router = APIRouter()
logger = get_logger(__name__) logger = get_logger(__name__)
# Authentication endpoints
@router.post( @router.post(
"/register", "/register",
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
@@ -224,3 +231,100 @@ async def logout(
) )
return {"message": "Successfully logged out"} return {"message": "Successfully logged out"}
# OAuth2 endpoints
@router.get("/{provider}/authorize")
async def oauth_authorize(
provider: str,
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
) -> dict[str, str]:
"""Get OAuth authorization URL."""
try:
# Generate secure state parameter
state = oauth_service.generate_state()
# Get authorization URL
auth_url = oauth_service.get_authorization_url(provider, state)
except HTTPException:
raise
except Exception as e:
logger.exception("OAuth authorization failed for provider: %s", provider)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth authorization failed",
) from e
else:
return {
"authorization_url": auth_url,
"state": state,
}
@router.get("/{provider}/callback")
async def oauth_callback(
provider: str,
response: Response,
code: Annotated[str, Query()],
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> RedirectResponse:
"""Handle OAuth callback."""
try:
# Handle OAuth callback and get user info
oauth_user_info = await oauth_service.handle_callback(provider, code)
# Perform OAuth login (link or create user)
auth_response = await auth_service.oauth_login(oauth_user_info)
# Create and store refresh token
user = await auth_service.get_current_user(auth_response.user.id)
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
max_age=auth_response.token.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
logger.info(
"OAuth login successful for user: %s via %s",
auth_response.user.email,
provider,
)
# Redirect back to frontend after successful authentication
return RedirectResponse(
url="http://localhost:8001/?auth=success",
status_code=302,
)
except HTTPException:
raise
except Exception as e:
logger.exception("OAuth callback failed for provider: %s", provider)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth callback failed",
) from e
@router.get("/providers")
async def get_oauth_providers() -> dict[str, list[str]]:
"""Get list of available OAuth providers."""
return {
"providers": ["google", "github"],
}

View File

@@ -1,111 +0,0 @@
"""OAuth2 authentication endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from fastapi.responses import RedirectResponse
from app.core.config import settings
from app.core.dependencies import get_auth_service, get_oauth_service
from app.core.logging import get_logger
from app.services.auth import AuthService
from app.services.oauth import OAuthService
router = APIRouter()
logger = get_logger(__name__)
@router.get("/{provider}/authorize")
async def oauth_authorize(
provider: str,
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
) -> dict[str, str]:
"""Get OAuth authorization URL."""
try:
# Generate secure state parameter
state = oauth_service.generate_state()
# Get authorization URL
auth_url = oauth_service.get_authorization_url(provider, state)
except HTTPException:
raise
except Exception as e:
logger.exception("OAuth authorization failed for provider: %s", provider)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth authorization failed",
) from e
else:
return {
"authorization_url": auth_url,
"state": state,
}
@router.get("/{provider}/callback")
async def oauth_callback(
provider: str,
response: Response,
code: Annotated[str, Query()],
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> RedirectResponse:
"""Handle OAuth callback."""
try:
# Handle OAuth callback and get user info
oauth_user_info = await oauth_service.handle_callback(provider, code)
# Perform OAuth login (link or create user)
auth_response = await auth_service.oauth_login(oauth_user_info)
# Create and store refresh token
user = await auth_service.get_current_user(auth_response.user.id)
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
max_age=auth_response.token.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
logger.info(
"OAuth login successful for user: %s via %s",
auth_response.user.email,
provider,
)
# Redirect back to frontend after successful authentication
return RedirectResponse(
url="http://localhost:8001/?auth=success",
status_code=302,
)
except HTTPException:
raise
except Exception as e:
logger.exception("OAuth callback failed for provider: %s", provider)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth callback failed",
) from e
@router.get("/providers")
async def get_oauth_providers() -> dict[str, list[str]]:
"""Get list of available OAuth providers."""
return {
"providers": ["google", "github"],
}

View File

@@ -1,6 +1,7 @@
"""Tests for authentication endpoints.""" """Tests for authentication endpoints."""
from typing import Any from typing import Any
from unittest.mock import patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@@ -8,6 +9,7 @@ from httpx import AsyncClient
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user import User from app.models.user import User
from app.services.auth import OAuthUserInfo
from app.utils.auth import JWTUtils from app.utils.auth import JWTUtils
@@ -307,3 +309,141 @@ class TestAuthEndpoints:
# Test that get_admin_user passes for admin user # Test that get_admin_user passes for admin user
result = await get_admin_user(admin_user) result = await get_admin_user(admin_user)
assert result == admin_user assert result == admin_user
@pytest.mark.asyncio
async def test_get_oauth_providers(self, test_client: AsyncClient) -> None:
"""Test getting list of OAuth providers."""
response = await test_client.get("/api/v1/auth/providers")
assert response.status_code == 200
data = response.json()
assert "providers" in data
assert "google" in data["providers"]
assert "github" in data["providers"]
@pytest.mark.asyncio
async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None:
"""Test OAuth authorization URL generation for Google."""
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
mock_state.return_value = "test_state_123"
response = await test_client.get("/api/v1/auth/google/authorize")
assert response.status_code == 200
data = response.json()
assert "authorization_url" in data
assert "state" in data
assert data["state"] == "test_state_123"
assert "accounts.google.com" in data["authorization_url"]
@pytest.mark.asyncio
async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None:
"""Test OAuth authorization URL generation for GitHub."""
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
mock_state.return_value = "test_state_456"
response = await test_client.get("/api/v1/auth/github/authorize")
assert response.status_code == 200
data = response.json()
assert "authorization_url" in data
assert "state" in data
assert data["state"] == "test_state_456"
assert "github.com" in data["authorization_url"]
@pytest.mark.asyncio
async def test_oauth_authorize_invalid_provider(
self, test_client: AsyncClient
) -> None:
"""Test OAuth authorization with invalid provider."""
response = await test_client.get("/api/v1/auth/invalid/authorize")
assert response.status_code == 400
data = response.json()
assert "Unsupported OAuth provider" in data["detail"]
@pytest.mark.asyncio
async def test_oauth_callback_new_user(
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any]
) -> None:
"""Test OAuth callback for new user creation."""
# Mock OAuth user info
mock_user_info = OAuthUserInfo(
provider="google",
provider_user_id="google_123",
email="newuser@gmail.com",
name="New User",
picture="https://example.com/avatar.jpg",
)
# Mock the entire handle_callback method to avoid actual OAuth API calls
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
mock_callback.return_value = mock_user_info
response = await test_client.get(
"/api/v1/auth/google/callback",
params={"code": "auth_code_123", "state": "test_state"},
follow_redirects=False,
)
# OAuth callback should successfully process and redirect to frontend
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:8001/?auth=success"
# The fact that we get a 302 redirect means the OAuth login was successful
# Detailed cookie testing can be done in integration tests
@pytest.mark.asyncio
async def test_oauth_callback_existing_user_link(
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any]
) -> None:
"""Test OAuth callback for linking to existing user."""
# Mock OAuth user info with same email as test user
mock_user_info = OAuthUserInfo(
provider="github",
provider_user_id="github_456",
email=test_user.email, # Same email as existing user
name="Test User",
picture="https://github.com/avatar.jpg",
)
# Mock the entire handle_callback method to avoid actual OAuth API calls
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
mock_callback.return_value = mock_user_info
response = await test_client.get(
"/api/v1/auth/github/callback",
params={"code": "auth_code_456", "state": "test_state"},
follow_redirects=False,
)
# OAuth callback should successfully process and redirect to frontend
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:8001/?auth=success"
# The fact that we get a 302 redirect means the OAuth login was successful
# Detailed cookie testing can be done in integration tests
@pytest.mark.asyncio
async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None:
"""Test OAuth callback with missing authorization code."""
response = await test_client.get(
"/api/v1/auth/google/callback",
params={"state": "test_state"}, # Missing code parameter
)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_oauth_callback_invalid_provider(
self, test_client: AsyncClient
) -> None:
"""Test OAuth callback with invalid provider."""
response = await test_client.get(
"/api/v1/auth/invalid/callback",
params={"code": "auth_code_123", "state": "test_state"},
)
assert response.status_code == 400
data = response.json()
assert "Unsupported OAuth provider" in data["detail"]

View File

@@ -1,151 +0,0 @@
"""Tests for OAuth authentication endpoints."""
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from httpx import AsyncClient
from app.services.oauth import OAuthUserInfo
class TestOAuthEndpoints:
"""Test OAuth API endpoints."""
@pytest.mark.asyncio
async def test_get_oauth_providers(self, test_client: AsyncClient) -> None:
"""Test getting list of OAuth providers."""
response = await test_client.get("/api/v1/oauth/providers")
assert response.status_code == 200
data = response.json()
assert "providers" in data
assert "google" in data["providers"]
assert "github" in data["providers"]
@pytest.mark.asyncio
async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None:
"""Test OAuth authorization URL generation for Google."""
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
mock_state.return_value = "test_state_123"
response = await test_client.get("/api/v1/oauth/google/authorize")
assert response.status_code == 200
data = response.json()
assert "authorization_url" in data
assert "state" in data
assert data["state"] == "test_state_123"
assert "accounts.google.com" in data["authorization_url"]
@pytest.mark.asyncio
async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None:
"""Test OAuth authorization URL generation for GitHub."""
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
mock_state.return_value = "test_state_456"
response = await test_client.get("/api/v1/oauth/github/authorize")
assert response.status_code == 200
data = response.json()
assert "authorization_url" in data
assert "state" in data
assert data["state"] == "test_state_456"
assert "github.com" in data["authorization_url"]
@pytest.mark.asyncio
async def test_oauth_authorize_invalid_provider(
self, test_client: AsyncClient
) -> None:
"""Test OAuth authorization with invalid provider."""
response = await test_client.get("/api/v1/oauth/invalid/authorize")
assert response.status_code == 400
data = response.json()
assert "Unsupported OAuth provider" in data["detail"]
@pytest.mark.asyncio
async def test_oauth_callback_new_user(
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any]
) -> None:
"""Test OAuth callback for new user creation."""
# Mock OAuth user info
mock_user_info = OAuthUserInfo(
provider="google",
provider_user_id="google_123",
email="newuser@gmail.com",
name="New User",
picture="https://example.com/avatar.jpg",
)
# Mock the entire handle_callback method to avoid actual OAuth API calls
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
mock_callback.return_value = mock_user_info
response = await test_client.get(
"/api/v1/oauth/google/callback",
params={"code": "auth_code_123", "state": "test_state"},
follow_redirects=False,
)
# OAuth callback should successfully process and redirect to frontend
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:8001/?auth=success"
# The fact that we get a 302 redirect means the OAuth login was successful
# Detailed cookie testing can be done in integration tests
@pytest.mark.asyncio
async def test_oauth_callback_existing_user_link(
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any]
) -> None:
"""Test OAuth callback for linking to existing user."""
# Mock OAuth user info with same email as test user
mock_user_info = OAuthUserInfo(
provider="github",
provider_user_id="github_456",
email=test_user.email, # Same email as existing user
name="Test User",
picture="https://github.com/avatar.jpg",
)
# Mock the entire handle_callback method to avoid actual OAuth API calls
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
mock_callback.return_value = mock_user_info
response = await test_client.get(
"/api/v1/oauth/github/callback",
params={"code": "auth_code_456", "state": "test_state"},
follow_redirects=False,
)
# OAuth callback should successfully process and redirect to frontend
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:8001/?auth=success"
# The fact that we get a 302 redirect means the OAuth login was successful
# Detailed cookie testing can be done in integration tests
@pytest.mark.asyncio
async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None:
"""Test OAuth callback with missing authorization code."""
response = await test_client.get(
"/api/v1/oauth/google/callback",
params={"state": "test_state"}, # Missing code parameter
)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_oauth_callback_invalid_provider(
self, test_client: AsyncClient
) -> None:
"""Test OAuth callback with invalid provider."""
response = await test_client.get(
"/api/v1/oauth/invalid/callback",
params={"code": "auth_code_123", "state": "test_state"},
)
assert response.status_code == 400
data = response.json()
assert "Unsupported OAuth provider" in data["detail"]