refactor(auth): improve code structure and add user registration endpoint
refactor(main): update index route response and remove greeting service refactor(decorators): streamline authentication decorators and remove unused ones test(routes): update tests to reflect changes in main routes and error messages
This commit is contained in:
@@ -1,10 +1,14 @@
|
|||||||
"""Authentication routes."""
|
"""Authentication routes."""
|
||||||
|
|
||||||
from flask import Blueprint, jsonify, url_for
|
from flask import Blueprint, jsonify, url_for
|
||||||
from flask_jwt_extended import create_access_token, get_jwt_identity, jwt_required
|
from flask_jwt_extended import (
|
||||||
|
create_access_token,
|
||||||
|
get_jwt_identity,
|
||||||
|
jwt_required,
|
||||||
|
)
|
||||||
|
|
||||||
from app import auth_service
|
from app import auth_service
|
||||||
from app.services.decorators import get_current_user
|
from app.services.decorators import get_current_user, require_auth
|
||||||
|
|
||||||
bp = Blueprint("auth", __name__)
|
bp = Blueprint("auth", __name__)
|
||||||
|
|
||||||
@@ -19,18 +23,20 @@ def login_oauth(provider):
|
|||||||
@bp.route("/callback/<provider>")
|
@bp.route("/callback/<provider>")
|
||||||
def callback(provider):
|
def callback(provider):
|
||||||
"""Handle OAuth callback from specified provider."""
|
"""Handle OAuth callback from specified provider."""
|
||||||
from flask import redirect, make_response
|
from flask import make_response, redirect
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth_response = auth_service.handle_callback(provider)
|
auth_response = auth_service.handle_callback(provider)
|
||||||
|
|
||||||
# If successful, redirect to frontend dashboard with cookies
|
# If successful, redirect to frontend dashboard with cookies
|
||||||
if auth_response.status_code == 200:
|
if auth_response.status_code == 200:
|
||||||
redirect_response = make_response(redirect("http://localhost:3000/dashboard"))
|
redirect_response = make_response(
|
||||||
|
redirect("http://localhost:3000/dashboard")
|
||||||
|
)
|
||||||
|
|
||||||
# Copy all cookies from the auth response
|
# Copy all cookies from the auth response
|
||||||
for cookie in auth_response.headers.getlist('Set-Cookie'):
|
for cookie in auth_response.headers.getlist("Set-Cookie"):
|
||||||
redirect_response.headers.add('Set-Cookie', cookie)
|
redirect_response.headers.add("Set-Cookie", cookie)
|
||||||
|
|
||||||
return redirect_response
|
return redirect_response
|
||||||
else:
|
else:
|
||||||
@@ -38,7 +44,7 @@ def callback(provider):
|
|||||||
return redirect("http://localhost:3000/login?error=oauth_failed")
|
return redirect("http://localhost:3000/login?error=oauth_failed")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e).replace(' ', '_').replace('"', '')
|
error_msg = str(e).replace(" ", "_").replace('"', "")
|
||||||
return redirect(f"http://localhost:3000/login?error={error_msg}")
|
return redirect(f"http://localhost:3000/login?error={error_msg}")
|
||||||
|
|
||||||
|
|
||||||
@@ -48,157 +54,6 @@ def providers():
|
|||||||
return {"providers": auth_service.get_available_providers()}
|
return {"providers": auth_service.get_available_providers()}
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/login", methods=["POST"])
|
|
||||||
def login():
|
|
||||||
"""Login user with email and password."""
|
|
||||||
from flask import request
|
|
||||||
|
|
||||||
data = request.get_json()
|
|
||||||
if not data:
|
|
||||||
return {"error": "No data provided"}, 400
|
|
||||||
|
|
||||||
email = data.get("email")
|
|
||||||
password = data.get("password")
|
|
||||||
|
|
||||||
if not email or not password:
|
|
||||||
return {"error": "Email and password are required"}, 400
|
|
||||||
|
|
||||||
return auth_service.login_with_password(email, password)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/logout")
|
|
||||||
def logout():
|
|
||||||
"""Logout current user."""
|
|
||||||
return auth_service.logout()
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/me")
|
|
||||||
@jwt_required()
|
|
||||||
def me():
|
|
||||||
"""Get current user information."""
|
|
||||||
user = get_current_user()
|
|
||||||
return {"user": user}
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/refresh", methods=["POST"])
|
|
||||||
@jwt_required(refresh=True)
|
|
||||||
def refresh():
|
|
||||||
"""Refresh access token using refresh token."""
|
|
||||||
current_user_id = get_jwt_identity()
|
|
||||||
|
|
||||||
# Create new access token
|
|
||||||
new_access_token = create_access_token(identity=current_user_id)
|
|
||||||
|
|
||||||
response = jsonify({"message": "Token refreshed"})
|
|
||||||
|
|
||||||
# Set new access token cookie
|
|
||||||
from flask_jwt_extended import set_access_cookies
|
|
||||||
set_access_cookies(response, new_access_token)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/link/<provider>")
|
|
||||||
@jwt_required()
|
|
||||||
def link_provider(provider):
|
|
||||||
"""Link a new OAuth provider to current user account."""
|
|
||||||
redirect_uri = url_for("auth.link_callback", provider=provider, _external=True)
|
|
||||||
return auth_service.redirect_to_login(provider, redirect_uri)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/link/callback/<provider>")
|
|
||||||
@jwt_required()
|
|
||||||
def link_callback(provider):
|
|
||||||
"""Handle OAuth callback for linking new provider."""
|
|
||||||
try:
|
|
||||||
current_user_id = get_jwt_identity()
|
|
||||||
if not current_user_id:
|
|
||||||
return {"error": "User not authenticated"}, 401
|
|
||||||
|
|
||||||
# Get current user from database
|
|
||||||
from app.models.user import User
|
|
||||||
user = User.query.get(current_user_id)
|
|
||||||
if not user:
|
|
||||||
return {"error": "User not found"}, 404
|
|
||||||
|
|
||||||
# Process OAuth callback but link to existing user
|
|
||||||
from app.services.oauth_providers.registry import OAuthProviderRegistry
|
|
||||||
from authlib.integrations.flask_client import OAuth
|
|
||||||
|
|
||||||
oauth = OAuth()
|
|
||||||
registry = OAuthProviderRegistry(oauth)
|
|
||||||
oauth_provider = registry.get_provider(provider)
|
|
||||||
|
|
||||||
if not oauth_provider:
|
|
||||||
return {"error": f"OAuth provider '{provider}' not configured"}, 400
|
|
||||||
|
|
||||||
token = oauth_provider.exchange_code_for_token(None, None)
|
|
||||||
raw_user_info = oauth_provider.get_user_info(token)
|
|
||||||
provider_data = oauth_provider.normalize_user_data(raw_user_info)
|
|
||||||
|
|
||||||
if not provider_data.get("id"):
|
|
||||||
return {"error": "Failed to get user information from provider"}, 400
|
|
||||||
|
|
||||||
# Check if this provider is already linked to another user
|
|
||||||
from app.models.user_oauth import UserOAuth
|
|
||||||
existing_provider = UserOAuth.find_by_provider_and_id(
|
|
||||||
provider, provider_data["id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if existing_provider and existing_provider.user_id != user.id:
|
|
||||||
return {"error": "This provider account is already linked to another user"}, 409
|
|
||||||
|
|
||||||
# Link the provider to current user
|
|
||||||
UserOAuth.create_or_update(
|
|
||||||
user_id=user.id,
|
|
||||||
provider=provider,
|
|
||||||
provider_id=provider_data["id"],
|
|
||||||
email=provider_data["email"],
|
|
||||||
name=provider_data["name"],
|
|
||||||
picture=provider_data.get("picture")
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"message": f"{provider.title()} account linked successfully"}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": str(e)}, 400
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/unlink/<provider>", methods=["DELETE"])
|
|
||||||
@jwt_required()
|
|
||||||
def unlink_provider(provider):
|
|
||||||
"""Unlink an OAuth provider from current user account."""
|
|
||||||
try:
|
|
||||||
current_user_id = get_jwt_identity()
|
|
||||||
if not current_user_id:
|
|
||||||
return {"error": "User not authenticated"}, 401
|
|
||||||
|
|
||||||
from app.models.user import User
|
|
||||||
from app.models.user_oauth import UserOAuth
|
|
||||||
from app.database import db
|
|
||||||
|
|
||||||
user = User.query.get(current_user_id)
|
|
||||||
if not user:
|
|
||||||
return {"error": "User not found"}, 404
|
|
||||||
|
|
||||||
# Check if user has more than one provider (prevent locking out)
|
|
||||||
if len(user.oauth_providers) <= 1:
|
|
||||||
return {"error": "Cannot unlink last authentication provider"}, 400
|
|
||||||
|
|
||||||
# Find and remove the provider
|
|
||||||
oauth_provider = user.get_provider(provider)
|
|
||||||
if not oauth_provider:
|
|
||||||
return {"error": f"Provider '{provider}' not linked to this account"}, 404
|
|
||||||
|
|
||||||
db.session.delete(oauth_provider)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return {"message": f"{provider.title()} account unlinked successfully"}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {"error": str(e)}, 400
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/register", methods=["POST"])
|
@bp.route("/register", methods=["POST"])
|
||||||
def register():
|
def register():
|
||||||
"""Register new user with email and password."""
|
"""Register new user with email and password."""
|
||||||
@@ -226,6 +81,161 @@ def register():
|
|||||||
return auth_service.register_with_password(email, password, name)
|
return auth_service.register_with_password(email, password, name)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/login", methods=["POST"])
|
||||||
|
def login():
|
||||||
|
"""Login user with email and password."""
|
||||||
|
from flask import request
|
||||||
|
|
||||||
|
data = request.get_json()
|
||||||
|
if not data:
|
||||||
|
return {"error": "No data provided"}, 400
|
||||||
|
|
||||||
|
email = data.get("email")
|
||||||
|
password = data.get("password")
|
||||||
|
|
||||||
|
if not email or not password:
|
||||||
|
return {"error": "Email and password are required"}, 400
|
||||||
|
|
||||||
|
return auth_service.login_with_password(email, password)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/logout")
|
||||||
|
def logout():
|
||||||
|
"""Logout current user."""
|
||||||
|
return auth_service.logout()
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/refresh", methods=["POST"])
|
||||||
|
@jwt_required(refresh=True)
|
||||||
|
def refresh():
|
||||||
|
"""Refresh access token using refresh token."""
|
||||||
|
current_user_id = get_jwt_identity()
|
||||||
|
|
||||||
|
# Create new access token
|
||||||
|
new_access_token = create_access_token(identity=current_user_id)
|
||||||
|
|
||||||
|
response = jsonify({"message": "Token refreshed"})
|
||||||
|
|
||||||
|
# Set new access token cookie
|
||||||
|
from flask_jwt_extended import set_access_cookies
|
||||||
|
|
||||||
|
set_access_cookies(response, new_access_token)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/link/<provider>")
|
||||||
|
@jwt_required()
|
||||||
|
def link_provider(provider):
|
||||||
|
"""Link a new OAuth provider to current user account."""
|
||||||
|
redirect_uri = url_for(
|
||||||
|
"auth.link_callback", provider=provider, _external=True
|
||||||
|
)
|
||||||
|
return auth_service.redirect_to_login(provider, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/link/callback/<provider>")
|
||||||
|
@jwt_required()
|
||||||
|
def link_callback(provider):
|
||||||
|
"""Handle OAuth callback for linking new provider."""
|
||||||
|
try:
|
||||||
|
current_user_id = get_jwt_identity()
|
||||||
|
if not current_user_id:
|
||||||
|
return {"error": "User not authenticated"}, 401
|
||||||
|
|
||||||
|
# Get current user from database
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
user = User.query.get(current_user_id)
|
||||||
|
if not user:
|
||||||
|
return {"error": "User not found"}, 404
|
||||||
|
|
||||||
|
# Process OAuth callback but link to existing user
|
||||||
|
from authlib.integrations.flask_client import OAuth
|
||||||
|
|
||||||
|
from app.services.oauth_providers.registry import OAuthProviderRegistry
|
||||||
|
|
||||||
|
oauth = OAuth()
|
||||||
|
registry = OAuthProviderRegistry(oauth)
|
||||||
|
oauth_provider = registry.get_provider(provider)
|
||||||
|
|
||||||
|
if not oauth_provider:
|
||||||
|
return {"error": f"OAuth provider '{provider}' not configured"}, 400
|
||||||
|
|
||||||
|
token = oauth_provider.exchange_code_for_token(None, None)
|
||||||
|
raw_user_info = oauth_provider.get_user_info(token)
|
||||||
|
provider_data = oauth_provider.normalize_user_data(raw_user_info)
|
||||||
|
|
||||||
|
if not provider_data.get("id"):
|
||||||
|
return {
|
||||||
|
"error": "Failed to get user information from provider"
|
||||||
|
}, 400
|
||||||
|
|
||||||
|
# Check if this provider is already linked to another user
|
||||||
|
from app.models.user_oauth import UserOAuth
|
||||||
|
|
||||||
|
existing_provider = UserOAuth.find_by_provider_and_id(
|
||||||
|
provider, provider_data["id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_provider and existing_provider.user_id != user.id:
|
||||||
|
return {
|
||||||
|
"error": "This provider account is already linked to another user"
|
||||||
|
}, 409
|
||||||
|
|
||||||
|
# Link the provider to current user
|
||||||
|
UserOAuth.create_or_update(
|
||||||
|
user_id=user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_id=provider_data["id"],
|
||||||
|
email=provider_data["email"],
|
||||||
|
name=provider_data["name"],
|
||||||
|
picture=provider_data.get("picture"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": f"{provider.title()} account linked successfully"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/unlink/<provider>", methods=["DELETE"])
|
||||||
|
@jwt_required()
|
||||||
|
def unlink_provider(provider):
|
||||||
|
"""Unlink an OAuth provider from current user account."""
|
||||||
|
try:
|
||||||
|
current_user_id = get_jwt_identity()
|
||||||
|
if not current_user_id:
|
||||||
|
return {"error": "User not authenticated"}, 401
|
||||||
|
|
||||||
|
from app.database import db
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.user_oauth import UserOAuth
|
||||||
|
|
||||||
|
user = User.query.get(current_user_id)
|
||||||
|
if not user:
|
||||||
|
return {"error": "User not found"}, 404
|
||||||
|
|
||||||
|
# Check if user has more than one provider (prevent locking out)
|
||||||
|
if len(user.oauth_providers) <= 1:
|
||||||
|
return {"error": "Cannot unlink last authentication provider"}, 400
|
||||||
|
|
||||||
|
# Find and remove the provider
|
||||||
|
oauth_provider = user.get_provider(provider)
|
||||||
|
if not oauth_provider:
|
||||||
|
return {
|
||||||
|
"error": f"Provider '{provider}' not linked to this account"
|
||||||
|
}, 404
|
||||||
|
|
||||||
|
db.session.delete(oauth_provider)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {"message": f"{provider.title()} account unlinked successfully"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/regenerate-api-token", methods=["POST"])
|
@bp.route("/regenerate-api-token", methods=["POST"])
|
||||||
@jwt_required()
|
@jwt_required()
|
||||||
def regenerate_api_token():
|
def regenerate_api_token():
|
||||||
@@ -234,8 +244,8 @@ def regenerate_api_token():
|
|||||||
if not current_user_id:
|
if not current_user_id:
|
||||||
return {"error": "User not authenticated"}, 401
|
return {"error": "User not authenticated"}, 401
|
||||||
|
|
||||||
from app.models.user import User
|
|
||||||
from app.database import db
|
from app.database import db
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
user = User.query.get(current_user_id)
|
user = User.query.get(current_user_id)
|
||||||
if not user:
|
if not user:
|
||||||
@@ -248,7 +258,17 @@ def regenerate_api_token():
|
|||||||
return {
|
return {
|
||||||
"message": "API token regenerated successfully",
|
"message": "API token regenerated successfully",
|
||||||
"api_token": new_token,
|
"api_token": new_token,
|
||||||
"expires_at": user.api_token_expires_at.isoformat() if user.api_token_expires_at else None
|
"expires_at": (
|
||||||
|
user.api_token_expires_at.isoformat()
|
||||||
|
if user.api_token_expires_at
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/me")
|
||||||
|
@require_auth
|
||||||
|
def me():
|
||||||
|
"""Get current user information."""
|
||||||
|
user = get_current_user()
|
||||||
|
return {"user": user}
|
||||||
|
|||||||
@@ -2,60 +2,49 @@
|
|||||||
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
||||||
from app.services.decorators import get_current_user, require_auth, require_admin, require_auth_or_api_token, get_user_from_api_token
|
from app.services.decorators import get_current_user, require_auth, require_role
|
||||||
from app.services.greeting_service import GreetingService
|
|
||||||
|
|
||||||
bp = Blueprint("main", __name__)
|
bp = Blueprint("main", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/")
|
@bp.route("/")
|
||||||
def index() -> dict[str, str]:
|
def index() -> dict[str, str]:
|
||||||
"""Root endpoint that returns a greeting."""
|
"""Root endpoint that returns API status."""
|
||||||
return GreetingService.get_greeting()
|
return {"message": "API is running", "status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/hello")
|
|
||||||
@bp.route("/hello/<name>")
|
|
||||||
def hello(name: str | None = None) -> dict[str, str]:
|
|
||||||
"""Hello endpoint with optional name parameter."""
|
|
||||||
return GreetingService.get_greeting(name)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/protected")
|
@bp.route("/protected")
|
||||||
@require_auth
|
@require_auth
|
||||||
def protected() -> dict[str, str]:
|
def protected() -> dict[str, str]:
|
||||||
"""Protected endpoint that requires JWT authentication."""
|
"""Protected endpoint that requires authentication."""
|
||||||
user = get_current_user()
|
user = get_current_user()
|
||||||
return {
|
return {
|
||||||
"message": f"Hello {user['name']}, this is a protected endpoint!",
|
"message": f"Hello {user['name']}, this is a protected endpoint!",
|
||||||
"user": user
|
"user": user,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/api-protected")
|
@bp.route("/api-protected")
|
||||||
@require_auth_or_api_token
|
@require_auth
|
||||||
def api_protected() -> dict[str, str]:
|
def api_protected() -> dict[str, str]:
|
||||||
"""Protected endpoint that accepts JWT or API token authentication."""
|
"""Protected endpoint that accepts JWT or API token authentication."""
|
||||||
# Try to get user from JWT first, then API token
|
|
||||||
user = get_current_user()
|
user = get_current_user()
|
||||||
if not user:
|
|
||||||
user = get_user_from_api_token()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": f"Hello {user['name']}, you accessed this via {user['provider']}!",
|
"message": f"Hello {user['name']}, you accessed this via {user['provider']}!",
|
||||||
"user": user
|
"user": user,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/admin")
|
@bp.route("/admin")
|
||||||
@require_admin
|
@require_auth
|
||||||
|
@require_role("admin")
|
||||||
def admin_only() -> dict[str, str]:
|
def admin_only() -> dict[str, str]:
|
||||||
"""Admin-only endpoint to demonstrate role-based access."""
|
"""Admin-only endpoint to demonstrate role-based access."""
|
||||||
user = get_current_user()
|
user = get_current_user()
|
||||||
return {
|
return {
|
||||||
"message": f"Hello admin {user['name']}, you have admin access!",
|
"message": f"Hello admin {user['name']}, you have admin access!",
|
||||||
"user": user,
|
"user": user,
|
||||||
"admin_info": "This endpoint is only accessible to admin users"
|
"admin_info": "This endpoint is only accessible to admin users",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,17 +4,15 @@ from functools import wraps
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
from flask_jwt_extended import get_jwt, get_jwt_identity, jwt_required
|
from flask_jwt_extended import get_jwt, get_jwt_identity, verify_jwt_in_request
|
||||||
|
|
||||||
|
|
||||||
def require_auth(f):
|
def get_user_from_jwt() -> dict[str, Any] | None:
|
||||||
"""Decorator to require authentication for routes."""
|
|
||||||
return jwt_required()(f)
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_user() -> dict[str, Any] | None:
|
|
||||||
"""Helper function to get current user from JWT token."""
|
"""Helper function to get current user from JWT token."""
|
||||||
try:
|
try:
|
||||||
|
# Try to verify JWT token in request - this sets up the context
|
||||||
|
verify_jwt_in_request()
|
||||||
|
|
||||||
current_user_id = get_jwt_identity()
|
current_user_id = get_jwt_identity()
|
||||||
if not current_user_id:
|
if not current_user_id:
|
||||||
return None
|
return None
|
||||||
@@ -22,7 +20,6 @@ def get_current_user() -> dict[str, Any] | None:
|
|||||||
claims = get_jwt()
|
claims = get_jwt()
|
||||||
is_active = claims.get("is_active", True)
|
is_active = claims.get("is_active", True)
|
||||||
|
|
||||||
# Check if user is active
|
|
||||||
if not is_active:
|
if not is_active:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -40,61 +37,19 @@ def get_current_user() -> dict[str, Any] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def require_role(required_role: str):
|
|
||||||
"""Decorator to require specific role for routes."""
|
|
||||||
def decorator(f):
|
|
||||||
@wraps(f)
|
|
||||||
@jwt_required()
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
user = get_current_user()
|
|
||||||
if not user:
|
|
||||||
return jsonify({"error": "Authentication required"}), 401
|
|
||||||
|
|
||||||
if user.get("role") != required_role:
|
|
||||||
return jsonify({"error": f"Access denied. {required_role.title()} role required"}), 403
|
|
||||||
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
return wrapper
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def require_admin(f):
|
|
||||||
"""Decorator to require admin role for routes."""
|
|
||||||
return require_role("admin")(f)
|
|
||||||
|
|
||||||
|
|
||||||
def require_user_or_admin(f):
|
|
||||||
"""Decorator to require user or admin role for routes."""
|
|
||||||
@wraps(f)
|
|
||||||
@jwt_required()
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
user = get_current_user()
|
|
||||||
if not user:
|
|
||||||
return jsonify({"error": "Authentication required"}), 401
|
|
||||||
|
|
||||||
if user.get("role") not in ["user", "admin"]:
|
|
||||||
return jsonify({"error": "Access denied"}), 403
|
|
||||||
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_from_api_token() -> dict[str, Any] | None:
|
def get_user_from_api_token() -> dict[str, Any] | None:
|
||||||
"""Get user from API token in request headers."""
|
"""Get user from API token in request headers."""
|
||||||
try:
|
try:
|
||||||
# Check for API token in Authorization header
|
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_header = request.headers.get("Authorization")
|
||||||
if not auth_header:
|
if not auth_header:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Expected format: "Bearer <token>" or "Token <token>"
|
|
||||||
parts = auth_header.split()
|
parts = auth_header.split()
|
||||||
if len(parts) != 2 or parts[0].lower() not in ["bearer", "token"]:
|
if len(parts) != 2 or parts[0].lower() not in ["bearer", "token"]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
api_token = parts[1]
|
api_token = parts[1]
|
||||||
|
|
||||||
# Import here to avoid circular imports
|
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
user = User.find_by_api_token(api_token)
|
user = User.find_by_api_token(api_token)
|
||||||
@@ -107,7 +62,8 @@ def get_user_from_api_token() -> dict[str, Any] | None:
|
|||||||
"role": user.role,
|
"role": user.role,
|
||||||
"is_active": user.is_active,
|
"is_active": user.is_active,
|
||||||
"provider": "api_token",
|
"provider": "api_token",
|
||||||
"providers": [p.provider for p in user.oauth_providers] + ["api_token"],
|
"providers": [p.provider for p in user.oauth_providers]
|
||||||
|
+ ["api_token"],
|
||||||
}
|
}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -115,34 +71,58 @@ def get_user_from_api_token() -> dict[str, Any] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def require_api_token(f):
|
def get_current_user() -> dict[str, Any] | None:
|
||||||
"""Decorator to require API token authentication for routes."""
|
"""Get current user from either JWT or API token."""
|
||||||
|
# Try JWT first
|
||||||
|
user = get_user_from_jwt()
|
||||||
|
if user:
|
||||||
|
return user
|
||||||
|
|
||||||
|
# Try API token
|
||||||
|
return get_user_from_api_token()
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth(f):
|
||||||
|
"""Decorator to require authentication (JWT or API token) for routes."""
|
||||||
|
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
user = get_user_from_api_token()
|
user = get_current_user()
|
||||||
if not user:
|
if not user:
|
||||||
return jsonify({"error": "Valid API token required"}), 401
|
return (
|
||||||
|
jsonify(
|
||||||
|
{"error": "Authentication required (JWT or API token)"}
|
||||||
|
),
|
||||||
|
401,
|
||||||
|
)
|
||||||
|
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_auth_or_api_token(f):
|
def require_role(required_role: str):
|
||||||
"""Decorator to accept either JWT or API token authentication."""
|
"""Decorator to require specific role for routes."""
|
||||||
@wraps(f)
|
|
||||||
def wrapper(*args, **kwargs):
|
def decorator(f):
|
||||||
# Try JWT authentication first
|
@wraps(f)
|
||||||
try:
|
def wrapper(*args, **kwargs):
|
||||||
user = get_current_user()
|
user = get_current_user()
|
||||||
if user:
|
if not user:
|
||||||
return f(*args, **kwargs)
|
return jsonify({"error": "Authentication required"}), 401
|
||||||
except Exception:
|
|
||||||
pass
|
if user.get("role") != required_role:
|
||||||
|
return (
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"error": f"Access denied. {required_role.title()} role required"
|
||||||
|
}
|
||||||
|
),
|
||||||
|
403,
|
||||||
|
)
|
||||||
|
|
||||||
# Try API token authentication
|
|
||||||
user = get_user_from_api_token()
|
|
||||||
if user:
|
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return jsonify({"error": "Authentication required (JWT or API token)"}), 401
|
return wrapper
|
||||||
return wrapper
|
|
||||||
|
return decorator
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
"""Service for handling greeting-related business logic."""
|
|
||||||
|
|
||||||
|
|
||||||
class GreetingService:
|
|
||||||
"""Service for greeting operations."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_greeting(name: str | None = None) -> dict[str, str]:
|
|
||||||
"""Get a greeting message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Optional name to personalize the greeting
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing the greeting message
|
|
||||||
"""
|
|
||||||
if name:
|
|
||||||
message = f"Hello, {name}!"
|
|
||||||
else:
|
|
||||||
message = "Hello from backend!"
|
|
||||||
|
|
||||||
return {"message": message}
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
"""Tests for GreetingService."""
|
|
||||||
|
|
||||||
from app.services.greeting_service import GreetingService
|
|
||||||
|
|
||||||
|
|
||||||
class TestGreetingService:
|
|
||||||
"""Test cases for GreetingService."""
|
|
||||||
|
|
||||||
def test_get_greeting_without_name(self) -> None:
|
|
||||||
"""Test getting greeting without providing a name."""
|
|
||||||
result = GreetingService.get_greeting()
|
|
||||||
assert result == {"message": "Hello from backend!"}
|
|
||||||
|
|
||||||
def test_get_greeting_with_name(self) -> None:
|
|
||||||
"""Test getting greeting with a name."""
|
|
||||||
result = GreetingService.get_greeting("Alice")
|
|
||||||
assert result == {"message": "Hello, Alice!"}
|
|
||||||
|
|
||||||
def test_get_greeting_with_empty_string(self) -> None:
|
|
||||||
"""Test getting greeting with empty string name."""
|
|
||||||
result = GreetingService.get_greeting("")
|
|
||||||
assert result == {"message": "Hello from backend!"}
|
|
||||||
@@ -21,19 +21,7 @@ class TestMainRoutes:
|
|||||||
"""Test the index route."""
|
"""Test the index route."""
|
||||||
response = client.get("/api/")
|
response = client.get("/api/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.get_json() == {"message": "Hello from backend!"}
|
assert response.get_json() == {"message": "API is running", "status": "ok"}
|
||||||
|
|
||||||
def test_hello_route_without_name(self, client) -> None:
|
|
||||||
"""Test hello route without name parameter."""
|
|
||||||
response = client.get("/api/hello")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.get_json() == {"message": "Hello from backend!"}
|
|
||||||
|
|
||||||
def test_hello_route_with_name(self, client) -> None:
|
|
||||||
"""Test hello route with name parameter."""
|
|
||||||
response = client.get("/api/hello/Alice")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.get_json() == {"message": "Hello, Alice!"}
|
|
||||||
|
|
||||||
def test_health_route(self, client) -> None:
|
def test_health_route(self, client) -> None:
|
||||||
"""Test health check route."""
|
"""Test health check route."""
|
||||||
@@ -46,4 +34,4 @@ class TestMainRoutes:
|
|||||||
response = client.get("/api/protected")
|
response = client.get("/api/protected")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["error"] == "Authentication required"
|
assert data["error"] == "Authentication required (JWT or API token)"
|
||||||
Reference in New Issue
Block a user