"""OAuth2 service for external authentication providers.""" import secrets from abc import ABC, abstractmethod from urllib.parse import urlencode import httpx from fastapi import HTTPException, status from sqlmodel.ext.asyncio.session import AsyncSession from app.core.config import settings from app.core.logging import get_logger logger = get_logger(__name__) class OAuthUserInfo: """OAuth user information.""" def __init__( self, provider: str, provider_user_id: str, email: str, name: str, picture: str | None = None, ) -> None: """Initialize OAuth user info.""" self.provider = provider self.provider_user_id = provider_user_id self.email = email self.name = name self.picture = picture class OAuthProvider(ABC): """Abstract base class for OAuth providers.""" def __init__(self, client_id: str, client_secret: str) -> None: """Initialize OAuth provider with client ID and secret.""" self.client_id = client_id self.client_secret = client_secret @property @abstractmethod def provider_name(self) -> str: """Return the provider name.""" @property @abstractmethod def authorization_url(self) -> str: """Return the authorization URL.""" @property @abstractmethod def token_url(self) -> str: """Return the token URL.""" @property @abstractmethod def user_info_url(self) -> str: """Return the user info URL.""" @property @abstractmethod def scope(self) -> str: """Return the required scopes.""" def get_authorization_url(self, state: str) -> str: """Generate authorization URL with state parameter.""" # Construct provider-specific redirect URI redirect_uri = ( f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" ) params = { "client_id": self.client_id, "redirect_uri": redirect_uri, "scope": self.scope, "response_type": "code", "state": state, } return f"{self.authorization_url}?{urlencode(params)}" async def exchange_code_for_token(self, code: str) -> str: """Exchange authorization code for access token.""" # Construct provider-specific redirect URI (must match authorization request) redirect_uri = ( f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" ) data = { "client_id": self.client_id, "client_secret": self.client_secret, "code": code, "redirect_uri": redirect_uri, } async with httpx.AsyncClient() as client: response = await client.post( self.token_url, data=data, headers={"Accept": "application/json"}, ) if response.status_code != status.HTTP_200_OK: logger.error("Failed to exchange code for token: %s", response.text) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to exchange authorization code", ) token_data = response.json() return str(token_data["access_token"]) @abstractmethod async def get_user_info(self, access_token: str) -> OAuthUserInfo: """Get user information from the provider.""" class GoogleOAuthProvider(OAuthProvider): """Google OAuth provider.""" @property def provider_name(self) -> str: """Return the provider name.""" return "google" @property def authorization_url(self) -> str: """Return the authorization URL.""" return "https://accounts.google.com/o/oauth2/v2/auth" @property def token_url(self) -> str: """Return the token URL.""" return "https://oauth2.googleapis.com/token" @property def user_info_url(self) -> str: """Return the user info URL.""" return "https://www.googleapis.com/oauth2/v2/userinfo" @property def scope(self) -> str: """Return the required scopes.""" return "openid email profile" async def exchange_code_for_token(self, code: str) -> str: """Exchange authorization code for access token.""" # Construct provider-specific redirect URI (must match authorization request) redirect_uri = ( f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback" ) data = { "client_id": self.client_id, "client_secret": self.client_secret, "code": code, "grant_type": "authorization_code", "redirect_uri": redirect_uri, } async with httpx.AsyncClient() as client: response = await client.post( self.token_url, data=data, headers={"Accept": "application/json"}, ) if response.status_code != status.HTTP_200_OK: logger.error("Failed to exchange code for token: %s", response.text) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to exchange authorization code", ) token_data = response.json() return str(token_data["access_token"]) async def get_user_info(self, access_token: str) -> OAuthUserInfo: """Get user information from Google.""" async with httpx.AsyncClient() as client: response = await client.get( self.user_info_url, headers={"Authorization": f"Bearer {access_token}"}, ) if response.status_code != status.HTTP_200_OK: logger.error("Failed to get user info: %s", response.text) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to get user information", ) user_data = response.json() return OAuthUserInfo( provider=self.provider_name, provider_user_id=user_data["id"], email=user_data["email"], name=user_data["name"], picture=user_data.get("picture"), ) class GitHubOAuthProvider(OAuthProvider): """GitHub OAuth provider.""" @property def provider_name(self) -> str: """Return the provider name.""" return "github" @property def authorization_url(self) -> str: """Return the authorization URL.""" return "https://github.com/login/oauth/authorize" @property def token_url(self) -> str: """Return the token URL.""" return "https://github.com/login/oauth/access_token" @property def user_info_url(self) -> str: """Return the user info URL.""" return "https://api.github.com/user" @property def scope(self) -> str: """Return the required scopes.""" return "user:email" async def get_user_info(self, access_token: str) -> OAuthUserInfo: """Get user information from GitHub.""" async with httpx.AsyncClient() as client: # Get user profile user_response = await client.get( self.user_info_url, headers={"Authorization": f"Bearer {access_token}"}, ) if user_response.status_code != status.HTTP_200_OK: logger.error("Failed to get user info: %s", user_response.text) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to get user information", ) user_data = user_response.json() # Get user email (GitHub doesn't include email in basic profile) email_response = await client.get( "https://api.github.com/user/emails", headers={"Authorization": f"Bearer {access_token}"}, ) if email_response.status_code != status.HTTP_200_OK: logger.error("Failed to get user emails: %s", email_response.text) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Failed to get user email", ) emails = email_response.json() primary_email = next( (email["email"] for email in emails if email["primary"]), None, ) if not primary_email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="No primary email found", ) return OAuthUserInfo( provider=self.provider_name, provider_user_id=str(user_data["id"]), email=primary_email, name=user_data.get("name") or user_data["login"], picture=user_data.get("avatar_url"), ) class OAuthService: """Service for handling OAuth authentication.""" def __init__(self, session: AsyncSession) -> None: """Initialize OAuth service with database session.""" self.session = session self.providers = { "google": GoogleOAuthProvider( settings.GOOGLE_CLIENT_ID, settings.GOOGLE_CLIENT_SECRET, ), "github": GitHubOAuthProvider( settings.GITHUB_CLIENT_ID, settings.GITHUB_CLIENT_SECRET, ), } def get_provider(self, provider_name: str) -> OAuthProvider: """Get OAuth provider by name.""" if provider_name not in self.providers: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported OAuth provider: {provider_name}", ) return self.providers[provider_name] def generate_state(self) -> str: """Generate a secure state parameter for OAuth flow.""" return secrets.token_urlsafe(32) def get_authorization_url(self, provider_name: str, state: str) -> str: """Get authorization URL for the specified provider.""" provider = self.get_provider(provider_name) return provider.get_authorization_url(state) async def handle_callback(self, provider_name: str, code: str) -> OAuthUserInfo: """Handle OAuth callback and return user info.""" provider = self.get_provider(provider_name) # Exchange code for access token access_token = await provider.exchange_code_for_token(code) # Get user information return await provider.get_user_info(access_token)