diff --git a/app/api/v1/__init__.py b/app/api/v1/__init__.py index c3d5e78..1d42966 100644 --- a/app/api/v1/__init__.py +++ b/app/api/v1/__init__.py @@ -2,7 +2,7 @@ from fastapi import APIRouter -from app.api.v1 import auth, main, oauth +from app.api.v1 import auth, main # V1 API router with v1 prefix api_router = APIRouter(prefix="/v1") @@ -10,4 +10,3 @@ api_router = APIRouter(prefix="/v1") # Include all route modules api_router.include_router(main.router, tags=["main"]) api_router.include_router(auth.router, prefix="/auth", tags=["authentication"]) -api_router.include_router(oauth.router, prefix="/oauth", tags=["oauth"]) diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index fd21528..8a81151 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -2,20 +2,27 @@ from typing import Annotated -from fastapi import APIRouter, Cookie, Depends, HTTPException, Response, status +from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status +from fastapi.responses import RedirectResponse from app.core.config import settings -from app.core.dependencies import get_auth_service, get_current_active_user +from app.core.dependencies import ( + get_auth_service, + get_current_active_user, + get_oauth_service, +) from app.core.logging import get_logger from app.models.user import User from app.schemas.auth import UserLoginRequest, UserRegisterRequest, UserResponse from app.services.auth import AuthService +from app.services.oauth import OAuthService from app.utils.auth import JWTUtils router = APIRouter() logger = get_logger(__name__) +# Authentication endpoints @router.post( "/register", status_code=status.HTTP_201_CREATED, @@ -224,3 +231,100 @@ async def logout( ) return {"message": "Successfully logged out"} + + +# OAuth2 endpoints +@router.get("/{provider}/authorize") +async def oauth_authorize( + provider: str, + oauth_service: Annotated[OAuthService, Depends(get_oauth_service)], +) -> dict[str, str]: + """Get OAuth authorization URL.""" + try: + # Generate secure state parameter + state = oauth_service.generate_state() + + # Get authorization URL + auth_url = oauth_service.get_authorization_url(provider, state) + + except HTTPException: + raise + except Exception as e: + logger.exception("OAuth authorization failed for provider: %s", provider) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth authorization failed", + ) from e + else: + return { + "authorization_url": auth_url, + "state": state, + } + + +@router.get("/{provider}/callback") +async def oauth_callback( + provider: str, + response: Response, + code: Annotated[str, Query()], + oauth_service: Annotated[OAuthService, Depends(get_oauth_service)], + auth_service: Annotated[AuthService, Depends(get_auth_service)], +) -> RedirectResponse: + """Handle OAuth callback.""" + try: + # Handle OAuth callback and get user info + oauth_user_info = await oauth_service.handle_callback(provider, code) + + # Perform OAuth login (link or create user) + auth_response = await auth_service.oauth_login(oauth_user_info) + + # Create and store refresh token + user = await auth_service.get_current_user(auth_response.user.id) + refresh_token = await auth_service.create_and_store_refresh_token(user) + + # Set HTTP-only cookies for both tokens + response.set_cookie( + key="access_token", + value=auth_response.token.access_token, + max_age=auth_response.token.expires_in, + httponly=True, + secure=settings.COOKIE_SECURE, + samesite=settings.COOKIE_SAMESITE, + ) + response.set_cookie( + key="refresh_token", + value=refresh_token, + max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, + httponly=True, + secure=settings.COOKIE_SECURE, + samesite=settings.COOKIE_SAMESITE, + ) + + logger.info( + "OAuth login successful for user: %s via %s", + auth_response.user.email, + provider, + ) + + # Redirect back to frontend after successful authentication + return RedirectResponse( + url="http://localhost:8001/?auth=success", + status_code=302, + ) + + except HTTPException: + raise + except Exception as e: + logger.exception("OAuth callback failed for provider: %s", provider) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth callback failed", + ) from e + + +@router.get("/providers") +async def get_oauth_providers() -> dict[str, list[str]]: + """Get list of available OAuth providers.""" + return { + "providers": ["google", "github"], + } diff --git a/app/api/v1/oauth.py b/app/api/v1/oauth.py deleted file mode 100644 index 8fe7508..0000000 --- a/app/api/v1/oauth.py +++ /dev/null @@ -1,111 +0,0 @@ -"""OAuth2 authentication endpoints.""" - -from typing import Annotated - -from fastapi import APIRouter, Depends, HTTPException, Query, Response, status -from fastapi.responses import RedirectResponse - -from app.core.config import settings -from app.core.dependencies import get_auth_service, get_oauth_service -from app.core.logging import get_logger -from app.services.auth import AuthService -from app.services.oauth import OAuthService - -router = APIRouter() -logger = get_logger(__name__) - - -@router.get("/{provider}/authorize") -async def oauth_authorize( - provider: str, - oauth_service: Annotated[OAuthService, Depends(get_oauth_service)], -) -> dict[str, str]: - """Get OAuth authorization URL.""" - try: - # Generate secure state parameter - state = oauth_service.generate_state() - - # Get authorization URL - auth_url = oauth_service.get_authorization_url(provider, state) - - except HTTPException: - raise - except Exception as e: - logger.exception("OAuth authorization failed for provider: %s", provider) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="OAuth authorization failed", - ) from e - else: - return { - "authorization_url": auth_url, - "state": state, - } - - -@router.get("/{provider}/callback") -async def oauth_callback( - provider: str, - response: Response, - code: Annotated[str, Query()], - oauth_service: Annotated[OAuthService, Depends(get_oauth_service)], - auth_service: Annotated[AuthService, Depends(get_auth_service)], -) -> RedirectResponse: - """Handle OAuth callback.""" - try: - # Handle OAuth callback and get user info - oauth_user_info = await oauth_service.handle_callback(provider, code) - - # Perform OAuth login (link or create user) - auth_response = await auth_service.oauth_login(oauth_user_info) - - # Create and store refresh token - user = await auth_service.get_current_user(auth_response.user.id) - refresh_token = await auth_service.create_and_store_refresh_token(user) - - # Set HTTP-only cookies for both tokens - response.set_cookie( - key="access_token", - value=auth_response.token.access_token, - max_age=auth_response.token.expires_in, - httponly=True, - secure=settings.COOKIE_SECURE, - samesite=settings.COOKIE_SAMESITE, - ) - response.set_cookie( - key="refresh_token", - value=refresh_token, - max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, - httponly=True, - secure=settings.COOKIE_SECURE, - samesite=settings.COOKIE_SAMESITE, - ) - - logger.info( - "OAuth login successful for user: %s via %s", - auth_response.user.email, - provider, - ) - - # Redirect back to frontend after successful authentication - return RedirectResponse( - url="http://localhost:8001/?auth=success", - status_code=302, - ) - - except HTTPException: - raise - except Exception as e: - logger.exception("OAuth callback failed for provider: %s", provider) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="OAuth callback failed", - ) from e - - -@router.get("/providers") -async def get_oauth_providers() -> dict[str, list[str]]: - """Get list of available OAuth providers.""" - return { - "providers": ["google", "github"], - } diff --git a/tests/api/v1/test_auth_endpoints.py b/tests/api/v1/test_auth_endpoints.py index 8e32642..dbe5f7b 100644 --- a/tests/api/v1/test_auth_endpoints.py +++ b/tests/api/v1/test_auth_endpoints.py @@ -1,6 +1,7 @@ """Tests for authentication endpoints.""" from typing import Any +from unittest.mock import patch import pytest import pytest_asyncio @@ -8,6 +9,7 @@ from httpx import AsyncClient from app.models.plan import Plan from app.models.user import User +from app.services.auth import OAuthUserInfo from app.utils.auth import JWTUtils @@ -307,3 +309,141 @@ class TestAuthEndpoints: # Test that get_admin_user passes for admin user result = await get_admin_user(admin_user) assert result == admin_user + + @pytest.mark.asyncio + async def test_get_oauth_providers(self, test_client: AsyncClient) -> None: + """Test getting list of OAuth providers.""" + response = await test_client.get("/api/v1/auth/providers") + + assert response.status_code == 200 + data = response.json() + assert "providers" in data + assert "google" in data["providers"] + assert "github" in data["providers"] + + @pytest.mark.asyncio + async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None: + """Test OAuth authorization URL generation for Google.""" + with patch("app.services.oauth.OAuthService.generate_state") as mock_state: + mock_state.return_value = "test_state_123" + + response = await test_client.get("/api/v1/auth/google/authorize") + + assert response.status_code == 200 + data = response.json() + assert "authorization_url" in data + assert "state" in data + assert data["state"] == "test_state_123" + assert "accounts.google.com" in data["authorization_url"] + + @pytest.mark.asyncio + async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None: + """Test OAuth authorization URL generation for GitHub.""" + with patch("app.services.oauth.OAuthService.generate_state") as mock_state: + mock_state.return_value = "test_state_456" + + response = await test_client.get("/api/v1/auth/github/authorize") + + assert response.status_code == 200 + data = response.json() + assert "authorization_url" in data + assert "state" in data + assert data["state"] == "test_state_456" + assert "github.com" in data["authorization_url"] + + @pytest.mark.asyncio + async def test_oauth_authorize_invalid_provider( + self, test_client: AsyncClient + ) -> None: + """Test OAuth authorization with invalid provider.""" + response = await test_client.get("/api/v1/auth/invalid/authorize") + + assert response.status_code == 400 + data = response.json() + assert "Unsupported OAuth provider" in data["detail"] + + @pytest.mark.asyncio + async def test_oauth_callback_new_user( + self, test_client: AsyncClient, ensure_plans: tuple[Any, Any] + ) -> None: + """Test OAuth callback for new user creation.""" + # Mock OAuth user info + mock_user_info = OAuthUserInfo( + provider="google", + provider_user_id="google_123", + email="newuser@gmail.com", + name="New User", + picture="https://example.com/avatar.jpg", + ) + + # Mock the entire handle_callback method to avoid actual OAuth API calls + with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback: + mock_callback.return_value = mock_user_info + + response = await test_client.get( + "/api/v1/auth/google/callback", + params={"code": "auth_code_123", "state": "test_state"}, + follow_redirects=False, + ) + + # OAuth callback should successfully process and redirect to frontend + assert response.status_code == 302 + assert response.headers["location"] == "http://localhost:8001/?auth=success" + + # The fact that we get a 302 redirect means the OAuth login was successful + # Detailed cookie testing can be done in integration tests + + @pytest.mark.asyncio + async def test_oauth_callback_existing_user_link( + self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any] + ) -> None: + """Test OAuth callback for linking to existing user.""" + # Mock OAuth user info with same email as test user + mock_user_info = OAuthUserInfo( + provider="github", + provider_user_id="github_456", + email=test_user.email, # Same email as existing user + name="Test User", + picture="https://github.com/avatar.jpg", + ) + + # Mock the entire handle_callback method to avoid actual OAuth API calls + with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback: + mock_callback.return_value = mock_user_info + + response = await test_client.get( + "/api/v1/auth/github/callback", + params={"code": "auth_code_456", "state": "test_state"}, + follow_redirects=False, + ) + + # OAuth callback should successfully process and redirect to frontend + assert response.status_code == 302 + assert response.headers["location"] == "http://localhost:8001/?auth=success" + + # The fact that we get a 302 redirect means the OAuth login was successful + # Detailed cookie testing can be done in integration tests + + @pytest.mark.asyncio + async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None: + """Test OAuth callback with missing authorization code.""" + response = await test_client.get( + "/api/v1/auth/google/callback", + params={"state": "test_state"}, # Missing code parameter + ) + + assert response.status_code == 422 # Validation error + + @pytest.mark.asyncio + async def test_oauth_callback_invalid_provider( + self, test_client: AsyncClient + ) -> None: + """Test OAuth callback with invalid provider.""" + response = await test_client.get( + "/api/v1/auth/invalid/callback", + params={"code": "auth_code_123", "state": "test_state"}, + ) + + assert response.status_code == 400 + data = response.json() + assert "Unsupported OAuth provider" in data["detail"] diff --git a/tests/api/v1/test_oauth_endpoints.py b/tests/api/v1/test_oauth_endpoints.py deleted file mode 100644 index 96301fc..0000000 --- a/tests/api/v1/test_oauth_endpoints.py +++ /dev/null @@ -1,151 +0,0 @@ -"""Tests for OAuth authentication endpoints.""" - -from typing import Any -from unittest.mock import AsyncMock, patch - -import pytest -from httpx import AsyncClient - -from app.services.oauth import OAuthUserInfo - - -class TestOAuthEndpoints: - """Test OAuth API endpoints.""" - - @pytest.mark.asyncio - async def test_get_oauth_providers(self, test_client: AsyncClient) -> None: - """Test getting list of OAuth providers.""" - response = await test_client.get("/api/v1/oauth/providers") - - assert response.status_code == 200 - data = response.json() - assert "providers" in data - assert "google" in data["providers"] - assert "github" in data["providers"] - - @pytest.mark.asyncio - async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None: - """Test OAuth authorization URL generation for Google.""" - with patch("app.services.oauth.OAuthService.generate_state") as mock_state: - mock_state.return_value = "test_state_123" - - response = await test_client.get("/api/v1/oauth/google/authorize") - - assert response.status_code == 200 - data = response.json() - assert "authorization_url" in data - assert "state" in data - assert data["state"] == "test_state_123" - assert "accounts.google.com" in data["authorization_url"] - - @pytest.mark.asyncio - async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None: - """Test OAuth authorization URL generation for GitHub.""" - with patch("app.services.oauth.OAuthService.generate_state") as mock_state: - mock_state.return_value = "test_state_456" - - response = await test_client.get("/api/v1/oauth/github/authorize") - - assert response.status_code == 200 - data = response.json() - assert "authorization_url" in data - assert "state" in data - assert data["state"] == "test_state_456" - assert "github.com" in data["authorization_url"] - - @pytest.mark.asyncio - async def test_oauth_authorize_invalid_provider( - self, test_client: AsyncClient - ) -> None: - """Test OAuth authorization with invalid provider.""" - response = await test_client.get("/api/v1/oauth/invalid/authorize") - - assert response.status_code == 400 - data = response.json() - assert "Unsupported OAuth provider" in data["detail"] - - @pytest.mark.asyncio - async def test_oauth_callback_new_user( - self, test_client: AsyncClient, ensure_plans: tuple[Any, Any] - ) -> None: - """Test OAuth callback for new user creation.""" - # Mock OAuth user info - mock_user_info = OAuthUserInfo( - provider="google", - provider_user_id="google_123", - email="newuser@gmail.com", - name="New User", - picture="https://example.com/avatar.jpg", - ) - - # Mock the entire handle_callback method to avoid actual OAuth API calls - with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback: - mock_callback.return_value = mock_user_info - - response = await test_client.get( - "/api/v1/oauth/google/callback", - params={"code": "auth_code_123", "state": "test_state"}, - follow_redirects=False, - ) - - # OAuth callback should successfully process and redirect to frontend - assert response.status_code == 302 - assert response.headers["location"] == "http://localhost:8001/?auth=success" - - # The fact that we get a 302 redirect means the OAuth login was successful - # Detailed cookie testing can be done in integration tests - - @pytest.mark.asyncio - async def test_oauth_callback_existing_user_link( - self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any] - ) -> None: - """Test OAuth callback for linking to existing user.""" - # Mock OAuth user info with same email as test user - mock_user_info = OAuthUserInfo( - provider="github", - provider_user_id="github_456", - email=test_user.email, # Same email as existing user - name="Test User", - picture="https://github.com/avatar.jpg", - ) - - # Mock the entire handle_callback method to avoid actual OAuth API calls - with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback: - mock_callback.return_value = mock_user_info - - response = await test_client.get( - "/api/v1/oauth/github/callback", - params={"code": "auth_code_456", "state": "test_state"}, - follow_redirects=False, - ) - - # OAuth callback should successfully process and redirect to frontend - assert response.status_code == 302 - assert response.headers["location"] == "http://localhost:8001/?auth=success" - - # The fact that we get a 302 redirect means the OAuth login was successful - # Detailed cookie testing can be done in integration tests - - @pytest.mark.asyncio - async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None: - """Test OAuth callback with missing authorization code.""" - response = await test_client.get( - "/api/v1/oauth/google/callback", - params={"state": "test_state"}, # Missing code parameter - ) - - assert response.status_code == 422 # Validation error - - @pytest.mark.asyncio - async def test_oauth_callback_invalid_provider( - self, test_client: AsyncClient - ) -> None: - """Test OAuth callback with invalid provider.""" - response = await test_client.get( - "/api/v1/oauth/invalid/callback", - params={"code": "auth_code_123", "state": "test_state"}, - ) - - assert response.status_code == 400 - data = response.json() - assert "Unsupported OAuth provider" in data["detail"] diff --git a/tests/services/test_oauth_service.py b/tests/services/test_oauth_service.py index 53b792e..be5c3bd 100644 --- a/tests/services/test_oauth_service.py +++ b/tests/services/test_oauth_service.py @@ -35,7 +35,7 @@ class TestOAuthService: assert isinstance(state, str) assert len(state) > 10 # Should be a reasonable length - + # Generate another to ensure they're different state2 = oauth_service.generate_state() assert state != state2 @@ -60,7 +60,7 @@ class TestOAuthService: with pytest.raises(Exception) as exc_info: oauth_service.get_provider("invalid") - + assert "Unsupported OAuth provider" in str(exc_info.value) @pytest.mark.asyncio @@ -101,7 +101,7 @@ class TestGoogleOAuthProvider: state = "test_state" auth_url = provider.get_authorization_url(state) - + assert "accounts.google.com" in auth_url assert "client_id=test_client_id" in auth_url assert f"state={state}" in auth_url @@ -111,7 +111,7 @@ class TestGoogleOAuthProvider: async def test_get_user_info_success(self) -> None: """Test successful user info retrieval.""" provider = GoogleOAuthProvider("test_client_id", "test_secret") - + mock_response_data = { "id": "google_user_123", "email": "test@gmail.com", @@ -151,14 +151,14 @@ class TestGitHubOAuthProvider: async def test_get_user_info_success(self) -> None: """Test successful user info retrieval.""" provider = GitHubOAuthProvider("test_client_id", "test_secret") - + mock_user_data = { "id": 123456, "login": "testuser", "name": "Test User", "avatar_url": "https://github.com/avatar.jpg", } - + mock_emails_data = [ {"email": "test@example.com", "primary": True, "verified": True}, {"email": "secondary@example.com", "primary": False, "verified": True}, @@ -170,7 +170,7 @@ class TestGitHubOAuthProvider: mock_user_response.status_code = 200 mock_user_response.json.return_value = mock_user_data - # Mock emails response + # Mock emails response mock_emails_response = Mock() mock_emails_response.status_code = 200 mock_emails_response.json.return_value = mock_emails_data @@ -189,4 +189,4 @@ class TestGitHubOAuthProvider: assert user_info.provider_user_id == "123456" assert user_info.email == "test@example.com" assert user_info.name == "Test User" - assert user_info.picture == "https://github.com/avatar.jpg" \ No newline at end of file + assert user_info.picture == "https://github.com/avatar.jpg"