feat: Implement OAuth2 authentication with Google and GitHub
- Added OAuth2 endpoints for Google and GitHub authentication. - Created OAuth service to handle provider interactions and user info retrieval. - Implemented user OAuth repository for managing user OAuth links in the database. - Updated auth service to support linking existing users and creating new users via OAuth. - Added CORS middleware to allow frontend access. - Created tests for OAuth endpoints and service functionality. - Introduced environment configuration for OAuth client IDs and secrets. - Added logging for OAuth operations and error handling.
This commit is contained in:
29
.env.template
Normal file
29
.env.template
Normal file
@@ -0,0 +1,29 @@
|
||||
# Application Configuration
|
||||
HOST=localhost
|
||||
PORT=8000
|
||||
RELOAD=true
|
||||
|
||||
# Database Configuration
|
||||
DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db
|
||||
DATABASE_ECHO=false
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=info
|
||||
LOG_FILE=logs/app.log
|
||||
LOG_MAX_SIZE=10485760
|
||||
LOG_BACKUP_COUNT=5
|
||||
|
||||
# JWT Configuration
|
||||
JWT_SECRET_KEY=your-secret-key-change-in-production
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# Cookie Configuration
|
||||
COOKIE_SECURE=false
|
||||
|
||||
# OAuth2 Configuration
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
OAUTH_REDIRECT_URL=http://localhost:8001/auth/callback
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1 import auth, main
|
||||
from app.api.v1 import auth, main, oauth
|
||||
|
||||
# V1 API router with v1 prefix
|
||||
api_router = APIRouter(prefix="/v1")
|
||||
@@ -10,3 +10,4 @@ api_router = APIRouter(prefix="/v1")
|
||||
# Include all route modules
|
||||
api_router.include_router(main.router, tags=["main"])
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
|
||||
api_router.include_router(oauth.router, prefix="/oauth", tags=["oauth"])
|
||||
|
||||
@@ -54,8 +54,6 @@ async def register(
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
|
||||
# Return only user data, tokens are now in cookies
|
||||
return auth_response.user
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -64,6 +62,8 @@ async def register(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Registration failed",
|
||||
) from e
|
||||
else:
|
||||
return auth_response.user
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
@@ -101,8 +101,6 @@ async def login(
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
|
||||
# Return only user data, tokens are now in cookies
|
||||
return auth_response.user
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -111,6 +109,8 @@ async def login(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Login failed",
|
||||
) from e
|
||||
else:
|
||||
return auth_response.user
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
@@ -156,7 +156,6 @@ async def refresh_token(
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
|
||||
return {"message": "Token refreshed successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -165,6 +164,8 @@ async def refresh_token(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Token refresh failed",
|
||||
) from e
|
||||
else:
|
||||
return {"message": "Token refreshed successfully"}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
|
||||
111
app/api/v1/oauth.py
Normal file
111
app/api/v1/oauth.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""OAuth2 authentication endpoints."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.dependencies import get_auth_service, get_oauth_service
|
||||
from app.core.logging import get_logger
|
||||
from app.services.auth import AuthService
|
||||
from app.services.oauth import OAuthService
|
||||
|
||||
router = APIRouter()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get("/{provider}/authorize")
|
||||
async def oauth_authorize(
|
||||
provider: str,
|
||||
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
||||
) -> dict[str, str]:
|
||||
"""Get OAuth authorization URL."""
|
||||
try:
|
||||
# Generate secure state parameter
|
||||
state = oauth_service.generate_state()
|
||||
|
||||
# Get authorization URL
|
||||
auth_url = oauth_service.get_authorization_url(provider, state)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("OAuth authorization failed for provider: %s", provider)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth authorization failed",
|
||||
) from e
|
||||
else:
|
||||
return {
|
||||
"authorization_url": auth_url,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{provider}/callback")
|
||||
async def oauth_callback(
|
||||
provider: str,
|
||||
response: Response,
|
||||
code: Annotated[str, Query()],
|
||||
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> RedirectResponse:
|
||||
"""Handle OAuth callback."""
|
||||
try:
|
||||
# Handle OAuth callback and get user info
|
||||
oauth_user_info = await oauth_service.handle_callback(provider, code)
|
||||
|
||||
# Perform OAuth login (link or create user)
|
||||
auth_response = await auth_service.oauth_login(oauth_user_info)
|
||||
|
||||
# Create and store refresh token
|
||||
user = await auth_service.get_current_user(auth_response.user.id)
|
||||
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
||||
|
||||
# Set HTTP-only cookies for both tokens
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=auth_response.token.access_token,
|
||||
max_age=auth_response.token.expires_in,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=refresh_token,
|
||||
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
||||
httponly=True,
|
||||
secure=settings.COOKIE_SECURE,
|
||||
samesite=settings.COOKIE_SAMESITE,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"OAuth login successful for user: %s via %s",
|
||||
auth_response.user.email,
|
||||
provider,
|
||||
)
|
||||
|
||||
# Redirect back to frontend after successful authentication
|
||||
return RedirectResponse(
|
||||
url="http://localhost:8001/?auth=success",
|
||||
status_code=302,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("OAuth callback failed for provider: %s", provider)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="OAuth callback failed",
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/providers")
|
||||
async def get_oauth_providers() -> dict[str, list[str]]:
|
||||
"""Get list of available OAuth providers."""
|
||||
return {
|
||||
"providers": ["google", "github"],
|
||||
}
|
||||
@@ -13,13 +13,16 @@ class Settings(BaseSettings):
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Application Configuration
|
||||
HOST: str = "localhost"
|
||||
PORT: int = 8000
|
||||
RELOAD: bool = True
|
||||
|
||||
# Database Configuration
|
||||
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
|
||||
DATABASE_ECHO: bool = False
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL: str = "info"
|
||||
LOG_FILE: str = "logs/app.log"
|
||||
LOG_MAX_SIZE: int = 10 * 1024 * 1024
|
||||
@@ -31,12 +34,19 @@ class Settings(BaseSettings):
|
||||
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
|
||||
)
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # Shorter-lived access token
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Longer-lived refresh token
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||
|
||||
# Cookie Configuration
|
||||
COOKIE_SECURE: bool = True # Set to False for development without HTTPS
|
||||
COOKIE_SECURE: bool = True
|
||||
COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax"
|
||||
|
||||
# OAuth2 Configuration
|
||||
GOOGLE_CLIENT_ID: str = ""
|
||||
GOOGLE_CLIENT_SECRET: str = ""
|
||||
GITHUB_CLIENT_ID: str = ""
|
||||
GITHUB_CLIENT_SECRET: str = ""
|
||||
OAUTH_REDIRECT_URL: str = "http://localhost:8001/auth/callback"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""FastAPI dependencies."""
|
||||
|
||||
from typing import Annotated, NoReturn, cast
|
||||
from typing import Annotated, cast
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -9,27 +9,12 @@ from app.core.database import get_db
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.services.auth import AuthService
|
||||
from app.services.oauth import OAuthService
|
||||
from app.utils.auth import JWTUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _raise_invalid_token_error() -> NoReturn:
|
||||
"""Raise an invalid token HTTP exception."""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
)
|
||||
|
||||
|
||||
def _raise_auth_error() -> NoReturn:
|
||||
"""Raise an authentication HTTP exception."""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
|
||||
|
||||
async def get_auth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AuthService:
|
||||
@@ -37,6 +22,13 @@ async def get_auth_service(
|
||||
return AuthService(session)
|
||||
|
||||
|
||||
async def get_oauth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> OAuthService:
|
||||
"""Get the OAuth service."""
|
||||
return OAuthService(session)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
access_token: Annotated[str | None, Cookie()],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
@@ -46,7 +38,10 @@ async def get_current_user(
|
||||
# Check if access token cookie exists
|
||||
if not access_token:
|
||||
logger.warning("No access token cookie found")
|
||||
_raise_auth_error()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
)
|
||||
|
||||
# Decode the JWT token
|
||||
payload = JWTUtils.decode_access_token(access_token)
|
||||
@@ -54,7 +49,10 @@ async def get_current_user(
|
||||
# Extract user ID from token
|
||||
user_id_str = payload.get("sub")
|
||||
if not user_id_str:
|
||||
_raise_invalid_token_error()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token payload",
|
||||
)
|
||||
|
||||
# At this point user_id_str is guaranteed to be truthy, safe to cast
|
||||
user_id_str = cast("str", user_id_str)
|
||||
@@ -74,9 +72,12 @@ async def get_current_user(
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions without wrapping them
|
||||
raise
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to authenticate user")
|
||||
_raise_auth_error()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
) from e
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
|
||||
10
app/main.py
10
app/main.py
@@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api import api_router
|
||||
from app.core.database import init_db
|
||||
@@ -28,6 +29,15 @@ def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:8001"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
# Include API routes
|
||||
|
||||
117
app/repositories/user_oauth.py
Normal file
117
app/repositories/user_oauth.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Repository for user OAuth operations."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user_oauth import UserOauth
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class UserOauthRepository:
|
||||
"""Repository for user OAuth operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize repository with database session."""
|
||||
self.session = session
|
||||
|
||||
async def get_by_provider_user_id(
|
||||
self,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
) -> UserOauth | None:
|
||||
"""Get user OAuth by provider and provider user ID."""
|
||||
try:
|
||||
statement = select(UserOauth).where(
|
||||
UserOauth.provider == provider,
|
||||
UserOauth.provider_user_id == provider_user_id,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return result.first()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get user OAuth by provider user ID: %s:%s",
|
||||
provider,
|
||||
provider_user_id,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_user_id_and_provider(
|
||||
self,
|
||||
user_id: int,
|
||||
provider: str,
|
||||
) -> UserOauth | None:
|
||||
"""Get user OAuth by user ID and provider."""
|
||||
try:
|
||||
statement = select(UserOauth).where(
|
||||
UserOauth.user_id == user_id,
|
||||
UserOauth.provider == provider,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to get user OAuth by user ID and provider: %s:%s",
|
||||
user_id,
|
||||
provider,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
return result.first()
|
||||
|
||||
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
|
||||
"""Create a new user OAuth record."""
|
||||
try:
|
||||
oauth = UserOauth(**oauth_data)
|
||||
self.session.add(oauth)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(oauth)
|
||||
logger.info(
|
||||
"Created OAuth link for user %s with provider %s",
|
||||
oauth.user_id,
|
||||
oauth.provider,
|
||||
)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to create user OAuth")
|
||||
raise
|
||||
else:
|
||||
return oauth
|
||||
|
||||
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
|
||||
"""Update a user OAuth record."""
|
||||
try:
|
||||
for key, value in update_data.items():
|
||||
setattr(oauth, key, value)
|
||||
|
||||
self.session.add(oauth)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(oauth)
|
||||
logger.info(
|
||||
"Updated OAuth link for user %s with provider %s",
|
||||
oauth.user_id,
|
||||
oauth.provider,
|
||||
)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to update user OAuth")
|
||||
raise
|
||||
else:
|
||||
return oauth
|
||||
|
||||
async def delete(self, oauth: UserOauth) -> None:
|
||||
"""Delete a user OAuth record."""
|
||||
try:
|
||||
await self.session.delete(oauth)
|
||||
await self.session.commit()
|
||||
logger.info(
|
||||
"Deleted OAuth link for user %s with provider %s",
|
||||
oauth.user_id,
|
||||
oauth.provider,
|
||||
)
|
||||
except Exception:
|
||||
await self.session.rollback()
|
||||
logger.exception("Failed to delete user OAuth")
|
||||
raise
|
||||
@@ -10,6 +10,7 @@ from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.repositories.user_oauth import UserOauthRepository
|
||||
from app.schemas.auth import (
|
||||
AuthResponse,
|
||||
TokenResponse,
|
||||
@@ -17,6 +18,7 @@ from app.schemas.auth import (
|
||||
UserRegisterRequest,
|
||||
UserResponse,
|
||||
)
|
||||
from app.services.oauth import OAuthUserInfo
|
||||
from app.utils.auth import JWTUtils, PasswordUtils
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -29,6 +31,7 @@ class AuthService:
|
||||
"""Initialize the auth service."""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
self.oauth_repo = UserOauthRepository(session)
|
||||
|
||||
async def register(self, request: UserRegisterRequest) -> AuthResponse:
|
||||
"""Register a new user."""
|
||||
@@ -203,7 +206,7 @@ class AuthService:
|
||||
|
||||
# Check if refresh token is expired
|
||||
if user.refresh_token_expires_at and datetime.now(
|
||||
UTC
|
||||
UTC,
|
||||
) > user.refresh_token_expires_at.replace(tzinfo=UTC):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -272,3 +275,127 @@ class AuthService:
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
)
|
||||
|
||||
async def oauth_login(self, oauth_user_info: OAuthUserInfo) -> AuthResponse:
|
||||
"""Handle OAuth login - link or create user."""
|
||||
logger.info(
|
||||
"OAuth login attempt for %s with provider %s",
|
||||
oauth_user_info.email,
|
||||
oauth_user_info.provider,
|
||||
)
|
||||
|
||||
# Check if user already has OAuth link for this provider
|
||||
existing_oauth = await self.oauth_repo.get_by_provider_user_id(
|
||||
oauth_user_info.provider,
|
||||
oauth_user_info.provider_user_id,
|
||||
)
|
||||
|
||||
if existing_oauth:
|
||||
# User exists with this OAuth provider, get the user
|
||||
user = await self.user_repo.get_by_id(existing_oauth.user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Linked user not found",
|
||||
)
|
||||
|
||||
# Refresh user to avoid greenlet issues
|
||||
await self.session.refresh(user)
|
||||
|
||||
# Update OAuth record with latest info
|
||||
oauth_update_data = {
|
||||
"email": oauth_user_info.email,
|
||||
"name": oauth_user_info.name,
|
||||
"picture": oauth_user_info.picture,
|
||||
}
|
||||
await self.oauth_repo.update(existing_oauth, oauth_update_data)
|
||||
|
||||
logger.info(
|
||||
"OAuth login successful for existing user: %s",
|
||||
oauth_user_info.email,
|
||||
)
|
||||
else:
|
||||
# Check if user exists by email
|
||||
user = await self.user_repo.get_by_email(oauth_user_info.email)
|
||||
|
||||
if user:
|
||||
# Refresh user to avoid greenlet issues
|
||||
await self.session.refresh(user)
|
||||
|
||||
# Store user picture value to avoid greenlet issues later
|
||||
current_user_picture = user.picture
|
||||
|
||||
# Link existing user to OAuth provider
|
||||
oauth_data = {
|
||||
"user_id": user.id,
|
||||
"provider": oauth_user_info.provider,
|
||||
"provider_user_id": oauth_user_info.provider_user_id,
|
||||
"email": oauth_user_info.email,
|
||||
"name": oauth_user_info.name,
|
||||
"picture": oauth_user_info.picture,
|
||||
}
|
||||
await self.oauth_repo.create(oauth_data)
|
||||
|
||||
# Update user profile with OAuth info if needed
|
||||
user_update_data = {}
|
||||
if not current_user_picture and oauth_user_info.picture:
|
||||
user_update_data["picture"] = oauth_user_info.picture
|
||||
if user_update_data:
|
||||
await self.user_repo.update(user, user_update_data)
|
||||
# Refresh user after update to avoid greenlet issues
|
||||
await self.session.refresh(user)
|
||||
|
||||
logger.info(
|
||||
"Linked existing user %s to OAuth provider %s",
|
||||
oauth_user_info.email,
|
||||
oauth_user_info.provider,
|
||||
)
|
||||
else:
|
||||
# Create new user
|
||||
user_data = {
|
||||
"email": oauth_user_info.email,
|
||||
"name": oauth_user_info.name,
|
||||
"picture": oauth_user_info.picture,
|
||||
"is_active": True,
|
||||
# No password for OAuth users
|
||||
"password_hash": None,
|
||||
}
|
||||
|
||||
user = await self.user_repo.create(user_data)
|
||||
|
||||
# Create OAuth link
|
||||
oauth_data = {
|
||||
"user_id": user.id,
|
||||
"provider": oauth_user_info.provider,
|
||||
"provider_user_id": oauth_user_info.provider_user_id,
|
||||
"email": oauth_user_info.email,
|
||||
"name": oauth_user_info.name,
|
||||
"picture": oauth_user_info.picture,
|
||||
}
|
||||
await self.oauth_repo.create(oauth_data)
|
||||
|
||||
logger.info(
|
||||
"Created new user %s from OAuth provider %s",
|
||||
oauth_user_info.email,
|
||||
oauth_user_info.provider,
|
||||
)
|
||||
|
||||
# Refresh user to avoid greenlet issues and check if user is active
|
||||
await self.session.refresh(user)
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Account is deactivated",
|
||||
)
|
||||
|
||||
# Generate access token
|
||||
token = self._create_access_token(user)
|
||||
|
||||
# Create response
|
||||
user_response = await self.create_user_response(user)
|
||||
|
||||
logger.info(
|
||||
"OAuth login completed successfully for user: %s",
|
||||
oauth_user_info.email,
|
||||
)
|
||||
return AuthResponse(user=user_response, token=token)
|
||||
|
||||
329
app/services/oauth.py
Normal file
329
app/services/oauth.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""OAuth2 service for external authentication providers."""
|
||||
|
||||
import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException, status
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OAuthUserInfo:
|
||||
"""OAuth user information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str,
|
||||
provider_user_id: str,
|
||||
email: str,
|
||||
name: str,
|
||||
picture: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize OAuth user info."""
|
||||
self.provider = provider
|
||||
self.provider_user_id = provider_user_id
|
||||
self.email = email
|
||||
self.name = name
|
||||
self.picture = picture
|
||||
|
||||
|
||||
class OAuthProvider(ABC):
|
||||
"""Abstract base class for OAuth providers."""
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||
"""Initialize OAuth provider with client ID and secret."""
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def authorization_url(self) -> str:
|
||||
"""Return the authorization URL."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def token_url(self) -> str:
|
||||
"""Return the token URL."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def user_info_url(self) -> str:
|
||||
"""Return the user info URL."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def scope(self) -> str:
|
||||
"""Return the required scopes."""
|
||||
|
||||
def get_authorization_url(self, state: str) -> str:
|
||||
"""Generate authorization URL with state parameter."""
|
||||
# Construct provider-specific redirect URI
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": self.scope,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
}
|
||||
return f"{self.authorization_url}?{urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_token(self, code: str) -> str:
|
||||
"""Exchange authorization code for access token."""
|
||||
# Construct provider-specific redirect URI (must match authorization request)
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to exchange code for token: %s", response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to exchange authorization code",
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
return str(token_data["access_token"])
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Get user information from the provider."""
|
||||
|
||||
|
||||
class GoogleOAuthProvider(OAuthProvider):
|
||||
"""Google OAuth provider."""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "google"
|
||||
|
||||
@property
|
||||
def authorization_url(self) -> str:
|
||||
"""Return the authorization URL."""
|
||||
return "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
"""Return the token URL."""
|
||||
return "https://oauth2.googleapis.com/token"
|
||||
|
||||
@property
|
||||
def user_info_url(self) -> str:
|
||||
"""Return the user info URL."""
|
||||
return "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
@property
|
||||
def scope(self) -> str:
|
||||
"""Return the required scopes."""
|
||||
return "openid email profile"
|
||||
|
||||
async def exchange_code_for_token(self, code: str) -> str:
|
||||
"""Exchange authorization code for access token."""
|
||||
# Construct provider-specific redirect URI (must match authorization request)
|
||||
redirect_uri = (
|
||||
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||
)
|
||||
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to exchange code for token: %s", response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to exchange authorization code",
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
return str(token_data["access_token"])
|
||||
|
||||
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Get user information from Google."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self.user_info_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
if response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to get user info: %s", response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to get user information",
|
||||
)
|
||||
|
||||
user_data = response.json()
|
||||
return OAuthUserInfo(
|
||||
provider=self.provider_name,
|
||||
provider_user_id=user_data["id"],
|
||||
email=user_data["email"],
|
||||
name=user_data["name"],
|
||||
picture=user_data.get("picture"),
|
||||
)
|
||||
|
||||
|
||||
class GitHubOAuthProvider(OAuthProvider):
|
||||
"""GitHub OAuth provider."""
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "github"
|
||||
|
||||
@property
|
||||
def authorization_url(self) -> str:
|
||||
"""Return the authorization URL."""
|
||||
return "https://github.com/login/oauth/authorize"
|
||||
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
"""Return the token URL."""
|
||||
return "https://github.com/login/oauth/access_token"
|
||||
|
||||
@property
|
||||
def user_info_url(self) -> str:
|
||||
"""Return the user info URL."""
|
||||
return "https://api.github.com/user"
|
||||
|
||||
@property
|
||||
def scope(self) -> str:
|
||||
"""Return the required scopes."""
|
||||
return "user:email"
|
||||
|
||||
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Get user information from GitHub."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Get user profile
|
||||
user_response = await client.get(
|
||||
self.user_info_url,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
if user_response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to get user info: %s", user_response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to get user information",
|
||||
)
|
||||
|
||||
user_data = user_response.json()
|
||||
|
||||
# Get user email (GitHub doesn't include email in basic profile)
|
||||
email_response = await client.get(
|
||||
"https://api.github.com/user/emails",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
|
||||
if email_response.status_code != status.HTTP_200_OK:
|
||||
logger.error("Failed to get user emails: %s", email_response.text)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Failed to get user email",
|
||||
)
|
||||
|
||||
emails = email_response.json()
|
||||
primary_email = next(
|
||||
(email["email"] for email in emails if email["primary"]),
|
||||
None,
|
||||
)
|
||||
|
||||
if not primary_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No primary email found",
|
||||
)
|
||||
|
||||
return OAuthUserInfo(
|
||||
provider=self.provider_name,
|
||||
provider_user_id=str(user_data["id"]),
|
||||
email=primary_email,
|
||||
name=user_data.get("name") or user_data["login"],
|
||||
picture=user_data.get("avatar_url"),
|
||||
)
|
||||
|
||||
|
||||
class OAuthService:
|
||||
"""Service for handling OAuth authentication."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize OAuth service with database session."""
|
||||
self.session = session
|
||||
self.providers = {
|
||||
"google": GoogleOAuthProvider(
|
||||
settings.GOOGLE_CLIENT_ID,
|
||||
settings.GOOGLE_CLIENT_SECRET,
|
||||
),
|
||||
"github": GitHubOAuthProvider(
|
||||
settings.GITHUB_CLIENT_ID,
|
||||
settings.GITHUB_CLIENT_SECRET,
|
||||
),
|
||||
}
|
||||
|
||||
def get_provider(self, provider_name: str) -> OAuthProvider:
|
||||
"""Get OAuth provider by name."""
|
||||
if provider_name not in self.providers:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported OAuth provider: {provider_name}",
|
||||
)
|
||||
return self.providers[provider_name]
|
||||
|
||||
def generate_state(self) -> str:
|
||||
"""Generate a secure state parameter for OAuth flow."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def get_authorization_url(self, provider_name: str, state: str) -> str:
|
||||
"""Get authorization URL for the specified provider."""
|
||||
provider = self.get_provider(provider_name)
|
||||
return provider.get_authorization_url(state)
|
||||
|
||||
async def handle_callback(self, provider_name: str, code: str) -> OAuthUserInfo:
|
||||
"""Handle OAuth callback and return user info."""
|
||||
provider = self.get_provider(provider_name)
|
||||
|
||||
# Exchange code for access token
|
||||
access_token = await provider.exchange_code_for_token(code)
|
||||
|
||||
# Get user information
|
||||
return await provider.get_user_info(access_token)
|
||||
@@ -9,6 +9,7 @@ dependencies = [
|
||||
"bcrypt==4.3.0",
|
||||
"email-validator==2.2.0",
|
||||
"fastapi[standard]==0.116.1",
|
||||
"httpx==0.28.1",
|
||||
"pydantic-settings==2.10.1",
|
||||
"pyjwt==2.10.1",
|
||||
"sqlmodel==0.0.24",
|
||||
@@ -36,7 +37,7 @@ exclude = ["alembic"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["ALL"]
|
||||
ignore = ["D100", "D103"]
|
||||
ignore = ["D100", "D103", "TRY301"]
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
"tests/**/*.py" = ["S101", "S105"]
|
||||
|
||||
151
tests/api/v1/test_oauth_endpoints.py
Normal file
151
tests/api/v1/test_oauth_endpoints.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for OAuth authentication endpoints."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.services.oauth import OAuthUserInfo
|
||||
|
||||
|
||||
class TestOAuthEndpoints:
|
||||
"""Test OAuth API endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_oauth_providers(self, test_client: AsyncClient) -> None:
|
||||
"""Test getting list of OAuth providers."""
|
||||
response = await test_client.get("/api/v1/oauth/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "providers" in data
|
||||
assert "google" in data["providers"]
|
||||
assert "github" in data["providers"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None:
|
||||
"""Test OAuth authorization URL generation for Google."""
|
||||
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
||||
mock_state.return_value = "test_state_123"
|
||||
|
||||
response = await test_client.get("/api/v1/oauth/google/authorize")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
assert "state" in data
|
||||
assert data["state"] == "test_state_123"
|
||||
assert "accounts.google.com" in data["authorization_url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None:
|
||||
"""Test OAuth authorization URL generation for GitHub."""
|
||||
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
||||
mock_state.return_value = "test_state_456"
|
||||
|
||||
response = await test_client.get("/api/v1/oauth/github/authorize")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "authorization_url" in data
|
||||
assert "state" in data
|
||||
assert data["state"] == "test_state_456"
|
||||
assert "github.com" in data["authorization_url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_authorize_invalid_provider(
|
||||
self, test_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test OAuth authorization with invalid provider."""
|
||||
response = await test_client.get("/api/v1/oauth/invalid/authorize")
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "Unsupported OAuth provider" in data["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_new_user(
|
||||
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any]
|
||||
) -> None:
|
||||
"""Test OAuth callback for new user creation."""
|
||||
# Mock OAuth user info
|
||||
mock_user_info = OAuthUserInfo(
|
||||
provider="google",
|
||||
provider_user_id="google_123",
|
||||
email="newuser@gmail.com",
|
||||
name="New User",
|
||||
picture="https://example.com/avatar.jpg",
|
||||
)
|
||||
|
||||
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
||||
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
||||
mock_callback.return_value = mock_user_info
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/v1/oauth/google/callback",
|
||||
params={"code": "auth_code_123", "state": "test_state"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
# OAuth callback should successfully process and redirect to frontend
|
||||
assert response.status_code == 302
|
||||
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
||||
|
||||
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||
# Detailed cookie testing can be done in integration tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_existing_user_link(
|
||||
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any]
|
||||
) -> None:
|
||||
"""Test OAuth callback for linking to existing user."""
|
||||
# Mock OAuth user info with same email as test user
|
||||
mock_user_info = OAuthUserInfo(
|
||||
provider="github",
|
||||
provider_user_id="github_456",
|
||||
email=test_user.email, # Same email as existing user
|
||||
name="Test User",
|
||||
picture="https://github.com/avatar.jpg",
|
||||
)
|
||||
|
||||
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
||||
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
||||
mock_callback.return_value = mock_user_info
|
||||
|
||||
response = await test_client.get(
|
||||
"/api/v1/oauth/github/callback",
|
||||
params={"code": "auth_code_456", "state": "test_state"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
# OAuth callback should successfully process and redirect to frontend
|
||||
assert response.status_code == 302
|
||||
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
||||
|
||||
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||
# Detailed cookie testing can be done in integration tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None:
|
||||
"""Test OAuth callback with missing authorization code."""
|
||||
response = await test_client.get(
|
||||
"/api/v1/oauth/google/callback",
|
||||
params={"state": "test_state"}, # Missing code parameter
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_invalid_provider(
|
||||
self, test_client: AsyncClient
|
||||
) -> None:
|
||||
"""Test OAuth callback with invalid provider."""
|
||||
response = await test_client.get(
|
||||
"/api/v1/oauth/invalid/callback",
|
||||
params={"code": "auth_code_123", "state": "test_state"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "Unsupported OAuth provider" in data["detail"]
|
||||
192
tests/services/test_oauth_service.py
Normal file
192
tests/services/test_oauth_service.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for OAuth service."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.services.oauth import (
|
||||
GitHubOAuthProvider,
|
||||
GoogleOAuthProvider,
|
||||
OAuthService,
|
||||
OAuthUserInfo,
|
||||
)
|
||||
|
||||
|
||||
class TestOAuthService:
|
||||
"""Test OAuth service functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_service_initialization(self, test_session: Any) -> None:
|
||||
"""Test OAuth service initialization."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
|
||||
assert "google" in oauth_service.providers
|
||||
assert "github" in oauth_service.providers
|
||||
assert isinstance(oauth_service.providers["google"], GoogleOAuthProvider)
|
||||
assert isinstance(oauth_service.providers["github"], GitHubOAuthProvider)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_state(self, test_session: Any) -> None:
|
||||
"""Test state generation."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
state = oauth_service.generate_state()
|
||||
|
||||
assert isinstance(state, str)
|
||||
assert len(state) > 10 # Should be a reasonable length
|
||||
|
||||
# Generate another to ensure they're different
|
||||
state2 = oauth_service.generate_state()
|
||||
assert state != state2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_valid(self, test_session: Any) -> None:
|
||||
"""Test getting valid OAuth provider."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
|
||||
google_provider = oauth_service.get_provider("google")
|
||||
assert isinstance(google_provider, GoogleOAuthProvider)
|
||||
assert google_provider.provider_name == "google"
|
||||
|
||||
github_provider = oauth_service.get_provider("github")
|
||||
assert isinstance(github_provider, GitHubOAuthProvider)
|
||||
assert github_provider.provider_name == "github"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_invalid(self, test_session: Any) -> None:
|
||||
"""Test getting invalid OAuth provider."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
oauth_service.get_provider("invalid")
|
||||
|
||||
assert "Unsupported OAuth provider" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authorization_url(self, test_session: Any) -> None:
|
||||
"""Test authorization URL generation."""
|
||||
oauth_service = OAuthService(test_session)
|
||||
state = "test_state_123"
|
||||
|
||||
# Test Google
|
||||
google_url = oauth_service.get_authorization_url("google", state)
|
||||
assert "accounts.google.com" in google_url
|
||||
assert "client_id=" in google_url
|
||||
assert f"state={state}" in google_url
|
||||
|
||||
# Test GitHub
|
||||
github_url = oauth_service.get_authorization_url("github", state)
|
||||
assert "github.com" in github_url
|
||||
assert "client_id=" in github_url
|
||||
assert f"state={state}" in github_url
|
||||
|
||||
|
||||
class TestGoogleOAuthProvider:
|
||||
"""Test Google OAuth provider."""
|
||||
|
||||
def test_provider_properties(self) -> None:
|
||||
"""Test Google provider properties."""
|
||||
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||
|
||||
assert provider.provider_name == "google"
|
||||
assert "accounts.google.com" in provider.authorization_url
|
||||
assert "oauth2.googleapis.com" in provider.token_url
|
||||
assert "googleapis.com" in provider.user_info_url
|
||||
assert "openid email profile" in provider.scope
|
||||
|
||||
def test_authorization_url_generation(self) -> None:
|
||||
"""Test authorization URL generation."""
|
||||
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||
state = "test_state"
|
||||
|
||||
auth_url = provider.get_authorization_url(state)
|
||||
|
||||
assert "accounts.google.com" in auth_url
|
||||
assert "client_id=test_client_id" in auth_url
|
||||
assert f"state={state}" in auth_url
|
||||
assert "scope=openid+email+profile" in auth_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_info_success(self) -> None:
|
||||
"""Test successful user info retrieval."""
|
||||
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||
|
||||
mock_response_data = {
|
||||
"id": "google_user_123",
|
||||
"email": "test@gmail.com",
|
||||
"name": "Test User",
|
||||
"picture": "https://example.com/avatar.jpg",
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient.get") as mock_get:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = await provider.get_user_info("test_access_token")
|
||||
|
||||
assert user_info.provider == "google"
|
||||
assert user_info.provider_user_id == "google_user_123"
|
||||
assert user_info.email == "test@gmail.com"
|
||||
assert user_info.name == "Test User"
|
||||
assert user_info.picture == "https://example.com/avatar.jpg"
|
||||
|
||||
|
||||
class TestGitHubOAuthProvider:
|
||||
"""Test GitHub OAuth provider."""
|
||||
|
||||
def test_provider_properties(self) -> None:
|
||||
"""Test GitHub provider properties."""
|
||||
provider = GitHubOAuthProvider("test_client_id", "test_secret")
|
||||
|
||||
assert provider.provider_name == "github"
|
||||
assert "github.com" in provider.authorization_url
|
||||
assert "github.com" in provider.token_url
|
||||
assert "api.github.com" in provider.user_info_url
|
||||
assert "user:email" in provider.scope
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_info_success(self) -> None:
|
||||
"""Test successful user info retrieval."""
|
||||
provider = GitHubOAuthProvider("test_client_id", "test_secret")
|
||||
|
||||
mock_user_data = {
|
||||
"id": 123456,
|
||||
"login": "testuser",
|
||||
"name": "Test User",
|
||||
"avatar_url": "https://github.com/avatar.jpg",
|
||||
}
|
||||
|
||||
mock_emails_data = [
|
||||
{"email": "test@example.com", "primary": True, "verified": True},
|
||||
{"email": "secondary@example.com", "primary": False, "verified": True},
|
||||
]
|
||||
|
||||
with patch("httpx.AsyncClient.get") as mock_get:
|
||||
# Mock user profile response
|
||||
mock_user_response = Mock()
|
||||
mock_user_response.status_code = 200
|
||||
mock_user_response.json.return_value = mock_user_data
|
||||
|
||||
# Mock emails response
|
||||
mock_emails_response = Mock()
|
||||
mock_emails_response.status_code = 200
|
||||
mock_emails_response.json.return_value = mock_emails_data
|
||||
|
||||
# Return different responses based on URL
|
||||
def side_effect(url, **kwargs):
|
||||
if "user/emails" in str(url):
|
||||
return mock_emails_response
|
||||
return mock_user_response
|
||||
|
||||
mock_get.side_effect = side_effect
|
||||
|
||||
user_info = await provider.get_user_info("test_access_token")
|
||||
|
||||
assert user_info.provider == "github"
|
||||
assert user_info.provider_user_id == "123456"
|
||||
assert user_info.email == "test@example.com"
|
||||
assert user_info.name == "Test User"
|
||||
assert user_info.picture == "https://github.com/avatar.jpg"
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -46,6 +46,7 @@ dependencies = [
|
||||
{ name = "bcrypt" },
|
||||
{ name = "email-validator" },
|
||||
{ name = "fastapi", extra = ["standard"] },
|
||||
{ name = "httpx" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyjwt" },
|
||||
{ name = "sqlmodel" },
|
||||
@@ -69,6 +70,7 @@ requires-dist = [
|
||||
{ name = "bcrypt", specifier = "==4.3.0" },
|
||||
{ name = "email-validator", specifier = "==2.2.0" },
|
||||
{ name = "fastapi", extras = ["standard"], specifier = "==0.116.1" },
|
||||
{ name = "httpx", specifier = "==0.28.1" },
|
||||
{ name = "pydantic-settings", specifier = "==2.10.1" },
|
||||
{ name = "pyjwt", specifier = "==2.10.1" },
|
||||
{ name = "sqlmodel", specifier = "==0.0.24" },
|
||||
|
||||
Reference in New Issue
Block a user