diff --git a/app/routes/auth.py b/app/routes/auth.py index fedaedd..3dd42eb 100644 --- a/app/routes/auth.py +++ b/app/routes/auth.py @@ -1,10 +1,14 @@ """Authentication routes.""" from flask import Blueprint, jsonify, url_for -from flask_jwt_extended import create_access_token, get_jwt_identity, jwt_required +from flask_jwt_extended import ( + create_access_token, + get_jwt_identity, + jwt_required, +) from app import auth_service -from app.services.decorators import get_current_user +from app.services.decorators import get_current_user, require_auth bp = Blueprint("auth", __name__) @@ -19,26 +23,28 @@ def login_oauth(provider): @bp.route("/callback/") def callback(provider): """Handle OAuth callback from specified provider.""" - from flask import redirect, make_response - + from flask import make_response, redirect + try: auth_response = auth_service.handle_callback(provider) - + # If successful, redirect to frontend dashboard with cookies if auth_response.status_code == 200: - redirect_response = make_response(redirect("http://localhost:3000/dashboard")) - + redirect_response = make_response( + redirect("http://localhost:3000/dashboard") + ) + # Copy all cookies from the auth response - for cookie in auth_response.headers.getlist('Set-Cookie'): - redirect_response.headers.add('Set-Cookie', cookie) - + for cookie in auth_response.headers.getlist("Set-Cookie"): + redirect_response.headers.add("Set-Cookie", cookie) + return redirect_response else: # If there was an error, redirect to login with error return redirect("http://localhost:3000/login?error=oauth_failed") - + except Exception as e: - error_msg = str(e).replace(' ', '_').replace('"', '') + error_msg = str(e).replace(" ", "_").replace('"', "") return redirect(f"http://localhost:3000/login?error={error_msg}") @@ -48,21 +54,48 @@ def providers(): return {"providers": auth_service.get_available_providers()} +@bp.route("/register", methods=["POST"]) +def register(): + """Register new user with email and password.""" + from flask import request + + data = request.get_json() + if not data: + return {"error": "No data provided"}, 400 + + email = data.get("email") + password = data.get("password") + name = data.get("name") + + if not email or not password or not name: + return {"error": "Email, password, and name are required"}, 400 + + # Basic email validation + if "@" not in email or "." not in email: + return {"error": "Invalid email format"}, 400 + + # Basic password validation + if len(password) < 6: + return {"error": "Password must be at least 6 characters long"}, 400 + + return auth_service.register_with_password(email, password, name) + + @bp.route("/login", methods=["POST"]) def login(): """Login user with email and password.""" from flask import request - + data = request.get_json() if not data: return {"error": "No data provided"}, 400 - + email = data.get("email") password = data.get("password") - + if not email or not password: return {"error": "Email and password are required"}, 400 - + return auth_service.login_with_password(email, password) @@ -72,29 +105,22 @@ def logout(): return auth_service.logout() -@bp.route("/me") -@jwt_required() -def me(): - """Get current user information.""" - user = get_current_user() - return {"user": user} - - @bp.route("/refresh", methods=["POST"]) @jwt_required(refresh=True) def refresh(): """Refresh access token using refresh token.""" current_user_id = get_jwt_identity() - + # Create new access token new_access_token = create_access_token(identity=current_user_id) - + response = jsonify({"message": "Token refreshed"}) - + # Set new access token cookie from flask_jwt_extended import set_access_cookies + set_access_cookies(response, new_access_token) - + return response @@ -102,7 +128,9 @@ def refresh(): @jwt_required() def link_provider(provider): """Link a new OAuth provider to current user account.""" - redirect_uri = url_for("auth.link_callback", provider=provider, _external=True) + redirect_uri = url_for( + "auth.link_callback", provider=provider, _external=True + ) return auth_service.redirect_to_login(provider, redirect_uri) @@ -114,40 +142,47 @@ def link_callback(provider): current_user_id = get_jwt_identity() if not current_user_id: return {"error": "User not authenticated"}, 401 - + # Get current user from database from app.models.user import User + user = User.query.get(current_user_id) if not user: return {"error": "User not found"}, 404 - + # Process OAuth callback but link to existing user - from app.services.oauth_providers.registry import OAuthProviderRegistry from authlib.integrations.flask_client import OAuth - + + from app.services.oauth_providers.registry import OAuthProviderRegistry + oauth = OAuth() registry = OAuthProviderRegistry(oauth) oauth_provider = registry.get_provider(provider) - + if not oauth_provider: return {"error": f"OAuth provider '{provider}' not configured"}, 400 - + token = oauth_provider.exchange_code_for_token(None, None) raw_user_info = oauth_provider.get_user_info(token) provider_data = oauth_provider.normalize_user_data(raw_user_info) - + if not provider_data.get("id"): - return {"error": "Failed to get user information from provider"}, 400 - + return { + "error": "Failed to get user information from provider" + }, 400 + # Check if this provider is already linked to another user from app.models.user_oauth import UserOAuth + existing_provider = UserOAuth.find_by_provider_and_id( provider, provider_data["id"] ) - + if existing_provider and existing_provider.user_id != user.id: - return {"error": "This provider account is already linked to another user"}, 409 - + return { + "error": "This provider account is already linked to another user" + }, 409 + # Link the provider to current user UserOAuth.create_or_update( user_id=user.id, @@ -155,11 +190,11 @@ def link_callback(provider): provider_id=provider_data["id"], email=provider_data["email"], name=provider_data["name"], - picture=provider_data.get("picture") + picture=provider_data.get("picture"), ) - + return {"message": f"{provider.title()} account linked successfully"} - + except Exception as e: return {"error": str(e)}, 400 @@ -172,60 +207,35 @@ def unlink_provider(provider): current_user_id = get_jwt_identity() if not current_user_id: return {"error": "User not authenticated"}, 401 - + + from app.database import db from app.models.user import User from app.models.user_oauth import UserOAuth - from app.database import db - + user = User.query.get(current_user_id) if not user: return {"error": "User not found"}, 404 - + # Check if user has more than one provider (prevent locking out) if len(user.oauth_providers) <= 1: return {"error": "Cannot unlink last authentication provider"}, 400 - + # Find and remove the provider oauth_provider = user.get_provider(provider) if not oauth_provider: - return {"error": f"Provider '{provider}' not linked to this account"}, 404 - + return { + "error": f"Provider '{provider}' not linked to this account" + }, 404 + db.session.delete(oauth_provider) db.session.commit() - + return {"message": f"{provider.title()} account unlinked successfully"} - + except Exception as e: return {"error": str(e)}, 400 -@bp.route("/register", methods=["POST"]) -def register(): - """Register new user with email and password.""" - from flask import request - - data = request.get_json() - if not data: - return {"error": "No data provided"}, 400 - - email = data.get("email") - password = data.get("password") - name = data.get("name") - - if not email or not password or not name: - return {"error": "Email, password, and name are required"}, 400 - - # Basic email validation - if "@" not in email or "." not in email: - return {"error": "Invalid email format"}, 400 - - # Basic password validation - if len(password) < 6: - return {"error": "Password must be at least 6 characters long"}, 400 - - return auth_service.register_with_password(email, password, name) - - @bp.route("/regenerate-api-token", methods=["POST"]) @jwt_required() def regenerate_api_token(): @@ -233,22 +243,32 @@ def regenerate_api_token(): current_user_id = get_jwt_identity() if not current_user_id: return {"error": "User not authenticated"}, 401 - - from app.models.user import User + from app.database import db - + from app.models.user import User + user = User.query.get(current_user_id) if not user: return {"error": "User not found"}, 404 - + # Generate new API token new_token = user.generate_api_token() db.session.commit() - + return { "message": "API token regenerated successfully", "api_token": new_token, - "expires_at": user.api_token_expires_at.isoformat() if user.api_token_expires_at else None + "expires_at": ( + user.api_token_expires_at.isoformat() + if user.api_token_expires_at + else None + ), } +@bp.route("/me") +@require_auth +def me(): + """Get current user information.""" + user = get_current_user() + return {"user": user} diff --git a/app/routes/main.py b/app/routes/main.py index 8e4f747..9e30b17 100644 --- a/app/routes/main.py +++ b/app/routes/main.py @@ -2,64 +2,53 @@ from flask import Blueprint -from app.services.decorators import get_current_user, require_auth, require_admin, require_auth_or_api_token, get_user_from_api_token -from app.services.greeting_service import GreetingService +from app.services.decorators import get_current_user, require_auth, require_role bp = Blueprint("main", __name__) @bp.route("/") def index() -> dict[str, str]: - """Root endpoint that returns a greeting.""" - return GreetingService.get_greeting() - - -@bp.route("/hello") -@bp.route("/hello/") -def hello(name: str | None = None) -> dict[str, str]: - """Hello endpoint with optional name parameter.""" - return GreetingService.get_greeting(name) + """Root endpoint that returns API status.""" + return {"message": "API is running", "status": "ok"} @bp.route("/protected") @require_auth def protected() -> dict[str, str]: - """Protected endpoint that requires JWT authentication.""" + """Protected endpoint that requires authentication.""" user = get_current_user() return { "message": f"Hello {user['name']}, this is a protected endpoint!", - "user": user + "user": user, } @bp.route("/api-protected") -@require_auth_or_api_token +@require_auth def api_protected() -> dict[str, str]: """Protected endpoint that accepts JWT or API token authentication.""" - # Try to get user from JWT first, then API token user = get_current_user() - if not user: - user = get_user_from_api_token() - return { "message": f"Hello {user['name']}, you accessed this via {user['provider']}!", - "user": user + "user": user, } @bp.route("/admin") -@require_admin +@require_auth +@require_role("admin") def admin_only() -> dict[str, str]: """Admin-only endpoint to demonstrate role-based access.""" user = get_current_user() return { "message": f"Hello admin {user['name']}, you have admin access!", "user": user, - "admin_info": "This endpoint is only accessible to admin users" + "admin_info": "This endpoint is only accessible to admin users", } @bp.route("/health") def health() -> dict[str, str]: """Health check endpoint.""" - return {"status": "ok"} \ No newline at end of file + return {"status": "ok"} diff --git a/app/services/decorators.py b/app/services/decorators.py index 7261594..1f1ad21 100644 --- a/app/services/decorators.py +++ b/app/services/decorators.py @@ -4,28 +4,25 @@ from functools import wraps from typing import Any from flask import jsonify, request -from flask_jwt_extended import get_jwt, get_jwt_identity, jwt_required +from flask_jwt_extended import get_jwt, get_jwt_identity, verify_jwt_in_request -def require_auth(f): - """Decorator to require authentication for routes.""" - return jwt_required()(f) - - -def get_current_user() -> dict[str, Any] | None: +def get_user_from_jwt() -> dict[str, Any] | None: """Helper function to get current user from JWT token.""" try: + # Try to verify JWT token in request - this sets up the context + verify_jwt_in_request() + current_user_id = get_jwt_identity() if not current_user_id: return None - + claims = get_jwt() is_active = claims.get("is_active", True) - - # Check if user is active + if not is_active: return None - + return { "id": current_user_id, "email": claims.get("email", ""), @@ -40,63 +37,21 @@ def get_current_user() -> dict[str, Any] | None: return None -def require_role(required_role: str): - """Decorator to require specific role for routes.""" - def decorator(f): - @wraps(f) - @jwt_required() - def wrapper(*args, **kwargs): - user = get_current_user() - if not user: - return jsonify({"error": "Authentication required"}), 401 - - if user.get("role") != required_role: - return jsonify({"error": f"Access denied. {required_role.title()} role required"}), 403 - - return f(*args, **kwargs) - return wrapper - return decorator - - -def require_admin(f): - """Decorator to require admin role for routes.""" - return require_role("admin")(f) - - -def require_user_or_admin(f): - """Decorator to require user or admin role for routes.""" - @wraps(f) - @jwt_required() - def wrapper(*args, **kwargs): - user = get_current_user() - if not user: - return jsonify({"error": "Authentication required"}), 401 - - if user.get("role") not in ["user", "admin"]: - return jsonify({"error": "Access denied"}), 403 - - return f(*args, **kwargs) - return wrapper - - def get_user_from_api_token() -> dict[str, Any] | None: """Get user from API token in request headers.""" try: - # Check for API token in Authorization header auth_header = request.headers.get("Authorization") if not auth_header: return None - - # Expected format: "Bearer " or "Token " + parts = auth_header.split() if len(parts) != 2 or parts[0].lower() not in ["bearer", "token"]: return None - + api_token = parts[1] - - # Import here to avoid circular imports + from app.models.user import User - + user = User.find_by_api_token(api_token) if user and user.is_active: return { @@ -107,42 +62,67 @@ def get_user_from_api_token() -> dict[str, Any] | None: "role": user.role, "is_active": user.is_active, "provider": "api_token", - "providers": [p.provider for p in user.oauth_providers] + ["api_token"], + "providers": [p.provider for p in user.oauth_providers] + + ["api_token"], } - + return None except Exception: return None -def require_api_token(f): - """Decorator to require API token authentication for routes.""" +def get_current_user() -> dict[str, Any] | None: + """Get current user from either JWT or API token.""" + # Try JWT first + user = get_user_from_jwt() + if user: + return user + + # Try API token + return get_user_from_api_token() + + +def require_auth(f): + """Decorator to require authentication (JWT or API token) for routes.""" + @wraps(f) def wrapper(*args, **kwargs): - user = get_user_from_api_token() + user = get_current_user() if not user: - return jsonify({"error": "Valid API token required"}), 401 - + return ( + jsonify( + {"error": "Authentication required (JWT or API token)"} + ), + 401, + ) + return f(*args, **kwargs) + return wrapper -def require_auth_or_api_token(f): - """Decorator to accept either JWT or API token authentication.""" - @wraps(f) - def wrapper(*args, **kwargs): - # Try JWT authentication first - try: +def require_role(required_role: str): + """Decorator to require specific role for routes.""" + + def decorator(f): + @wraps(f) + def wrapper(*args, **kwargs): user = get_current_user() - if user: - return f(*args, **kwargs) - except Exception: - pass - - # Try API token authentication - user = get_user_from_api_token() - if user: + if not user: + return jsonify({"error": "Authentication required"}), 401 + + if user.get("role") != required_role: + return ( + jsonify( + { + "error": f"Access denied. {required_role.title()} role required" + } + ), + 403, + ) + return f(*args, **kwargs) - - return jsonify({"error": "Authentication required (JWT or API token)"}), 401 - return wrapper \ No newline at end of file + + return wrapper + + return decorator diff --git a/app/services/greeting_service.py b/app/services/greeting_service.py deleted file mode 100644 index 35c842e..0000000 --- a/app/services/greeting_service.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Service for handling greeting-related business logic.""" - - -class GreetingService: - """Service for greeting operations.""" - - @staticmethod - def get_greeting(name: str | None = None) -> dict[str, str]: - """Get a greeting message. - - Args: - name: Optional name to personalize the greeting - - Returns: - Dictionary containing the greeting message - """ - if name: - message = f"Hello, {name}!" - else: - message = "Hello from backend!" - - return {"message": message} \ No newline at end of file diff --git a/tests/test_greeting_service.py b/tests/test_greeting_service.py deleted file mode 100644 index e019582..0000000 --- a/tests/test_greeting_service.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Tests for GreetingService.""" - -from app.services.greeting_service import GreetingService - - -class TestGreetingService: - """Test cases for GreetingService.""" - - def test_get_greeting_without_name(self) -> None: - """Test getting greeting without providing a name.""" - result = GreetingService.get_greeting() - assert result == {"message": "Hello from backend!"} - - def test_get_greeting_with_name(self) -> None: - """Test getting greeting with a name.""" - result = GreetingService.get_greeting("Alice") - assert result == {"message": "Hello, Alice!"} - - def test_get_greeting_with_empty_string(self) -> None: - """Test getting greeting with empty string name.""" - result = GreetingService.get_greeting("") - assert result == {"message": "Hello from backend!"} \ No newline at end of file diff --git a/tests/test_routes.py b/tests/test_routes.py index 63bcc48..b2363bb 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -21,19 +21,7 @@ class TestMainRoutes: """Test the index route.""" response = client.get("/api/") assert response.status_code == 200 - assert response.get_json() == {"message": "Hello from backend!"} - - def test_hello_route_without_name(self, client) -> None: - """Test hello route without name parameter.""" - response = client.get("/api/hello") - assert response.status_code == 200 - assert response.get_json() == {"message": "Hello from backend!"} - - def test_hello_route_with_name(self, client) -> None: - """Test hello route with name parameter.""" - response = client.get("/api/hello/Alice") - assert response.status_code == 200 - assert response.get_json() == {"message": "Hello, Alice!"} + assert response.get_json() == {"message": "API is running", "status": "ok"} def test_health_route(self, client) -> None: """Test health check route.""" @@ -46,4 +34,4 @@ class TestMainRoutes: response = client.get("/api/protected") assert response.status_code == 401 data = response.get_json() - assert data["error"] == "Authentication required" \ No newline at end of file + assert data["error"] == "Authentication required (JWT or API token)" \ No newline at end of file