330 lines
10 KiB
Python
330 lines
10 KiB
Python
"""OAuth2 service for external authentication providers."""
|
|
|
|
import secrets
|
|
from abc import ABC, abstractmethod
|
|
from urllib.parse import urlencode
|
|
|
|
import httpx
|
|
from fastapi import HTTPException, status
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.config import settings
|
|
from app.core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class OAuthUserInfo:
|
|
"""OAuth user information."""
|
|
|
|
def __init__(
|
|
self,
|
|
provider: str,
|
|
provider_user_id: str,
|
|
email: str,
|
|
name: str,
|
|
picture: str | None = None,
|
|
) -> None:
|
|
"""Initialize OAuth user info."""
|
|
self.provider = provider
|
|
self.provider_user_id = provider_user_id
|
|
self.email = email
|
|
self.name = name
|
|
self.picture = picture
|
|
|
|
|
|
class OAuthProvider(ABC):
|
|
"""Abstract base class for OAuth providers."""
|
|
|
|
def __init__(self, client_id: str, client_secret: str) -> None:
|
|
"""Initialize OAuth provider with client ID and secret."""
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
|
|
@property
|
|
@abstractmethod
|
|
def provider_name(self) -> str:
|
|
"""Return the provider name."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def authorization_url(self) -> str:
|
|
"""Return the authorization URL."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def token_url(self) -> str:
|
|
"""Return the token URL."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def user_info_url(self) -> str:
|
|
"""Return the user info URL."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def scope(self) -> str:
|
|
"""Return the required scopes."""
|
|
|
|
def get_authorization_url(self, state: str) -> str:
|
|
"""Generate authorization URL with state parameter."""
|
|
# Construct provider-specific redirect URI
|
|
redirect_uri = (
|
|
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
|
)
|
|
|
|
params = {
|
|
"client_id": self.client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"scope": self.scope,
|
|
"response_type": "code",
|
|
"state": state,
|
|
}
|
|
return f"{self.authorization_url}?{urlencode(params)}"
|
|
|
|
async def exchange_code_for_token(self, code: str) -> str:
|
|
"""Exchange authorization code for access token."""
|
|
# Construct provider-specific redirect URI (must match authorization request)
|
|
redirect_uri = (
|
|
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
|
)
|
|
|
|
data = {
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
"code": code,
|
|
"redirect_uri": redirect_uri,
|
|
}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
self.token_url,
|
|
data=data,
|
|
headers={"Accept": "application/json"},
|
|
)
|
|
|
|
if response.status_code != status.HTTP_200_OK:
|
|
logger.error("Failed to exchange code for token: %s", response.text)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Failed to exchange authorization code",
|
|
)
|
|
|
|
token_data = response.json()
|
|
return str(token_data["access_token"])
|
|
|
|
@abstractmethod
|
|
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
|
"""Get user information from the provider."""
|
|
|
|
|
|
class GoogleOAuthProvider(OAuthProvider):
|
|
"""Google OAuth provider."""
|
|
|
|
@property
|
|
def provider_name(self) -> str:
|
|
"""Return the provider name."""
|
|
return "google"
|
|
|
|
@property
|
|
def authorization_url(self) -> str:
|
|
"""Return the authorization URL."""
|
|
return "https://accounts.google.com/o/oauth2/v2/auth"
|
|
|
|
@property
|
|
def token_url(self) -> str:
|
|
"""Return the token URL."""
|
|
return "https://oauth2.googleapis.com/token"
|
|
|
|
@property
|
|
def user_info_url(self) -> str:
|
|
"""Return the user info URL."""
|
|
return "https://www.googleapis.com/oauth2/v2/userinfo"
|
|
|
|
@property
|
|
def scope(self) -> str:
|
|
"""Return the required scopes."""
|
|
return "openid email profile"
|
|
|
|
async def exchange_code_for_token(self, code: str) -> str:
|
|
"""Exchange authorization code for access token."""
|
|
# Construct provider-specific redirect URI (must match authorization request)
|
|
redirect_uri = (
|
|
f"http://localhost:8000/api/v1/auth/{self.provider_name}/callback"
|
|
)
|
|
|
|
data = {
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
"code": code,
|
|
"grant_type": "authorization_code",
|
|
"redirect_uri": redirect_uri,
|
|
}
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
self.token_url,
|
|
data=data,
|
|
headers={"Accept": "application/json"},
|
|
)
|
|
|
|
if response.status_code != status.HTTP_200_OK:
|
|
logger.error("Failed to exchange code for token: %s", response.text)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Failed to exchange authorization code",
|
|
)
|
|
|
|
token_data = response.json()
|
|
return str(token_data["access_token"])
|
|
|
|
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
|
"""Get user information from Google."""
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
self.user_info_url,
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
if response.status_code != status.HTTP_200_OK:
|
|
logger.error("Failed to get user info: %s", response.text)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Failed to get user information",
|
|
)
|
|
|
|
user_data = response.json()
|
|
return OAuthUserInfo(
|
|
provider=self.provider_name,
|
|
provider_user_id=user_data["id"],
|
|
email=user_data["email"],
|
|
name=user_data["name"],
|
|
picture=user_data.get("picture"),
|
|
)
|
|
|
|
|
|
class GitHubOAuthProvider(OAuthProvider):
|
|
"""GitHub OAuth provider."""
|
|
|
|
@property
|
|
def provider_name(self) -> str:
|
|
"""Return the provider name."""
|
|
return "github"
|
|
|
|
@property
|
|
def authorization_url(self) -> str:
|
|
"""Return the authorization URL."""
|
|
return "https://github.com/login/oauth/authorize"
|
|
|
|
@property
|
|
def token_url(self) -> str:
|
|
"""Return the token URL."""
|
|
return "https://github.com/login/oauth/access_token"
|
|
|
|
@property
|
|
def user_info_url(self) -> str:
|
|
"""Return the user info URL."""
|
|
return "https://api.github.com/user"
|
|
|
|
@property
|
|
def scope(self) -> str:
|
|
"""Return the required scopes."""
|
|
return "user:email"
|
|
|
|
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
|
"""Get user information from GitHub."""
|
|
async with httpx.AsyncClient() as client:
|
|
# Get user profile
|
|
user_response = await client.get(
|
|
self.user_info_url,
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
if user_response.status_code != status.HTTP_200_OK:
|
|
logger.error("Failed to get user info: %s", user_response.text)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Failed to get user information",
|
|
)
|
|
|
|
user_data = user_response.json()
|
|
|
|
# Get user email (GitHub doesn't include email in basic profile)
|
|
email_response = await client.get(
|
|
"https://api.github.com/user/emails",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
)
|
|
|
|
if email_response.status_code != status.HTTP_200_OK:
|
|
logger.error("Failed to get user emails: %s", email_response.text)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Failed to get user email",
|
|
)
|
|
|
|
emails = email_response.json()
|
|
primary_email = next(
|
|
(email["email"] for email in emails if email["primary"]),
|
|
None,
|
|
)
|
|
|
|
if not primary_email:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="No primary email found",
|
|
)
|
|
|
|
return OAuthUserInfo(
|
|
provider=self.provider_name,
|
|
provider_user_id=str(user_data["id"]),
|
|
email=primary_email,
|
|
name=user_data.get("name") or user_data["login"],
|
|
picture=user_data.get("avatar_url"),
|
|
)
|
|
|
|
|
|
class OAuthService:
|
|
"""Service for handling OAuth authentication."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize OAuth service with database session."""
|
|
self.session = session
|
|
self.providers = {
|
|
"google": GoogleOAuthProvider(
|
|
settings.GOOGLE_CLIENT_ID,
|
|
settings.GOOGLE_CLIENT_SECRET,
|
|
),
|
|
"github": GitHubOAuthProvider(
|
|
settings.GITHUB_CLIENT_ID,
|
|
settings.GITHUB_CLIENT_SECRET,
|
|
),
|
|
}
|
|
|
|
def get_provider(self, provider_name: str) -> OAuthProvider:
|
|
"""Get OAuth provider by name."""
|
|
if provider_name not in self.providers:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unsupported OAuth provider: {provider_name}",
|
|
)
|
|
return self.providers[provider_name]
|
|
|
|
def generate_state(self) -> str:
|
|
"""Generate a secure state parameter for OAuth flow."""
|
|
return secrets.token_urlsafe(32)
|
|
|
|
def get_authorization_url(self, provider_name: str, state: str) -> str:
|
|
"""Get authorization URL for the specified provider."""
|
|
provider = self.get_provider(provider_name)
|
|
return provider.get_authorization_url(state)
|
|
|
|
async def handle_callback(self, provider_name: str, code: str) -> OAuthUserInfo:
|
|
"""Handle OAuth callback and return user info."""
|
|
provider = self.get_provider(provider_name)
|
|
|
|
# Exchange code for access token
|
|
access_token = await provider.exchange_code_for_token(code)
|
|
|
|
# Get user information
|
|
return await provider.get_user_info(access_token)
|