feat: Implement OAuth2 authentication with Google and GitHub
- Added OAuth2 endpoints for Google and GitHub authentication. - Created OAuth service to handle provider interactions and user info retrieval. - Implemented user OAuth repository for managing user OAuth links in the database. - Updated auth service to support linking existing users and creating new users via OAuth. - Added CORS middleware to allow frontend access. - Created tests for OAuth endpoints and service functionality. - Introduced environment configuration for OAuth client IDs and secrets. - Added logging for OAuth operations and error handling.
This commit is contained in:
@@ -10,6 +10,7 @@ from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.repositories.user_oauth import UserOauthRepository
|
||||
from app.schemas.auth import (
|
||||
AuthResponse,
|
||||
TokenResponse,
|
||||
@@ -17,6 +18,7 @@ from app.schemas.auth import (
|
||||
UserRegisterRequest,
|
||||
UserResponse,
|
||||
)
|
||||
from app.services.oauth import OAuthUserInfo
|
||||
from app.utils.auth import JWTUtils, PasswordUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -29,6 +31,7 @@ class AuthService:
|
||||
"""Initialize the auth service."""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
self.oauth_repo = UserOauthRepository(session)
|
||||
|
||||
async def register(self, request: UserRegisterRequest) -> AuthResponse:
|
||||
"""Register a new user."""
|
||||
@@ -203,7 +206,7 @@ class AuthService:
|
||||
|
||||
# Check if refresh token is expired
|
||||
if user.refresh_token_expires_at and datetime.now(
|
||||
UTC
|
||||
UTC,
|
||||
) > user.refresh_token_expires_at.replace(tzinfo=UTC):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -272,3 +275,127 @@ class AuthService:
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
async def oauth_login(self, oauth_user_info: OAuthUserInfo) -> AuthResponse:
|
||||
"""Handle OAuth login - link or create user."""
|
||||
logger.info(
|
||||
"OAuth login attempt for %s with provider %s",
|
||||
oauth_user_info.email,
|
||||
oauth_user_info.provider,
|
||||
)
|
||||
|
||||
# Check if user already has OAuth link for this provider
|
||||
existing_oauth = await self.oauth_repo.get_by_provider_user_id(
|
||||
oauth_user_info.provider,
|
||||
oauth_user_info.provider_user_id,
|
||||
)
|
||||
|
||||
if existing_oauth:
|
||||
# User exists with this OAuth provider, get the user
|
||||
user = await self.user_repo.get_by_id(existing_oauth.user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Linked user not found",
|
||||
)
|
||||
|
||||
# Refresh user to avoid greenlet issues
|
||||
await self.session.refresh(user)
|
||||
|
||||
# Update OAuth record with latest info
|
||||
oauth_update_data = {
|
||||
"email": oauth_user_info.email,
|
||||
"name": oauth_user_info.name,
|
||||
"picture": oauth_user_info.picture,
|
||||
}
|
||||
await self.oauth_repo.update(existing_oauth, oauth_update_data)
|
||||
|
||||
logger.info(
|
||||
"OAuth login successful for existing user: %s",
|
||||
oauth_user_info.email,
|
||||
)
|
||||
else:
|
||||
# Check if user exists by email
|
||||
user = await self.user_repo.get_by_email(oauth_user_info.email)
|
||||
|
||||
if user:
|
||||
# Refresh user to avoid greenlet issues
|
||||
await self.session.refresh(user)
|
||||
|
||||
# Store user picture value to avoid greenlet issues later
|
||||
current_user_picture = user.picture
|
||||
|
||||
# Link existing user to OAuth provider
|
||||
oauth_data = {
|
||||
"user_id": user.id,
|
||||
"provider": oauth_user_info.provider,
|
||||
"provider_user_id": oauth_user_info.provider_user_id,
|
||||
"email": oauth_user_info.email,
|
||||
"name": oauth_user_info.name,
|
||||
"picture": oauth_user_info.picture,
|
||||
}
|
||||
await self.oauth_repo.create(oauth_data)
|
||||
|
||||
# Update user profile with OAuth info if needed
|
||||
user_update_data = {}
|
||||
if not current_user_picture and oauth_user_info.picture:
|
||||
user_update_data["picture"] = oauth_user_info.picture
|
||||
if user_update_data:
|
||||
await self.user_repo.update(user, user_update_data)
|
||||
# Refresh user after update to avoid greenlet issues
|
||||
await self.session.refresh(user)
|
||||
|
||||
logger.info(
|
||||
"Linked existing user %s to OAuth provider %s",
|
||||
oauth_user_info.email,
|
||||
oauth_user_info.provider,
|
||||
)
|
||||
else:
|
||||
# Create new user
|
||||
user_data = {
|
||||
"email": oauth_user_info.email,
|
||||
"name": oauth_user_info.name,
|
||||
"picture": oauth_user_info.picture,
|
||||
"is_active": True,
|
||||
# No password for OAuth users
|
||||
"password_hash": None,
|
||||
}
|
||||
|
||||
user = await self.user_repo.create(user_data)
|
||||
|
||||
# Create OAuth link
|
||||
oauth_data = {
|
||||
"user_id": user.id,
|
||||
"provider": oauth_user_info.provider,
|
||||
"provider_user_id": oauth_user_info.provider_user_id,
|
||||
"email": oauth_user_info.email,
|
||||
"name": oauth_user_info.name,
|
||||
"picture": oauth_user_info.picture,
|
||||
}
|
||||
await self.oauth_repo.create(oauth_data)
|
||||
|
||||
logger.info(
|
||||
"Created new user %s from OAuth provider %s",
|
||||
oauth_user_info.email,
|
||||
oauth_user_info.provider,
|
||||
)
|
||||
|
||||
# Refresh user to avoid greenlet issues and check if user is active
|
||||
await self.session.refresh(user)
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
# Generate access token
|
||||
token = self._create_access_token(user)
|
||||
|
||||
# Create response
|
||||
user_response = await self.create_user_response(user)
|
||||
|
||||
logger.info(
|
||||
"OAuth login completed successfully for user: %s",
|
||||
oauth_user_info.email,
|
||||
)
|
||||
return AuthResponse(user=user_response, token=token)
|
||||
|
||||
329
app/services/oauth.py
Normal file
329
app/services/oauth.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""OAuth2 service for external authentication providers."""
|
||||
|
||||
import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthUserInfo:
|
||||
"""OAuth user information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
email: str,
|
||||
name: str,
|
||||
picture: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize OAuth user info."""
|
||||
self.provider = provider
|
||||
self.provider_user_id = provider_user_id
|
||||
self.email = email
|
||||
self.name = name
|
||||
self.picture = picture
|
||||
|
||||
|
||||
class OAuthProvider(ABC):
|
||||
"""Abstract base class for OAuth providers."""
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||
"""Initialize OAuth provider with client ID and secret."""
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def authorization_url(self) -> str:
|
||||
"""Return the authorization URL."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def token_url(self) -> str:
|
||||
"""Return the token URL."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def user_info_url(self) -> str:
|
||||
"""Return the user info URL."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def scope(self) -> str:
|
||||
"""Return the required scopes."""
|
||||
|
||||
def get_authorization_url(self, state: str) -> str:
|
||||
"""Generate authorization URL with state parameter."""
|
||||
# Construct provider-specific redirect URI
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": self.scope,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
}
|
||||
return f"{self.authorization_url}?{urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_token(self, code: str) -> str:
|
||||
"""Exchange authorization code for access token."""
|
||||
# Construct provider-specific redirect URI (must match authorization request)
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to exchange code for token: %s", response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to exchange authorization code",
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
return str(token_data["access_token"])
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Get user information from the provider."""
|
||||
|
||||
|
||||
class GoogleOAuthProvider(OAuthProvider):
|
||||
"""Google OAuth provider."""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "google"
|
||||
|
||||
@property
|
||||
def authorization_url(self) -> str:
|
||||
"""Return the authorization URL."""
|
||||
return "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
"""Return the token URL."""
|
||||
return "https://oauth2.googleapis.com/token"
|
||||
|
||||
@property
|
||||
def user_info_url(self) -> str:
|
||||
"""Return the user info URL."""
|
||||
return "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
@property
|
||||
def scope(self) -> str:
|
||||
"""Return the required scopes."""
|
||||
return "openid email profile"
|
||||
|
||||
async def exchange_code_for_token(self, code: str) -> str:
|
||||
"""Exchange authorization code for access token."""
|
||||
# Construct provider-specific redirect URI (must match authorization request)
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to exchange code for token: %s", response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to exchange authorization code",
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
return str(token_data["access_token"])
|
||||
|
||||
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Get user information from Google."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.user_info_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
if response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to get user info: %s", response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to get user information",
|
||||
)
|
||||
|
||||
user_data = response.json()
|
||||
return OAuthUserInfo(
|
||||
provider=self.provider_name,
|
||||
provider_user_id=user_data["id"],
|
||||
email=user_data["email"],
|
||||
name=user_data["name"],
|
||||
picture=user_data.get("picture"),
|
||||
)
|
||||
|
||||
|
||||
class GitHubOAuthProvider(OAuthProvider):
|
||||
"""GitHub OAuth provider."""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "github"
|
||||
|
||||
@property
|
||||
def authorization_url(self) -> str:
|
||||
"""Return the authorization URL."""
|
||||
return "https://github.com/login/oauth/authorize"
|
||||
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
"""Return the token URL."""
|
||||
return "https://github.com/login/oauth/access_token"
|
||||
|
||||
@property
|
||||
def user_info_url(self) -> str:
|
||||
"""Return the user info URL."""
|
||||
return "https://api.github.com/user"
|
||||
|
||||
@property
|
||||
def scope(self) -> str:
|
||||
"""Return the required scopes."""
|
||||
return "user:email"
|
||||
|
||||
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Get user information from GitHub."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Get user profile
|
||||
user_response = await client.get(
|
||||
self.user_info_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
if user_response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to get user info: %s", user_response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to get user information",
|
||||
)
|
||||
|
||||
user_data = user_response.json()
|
||||
|
||||
# Get user email (GitHub doesn't include email in basic profile)
|
||||
email_response = await client.get(
|
||||
"https://api.github.com/user/emails",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
if email_response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to get user emails: %s", email_response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to get user email",
|
||||
)
|
||||
|
||||
emails = email_response.json()
|
||||
primary_email = next(
|
||||
(email["email"] for email in emails if email["primary"]),
|
||||
None,
|
||||
)
|
||||
|
||||
if not primary_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No primary email found",
|
||||
)
|
||||
|
||||
return OAuthUserInfo(
|
||||
provider=self.provider_name,
|
||||
provider_user_id=str(user_data["id"]),
|
||||
email=primary_email,
|
||||
name=user_data.get("name") or user_data["login"],
|
||||
picture=user_data.get("avatar_url"),
|
||||
)
|
||||
|
||||
|
||||
class OAuthService:
|
||||
"""Service for handling OAuth authentication."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize OAuth service with database session."""
|
||||
self.session = session
|
||||
self.providers = {
|
||||
"google": GoogleOAuthProvider(
|
||||
settings.GOOGLE_CLIENT_ID,
|
||||
settings.GOOGLE_CLIENT_SECRET,
|
||||
),
|
||||
"github": GitHubOAuthProvider(
|
||||
settings.GITHUB_CLIENT_ID,
|
||||
settings.GITHUB_CLIENT_SECRET,
|
||||
),
|
||||
}
|
||||
|
||||
def get_provider(self, provider_name: str) -> OAuthProvider:
|
||||
"""Get OAuth provider by name."""
|
||||
if provider_name not in self.providers:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported OAuth provider: {provider_name}",
|
||||
)
|
||||
return self.providers[provider_name]
|
||||
|
||||
def generate_state(self) -> str:
|
||||
"""Generate a secure state parameter for OAuth flow."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def get_authorization_url(self, provider_name: str, state: str) -> str:
|
||||
"""Get authorization URL for the specified provider."""
|
||||
provider = self.get_provider(provider_name)
|
||||
return provider.get_authorization_url(state)
|
||||
|
||||
async def handle_callback(self, provider_name: str, code: str) -> OAuthUserInfo:
|
||||
"""Handle OAuth callback and return user info."""
|
||||
provider = self.get_provider(provider_name)
|
||||
|
||||
# Exchange code for access token
|
||||
access_token = await provider.exchange_code_for_token(code)
|
||||
|
||||
# Get user information
|
||||
return await provider.get_user_info(access_token)
|
||||
Reference in New Issue
Block a user