feat(auth): implement user plans and credits system with related endpoints
This commit is contained in:
@@ -2,11 +2,11 @@ import os
|
||||
from datetime import timedelta
|
||||
|
||||
from flask import Flask
|
||||
from flask_jwt_extended import JWTManager
|
||||
from flask_cors import CORS
|
||||
from flask_jwt_extended import JWTManager
|
||||
|
||||
from app.services.auth_service import AuthService
|
||||
from app.database import init_db
|
||||
from app.services.auth_service import AuthService
|
||||
|
||||
# Global auth service instance
|
||||
auth_service = AuthService()
|
||||
@@ -15,17 +15,19 @@ auth_service = AuthService()
|
||||
def create_app():
|
||||
"""Create and configure the Flask application."""
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
# Configure Flask secret key (required for sessions used by OAuth)
|
||||
app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "dev-secret-key")
|
||||
|
||||
|
||||
# Configure SQLAlchemy database
|
||||
database_url = os.environ.get("DATABASE_URL", "sqlite:///soundboard.db")
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = database_url
|
||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
||||
|
||||
|
||||
# Configure Flask-JWT-Extended
|
||||
app.config["JWT_SECRET_KEY"] = os.environ.get("JWT_SECRET_KEY", "jwt-secret-key")
|
||||
app.config["JWT_SECRET_KEY"] = os.environ.get(
|
||||
"JWT_SECRET_KEY", "jwt-secret-key"
|
||||
)
|
||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
|
||||
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)
|
||||
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
|
||||
@@ -33,26 +35,35 @@ def create_app():
|
||||
app.config["JWT_COOKIE_CSRF_PROTECT"] = False
|
||||
app.config["JWT_ACCESS_COOKIE_PATH"] = "/api/"
|
||||
app.config["JWT_REFRESH_COOKIE_PATH"] = "/api/auth/refresh"
|
||||
|
||||
|
||||
# Initialize CORS
|
||||
CORS(app,
|
||||
origins=["http://localhost:3000"], # Frontend URL
|
||||
supports_credentials=True, # Allow cookies
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
||||
|
||||
CORS(
|
||||
app,
|
||||
origins=["http://localhost:3000"], # Frontend URL
|
||||
supports_credentials=True, # Allow cookies
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
)
|
||||
|
||||
# Initialize JWT manager
|
||||
jwt = JWTManager(app)
|
||||
|
||||
|
||||
# Initialize database
|
||||
init_db(app)
|
||||
|
||||
|
||||
# Initialize database tables and seed data
|
||||
with app.app_context():
|
||||
from app.database_init import init_database
|
||||
|
||||
init_database()
|
||||
|
||||
# Initialize authentication service with app
|
||||
auth_service.init_app(app)
|
||||
|
||||
|
||||
# Register blueprints
|
||||
from app.routes import main, auth
|
||||
from app.routes import auth, main
|
||||
|
||||
app.register_blueprint(main.bp, url_prefix="/api")
|
||||
app.register_blueprint(auth.bp, url_prefix="/api/auth")
|
||||
|
||||
return app
|
||||
|
||||
return app
|
||||
|
||||
120
app/database_init.py
Normal file
120
app/database_init.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Database initialization and seeding functions."""
|
||||
|
||||
from app.database import db
|
||||
from app.models.plan import Plan
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize database tables and seed with default data."""
|
||||
# Create all tables
|
||||
db.create_all()
|
||||
|
||||
# Seed plans if they don't exist
|
||||
seed_plans()
|
||||
|
||||
# Migrate existing users to have plans
|
||||
migrate_users_to_plans()
|
||||
|
||||
|
||||
def seed_plans():
|
||||
"""Seed the plans table with default plans if empty."""
|
||||
# Check if plans already exist
|
||||
if Plan.query.count() > 0:
|
||||
return
|
||||
|
||||
# Create default plans
|
||||
plans_data = [
|
||||
{
|
||||
"code": "free",
|
||||
"name": "Free Plan",
|
||||
"description": "Basic features with limited usage",
|
||||
"credits": 25,
|
||||
"max_credits": 75,
|
||||
},
|
||||
{
|
||||
"code": "premium",
|
||||
"name": "Premium Plan",
|
||||
"description": "Enhanced features with increased usage limits",
|
||||
"credits": 50,
|
||||
"max_credits": 150,
|
||||
},
|
||||
{
|
||||
"code": "pro",
|
||||
"name": "Pro Plan",
|
||||
"description": "Full access with unlimited usage",
|
||||
"credits": 100,
|
||||
"max_credits": 300,
|
||||
},
|
||||
]
|
||||
|
||||
for plan_data in plans_data:
|
||||
plan = Plan(**plan_data)
|
||||
db.session.add(plan)
|
||||
|
||||
db.session.commit()
|
||||
print(f"Seeded {len(plans_data)} plans into database")
|
||||
|
||||
|
||||
def migrate_users_to_plans():
|
||||
"""Assign plans to existing users who don't have one."""
|
||||
from app.models.user import User
|
||||
|
||||
try:
|
||||
# Find users without plans
|
||||
users_without_plans = User.query.filter(User.plan_id.is_(None)).all()
|
||||
|
||||
# Find users with plans but NULL credits (only if credits column exists)
|
||||
# Note: We only migrate users with NULL credits, not 0 credits
|
||||
# 0 credits means they spent them, NULL means they never got assigned
|
||||
try:
|
||||
users_without_credits = User.query.filter(
|
||||
User.plan_id.isnot(None), User.credits.is_(None)
|
||||
).all()
|
||||
except Exception:
|
||||
# Credits column doesn't exist yet, will be handled by create_all
|
||||
users_without_credits = []
|
||||
|
||||
if not users_without_plans and not users_without_credits:
|
||||
return
|
||||
|
||||
# Get default and pro plans
|
||||
default_plan = Plan.get_default_plan()
|
||||
pro_plan = Plan.get_pro_plan()
|
||||
|
||||
# Get the first user (admin) from all users ordered by ID
|
||||
first_user = User.query.order_by(User.id).first()
|
||||
|
||||
updated_count = 0
|
||||
|
||||
# Assign plans to users without plans
|
||||
for user in users_without_plans:
|
||||
# First user gets pro plan, others get free plan
|
||||
if user.id == first_user.id:
|
||||
user.plan_id = pro_plan.id
|
||||
# Only set credits if the column exists
|
||||
try:
|
||||
user.credits = pro_plan.credits
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
user.plan_id = default_plan.id
|
||||
# Only set credits if the column exists
|
||||
try:
|
||||
user.credits = default_plan.credits
|
||||
except Exception:
|
||||
pass
|
||||
updated_count += 1
|
||||
|
||||
# Assign credits to users with plans but no credits
|
||||
for user in users_without_credits:
|
||||
user.credits = user.plan.credits
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0:
|
||||
db.session.commit()
|
||||
print(f"Updated {updated_count} existing users with plans and credits")
|
||||
|
||||
except Exception:
|
||||
# If there's any error (like missing columns), just skip migration
|
||||
# The database will be properly created by create_all()
|
||||
pass
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Database models."""
|
||||
|
||||
from .plan import Plan
|
||||
from .user import User
|
||||
from .user_oauth import UserOAuth
|
||||
|
||||
__all__ = ["User", "UserOAuth"]
|
||||
__all__ = ["Plan", "User", "UserOAuth"]
|
||||
58
app/models/plan.py
Normal file
58
app/models/plan.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Plan model for user subscription plans."""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import db
|
||||
|
||||
|
||||
class Plan(db.Model):
|
||||
"""Plan model for user subscription plans."""
|
||||
|
||||
__tablename__ = "plans"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
code = Column(String(50), unique=True, nullable=False, index=True)
|
||||
name = Column(String(100), nullable=False)
|
||||
description = Column(Text)
|
||||
credits = Column(Integer, default=0, nullable=False)
|
||||
max_credits = Column(Integer, default=0, nullable=False)
|
||||
|
||||
# Relationship with users
|
||||
users = relationship("User", back_populates="plan", lazy="dynamic")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of Plan."""
|
||||
return f"<Plan {self.code}: {self.name}>"
|
||||
|
||||
@classmethod
|
||||
def find_by_code(cls, code: str) -> "Plan | None":
|
||||
"""Find plan by code."""
|
||||
return cls.query.filter_by(code=code).first()
|
||||
|
||||
@classmethod
|
||||
def get_default_plan(cls) -> "Plan":
|
||||
"""Get the default plan (free)."""
|
||||
plan = cls.find_by_code("free")
|
||||
if not plan:
|
||||
raise ValueError("Default 'free' plan not found in database")
|
||||
return plan
|
||||
|
||||
@classmethod
|
||||
def get_pro_plan(cls) -> "Plan":
|
||||
"""Get the pro plan."""
|
||||
plan = cls.find_by_code("pro")
|
||||
if not plan:
|
||||
raise ValueError("'pro' plan not found in database")
|
||||
return plan
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert plan to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"code": self.code,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"credits": self.credits,
|
||||
"max_credits": self.max_credits,
|
||||
}
|
||||
@@ -5,13 +5,14 @@ from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
from werkzeug.security import check_password_hash, generate_password_hash
|
||||
from sqlalchemy import String, DateTime
|
||||
from sqlalchemy import String, DateTime, Integer, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user_oauth import UserOAuth
|
||||
from app.models.plan import Plan
|
||||
|
||||
|
||||
class User(db.Model):
|
||||
@@ -35,6 +36,12 @@ class User(db.Model):
|
||||
# User status
|
||||
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
||||
|
||||
# Plan relationship
|
||||
plan_id: Mapped[int] = mapped_column(Integer, ForeignKey("plans.id"), nullable=False)
|
||||
|
||||
# User credits (populated from plan credits on creation)
|
||||
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# API token for programmatic access
|
||||
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
@@ -51,6 +58,7 @@ class User(db.Model):
|
||||
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
||||
"UserOAuth", back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of User."""
|
||||
@@ -69,6 +77,8 @@ class User(db.Model):
|
||||
"api_token": self.api_token,
|
||||
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
|
||||
"providers": [provider.provider for provider in self.oauth_providers],
|
||||
"plan": self.plan.to_dict() if self.plan else None,
|
||||
"credits": self.credits,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
@@ -159,6 +169,7 @@ class User(db.Model):
|
||||
) -> tuple["User", "UserOAuth"]:
|
||||
"""Find existing user or create new one from OAuth data."""
|
||||
from app.models.user_oauth import UserOAuth
|
||||
from app.models.plan import Plan
|
||||
|
||||
# First, try to find existing OAuth provider
|
||||
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id)
|
||||
@@ -178,16 +189,24 @@ class User(db.Model):
|
||||
user = cls.find_by_email(email)
|
||||
|
||||
if not user:
|
||||
# Check if this is the first user (admin)
|
||||
# Check if this is the first user (admin with pro plan)
|
||||
user_count = cls.query.count()
|
||||
role = "admin" if user_count == 0 else "user"
|
||||
|
||||
# Assign plan: first user gets pro, others get free
|
||||
if user_count == 0:
|
||||
plan = Plan.get_pro_plan()
|
||||
else:
|
||||
plan = Plan.get_default_plan()
|
||||
|
||||
# Create new user
|
||||
user = cls(
|
||||
email=email,
|
||||
name=name,
|
||||
picture=picture,
|
||||
role=role,
|
||||
plan_id=plan.id,
|
||||
credits=plan.credits, # Set credits from plan
|
||||
)
|
||||
user.generate_api_token() # Generate API token on creation
|
||||
db.session.add(user)
|
||||
@@ -209,20 +228,30 @@ class User(db.Model):
|
||||
@classmethod
|
||||
def create_with_password(cls, email: str, password: str, name: str) -> "User":
|
||||
"""Create new user with email and password."""
|
||||
from app.models.plan import Plan
|
||||
|
||||
# Check if user already exists
|
||||
existing_user = cls.find_by_email(email)
|
||||
if existing_user:
|
||||
raise ValueError("User with this email already exists")
|
||||
|
||||
# Check if this is the first user (admin)
|
||||
# Check if this is the first user (admin with pro plan)
|
||||
user_count = cls.query.count()
|
||||
role = "admin" if user_count == 0 else "user"
|
||||
|
||||
# Assign plan: first user gets pro, others get free
|
||||
if user_count == 0:
|
||||
plan = Plan.get_pro_plan()
|
||||
else:
|
||||
plan = Plan.get_default_plan()
|
||||
|
||||
# Create new user
|
||||
user = cls(
|
||||
email=email,
|
||||
name=name,
|
||||
role=role,
|
||||
plan_id=plan.id,
|
||||
credits=plan.credits, # Set credits from plan
|
||||
)
|
||||
user.set_password(password)
|
||||
user.generate_api_token() # Generate API token on creation
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from flask import Blueprint
|
||||
|
||||
from app.services.decorators import get_current_user, require_auth, require_role
|
||||
from app.services.decorators import get_current_user, require_auth, require_role, require_credits
|
||||
|
||||
bp = Blueprint("main", __name__)
|
||||
|
||||
@@ -52,3 +52,29 @@ def admin_only() -> dict[str, str]:
|
||||
def health() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@bp.route("/use-credits/<int:amount>")
|
||||
@require_auth
|
||||
@require_credits(5)
|
||||
def use_credits(amount: int) -> dict[str, str]:
|
||||
"""Test endpoint that costs 5 credits to use."""
|
||||
user = get_current_user()
|
||||
return {
|
||||
"message": f"Successfully used endpoint! You requested amount: {amount}",
|
||||
"user": user["email"],
|
||||
"remaining_credits": user["credits"] - 5, # Note: credits already deducted by decorator
|
||||
}
|
||||
|
||||
|
||||
@bp.route("/expensive-operation")
|
||||
@require_auth
|
||||
@require_credits(10)
|
||||
def expensive_operation() -> dict[str, str]:
|
||||
"""Test endpoint that costs 10 credits to use."""
|
||||
user = get_current_user()
|
||||
return {
|
||||
"message": "Expensive operation completed successfully!",
|
||||
"user": user["email"],
|
||||
"operation_cost": 10,
|
||||
}
|
||||
|
||||
@@ -81,17 +81,9 @@ class AuthService:
|
||||
response.status_code = 401
|
||||
return response
|
||||
|
||||
# Prepare user data for JWT token
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": oauth_provider.provider,
|
||||
"providers": [p.provider for p in user.oauth_providers],
|
||||
}
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
@@ -138,6 +130,8 @@ class AuthService:
|
||||
claims = get_jwt()
|
||||
|
||||
if current_user_id:
|
||||
# Get plan information from JWT claims
|
||||
plan_data = claims.get("plan")
|
||||
return {
|
||||
"id": current_user_id,
|
||||
"email": claims.get("email", ""),
|
||||
@@ -147,6 +141,8 @@ class AuthService:
|
||||
"is_active": claims.get("is_active", True),
|
||||
"provider": claims.get("provider", "unknown"),
|
||||
"providers": claims.get("providers", []),
|
||||
"plan": plan_data,
|
||||
"credits": claims.get("credits"),
|
||||
}
|
||||
return None
|
||||
|
||||
@@ -158,17 +154,9 @@ class AuthService:
|
||||
# Create user with password
|
||||
user = User.create_with_password(email, password, name)
|
||||
|
||||
# Prepare user data for JWT token
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": "password",
|
||||
"providers": ["password"],
|
||||
}
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password registration
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
@@ -209,21 +197,9 @@ class AuthService:
|
||||
response.status_code = 401
|
||||
return response
|
||||
|
||||
# Prepare user data for JWT token
|
||||
oauth_providers = [p.provider for p in user.oauth_providers]
|
||||
if user.has_password():
|
||||
oauth_providers.append("password")
|
||||
|
||||
jwt_user_data = {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"picture": user.picture,
|
||||
"role": user.role,
|
||||
"is_active": user.is_active,
|
||||
"provider": "password",
|
||||
"providers": oauth_providers,
|
||||
}
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password login
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(jwt_user_data)
|
||||
|
||||
@@ -32,6 +32,8 @@ def get_user_from_jwt() -> dict[str, Any] | None:
|
||||
"is_active": is_active,
|
||||
"provider": claims.get("provider", "unknown"),
|
||||
"providers": claims.get("providers", []),
|
||||
"plan": claims.get("plan"),
|
||||
"credits": claims.get("credits"),
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
@@ -64,6 +66,8 @@ def get_user_from_api_token() -> dict[str, Any] | None:
|
||||
"provider": "api_token",
|
||||
"providers": [p.provider for p in user.oauth_providers]
|
||||
+ ["api_token"],
|
||||
"plan": user.plan.to_dict() if user.plan else None,
|
||||
"credits": user.credits,
|
||||
}
|
||||
|
||||
return None
|
||||
@@ -126,3 +130,45 @@ def require_role(required_role: str):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_credits(credits_needed: int):
|
||||
"""Decorator to require and deduct credits for routes."""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
from app.models.user import User
|
||||
from app.database import db
|
||||
|
||||
# First check authentication
|
||||
user_data = get_current_user()
|
||||
if not user_data:
|
||||
return jsonify({"error": "Authentication required"}), 401
|
||||
|
||||
# Get the actual user from database to check/update credits
|
||||
user = User.query.get(int(user_data["id"]))
|
||||
if not user or not user.is_active:
|
||||
return jsonify({"error": "User not found or inactive"}), 401
|
||||
|
||||
# Check if user has enough credits
|
||||
if user.credits < credits_needed:
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}"
|
||||
}
|
||||
),
|
||||
402, # Payment Required status code
|
||||
)
|
||||
|
||||
# Deduct credits
|
||||
user.credits -= credits_needed
|
||||
db.session.commit()
|
||||
|
||||
# Execute the function
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@@ -20,6 +20,8 @@ class TokenService:
|
||||
"is_active": user_data.get("is_active"),
|
||||
"provider": user_data.get("provider"),
|
||||
"providers": user_data.get("providers", []),
|
||||
"plan": user_data.get("plan"),
|
||||
"credits": user_data.get("credits"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ cli = FlaskGroup(app)
|
||||
def init_db():
|
||||
"""Initialize the database."""
|
||||
print("Initializing database...")
|
||||
db.create_all()
|
||||
from app.database_init import init_database
|
||||
init_database()
|
||||
print("Database initialized successfully!")
|
||||
|
||||
@cli.command()
|
||||
@@ -21,7 +22,8 @@ def reset_db():
|
||||
"""Reset the database (drop all tables and recreate)."""
|
||||
print("Resetting database...")
|
||||
db.drop_all()
|
||||
db.create_all()
|
||||
from app.database_init import init_database
|
||||
init_database()
|
||||
print("Database reset successfully!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user