feat(auth): implement user plans and credits system with related endpoints
This commit is contained in:
@@ -81,17 +81,9 @@ class AuthService:
|
||||
response.status_code = 401
|
||||
return response
|
||||
|
||||
# Prepare user data for JWT token
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": oauth_provider.provider,
|
||||
"providers": [p.provider for p in user.oauth_providers],
|
||||
}
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
@@ -138,6 +130,8 @@ class AuthService:
|
||||
claims = get_jwt()
|
||||
|
||||
if current_user_id:
|
||||
# Get plan information from JWT claims
|
||||
plan_data = claims.get("plan")
|
||||
return {
|
||||
"id": current_user_id,
|
||||
"email": claims.get("email", ""),
|
||||
@@ -147,6 +141,8 @@ class AuthService:
|
||||
"is_active": claims.get("is_active", True),
|
||||
"provider": claims.get("provider", "unknown"),
|
||||
"providers": claims.get("providers", []),
|
||||
"plan": plan_data,
|
||||
"credits": claims.get("credits"),
|
||||
}
|
||||
return None
|
||||
|
||||
@@ -158,17 +154,9 @@ class AuthService:
|
||||
# Create user with password
|
||||
user = User.create_with_password(email, password, name)
|
||||
|
||||
# Prepare user data for JWT token
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": "password",
|
||||
"providers": ["password"],
|
||||
}
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password registration
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
@@ -209,21 +197,9 @@ class AuthService:
|
||||
response.status_code = 401
|
||||
return response
|
||||
|
||||
# Prepare user data for JWT token
|
||||
oauth_providers = [p.provider for p in user.oauth_providers]
|
||||
if user.has_password():
|
||||
oauth_providers.append("password")
|
||||
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": "password",
|
||||
"providers": oauth_providers,
|
||||
}
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password login
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(jwt_user_data)
|
||||
|
||||
@@ -32,6 +32,8 @@ def get_user_from_jwt() -> dict[str, Any] | None:
|
||||
"is_active": is_active,
|
||||
"provider": claims.get("provider", "unknown"),
|
||||
"providers": claims.get("providers", []),
|
||||
"plan": claims.get("plan"),
|
||||
"credits": claims.get("credits"),
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
@@ -64,6 +66,8 @@ def get_user_from_api_token() -> dict[str, Any] | None:
|
||||
"provider": "api_token",
|
||||
"providers": [p.provider for p in user.oauth_providers]
|
||||
+ ["api_token"],
|
||||
"plan": user.plan.to_dict() if user.plan else None,
|
||||
"credits": user.credits,
|
||||
}
|
||||
|
||||
return None
|
||||
@@ -126,3 +130,45 @@ def require_role(required_role: str):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_credits(credits_needed: int):
|
||||
"""Decorator to require and deduct credits for routes."""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
from app.models.user import User
|
||||
from app.database import db
|
||||
|
||||
# First check authentication
|
||||
user_data = get_current_user()
|
||||
if not user_data:
|
||||
return jsonify({"error": "Authentication required"}), 401
|
||||
|
||||
# Get the actual user from database to check/update credits
|
||||
user = User.query.get(int(user_data["id"]))
|
||||
if not user or not user.is_active:
|
||||
return jsonify({"error": "User not found or inactive"}), 401
|
||||
|
||||
# Check if user has enough credits
|
||||
if user.credits < credits_needed:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}"
|
||||
}
|
||||
),
|
||||
402, # Payment Required status code
|
||||
)
|
||||
|
||||
# Deduct credits
|
||||
user.credits -= credits_needed
|
||||
db.session.commit()
|
||||
|
||||
# Execute the function
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@@ -20,6 +20,8 @@ class TokenService:
|
||||
"is_active": user_data.get("is_active"),
|
||||
"provider": user_data.get("provider"),
|
||||
"providers": user_data.get("providers", []),
|
||||
"plan": user_data.get("plan"),
|
||||
"credits": user_data.get("credits"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user