diff --git a/.env.template b/.env.template new file mode 100644 index 0000000..a32b612 --- /dev/null +++ b/.env.template @@ -0,0 +1,29 @@ +# Application Configuration +HOST=localhost +PORT=8000 +RELOAD=true + +# Database Configuration +DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db +DATABASE_ECHO=false + +# Logging Configuration +LOG_LEVEL=info +LOG_FILE=logs/app.log +LOG_MAX_SIZE=10485760 +LOG_BACKUP_COUNT=5 + +# JWT Configuration +JWT_SECRET_KEY=your-secret-key-change-in-production +JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15 +JWT_REFRESH_TOKEN_EXPIRE_DAYS=7 + +# Cookie Configuration +COOKIE_SECURE=false + +# OAuth2 Configuration +GOOGLE_CLIENT_ID= +GOOGLE_CLIENT_SECRET= +GITHUB_CLIENT_ID= +GITHUB_CLIENT_SECRET= +OAUTH_REDIRECT_URL=http://localhost:8001/auth/callback \ No newline at end of file diff --git a/app/api/v1/__init__.py b/app/api/v1/__init__.py index 1d42966..c3d5e78 100644 --- a/app/api/v1/__init__.py +++ b/app/api/v1/__init__.py @@ -2,7 +2,7 @@ from fastapi import APIRouter -from app.api.v1 import auth, main +from app.api.v1 import auth, main, oauth # V1 API router with v1 prefix api_router = APIRouter(prefix="/v1") @@ -10,3 +10,4 @@ api_router = APIRouter(prefix="/v1") # Include all route modules api_router.include_router(main.router, tags=["main"]) api_router.include_router(auth.router, prefix="/auth", tags=["authentication"]) +api_router.include_router(oauth.router, prefix="/oauth", tags=["oauth"]) diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index f26da94..fd21528 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -54,8 +54,6 @@ async def register( samesite=settings.COOKIE_SAMESITE, ) - # Return only user data, tokens are now in cookies - return auth_response.user except HTTPException: raise except Exception as e: @@ -64,6 +62,8 @@ async def register( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Registration failed", ) from e + else: + return auth_response.user @router.post("/login") @@ -101,8 +101,6 @@ async def login( samesite=settings.COOKIE_SAMESITE, ) - # Return only user data, tokens are now in cookies - return auth_response.user except HTTPException: raise except Exception as e: @@ -111,6 +109,8 @@ async def login( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Login failed", ) from e + else: + return auth_response.user @router.get("/me") @@ -156,7 +156,6 @@ async def refresh_token( samesite=settings.COOKIE_SAMESITE, ) - return {"message": "Token refreshed successfully"} except HTTPException: raise except Exception as e: @@ -165,6 +164,8 @@ async def refresh_token( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Token refresh failed", ) from e + else: + return {"message": "Token refreshed successfully"} @router.post("/logout") @@ -176,7 +177,7 @@ async def logout( ) -> dict[str, str]: """Logout endpoint - clears cookies and revokes refresh token.""" user = None - + # Try to get user from access token first if access_token: try: @@ -188,7 +189,7 @@ async def logout( logger.info("Found user from access token: %s", user.email) except (HTTPException, Exception) as e: logger.info("Access token validation failed: %s", str(e)) - + # If no user found, try refresh token if not user and refresh_token: try: @@ -200,14 +201,14 @@ async def logout( logger.info("Found user from refresh token: %s", user.email) except (HTTPException, Exception) as e: logger.info("Refresh token validation failed: %s", str(e)) - + # If we found a user, revoke their refresh token if user: await auth_service.revoke_refresh_token(user) logger.info("Successfully revoked refresh token for user: %s", user.email) else: logger.info("No user found, skipping token revocation") - + # Always clear both cookies regardless of token validity response.delete_cookie( key="access_token", @@ -221,5 +222,5 @@ async def logout( secure=settings.COOKIE_SECURE, samesite=settings.COOKIE_SAMESITE, ) - + return {"message": "Successfully logged out"} diff --git a/app/api/v1/oauth.py b/app/api/v1/oauth.py new file mode 100644 index 0000000..8fe7508 --- /dev/null +++ b/app/api/v1/oauth.py @@ -0,0 +1,111 @@ +"""OAuth2 authentication endpoints.""" + +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, Query, Response, status +from fastapi.responses import RedirectResponse + +from app.core.config import settings +from app.core.dependencies import get_auth_service, get_oauth_service +from app.core.logging import get_logger +from app.services.auth import AuthService +from app.services.oauth import OAuthService + +router = APIRouter() +logger = get_logger(__name__) + + +@router.get("/{provider}/authorize") +async def oauth_authorize( + provider: str, + oauth_service: Annotated[OAuthService, Depends(get_oauth_service)], +) -> dict[str, str]: + """Get OAuth authorization URL.""" + try: + # Generate secure state parameter + state = oauth_service.generate_state() + + # Get authorization URL + auth_url = oauth_service.get_authorization_url(provider, state) + + except HTTPException: + raise + except Exception as e: + logger.exception("OAuth authorization failed for provider: %s", provider) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth authorization failed", + ) from e + else: + return { + "authorization_url": auth_url, + "state": state, + } + + +@router.get("/{provider}/callback") +async def oauth_callback( + provider: str, + response: Response, + code: Annotated[str, Query()], + oauth_service: Annotated[OAuthService, Depends(get_oauth_service)], + auth_service: Annotated[AuthService, Depends(get_auth_service)], +) -> RedirectResponse: + """Handle OAuth callback.""" + try: + # Handle OAuth callback and get user info + oauth_user_info = await oauth_service.handle_callback(provider, code) + + # Perform OAuth login (link or create user) + auth_response = await auth_service.oauth_login(oauth_user_info) + + # Create and store refresh token + user = await auth_service.get_current_user(auth_response.user.id) + refresh_token = await auth_service.create_and_store_refresh_token(user) + + # Set HTTP-only cookies for both tokens + response.set_cookie( + key="access_token", + value=auth_response.token.access_token, + max_age=auth_response.token.expires_in, + httponly=True, + secure=settings.COOKIE_SECURE, + samesite=settings.COOKIE_SAMESITE, + ) + response.set_cookie( + key="refresh_token", + value=refresh_token, + max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60, + httponly=True, + secure=settings.COOKIE_SECURE, + samesite=settings.COOKIE_SAMESITE, + ) + + logger.info( + "OAuth login successful for user: %s via %s", + auth_response.user.email, + provider, + ) + + # Redirect back to frontend after successful authentication + return RedirectResponse( + url="http://localhost:8001/?auth=success", + status_code=302, + ) + + except HTTPException: + raise + except Exception as e: + logger.exception("OAuth callback failed for provider: %s", provider) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="OAuth callback failed", + ) from e + + +@router.get("/providers") +async def get_oauth_providers() -> dict[str, list[str]]: + """Get list of available OAuth providers.""" + return { + "providers": ["google", "github"], + } diff --git a/app/core/config.py b/app/core/config.py index 378ec26..de03ce1 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -13,13 +13,16 @@ class Settings(BaseSettings): extra="ignore", ) + # Application Configuration HOST: str = "localhost" PORT: int = 8000 RELOAD: bool = True + # Database Configuration DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db" DATABASE_ECHO: bool = False + # Logging Configuration LOG_LEVEL: str = "info" LOG_FILE: str = "logs/app.log" LOG_MAX_SIZE: int = 10 * 1024 * 1024 @@ -31,12 +34,19 @@ class Settings(BaseSettings): "your-secret-key-change-in-production" # noqa: S105 default value if none set in .env ) JWT_ALGORITHM: str = "HS256" - JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # Shorter-lived access token - JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Longer-lived refresh token + JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 + JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Cookie Configuration - COOKIE_SECURE: bool = True # Set to False for development without HTTPS + COOKIE_SECURE: bool = True COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax" + # OAuth2 Configuration + GOOGLE_CLIENT_ID: str = "" + GOOGLE_CLIENT_SECRET: str = "" + GITHUB_CLIENT_ID: str = "" + GITHUB_CLIENT_SECRET: str = "" + OAUTH_REDIRECT_URL: str = "http://localhost:8001/auth/callback" + settings = Settings() diff --git a/app/core/dependencies.py b/app/core/dependencies.py index e7c3781..4890839 100644 --- a/app/core/dependencies.py +++ b/app/core/dependencies.py @@ -1,6 +1,6 @@ """FastAPI dependencies.""" -from typing import Annotated, NoReturn, cast +from typing import Annotated, cast from fastapi import Cookie, Depends, HTTPException, status from sqlmodel.ext.asyncio.session import AsyncSession @@ -9,27 +9,12 @@ from app.core.database import get_db from app.core.logging import get_logger from app.models.user import User from app.services.auth import AuthService +from app.services.oauth import OAuthService from app.utils.auth import JWTUtils logger = get_logger(__name__) -def _raise_invalid_token_error() -> NoReturn: - """Raise an invalid token HTTP exception.""" - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token payload", - ) - - -def _raise_auth_error() -> NoReturn: - """Raise an authentication HTTP exception.""" - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - ) - - async def get_auth_service( session: Annotated[AsyncSession, Depends(get_db)], ) -> AuthService: @@ -37,6 +22,13 @@ async def get_auth_service( return AuthService(session) +async def get_oauth_service( + session: Annotated[AsyncSession, Depends(get_db)], +) -> OAuthService: + """Get the OAuth service.""" + return OAuthService(session) + + async def get_current_user( access_token: Annotated[str | None, Cookie()], auth_service: Annotated[AuthService, Depends(get_auth_service)], @@ -46,7 +38,10 @@ async def get_current_user( # Check if access token cookie exists if not access_token: logger.warning("No access token cookie found") - _raise_auth_error() + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + ) # Decode the JWT token payload = JWTUtils.decode_access_token(access_token) @@ -54,7 +49,10 @@ async def get_current_user( # Extract user ID from token user_id_str = payload.get("sub") if not user_id_str: - _raise_invalid_token_error() + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token payload", + ) # At this point user_id_str is guaranteed to be truthy, safe to cast user_id_str = cast("str", user_id_str) @@ -74,9 +72,12 @@ async def get_current_user( except HTTPException: # Re-raise HTTPExceptions without wrapping them raise - except Exception: + except Exception as e: logger.exception("Failed to authenticate user") - _raise_auth_error() + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + ) from e async def get_current_active_user( diff --git a/app/main.py b/app/main.py index 627bbe4..886a605 100644 --- a/app/main.py +++ b/app/main.py @@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from app.api import api_router from app.core.database import init_db @@ -28,6 +29,15 @@ def create_app() -> FastAPI: """Create and configure the FastAPI application.""" app = FastAPI(lifespan=lifespan) + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:8001"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + app.add_middleware(LoggingMiddleware) # Include API routes diff --git a/app/repositories/user_oauth.py b/app/repositories/user_oauth.py new file mode 100644 index 0000000..7bf76b6 --- /dev/null +++ b/app/repositories/user_oauth.py @@ -0,0 +1,117 @@ +"""Repository for user OAuth operations.""" + +from typing import Any + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.logging import get_logger +from app.models.user_oauth import UserOauth + +logger = get_logger(__name__) + + +class UserOauthRepository: + """Repository for user OAuth operations.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize repository with database session.""" + self.session = session + + async def get_by_provider_user_id( + self, + provider: str, + provider_user_id: str, + ) -> UserOauth | None: + """Get user OAuth by provider and provider user ID.""" + try: + statement = select(UserOauth).where( + UserOauth.provider == provider, + UserOauth.provider_user_id == provider_user_id, + ) + result = await self.session.exec(statement) + return result.first() + except Exception: + logger.exception( + "Failed to get user OAuth by provider user ID: %s:%s", + provider, + provider_user_id, + ) + raise + + async def get_by_user_id_and_provider( + self, + user_id: int, + provider: str, + ) -> UserOauth | None: + """Get user OAuth by user ID and provider.""" + try: + statement = select(UserOauth).where( + UserOauth.user_id == user_id, + UserOauth.provider == provider, + ) + result = await self.session.exec(statement) + except Exception: + logger.exception( + "Failed to get user OAuth by user ID and provider: %s:%s", + user_id, + provider, + ) + raise + else: + return result.first() + + async def create(self, oauth_data: dict[str, Any]) -> UserOauth: + """Create a new user OAuth record.""" + try: + oauth = UserOauth(**oauth_data) + self.session.add(oauth) + await self.session.commit() + await self.session.refresh(oauth) + logger.info( + "Created OAuth link for user %s with provider %s", + oauth.user_id, + oauth.provider, + ) + except Exception: + await self.session.rollback() + logger.exception("Failed to create user OAuth") + raise + else: + return oauth + + async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth: + """Update a user OAuth record.""" + try: + for key, value in update_data.items(): + setattr(oauth, key, value) + + self.session.add(oauth) + await self.session.commit() + await self.session.refresh(oauth) + logger.info( + "Updated OAuth link for user %s with provider %s", + oauth.user_id, + oauth.provider, + ) + except Exception: + await self.session.rollback() + logger.exception("Failed to update user OAuth") + raise + else: + return oauth + + async def delete(self, oauth: UserOauth) -> None: + """Delete a user OAuth record.""" + try: + await self.session.delete(oauth) + await self.session.commit() + logger.info( + "Deleted OAuth link for user %s with provider %s", + oauth.user_id, + oauth.provider, + ) + except Exception: + await self.session.rollback() + logger.exception("Failed to delete user OAuth") + raise diff --git a/app/services/auth.py b/app/services/auth.py index 14485e3..e003f19 100644 --- a/app/services/auth.py +++ b/app/services/auth.py @@ -10,6 +10,7 @@ from app.core.config import settings from app.core.logging import get_logger from app.models.user import User from app.repositories.user import UserRepository +from app.repositories.user_oauth import UserOauthRepository from app.schemas.auth import ( AuthResponse, TokenResponse, @@ -17,6 +18,7 @@ from app.schemas.auth import ( UserRegisterRequest, UserResponse, ) +from app.services.oauth import OAuthUserInfo from app.utils.auth import JWTUtils, PasswordUtils logger = get_logger(__name__) @@ -29,6 +31,7 @@ class AuthService: """Initialize the auth service.""" self.session = session self.user_repo = UserRepository(session) + self.oauth_repo = UserOauthRepository(session) async def register(self, request: UserRegisterRequest) -> AuthResponse: """Register a new user.""" @@ -203,7 +206,7 @@ class AuthService: # Check if refresh token is expired if user.refresh_token_expires_at and datetime.now( - UTC + UTC, ) > user.refresh_token_expires_at.replace(tzinfo=UTC): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -272,3 +275,127 @@ class AuthService: created_at=user.created_at, updated_at=user.updated_at, ) + + async def oauth_login(self, oauth_user_info: OAuthUserInfo) -> AuthResponse: + """Handle OAuth login - link or create user.""" + logger.info( + "OAuth login attempt for %s with provider %s", + oauth_user_info.email, + oauth_user_info.provider, + ) + + # Check if user already has OAuth link for this provider + existing_oauth = await self.oauth_repo.get_by_provider_user_id( + oauth_user_info.provider, + oauth_user_info.provider_user_id, + ) + + if existing_oauth: + # User exists with this OAuth provider, get the user + user = await self.user_repo.get_by_id(existing_oauth.user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Linked user not found", + ) + + # Refresh user to avoid greenlet issues + await self.session.refresh(user) + + # Update OAuth record with latest info + oauth_update_data = { + "email": oauth_user_info.email, + "name": oauth_user_info.name, + "picture": oauth_user_info.picture, + } + await self.oauth_repo.update(existing_oauth, oauth_update_data) + + logger.info( + "OAuth login successful for existing user: %s", + oauth_user_info.email, + ) + else: + # Check if user exists by email + user = await self.user_repo.get_by_email(oauth_user_info.email) + + if user: + # Refresh user to avoid greenlet issues + await self.session.refresh(user) + + # Store user picture value to avoid greenlet issues later + current_user_picture = user.picture + + # Link existing user to OAuth provider + oauth_data = { + "user_id": user.id, + "provider": oauth_user_info.provider, + "provider_user_id": oauth_user_info.provider_user_id, + "email": oauth_user_info.email, + "name": oauth_user_info.name, + "picture": oauth_user_info.picture, + } + await self.oauth_repo.create(oauth_data) + + # Update user profile with OAuth info if needed + user_update_data = {} + if not current_user_picture and oauth_user_info.picture: + user_update_data["picture"] = oauth_user_info.picture + if user_update_data: + await self.user_repo.update(user, user_update_data) + # Refresh user after update to avoid greenlet issues + await self.session.refresh(user) + + logger.info( + "Linked existing user %s to OAuth provider %s", + oauth_user_info.email, + oauth_user_info.provider, + ) + else: + # Create new user + user_data = { + "email": oauth_user_info.email, + "name": oauth_user_info.name, + "picture": oauth_user_info.picture, + "is_active": True, + # No password for OAuth users + "password_hash": None, + } + + user = await self.user_repo.create(user_data) + + # Create OAuth link + oauth_data = { + "user_id": user.id, + "provider": oauth_user_info.provider, + "provider_user_id": oauth_user_info.provider_user_id, + "email": oauth_user_info.email, + "name": oauth_user_info.name, + "picture": oauth_user_info.picture, + } + await self.oauth_repo.create(oauth_data) + + logger.info( + "Created new user %s from OAuth provider %s", + oauth_user_info.email, + oauth_user_info.provider, + ) + + # Refresh user to avoid greenlet issues and check if user is active + await self.session.refresh(user) + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Account is deactivated", + ) + + # Generate access token + token = self._create_access_token(user) + + # Create response + user_response = await self.create_user_response(user) + + logger.info( + "OAuth login completed successfully for user: %s", + oauth_user_info.email, + ) + return AuthResponse(user=user_response, token=token) diff --git a/app/services/oauth.py b/app/services/oauth.py new file mode 100644 index 0000000..74bbc6f --- /dev/null +++ b/app/services/oauth.py @@ -0,0 +1,329 @@ +"""OAuth2 service for external authentication providers.""" + +import secrets +from abc import ABC, abstractmethod +from urllib.parse import urlencode + +import httpx +from fastapi import HTTPException, status +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.core.config import settings +from app.core.logging import get_logger + +logger = get_logger(__name__) + + +class OAuthUserInfo: + """OAuth user information.""" + + def __init__( + self, + provider: str, + provider_user_id: str, + email: str, + name: str, + picture: str | None = None, + ) -> None: + """Initialize OAuth user info.""" + self.provider = provider + self.provider_user_id = provider_user_id + self.email = email + self.name = name + self.picture = picture + + +class OAuthProvider(ABC): + """Abstract base class for OAuth providers.""" + + def __init__(self, client_id: str, client_secret: str) -> None: + """Initialize OAuth provider with client ID and secret.""" + self.client_id = client_id + self.client_secret = client_secret + + @property + @abstractmethod + def provider_name(self) -> str: + """Return the provider name.""" + + @property + @abstractmethod + def authorization_url(self) -> str: + """Return the authorization URL.""" + + @property + @abstractmethod + def token_url(self) -> str: + """Return the token URL.""" + + @property + @abstractmethod + def user_info_url(self) -> str: + """Return the user info URL.""" + + @property + @abstractmethod + def scope(self) -> str: + """Return the required scopes.""" + + def get_authorization_url(self, state: str) -> str: + """Generate authorization URL with state parameter.""" + # Construct provider-specific redirect URI + redirect_uri = ( + f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" + ) + + params = { + "client_id": self.client_id, + "redirect_uri": redirect_uri, + "scope": self.scope, + "response_type": "code", + "state": state, + } + return f"{self.authorization_url}?{urlencode(params)}" + + async def exchange_code_for_token(self, code: str) -> str: + """Exchange authorization code for access token.""" + # Construct provider-specific redirect URI (must match authorization request) + redirect_uri = ( + f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" + ) + + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": redirect_uri, + } + + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=data, + headers={"Accept": "application/json"}, + ) + + if response.status_code != status.HTTP_200_OK: + logger.error("Failed to exchange code for token: %s", response.text) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to exchange authorization code", + ) + + token_data = response.json() + return str(token_data["access_token"]) + + @abstractmethod + async def get_user_info(self, access_token: str) -> OAuthUserInfo: + """Get user information from the provider.""" + + +class GoogleOAuthProvider(OAuthProvider): + """Google OAuth provider.""" + + @property + def provider_name(self) -> str: + """Return the provider name.""" + return "google" + + @property + def authorization_url(self) -> str: + """Return the authorization URL.""" + return "https://accounts.google.com/o/oauth2/v2/auth" + + @property + def token_url(self) -> str: + """Return the token URL.""" + return "https://oauth2.googleapis.com/token" + + @property + def user_info_url(self) -> str: + """Return the user info URL.""" + return "https://www.googleapis.com/oauth2/v2/userinfo" + + @property + def scope(self) -> str: + """Return the required scopes.""" + return "openid email profile" + + async def exchange_code_for_token(self, code: str) -> str: + """Exchange authorization code for access token.""" + # Construct provider-specific redirect URI (must match authorization request) + redirect_uri = ( + f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" + ) + + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + } + + async with httpx.AsyncClient() as client: + response = await client.post( + self.token_url, + data=data, + headers={"Accept": "application/json"}, + ) + + if response.status_code != status.HTTP_200_OK: + logger.error("Failed to exchange code for token: %s", response.text) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to exchange authorization code", + ) + + token_data = response.json() + return str(token_data["access_token"]) + + async def get_user_info(self, access_token: str) -> OAuthUserInfo: + """Get user information from Google.""" + async with httpx.AsyncClient() as client: + response = await client.get( + self.user_info_url, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + if response.status_code != status.HTTP_200_OK: + logger.error("Failed to get user info: %s", response.text) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to get user information", + ) + + user_data = response.json() + return OAuthUserInfo( + provider=self.provider_name, + provider_user_id=user_data["id"], + email=user_data["email"], + name=user_data["name"], + picture=user_data.get("picture"), + ) + + +class GitHubOAuthProvider(OAuthProvider): + """GitHub OAuth provider.""" + + @property + def provider_name(self) -> str: + """Return the provider name.""" + return "github" + + @property + def authorization_url(self) -> str: + """Return the authorization URL.""" + return "https://github.com/login/oauth/authorize" + + @property + def token_url(self) -> str: + """Return the token URL.""" + return "https://github.com/login/oauth/access_token" + + @property + def user_info_url(self) -> str: + """Return the user info URL.""" + return "https://api.github.com/user" + + @property + def scope(self) -> str: + """Return the required scopes.""" + return "user:email" + + async def get_user_info(self, access_token: str) -> OAuthUserInfo: + """Get user information from GitHub.""" + async with httpx.AsyncClient() as client: + # Get user profile + user_response = await client.get( + self.user_info_url, + headers={"Authorization": f"Bearer {access_token}"}, + ) + + if user_response.status_code != status.HTTP_200_OK: + logger.error("Failed to get user info: %s", user_response.text) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to get user information", + ) + + user_data = user_response.json() + + # Get user email (GitHub doesn't include email in basic profile) + email_response = await client.get( + "https://api.github.com/user/emails", + headers={"Authorization": f"Bearer {access_token}"}, + ) + + if email_response.status_code != status.HTTP_200_OK: + logger.error("Failed to get user emails: %s", email_response.text) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Failed to get user email", + ) + + emails = email_response.json() + primary_email = next( + (email["email"] for email in emails if email["primary"]), + None, + ) + + if not primary_email: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No primary email found", + ) + + return OAuthUserInfo( + provider=self.provider_name, + provider_user_id=str(user_data["id"]), + email=primary_email, + name=user_data.get("name") or user_data["login"], + picture=user_data.get("avatar_url"), + ) + + +class OAuthService: + """Service for handling OAuth authentication.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize OAuth service with database session.""" + self.session = session + self.providers = { + "google": GoogleOAuthProvider( + settings.GOOGLE_CLIENT_ID, + settings.GOOGLE_CLIENT_SECRET, + ), + "github": GitHubOAuthProvider( + settings.GITHUB_CLIENT_ID, + settings.GITHUB_CLIENT_SECRET, + ), + } + + def get_provider(self, provider_name: str) -> OAuthProvider: + """Get OAuth provider by name.""" + if provider_name not in self.providers: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported OAuth provider: {provider_name}", + ) + return self.providers[provider_name] + + def generate_state(self) -> str: + """Generate a secure state parameter for OAuth flow.""" + return secrets.token_urlsafe(32) + + def get_authorization_url(self, provider_name: str, state: str) -> str: + """Get authorization URL for the specified provider.""" + provider = self.get_provider(provider_name) + return provider.get_authorization_url(state) + + async def handle_callback(self, provider_name: str, code: str) -> OAuthUserInfo: + """Handle OAuth callback and return user info.""" + provider = self.get_provider(provider_name) + + # Exchange code for access token + access_token = await provider.exchange_code_for_token(code) + + # Get user information + return await provider.get_user_info(access_token) diff --git a/pyproject.toml b/pyproject.toml index dea2581..42a00f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ dependencies = [ "bcrypt==4.3.0", "email-validator==2.2.0", "fastapi[standard]==0.116.1", + "httpx==0.28.1", "pydantic-settings==2.10.1", "pyjwt==2.10.1", "sqlmodel==0.0.24", @@ -36,7 +37,7 @@ exclude = ["alembic"] [tool.ruff.lint] select = ["ALL"] -ignore = ["D100", "D103"] +ignore = ["D100", "D103", "TRY301"] [tool.ruff.per-file-ignores] "tests/**/*.py" = ["S101", "S105"] diff --git a/tests/api/v1/test_oauth_endpoints.py b/tests/api/v1/test_oauth_endpoints.py new file mode 100644 index 0000000..96301fc --- /dev/null +++ b/tests/api/v1/test_oauth_endpoints.py @@ -0,0 +1,151 @@ +"""Tests for OAuth authentication endpoints.""" + +from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import AsyncClient + +from app.services.oauth import OAuthUserInfo + + +class TestOAuthEndpoints: + """Test OAuth API endpoints.""" + + @pytest.mark.asyncio + async def test_get_oauth_providers(self, test_client: AsyncClient) -> None: + """Test getting list of OAuth providers.""" + response = await test_client.get("/api/v1/oauth/providers") + + assert response.status_code == 200 + data = response.json() + assert "providers" in data + assert "google" in data["providers"] + assert "github" in data["providers"] + + @pytest.mark.asyncio + async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None: + """Test OAuth authorization URL generation for Google.""" + with patch("app.services.oauth.OAuthService.generate_state") as mock_state: + mock_state.return_value = "test_state_123" + + response = await test_client.get("/api/v1/oauth/google/authorize") + + assert response.status_code == 200 + data = response.json() + assert "authorization_url" in data + assert "state" in data + assert data["state"] == "test_state_123" + assert "accounts.google.com" in data["authorization_url"] + + @pytest.mark.asyncio + async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None: + """Test OAuth authorization URL generation for GitHub.""" + with patch("app.services.oauth.OAuthService.generate_state") as mock_state: + mock_state.return_value = "test_state_456" + + response = await test_client.get("/api/v1/oauth/github/authorize") + + assert response.status_code == 200 + data = response.json() + assert "authorization_url" in data + assert "state" in data + assert data["state"] == "test_state_456" + assert "github.com" in data["authorization_url"] + + @pytest.mark.asyncio + async def test_oauth_authorize_invalid_provider( + self, test_client: AsyncClient + ) -> None: + """Test OAuth authorization with invalid provider.""" + response = await test_client.get("/api/v1/oauth/invalid/authorize") + + assert response.status_code == 400 + data = response.json() + assert "Unsupported OAuth provider" in data["detail"] + + @pytest.mark.asyncio + async def test_oauth_callback_new_user( + self, test_client: AsyncClient, ensure_plans: tuple[Any, Any] + ) -> None: + """Test OAuth callback for new user creation.""" + # Mock OAuth user info + mock_user_info = OAuthUserInfo( + provider="google", + provider_user_id="google_123", + email="newuser@gmail.com", + name="New User", + picture="https://example.com/avatar.jpg", + ) + + # Mock the entire handle_callback method to avoid actual OAuth API calls + with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback: + mock_callback.return_value = mock_user_info + + response = await test_client.get( + "/api/v1/oauth/google/callback", + params={"code": "auth_code_123", "state": "test_state"}, + follow_redirects=False, + ) + + # OAuth callback should successfully process and redirect to frontend + assert response.status_code == 302 + assert response.headers["location"] == "http://localhost:8001/?auth=success" + + # The fact that we get a 302 redirect means the OAuth login was successful + # Detailed cookie testing can be done in integration tests + + @pytest.mark.asyncio + async def test_oauth_callback_existing_user_link( + self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any] + ) -> None: + """Test OAuth callback for linking to existing user.""" + # Mock OAuth user info with same email as test user + mock_user_info = OAuthUserInfo( + provider="github", + provider_user_id="github_456", + email=test_user.email, # Same email as existing user + name="Test User", + picture="https://github.com/avatar.jpg", + ) + + # Mock the entire handle_callback method to avoid actual OAuth API calls + with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback: + mock_callback.return_value = mock_user_info + + response = await test_client.get( + "/api/v1/oauth/github/callback", + params={"code": "auth_code_456", "state": "test_state"}, + follow_redirects=False, + ) + + # OAuth callback should successfully process and redirect to frontend + assert response.status_code == 302 + assert response.headers["location"] == "http://localhost:8001/?auth=success" + + # The fact that we get a 302 redirect means the OAuth login was successful + # Detailed cookie testing can be done in integration tests + + @pytest.mark.asyncio + async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None: + """Test OAuth callback with missing authorization code.""" + response = await test_client.get( + "/api/v1/oauth/google/callback", + params={"state": "test_state"}, # Missing code parameter + ) + + assert response.status_code == 422 # Validation error + + @pytest.mark.asyncio + async def test_oauth_callback_invalid_provider( + self, test_client: AsyncClient + ) -> None: + """Test OAuth callback with invalid provider.""" + response = await test_client.get( + "/api/v1/oauth/invalid/callback", + params={"code": "auth_code_123", "state": "test_state"}, + ) + + assert response.status_code == 400 + data = response.json() + assert "Unsupported OAuth provider" in data["detail"] diff --git a/tests/services/test_oauth_service.py b/tests/services/test_oauth_service.py new file mode 100644 index 0000000..53b792e --- /dev/null +++ b/tests/services/test_oauth_service.py @@ -0,0 +1,192 @@ +"""Tests for OAuth service.""" + +from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from httpx import AsyncClient + +from app.services.oauth import ( + GitHubOAuthProvider, + GoogleOAuthProvider, + OAuthService, + OAuthUserInfo, +) + + +class TestOAuthService: + """Test OAuth service functionality.""" + + @pytest.mark.asyncio + async def test_oauth_service_initialization(self, test_session: Any) -> None: + """Test OAuth service initialization.""" + oauth_service = OAuthService(test_session) + + assert "google" in oauth_service.providers + assert "github" in oauth_service.providers + assert isinstance(oauth_service.providers["google"], GoogleOAuthProvider) + assert isinstance(oauth_service.providers["github"], GitHubOAuthProvider) + + @pytest.mark.asyncio + async def test_generate_state(self, test_session: Any) -> None: + """Test state generation.""" + oauth_service = OAuthService(test_session) + state = oauth_service.generate_state() + + assert isinstance(state, str) + assert len(state) > 10 # Should be a reasonable length + + # Generate another to ensure they're different + state2 = oauth_service.generate_state() + assert state != state2 + + @pytest.mark.asyncio + async def test_get_provider_valid(self, test_session: Any) -> None: + """Test getting valid OAuth provider.""" + oauth_service = OAuthService(test_session) + + google_provider = oauth_service.get_provider("google") + assert isinstance(google_provider, GoogleOAuthProvider) + assert google_provider.provider_name == "google" + + github_provider = oauth_service.get_provider("github") + assert isinstance(github_provider, GitHubOAuthProvider) + assert github_provider.provider_name == "github" + + @pytest.mark.asyncio + async def test_get_provider_invalid(self, test_session: Any) -> None: + """Test getting invalid OAuth provider.""" + oauth_service = OAuthService(test_session) + + with pytest.raises(Exception) as exc_info: + oauth_service.get_provider("invalid") + + assert "Unsupported OAuth provider" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_get_authorization_url(self, test_session: Any) -> None: + """Test authorization URL generation.""" + oauth_service = OAuthService(test_session) + state = "test_state_123" + + # Test Google + google_url = oauth_service.get_authorization_url("google", state) + assert "accounts.google.com" in google_url + assert "client_id=" in google_url + assert f"state={state}" in google_url + + # Test GitHub + github_url = oauth_service.get_authorization_url("github", state) + assert "github.com" in github_url + assert "client_id=" in github_url + assert f"state={state}" in github_url + + +class TestGoogleOAuthProvider: + """Test Google OAuth provider.""" + + def test_provider_properties(self) -> None: + """Test Google provider properties.""" + provider = GoogleOAuthProvider("test_client_id", "test_secret") + + assert provider.provider_name == "google" + assert "accounts.google.com" in provider.authorization_url + assert "oauth2.googleapis.com" in provider.token_url + assert "googleapis.com" in provider.user_info_url + assert "openid email profile" in provider.scope + + def test_authorization_url_generation(self) -> None: + """Test authorization URL generation.""" + provider = GoogleOAuthProvider("test_client_id", "test_secret") + state = "test_state" + + auth_url = provider.get_authorization_url(state) + + assert "accounts.google.com" in auth_url + assert "client_id=test_client_id" in auth_url + assert f"state={state}" in auth_url + assert "scope=openid+email+profile" in auth_url + + @pytest.mark.asyncio + async def test_get_user_info_success(self) -> None: + """Test successful user info retrieval.""" + provider = GoogleOAuthProvider("test_client_id", "test_secret") + + mock_response_data = { + "id": "google_user_123", + "email": "test@gmail.com", + "name": "Test User", + "picture": "https://example.com/avatar.jpg", + } + + with patch("httpx.AsyncClient.get") as mock_get: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + mock_get.return_value = mock_response + + user_info = await provider.get_user_info("test_access_token") + + assert user_info.provider == "google" + assert user_info.provider_user_id == "google_user_123" + assert user_info.email == "test@gmail.com" + assert user_info.name == "Test User" + assert user_info.picture == "https://example.com/avatar.jpg" + + +class TestGitHubOAuthProvider: + """Test GitHub OAuth provider.""" + + def test_provider_properties(self) -> None: + """Test GitHub provider properties.""" + provider = GitHubOAuthProvider("test_client_id", "test_secret") + + assert provider.provider_name == "github" + assert "github.com" in provider.authorization_url + assert "github.com" in provider.token_url + assert "api.github.com" in provider.user_info_url + assert "user:email" in provider.scope + + @pytest.mark.asyncio + async def test_get_user_info_success(self) -> None: + """Test successful user info retrieval.""" + provider = GitHubOAuthProvider("test_client_id", "test_secret") + + mock_user_data = { + "id": 123456, + "login": "testuser", + "name": "Test User", + "avatar_url": "https://github.com/avatar.jpg", + } + + mock_emails_data = [ + {"email": "test@example.com", "primary": True, "verified": True}, + {"email": "secondary@example.com", "primary": False, "verified": True}, + ] + + with patch("httpx.AsyncClient.get") as mock_get: + # Mock user profile response + mock_user_response = Mock() + mock_user_response.status_code = 200 + mock_user_response.json.return_value = mock_user_data + + # Mock emails response + mock_emails_response = Mock() + mock_emails_response.status_code = 200 + mock_emails_response.json.return_value = mock_emails_data + + # Return different responses based on URL + def side_effect(url, **kwargs): + if "user/emails" in str(url): + return mock_emails_response + return mock_user_response + + mock_get.side_effect = side_effect + + user_info = await provider.get_user_info("test_access_token") + + assert user_info.provider == "github" + assert user_info.provider_user_id == "123456" + assert user_info.email == "test@example.com" + assert user_info.name == "Test User" + assert user_info.picture == "https://github.com/avatar.jpg" \ No newline at end of file diff --git a/uv.lock b/uv.lock index f541059..ffbd421 100644 --- a/uv.lock +++ b/uv.lock @@ -46,6 +46,7 @@ dependencies = [ { name = "bcrypt" }, { name = "email-validator" }, { name = "fastapi", extra = ["standard"] }, + { name = "httpx" }, { name = "pydantic-settings" }, { name = "pyjwt" }, { name = "sqlmodel" }, @@ -69,6 +70,7 @@ requires-dist = [ { name = "bcrypt", specifier = "==4.3.0" }, { name = "email-validator", specifier = "==2.2.0" }, { name = "fastapi", extras = ["standard"], specifier = "==0.116.1" }, + { name = "httpx", specifier = "==0.28.1" }, { name = "pydantic-settings", specifier = "==2.10.1" }, { name = "pyjwt", specifier = "==2.10.1" }, { name = "sqlmodel", specifier = "==0.0.24" },