feat(auth): implement user plans and credits system with related endpoints

This commit is contained in:
JSC
2025-06-29 16:40:54 +02:00
parent 52c60db811
commit 91648a858e
10 changed files with 334 additions and 63 deletions

View File

@@ -2,11 +2,11 @@ import os
from datetime import timedelta from datetime import timedelta
from flask import Flask from flask import Flask
from flask_jwt_extended import JWTManager
from flask_cors import CORS from flask_cors import CORS
from flask_jwt_extended import JWTManager
from app.services.auth_service import AuthService
from app.database import init_db from app.database import init_db
from app.services.auth_service import AuthService
# Global auth service instance # Global auth service instance
auth_service = AuthService() auth_service = AuthService()
@@ -25,7 +25,9 @@ def create_app():
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
# Configure Flask-JWT-Extended # Configure Flask-JWT-Extended
app.config["JWT_SECRET_KEY"] = os.environ.get("JWT_SECRET_KEY", "jwt-secret-key") app.config["JWT_SECRET_KEY"] = os.environ.get(
"JWT_SECRET_KEY", "jwt-secret-key"
)
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15) app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7) app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)
app.config["JWT_TOKEN_LOCATION"] = ["cookies"] app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
@@ -35,11 +37,13 @@ def create_app():
app.config["JWT_REFRESH_COOKIE_PATH"] = "/api/auth/refresh" app.config["JWT_REFRESH_COOKIE_PATH"] = "/api/auth/refresh"
# Initialize CORS # Initialize CORS
CORS(app, CORS(
origins=["http://localhost:3000"], # Frontend URL app,
supports_credentials=True, # Allow cookies origins=["http://localhost:3000"], # Frontend URL
allow_headers=["Content-Type", "Authorization"], supports_credentials=True, # Allow cookies
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"]) allow_headers=["Content-Type", "Authorization"],
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
)
# Initialize JWT manager # Initialize JWT manager
jwt = JWTManager(app) jwt = JWTManager(app)
@@ -47,11 +51,18 @@ def create_app():
# Initialize database # Initialize database
init_db(app) init_db(app)
# Initialize database tables and seed data
with app.app_context():
from app.database_init import init_database
init_database()
# Initialize authentication service with app # Initialize authentication service with app
auth_service.init_app(app) auth_service.init_app(app)
# Register blueprints # Register blueprints
from app.routes import main, auth from app.routes import auth, main
app.register_blueprint(main.bp, url_prefix="/api") app.register_blueprint(main.bp, url_prefix="/api")
app.register_blueprint(auth.bp, url_prefix="/api/auth") app.register_blueprint(auth.bp, url_prefix="/api/auth")

120
app/database_init.py Normal file
View File

@@ -0,0 +1,120 @@
"""Database initialization and seeding functions."""
from app.database import db
from app.models.plan import Plan
def init_database():
"""Initialize database tables and seed with default data."""
# Create all tables
db.create_all()
# Seed plans if they don't exist
seed_plans()
# Migrate existing users to have plans
migrate_users_to_plans()
def seed_plans():
"""Seed the plans table with default plans if empty."""
# Check if plans already exist
if Plan.query.count() > 0:
return
# Create default plans
plans_data = [
{
"code": "free",
"name": "Free Plan",
"description": "Basic features with limited usage",
"credits": 25,
"max_credits": 75,
},
{
"code": "premium",
"name": "Premium Plan",
"description": "Enhanced features with increased usage limits",
"credits": 50,
"max_credits": 150,
},
{
"code": "pro",
"name": "Pro Plan",
"description": "Full access with unlimited usage",
"credits": 100,
"max_credits": 300,
},
]
for plan_data in plans_data:
plan = Plan(**plan_data)
db.session.add(plan)
db.session.commit()
print(f"Seeded {len(plans_data)} plans into database")
def migrate_users_to_plans():
"""Assign plans to existing users who don't have one."""
from app.models.user import User
try:
# Find users without plans
users_without_plans = User.query.filter(User.plan_id.is_(None)).all()
# Find users with plans but NULL credits (only if credits column exists)
# Note: We only migrate users with NULL credits, not 0 credits
# 0 credits means they spent them, NULL means they never got assigned
try:
users_without_credits = User.query.filter(
User.plan_id.isnot(None), User.credits.is_(None)
).all()
except Exception:
# Credits column doesn't exist yet, will be handled by create_all
users_without_credits = []
if not users_without_plans and not users_without_credits:
return
# Get default and pro plans
default_plan = Plan.get_default_plan()
pro_plan = Plan.get_pro_plan()
# Get the first user (admin) from all users ordered by ID
first_user = User.query.order_by(User.id).first()
updated_count = 0
# Assign plans to users without plans
for user in users_without_plans:
# First user gets pro plan, others get free plan
if user.id == first_user.id:
user.plan_id = pro_plan.id
# Only set credits if the column exists
try:
user.credits = pro_plan.credits
except Exception:
pass
else:
user.plan_id = default_plan.id
# Only set credits if the column exists
try:
user.credits = default_plan.credits
except Exception:
pass
updated_count += 1
# Assign credits to users with plans but no credits
for user in users_without_credits:
user.credits = user.plan.credits
updated_count += 1
if updated_count > 0:
db.session.commit()
print(f"Updated {updated_count} existing users with plans and credits")
except Exception:
# If there's any error (like missing columns), just skip migration
# The database will be properly created by create_all()
pass

View File

@@ -1,6 +1,7 @@
"""Database models.""" """Database models."""
from .plan import Plan
from .user import User from .user import User
from .user_oauth import UserOAuth from .user_oauth import UserOAuth
__all__ = ["User", "UserOAuth"] __all__ = ["Plan", "User", "UserOAuth"]

58
app/models/plan.py Normal file
View File

@@ -0,0 +1,58 @@
"""Plan model for user subscription plans."""
from sqlalchemy import Column, Integer, String, Text
from sqlalchemy.orm import relationship
from app.database import db
class Plan(db.Model):
"""Plan model for user subscription plans."""
__tablename__ = "plans"
id = Column(Integer, primary_key=True)
code = Column(String(50), unique=True, nullable=False, index=True)
name = Column(String(100), nullable=False)
description = Column(Text)
credits = Column(Integer, default=0, nullable=False)
max_credits = Column(Integer, default=0, nullable=False)
# Relationship with users
users = relationship("User", back_populates="plan", lazy="dynamic")
def __repr__(self) -> str:
"""String representation of Plan."""
return f"<Plan {self.code}: {self.name}>"
@classmethod
def find_by_code(cls, code: str) -> "Plan | None":
"""Find plan by code."""
return cls.query.filter_by(code=code).first()
@classmethod
def get_default_plan(cls) -> "Plan":
"""Get the default plan (free)."""
plan = cls.find_by_code("free")
if not plan:
raise ValueError("Default 'free' plan not found in database")
return plan
@classmethod
def get_pro_plan(cls) -> "Plan":
"""Get the pro plan."""
plan = cls.find_by_code("pro")
if not plan:
raise ValueError("'pro' plan not found in database")
return plan
def to_dict(self) -> dict:
"""Convert plan to dictionary."""
return {
"id": self.id,
"code": self.code,
"name": self.name,
"description": self.description,
"credits": self.credits,
"max_credits": self.max_credits,
}

View File

@@ -5,13 +5,14 @@ from datetime import datetime
from typing import Optional, TYPE_CHECKING from typing import Optional, TYPE_CHECKING
from werkzeug.security import check_password_hash, generate_password_hash from werkzeug.security import check_password_hash, generate_password_hash
from sqlalchemy import String, DateTime from sqlalchemy import String, DateTime, Integer, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import db from app.database import db
if TYPE_CHECKING: if TYPE_CHECKING:
from app.models.user_oauth import UserOAuth from app.models.user_oauth import UserOAuth
from app.models.plan import Plan
class User(db.Model): class User(db.Model):
@@ -35,6 +36,12 @@ class User(db.Model):
# User status # User status
is_active: Mapped[bool] = mapped_column(nullable=False, default=True) is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
# Plan relationship
plan_id: Mapped[int] = mapped_column(Integer, ForeignKey("plans.id"), nullable=False)
# User credits (populated from plan credits on creation)
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# API token for programmatic access # API token for programmatic access
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
@@ -51,6 +58,7 @@ class User(db.Model):
oauth_providers: Mapped[list["UserOAuth"]] = relationship( oauth_providers: Mapped[list["UserOAuth"]] = relationship(
"UserOAuth", back_populates="user", cascade="all, delete-orphan" "UserOAuth", back_populates="user", cascade="all, delete-orphan"
) )
plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
def __repr__(self) -> str: def __repr__(self) -> str:
"""String representation of User.""" """String representation of User."""
@@ -69,6 +77,8 @@ class User(db.Model):
"api_token": self.api_token, "api_token": self.api_token,
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None, "api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
"providers": [provider.provider for provider in self.oauth_providers], "providers": [provider.provider for provider in self.oauth_providers],
"plan": self.plan.to_dict() if self.plan else None,
"credits": self.credits,
"created_at": self.created_at.isoformat(), "created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(), "updated_at": self.updated_at.isoformat(),
} }
@@ -159,6 +169,7 @@ class User(db.Model):
) -> tuple["User", "UserOAuth"]: ) -> tuple["User", "UserOAuth"]:
"""Find existing user or create new one from OAuth data.""" """Find existing user or create new one from OAuth data."""
from app.models.user_oauth import UserOAuth from app.models.user_oauth import UserOAuth
from app.models.plan import Plan
# First, try to find existing OAuth provider # First, try to find existing OAuth provider
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id) oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id)
@@ -178,16 +189,24 @@ class User(db.Model):
user = cls.find_by_email(email) user = cls.find_by_email(email)
if not user: if not user:
# Check if this is the first user (admin) # Check if this is the first user (admin with pro plan)
user_count = cls.query.count() user_count = cls.query.count()
role = "admin" if user_count == 0 else "user" role = "admin" if user_count == 0 else "user"
# Assign plan: first user gets pro, others get free
if user_count == 0:
plan = Plan.get_pro_plan()
else:
plan = Plan.get_default_plan()
# Create new user # Create new user
user = cls( user = cls(
email=email, email=email,
name=name, name=name,
picture=picture, picture=picture,
role=role, role=role,
plan_id=plan.id,
credits=plan.credits, # Set credits from plan
) )
user.generate_api_token() # Generate API token on creation user.generate_api_token() # Generate API token on creation
db.session.add(user) db.session.add(user)
@@ -209,20 +228,30 @@ class User(db.Model):
@classmethod @classmethod
def create_with_password(cls, email: str, password: str, name: str) -> "User": def create_with_password(cls, email: str, password: str, name: str) -> "User":
"""Create new user with email and password.""" """Create new user with email and password."""
from app.models.plan import Plan
# Check if user already exists # Check if user already exists
existing_user = cls.find_by_email(email) existing_user = cls.find_by_email(email)
if existing_user: if existing_user:
raise ValueError("User with this email already exists") raise ValueError("User with this email already exists")
# Check if this is the first user (admin) # Check if this is the first user (admin with pro plan)
user_count = cls.query.count() user_count = cls.query.count()
role = "admin" if user_count == 0 else "user" role = "admin" if user_count == 0 else "user"
# Assign plan: first user gets pro, others get free
if user_count == 0:
plan = Plan.get_pro_plan()
else:
plan = Plan.get_default_plan()
# Create new user # Create new user
user = cls( user = cls(
email=email, email=email,
name=name, name=name,
role=role, role=role,
plan_id=plan.id,
credits=plan.credits, # Set credits from plan
) )
user.set_password(password) user.set_password(password)
user.generate_api_token() # Generate API token on creation user.generate_api_token() # Generate API token on creation

View File

@@ -2,7 +2,7 @@
from flask import Blueprint from flask import Blueprint
from app.services.decorators import get_current_user, require_auth, require_role from app.services.decorators import get_current_user, require_auth, require_role, require_credits
bp = Blueprint("main", __name__) bp = Blueprint("main", __name__)
@@ -52,3 +52,29 @@ def admin_only() -> dict[str, str]:
def health() -> dict[str, str]: def health() -> dict[str, str]:
"""Health check endpoint.""" """Health check endpoint."""
return {"status": "ok"} return {"status": "ok"}
@bp.route("/use-credits/<int:amount>")
@require_auth
@require_credits(5)
def use_credits(amount: int) -> dict[str, str]:
"""Test endpoint that costs 5 credits to use."""
user = get_current_user()
return {
"message": f"Successfully used endpoint! You requested amount: {amount}",
"user": user["email"],
"remaining_credits": user["credits"] - 5, # Note: credits already deducted by decorator
}
@bp.route("/expensive-operation")
@require_auth
@require_credits(10)
def expensive_operation() -> dict[str, str]:
"""Test endpoint that costs 10 credits to use."""
user = get_current_user()
return {
"message": "Expensive operation completed successfully!",
"user": user["email"],
"operation_cost": 10,
}

View File

@@ -81,17 +81,9 @@ class AuthService:
response.status_code = 401 response.status_code = 401
return response return response
# Prepare user data for JWT token # Prepare user data for JWT token using user.to_dict()
jwt_user_data = { jwt_user_data = user.to_dict()
"id": str(user.id), jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login
"email": user.email,
"name": user.name,
"picture": user.picture,
"role": user.role,
"is_active": user.is_active,
"provider": oauth_provider.provider,
"providers": [p.provider for p in user.oauth_providers],
}
# Generate JWT tokens # Generate JWT tokens
access_token = self.token_service.generate_access_token( access_token = self.token_service.generate_access_token(
@@ -138,6 +130,8 @@ class AuthService:
claims = get_jwt() claims = get_jwt()
if current_user_id: if current_user_id:
# Get plan information from JWT claims
plan_data = claims.get("plan")
return { return {
"id": current_user_id, "id": current_user_id,
"email": claims.get("email", ""), "email": claims.get("email", ""),
@@ -147,6 +141,8 @@ class AuthService:
"is_active": claims.get("is_active", True), "is_active": claims.get("is_active", True),
"provider": claims.get("provider", "unknown"), "provider": claims.get("provider", "unknown"),
"providers": claims.get("providers", []), "providers": claims.get("providers", []),
"plan": plan_data,
"credits": claims.get("credits"),
} }
return None return None
@@ -158,17 +154,9 @@ class AuthService:
# Create user with password # Create user with password
user = User.create_with_password(email, password, name) user = User.create_with_password(email, password, name)
# Prepare user data for JWT token # Prepare user data for JWT token using user.to_dict()
jwt_user_data = { jwt_user_data = user.to_dict()
"id": str(user.id), jwt_user_data["provider"] = "password" # Override provider for password registration
"email": user.email,
"name": user.name,
"picture": user.picture,
"role": user.role,
"is_active": user.is_active,
"provider": "password",
"providers": ["password"],
}
# Generate JWT tokens # Generate JWT tokens
access_token = self.token_service.generate_access_token( access_token = self.token_service.generate_access_token(
@@ -209,21 +197,9 @@ class AuthService:
response.status_code = 401 response.status_code = 401
return response return response
# Prepare user data for JWT token # Prepare user data for JWT token using user.to_dict()
oauth_providers = [p.provider for p in user.oauth_providers] jwt_user_data = user.to_dict()
if user.has_password(): jwt_user_data["provider"] = "password" # Override provider for password login
oauth_providers.append("password")
jwt_user_data = {
"id": str(user.id),
"email": user.email,
"name": user.name,
"picture": user.picture,
"role": user.role,
"is_active": user.is_active,
"provider": "password",
"providers": oauth_providers,
}
# Generate JWT tokens # Generate JWT tokens
access_token = self.token_service.generate_access_token(jwt_user_data) access_token = self.token_service.generate_access_token(jwt_user_data)

View File

@@ -32,6 +32,8 @@ def get_user_from_jwt() -> dict[str, Any] | None:
"is_active": is_active, "is_active": is_active,
"provider": claims.get("provider", "unknown"), "provider": claims.get("provider", "unknown"),
"providers": claims.get("providers", []), "providers": claims.get("providers", []),
"plan": claims.get("plan"),
"credits": claims.get("credits"),
} }
except Exception: except Exception:
return None return None
@@ -64,6 +66,8 @@ def get_user_from_api_token() -> dict[str, Any] | None:
"provider": "api_token", "provider": "api_token",
"providers": [p.provider for p in user.oauth_providers] "providers": [p.provider for p in user.oauth_providers]
+ ["api_token"], + ["api_token"],
"plan": user.plan.to_dict() if user.plan else None,
"credits": user.credits,
} }
return None return None
@@ -126,3 +130,45 @@ def require_role(required_role: str):
return wrapper return wrapper
return decorator return decorator
def require_credits(credits_needed: int):
"""Decorator to require and deduct credits for routes."""
def decorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
from app.models.user import User
from app.database import db
# First check authentication
user_data = get_current_user()
if not user_data:
return jsonify({"error": "Authentication required"}), 401
# Get the actual user from database to check/update credits
user = User.query.get(int(user_data["id"]))
if not user or not user.is_active:
return jsonify({"error": "User not found or inactive"}), 401
# Check if user has enough credits
if user.credits < credits_needed:
return (
jsonify(
{
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}"
}
),
402, # Payment Required status code
)
# Deduct credits
user.credits -= credits_needed
db.session.commit()
# Execute the function
result = f(*args, **kwargs)
return result
return wrapper
return decorator

View File

@@ -20,6 +20,8 @@ class TokenService:
"is_active": user_data.get("is_active"), "is_active": user_data.get("is_active"),
"provider": user_data.get("provider"), "provider": user_data.get("provider"),
"providers": user_data.get("providers", []), "providers": user_data.get("providers", []),
"plan": user_data.get("plan"),
"credits": user_data.get("credits"),
}, },
) )

View File

@@ -13,7 +13,8 @@ cli = FlaskGroup(app)
def init_db(): def init_db():
"""Initialize the database.""" """Initialize the database."""
print("Initializing database...") print("Initializing database...")
db.create_all() from app.database_init import init_database
init_database()
print("Database initialized successfully!") print("Database initialized successfully!")
@cli.command() @cli.command()
@@ -21,7 +22,8 @@ def reset_db():
"""Reset the database (drop all tables and recreate).""" """Reset the database (drop all tables and recreate)."""
print("Resetting database...") print("Resetting database...")
db.drop_all() db.drop_all()
db.create_all() from app.database_init import init_database
init_database()
print("Database reset successfully!") print("Database reset successfully!")
if __name__ == "__main__": if __name__ == "__main__":