533 lines
18 KiB
Python
533 lines
18 KiB
Python
"""Authentication endpoints."""
|
|
|
|
import secrets
|
|
import time
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response, status
|
|
from fastapi.responses import RedirectResponse
|
|
|
|
from app.core.config import settings
|
|
from app.core.dependencies import (
|
|
get_auth_service,
|
|
get_current_active_user,
|
|
get_current_active_user_flexible,
|
|
get_oauth_service,
|
|
)
|
|
from app.core.logging import get_logger
|
|
from app.models.user import User
|
|
from app.schemas.auth import (
|
|
ApiTokenRequest,
|
|
ApiTokenResponse,
|
|
ApiTokenStatusResponse,
|
|
ChangePasswordRequest,
|
|
UpdateProfileRequest,
|
|
UserLoginRequest,
|
|
UserRegisterRequest,
|
|
UserResponse,
|
|
)
|
|
from app.services.auth import AuthService
|
|
from app.services.oauth import OAuthService
|
|
from app.utils.auth import JWTUtils, TokenUtils
|
|
from app.utils.cookies import set_access_token_cookie, set_auth_cookies
|
|
|
|
router = APIRouter(prefix="/auth", tags=["authentication"])
|
|
logger = get_logger(__name__)
|
|
|
|
# Global temporary storage for OAuth codes (in production, use Redis with TTL)
|
|
_temp_oauth_codes: dict[str, dict[str, Any]] = {}
|
|
|
|
|
|
# Authentication endpoints
|
|
@router.post(
|
|
"/register",
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
async def register(
|
|
request: UserRegisterRequest,
|
|
response: Response,
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> UserResponse:
|
|
"""Register a new user account."""
|
|
try:
|
|
auth_response = await auth_service.register(request)
|
|
|
|
# Create and store refresh token - need to get User object from service
|
|
user = await auth_service.get_current_user(auth_response.user.id)
|
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
|
|
|
# Set HTTP-only cookies for both tokens
|
|
set_auth_cookies(
|
|
response=response,
|
|
access_token=auth_response.token.access_token,
|
|
refresh_token=refresh_token,
|
|
expires_in=auth_response.token.expires_in,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Registration failed for email: %s", request.email)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Registration failed",
|
|
) from e
|
|
else:
|
|
return auth_response.user
|
|
|
|
|
|
@router.post("/login")
|
|
async def login(
|
|
request: UserLoginRequest,
|
|
response: Response,
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> UserResponse:
|
|
"""Authenticate a user and return access token."""
|
|
try:
|
|
auth_response = await auth_service.login(request)
|
|
|
|
# Create and store refresh token - need to get User object from service
|
|
user = await auth_service.get_current_user(auth_response.user.id)
|
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
|
|
|
# Set HTTP-only cookies for both tokens
|
|
set_auth_cookies(
|
|
response=response,
|
|
access_token=auth_response.token.access_token,
|
|
refresh_token=refresh_token,
|
|
expires_in=auth_response.token.expires_in,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Login failed for email: %s", request.email)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Login failed",
|
|
) from e
|
|
else:
|
|
return auth_response.user
|
|
|
|
|
|
@router.get("/me")
|
|
async def get_current_user_info(
|
|
current_user: Annotated[User, Depends(get_current_active_user_flexible)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> UserResponse:
|
|
"""Get current user information."""
|
|
try:
|
|
return await auth_service.create_user_response(current_user)
|
|
except Exception as e:
|
|
logger.exception("Failed to get current user info")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to retrieve user information",
|
|
) from e
|
|
|
|
|
|
@router.post("/refresh")
|
|
async def refresh_token(
|
|
response: Response,
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
refresh_token: Annotated[str | None, Cookie()] = None,
|
|
) -> dict[str, str]:
|
|
"""Refresh access token using refresh token."""
|
|
try:
|
|
if not refresh_token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="No refresh token provided",
|
|
)
|
|
|
|
# Get new access token
|
|
token_response = await auth_service.refresh_access_token(refresh_token)
|
|
|
|
# Set new access token cookie
|
|
set_access_token_cookie(
|
|
response=response,
|
|
access_token=token_response.access_token,
|
|
expires_in=token_response.expires_in,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Token refresh failed")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Token refresh failed",
|
|
) from e
|
|
else:
|
|
return {"message": "Token refreshed successfully"}
|
|
|
|
|
|
@router.post("/logout")
|
|
async def logout(
|
|
response: Response,
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
access_token: Annotated[str | None, Cookie()] = None,
|
|
refresh_token: Annotated[str | None, Cookie()] = None,
|
|
) -> dict[str, str]:
|
|
"""Logout endpoint - clears cookies and revokes refresh token."""
|
|
user = None
|
|
|
|
# Try to get user from access token first
|
|
if access_token:
|
|
try:
|
|
payload = JWTUtils.decode_access_token(access_token)
|
|
user_id_str = payload.get("sub")
|
|
if user_id_str:
|
|
user_id = int(user_id_str)
|
|
user = await auth_service.get_current_user(user_id)
|
|
logger.info("Found user from access token: %s", user.email)
|
|
except (HTTPException, Exception) as e:
|
|
logger.info("Access token validation failed: %s", str(e))
|
|
|
|
# If no user found, try refresh token
|
|
if not user and refresh_token:
|
|
try:
|
|
payload = JWTUtils.decode_refresh_token(refresh_token)
|
|
user_id_str = payload.get("sub")
|
|
if user_id_str:
|
|
user_id = int(user_id_str)
|
|
user = await auth_service.get_current_user(user_id)
|
|
logger.info("Found user from refresh token: %s", user.email)
|
|
except (HTTPException, Exception) as e:
|
|
logger.info("Refresh token validation failed: %s", str(e))
|
|
|
|
# If we found a user, revoke their refresh token
|
|
if user:
|
|
await auth_service.revoke_refresh_token(user)
|
|
logger.info("Successfully revoked refresh token for user: %s", user.email)
|
|
else:
|
|
logger.info("No user found, skipping token revocation")
|
|
|
|
# Always clear both cookies regardless of token validity
|
|
response.delete_cookie(
|
|
key="access_token",
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
domain=settings.COOKIE_DOMAIN, # Match the domain used when setting cookies
|
|
)
|
|
response.delete_cookie(
|
|
key="refresh_token",
|
|
httponly=True,
|
|
secure=settings.COOKIE_SECURE,
|
|
samesite=settings.COOKIE_SAMESITE,
|
|
domain=settings.COOKIE_DOMAIN, # Match the domain used when setting cookies
|
|
)
|
|
|
|
return {"message": "Successfully logged out"}
|
|
|
|
|
|
# OAuth2 endpoints
|
|
@router.get("/{provider}/authorize")
|
|
async def oauth_authorize(
|
|
provider: str,
|
|
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
|
) -> dict[str, str]:
|
|
"""Get OAuth authorization URL."""
|
|
try:
|
|
# Generate secure state parameter
|
|
state = oauth_service.generate_state()
|
|
|
|
# Get authorization URL
|
|
auth_url = oauth_service.get_authorization_url(provider, state)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("OAuth authorization failed for provider: %s", provider)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="OAuth authorization failed",
|
|
) from e
|
|
else:
|
|
return {
|
|
"authorization_url": auth_url,
|
|
"state": state,
|
|
}
|
|
|
|
|
|
@router.get("/{provider}/callback")
|
|
async def oauth_callback(
|
|
provider: str,
|
|
response: Response,
|
|
code: Annotated[str, Query()],
|
|
oauth_service: Annotated[OAuthService, Depends(get_oauth_service)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> RedirectResponse:
|
|
"""Handle OAuth callback."""
|
|
try:
|
|
logger.info("OAuth callback started for provider: %s", provider)
|
|
|
|
# Handle OAuth callback and get user info
|
|
oauth_user_info = await oauth_service.handle_callback(provider, code)
|
|
|
|
# Perform OAuth login (link or create user)
|
|
auth_response = await auth_service.oauth_login(oauth_user_info)
|
|
|
|
# Create and store refresh token
|
|
user = await auth_service.get_current_user(auth_response.user.id)
|
|
refresh_token = await auth_service.create_and_store_refresh_token(user)
|
|
|
|
# Set HTTP-only cookies for both tokens (not used due to cross-port issues)
|
|
# These cookies are kept for potential future same-origin scenarios
|
|
|
|
set_auth_cookies(
|
|
response=response,
|
|
access_token=auth_response.token.access_token,
|
|
refresh_token=refresh_token,
|
|
expires_in=auth_response.token.expires_in,
|
|
path="/", # Ensure cookie is available for all paths
|
|
)
|
|
|
|
logger.info(
|
|
"OAuth login successful for user: %s via %s",
|
|
auth_response.user.email,
|
|
provider,
|
|
)
|
|
|
|
# Instead of setting cookies that won't work across ports,
|
|
# let's redirect to a special frontend endpoint that can make an
|
|
# immediate API call. Frontend will call /exchange-oauth-token with this code
|
|
temp_code = secrets.token_urlsafe(32)
|
|
|
|
# Store the mapping temporarily (in production, use Redis with TTL)
|
|
# For now, store in memory with the user data
|
|
_temp_oauth_codes[temp_code] = {
|
|
"user_id": auth_response.user.id,
|
|
"access_token": auth_response.token.access_token,
|
|
"refresh_token": refresh_token,
|
|
"expires_in": auth_response.token.expires_in,
|
|
"created_at": time.time(),
|
|
}
|
|
|
|
redirect_url = f"{settings.FRONTEND_URL}/auth/callback?code={temp_code}"
|
|
logger.info("Redirecting to: %s", redirect_url)
|
|
|
|
return RedirectResponse(
|
|
url=redirect_url,
|
|
status_code=302,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("OAuth callback failed for provider: %s", provider)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="OAuth callback failed",
|
|
) from e
|
|
|
|
|
|
@router.get("/providers")
|
|
async def get_oauth_providers() -> dict[str, list[str]]:
|
|
"""Get list of available OAuth providers."""
|
|
return {
|
|
"providers": ["google", "github"],
|
|
}
|
|
|
|
|
|
@router.post("/exchange-oauth-token")
|
|
async def exchange_oauth_token(
|
|
request: dict[str, str],
|
|
response: Response,
|
|
) -> dict[str, str]:
|
|
"""Exchange temporary OAuth code for proper auth cookies."""
|
|
code = request.get("code")
|
|
if not code:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Code parameter is required",
|
|
)
|
|
|
|
logger.info("OAuth token exchange requested with code: %s", code[:10] + "...")
|
|
|
|
# Get the stored token data
|
|
if code not in _temp_oauth_codes:
|
|
logger.error("Invalid or expired OAuth code: %s", code[:10] + "...")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid or expired OAuth code",
|
|
)
|
|
|
|
token_data = _temp_oauth_codes.pop(code) # Remove after use
|
|
|
|
# Check if code is too old (5 minutes max)
|
|
code_expiry_seconds = 300
|
|
if time.time() - token_data["created_at"] > code_expiry_seconds:
|
|
logger.error("OAuth code expired: %s", code[:10] + "...")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="OAuth code expired",
|
|
)
|
|
|
|
# Set the proper auth cookies
|
|
set_auth_cookies(
|
|
response=response,
|
|
access_token=token_data["access_token"],
|
|
refresh_token=token_data["refresh_token"],
|
|
expires_in=token_data["expires_in"],
|
|
path="/",
|
|
)
|
|
|
|
user_id = token_data["user_id"]
|
|
logger.info("OAuth tokens exchanged successfully for user: %s", user_id)
|
|
return {"message": "Tokens set successfully", "user_id": str(user_id)}
|
|
|
|
|
|
# API Token endpoints
|
|
@router.post("/api-token")
|
|
async def generate_api_token(
|
|
request: ApiTokenRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> ApiTokenResponse:
|
|
"""Generate a new API token for the current user."""
|
|
try:
|
|
api_token = await auth_service.generate_api_token(
|
|
current_user,
|
|
expires_days=request.expires_days,
|
|
)
|
|
|
|
# Refresh user to get updated token info
|
|
await auth_service.session.refresh(current_user)
|
|
|
|
return ApiTokenResponse(
|
|
api_token=api_token,
|
|
expires_at=current_user.api_token_expires_at,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(
|
|
"Failed to generate API token for user: %s",
|
|
current_user.email,
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to generate API token",
|
|
) from e
|
|
|
|
|
|
@router.get("/api-token/status")
|
|
async def get_api_token_status(
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
) -> ApiTokenStatusResponse:
|
|
"""Get the current user's API token status."""
|
|
has_token = current_user.api_token is not None
|
|
is_expired = False
|
|
|
|
if has_token and current_user.api_token_expires_at:
|
|
is_expired = TokenUtils.is_token_expired(current_user.api_token_expires_at)
|
|
|
|
return ApiTokenStatusResponse(
|
|
has_token=has_token,
|
|
expires_at=current_user.api_token_expires_at,
|
|
is_expired=is_expired,
|
|
)
|
|
|
|
|
|
@router.delete("/api-token")
|
|
async def revoke_api_token(
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> dict[str, str]:
|
|
"""Revoke the current user's API token."""
|
|
try:
|
|
await auth_service.revoke_api_token(current_user)
|
|
except Exception as e:
|
|
logger.exception(
|
|
"Failed to revoke API token for user: %s",
|
|
current_user.email,
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to revoke API token",
|
|
) from e
|
|
else:
|
|
return {"message": "API token revoked successfully"}
|
|
|
|
|
|
# Profile management endpoints
|
|
@router.patch("/me")
|
|
async def update_profile(
|
|
request: UpdateProfileRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> UserResponse:
|
|
"""Update the current user's profile."""
|
|
try:
|
|
updated_user = await auth_service.update_user_profile(
|
|
current_user, request.model_dump(exclude_unset=True)
|
|
)
|
|
return await auth_service.user_to_response(updated_user)
|
|
except Exception as e:
|
|
logger.exception("Failed to update profile for user: %s", current_user.email)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to update profile",
|
|
) from e
|
|
|
|
|
|
@router.post("/change-password")
|
|
async def change_password(
|
|
request: ChangePasswordRequest,
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> dict[str, str]:
|
|
"""Change the current user's password."""
|
|
# Store user email before operations to avoid session detachment issues
|
|
user_email = current_user.email
|
|
try:
|
|
await auth_service.change_user_password(
|
|
current_user, request.current_password, request.new_password
|
|
)
|
|
return {"message": "Password changed successfully"}
|
|
except ValueError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=str(e),
|
|
) from e
|
|
except Exception as e:
|
|
logger.exception("Failed to change password for user: %s", user_email)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to change password",
|
|
) from e
|
|
|
|
|
|
@router.get("/user-providers")
|
|
async def get_user_providers(
|
|
current_user: Annotated[User, Depends(get_current_active_user)],
|
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
|
) -> list[dict[str, str]]:
|
|
"""Get the current user's connected authentication providers."""
|
|
providers = []
|
|
|
|
# Add password provider if user has password
|
|
if current_user.password_hash:
|
|
providers.append({
|
|
"provider": "password",
|
|
"display_name": "Password",
|
|
"connected_at": current_user.created_at.isoformat(),
|
|
})
|
|
|
|
# Get OAuth providers from the database
|
|
oauth_providers = await auth_service.get_user_oauth_providers(current_user)
|
|
for oauth in oauth_providers:
|
|
display_name = oauth.provider.title() # Capitalize first letter
|
|
if oauth.provider == "github":
|
|
display_name = "GitHub"
|
|
elif oauth.provider == "google":
|
|
display_name = "Google"
|
|
|
|
providers.append({
|
|
"provider": oauth.provider,
|
|
"display_name": display_name,
|
|
"connected_at": oauth.created_at.isoformat(),
|
|
})
|
|
|
|
return providers
|