Files
sdb2-backend/app/services/oauth.py
JSC 734521c5c3
Some checks failed
Backend CI / lint (push) Failing after 5m0s
Backend CI / test (push) Successful in 3m39s
feat: Add environment configuration files and update settings for production and development
2025-08-09 14:43:20 +02:00

330 lines
10 KiB
Python

"""OAuth2 service for external authentication providers."""
import secrets
from abc import ABC, abstractmethod
from urllib.parse import urlencode
import httpx
from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class OAuthUserInfo:
"""OAuth user information."""
def __init__(
self,
provider: str,
provider_user_id: str,
email: str,
name: str,
picture: str | None = None,
) -> None:
"""Initialize OAuth user info."""
self.provider = provider
self.provider_user_id = provider_user_id
self.email = email
self.name = name
self.picture = picture
class OAuthProvider(ABC):
"""Abstract base class for OAuth providers."""
def __init__(self, client_id: str, client_secret: str) -> None:
"""Initialize OAuth provider with client ID and secret."""
self.client_id = client_id
self.client_secret = client_secret
@property
@abstractmethod
def provider_name(self) -> str:
"""Return the provider name."""
@property
@abstractmethod
def authorization_url(self) -> str:
"""Return the authorization URL."""
@property
@abstractmethod
def token_url(self) -> str:
"""Return the token URL."""
@property
@abstractmethod
def user_info_url(self) -> str:
"""Return the user info URL."""
@property
@abstractmethod
def scope(self) -> str:
"""Return the required scopes."""
def get_authorization_url(self, state: str) -> str:
"""Generate authorization URL with state parameter."""
# Construct provider-specific redirect URI
redirect_uri = (
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
)
params = {
"client_id": self.client_id,
"redirect_uri": redirect_uri,
"scope": self.scope,
"response_type": "code",
"state": state,
}
return f"{self.authorization_url}?{urlencode(params)}"
async def exchange_code_for_token(self, code: str) -> str:
"""Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request)
redirect_uri = (
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
)
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"redirect_uri": redirect_uri,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.token_url,
data=data,
headers={"Accept": "application/json"},
)
if response.status_code != status.HTTP_200_OK:
logger.error("Failed to exchange code for token: %s", response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to exchange authorization code",
)
token_data = response.json()
return str(token_data["access_token"])
@abstractmethod
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
"""Get user information from the provider."""
class GoogleOAuthProvider(OAuthProvider):
"""Google OAuth provider."""
@property
def provider_name(self) -> str:
"""Return the provider name."""
return "google"
@property
def authorization_url(self) -> str:
"""Return the authorization URL."""
return "https://accounts.google.com/o/oauth2/v2/auth"
@property
def token_url(self) -> str:
"""Return the token URL."""
return "https://oauth2.googleapis.com/token"
@property
def user_info_url(self) -> str:
"""Return the user info URL."""
return "https://www.googleapis.com/oauth2/v2/userinfo"
@property
def scope(self) -> str:
"""Return the required scopes."""
return "openid email profile"
async def exchange_code_for_token(self, code: str) -> str:
"""Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request)
redirect_uri = (
f"{settings.BACKEND_URL}/api/v1/auth/{self.provider_name}/callback"
)
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.token_url,
data=data,
headers={"Accept": "application/json"},
)
if response.status_code != status.HTTP_200_OK:
logger.error("Failed to exchange code for token: %s", response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to exchange authorization code",
)
token_data = response.json()
return str(token_data["access_token"])
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
"""Get user information from Google."""
async with httpx.AsyncClient() as client:
response = await client.get(
self.user_info_url,
headers={"Authorization": f"Bearer {access_token}"},
)
if response.status_code != status.HTTP_200_OK:
logger.error("Failed to get user info: %s", response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to get user information",
)
user_data = response.json()
return OAuthUserInfo(
provider=self.provider_name,
provider_user_id=user_data["id"],
email=user_data["email"],
name=user_data["name"],
picture=user_data.get("picture"),
)
class GitHubOAuthProvider(OAuthProvider):
"""GitHub OAuth provider."""
@property
def provider_name(self) -> str:
"""Return the provider name."""
return "github"
@property
def authorization_url(self) -> str:
"""Return the authorization URL."""
return "https://github.com/login/oauth/authorize"
@property
def token_url(self) -> str:
"""Return the token URL."""
return "https://github.com/login/oauth/access_token"
@property
def user_info_url(self) -> str:
"""Return the user info URL."""
return "https://api.github.com/user"
@property
def scope(self) -> str:
"""Return the required scopes."""
return "user:email"
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
"""Get user information from GitHub."""
async with httpx.AsyncClient() as client:
# Get user profile
user_response = await client.get(
self.user_info_url,
headers={"Authorization": f"Bearer {access_token}"},
)
if user_response.status_code != status.HTTP_200_OK:
logger.error("Failed to get user info: %s", user_response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to get user information",
)
user_data = user_response.json()
# Get user email (GitHub doesn't include email in basic profile)
email_response = await client.get(
"https://api.github.com/user/emails",
headers={"Authorization": f"Bearer {access_token}"},
)
if email_response.status_code != status.HTTP_200_OK:
logger.error("Failed to get user emails: %s", email_response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to get user email",
)
emails = email_response.json()
primary_email = next(
(email["email"] for email in emails if email["primary"]),
None,
)
if not primary_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No primary email found",
)
return OAuthUserInfo(
provider=self.provider_name,
provider_user_id=str(user_data["id"]),
email=primary_email,
name=user_data.get("name") or user_data["login"],
picture=user_data.get("avatar_url"),
)
class OAuthService:
"""Service for handling OAuth authentication."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize OAuth service with database session."""
self.session = session
self.providers = {
"google": GoogleOAuthProvider(
settings.GOOGLE_CLIENT_ID,
settings.GOOGLE_CLIENT_SECRET,
),
"github": GitHubOAuthProvider(
settings.GITHUB_CLIENT_ID,
settings.GITHUB_CLIENT_SECRET,
),
}
def get_provider(self, provider_name: str) -> OAuthProvider:
"""Get OAuth provider by name."""
if provider_name not in self.providers:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider_name}",
)
return self.providers[provider_name]
def generate_state(self) -> str:
"""Generate a secure state parameter for OAuth flow."""
return secrets.token_urlsafe(32)
def get_authorization_url(self, provider_name: str, state: str) -> str:
"""Get authorization URL for the specified provider."""
provider = self.get_provider(provider_name)
return provider.get_authorization_url(state)
async def handle_callback(self, provider_name: str, code: str) -> OAuthUserInfo:
"""Handle OAuth callback and return user info."""
provider = self.get_provider(provider_name)
# Exchange code for access token
access_token = await provider.exchange_code_for_token(code)
# Get user information
return await provider.get_user_info(access_token)