feat: Consolidate OAuth2 endpoints into auth module and remove redundant oauth file
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api.v1 import auth, main, oauth
|
from app.api.v1 import auth, main
|
||||||
|
|
||||||
# V1 API router with v1 prefix
|
# V1 API router with v1 prefix
|
||||||
api_router = APIRouter(prefix="/v1")
|
api_router = APIRouter(prefix="/v1")
|
||||||
@@ -10,4 +10,3 @@ api_router = APIRouter(prefix="/v1")
|
|||||||
# Include all route modules
|
# Include all route modules
|
||||||
api_router.include_router(main.router, tags=["main"])
|
api_router.include_router(main.router, tags=["main"])
|
||||||
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
|
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
|
||||||
api_router.include_router(oauth.router, prefix="/oauth", tags=["oauth"])
|
|
||||||
|
|||||||
@@ -2,20 +2,27 @@
|
|||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status
|
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.dependencies import get_auth_service, get_current_active_user
|
from app.core.dependencies import (
|
||||||
|
get_auth_service,
|
||||||
|
get_current_active_user,
|
||||||
|
get_oauth_service,
|
||||||
|
)
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse
|
from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse
|
||||||
from app.services.auth import AuthService
|
from app.services.auth import AuthService
|
||||||
|
from app.services.oauth import OAuthService
|
||||||
from app.utils.auth import JWTUtils
|
from app.utils.auth import JWTUtils
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Authentication endpoints
|
||||||
@router.post(
|
@router.post(
|
||||||
"/register",
|
"/register",
|
||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
@@ -224,3 +231,100 @@ async def logout(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {"message": "Successfully logged out"}
|
return {"message": "Successfully logged out"}
|
||||||
|
|
||||||
|
|
||||||
|
# OAuth2 endpoints
|
||||||
|
@router.get("/{provider}/authorize")
|
||||||
|
async def oauth_authorize(
|
||||||
|
provider: str,
|
||||||
|
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Get OAuth authorization URL."""
|
||||||
|
try:
|
||||||
|
# Generate secure state parameter
|
||||||
|
state = oauth_service.generate_state()
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
auth_url = oauth_service.get_authorization_url(provider, state)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("OAuth authorization failed for provider: %s", provider)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth authorization failed",
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"authorization_url": auth_url,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{provider}/callback")
|
||||||
|
async def oauth_callback(
|
||||||
|
provider: str,
|
||||||
|
response: Response,
|
||||||
|
code: Annotated[str, Query()],
|
||||||
|
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
||||||
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Handle OAuth callback."""
|
||||||
|
try:
|
||||||
|
# Handle OAuth callback and get user info
|
||||||
|
oauth_user_info = await oauth_service.handle_callback(provider, code)
|
||||||
|
|
||||||
|
# Perform OAuth login (link or create user)
|
||||||
|
auth_response = await auth_service.oauth_login(oauth_user_info)
|
||||||
|
|
||||||
|
# Create and store refresh token
|
||||||
|
user = await auth_service.get_current_user(auth_response.user.id)
|
||||||
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
||||||
|
|
||||||
|
# Set HTTP-only cookies for both tokens
|
||||||
|
response.set_cookie(
|
||||||
|
key="access_token",
|
||||||
|
value=auth_response.token.access_token,
|
||||||
|
max_age=auth_response.token.expires_in,
|
||||||
|
httponly=True,
|
||||||
|
secure=settings.COOKIE_SECURE,
|
||||||
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token",
|
||||||
|
value=refresh_token,
|
||||||
|
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
||||||
|
httponly=True,
|
||||||
|
secure=settings.COOKIE_SECURE,
|
||||||
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OAuth login successful for user: %s via %s",
|
||||||
|
auth_response.user.email,
|
||||||
|
provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redirect back to frontend after successful authentication
|
||||||
|
return RedirectResponse(
|
||||||
|
url="http://localhost:8001/?auth=success",
|
||||||
|
status_code=302,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("OAuth callback failed for provider: %s", provider)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth callback failed",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers")
|
||||||
|
async def get_oauth_providers() -> dict[str, list[str]]:
|
||||||
|
"""Get list of available OAuth providers."""
|
||||||
|
return {
|
||||||
|
"providers": ["google", "github"],
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,111 +0,0 @@
|
|||||||
"""OAuth2 authentication endpoints."""
|
|
||||||
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
|
||||||
from fastapi.responses import RedirectResponse
|
|
||||||
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.core.dependencies import get_auth_service, get_oauth_service
|
|
||||||
from app.core.logging import get_logger
|
|
||||||
from app.services.auth import AuthService
|
|
||||||
from app.services.oauth import OAuthService
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{provider}/authorize")
|
|
||||||
async def oauth_authorize(
|
|
||||||
provider: str,
|
|
||||||
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Get OAuth authorization URL."""
|
|
||||||
try:
|
|
||||||
# Generate secure state parameter
|
|
||||||
state = oauth_service.generate_state()
|
|
||||||
|
|
||||||
# Get authorization URL
|
|
||||||
auth_url = oauth_service.get_authorization_url(provider, state)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("OAuth authorization failed for provider: %s", provider)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="OAuth authorization failed",
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"authorization_url": auth_url,
|
|
||||||
"state": state,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{provider}/callback")
|
|
||||||
async def oauth_callback(
|
|
||||||
provider: str,
|
|
||||||
response: Response,
|
|
||||||
code: Annotated[str, Query()],
|
|
||||||
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
|
||||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
||||||
) -> RedirectResponse:
|
|
||||||
"""Handle OAuth callback."""
|
|
||||||
try:
|
|
||||||
# Handle OAuth callback and get user info
|
|
||||||
oauth_user_info = await oauth_service.handle_callback(provider, code)
|
|
||||||
|
|
||||||
# Perform OAuth login (link or create user)
|
|
||||||
auth_response = await auth_service.oauth_login(oauth_user_info)
|
|
||||||
|
|
||||||
# Create and store refresh token
|
|
||||||
user = await auth_service.get_current_user(auth_response.user.id)
|
|
||||||
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
|
||||||
|
|
||||||
# Set HTTP-only cookies for both tokens
|
|
||||||
response.set_cookie(
|
|
||||||
key="access_token",
|
|
||||||
value=auth_response.token.access_token,
|
|
||||||
max_age=auth_response.token.expires_in,
|
|
||||||
httponly=True,
|
|
||||||
secure=settings.COOKIE_SECURE,
|
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
|
||||||
)
|
|
||||||
response.set_cookie(
|
|
||||||
key="refresh_token",
|
|
||||||
value=refresh_token,
|
|
||||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
|
||||||
httponly=True,
|
|
||||||
secure=settings.COOKIE_SECURE,
|
|
||||||
samesite=settings.COOKIE_SAMESITE,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"OAuth login successful for user: %s via %s",
|
|
||||||
auth_response.user.email,
|
|
||||||
provider,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redirect back to frontend after successful authentication
|
|
||||||
return RedirectResponse(
|
|
||||||
url="http://localhost:8001/?auth=success",
|
|
||||||
status_code=302,
|
|
||||||
)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("OAuth callback failed for provider: %s", provider)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="OAuth callback failed",
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/providers")
|
|
||||||
async def get_oauth_providers() -> dict[str, list[str]]:
|
|
||||||
"""Get list of available OAuth providers."""
|
|
||||||
return {
|
|
||||||
"providers": ["google", "github"],
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tests for authentication endpoints."""
|
"""Tests for authentication endpoints."""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
@@ -8,6 +9,7 @@ from httpx import AsyncClient
|
|||||||
|
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
from app.services.auth import OAuthUserInfo
|
||||||
from app.utils.auth import JWTUtils
|
from app.utils.auth import JWTUtils
|
||||||
|
|
||||||
|
|
||||||
@@ -307,3 +309,141 @@ class TestAuthEndpoints:
|
|||||||
# Test that get_admin_user passes for admin user
|
# Test that get_admin_user passes for admin user
|
||||||
result = await get_admin_user(admin_user)
|
result = await get_admin_user(admin_user)
|
||||||
assert result == admin_user
|
assert result == admin_user
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_oauth_providers(self, test_client: AsyncClient) -> None:
|
||||||
|
"""Test getting list of OAuth providers."""
|
||||||
|
response = await test_client.get("/api/v1/auth/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "providers" in data
|
||||||
|
assert "google" in data["providers"]
|
||||||
|
assert "github" in data["providers"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None:
|
||||||
|
"""Test OAuth authorization URL generation for Google."""
|
||||||
|
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
||||||
|
mock_state.return_value = "test_state_123"
|
||||||
|
|
||||||
|
response = await test_client.get("/api/v1/auth/google/authorize")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "authorization_url" in data
|
||||||
|
assert "state" in data
|
||||||
|
assert data["state"] == "test_state_123"
|
||||||
|
assert "accounts.google.com" in data["authorization_url"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None:
|
||||||
|
"""Test OAuth authorization URL generation for GitHub."""
|
||||||
|
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
||||||
|
mock_state.return_value = "test_state_456"
|
||||||
|
|
||||||
|
response = await test_client.get("/api/v1/auth/github/authorize")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "authorization_url" in data
|
||||||
|
assert "state" in data
|
||||||
|
assert data["state"] == "test_state_456"
|
||||||
|
assert "github.com" in data["authorization_url"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_authorize_invalid_provider(
|
||||||
|
self, test_client: AsyncClient
|
||||||
|
) -> None:
|
||||||
|
"""Test OAuth authorization with invalid provider."""
|
||||||
|
response = await test_client.get("/api/v1/auth/invalid/authorize")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert "Unsupported OAuth provider" in data["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_new_user(
|
||||||
|
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Test OAuth callback for new user creation."""
|
||||||
|
# Mock OAuth user info
|
||||||
|
mock_user_info = OAuthUserInfo(
|
||||||
|
provider="google",
|
||||||
|
provider_user_id="google_123",
|
||||||
|
email="newuser@gmail.com",
|
||||||
|
name="New User",
|
||||||
|
picture="https://example.com/avatar.jpg",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
||||||
|
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
||||||
|
mock_callback.return_value = mock_user_info
|
||||||
|
|
||||||
|
response = await test_client.get(
|
||||||
|
"/api/v1/auth/google/callback",
|
||||||
|
params={"code": "auth_code_123", "state": "test_state"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# OAuth callback should successfully process and redirect to frontend
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
||||||
|
|
||||||
|
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||||
|
# Detailed cookie testing can be done in integration tests
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_existing_user_link(
|
||||||
|
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Test OAuth callback for linking to existing user."""
|
||||||
|
# Mock OAuth user info with same email as test user
|
||||||
|
mock_user_info = OAuthUserInfo(
|
||||||
|
provider="github",
|
||||||
|
provider_user_id="github_456",
|
||||||
|
email=test_user.email, # Same email as existing user
|
||||||
|
name="Test User",
|
||||||
|
picture="https://github.com/avatar.jpg",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
||||||
|
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
||||||
|
mock_callback.return_value = mock_user_info
|
||||||
|
|
||||||
|
response = await test_client.get(
|
||||||
|
"/api/v1/auth/github/callback",
|
||||||
|
params={"code": "auth_code_456", "state": "test_state"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# OAuth callback should successfully process and redirect to frontend
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
||||||
|
|
||||||
|
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||||
|
# Detailed cookie testing can be done in integration tests
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None:
|
||||||
|
"""Test OAuth callback with missing authorization code."""
|
||||||
|
response = await test_client.get(
|
||||||
|
"/api/v1/auth/google/callback",
|
||||||
|
params={"state": "test_state"}, # Missing code parameter
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_invalid_provider(
|
||||||
|
self, test_client: AsyncClient
|
||||||
|
) -> None:
|
||||||
|
"""Test OAuth callback with invalid provider."""
|
||||||
|
response = await test_client.get(
|
||||||
|
"/api/v1/auth/invalid/callback",
|
||||||
|
params={"code": "auth_code_123", "state": "test_state"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert "Unsupported OAuth provider" in data["detail"]
|
||||||
|
|||||||
@@ -1,151 +0,0 @@
|
|||||||
"""Tests for OAuth authentication endpoints."""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
from app.services.oauth import OAuthUserInfo
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthEndpoints:
|
|
||||||
"""Test OAuth API endpoints."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_oauth_providers(self, test_client: AsyncClient) -> None:
|
|
||||||
"""Test getting list of OAuth providers."""
|
|
||||||
response = await test_client.get("/api/v1/oauth/providers")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "providers" in data
|
|
||||||
assert "google" in data["providers"]
|
|
||||||
assert "github" in data["providers"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None:
|
|
||||||
"""Test OAuth authorization URL generation for Google."""
|
|
||||||
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
|
||||||
mock_state.return_value = "test_state_123"
|
|
||||||
|
|
||||||
response = await test_client.get("/api/v1/oauth/google/authorize")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "authorization_url" in data
|
|
||||||
assert "state" in data
|
|
||||||
assert data["state"] == "test_state_123"
|
|
||||||
assert "accounts.google.com" in data["authorization_url"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None:
|
|
||||||
"""Test OAuth authorization URL generation for GitHub."""
|
|
||||||
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
|
||||||
mock_state.return_value = "test_state_456"
|
|
||||||
|
|
||||||
response = await test_client.get("/api/v1/oauth/github/authorize")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "authorization_url" in data
|
|
||||||
assert "state" in data
|
|
||||||
assert data["state"] == "test_state_456"
|
|
||||||
assert "github.com" in data["authorization_url"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_authorize_invalid_provider(
|
|
||||||
self, test_client: AsyncClient
|
|
||||||
) -> None:
|
|
||||||
"""Test OAuth authorization with invalid provider."""
|
|
||||||
response = await test_client.get("/api/v1/oauth/invalid/authorize")
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
data = response.json()
|
|
||||||
assert "Unsupported OAuth provider" in data["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_callback_new_user(
|
|
||||||
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any]
|
|
||||||
) -> None:
|
|
||||||
"""Test OAuth callback for new user creation."""
|
|
||||||
# Mock OAuth user info
|
|
||||||
mock_user_info = OAuthUserInfo(
|
|
||||||
provider="google",
|
|
||||||
provider_user_id="google_123",
|
|
||||||
email="newuser@gmail.com",
|
|
||||||
name="New User",
|
|
||||||
picture="https://example.com/avatar.jpg",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
|
||||||
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
|
||||||
mock_callback.return_value = mock_user_info
|
|
||||||
|
|
||||||
response = await test_client.get(
|
|
||||||
"/api/v1/oauth/google/callback",
|
|
||||||
params={"code": "auth_code_123", "state": "test_state"},
|
|
||||||
follow_redirects=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# OAuth callback should successfully process and redirect to frontend
|
|
||||||
assert response.status_code == 302
|
|
||||||
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
|
||||||
|
|
||||||
# The fact that we get a 302 redirect means the OAuth login was successful
|
|
||||||
# Detailed cookie testing can be done in integration tests
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_callback_existing_user_link(
|
|
||||||
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any]
|
|
||||||
) -> None:
|
|
||||||
"""Test OAuth callback for linking to existing user."""
|
|
||||||
# Mock OAuth user info with same email as test user
|
|
||||||
mock_user_info = OAuthUserInfo(
|
|
||||||
provider="github",
|
|
||||||
provider_user_id="github_456",
|
|
||||||
email=test_user.email, # Same email as existing user
|
|
||||||
name="Test User",
|
|
||||||
picture="https://github.com/avatar.jpg",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
|
||||||
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
|
||||||
mock_callback.return_value = mock_user_info
|
|
||||||
|
|
||||||
response = await test_client.get(
|
|
||||||
"/api/v1/oauth/github/callback",
|
|
||||||
params={"code": "auth_code_456", "state": "test_state"},
|
|
||||||
follow_redirects=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# OAuth callback should successfully process and redirect to frontend
|
|
||||||
assert response.status_code == 302
|
|
||||||
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
|
||||||
|
|
||||||
# The fact that we get a 302 redirect means the OAuth login was successful
|
|
||||||
# Detailed cookie testing can be done in integration tests
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None:
|
|
||||||
"""Test OAuth callback with missing authorization code."""
|
|
||||||
response = await test_client.get(
|
|
||||||
"/api/v1/oauth/google/callback",
|
|
||||||
params={"state": "test_state"}, # Missing code parameter
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 422 # Validation error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_oauth_callback_invalid_provider(
|
|
||||||
self, test_client: AsyncClient
|
|
||||||
) -> None:
|
|
||||||
"""Test OAuth callback with invalid provider."""
|
|
||||||
response = await test_client.get(
|
|
||||||
"/api/v1/oauth/invalid/callback",
|
|
||||||
params={"code": "auth_code_123", "state": "test_state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
data = response.json()
|
|
||||||
assert "Unsupported OAuth provider" in data["detail"]
|
|
||||||
Reference in New Issue
Block a user