feat: Implement OAuth2 authentication with Google and GitHub
- Added OAuth2 endpoints for Google and GitHub authentication. - Created OAuth service to handle provider interactions and user info retrieval. - Implemented user OAuth repository for managing user OAuth links in the database. - Updated auth service to support linking existing users and creating new users via OAuth. - Added CORS middleware to allow frontend access. - Created tests for OAuth endpoints and service functionality. - Introduced environment configuration for OAuth client IDs and secrets. - Added logging for OAuth operations and error handling.
This commit is contained in:
151
tests/api/v1/test_oauth_endpoints.py
Normal file
151
tests/api/v1/test_oauth_endpoints.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for OAuth authentication endpoints."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.services.oauth import OAuthUserInfo
|
||||
|
||||
|
||||
class TestOAuthEndpoints:
|
||||
"""Test OAuth API endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_oauth_providers(self, test_client: AsyncClient) -> None:
|
||||
"""Test getting list of OAuth providers."""
|
||||
response = await test_client.get("/api/v1/oauth/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "providers" in data
|
||||
assert "google" in data["providers"]
|
||||
assert "github" in data["providers"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None:
|
||||
"""Test OAuth authorization URL generation for Google."""
|
||||
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
||||
mock_state.return_value = "test_state_123"
|
||||
|
||||
response = await test_client.get("/api/v1/oauth/google/authorize")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
assert "state" in data
|
||||
assert data["state"] == "test_state_123"
|
||||
assert "accounts.google.com" in data["authorization_url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None:
|
||||
"""Test OAuth authorization URL generation for GitHub."""
|
||||
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
||||
mock_state.return_value = "test_state_456"
|
||||
|
||||
response = await test_client.get("/api/v1/oauth/github/authorize")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
assert "state" in data
|
||||
assert data["state"] == "test_state_456"
|
||||
assert "github.com" in data["authorization_url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_authorize_invalid_provider(
|
||||
self, test_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test OAuth authorization with invalid provider."""
|
||||
response = await test_client.get("/api/v1/oauth/invalid/authorize")
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "Unsupported OAuth provider" in data["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_new_user(
|
||||
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any]
|
||||
) -> None:
|
||||
"""Test OAuth callback for new user creation."""
|
||||
# Mock OAuth user info
|
||||
mock_user_info = OAuthUserInfo(
|
||||
provider="google",
|
||||
provider_user_id="google_123",
|
||||
email="newuser@gmail.com",
|
||||
name="New User",
|
||||
picture="https://example.com/avatar.jpg",
|
||||
)
|
||||
|
||||
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
||||
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
||||
mock_callback.return_value = mock_user_info
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/v1/oauth/google/callback",
|
||||
params={"code": "auth_code_123", "state": "test_state"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
# OAuth callback should successfully process and redirect to frontend
|
||||
assert response.status_code == 302
|
||||
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
||||
|
||||
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||
# Detailed cookie testing can be done in integration tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_existing_user_link(
|
||||
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any]
|
||||
) -> None:
|
||||
"""Test OAuth callback for linking to existing user."""
|
||||
# Mock OAuth user info with same email as test user
|
||||
mock_user_info = OAuthUserInfo(
|
||||
provider="github",
|
||||
provider_user_id="github_456",
|
||||
email=test_user.email, # Same email as existing user
|
||||
name="Test User",
|
||||
picture="https://github.com/avatar.jpg",
|
||||
)
|
||||
|
||||
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
||||
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
||||
mock_callback.return_value = mock_user_info
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/v1/oauth/github/callback",
|
||||
params={"code": "auth_code_456", "state": "test_state"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
# OAuth callback should successfully process and redirect to frontend
|
||||
assert response.status_code == 302
|
||||
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
||||
|
||||
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||
# Detailed cookie testing can be done in integration tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None:
|
||||
"""Test OAuth callback with missing authorization code."""
|
||||
response = await test_client.get(
|
||||
"/api/v1/oauth/google/callback",
|
||||
params={"state": "test_state"}, # Missing code parameter
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_invalid_provider(
|
||||
self, test_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test OAuth callback with invalid provider."""
|
||||
response = await test_client.get(
|
||||
"/api/v1/oauth/invalid/callback",
|
||||
params={"code": "auth_code_123", "state": "test_state"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "Unsupported OAuth provider" in data["detail"]
|
||||
192
tests/services/test_oauth_service.py
Normal file
192
tests/services/test_oauth_service.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for OAuth service."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.services.oauth import (
|
||||
GitHubOAuthProvider,
|
||||
GoogleOAuthProvider,
|
||||
OAuthService,
|
||||
OAuthUserInfo,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthService:
|
||||
"""Test OAuth service functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_service_initialization(self, test_session: Any) -> None:
|
||||
"""Test OAuth service initialization."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
|
||||
assert "google" in oauth_service.providers
|
||||
assert "github" in oauth_service.providers
|
||||
assert isinstance(oauth_service.providers["google"], GoogleOAuthProvider)
|
||||
assert isinstance(oauth_service.providers["github"], GitHubOAuthProvider)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_state(self, test_session: Any) -> None:
|
||||
"""Test state generation."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
state = oauth_service.generate_state()
|
||||
|
||||
assert isinstance(state, str)
|
||||
assert len(state) > 10 # Should be a reasonable length
|
||||
|
||||
# Generate another to ensure they're different
|
||||
state2 = oauth_service.generate_state()
|
||||
assert state != state2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_valid(self, test_session: Any) -> None:
|
||||
"""Test getting valid OAuth provider."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
|
||||
google_provider = oauth_service.get_provider("google")
|
||||
assert isinstance(google_provider, GoogleOAuthProvider)
|
||||
assert google_provider.provider_name == "google"
|
||||
|
||||
github_provider = oauth_service.get_provider("github")
|
||||
assert isinstance(github_provider, GitHubOAuthProvider)
|
||||
assert github_provider.provider_name == "github"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_invalid(self, test_session: Any) -> None:
|
||||
"""Test getting invalid OAuth provider."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
oauth_service.get_provider("invalid")
|
||||
|
||||
assert "Unsupported OAuth provider" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authorization_url(self, test_session: Any) -> None:
|
||||
"""Test authorization URL generation."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
state = "test_state_123"
|
||||
|
||||
# Test Google
|
||||
google_url = oauth_service.get_authorization_url("google", state)
|
||||
assert "accounts.google.com" in google_url
|
||||
assert "client_id=" in google_url
|
||||
assert f"state={state}" in google_url
|
||||
|
||||
# Test GitHub
|
||||
github_url = oauth_service.get_authorization_url("github", state)
|
||||
assert "github.com" in github_url
|
||||
assert "client_id=" in github_url
|
||||
assert f"state={state}" in github_url
|
||||
|
||||
|
||||
class TestGoogleOAuthProvider:
|
||||
"""Test Google OAuth provider."""
|
||||
|
||||
def test_provider_properties(self) -> None:
|
||||
"""Test Google provider properties."""
|
||||
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||
|
||||
assert provider.provider_name == "google"
|
||||
assert "accounts.google.com" in provider.authorization_url
|
||||
assert "oauth2.googleapis.com" in provider.token_url
|
||||
assert "googleapis.com" in provider.user_info_url
|
||||
assert "openid email profile" in provider.scope
|
||||
|
||||
def test_authorization_url_generation(self) -> None:
|
||||
"""Test authorization URL generation."""
|
||||
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||
state = "test_state"
|
||||
|
||||
auth_url = provider.get_authorization_url(state)
|
||||
|
||||
assert "accounts.google.com" in auth_url
|
||||
assert "client_id=test_client_id" in auth_url
|
||||
assert f"state={state}" in auth_url
|
||||
assert "scope=openid+email+profile" in auth_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_info_success(self) -> None:
|
||||
"""Test successful user info retrieval."""
|
||||
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||
|
||||
mock_response_data = {
|
||||
"id": "google_user_123",
|
||||
"email": "test@gmail.com",
|
||||
"name": "Test User",
|
||||
"picture": "https://example.com/avatar.jpg",
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get") as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = await provider.get_user_info("test_access_token")
|
||||
|
||||
assert user_info.provider == "google"
|
||||
assert user_info.provider_user_id == "google_user_123"
|
||||
assert user_info.email == "test@gmail.com"
|
||||
assert user_info.name == "Test User"
|
||||
assert user_info.picture == "https://example.com/avatar.jpg"
|
||||
|
||||
|
||||
class TestGitHubOAuthProvider:
|
||||
"""Test GitHub OAuth provider."""
|
||||
|
||||
def test_provider_properties(self) -> None:
|
||||
"""Test GitHub provider properties."""
|
||||
provider = GitHubOAuthProvider("test_client_id", "test_secret")
|
||||
|
||||
assert provider.provider_name == "github"
|
||||
assert "github.com" in provider.authorization_url
|
||||
assert "github.com" in provider.token_url
|
||||
assert "api.github.com" in provider.user_info_url
|
||||
assert "user:email" in provider.scope
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_info_success(self) -> None:
|
||||
"""Test successful user info retrieval."""
|
||||
provider = GitHubOAuthProvider("test_client_id", "test_secret")
|
||||
|
||||
mock_user_data = {
|
||||
"id": 123456,
|
||||
"login": "testuser",
|
||||
"name": "Test User",
|
||||
"avatar_url": "https://github.com/avatar.jpg",
|
||||
}
|
||||
|
||||
mock_emails_data = [
|
||||
{"email": "test@example.com", "primary": True, "verified": True},
|
||||
{"email": "secondary@example.com", "primary": False, "verified": True},
|
||||
]
|
||||
|
||||
with patch("httpx.AsyncClient.get") as mock_get:
|
||||
# Mock user profile response
|
||||
mock_user_response = Mock()
|
||||
mock_user_response.status_code = 200
|
||||
mock_user_response.json.return_value = mock_user_data
|
||||
|
||||
# Mock emails response
|
||||
mock_emails_response = Mock()
|
||||
mock_emails_response.status_code = 200
|
||||
mock_emails_response.json.return_value = mock_emails_data
|
||||
|
||||
# Return different responses based on URL
|
||||
def side_effect(url, **kwargs):
|
||||
if "user/emails" in str(url):
|
||||
return mock_emails_response
|
||||
return mock_user_response
|
||||
|
||||
mock_get.side_effect = side_effect
|
||||
|
||||
user_info = await provider.get_user_info("test_access_token")
|
||||
|
||||
assert user_info.provider == "github"
|
||||
assert user_info.provider_user_id == "123456"
|
||||
assert user_info.email == "test@example.com"
|
||||
assert user_info.name == "Test User"
|
||||
assert user_info.picture == "https://github.com/avatar.jpg"
|
||||
Reference in New Issue
Block a user