auth email/password
This commit is contained in:
@@ -1,19 +1,46 @@
|
||||
import os
|
||||
from datetime import timedelta
|
||||
|
||||
from flask import Flask
|
||||
from flask_jwt_extended import JWTManager
|
||||
|
||||
from app.services.auth_service import AuthService
|
||||
from app.database import init_db
|
||||
|
||||
# Global auth service instance
|
||||
auth_service = AuthService()
|
||||
|
||||
|
||||
def create_app():
|
||||
"""Create and configure the Flask application."""
|
||||
app = Flask(__name__)
|
||||
|
||||
# Configure session
|
||||
# Configure Flask secret key (required for sessions used by OAuth)
|
||||
app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "dev-secret-key")
|
||||
|
||||
# Initialize authentication service
|
||||
auth_service = AuthService(app)
|
||||
# Configure SQLAlchemy database
|
||||
database_url = os.environ.get("DATABASE_URL", "sqlite:///soundboard.db")
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = database_url
|
||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
||||
|
||||
# Configure Flask-JWT-Extended
|
||||
app.config["JWT_SECRET_KEY"] = os.environ.get("JWT_SECRET_KEY", "jwt-secret-key")
|
||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
|
||||
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)
|
||||
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
|
||||
app.config["JWT_COOKIE_SECURE"] = False # Set to True in production
|
||||
app.config["JWT_COOKIE_CSRF_PROTECT"] = False
|
||||
app.config["JWT_ACCESS_COOKIE_PATH"] = "/api/"
|
||||
app.config["JWT_REFRESH_COOKIE_PATH"] = "/api/auth/refresh"
|
||||
|
||||
# Initialize JWT manager
|
||||
jwt = JWTManager(app)
|
||||
|
||||
# Initialize database
|
||||
init_db(app)
|
||||
|
||||
# Initialize authentication service with app
|
||||
auth_service.init_app(app)
|
||||
|
||||
# Register blueprints
|
||||
from app.routes import main, auth
|
||||
|
||||
18
app/database.py
Normal file
18
app/database.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Database configuration and initialization."""
|
||||
|
||||
from flask_sqlalchemy import SQLAlchemy
|
||||
from flask_migrate import Migrate
|
||||
|
||||
db = SQLAlchemy()
|
||||
migrate = Migrate()
|
||||
|
||||
|
||||
def init_db(app):
|
||||
"""Initialize database with Flask app."""
|
||||
db.init_app(app)
|
||||
migrate.init_app(app, db)
|
||||
|
||||
# Import models here to ensure they are registered with SQLAlchemy
|
||||
from app.models import user, user_oauth # noqa: F401
|
||||
|
||||
return db
|
||||
6
app/models/__init__.py
Normal file
6
app/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Database models."""
|
||||
|
||||
from .user import User
|
||||
from .user_oauth import UserOAuth
|
||||
|
||||
__all__ = ["User", "UserOAuth"]
|
||||
240
app/models/user.py
Normal file
240
app/models/user.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""User model for authentication."""
|
||||
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
from sqlalchemy import String, DateTime
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user_oauth import UserOAuth
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
"""User model for storing user information."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
# Primary user information (can be updated from any connected provider)
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||
|
||||
# Password authentication (optional - users can use OAuth instead)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# Role-based access control
|
||||
role: Mapped[str] = mapped_column(String(50), nullable=False, default="user")
|
||||
|
||||
# User status
|
||||
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
||||
|
||||
# API token for programmatic access
|
||||
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
||||
"UserOAuth", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of User."""
|
||||
provider_count = len(self.oauth_providers)
|
||||
return f"<User {self.email} ({provider_count} providers)>"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert user to dictionary."""
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"email": self.email,
|
||||
"name": self.name,
|
||||
"picture": self.picture,
|
||||
"role": self.role,
|
||||
"is_active": self.is_active,
|
||||
"api_token": self.api_token,
|
||||
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
|
||||
"providers": [provider.provider for provider in self.oauth_providers],
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
def get_provider(self, provider_name: str) -> Optional["UserOAuth"]:
|
||||
"""Get specific OAuth provider for this user."""
|
||||
for provider in self.oauth_providers:
|
||||
if provider.provider == provider_name:
|
||||
return provider
|
||||
return None
|
||||
|
||||
def has_provider(self, provider_name: str) -> bool:
|
||||
"""Check if user has specific OAuth provider connected."""
|
||||
return self.get_provider(provider_name) is not None
|
||||
|
||||
def update_from_provider(self, provider_data: dict) -> None:
|
||||
"""Update user info from provider data (email, name, picture)."""
|
||||
self.email = provider_data.get("email", self.email)
|
||||
self.name = provider_data.get("name", self.name)
|
||||
self.picture = provider_data.get("picture", self.picture)
|
||||
self.updated_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def set_password(self, password: str) -> None:
|
||||
"""Hash and set user password."""
|
||||
self.password_hash = generate_password_hash(password)
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def check_password(self, password: str) -> bool:
|
||||
"""Check if provided password matches user's password."""
|
||||
if not self.password_hash:
|
||||
return False
|
||||
return check_password_hash(self.password_hash, password)
|
||||
|
||||
def has_password(self) -> bool:
|
||||
"""Check if user has a password set."""
|
||||
return self.password_hash is not None
|
||||
|
||||
def generate_api_token(self) -> str:
|
||||
"""Generate a new API token for the user."""
|
||||
self.api_token = secrets.token_urlsafe(32)
|
||||
self.api_token_expires_at = None # No expiration by default
|
||||
self.updated_at = datetime.utcnow()
|
||||
return self.api_token
|
||||
|
||||
def is_api_token_valid(self) -> bool:
|
||||
"""Check if the user's API token is valid (exists and not expired)."""
|
||||
if not self.api_token:
|
||||
return False
|
||||
|
||||
if self.api_token_expires_at is None:
|
||||
return True # No expiration
|
||||
|
||||
return datetime.utcnow() < self.api_token_expires_at
|
||||
|
||||
def revoke_api_token(self) -> None:
|
||||
"""Revoke the user's API token."""
|
||||
self.api_token = None
|
||||
self.api_token_expires_at = None
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def activate(self) -> None:
|
||||
"""Activate the user account."""
|
||||
self.is_active = True
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
def deactivate(self) -> None:
|
||||
"""Deactivate the user account."""
|
||||
self.is_active = False
|
||||
self.updated_at = datetime.utcnow()
|
||||
|
||||
@classmethod
|
||||
def find_by_email(cls, email: str) -> Optional["User"]:
|
||||
"""Find user by email address."""
|
||||
return cls.query.filter_by(email=email).first()
|
||||
|
||||
@classmethod
|
||||
def find_by_api_token(cls, api_token: str) -> Optional["User"]:
|
||||
"""Find user by API token if token is valid."""
|
||||
user = cls.query.filter_by(api_token=api_token).first()
|
||||
if user and user.is_api_token_valid():
|
||||
return user
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_or_create_from_oauth(
|
||||
cls, provider: str, provider_id: str, email: str, name: str, picture: Optional[str] = None
|
||||
) -> tuple["User", "UserOAuth"]:
|
||||
"""Find existing user or create new one from OAuth data."""
|
||||
from app.models.user_oauth import UserOAuth
|
||||
|
||||
# First, try to find existing OAuth provider
|
||||
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id)
|
||||
|
||||
if oauth_provider:
|
||||
# Update existing provider and user info
|
||||
user = oauth_provider.user
|
||||
oauth_provider.email = email
|
||||
oauth_provider.name = name
|
||||
oauth_provider.picture = picture
|
||||
oauth_provider.updated_at = datetime.utcnow()
|
||||
|
||||
# Update user info with latest data
|
||||
user.update_from_provider({"email": email, "name": name, "picture": picture})
|
||||
else:
|
||||
# Try to find user by email to link the new provider
|
||||
user = cls.find_by_email(email)
|
||||
|
||||
if not user:
|
||||
# Check if this is the first user (admin)
|
||||
user_count = cls.query.count()
|
||||
role = "admin" if user_count == 0 else "user"
|
||||
|
||||
# Create new user
|
||||
user = cls(
|
||||
email=email,
|
||||
name=name,
|
||||
picture=picture,
|
||||
role=role,
|
||||
)
|
||||
user.generate_api_token() # Generate API token on creation
|
||||
db.session.add(user)
|
||||
db.session.flush() # Flush to get user.id
|
||||
|
||||
# Create new OAuth provider
|
||||
oauth_provider = UserOAuth.create_or_update(
|
||||
user_id=user.id,
|
||||
provider=provider,
|
||||
provider_id=provider_id,
|
||||
email=email,
|
||||
name=name,
|
||||
picture=picture,
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
return user, oauth_provider
|
||||
|
||||
@classmethod
|
||||
def create_with_password(cls, email: str, password: str, name: str) -> "User":
|
||||
"""Create new user with email and password."""
|
||||
# Check if user already exists
|
||||
existing_user = cls.find_by_email(email)
|
||||
if existing_user:
|
||||
raise ValueError("User with this email already exists")
|
||||
|
||||
# Check if this is the first user (admin)
|
||||
user_count = cls.query.count()
|
||||
role = "admin" if user_count == 0 else "user"
|
||||
|
||||
# Create new user
|
||||
user = cls(
|
||||
email=email,
|
||||
name=name,
|
||||
role=role,
|
||||
)
|
||||
user.set_password(password)
|
||||
user.generate_api_token() # Generate API token on creation
|
||||
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def authenticate_with_password(cls, email: str, password: str) -> Optional["User"]:
|
||||
"""Authenticate user with email and password."""
|
||||
user = cls.find_by_email(email)
|
||||
if user and user.check_password(password) and user.is_active:
|
||||
return user
|
||||
return None
|
||||
105
app/models/user_oauth.py
Normal file
105
app/models/user_oauth.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""User OAuth model for storing user's connected providers."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import String, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class UserOAuth(db.Model):
|
||||
"""Model for storing user's connected OAuth providers."""
|
||||
|
||||
__tablename__ = "user_oauth"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
# User relationship
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# OAuth provider information
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
provider_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
# Provider-specific user information
|
||||
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
# Unique constraint on provider + provider_id combination
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("provider", "provider_id", name="unique_provider_user"),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship("User", back_populates="oauth_providers")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of UserOAuth."""
|
||||
return f"<UserOAuth {self.email} ({self.provider})>"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert oauth provider to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"provider": self.provider,
|
||||
"provider_id": self.provider_id,
|
||||
"email": self.email,
|
||||
"name": self.name,
|
||||
"picture": self.picture,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def find_by_provider_and_id(cls, provider: str, provider_id: str) -> Optional["UserOAuth"]:
|
||||
"""Find OAuth provider by provider name and provider ID."""
|
||||
return cls.query.filter_by(provider=provider, provider_id=provider_id).first()
|
||||
|
||||
@classmethod
|
||||
def create_or_update(
|
||||
cls,
|
||||
user_id: int,
|
||||
provider: str,
|
||||
provider_id: str,
|
||||
email: str,
|
||||
name: str,
|
||||
picture: Optional[str] = None
|
||||
) -> "UserOAuth":
|
||||
"""Create new OAuth provider or update existing one."""
|
||||
oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
|
||||
|
||||
if oauth_provider:
|
||||
# Update existing provider
|
||||
oauth_provider.user_id = user_id
|
||||
oauth_provider.email = email
|
||||
oauth_provider.name = name
|
||||
oauth_provider.picture = picture
|
||||
oauth_provider.updated_at = datetime.utcnow()
|
||||
else:
|
||||
# Create new provider
|
||||
oauth_provider = cls(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
provider_id=provider_id,
|
||||
email=email,
|
||||
name=name,
|
||||
picture=picture,
|
||||
)
|
||||
db.session.add(oauth_provider)
|
||||
|
||||
db.session.commit()
|
||||
return oauth_provider
|
||||
@@ -1,31 +1,54 @@
|
||||
"""Authentication routes."""
|
||||
|
||||
from flask import Blueprint, url_for
|
||||
from flask import Blueprint, jsonify, url_for
|
||||
from flask_jwt_extended import create_access_token, get_jwt_identity, jwt_required
|
||||
|
||||
from app.services.auth_service import AuthService
|
||||
from app import auth_service
|
||||
from app.services.decorators import get_current_user
|
||||
|
||||
bp = Blueprint("auth", __name__)
|
||||
auth_service = AuthService()
|
||||
|
||||
|
||||
@bp.route("/login")
|
||||
def login() -> dict[str, str]:
|
||||
"""Initiate Google OAuth login."""
|
||||
redirect_uri = url_for("auth.callback", _external=True)
|
||||
login_url = auth_service.get_login_url(redirect_uri)
|
||||
return {"login_url": login_url}
|
||||
@bp.route("/login/<provider>")
|
||||
def login_oauth(provider):
|
||||
"""Initiate OAuth login for specified provider."""
|
||||
redirect_uri = url_for("auth.callback", provider=provider, _external=True)
|
||||
return auth_service.redirect_to_login(provider, redirect_uri)
|
||||
|
||||
|
||||
@bp.route("/callback")
|
||||
def callback():
|
||||
"""Handle OAuth callback from Google."""
|
||||
@bp.route("/callback/<provider>")
|
||||
def callback(provider):
|
||||
"""Handle OAuth callback from specified provider."""
|
||||
try:
|
||||
user_data, response = auth_service.handle_callback()
|
||||
return response
|
||||
return auth_service.handle_callback(provider)
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
@bp.route("/providers")
|
||||
def providers():
|
||||
"""Get list of available OAuth providers."""
|
||||
return {"providers": auth_service.get_available_providers()}
|
||||
|
||||
|
||||
@bp.route("/login", methods=["POST"])
|
||||
def login():
|
||||
"""Login user with email and password."""
|
||||
from flask import request
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return {"error": "No data provided"}, 400
|
||||
|
||||
email = data.get("email")
|
||||
password = data.get("password")
|
||||
|
||||
if not email or not password:
|
||||
return {"error": "Email and password are required"}, 400
|
||||
|
||||
return auth_service.login_with_password(email, password)
|
||||
|
||||
|
||||
@bp.route("/logout")
|
||||
def logout():
|
||||
"""Logout current user."""
|
||||
@@ -33,20 +56,182 @@ def logout():
|
||||
|
||||
|
||||
@bp.route("/me")
|
||||
def me() -> dict[str, str] | tuple[dict[str, str], int]:
|
||||
@jwt_required()
|
||||
def me():
|
||||
"""Get current user information."""
|
||||
user = auth_service.get_current_user()
|
||||
if not user:
|
||||
return {"error": "Not authenticated"}, 401
|
||||
|
||||
user = get_current_user()
|
||||
return {"user": user}
|
||||
|
||||
|
||||
@bp.route("/refresh")
|
||||
@bp.route("/refresh", methods=["POST"])
|
||||
@jwt_required(refresh=True)
|
||||
def refresh():
|
||||
"""Refresh access token using refresh token."""
|
||||
response = auth_service.refresh_tokens()
|
||||
if not response:
|
||||
return {"error": "Invalid or expired refresh token"}, 401
|
||||
current_user_id = get_jwt_identity()
|
||||
|
||||
# Create new access token
|
||||
new_access_token = create_access_token(identity=current_user_id)
|
||||
|
||||
response = jsonify({"message": "Token refreshed"})
|
||||
|
||||
# Set new access token cookie
|
||||
from flask_jwt_extended import set_access_cookies
|
||||
set_access_cookies(response, new_access_token)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@bp.route("/link/<provider>")
|
||||
@jwt_required()
|
||||
def link_provider(provider):
|
||||
"""Link a new OAuth provider to current user account."""
|
||||
redirect_uri = url_for("auth.link_callback", provider=provider, _external=True)
|
||||
return auth_service.redirect_to_login(provider, redirect_uri)
|
||||
|
||||
|
||||
@bp.route("/link/callback/<provider>")
|
||||
@jwt_required()
|
||||
def link_callback(provider):
|
||||
"""Handle OAuth callback for linking new provider."""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
if not current_user_id:
|
||||
return {"error": "User not authenticated"}, 401
|
||||
|
||||
# Get current user from database
|
||||
from app.models.user import User
|
||||
user = User.query.get(current_user_id)
|
||||
if not user:
|
||||
return {"error": "User not found"}, 404
|
||||
|
||||
# Process OAuth callback but link to existing user
|
||||
from app.services.oauth_providers.registry import OAuthProviderRegistry
|
||||
from authlib.integrations.flask_client import OAuth
|
||||
|
||||
oauth = OAuth()
|
||||
registry = OAuthProviderRegistry(oauth)
|
||||
oauth_provider = registry.get_provider(provider)
|
||||
|
||||
if not oauth_provider:
|
||||
return {"error": f"OAuth provider '{provider}' not configured"}, 400
|
||||
|
||||
token = oauth_provider.exchange_code_for_token(None, None)
|
||||
raw_user_info = oauth_provider.get_user_info(token)
|
||||
provider_data = oauth_provider.normalize_user_data(raw_user_info)
|
||||
|
||||
if not provider_data.get("id"):
|
||||
return {"error": "Failed to get user information from provider"}, 400
|
||||
|
||||
# Check if this provider is already linked to another user
|
||||
from app.models.user_oauth import UserOAuth
|
||||
existing_provider = UserOAuth.find_by_provider_and_id(
|
||||
provider, provider_data["id"]
|
||||
)
|
||||
|
||||
if existing_provider and existing_provider.user_id != user.id:
|
||||
return {"error": "This provider account is already linked to another user"}, 409
|
||||
|
||||
# Link the provider to current user
|
||||
UserOAuth.create_or_update(
|
||||
user_id=user.id,
|
||||
provider=provider,
|
||||
provider_id=provider_data["id"],
|
||||
email=provider_data["email"],
|
||||
name=provider_data["name"],
|
||||
picture=provider_data.get("picture")
|
||||
)
|
||||
|
||||
return {"message": f"{provider.title()} account linked successfully"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
@bp.route("/unlink/<provider>", methods=["DELETE"])
|
||||
@jwt_required()
|
||||
def unlink_provider(provider):
|
||||
"""Unlink an OAuth provider from current user account."""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
if not current_user_id:
|
||||
return {"error": "User not authenticated"}, 401
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.user_oauth import UserOAuth
|
||||
from app.database import db
|
||||
|
||||
user = User.query.get(current_user_id)
|
||||
if not user:
|
||||
return {"error": "User not found"}, 404
|
||||
|
||||
# Check if user has more than one provider (prevent locking out)
|
||||
if len(user.oauth_providers) <= 1:
|
||||
return {"error": "Cannot unlink last authentication provider"}, 400
|
||||
|
||||
# Find and remove the provider
|
||||
oauth_provider = user.get_provider(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": f"Provider '{provider}' not linked to this account"}, 404
|
||||
|
||||
db.session.delete(oauth_provider)
|
||||
db.session.commit()
|
||||
|
||||
return {"message": f"{provider.title()} account unlinked successfully"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
@bp.route("/register", methods=["POST"])
|
||||
def register():
|
||||
"""Register new user with email and password."""
|
||||
from flask import request
|
||||
|
||||
data = request.get_json()
|
||||
if not data:
|
||||
return {"error": "No data provided"}, 400
|
||||
|
||||
email = data.get("email")
|
||||
password = data.get("password")
|
||||
name = data.get("name")
|
||||
|
||||
if not email or not password or not name:
|
||||
return {"error": "Email, password, and name are required"}, 400
|
||||
|
||||
# Basic email validation
|
||||
if "@" not in email or "." not in email:
|
||||
return {"error": "Invalid email format"}, 400
|
||||
|
||||
# Basic password validation
|
||||
if len(password) < 6:
|
||||
return {"error": "Password must be at least 6 characters long"}, 400
|
||||
|
||||
return auth_service.register_with_password(email, password, name)
|
||||
|
||||
|
||||
@bp.route("/regenerate-api-token", methods=["POST"])
|
||||
@jwt_required()
|
||||
def regenerate_api_token():
|
||||
"""Regenerate API token for current user."""
|
||||
current_user_id = get_jwt_identity()
|
||||
if not current_user_id:
|
||||
return {"error": "User not authenticated"}, 401
|
||||
|
||||
from app.models.user import User
|
||||
from app.database import db
|
||||
|
||||
user = User.query.get(current_user_id)
|
||||
if not user:
|
||||
return {"error": "User not found"}, 404
|
||||
|
||||
# Generate new API token
|
||||
new_token = user.generate_api_token()
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"message": "API token regenerated successfully",
|
||||
"api_token": new_token,
|
||||
"expires_at": user.api_token_expires_at.isoformat() if user.api_token_expires_at else None
|
||||
}
|
||||
|
||||
|
||||
return response
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from flask import Blueprint
|
||||
|
||||
from app.services.decorators import get_current_user, require_auth
|
||||
from app.services.decorators import get_current_user, require_auth, require_admin, require_auth_or_api_token, get_user_from_api_token
|
||||
from app.services.greeting_service import GreetingService
|
||||
|
||||
bp = Blueprint("main", __name__)
|
||||
@@ -24,7 +24,7 @@ def hello(name: str | None = None) -> dict[str, str]:
|
||||
@bp.route("/protected")
|
||||
@require_auth
|
||||
def protected() -> dict[str, str]:
|
||||
"""Protected endpoint that requires authentication."""
|
||||
"""Protected endpoint that requires JWT authentication."""
|
||||
user = get_current_user()
|
||||
return {
|
||||
"message": f"Hello {user['name']}, this is a protected endpoint!",
|
||||
@@ -32,6 +32,33 @@ def protected() -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
@bp.route("/api-protected")
|
||||
@require_auth_or_api_token
|
||||
def api_protected() -> dict[str, str]:
|
||||
"""Protected endpoint that accepts JWT or API token authentication."""
|
||||
# Try to get user from JWT first, then API token
|
||||
user = get_current_user()
|
||||
if not user:
|
||||
user = get_user_from_api_token()
|
||||
|
||||
return {
|
||||
"message": f"Hello {user['name']}, you accessed this via {user['provider']}!",
|
||||
"user": user
|
||||
}
|
||||
|
||||
|
||||
@bp.route("/admin")
|
||||
@require_admin
|
||||
def admin_only() -> dict[str, str]:
|
||||
"""Admin-only endpoint to demonstrate role-based access."""
|
||||
user = get_current_user()
|
||||
return {
|
||||
"message": f"Hello admin {user['name']}, you have admin access!",
|
||||
"user": user,
|
||||
"admin_info": "This endpoint is only accessible to admin users"
|
||||
}
|
||||
|
||||
|
||||
@bp.route("/health")
|
||||
def health() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
|
||||
@@ -1,21 +1,29 @@
|
||||
"""Authentication service for Google OAuth."""
|
||||
"""Authentication service for multiple OAuth providers."""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from authlib.integrations.flask_client import OAuth
|
||||
from flask import Flask, make_response, request
|
||||
from flask import Flask, jsonify
|
||||
from flask_jwt_extended import (
|
||||
get_jwt_identity,
|
||||
jwt_required,
|
||||
set_access_cookies,
|
||||
set_refresh_cookies,
|
||||
unset_jwt_cookies,
|
||||
)
|
||||
|
||||
from app.models.user import User
|
||||
from app.services.oauth_providers.registry import OAuthProviderRegistry
|
||||
from app.services.token_service import TokenService
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Service for handling Google OAuth authentication."""
|
||||
"""Service for handling multiple OAuth providers authentication."""
|
||||
|
||||
def __init__(self, app: Flask | None = None) -> None:
|
||||
"""Initialize the authentication service."""
|
||||
self.oauth = OAuth()
|
||||
self.google = None
|
||||
self.provider_registry = None
|
||||
self.token_service = TokenService()
|
||||
if app:
|
||||
self.init_app(app)
|
||||
@@ -24,120 +32,219 @@ class AuthService:
|
||||
"""Initialize the service with Flask app."""
|
||||
self.oauth.init_app(app)
|
||||
|
||||
# Configure Google OAuth
|
||||
self.google = self.oauth.register(
|
||||
name="google",
|
||||
client_id=os.getenv("GOOGLE_CLIENT_ID"),
|
||||
client_secret=os.getenv("GOOGLE_CLIENT_SECRET"),
|
||||
server_metadata_url="https://accounts.google.com/.well-known/openid_configuration",
|
||||
client_kwargs={"scope": "openid email profile"},
|
||||
)
|
||||
# Initialize provider registry
|
||||
self.provider_registry = OAuthProviderRegistry(self.oauth)
|
||||
|
||||
def get_login_url(self, redirect_uri: str) -> str:
|
||||
"""Generate Google OAuth login URL."""
|
||||
if not self.google:
|
||||
msg = "Google OAuth not configured"
|
||||
def redirect_to_login(self, provider_name: str, redirect_uri: str):
|
||||
"""Redirect to OAuth provider login."""
|
||||
provider = self.provider_registry.get_provider(provider_name)
|
||||
if not provider:
|
||||
msg = f"OAuth provider '{provider_name}' not configured"
|
||||
raise RuntimeError(msg)
|
||||
return self.google.authorize_redirect(redirect_uri).location
|
||||
|
||||
def handle_callback(self) -> tuple[dict[str, Any], Any]:
|
||||
client = provider.get_client()
|
||||
return client.authorize_redirect(redirect_uri)
|
||||
|
||||
def get_login_url(self, provider_name: str, redirect_uri: str) -> str:
|
||||
"""Generate OAuth provider login URL (for testing or manual use)."""
|
||||
provider = self.provider_registry.get_provider(provider_name)
|
||||
if not provider:
|
||||
msg = f"OAuth provider '{provider_name}' not configured"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
return provider.get_authorization_url(redirect_uri)
|
||||
|
||||
def handle_callback(self, provider_name: str) -> Any:
|
||||
"""Handle OAuth callback and exchange code for token."""
|
||||
if not self.google:
|
||||
msg = "Google OAuth not configured"
|
||||
provider = self.provider_registry.get_provider(provider_name)
|
||||
if not provider:
|
||||
msg = f"OAuth provider '{provider_name}' not configured"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
token = self.google.authorize_access_token()
|
||||
user_info = token.get("userinfo")
|
||||
token = provider.exchange_code_for_token(None, None)
|
||||
raw_user_info = provider.get_user_info(token)
|
||||
user_data = provider.normalize_user_data(raw_user_info)
|
||||
|
||||
if user_info:
|
||||
user_data = {
|
||||
"id": user_info["sub"],
|
||||
"email": user_info["email"],
|
||||
"name": user_info["name"],
|
||||
"picture": user_info.get("picture"),
|
||||
if user_data and user_data.get("id"):
|
||||
# Find or create user in database
|
||||
user, oauth_provider = User.find_or_create_from_oauth(
|
||||
provider=provider_name,
|
||||
provider_id=user_data["id"],
|
||||
email=user_data["email"],
|
||||
name=user_data["name"],
|
||||
picture=user_data.get("picture"),
|
||||
)
|
||||
|
||||
# Check if user account is active
|
||||
if not user.is_active:
|
||||
response = jsonify({"error": "Account is disabled"})
|
||||
response.status_code = 401
|
||||
return response
|
||||
|
||||
# Prepare user data for JWT token
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": oauth_provider.provider,
|
||||
"providers": [p.provider for p in user.oauth_providers],
|
||||
}
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(user_data)
|
||||
refresh_token = self.token_service.generate_refresh_token(user_data)
|
||||
access_token = self.token_service.generate_access_token(
|
||||
jwt_user_data
|
||||
)
|
||||
refresh_token = self.token_service.generate_refresh_token(
|
||||
jwt_user_data
|
||||
)
|
||||
|
||||
# Create response and set HTTP-only cookies
|
||||
response = make_response({
|
||||
"message": "Login successful",
|
||||
"user": user_data,
|
||||
})
|
||||
|
||||
response.set_cookie(
|
||||
"access_token",
|
||||
access_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="Lax",
|
||||
max_age=15 * 60, # 15 minutes
|
||||
response = jsonify(
|
||||
{
|
||||
"message": "Login successful",
|
||||
"user": jwt_user_data,
|
||||
}
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
"refresh_token",
|
||||
refresh_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="Lax",
|
||||
max_age=7 * 24 * 60 * 60, # 7 days
|
||||
)
|
||||
# Set JWT cookies
|
||||
set_access_cookies(response, access_token)
|
||||
set_refresh_cookies(response, refresh_token)
|
||||
|
||||
return user_data, response
|
||||
return response
|
||||
|
||||
msg = "Failed to get user information from Google"
|
||||
msg = f"Failed to get user information from {provider.display_name}"
|
||||
raise ValueError(msg)
|
||||
|
||||
def get_current_user(self) -> dict[str, Any] | None:
|
||||
"""Get current user from access token."""
|
||||
access_token = request.cookies.get("access_token")
|
||||
if not access_token:
|
||||
return None
|
||||
def get_available_providers(self) -> dict[str, Any]:
|
||||
"""Get list of available OAuth providers."""
|
||||
if not self.provider_registry:
|
||||
return {}
|
||||
|
||||
return self.token_service.get_user_from_access_token(access_token)
|
||||
|
||||
def refresh_tokens(self) -> Any:
|
||||
"""Refresh access token using refresh token."""
|
||||
refresh_token = request.cookies.get("refresh_token")
|
||||
if not refresh_token:
|
||||
return None
|
||||
|
||||
payload = self.token_service.verify_token(refresh_token)
|
||||
if not payload or not self.token_service.is_refresh_token(payload):
|
||||
return None
|
||||
|
||||
# For refresh, we need to get user data (in a real app, from database)
|
||||
# For now, we'll extract what we can from the refresh token
|
||||
user_data = {
|
||||
"id": payload["user_id"],
|
||||
"email": "", # Would need to fetch from database
|
||||
"name": "", # Would need to fetch from database
|
||||
providers = self.provider_registry.get_available_providers()
|
||||
return {
|
||||
name: {"name": provider.name, "display_name": provider.display_name}
|
||||
for name, provider in providers.items()
|
||||
}
|
||||
|
||||
# Generate new access token
|
||||
new_access_token = self.token_service.generate_access_token(user_data)
|
||||
@jwt_required()
|
||||
def get_current_user(self) -> dict[str, Any] | None:
|
||||
"""Get current user from JWT token."""
|
||||
from flask_jwt_extended import get_jwt
|
||||
|
||||
response = make_response({"message": "Token refreshed"})
|
||||
response.set_cookie(
|
||||
"access_token",
|
||||
new_access_token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="Lax",
|
||||
max_age=15 * 60, # 15 minutes
|
||||
current_user_id = get_jwt_identity()
|
||||
claims = get_jwt()
|
||||
|
||||
if current_user_id:
|
||||
return {
|
||||
"id": current_user_id,
|
||||
"email": claims.get("email", ""),
|
||||
"name": claims.get("name", ""),
|
||||
"picture": claims.get("picture"),
|
||||
"role": claims.get("role", "user"),
|
||||
"is_active": claims.get("is_active", True),
|
||||
"provider": claims.get("provider", "unknown"),
|
||||
"providers": claims.get("providers", []),
|
||||
}
|
||||
return None
|
||||
|
||||
def register_with_password(
|
||||
self, email: str, password: str, name: str
|
||||
) -> Any:
|
||||
"""Register new user with email and password."""
|
||||
try:
|
||||
# Create user with password
|
||||
user = User.create_with_password(email, password, name)
|
||||
|
||||
# Prepare user data for JWT token
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": "password",
|
||||
"providers": ["password"],
|
||||
}
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
jwt_user_data
|
||||
)
|
||||
refresh_token = self.token_service.generate_refresh_token(
|
||||
jwt_user_data
|
||||
)
|
||||
|
||||
# Create response and set HTTP-only cookies
|
||||
response = jsonify(
|
||||
{
|
||||
"message": "Registration successful",
|
||||
"user": jwt_user_data,
|
||||
}
|
||||
)
|
||||
|
||||
# Set JWT cookies
|
||||
set_access_cookies(response, access_token)
|
||||
set_refresh_cookies(response, refresh_token)
|
||||
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
response = jsonify({"error": str(e)})
|
||||
response.status_code = 400
|
||||
return response
|
||||
|
||||
def login_with_password(self, email: str, password: str) -> Any:
|
||||
"""Login user with email and password."""
|
||||
# Authenticate user
|
||||
user = User.authenticate_with_password(email, password)
|
||||
|
||||
if not user:
|
||||
response = jsonify(
|
||||
{"error": "Invalid email, password or disabled account"}
|
||||
)
|
||||
response.status_code = 401
|
||||
return response
|
||||
|
||||
# Prepare user data for JWT token
|
||||
oauth_providers = [p.provider for p in user.oauth_providers]
|
||||
if user.has_password():
|
||||
oauth_providers.append("password")
|
||||
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": "password",
|
||||
"providers": oauth_providers,
|
||||
}
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(jwt_user_data)
|
||||
refresh_token = self.token_service.generate_refresh_token(jwt_user_data)
|
||||
|
||||
# Create response and set HTTP-only cookies
|
||||
response = jsonify(
|
||||
{
|
||||
"message": "Login successful",
|
||||
"user": jwt_user_data,
|
||||
}
|
||||
)
|
||||
|
||||
# Set JWT cookies
|
||||
set_access_cookies(response, access_token)
|
||||
set_refresh_cookies(response, refresh_token)
|
||||
|
||||
return response
|
||||
|
||||
def logout(self) -> Any:
|
||||
"""Clear authentication cookies."""
|
||||
response = make_response({"message": "Logged out successfully"})
|
||||
response.set_cookie("access_token", "", expires=0)
|
||||
response.set_cookie("refresh_token", "", expires=0)
|
||||
response = jsonify({"message": "Logged out successfully"})
|
||||
unset_jwt_cookies(response)
|
||||
return response
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if user is authenticated."""
|
||||
return self.get_current_user() is not None
|
||||
@@ -1,37 +1,148 @@
|
||||
"""Authentication decorators and middleware."""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from flask import jsonify, request
|
||||
|
||||
from app.services.token_service import TokenService
|
||||
from flask_jwt_extended import get_jwt, get_jwt_identity, jwt_required
|
||||
|
||||
|
||||
def require_auth(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def require_auth(f):
|
||||
"""Decorator to require authentication for routes."""
|
||||
@wraps(f)
|
||||
def decorated_function(*args: Any, **kwargs: Any) -> Any:
|
||||
token_service = TokenService()
|
||||
access_token = request.cookies.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
return jsonify({"error": "Authentication required"}), 401
|
||||
|
||||
user_data = token_service.get_user_from_access_token(access_token)
|
||||
if not user_data:
|
||||
return jsonify({"error": "Invalid or expired token"}), 401
|
||||
|
||||
return f(*args, **kwargs)
|
||||
return decorated_function
|
||||
return jwt_required()(f)
|
||||
|
||||
|
||||
def get_current_user() -> dict[str, Any] | None:
|
||||
"""Helper function to get current user from access token."""
|
||||
token_service = TokenService()
|
||||
access_token = request.cookies.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
"""Helper function to get current user from JWT token."""
|
||||
try:
|
||||
current_user_id = get_jwt_identity()
|
||||
if not current_user_id:
|
||||
return None
|
||||
|
||||
claims = get_jwt()
|
||||
is_active = claims.get("is_active", True)
|
||||
|
||||
# Check if user is active
|
||||
if not is_active:
|
||||
return None
|
||||
|
||||
return {
|
||||
"id": current_user_id,
|
||||
"email": claims.get("email", ""),
|
||||
"name": claims.get("name", ""),
|
||||
"picture": claims.get("picture"),
|
||||
"role": claims.get("role", "user"),
|
||||
"is_active": is_active,
|
||||
"provider": claims.get("provider", "unknown"),
|
||||
"providers": claims.get("providers", []),
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return token_service.get_user_from_access_token(access_token)
|
||||
|
||||
|
||||
def require_role(required_role: str):
|
||||
"""Decorator to require specific role for routes."""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
@jwt_required()
|
||||
def wrapper(*args, **kwargs):
|
||||
user = get_current_user()
|
||||
if not user:
|
||||
return jsonify({"error": "Authentication required"}), 401
|
||||
|
||||
if user.get("role") != required_role:
|
||||
return jsonify({"error": f"Access denied. {required_role.title()} role required"}), 403
|
||||
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def require_admin(f):
|
||||
"""Decorator to require admin role for routes."""
|
||||
return require_role("admin")(f)
|
||||
|
||||
|
||||
def require_user_or_admin(f):
|
||||
"""Decorator to require user or admin role for routes."""
|
||||
@wraps(f)
|
||||
@jwt_required()
|
||||
def wrapper(*args, **kwargs):
|
||||
user = get_current_user()
|
||||
if not user:
|
||||
return jsonify({"error": "Authentication required"}), 401
|
||||
|
||||
if user.get("role") not in ["user", "admin"]:
|
||||
return jsonify({"error": "Access denied"}), 403
|
||||
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_user_from_api_token() -> dict[str, Any] | None:
|
||||
"""Get user from API token in request headers."""
|
||||
try:
|
||||
# Check for API token in Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return None
|
||||
|
||||
# Expected format: "Bearer <token>" or "Token <token>"
|
||||
parts = auth_header.split()
|
||||
if len(parts) != 2 or parts[0].lower() not in ["bearer", "token"]:
|
||||
return None
|
||||
|
||||
api_token = parts[1]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from app.models.user import User
|
||||
|
||||
user = User.find_by_api_token(api_token)
|
||||
if user and user.is_active:
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": "api_token",
|
||||
"providers": [p.provider for p in user.oauth_providers] + ["api_token"],
|
||||
}
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def require_api_token(f):
|
||||
"""Decorator to require API token authentication for routes."""
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
user = get_user_from_api_token()
|
||||
if not user:
|
||||
return jsonify({"error": "Valid API token required"}), 401
|
||||
|
||||
return f(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_auth_or_api_token(f):
|
||||
"""Decorator to accept either JWT or API token authentication."""
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Try JWT authentication first
|
||||
try:
|
||||
user = get_current_user()
|
||||
if user:
|
||||
return f(*args, **kwargs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try API token authentication
|
||||
user = get_user_from_api_token()
|
||||
if user:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return jsonify({"error": "Authentication required (JWT or API token)"}), 401
|
||||
return wrapper
|
||||
0
app/services/oauth_providers/__init__.py
Normal file
0
app/services/oauth_providers/__init__.py
Normal file
68
app/services/oauth_providers/base.py
Normal file
68
app/services/oauth_providers/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional
|
||||
from authlib.integrations.flask_client import OAuth
|
||||
|
||||
|
||||
class OAuthProvider(ABC):
|
||||
"""Abstract base class for OAuth providers."""
|
||||
|
||||
def __init__(self, oauth: OAuth, client_id: str, client_secret: str):
|
||||
self.oauth = oauth
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name (e.g., 'google', 'github')."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable provider name (e.g., 'Google', 'GitHub')."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
"""Return OAuth client configuration."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract user information from OAuth token response."""
|
||||
pass
|
||||
|
||||
def get_client(self):
|
||||
"""Get or create OAuth client."""
|
||||
if self._client is None:
|
||||
config = self.get_client_config()
|
||||
self._client = self.oauth.register(
|
||||
name=self.name,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
**config
|
||||
)
|
||||
return self._client
|
||||
|
||||
def get_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Generate authorization URL for OAuth flow."""
|
||||
client = self.get_client()
|
||||
return client.authorize_redirect(redirect_uri).location
|
||||
|
||||
def exchange_code_for_token(self, code: str = None, redirect_uri: str = None) -> Dict[str, Any]:
|
||||
"""Exchange authorization code for access token."""
|
||||
client = self.get_client()
|
||||
token = client.authorize_access_token()
|
||||
return token
|
||||
|
||||
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Normalize user data to common format."""
|
||||
return {
|
||||
'id': user_info.get('id'),
|
||||
'email': user_info.get('email'),
|
||||
'name': user_info.get('name'),
|
||||
'picture': user_info.get('picture'),
|
||||
'provider': self.name
|
||||
}
|
||||
52
app/services/oauth_providers/github.py
Normal file
52
app/services/oauth_providers/github.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Dict, Any
|
||||
from .base import OAuthProvider
|
||||
|
||||
|
||||
class GitHubOAuthProvider(OAuthProvider):
|
||||
"""GitHub OAuth provider implementation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return 'github'
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return 'GitHub'
|
||||
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
"""Return GitHub OAuth client configuration."""
|
||||
return {
|
||||
'access_token_url': 'https://github.com/login/oauth/access_token',
|
||||
'authorize_url': 'https://github.com/login/oauth/authorize',
|
||||
'api_base_url': 'https://api.github.com/',
|
||||
'client_kwargs': {
|
||||
'scope': 'user:email'
|
||||
}
|
||||
}
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract user information from GitHub OAuth token response."""
|
||||
client = self.get_client()
|
||||
|
||||
# Get user profile
|
||||
user_resp = client.get('user', token=token)
|
||||
user_data = user_resp.json()
|
||||
|
||||
# Get user email (may be private)
|
||||
email = user_data.get('email')
|
||||
if not email:
|
||||
# If email is private, get from emails endpoint
|
||||
emails_resp = client.get('user/emails', token=token)
|
||||
emails = emails_resp.json()
|
||||
# Find primary email
|
||||
for email_obj in emails:
|
||||
if email_obj.get('primary', False):
|
||||
email = email_obj.get('email')
|
||||
break
|
||||
|
||||
return {
|
||||
'id': str(user_data.get('id')),
|
||||
'email': email,
|
||||
'name': user_data.get('name') or user_data.get('login'),
|
||||
'picture': user_data.get('avatar_url')
|
||||
}
|
||||
34
app/services/oauth_providers/google.py
Normal file
34
app/services/oauth_providers/google.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from .base import OAuthProvider
|
||||
|
||||
|
||||
class GoogleOAuthProvider(OAuthProvider):
|
||||
"""Google OAuth provider implementation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "google"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return "Google"
|
||||
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
"""Return Google OAuth client configuration."""
|
||||
return {
|
||||
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
|
||||
"client_kwargs": {"scope": "openid email profile"},
|
||||
}
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract user information from Google OAuth token response."""
|
||||
client = self.get_client()
|
||||
user_info = client.userinfo(token=token)
|
||||
|
||||
return {
|
||||
"id": user_info.get("sub"),
|
||||
"email": user_info.get("email"),
|
||||
"name": user_info.get("name"),
|
||||
"picture": user_info.get("picture"),
|
||||
}
|
||||
45
app/services/oauth_providers/registry.py
Normal file
45
app/services/oauth_providers/registry.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from authlib.integrations.flask_client import OAuth
|
||||
from .base import OAuthProvider
|
||||
from .google import GoogleOAuthProvider
|
||||
from .github import GitHubOAuthProvider
|
||||
|
||||
|
||||
class OAuthProviderRegistry:
|
||||
"""Registry for OAuth providers."""
|
||||
|
||||
def __init__(self, oauth: OAuth):
|
||||
self.oauth = oauth
|
||||
self._providers: Dict[str, OAuthProvider] = {}
|
||||
self._initialize_providers()
|
||||
|
||||
def _initialize_providers(self):
|
||||
"""Initialize available providers based on environment variables."""
|
||||
# Google OAuth
|
||||
google_client_id = os.getenv('GOOGLE_CLIENT_ID')
|
||||
google_client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
|
||||
if google_client_id and google_client_secret:
|
||||
self._providers['google'] = GoogleOAuthProvider(
|
||||
self.oauth, google_client_id, google_client_secret
|
||||
)
|
||||
|
||||
# GitHub OAuth
|
||||
github_client_id = os.getenv('GITHUB_CLIENT_ID')
|
||||
github_client_secret = os.getenv('GITHUB_CLIENT_SECRET')
|
||||
if github_client_id and github_client_secret:
|
||||
self._providers['github'] = GitHubOAuthProvider(
|
||||
self.oauth, github_client_id, github_client_secret
|
||||
)
|
||||
|
||||
def get_provider(self, name: str) -> Optional[OAuthProvider]:
|
||||
"""Get OAuth provider by name."""
|
||||
return self._providers.get(name)
|
||||
|
||||
def get_available_providers(self) -> Dict[str, OAuthProvider]:
|
||||
"""Get all available providers."""
|
||||
return self._providers.copy()
|
||||
|
||||
def is_provider_available(self, name: str) -> bool:
|
||||
"""Check if provider is available."""
|
||||
return name in self._providers
|
||||
@@ -1,75 +1,24 @@
|
||||
"""JWT token service for handling access and refresh tokens."""
|
||||
"""JWT token service using Flask-JWT-Extended."""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from flask_jwt_extended import create_access_token, create_refresh_token
|
||||
|
||||
|
||||
class TokenService:
|
||||
"""Service for handling JWT tokens."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the token service."""
|
||||
self.secret_key = os.environ.get("JWT_SECRET_KEY", "jwt-secret-key")
|
||||
self.algorithm = "HS256"
|
||||
self.access_token_expire_minutes = 15
|
||||
self.refresh_token_expire_days = 7
|
||||
"""Service for handling JWT tokens using Flask-JWT-Extended."""
|
||||
|
||||
def generate_access_token(self, user_data: dict[str, Any]) -> str:
|
||||
"""Generate an access token for the user."""
|
||||
payload = {
|
||||
"user_id": user_data["id"],
|
||||
"email": user_data["email"],
|
||||
"name": user_data["name"],
|
||||
"type": "access",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(
|
||||
minutes=self.access_token_expire_minutes
|
||||
),
|
||||
"iat": datetime.now(timezone.utc),
|
||||
}
|
||||
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
return create_access_token(
|
||||
identity=user_data["id"],
|
||||
additional_claims={
|
||||
"email": user_data["email"],
|
||||
"name": user_data["name"],
|
||||
"picture": user_data.get("picture"),
|
||||
},
|
||||
)
|
||||
|
||||
def generate_refresh_token(self, user_data: dict[str, Any]) -> str:
|
||||
"""Generate a refresh token for the user."""
|
||||
payload = {
|
||||
"user_id": user_data["id"],
|
||||
"type": "refresh",
|
||||
"exp": datetime.now(timezone.utc) + timedelta(
|
||||
days=self.refresh_token_expire_days
|
||||
),
|
||||
"iat": datetime.now(timezone.utc),
|
||||
}
|
||||
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> dict[str, Any] | None:
|
||||
"""Verify and decode a JWT token."""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, self.secret_key, algorithms=[self.algorithm]
|
||||
)
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
||||
|
||||
def is_access_token(self, payload: dict[str, Any]) -> bool:
|
||||
"""Check if the token payload is for an access token."""
|
||||
return payload.get("type") == "access"
|
||||
|
||||
def is_refresh_token(self, payload: dict[str, Any]) -> bool:
|
||||
"""Check if the token payload is for a refresh token."""
|
||||
return payload.get("type") == "refresh"
|
||||
|
||||
def get_user_from_access_token(self, token: str) -> dict[str, Any] | None:
|
||||
"""Extract user data from access token."""
|
||||
payload = self.verify_token(token)
|
||||
if payload and self.is_access_token(payload):
|
||||
return {
|
||||
"id": payload["user_id"],
|
||||
"email": payload["email"],
|
||||
"name": payload["name"],
|
||||
}
|
||||
return None
|
||||
return create_refresh_token(identity=user_data["id"])
|
||||
Reference in New Issue
Block a user