feat: Implement OAuth2 authentication with Google and GitHub

- Added OAuth2 endpoints for Google and GitHub authentication.
- Created OAuth service to handle provider interactions and user info retrieval.
- Implemented user OAuth repository for managing user OAuth links in the database.
- Updated auth service to support linking existing users and creating new users via OAuth.
- Added CORS middleware to allow frontend access.
- Created tests for OAuth endpoints and service functionality.
- Introduced environment configuration for OAuth client IDs and secrets.
- Added logging for OAuth operations and error handling.
This commit is contained in:
JSC
2025-07-26 14:38:13 +02:00
parent 52ebc59293
commit 51423779a8
14 changed files with 1119 additions and 37 deletions

View File

@@ -2,7 +2,7 @@
from fastapi import APIRouter
from app.api.v1 import auth, main
from app.api.v1 import auth, main, oauth
# V1 API router with v1 prefix
api_router = APIRouter(prefix="/v1")
@@ -10,3 +10,4 @@ api_router = APIRouter(prefix="/v1")
# Include all route modules
api_router.include_router(main.router, tags=["main"])
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
api_router.include_router(oauth.router, prefix="/oauth", tags=["oauth"])

View File

@@ -54,8 +54,6 @@ async def register(
samesite=settings.COOKIE_SAMESITE,
)
# Return only user data, tokens are now in cookies
return auth_response.user
except HTTPException:
raise
except Exception as e:
@@ -64,6 +62,8 @@ async def register(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Registration failed",
) from e
else:
return auth_response.user
@router.post("/login")
@@ -101,8 +101,6 @@ async def login(
samesite=settings.COOKIE_SAMESITE,
)
# Return only user data, tokens are now in cookies
return auth_response.user
except HTTPException:
raise
except Exception as e:
@@ -111,6 +109,8 @@ async def login(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Login failed",
) from e
else:
return auth_response.user
@router.get("/me")
@@ -156,7 +156,6 @@ async def refresh_token(
samesite=settings.COOKIE_SAMESITE,
)
return {"message": "Token refreshed successfully"}
except HTTPException:
raise
except Exception as e:
@@ -165,6 +164,8 @@ async def refresh_token(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Token refresh failed",
) from e
else:
return {"message": "Token refreshed successfully"}
@router.post("/logout")
@@ -176,7 +177,7 @@ async def logout(
) -> dict[str, str]:
"""Logout endpoint - clears cookies and revokes refresh token."""
user = None
# Try to get user from access token first
if access_token:
try:
@@ -188,7 +189,7 @@ async def logout(
logger.info("Found user from access token: %s", user.email)
except (HTTPException, Exception) as e:
logger.info("Access token validation failed: %s", str(e))
# If no user found, try refresh token
if not user and refresh_token:
try:
@@ -200,14 +201,14 @@ async def logout(
logger.info("Found user from refresh token: %s", user.email)
except (HTTPException, Exception) as e:
logger.info("Refresh token validation failed: %s", str(e))
# If we found a user, revoke their refresh token
if user:
await auth_service.revoke_refresh_token(user)
logger.info("Successfully revoked refresh token for user: %s", user.email)
else:
logger.info("No user found, skipping token revocation")
# Always clear both cookies regardless of token validity
response.delete_cookie(
key="access_token",
@@ -221,5 +222,5 @@ async def logout(
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
return {"message": "Successfully logged out"}

111
app/api/v1/oauth.py Normal file
View File

@@ -0,0 +1,111 @@
"""OAuth2 authentication endpoints."""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from fastapi.responses import RedirectResponse
from app.core.config import settings
from app.core.dependencies import get_auth_service, get_oauth_service
from app.core.logging import get_logger
from app.services.auth import AuthService
from app.services.oauth import OAuthService
router = APIRouter()
logger = get_logger(__name__)
@router.get("/{provider}/authorize")
async def oauth_authorize(
provider: str,
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
) -> dict[str, str]:
"""Get OAuth authorization URL."""
try:
# Generate secure state parameter
state = oauth_service.generate_state()
# Get authorization URL
auth_url = oauth_service.get_authorization_url(provider, state)
except HTTPException:
raise
except Exception as e:
logger.exception("OAuth authorization failed for provider: %s", provider)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth authorization failed",
) from e
else:
return {
"authorization_url": auth_url,
"state": state,
}
@router.get("/{provider}/callback")
async def oauth_callback(
provider: str,
response: Response,
code: Annotated[str, Query()],
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> RedirectResponse:
"""Handle OAuth callback."""
try:
# Handle OAuth callback and get user info
oauth_user_info = await oauth_service.handle_callback(provider, code)
# Perform OAuth login (link or create user)
auth_response = await auth_service.oauth_login(oauth_user_info)
# Create and store refresh token
user = await auth_service.get_current_user(auth_response.user.id)
refresh_token = await auth_service.create_and_store_refresh_token(user)
# Set HTTP-only cookies for both tokens
response.set_cookie(
key="access_token",
value=auth_response.token.access_token,
max_age=auth_response.token.expires_in,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
httponly=True,
secure=settings.COOKIE_SECURE,
samesite=settings.COOKIE_SAMESITE,
)
logger.info(
"OAuth login successful for user: %s via %s",
auth_response.user.email,
provider,
)
# Redirect back to frontend after successful authentication
return RedirectResponse(
url="http://localhost:8001/?auth=success",
status_code=302,
)
except HTTPException:
raise
except Exception as e:
logger.exception("OAuth callback failed for provider: %s", provider)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth callback failed",
) from e
@router.get("/providers")
async def get_oauth_providers() -> dict[str, list[str]]:
"""Get list of available OAuth providers."""
return {
"providers": ["google", "github"],
}

View File

@@ -13,13 +13,16 @@ class Settings(BaseSettings):
extra="ignore",
)
# Application Configuration
HOST: str = "localhost"
PORT: int = 8000
RELOAD: bool = True
# Database Configuration
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
DATABASE_ECHO: bool = False
# Logging Configuration
LOG_LEVEL: str = "info"
LOG_FILE: str = "logs/app.log"
LOG_MAX_SIZE: int = 10 * 1024 * 1024
@@ -31,12 +34,19 @@ class Settings(BaseSettings):
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
)
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # Shorter-lived access token
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Longer-lived refresh token
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7
# Cookie Configuration
COOKIE_SECURE: bool = True # Set to False for development without HTTPS
COOKIE_SECURE: bool = True
COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax"
# OAuth2 Configuration
GOOGLE_CLIENT_ID: str = ""
GOOGLE_CLIENT_SECRET: str = ""
GITHUB_CLIENT_ID: str = ""
GITHUB_CLIENT_SECRET: str = ""
OAUTH_REDIRECT_URL: str = "http://localhost:8001/auth/callback"
settings = Settings()

View File

@@ -1,6 +1,6 @@
"""FastAPI dependencies."""
from typing import Annotated, NoReturn, cast
from typing import Annotated, cast
from fastapi import Cookie, Depends, HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -9,27 +9,12 @@ from app.core.database import get_db
from app.core.logging import get_logger
from app.models.user import User
from app.services.auth import AuthService
from app.services.oauth import OAuthService
from app.utils.auth import JWTUtils
logger = get_logger(__name__)
def _raise_invalid_token_error() -> NoReturn:
"""Raise an invalid token HTTP exception."""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
)
def _raise_auth_error() -> NoReturn:
"""Raise an authentication HTTP exception."""
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
async def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
@@ -37,6 +22,13 @@ async def get_auth_service(
return AuthService(session)
async def get_oauth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> OAuthService:
"""Get the OAuth service."""
return OAuthService(session)
async def get_current_user(
access_token: Annotated[str | None, Cookie()],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
@@ -46,7 +38,10 @@ async def get_current_user(
# Check if access token cookie exists
if not access_token:
logger.warning("No access token cookie found")
_raise_auth_error()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
)
# Decode the JWT token
payload = JWTUtils.decode_access_token(access_token)
@@ -54,7 +49,10 @@ async def get_current_user(
# Extract user ID from token
user_id_str = payload.get("sub")
if not user_id_str:
_raise_invalid_token_error()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token payload",
)
# At this point user_id_str is guaranteed to be truthy, safe to cast
user_id_str = cast("str", user_id_str)
@@ -74,9 +72,12 @@ async def get_current_user(
except HTTPException:
# Re-raise HTTPExceptions without wrapping them
raise
except Exception:
except Exception as e:
logger.exception("Failed to authenticate user")
_raise_auth_error()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
) from e
async def get_current_active_user(

View File

@@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api import api_router
from app.core.database import init_db
@@ -28,6 +29,15 @@ def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
app = FastAPI(lifespan=lifespan)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:8001"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(LoggingMiddleware)
# Include API routes

View File

@@ -0,0 +1,117 @@
"""Repository for user OAuth operations."""
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.user_oauth import UserOauth
logger = get_logger(__name__)
class UserOauthRepository:
"""Repository for user OAuth operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize repository with database session."""
self.session = session
async def get_by_provider_user_id(
self,
provider: str,
provider_user_id: str,
) -> UserOauth | None:
"""Get user OAuth by provider and provider user ID."""
try:
statement = select(UserOauth).where(
UserOauth.provider == provider,
UserOauth.provider_user_id == provider_user_id,
)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception(
"Failed to get user OAuth by provider user ID: %s:%s",
provider,
provider_user_id,
)
raise
async def get_by_user_id_and_provider(
self,
user_id: int,
provider: str,
) -> UserOauth | None:
"""Get user OAuth by user ID and provider."""
try:
statement = select(UserOauth).where(
UserOauth.user_id == user_id,
UserOauth.provider == provider,
)
result = await self.session.exec(statement)
except Exception:
logger.exception(
"Failed to get user OAuth by user ID and provider: %s:%s",
user_id,
provider,
)
raise
else:
return result.first()
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
"""Create a new user OAuth record."""
try:
oauth = UserOauth(**oauth_data)
self.session.add(oauth)
await self.session.commit()
await self.session.refresh(oauth)
logger.info(
"Created OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to create user OAuth")
raise
else:
return oauth
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
"""Update a user OAuth record."""
try:
for key, value in update_data.items():
setattr(oauth, key, value)
self.session.add(oauth)
await self.session.commit()
await self.session.refresh(oauth)
logger.info(
"Updated OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to update user OAuth")
raise
else:
return oauth
async def delete(self, oauth: UserOauth) -> None:
"""Delete a user OAuth record."""
try:
await self.session.delete(oauth)
await self.session.commit()
logger.info(
"Deleted OAuth link for user %s with provider %s",
oauth.user_id,
oauth.provider,
)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete user OAuth")
raise

View File

@@ -10,6 +10,7 @@ from app.core.config import settings
from app.core.logging import get_logger
from app.models.user import User
from app.repositories.user import UserRepository
from app.repositories.user_oauth import UserOauthRepository
from app.schemas.auth import (
AuthResponse,
TokenResponse,
@@ -17,6 +18,7 @@ from app.schemas.auth import (
UserRegisterRequest,
UserResponse,
)
from app.services.oauth import OAuthUserInfo
from app.utils.auth import JWTUtils, PasswordUtils
logger = get_logger(__name__)
@@ -29,6 +31,7 @@ class AuthService:
"""Initialize the auth service."""
self.session = session
self.user_repo = UserRepository(session)
self.oauth_repo = UserOauthRepository(session)
async def register(self, request: UserRegisterRequest) -> AuthResponse:
"""Register a new user."""
@@ -203,7 +206,7 @@ class AuthService:
# Check if refresh token is expired
if user.refresh_token_expires_at and datetime.now(
UTC
UTC,
) > user.refresh_token_expires_at.replace(tzinfo=UTC):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -272,3 +275,127 @@ class AuthService:
created_at=user.created_at,
updated_at=user.updated_at,
)
async def oauth_login(self, oauth_user_info: OAuthUserInfo) -> AuthResponse:
"""Handle OAuth login - link or create user."""
logger.info(
"OAuth login attempt for %s with provider %s",
oauth_user_info.email,
oauth_user_info.provider,
)
# Check if user already has OAuth link for this provider
existing_oauth = await self.oauth_repo.get_by_provider_user_id(
oauth_user_info.provider,
oauth_user_info.provider_user_id,
)
if existing_oauth:
# User exists with this OAuth provider, get the user
user = await self.user_repo.get_by_id(existing_oauth.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Linked user not found",
)
# Refresh user to avoid greenlet issues
await self.session.refresh(user)
# Update OAuth record with latest info
oauth_update_data = {
"email": oauth_user_info.email,
"name": oauth_user_info.name,
"picture": oauth_user_info.picture,
}
await self.oauth_repo.update(existing_oauth, oauth_update_data)
logger.info(
"OAuth login successful for existing user: %s",
oauth_user_info.email,
)
else:
# Check if user exists by email
user = await self.user_repo.get_by_email(oauth_user_info.email)
if user:
# Refresh user to avoid greenlet issues
await self.session.refresh(user)
# Store user picture value to avoid greenlet issues later
current_user_picture = user.picture
# Link existing user to OAuth provider
oauth_data = {
"user_id": user.id,
"provider": oauth_user_info.provider,
"provider_user_id": oauth_user_info.provider_user_id,
"email": oauth_user_info.email,
"name": oauth_user_info.name,
"picture": oauth_user_info.picture,
}
await self.oauth_repo.create(oauth_data)
# Update user profile with OAuth info if needed
user_update_data = {}
if not current_user_picture and oauth_user_info.picture:
user_update_data["picture"] = oauth_user_info.picture
if user_update_data:
await self.user_repo.update(user, user_update_data)
# Refresh user after update to avoid greenlet issues
await self.session.refresh(user)
logger.info(
"Linked existing user %s to OAuth provider %s",
oauth_user_info.email,
oauth_user_info.provider,
)
else:
# Create new user
user_data = {
"email": oauth_user_info.email,
"name": oauth_user_info.name,
"picture": oauth_user_info.picture,
"is_active": True,
# No password for OAuth users
"password_hash": None,
}
user = await self.user_repo.create(user_data)
# Create OAuth link
oauth_data = {
"user_id": user.id,
"provider": oauth_user_info.provider,
"provider_user_id": oauth_user_info.provider_user_id,
"email": oauth_user_info.email,
"name": oauth_user_info.name,
"picture": oauth_user_info.picture,
}
await self.oauth_repo.create(oauth_data)
logger.info(
"Created new user %s from OAuth provider %s",
oauth_user_info.email,
oauth_user_info.provider,
)
# Refresh user to avoid greenlet issues and check if user is active
await self.session.refresh(user)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
# Generate access token
token = self._create_access_token(user)
# Create response
user_response = await self.create_user_response(user)
logger.info(
"OAuth login completed successfully for user: %s",
oauth_user_info.email,
)
return AuthResponse(user=user_response, token=token)

329
app/services/oauth.py Normal file
View File

@@ -0,0 +1,329 @@
"""OAuth2 service for external authentication providers."""
import secrets
from abc import ABC, abstractmethod
from urllib.parse import urlencode
import httpx
from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
logger = get_logger(__name__)
class OAuthUserInfo:
"""OAuth user information."""
def __init__(
self,
provider: str,
provider_user_id: str,
email: str,
name: str,
picture: str | None = None,
) -> None:
"""Initialize OAuth user info."""
self.provider = provider
self.provider_user_id = provider_user_id
self.email = email
self.name = name
self.picture = picture
class OAuthProvider(ABC):
"""Abstract base class for OAuth providers."""
def __init__(self, client_id: str, client_secret: str) -> None:
"""Initialize OAuth provider with client ID and secret."""
self.client_id = client_id
self.client_secret = client_secret
@property
@abstractmethod
def provider_name(self) -> str:
"""Return the provider name."""
@property
@abstractmethod
def authorization_url(self) -> str:
"""Return the authorization URL."""
@property
@abstractmethod
def token_url(self) -> str:
"""Return the token URL."""
@property
@abstractmethod
def user_info_url(self) -> str:
"""Return the user info URL."""
@property
@abstractmethod
def scope(self) -> str:
"""Return the required scopes."""
def get_authorization_url(self, state: str) -> str:
"""Generate authorization URL with state parameter."""
# Construct provider-specific redirect URI
redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
)
params = {
"client_id": self.client_id,
"redirect_uri": redirect_uri,
"scope": self.scope,
"response_type": "code",
"state": state,
}
return f"{self.authorization_url}?{urlencode(params)}"
async def exchange_code_for_token(self, code: str) -> str:
"""Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request)
redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
)
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"redirect_uri": redirect_uri,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.token_url,
data=data,
headers={"Accept": "application/json"},
)
if response.status_code != status.HTTP_200_OK:
logger.error("Failed to exchange code for token: %s", response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to exchange authorization code",
)
token_data = response.json()
return str(token_data["access_token"])
@abstractmethod
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
"""Get user information from the provider."""
class GoogleOAuthProvider(OAuthProvider):
"""Google OAuth provider."""
@property
def provider_name(self) -> str:
"""Return the provider name."""
return "google"
@property
def authorization_url(self) -> str:
"""Return the authorization URL."""
return "https://accounts.google.com/o/oauth2/v2/auth"
@property
def token_url(self) -> str:
"""Return the token URL."""
return "https://oauth2.googleapis.com/token"
@property
def user_info_url(self) -> str:
"""Return the user info URL."""
return "https://www.googleapis.com/oauth2/v2/userinfo"
@property
def scope(self) -> str:
"""Return the required scopes."""
return "openid email profile"
async def exchange_code_for_token(self, code: str) -> str:
"""Exchange authorization code for access token."""
# Construct provider-specific redirect URI (must match authorization request)
redirect_uri = (
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
)
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.token_url,
data=data,
headers={"Accept": "application/json"},
)
if response.status_code != status.HTTP_200_OK:
logger.error("Failed to exchange code for token: %s", response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to exchange authorization code",
)
token_data = response.json()
return str(token_data["access_token"])
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
"""Get user information from Google."""
async with httpx.AsyncClient() as client:
response = await client.get(
self.user_info_url,
headers={"Authorization": f"Bearer {access_token}"},
)
if response.status_code != status.HTTP_200_OK:
logger.error("Failed to get user info: %s", response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to get user information",
)
user_data = response.json()
return OAuthUserInfo(
provider=self.provider_name,
provider_user_id=user_data["id"],
email=user_data["email"],
name=user_data["name"],
picture=user_data.get("picture"),
)
class GitHubOAuthProvider(OAuthProvider):
"""GitHub OAuth provider."""
@property
def provider_name(self) -> str:
"""Return the provider name."""
return "github"
@property
def authorization_url(self) -> str:
"""Return the authorization URL."""
return "https://github.com/login/oauth/authorize"
@property
def token_url(self) -> str:
"""Return the token URL."""
return "https://github.com/login/oauth/access_token"
@property
def user_info_url(self) -> str:
"""Return the user info URL."""
return "https://api.github.com/user"
@property
def scope(self) -> str:
"""Return the required scopes."""
return "user:email"
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
"""Get user information from GitHub."""
async with httpx.AsyncClient() as client:
# Get user profile
user_response = await client.get(
self.user_info_url,
headers={"Authorization": f"Bearer {access_token}"},
)
if user_response.status_code != status.HTTP_200_OK:
logger.error("Failed to get user info: %s", user_response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to get user information",
)
user_data = user_response.json()
# Get user email (GitHub doesn't include email in basic profile)
email_response = await client.get(
"https://api.github.com/user/emails",
headers={"Authorization": f"Bearer {access_token}"},
)
if email_response.status_code != status.HTTP_200_OK:
logger.error("Failed to get user emails: %s", email_response.text)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Failed to get user email",
)
emails = email_response.json()
primary_email = next(
(email["email"] for email in emails if email["primary"]),
None,
)
if not primary_email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No primary email found",
)
return OAuthUserInfo(
provider=self.provider_name,
provider_user_id=str(user_data["id"]),
email=primary_email,
name=user_data.get("name") or user_data["login"],
picture=user_data.get("avatar_url"),
)
class OAuthService:
"""Service for handling OAuth authentication."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize OAuth service with database session."""
self.session = session
self.providers = {
"google": GoogleOAuthProvider(
settings.GOOGLE_CLIENT_ID,
settings.GOOGLE_CLIENT_SECRET,
),
"github": GitHubOAuthProvider(
settings.GITHUB_CLIENT_ID,
settings.GITHUB_CLIENT_SECRET,
),
}
def get_provider(self, provider_name: str) -> OAuthProvider:
"""Get OAuth provider by name."""
if provider_name not in self.providers:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider_name}",
)
return self.providers[provider_name]
def generate_state(self) -> str:
"""Generate a secure state parameter for OAuth flow."""
return secrets.token_urlsafe(32)
def get_authorization_url(self, provider_name: str, state: str) -> str:
"""Get authorization URL for the specified provider."""
provider = self.get_provider(provider_name)
return provider.get_authorization_url(state)
async def handle_callback(self, provider_name: str, code: str) -> OAuthUserInfo:
"""Handle OAuth callback and return user info."""
provider = self.get_provider(provider_name)
# Exchange code for access token
access_token = await provider.exchange_code_for_token(code)
# Get user information
return await provider.get_user_info(access_token)