feat: Implement OAuth2 authentication with Google and GitHub
- Added OAuth2 endpoints for Google and GitHub authentication. - Created OAuth service to handle provider interactions and user info retrieval. - Implemented user OAuth repository for managing user OAuth links in the database. - Updated auth service to support linking existing users and creating new users via OAuth. - Added CORS middleware to allow frontend access. - Created tests for OAuth endpoints and service functionality. - Introduced environment configuration for OAuth client IDs and secrets. - Added logging for OAuth operations and error handling.
This commit is contained in:
29
.env.template
Normal file
29
.env.template
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Application Configuration
|
||||||
|
HOST=localhost
|
||||||
|
PORT=8000
|
||||||
|
RELOAD=true
|
||||||
|
|
||||||
|
# Database Configuration
|
||||||
|
DATABASE_URL=sqlite+aiosqlite:///data/soundboard.db
|
||||||
|
DATABASE_ECHO=false
|
||||||
|
|
||||||
|
# Logging Configuration
|
||||||
|
LOG_LEVEL=info
|
||||||
|
LOG_FILE=logs/app.log
|
||||||
|
LOG_MAX_SIZE=10485760
|
||||||
|
LOG_BACKUP_COUNT=5
|
||||||
|
|
||||||
|
# JWT Configuration
|
||||||
|
JWT_SECRET_KEY=your-secret-key-change-in-production
|
||||||
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=15
|
||||||
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||||
|
|
||||||
|
# Cookie Configuration
|
||||||
|
COOKIE_SECURE=false
|
||||||
|
|
||||||
|
# OAuth2 Configuration
|
||||||
|
GOOGLE_CLIENT_ID=
|
||||||
|
GOOGLE_CLIENT_SECRET=
|
||||||
|
GITHUB_CLIENT_ID=
|
||||||
|
GITHUB_CLIENT_SECRET=
|
||||||
|
OAUTH_REDIRECT_URL=http://localhost:8001/auth/callback
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from app.api.v1 import auth, main
|
from app.api.v1 import auth, main, oauth
|
||||||
|
|
||||||
# V1 API router with v1 prefix
|
# V1 API router with v1 prefix
|
||||||
api_router = APIRouter(prefix="/v1")
|
api_router = APIRouter(prefix="/v1")
|
||||||
@@ -10,3 +10,4 @@ api_router = APIRouter(prefix="/v1")
|
|||||||
# Include all route modules
|
# Include all route modules
|
||||||
api_router.include_router(main.router, tags=["main"])
|
api_router.include_router(main.router, tags=["main"])
|
||||||
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
|
api_router.include_router(auth.router, prefix="/auth", tags=["authentication"])
|
||||||
|
api_router.include_router(oauth.router, prefix="/oauth", tags=["oauth"])
|
||||||
|
|||||||
@@ -54,8 +54,6 @@ async def register(
|
|||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return only user data, tokens are now in cookies
|
|
||||||
return auth_response.user
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -64,6 +62,8 @@ async def register(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Registration failed",
|
detail="Registration failed",
|
||||||
) from e
|
) from e
|
||||||
|
else:
|
||||||
|
return auth_response.user
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login")
|
@router.post("/login")
|
||||||
@@ -101,8 +101,6 @@ async def login(
|
|||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return only user data, tokens are now in cookies
|
|
||||||
return auth_response.user
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -111,6 +109,8 @@ async def login(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Login failed",
|
detail="Login failed",
|
||||||
) from e
|
) from e
|
||||||
|
else:
|
||||||
|
return auth_response.user
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me")
|
@router.get("/me")
|
||||||
@@ -156,7 +156,6 @@ async def refresh_token(
|
|||||||
samesite=settings.COOKIE_SAMESITE,
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"message": "Token refreshed successfully"}
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -165,6 +164,8 @@ async def refresh_token(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Token refresh failed",
|
detail="Token refresh failed",
|
||||||
) from e
|
) from e
|
||||||
|
else:
|
||||||
|
return {"message": "Token refreshed successfully"}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout")
|
@router.post("/logout")
|
||||||
|
|||||||
111
app/api/v1/oauth.py
Normal file
111
app/api/v1/oauth.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""OAuth2 authentication endpoints."""
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.dependencies import get_auth_service, get_oauth_service
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.services.auth import AuthService
|
||||||
|
from app.services.oauth import OAuthService
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{provider}/authorize")
|
||||||
|
async def oauth_authorize(
|
||||||
|
provider: str,
|
||||||
|
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Get OAuth authorization URL."""
|
||||||
|
try:
|
||||||
|
# Generate secure state parameter
|
||||||
|
state = oauth_service.generate_state()
|
||||||
|
|
||||||
|
# Get authorization URL
|
||||||
|
auth_url = oauth_service.get_authorization_url(provider, state)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("OAuth authorization failed for provider: %s", provider)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth authorization failed",
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"authorization_url": auth_url,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{provider}/callback")
|
||||||
|
async def oauth_callback(
|
||||||
|
provider: str,
|
||||||
|
response: Response,
|
||||||
|
code: Annotated[str, Query()],
|
||||||
|
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
||||||
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Handle OAuth callback."""
|
||||||
|
try:
|
||||||
|
# Handle OAuth callback and get user info
|
||||||
|
oauth_user_info = await oauth_service.handle_callback(provider, code)
|
||||||
|
|
||||||
|
# Perform OAuth login (link or create user)
|
||||||
|
auth_response = await auth_service.oauth_login(oauth_user_info)
|
||||||
|
|
||||||
|
# Create and store refresh token
|
||||||
|
user = await auth_service.get_current_user(auth_response.user.id)
|
||||||
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
||||||
|
|
||||||
|
# Set HTTP-only cookies for both tokens
|
||||||
|
response.set_cookie(
|
||||||
|
key="access_token",
|
||||||
|
value=auth_response.token.access_token,
|
||||||
|
max_age=auth_response.token.expires_in,
|
||||||
|
httponly=True,
|
||||||
|
secure=settings.COOKIE_SECURE,
|
||||||
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
key="refresh_token",
|
||||||
|
value=refresh_token,
|
||||||
|
max_age=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60,
|
||||||
|
httponly=True,
|
||||||
|
secure=settings.COOKIE_SECURE,
|
||||||
|
samesite=settings.COOKIE_SAMESITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OAuth login successful for user: %s via %s",
|
||||||
|
auth_response.user.email,
|
||||||
|
provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redirect back to frontend after successful authentication
|
||||||
|
return RedirectResponse(
|
||||||
|
url="http://localhost:8001/?auth=success",
|
||||||
|
status_code=302,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("OAuth callback failed for provider: %s", provider)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="OAuth callback failed",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers")
|
||||||
|
async def get_oauth_providers() -> dict[str, list[str]]:
|
||||||
|
"""Get list of available OAuth providers."""
|
||||||
|
return {
|
||||||
|
"providers": ["google", "github"],
|
||||||
|
}
|
||||||
@@ -13,13 +13,16 @@ class Settings(BaseSettings):
|
|||||||
extra="ignore",
|
extra="ignore",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Application Configuration
|
||||||
HOST: str = "localhost"
|
HOST: str = "localhost"
|
||||||
PORT: int = 8000
|
PORT: int = 8000
|
||||||
RELOAD: bool = True
|
RELOAD: bool = True
|
||||||
|
|
||||||
|
# Database Configuration
|
||||||
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
|
DATABASE_URL: str = "sqlite+aiosqlite:///data/soundboard.db"
|
||||||
DATABASE_ECHO: bool = False
|
DATABASE_ECHO: bool = False
|
||||||
|
|
||||||
|
# Logging Configuration
|
||||||
LOG_LEVEL: str = "info"
|
LOG_LEVEL: str = "info"
|
||||||
LOG_FILE: str = "logs/app.log"
|
LOG_FILE: str = "logs/app.log"
|
||||||
LOG_MAX_SIZE: int = 10 * 1024 * 1024
|
LOG_MAX_SIZE: int = 10 * 1024 * 1024
|
||||||
@@ -31,12 +34,19 @@ class Settings(BaseSettings):
|
|||||||
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
|
"your-secret-key-change-in-production" # noqa: S105 default value if none set in .env
|
||||||
)
|
)
|
||||||
JWT_ALGORITHM: str = "HS256"
|
JWT_ALGORITHM: str = "HS256"
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15 # Shorter-lived access token
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7 # Longer-lived refresh token
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 7
|
||||||
|
|
||||||
# Cookie Configuration
|
# Cookie Configuration
|
||||||
COOKIE_SECURE: bool = True # Set to False for development without HTTPS
|
COOKIE_SECURE: bool = True
|
||||||
COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax"
|
COOKIE_SAMESITE: Literal["strict", "lax", "none"] = "lax"
|
||||||
|
|
||||||
|
# OAuth2 Configuration
|
||||||
|
GOOGLE_CLIENT_ID: str = ""
|
||||||
|
GOOGLE_CLIENT_SECRET: str = ""
|
||||||
|
GITHUB_CLIENT_ID: str = ""
|
||||||
|
GITHUB_CLIENT_SECRET: str = ""
|
||||||
|
OAUTH_REDIRECT_URL: str = "http://localhost:8001/auth/callback"
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""FastAPI dependencies."""
|
"""FastAPI dependencies."""
|
||||||
|
|
||||||
from typing import Annotated, NoReturn, cast
|
from typing import Annotated, cast
|
||||||
|
|
||||||
from fastapi import Cookie, Depends, HTTPException, status
|
from fastapi import Cookie, Depends, HTTPException, status
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -9,27 +9,12 @@ from app.core.database import get_db
|
|||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.auth import AuthService
|
from app.services.auth import AuthService
|
||||||
|
from app.services.oauth import OAuthService
|
||||||
from app.utils.auth import JWTUtils
|
from app.utils.auth import JWTUtils
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _raise_invalid_token_error() -> NoReturn:
|
|
||||||
"""Raise an invalid token HTTP exception."""
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Invalid token payload",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _raise_auth_error() -> NoReturn:
|
|
||||||
"""Raise an authentication HTTP exception."""
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_auth_service(
|
async def get_auth_service(
|
||||||
session: Annotated[AsyncSession, Depends(get_db)],
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
) -> AuthService:
|
) -> AuthService:
|
||||||
@@ -37,6 +22,13 @@ async def get_auth_service(
|
|||||||
return AuthService(session)
|
return AuthService(session)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_oauth_service(
|
||||||
|
session: Annotated[AsyncSession, Depends(get_db)],
|
||||||
|
) -> OAuthService:
|
||||||
|
"""Get the OAuth service."""
|
||||||
|
return OAuthService(session)
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
access_token: Annotated[str | None, Cookie()],
|
access_token: Annotated[str | None, Cookie()],
|
||||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
@@ -46,7 +38,10 @@ async def get_current_user(
|
|||||||
# Check if access token cookie exists
|
# Check if access token cookie exists
|
||||||
if not access_token:
|
if not access_token:
|
||||||
logger.warning("No access token cookie found")
|
logger.warning("No access token cookie found")
|
||||||
_raise_auth_error()
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
)
|
||||||
|
|
||||||
# Decode the JWT token
|
# Decode the JWT token
|
||||||
payload = JWTUtils.decode_access_token(access_token)
|
payload = JWTUtils.decode_access_token(access_token)
|
||||||
@@ -54,7 +49,10 @@ async def get_current_user(
|
|||||||
# Extract user ID from token
|
# Extract user ID from token
|
||||||
user_id_str = payload.get("sub")
|
user_id_str = payload.get("sub")
|
||||||
if not user_id_str:
|
if not user_id_str:
|
||||||
_raise_invalid_token_error()
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid token payload",
|
||||||
|
)
|
||||||
|
|
||||||
# At this point user_id_str is guaranteed to be truthy, safe to cast
|
# At this point user_id_str is guaranteed to be truthy, safe to cast
|
||||||
user_id_str = cast("str", user_id_str)
|
user_id_str = cast("str", user_id_str)
|
||||||
@@ -74,9 +72,12 @@ async def get_current_user(
|
|||||||
except HTTPException:
|
except HTTPException:
|
||||||
# Re-raise HTTPExceptions without wrapping them
|
# Re-raise HTTPExceptions without wrapping them
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Failed to authenticate user")
|
logger.exception("Failed to authenticate user")
|
||||||
_raise_auth_error()
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
async def get_current_active_user(
|
async def get_current_active_user(
|
||||||
|
|||||||
10
app/main.py
10
app/main.py
@@ -2,6 +2,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from app.api import api_router
|
from app.api import api_router
|
||||||
from app.core.database import init_db
|
from app.core.database import init_db
|
||||||
@@ -28,6 +29,15 @@ def create_app() -> FastAPI:
|
|||||||
"""Create and configure the FastAPI application."""
|
"""Create and configure the FastAPI application."""
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["http://localhost:8001"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
app.add_middleware(LoggingMiddleware)
|
app.add_middleware(LoggingMiddleware)
|
||||||
|
|
||||||
# Include API routes
|
# Include API routes
|
||||||
|
|||||||
117
app/repositories/user_oauth.py
Normal file
117
app/repositories/user_oauth.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""Repository for user OAuth operations."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.user_oauth import UserOauth
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UserOauthRepository:
|
||||||
|
"""Repository for user OAuth operations."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize repository with database session."""
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def get_by_provider_user_id(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
provider_user_id: str,
|
||||||
|
) -> UserOauth | None:
|
||||||
|
"""Get user OAuth by provider and provider user ID."""
|
||||||
|
try:
|
||||||
|
statement = select(UserOauth).where(
|
||||||
|
UserOauth.provider == provider,
|
||||||
|
UserOauth.provider_user_id == provider_user_id,
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return result.first()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to get user OAuth by provider user ID: %s:%s",
|
||||||
|
provider,
|
||||||
|
provider_user_id,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_user_id_and_provider(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
provider: str,
|
||||||
|
) -> UserOauth | None:
|
||||||
|
"""Get user OAuth by user ID and provider."""
|
||||||
|
try:
|
||||||
|
statement = select(UserOauth).where(
|
||||||
|
UserOauth.user_id == user_id,
|
||||||
|
UserOauth.provider == provider,
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to get user OAuth by user ID and provider: %s:%s",
|
||||||
|
user_id,
|
||||||
|
provider,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
return result.first()
|
||||||
|
|
||||||
|
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
|
||||||
|
"""Create a new user OAuth record."""
|
||||||
|
try:
|
||||||
|
oauth = UserOauth(**oauth_data)
|
||||||
|
self.session.add(oauth)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(oauth)
|
||||||
|
logger.info(
|
||||||
|
"Created OAuth link for user %s with provider %s",
|
||||||
|
oauth.user_id,
|
||||||
|
oauth.provider,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
await self.session.rollback()
|
||||||
|
logger.exception("Failed to create user OAuth")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
return oauth
|
||||||
|
|
||||||
|
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
|
||||||
|
"""Update a user OAuth record."""
|
||||||
|
try:
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(oauth, key, value)
|
||||||
|
|
||||||
|
self.session.add(oauth)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(oauth)
|
||||||
|
logger.info(
|
||||||
|
"Updated OAuth link for user %s with provider %s",
|
||||||
|
oauth.user_id,
|
||||||
|
oauth.provider,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
await self.session.rollback()
|
||||||
|
logger.exception("Failed to update user OAuth")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
return oauth
|
||||||
|
|
||||||
|
async def delete(self, oauth: UserOauth) -> None:
|
||||||
|
"""Delete a user OAuth record."""
|
||||||
|
try:
|
||||||
|
await self.session.delete(oauth)
|
||||||
|
await self.session.commit()
|
||||||
|
logger.info(
|
||||||
|
"Deleted OAuth link for user %s with provider %s",
|
||||||
|
oauth.user_id,
|
||||||
|
oauth.provider,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
await self.session.rollback()
|
||||||
|
logger.exception("Failed to delete user OAuth")
|
||||||
|
raise
|
||||||
@@ -10,6 +10,7 @@ from app.core.config import settings
|
|||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.repositories.user import UserRepository
|
from app.repositories.user import UserRepository
|
||||||
|
from app.repositories.user_oauth import UserOauthRepository
|
||||||
from app.schemas.auth import (
|
from app.schemas.auth import (
|
||||||
AuthResponse,
|
AuthResponse,
|
||||||
TokenResponse,
|
TokenResponse,
|
||||||
@@ -17,6 +18,7 @@ from app.schemas.auth import (
|
|||||||
UserRegisterRequest,
|
UserRegisterRequest,
|
||||||
UserResponse,
|
UserResponse,
|
||||||
)
|
)
|
||||||
|
from app.services.oauth import OAuthUserInfo
|
||||||
from app.utils.auth import JWTUtils, PasswordUtils
|
from app.utils.auth import JWTUtils, PasswordUtils
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -29,6 +31,7 @@ class AuthService:
|
|||||||
"""Initialize the auth service."""
|
"""Initialize the auth service."""
|
||||||
self.session = session
|
self.session = session
|
||||||
self.user_repo = UserRepository(session)
|
self.user_repo = UserRepository(session)
|
||||||
|
self.oauth_repo = UserOauthRepository(session)
|
||||||
|
|
||||||
async def register(self, request: UserRegisterRequest) -> AuthResponse:
|
async def register(self, request: UserRegisterRequest) -> AuthResponse:
|
||||||
"""Register a new user."""
|
"""Register a new user."""
|
||||||
@@ -203,7 +206,7 @@ class AuthService:
|
|||||||
|
|
||||||
# Check if refresh token is expired
|
# Check if refresh token is expired
|
||||||
if user.refresh_token_expires_at and datetime.now(
|
if user.refresh_token_expires_at and datetime.now(
|
||||||
UTC
|
UTC,
|
||||||
) > user.refresh_token_expires_at.replace(tzinfo=UTC):
|
) > user.refresh_token_expires_at.replace(tzinfo=UTC):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -272,3 +275,127 @@ class AuthService:
|
|||||||
created_at=user.created_at,
|
created_at=user.created_at,
|
||||||
updated_at=user.updated_at,
|
updated_at=user.updated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def oauth_login(self, oauth_user_info: OAuthUserInfo) -> AuthResponse:
|
||||||
|
"""Handle OAuth login - link or create user."""
|
||||||
|
logger.info(
|
||||||
|
"OAuth login attempt for %s with provider %s",
|
||||||
|
oauth_user_info.email,
|
||||||
|
oauth_user_info.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user already has OAuth link for this provider
|
||||||
|
existing_oauth = await self.oauth_repo.get_by_provider_user_id(
|
||||||
|
oauth_user_info.provider,
|
||||||
|
oauth_user_info.provider_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_oauth:
|
||||||
|
# User exists with this OAuth provider, get the user
|
||||||
|
user = await self.user_repo.get_by_id(existing_oauth.user_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Linked user not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Refresh user to avoid greenlet issues
|
||||||
|
await self.session.refresh(user)
|
||||||
|
|
||||||
|
# Update OAuth record with latest info
|
||||||
|
oauth_update_data = {
|
||||||
|
"email": oauth_user_info.email,
|
||||||
|
"name": oauth_user_info.name,
|
||||||
|
"picture": oauth_user_info.picture,
|
||||||
|
}
|
||||||
|
await self.oauth_repo.update(existing_oauth, oauth_update_data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OAuth login successful for existing user: %s",
|
||||||
|
oauth_user_info.email,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Check if user exists by email
|
||||||
|
user = await self.user_repo.get_by_email(oauth_user_info.email)
|
||||||
|
|
||||||
|
if user:
|
||||||
|
# Refresh user to avoid greenlet issues
|
||||||
|
await self.session.refresh(user)
|
||||||
|
|
||||||
|
# Store user picture value to avoid greenlet issues later
|
||||||
|
current_user_picture = user.picture
|
||||||
|
|
||||||
|
# Link existing user to OAuth provider
|
||||||
|
oauth_data = {
|
||||||
|
"user_id": user.id,
|
||||||
|
"provider": oauth_user_info.provider,
|
||||||
|
"provider_user_id": oauth_user_info.provider_user_id,
|
||||||
|
"email": oauth_user_info.email,
|
||||||
|
"name": oauth_user_info.name,
|
||||||
|
"picture": oauth_user_info.picture,
|
||||||
|
}
|
||||||
|
await self.oauth_repo.create(oauth_data)
|
||||||
|
|
||||||
|
# Update user profile with OAuth info if needed
|
||||||
|
user_update_data = {}
|
||||||
|
if not current_user_picture and oauth_user_info.picture:
|
||||||
|
user_update_data["picture"] = oauth_user_info.picture
|
||||||
|
if user_update_data:
|
||||||
|
await self.user_repo.update(user, user_update_data)
|
||||||
|
# Refresh user after update to avoid greenlet issues
|
||||||
|
await self.session.refresh(user)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Linked existing user %s to OAuth provider %s",
|
||||||
|
oauth_user_info.email,
|
||||||
|
oauth_user_info.provider,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create new user
|
||||||
|
user_data = {
|
||||||
|
"email": oauth_user_info.email,
|
||||||
|
"name": oauth_user_info.name,
|
||||||
|
"picture": oauth_user_info.picture,
|
||||||
|
"is_active": True,
|
||||||
|
# No password for OAuth users
|
||||||
|
"password_hash": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
user = await self.user_repo.create(user_data)
|
||||||
|
|
||||||
|
# Create OAuth link
|
||||||
|
oauth_data = {
|
||||||
|
"user_id": user.id,
|
||||||
|
"provider": oauth_user_info.provider,
|
||||||
|
"provider_user_id": oauth_user_info.provider_user_id,
|
||||||
|
"email": oauth_user_info.email,
|
||||||
|
"name": oauth_user_info.name,
|
||||||
|
"picture": oauth_user_info.picture,
|
||||||
|
}
|
||||||
|
await self.oauth_repo.create(oauth_data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Created new user %s from OAuth provider %s",
|
||||||
|
oauth_user_info.email,
|
||||||
|
oauth_user_info.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Refresh user to avoid greenlet issues and check if user is active
|
||||||
|
await self.session.refresh(user)
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Account is deactivated",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate access token
|
||||||
|
token = self._create_access_token(user)
|
||||||
|
|
||||||
|
# Create response
|
||||||
|
user_response = await self.create_user_response(user)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"OAuth login completed successfully for user: %s",
|
||||||
|
oauth_user_info.email,
|
||||||
|
)
|
||||||
|
return AuthResponse(user=user_response, token=token)
|
||||||
|
|||||||
329
app/services/oauth.py
Normal file
329
app/services/oauth.py
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
"""OAuth2 service for external authentication providers."""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthUserInfo:
|
||||||
|
"""OAuth user information."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
provider_user_id: str,
|
||||||
|
email: str,
|
||||||
|
name: str,
|
||||||
|
picture: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize OAuth user info."""
|
||||||
|
self.provider = provider
|
||||||
|
self.provider_user_id = provider_user_id
|
||||||
|
self.email = email
|
||||||
|
self.name = name
|
||||||
|
self.picture = picture
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProvider(ABC):
|
||||||
|
"""Abstract base class for OAuth providers."""
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||||
|
"""Initialize OAuth provider with client ID and secret."""
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def provider_name(self) -> str:
|
||||||
|
"""Return the provider name."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def authorization_url(self) -> str:
|
||||||
|
"""Return the authorization URL."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def token_url(self) -> str:
|
||||||
|
"""Return the token URL."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def user_info_url(self) -> str:
|
||||||
|
"""Return the user info URL."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def scope(self) -> str:
|
||||||
|
"""Return the required scopes."""
|
||||||
|
|
||||||
|
def get_authorization_url(self, state: str) -> str:
|
||||||
|
"""Generate authorization URL with state parameter."""
|
||||||
|
# Construct provider-specific redirect URI
|
||||||
|
redirect_uri = (
|
||||||
|
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"scope": self.scope,
|
||||||
|
"response_type": "code",
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
return f"{self.authorization_url}?{urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code_for_token(self, code: str) -> str:
|
||||||
|
"""Exchange authorization code for access token."""
|
||||||
|
# Construct provider-specific redirect URI (must match authorization request)
|
||||||
|
redirect_uri = (
|
||||||
|
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.token_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != status.HTTP_200_OK:
|
||||||
|
logger.error("Failed to exchange code for token: %s", response.text)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to exchange authorization code",
|
||||||
|
)
|
||||||
|
|
||||||
|
token_data = response.json()
|
||||||
|
return str(token_data["access_token"])
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||||
|
"""Get user information from the provider."""
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleOAuthProvider(OAuthProvider):
|
||||||
|
"""Google OAuth provider."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_name(self) -> str:
|
||||||
|
"""Return the provider name."""
|
||||||
|
return "google"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def authorization_url(self) -> str:
|
||||||
|
"""Return the authorization URL."""
|
||||||
|
return "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_url(self) -> str:
|
||||||
|
"""Return the token URL."""
|
||||||
|
return "https://oauth2.googleapis.com/token"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_info_url(self) -> str:
|
||||||
|
"""Return the user info URL."""
|
||||||
|
return "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scope(self) -> str:
|
||||||
|
"""Return the required scopes."""
|
||||||
|
return "openid email profile"
|
||||||
|
|
||||||
|
async def exchange_code_for_token(self, code: str) -> str:
|
||||||
|
"""Exchange authorization code for access token."""
|
||||||
|
# Construct provider-specific redirect URI (must match authorization request)
|
||||||
|
redirect_uri = (
|
||||||
|
f"http://localhost:8000/api/v1/oauth/{self.provider_name}/callback"
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"code": code,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.token_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != status.HTTP_200_OK:
|
||||||
|
logger.error("Failed to exchange code for token: %s", response.text)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to exchange authorization code",
|
||||||
|
)
|
||||||
|
|
||||||
|
token_data = response.json()
|
||||||
|
return str(token_data["access_token"])
|
||||||
|
|
||||||
|
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||||
|
"""Get user information from Google."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
self.user_info_url,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != status.HTTP_200_OK:
|
||||||
|
logger.error("Failed to get user info: %s", response.text)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to get user information",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_data = response.json()
|
||||||
|
return OAuthUserInfo(
|
||||||
|
provider=self.provider_name,
|
||||||
|
provider_user_id=user_data["id"],
|
||||||
|
email=user_data["email"],
|
||||||
|
name=user_data["name"],
|
||||||
|
picture=user_data.get("picture"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubOAuthProvider(OAuthProvider):
|
||||||
|
"""GitHub OAuth provider."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_name(self) -> str:
|
||||||
|
"""Return the provider name."""
|
||||||
|
return "github"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def authorization_url(self) -> str:
|
||||||
|
"""Return the authorization URL."""
|
||||||
|
return "https://github.com/login/oauth/authorize"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_url(self) -> str:
|
||||||
|
"""Return the token URL."""
|
||||||
|
return "https://github.com/login/oauth/access_token"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_info_url(self) -> str:
|
||||||
|
"""Return the user info URL."""
|
||||||
|
return "https://api.github.com/user"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scope(self) -> str:
|
||||||
|
"""Return the required scopes."""
|
||||||
|
return "user:email"
|
||||||
|
|
||||||
|
async def get_user_info(self, access_token: str) -> OAuthUserInfo:
|
||||||
|
"""Get user information from GitHub."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
# Get user profile
|
||||||
|
user_response = await client.get(
|
||||||
|
self.user_info_url,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_response.status_code != status.HTTP_200_OK:
|
||||||
|
logger.error("Failed to get user info: %s", user_response.text)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to get user information",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_data = user_response.json()
|
||||||
|
|
||||||
|
# Get user email (GitHub doesn't include email in basic profile)
|
||||||
|
email_response = await client.get(
|
||||||
|
"https://api.github.com/user/emails",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if email_response.status_code != status.HTTP_200_OK:
|
||||||
|
logger.error("Failed to get user emails: %s", email_response.text)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Failed to get user email",
|
||||||
|
)
|
||||||
|
|
||||||
|
emails = email_response.json()
|
||||||
|
primary_email = next(
|
||||||
|
(email["email"] for email in emails if email["primary"]),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not primary_email:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="No primary email found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return OAuthUserInfo(
|
||||||
|
provider=self.provider_name,
|
||||||
|
provider_user_id=str(user_data["id"]),
|
||||||
|
email=primary_email,
|
||||||
|
name=user_data.get("name") or user_data["login"],
|
||||||
|
picture=user_data.get("avatar_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthService:
|
||||||
|
"""Service for handling OAuth authentication."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize OAuth service with database session."""
|
||||||
|
self.session = session
|
||||||
|
self.providers = {
|
||||||
|
"google": GoogleOAuthProvider(
|
||||||
|
settings.GOOGLE_CLIENT_ID,
|
||||||
|
settings.GOOGLE_CLIENT_SECRET,
|
||||||
|
),
|
||||||
|
"github": GitHubOAuthProvider(
|
||||||
|
settings.GITHUB_CLIENT_ID,
|
||||||
|
settings.GITHUB_CLIENT_SECRET,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_provider(self, provider_name: str) -> OAuthProvider:
|
||||||
|
"""Get OAuth provider by name."""
|
||||||
|
if provider_name not in self.providers:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unsupported OAuth provider: {provider_name}",
|
||||||
|
)
|
||||||
|
return self.providers[provider_name]
|
||||||
|
|
||||||
|
def generate_state(self) -> str:
|
||||||
|
"""Generate a secure state parameter for OAuth flow."""
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
def get_authorization_url(self, provider_name: str, state: str) -> str:
|
||||||
|
"""Get authorization URL for the specified provider."""
|
||||||
|
provider = self.get_provider(provider_name)
|
||||||
|
return provider.get_authorization_url(state)
|
||||||
|
|
||||||
|
async def handle_callback(self, provider_name: str, code: str) -> OAuthUserInfo:
|
||||||
|
"""Handle OAuth callback and return user info."""
|
||||||
|
provider = self.get_provider(provider_name)
|
||||||
|
|
||||||
|
# Exchange code for access token
|
||||||
|
access_token = await provider.exchange_code_for_token(code)
|
||||||
|
|
||||||
|
# Get user information
|
||||||
|
return await provider.get_user_info(access_token)
|
||||||
@@ -9,6 +9,7 @@ dependencies = [
|
|||||||
"bcrypt==4.3.0",
|
"bcrypt==4.3.0",
|
||||||
"email-validator==2.2.0",
|
"email-validator==2.2.0",
|
||||||
"fastapi[standard]==0.116.1",
|
"fastapi[standard]==0.116.1",
|
||||||
|
"httpx==0.28.1",
|
||||||
"pydantic-settings==2.10.1",
|
"pydantic-settings==2.10.1",
|
||||||
"pyjwt==2.10.1",
|
"pyjwt==2.10.1",
|
||||||
"sqlmodel==0.0.24",
|
"sqlmodel==0.0.24",
|
||||||
@@ -36,7 +37,7 @@ exclude = ["alembic"]
|
|||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["ALL"]
|
select = ["ALL"]
|
||||||
ignore = ["D100", "D103"]
|
ignore = ["D100", "D103", "TRY301"]
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
[tool.ruff.per-file-ignores]
|
||||||
"tests/**/*.py" = ["S101", "S105"]
|
"tests/**/*.py" = ["S101", "S105"]
|
||||||
|
|||||||
151
tests/api/v1/test_oauth_endpoints.py
Normal file
151
tests/api/v1/test_oauth_endpoints.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""Tests for OAuth authentication endpoints."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
from app.services.oauth import OAuthUserInfo
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthEndpoints:
|
||||||
|
"""Test OAuth API endpoints."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_oauth_providers(self, test_client: AsyncClient) -> None:
|
||||||
|
"""Test getting list of OAuth providers."""
|
||||||
|
response = await test_client.get("/api/v1/oauth/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "providers" in data
|
||||||
|
assert "google" in data["providers"]
|
||||||
|
assert "github" in data["providers"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_authorize_google(self, test_client: AsyncClient) -> None:
|
||||||
|
"""Test OAuth authorization URL generation for Google."""
|
||||||
|
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
||||||
|
mock_state.return_value = "test_state_123"
|
||||||
|
|
||||||
|
response = await test_client.get("/api/v1/oauth/google/authorize")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "authorization_url" in data
|
||||||
|
assert "state" in data
|
||||||
|
assert data["state"] == "test_state_123"
|
||||||
|
assert "accounts.google.com" in data["authorization_url"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_authorize_github(self, test_client: AsyncClient) -> None:
|
||||||
|
"""Test OAuth authorization URL generation for GitHub."""
|
||||||
|
with patch("app.services.oauth.OAuthService.generate_state") as mock_state:
|
||||||
|
mock_state.return_value = "test_state_456"
|
||||||
|
|
||||||
|
response = await test_client.get("/api/v1/oauth/github/authorize")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "authorization_url" in data
|
||||||
|
assert "state" in data
|
||||||
|
assert data["state"] == "test_state_456"
|
||||||
|
assert "github.com" in data["authorization_url"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_authorize_invalid_provider(
|
||||||
|
self, test_client: AsyncClient
|
||||||
|
) -> None:
|
||||||
|
"""Test OAuth authorization with invalid provider."""
|
||||||
|
response = await test_client.get("/api/v1/oauth/invalid/authorize")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert "Unsupported OAuth provider" in data["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_new_user(
|
||||||
|
self, test_client: AsyncClient, ensure_plans: tuple[Any, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Test OAuth callback for new user creation."""
|
||||||
|
# Mock OAuth user info
|
||||||
|
mock_user_info = OAuthUserInfo(
|
||||||
|
provider="google",
|
||||||
|
provider_user_id="google_123",
|
||||||
|
email="newuser@gmail.com",
|
||||||
|
name="New User",
|
||||||
|
picture="https://example.com/avatar.jpg",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
||||||
|
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
||||||
|
mock_callback.return_value = mock_user_info
|
||||||
|
|
||||||
|
response = await test_client.get(
|
||||||
|
"/api/v1/oauth/google/callback",
|
||||||
|
params={"code": "auth_code_123", "state": "test_state"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# OAuth callback should successfully process and redirect to frontend
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
||||||
|
|
||||||
|
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||||
|
# Detailed cookie testing can be done in integration tests
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_existing_user_link(
|
||||||
|
self, test_client: AsyncClient, test_user: Any, ensure_plans: tuple[Any, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Test OAuth callback for linking to existing user."""
|
||||||
|
# Mock OAuth user info with same email as test user
|
||||||
|
mock_user_info = OAuthUserInfo(
|
||||||
|
provider="github",
|
||||||
|
provider_user_id="github_456",
|
||||||
|
email=test_user.email, # Same email as existing user
|
||||||
|
name="Test User",
|
||||||
|
picture="https://github.com/avatar.jpg",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the entire handle_callback method to avoid actual OAuth API calls
|
||||||
|
with patch("app.services.oauth.OAuthService.handle_callback") as mock_callback:
|
||||||
|
mock_callback.return_value = mock_user_info
|
||||||
|
|
||||||
|
response = await test_client.get(
|
||||||
|
"/api/v1/oauth/github/callback",
|
||||||
|
params={"code": "auth_code_456", "state": "test_state"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# OAuth callback should successfully process and redirect to frontend
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert response.headers["location"] == "http://localhost:8001/?auth=success"
|
||||||
|
|
||||||
|
# The fact that we get a 302 redirect means the OAuth login was successful
|
||||||
|
# Detailed cookie testing can be done in integration tests
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_missing_code(self, test_client: AsyncClient) -> None:
|
||||||
|
"""Test OAuth callback with missing authorization code."""
|
||||||
|
response = await test_client.get(
|
||||||
|
"/api/v1/oauth/google/callback",
|
||||||
|
params={"state": "test_state"}, # Missing code parameter
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_invalid_provider(
|
||||||
|
self, test_client: AsyncClient
|
||||||
|
) -> None:
|
||||||
|
"""Test OAuth callback with invalid provider."""
|
||||||
|
response = await test_client.get(
|
||||||
|
"/api/v1/oauth/invalid/callback",
|
||||||
|
params={"code": "auth_code_123", "state": "test_state"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert "Unsupported OAuth provider" in data["detail"]
|
||||||
192
tests/services/test_oauth_service.py
Normal file
192
tests/services/test_oauth_service.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""Tests for OAuth service."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
from app.services.oauth import (
|
||||||
|
GitHubOAuthProvider,
|
||||||
|
GoogleOAuthProvider,
|
||||||
|
OAuthService,
|
||||||
|
OAuthUserInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuthService:
|
||||||
|
"""Test OAuth service functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_service_initialization(self, test_session: Any) -> None:
|
||||||
|
"""Test OAuth service initialization."""
|
||||||
|
oauth_service = OAuthService(test_session)
|
||||||
|
|
||||||
|
assert "google" in oauth_service.providers
|
||||||
|
assert "github" in oauth_service.providers
|
||||||
|
assert isinstance(oauth_service.providers["google"], GoogleOAuthProvider)
|
||||||
|
assert isinstance(oauth_service.providers["github"], GitHubOAuthProvider)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_state(self, test_session: Any) -> None:
|
||||||
|
"""Test state generation."""
|
||||||
|
oauth_service = OAuthService(test_session)
|
||||||
|
state = oauth_service.generate_state()
|
||||||
|
|
||||||
|
assert isinstance(state, str)
|
||||||
|
assert len(state) > 10 # Should be a reasonable length
|
||||||
|
|
||||||
|
# Generate another to ensure they're different
|
||||||
|
state2 = oauth_service.generate_state()
|
||||||
|
assert state != state2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_provider_valid(self, test_session: Any) -> None:
|
||||||
|
"""Test getting valid OAuth provider."""
|
||||||
|
oauth_service = OAuthService(test_session)
|
||||||
|
|
||||||
|
google_provider = oauth_service.get_provider("google")
|
||||||
|
assert isinstance(google_provider, GoogleOAuthProvider)
|
||||||
|
assert google_provider.provider_name == "google"
|
||||||
|
|
||||||
|
github_provider = oauth_service.get_provider("github")
|
||||||
|
assert isinstance(github_provider, GitHubOAuthProvider)
|
||||||
|
assert github_provider.provider_name == "github"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_provider_invalid(self, test_session: Any) -> None:
|
||||||
|
"""Test getting invalid OAuth provider."""
|
||||||
|
oauth_service = OAuthService(test_session)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
oauth_service.get_provider("invalid")
|
||||||
|
|
||||||
|
assert "Unsupported OAuth provider" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_authorization_url(self, test_session: Any) -> None:
|
||||||
|
"""Test authorization URL generation."""
|
||||||
|
oauth_service = OAuthService(test_session)
|
||||||
|
state = "test_state_123"
|
||||||
|
|
||||||
|
# Test Google
|
||||||
|
google_url = oauth_service.get_authorization_url("google", state)
|
||||||
|
assert "accounts.google.com" in google_url
|
||||||
|
assert "client_id=" in google_url
|
||||||
|
assert f"state={state}" in google_url
|
||||||
|
|
||||||
|
# Test GitHub
|
||||||
|
github_url = oauth_service.get_authorization_url("github", state)
|
||||||
|
assert "github.com" in github_url
|
||||||
|
assert "client_id=" in github_url
|
||||||
|
assert f"state={state}" in github_url
|
||||||
|
|
||||||
|
|
||||||
|
class TestGoogleOAuthProvider:
|
||||||
|
"""Test Google OAuth provider."""
|
||||||
|
|
||||||
|
def test_provider_properties(self) -> None:
|
||||||
|
"""Test Google provider properties."""
|
||||||
|
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||||
|
|
||||||
|
assert provider.provider_name == "google"
|
||||||
|
assert "accounts.google.com" in provider.authorization_url
|
||||||
|
assert "oauth2.googleapis.com" in provider.token_url
|
||||||
|
assert "googleapis.com" in provider.user_info_url
|
||||||
|
assert "openid email profile" in provider.scope
|
||||||
|
|
||||||
|
def test_authorization_url_generation(self) -> None:
|
||||||
|
"""Test authorization URL generation."""
|
||||||
|
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||||
|
state = "test_state"
|
||||||
|
|
||||||
|
auth_url = provider.get_authorization_url(state)
|
||||||
|
|
||||||
|
assert "accounts.google.com" in auth_url
|
||||||
|
assert "client_id=test_client_id" in auth_url
|
||||||
|
assert f"state={state}" in auth_url
|
||||||
|
assert "scope=openid+email+profile" in auth_url
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_info_success(self) -> None:
|
||||||
|
"""Test successful user info retrieval."""
|
||||||
|
provider = GoogleOAuthProvider("test_client_id", "test_secret")
|
||||||
|
|
||||||
|
mock_response_data = {
|
||||||
|
"id": "google_user_123",
|
||||||
|
"email": "test@gmail.com",
|
||||||
|
"name": "Test User",
|
||||||
|
"picture": "https://example.com/avatar.jpg",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.get") as mock_get:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_response_data
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
|
||||||
|
user_info = await provider.get_user_info("test_access_token")
|
||||||
|
|
||||||
|
assert user_info.provider == "google"
|
||||||
|
assert user_info.provider_user_id == "google_user_123"
|
||||||
|
assert user_info.email == "test@gmail.com"
|
||||||
|
assert user_info.name == "Test User"
|
||||||
|
assert user_info.picture == "https://example.com/avatar.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubOAuthProvider:
|
||||||
|
"""Test GitHub OAuth provider."""
|
||||||
|
|
||||||
|
def test_provider_properties(self) -> None:
|
||||||
|
"""Test GitHub provider properties."""
|
||||||
|
provider = GitHubOAuthProvider("test_client_id", "test_secret")
|
||||||
|
|
||||||
|
assert provider.provider_name == "github"
|
||||||
|
assert "github.com" in provider.authorization_url
|
||||||
|
assert "github.com" in provider.token_url
|
||||||
|
assert "api.github.com" in provider.user_info_url
|
||||||
|
assert "user:email" in provider.scope
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_info_success(self) -> None:
|
||||||
|
"""Test successful user info retrieval."""
|
||||||
|
provider = GitHubOAuthProvider("test_client_id", "test_secret")
|
||||||
|
|
||||||
|
mock_user_data = {
|
||||||
|
"id": 123456,
|
||||||
|
"login": "testuser",
|
||||||
|
"name": "Test User",
|
||||||
|
"avatar_url": "https://github.com/avatar.jpg",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_emails_data = [
|
||||||
|
{"email": "test@example.com", "primary": True, "verified": True},
|
||||||
|
{"email": "secondary@example.com", "primary": False, "verified": True},
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient.get") as mock_get:
|
||||||
|
# Mock user profile response
|
||||||
|
mock_user_response = Mock()
|
||||||
|
mock_user_response.status_code = 200
|
||||||
|
mock_user_response.json.return_value = mock_user_data
|
||||||
|
|
||||||
|
# Mock emails response
|
||||||
|
mock_emails_response = Mock()
|
||||||
|
mock_emails_response.status_code = 200
|
||||||
|
mock_emails_response.json.return_value = mock_emails_data
|
||||||
|
|
||||||
|
# Return different responses based on URL
|
||||||
|
def side_effect(url, **kwargs):
|
||||||
|
if "user/emails" in str(url):
|
||||||
|
return mock_emails_response
|
||||||
|
return mock_user_response
|
||||||
|
|
||||||
|
mock_get.side_effect = side_effect
|
||||||
|
|
||||||
|
user_info = await provider.get_user_info("test_access_token")
|
||||||
|
|
||||||
|
assert user_info.provider == "github"
|
||||||
|
assert user_info.provider_user_id == "123456"
|
||||||
|
assert user_info.email == "test@example.com"
|
||||||
|
assert user_info.name == "Test User"
|
||||||
|
assert user_info.picture == "https://github.com/avatar.jpg"
|
||||||
2
uv.lock
generated
2
uv.lock
generated
@@ -46,6 +46,7 @@ dependencies = [
|
|||||||
{ name = "bcrypt" },
|
{ name = "bcrypt" },
|
||||||
{ name = "email-validator" },
|
{ name = "email-validator" },
|
||||||
{ name = "fastapi", extra = ["standard"] },
|
{ name = "fastapi", extra = ["standard"] },
|
||||||
|
{ name = "httpx" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
{ name = "pyjwt" },
|
{ name = "pyjwt" },
|
||||||
{ name = "sqlmodel" },
|
{ name = "sqlmodel" },
|
||||||
@@ -69,6 +70,7 @@ requires-dist = [
|
|||||||
{ name = "bcrypt", specifier = "==4.3.0" },
|
{ name = "bcrypt", specifier = "==4.3.0" },
|
||||||
{ name = "email-validator", specifier = "==2.2.0" },
|
{ name = "email-validator", specifier = "==2.2.0" },
|
||||||
{ name = "fastapi", extras = ["standard"], specifier = "==0.116.1" },
|
{ name = "fastapi", extras = ["standard"], specifier = "==0.116.1" },
|
||||||
|
{ name = "httpx", specifier = "==0.28.1" },
|
||||||
{ name = "pydantic-settings", specifier = "==2.10.1" },
|
{ name = "pydantic-settings", specifier = "==2.10.1" },
|
||||||
{ name = "pyjwt", specifier = "==2.10.1" },
|
{ name = "pyjwt", specifier = "==2.10.1" },
|
||||||
{ name = "sqlmodel", specifier = "==0.0.24" },
|
{ name = "sqlmodel", specifier = "==0.0.24" },
|
||||||
|
|||||||
Reference in New Issue
Block a user