71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Dict, Any, Optional
|
|
from authlib.integrations.flask_client import OAuth
|
|
|
|
|
|
class OAuthProvider(ABC):
|
|
"""Abstract base class for OAuth providers."""
|
|
|
|
def __init__(self, oauth: OAuth, client_id: str, client_secret: str):
|
|
self.oauth = oauth
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self._client = None
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self) -> str:
|
|
"""Provider name (e.g., 'google', 'github')."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def display_name(self) -> str:
|
|
"""Human-readable provider name (e.g., 'Google', 'GitHub')."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_client_config(self) -> Dict[str, Any]:
|
|
"""Return OAuth client configuration."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Extract user information from OAuth token response."""
|
|
pass
|
|
|
|
def get_client(self):
|
|
"""Get or create OAuth client."""
|
|
if self._client is None:
|
|
config = self.get_client_config()
|
|
self._client = self.oauth.register(
|
|
name=self.name,
|
|
client_id=self.client_id,
|
|
client_secret=self.client_secret,
|
|
**config,
|
|
)
|
|
return self._client
|
|
|
|
def get_authorization_url(self, redirect_uri: str) -> str:
|
|
"""Generate authorization URL for OAuth flow."""
|
|
client = self.get_client()
|
|
return client.authorize_redirect(redirect_uri).location
|
|
|
|
def exchange_code_for_token(
|
|
self, code: str = None, redirect_uri: str = None
|
|
) -> Dict[str, Any]:
|
|
"""Exchange authorization code for access token."""
|
|
client = self.get_client()
|
|
token = client.authorize_access_token()
|
|
return token
|
|
|
|
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Normalize user data to common format."""
|
|
return {
|
|
"id": user_info.get("id"),
|
|
"email": user_info.get("email"),
|
|
"name": user_info.get("name"),
|
|
"picture": user_info.get("picture"),
|
|
"provider": self.name,
|
|
}
|