feat(auth): implement user plans and credits system with related endpoints
This commit is contained in:
@@ -2,11 +2,11 @@ import os
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from flask_jwt_extended import JWTManager
|
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
|
from flask_jwt_extended import JWTManager
|
||||||
|
|
||||||
from app.services.auth_service import AuthService
|
|
||||||
from app.database import init_db
|
from app.database import init_db
|
||||||
|
from app.services.auth_service import AuthService
|
||||||
|
|
||||||
# Global auth service instance
|
# Global auth service instance
|
||||||
auth_service = AuthService()
|
auth_service = AuthService()
|
||||||
@@ -15,17 +15,19 @@ auth_service = AuthService()
|
|||||||
def create_app():
|
def create_app():
|
||||||
"""Create and configure the Flask application."""
|
"""Create and configure the Flask application."""
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# Configure Flask secret key (required for sessions used by OAuth)
|
# Configure Flask secret key (required for sessions used by OAuth)
|
||||||
app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "dev-secret-key")
|
app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "dev-secret-key")
|
||||||
|
|
||||||
# Configure SQLAlchemy database
|
# Configure SQLAlchemy database
|
||||||
database_url = os.environ.get("DATABASE_URL", "sqlite:///soundboard.db")
|
database_url = os.environ.get("DATABASE_URL", "sqlite:///soundboard.db")
|
||||||
app.config["SQLALCHEMY_DATABASE_URI"] = database_url
|
app.config["SQLALCHEMY_DATABASE_URI"] = database_url
|
||||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
||||||
|
|
||||||
# Configure Flask-JWT-Extended
|
# Configure Flask-JWT-Extended
|
||||||
app.config["JWT_SECRET_KEY"] = os.environ.get("JWT_SECRET_KEY", "jwt-secret-key")
|
app.config["JWT_SECRET_KEY"] = os.environ.get(
|
||||||
|
"JWT_SECRET_KEY", "jwt-secret-key"
|
||||||
|
)
|
||||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
|
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
|
||||||
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)
|
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)
|
||||||
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
|
app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
|
||||||
@@ -33,26 +35,35 @@ def create_app():
|
|||||||
app.config["JWT_COOKIE_CSRF_PROTECT"] = False
|
app.config["JWT_COOKIE_CSRF_PROTECT"] = False
|
||||||
app.config["JWT_ACCESS_COOKIE_PATH"] = "/api/"
|
app.config["JWT_ACCESS_COOKIE_PATH"] = "/api/"
|
||||||
app.config["JWT_REFRESH_COOKIE_PATH"] = "/api/auth/refresh"
|
app.config["JWT_REFRESH_COOKIE_PATH"] = "/api/auth/refresh"
|
||||||
|
|
||||||
# Initialize CORS
|
# Initialize CORS
|
||||||
CORS(app,
|
CORS(
|
||||||
origins=["http://localhost:3000"], # Frontend URL
|
app,
|
||||||
supports_credentials=True, # Allow cookies
|
origins=["http://localhost:3000"], # Frontend URL
|
||||||
allow_headers=["Content-Type", "Authorization"],
|
supports_credentials=True, # Allow cookies
|
||||||
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"])
|
allow_headers=["Content-Type", "Authorization"],
|
||||||
|
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize JWT manager
|
# Initialize JWT manager
|
||||||
jwt = JWTManager(app)
|
jwt = JWTManager(app)
|
||||||
|
|
||||||
# Initialize database
|
# Initialize database
|
||||||
init_db(app)
|
init_db(app)
|
||||||
|
|
||||||
|
# Initialize database tables and seed data
|
||||||
|
with app.app_context():
|
||||||
|
from app.database_init import init_database
|
||||||
|
|
||||||
|
init_database()
|
||||||
|
|
||||||
# Initialize authentication service with app
|
# Initialize authentication service with app
|
||||||
auth_service.init_app(app)
|
auth_service.init_app(app)
|
||||||
|
|
||||||
# Register blueprints
|
# Register blueprints
|
||||||
from app.routes import main, auth
|
from app.routes import auth, main
|
||||||
|
|
||||||
app.register_blueprint(main.bp, url_prefix="/api")
|
app.register_blueprint(main.bp, url_prefix="/api")
|
||||||
app.register_blueprint(auth.bp, url_prefix="/api/auth")
|
app.register_blueprint(auth.bp, url_prefix="/api/auth")
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|||||||
120
app/database_init.py
Normal file
120
app/database_init.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Database initialization and seeding functions."""
|
||||||
|
|
||||||
|
from app.database import db
|
||||||
|
from app.models.plan import Plan
|
||||||
|
|
||||||
|
|
||||||
|
def init_database():
|
||||||
|
"""Initialize database tables and seed with default data."""
|
||||||
|
# Create all tables
|
||||||
|
db.create_all()
|
||||||
|
|
||||||
|
# Seed plans if they don't exist
|
||||||
|
seed_plans()
|
||||||
|
|
||||||
|
# Migrate existing users to have plans
|
||||||
|
migrate_users_to_plans()
|
||||||
|
|
||||||
|
|
||||||
|
def seed_plans():
|
||||||
|
"""Seed the plans table with default plans if empty."""
|
||||||
|
# Check if plans already exist
|
||||||
|
if Plan.query.count() > 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create default plans
|
||||||
|
plans_data = [
|
||||||
|
{
|
||||||
|
"code": "free",
|
||||||
|
"name": "Free Plan",
|
||||||
|
"description": "Basic features with limited usage",
|
||||||
|
"credits": 25,
|
||||||
|
"max_credits": 75,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"code": "premium",
|
||||||
|
"name": "Premium Plan",
|
||||||
|
"description": "Enhanced features with increased usage limits",
|
||||||
|
"credits": 50,
|
||||||
|
"max_credits": 150,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"code": "pro",
|
||||||
|
"name": "Pro Plan",
|
||||||
|
"description": "Full access with unlimited usage",
|
||||||
|
"credits": 100,
|
||||||
|
"max_credits": 300,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for plan_data in plans_data:
|
||||||
|
plan = Plan(**plan_data)
|
||||||
|
db.session.add(plan)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
print(f"Seeded {len(plans_data)} plans into database")
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_users_to_plans():
|
||||||
|
"""Assign plans to existing users who don't have one."""
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find users without plans
|
||||||
|
users_without_plans = User.query.filter(User.plan_id.is_(None)).all()
|
||||||
|
|
||||||
|
# Find users with plans but NULL credits (only if credits column exists)
|
||||||
|
# Note: We only migrate users with NULL credits, not 0 credits
|
||||||
|
# 0 credits means they spent them, NULL means they never got assigned
|
||||||
|
try:
|
||||||
|
users_without_credits = User.query.filter(
|
||||||
|
User.plan_id.isnot(None), User.credits.is_(None)
|
||||||
|
).all()
|
||||||
|
except Exception:
|
||||||
|
# Credits column doesn't exist yet, will be handled by create_all
|
||||||
|
users_without_credits = []
|
||||||
|
|
||||||
|
if not users_without_plans and not users_without_credits:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get default and pro plans
|
||||||
|
default_plan = Plan.get_default_plan()
|
||||||
|
pro_plan = Plan.get_pro_plan()
|
||||||
|
|
||||||
|
# Get the first user (admin) from all users ordered by ID
|
||||||
|
first_user = User.query.order_by(User.id).first()
|
||||||
|
|
||||||
|
updated_count = 0
|
||||||
|
|
||||||
|
# Assign plans to users without plans
|
||||||
|
for user in users_without_plans:
|
||||||
|
# First user gets pro plan, others get free plan
|
||||||
|
if user.id == first_user.id:
|
||||||
|
user.plan_id = pro_plan.id
|
||||||
|
# Only set credits if the column exists
|
||||||
|
try:
|
||||||
|
user.credits = pro_plan.credits
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
user.plan_id = default_plan.id
|
||||||
|
# Only set credits if the column exists
|
||||||
|
try:
|
||||||
|
user.credits = default_plan.credits
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
# Assign credits to users with plans but no credits
|
||||||
|
for user in users_without_credits:
|
||||||
|
user.credits = user.plan.credits
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
if updated_count > 0:
|
||||||
|
db.session.commit()
|
||||||
|
print(f"Updated {updated_count} existing users with plans and credits")
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If there's any error (like missing columns), just skip migration
|
||||||
|
# The database will be properly created by create_all()
|
||||||
|
pass
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Database models."""
|
"""Database models."""
|
||||||
|
|
||||||
|
from .plan import Plan
|
||||||
from .user import User
|
from .user import User
|
||||||
from .user_oauth import UserOAuth
|
from .user_oauth import UserOAuth
|
||||||
|
|
||||||
__all__ = ["User", "UserOAuth"]
|
__all__ = ["Plan", "User", "UserOAuth"]
|
||||||
58
app/models/plan.py
Normal file
58
app/models/plan.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Plan model for user subscription plans."""
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, String, Text
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.database import db
|
||||||
|
|
||||||
|
|
||||||
|
class Plan(db.Model):
|
||||||
|
"""Plan model for user subscription plans."""
|
||||||
|
|
||||||
|
__tablename__ = "plans"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
code = Column(String(50), unique=True, nullable=False, index=True)
|
||||||
|
name = Column(String(100), nullable=False)
|
||||||
|
description = Column(Text)
|
||||||
|
credits = Column(Integer, default=0, nullable=False)
|
||||||
|
max_credits = Column(Integer, default=0, nullable=False)
|
||||||
|
|
||||||
|
# Relationship with users
|
||||||
|
users = relationship("User", back_populates="plan", lazy="dynamic")
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation of Plan."""
|
||||||
|
return f"<Plan {self.code}: {self.name}>"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def find_by_code(cls, code: str) -> "Plan | None":
|
||||||
|
"""Find plan by code."""
|
||||||
|
return cls.query.filter_by(code=code).first()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_plan(cls) -> "Plan":
|
||||||
|
"""Get the default plan (free)."""
|
||||||
|
plan = cls.find_by_code("free")
|
||||||
|
if not plan:
|
||||||
|
raise ValueError("Default 'free' plan not found in database")
|
||||||
|
return plan
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pro_plan(cls) -> "Plan":
|
||||||
|
"""Get the pro plan."""
|
||||||
|
plan = cls.find_by_code("pro")
|
||||||
|
if not plan:
|
||||||
|
raise ValueError("'pro' plan not found in database")
|
||||||
|
return plan
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert plan to dictionary."""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"code": self.code,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"credits": self.credits,
|
||||||
|
"max_credits": self.max_credits,
|
||||||
|
}
|
||||||
@@ -5,13 +5,14 @@ from datetime import datetime
|
|||||||
from typing import Optional, TYPE_CHECKING
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from werkzeug.security import check_password_hash, generate_password_hash
|
from werkzeug.security import check_password_hash, generate_password_hash
|
||||||
from sqlalchemy import String, DateTime
|
from sqlalchemy import String, DateTime, Integer, ForeignKey
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import db
|
from app.database import db
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.models.user_oauth import UserOAuth
|
from app.models.user_oauth import UserOAuth
|
||||||
|
from app.models.plan import Plan
|
||||||
|
|
||||||
|
|
||||||
class User(db.Model):
|
class User(db.Model):
|
||||||
@@ -35,6 +36,12 @@ class User(db.Model):
|
|||||||
# User status
|
# User status
|
||||||
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
||||||
|
|
||||||
|
# Plan relationship
|
||||||
|
plan_id: Mapped[int] = mapped_column(Integer, ForeignKey("plans.id"), nullable=False)
|
||||||
|
|
||||||
|
# User credits (populated from plan credits on creation)
|
||||||
|
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
# API token for programmatic access
|
# API token for programmatic access
|
||||||
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||||
@@ -51,6 +58,7 @@ class User(db.Model):
|
|||||||
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
||||||
"UserOAuth", back_populates="user", cascade="all, delete-orphan"
|
"UserOAuth", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""String representation of User."""
|
"""String representation of User."""
|
||||||
@@ -69,6 +77,8 @@ class User(db.Model):
|
|||||||
"api_token": self.api_token,
|
"api_token": self.api_token,
|
||||||
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
|
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
|
||||||
"providers": [provider.provider for provider in self.oauth_providers],
|
"providers": [provider.provider for provider in self.oauth_providers],
|
||||||
|
"plan": self.plan.to_dict() if self.plan else None,
|
||||||
|
"credits": self.credits,
|
||||||
"created_at": self.created_at.isoformat(),
|
"created_at": self.created_at.isoformat(),
|
||||||
"updated_at": self.updated_at.isoformat(),
|
"updated_at": self.updated_at.isoformat(),
|
||||||
}
|
}
|
||||||
@@ -159,6 +169,7 @@ class User(db.Model):
|
|||||||
) -> tuple["User", "UserOAuth"]:
|
) -> tuple["User", "UserOAuth"]:
|
||||||
"""Find existing user or create new one from OAuth data."""
|
"""Find existing user or create new one from OAuth data."""
|
||||||
from app.models.user_oauth import UserOAuth
|
from app.models.user_oauth import UserOAuth
|
||||||
|
from app.models.plan import Plan
|
||||||
|
|
||||||
# First, try to find existing OAuth provider
|
# First, try to find existing OAuth provider
|
||||||
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id)
|
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id)
|
||||||
@@ -178,16 +189,24 @@ class User(db.Model):
|
|||||||
user = cls.find_by_email(email)
|
user = cls.find_by_email(email)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
# Check if this is the first user (admin)
|
# Check if this is the first user (admin with pro plan)
|
||||||
user_count = cls.query.count()
|
user_count = cls.query.count()
|
||||||
role = "admin" if user_count == 0 else "user"
|
role = "admin" if user_count == 0 else "user"
|
||||||
|
|
||||||
|
# Assign plan: first user gets pro, others get free
|
||||||
|
if user_count == 0:
|
||||||
|
plan = Plan.get_pro_plan()
|
||||||
|
else:
|
||||||
|
plan = Plan.get_default_plan()
|
||||||
|
|
||||||
# Create new user
|
# Create new user
|
||||||
user = cls(
|
user = cls(
|
||||||
email=email,
|
email=email,
|
||||||
name=name,
|
name=name,
|
||||||
picture=picture,
|
picture=picture,
|
||||||
role=role,
|
role=role,
|
||||||
|
plan_id=plan.id,
|
||||||
|
credits=plan.credits, # Set credits from plan
|
||||||
)
|
)
|
||||||
user.generate_api_token() # Generate API token on creation
|
user.generate_api_token() # Generate API token on creation
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
@@ -209,20 +228,30 @@ class User(db.Model):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create_with_password(cls, email: str, password: str, name: str) -> "User":
|
def create_with_password(cls, email: str, password: str, name: str) -> "User":
|
||||||
"""Create new user with email and password."""
|
"""Create new user with email and password."""
|
||||||
|
from app.models.plan import Plan
|
||||||
|
|
||||||
# Check if user already exists
|
# Check if user already exists
|
||||||
existing_user = cls.find_by_email(email)
|
existing_user = cls.find_by_email(email)
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise ValueError("User with this email already exists")
|
raise ValueError("User with this email already exists")
|
||||||
|
|
||||||
# Check if this is the first user (admin)
|
# Check if this is the first user (admin with pro plan)
|
||||||
user_count = cls.query.count()
|
user_count = cls.query.count()
|
||||||
role = "admin" if user_count == 0 else "user"
|
role = "admin" if user_count == 0 else "user"
|
||||||
|
|
||||||
|
# Assign plan: first user gets pro, others get free
|
||||||
|
if user_count == 0:
|
||||||
|
plan = Plan.get_pro_plan()
|
||||||
|
else:
|
||||||
|
plan = Plan.get_default_plan()
|
||||||
|
|
||||||
# Create new user
|
# Create new user
|
||||||
user = cls(
|
user = cls(
|
||||||
email=email,
|
email=email,
|
||||||
name=name,
|
name=name,
|
||||||
role=role,
|
role=role,
|
||||||
|
plan_id=plan.id,
|
||||||
|
credits=plan.credits, # Set credits from plan
|
||||||
)
|
)
|
||||||
user.set_password(password)
|
user.set_password(password)
|
||||||
user.generate_api_token() # Generate API token on creation
|
user.generate_api_token() # Generate API token on creation
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
||||||
from app.services.decorators import get_current_user, require_auth, require_role
|
from app.services.decorators import get_current_user, require_auth, require_role, require_credits
|
||||||
|
|
||||||
bp = Blueprint("main", __name__)
|
bp = Blueprint("main", __name__)
|
||||||
|
|
||||||
@@ -52,3 +52,29 @@ def admin_only() -> dict[str, str]:
|
|||||||
def health() -> dict[str, str]:
|
def health() -> dict[str, str]:
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint."""
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/use-credits/<int:amount>")
|
||||||
|
@require_auth
|
||||||
|
@require_credits(5)
|
||||||
|
def use_credits(amount: int) -> dict[str, str]:
|
||||||
|
"""Test endpoint that costs 5 credits to use."""
|
||||||
|
user = get_current_user()
|
||||||
|
return {
|
||||||
|
"message": f"Successfully used endpoint! You requested amount: {amount}",
|
||||||
|
"user": user["email"],
|
||||||
|
"remaining_credits": user["credits"] - 5, # Note: credits already deducted by decorator
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/expensive-operation")
|
||||||
|
@require_auth
|
||||||
|
@require_credits(10)
|
||||||
|
def expensive_operation() -> dict[str, str]:
|
||||||
|
"""Test endpoint that costs 10 credits to use."""
|
||||||
|
user = get_current_user()
|
||||||
|
return {
|
||||||
|
"message": "Expensive operation completed successfully!",
|
||||||
|
"user": user["email"],
|
||||||
|
"operation_cost": 10,
|
||||||
|
}
|
||||||
|
|||||||
@@ -81,17 +81,9 @@ class AuthService:
|
|||||||
response.status_code = 401
|
response.status_code = 401
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Prepare user data for JWT token
|
# Prepare user data for JWT token using user.to_dict()
|
||||||
jwt_user_data = {
|
jwt_user_data = user.to_dict()
|
||||||
"id": str(user.id),
|
jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"picture": user.picture,
|
|
||||||
"role": user.role,
|
|
||||||
"is_active": user.is_active,
|
|
||||||
"provider": oauth_provider.provider,
|
|
||||||
"providers": [p.provider for p in user.oauth_providers],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate JWT tokens
|
# Generate JWT tokens
|
||||||
access_token = self.token_service.generate_access_token(
|
access_token = self.token_service.generate_access_token(
|
||||||
@@ -138,6 +130,8 @@ class AuthService:
|
|||||||
claims = get_jwt()
|
claims = get_jwt()
|
||||||
|
|
||||||
if current_user_id:
|
if current_user_id:
|
||||||
|
# Get plan information from JWT claims
|
||||||
|
plan_data = claims.get("plan")
|
||||||
return {
|
return {
|
||||||
"id": current_user_id,
|
"id": current_user_id,
|
||||||
"email": claims.get("email", ""),
|
"email": claims.get("email", ""),
|
||||||
@@ -147,6 +141,8 @@ class AuthService:
|
|||||||
"is_active": claims.get("is_active", True),
|
"is_active": claims.get("is_active", True),
|
||||||
"provider": claims.get("provider", "unknown"),
|
"provider": claims.get("provider", "unknown"),
|
||||||
"providers": claims.get("providers", []),
|
"providers": claims.get("providers", []),
|
||||||
|
"plan": plan_data,
|
||||||
|
"credits": claims.get("credits"),
|
||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -158,17 +154,9 @@ class AuthService:
|
|||||||
# Create user with password
|
# Create user with password
|
||||||
user = User.create_with_password(email, password, name)
|
user = User.create_with_password(email, password, name)
|
||||||
|
|
||||||
# Prepare user data for JWT token
|
# Prepare user data for JWT token using user.to_dict()
|
||||||
jwt_user_data = {
|
jwt_user_data = user.to_dict()
|
||||||
"id": str(user.id),
|
jwt_user_data["provider"] = "password" # Override provider for password registration
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"picture": user.picture,
|
|
||||||
"role": user.role,
|
|
||||||
"is_active": user.is_active,
|
|
||||||
"provider": "password",
|
|
||||||
"providers": ["password"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate JWT tokens
|
# Generate JWT tokens
|
||||||
access_token = self.token_service.generate_access_token(
|
access_token = self.token_service.generate_access_token(
|
||||||
@@ -209,21 +197,9 @@ class AuthService:
|
|||||||
response.status_code = 401
|
response.status_code = 401
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Prepare user data for JWT token
|
# Prepare user data for JWT token using user.to_dict()
|
||||||
oauth_providers = [p.provider for p in user.oauth_providers]
|
jwt_user_data = user.to_dict()
|
||||||
if user.has_password():
|
jwt_user_data["provider"] = "password" # Override provider for password login
|
||||||
oauth_providers.append("password")
|
|
||||||
|
|
||||||
jwt_user_data = {
|
|
||||||
"id": str(user.id),
|
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"picture": user.picture,
|
|
||||||
"role": user.role,
|
|
||||||
"is_active": user.is_active,
|
|
||||||
"provider": "password",
|
|
||||||
"providers": oauth_providers,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Generate JWT tokens
|
# Generate JWT tokens
|
||||||
access_token = self.token_service.generate_access_token(jwt_user_data)
|
access_token = self.token_service.generate_access_token(jwt_user_data)
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ def get_user_from_jwt() -> dict[str, Any] | None:
|
|||||||
"is_active": is_active,
|
"is_active": is_active,
|
||||||
"provider": claims.get("provider", "unknown"),
|
"provider": claims.get("provider", "unknown"),
|
||||||
"providers": claims.get("providers", []),
|
"providers": claims.get("providers", []),
|
||||||
|
"plan": claims.get("plan"),
|
||||||
|
"credits": claims.get("credits"),
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
@@ -64,6 +66,8 @@ def get_user_from_api_token() -> dict[str, Any] | None:
|
|||||||
"provider": "api_token",
|
"provider": "api_token",
|
||||||
"providers": [p.provider for p in user.oauth_providers]
|
"providers": [p.provider for p in user.oauth_providers]
|
||||||
+ ["api_token"],
|
+ ["api_token"],
|
||||||
|
"plan": user.plan.to_dict() if user.plan else None,
|
||||||
|
"credits": user.credits,
|
||||||
}
|
}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@@ -126,3 +130,45 @@ def require_role(required_role: str):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def require_credits(credits_needed: int):
|
||||||
|
"""Decorator to require and deduct credits for routes."""
|
||||||
|
def decorator(f):
|
||||||
|
@wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
from app.models.user import User
|
||||||
|
from app.database import db
|
||||||
|
|
||||||
|
# First check authentication
|
||||||
|
user_data = get_current_user()
|
||||||
|
if not user_data:
|
||||||
|
return jsonify({"error": "Authentication required"}), 401
|
||||||
|
|
||||||
|
# Get the actual user from database to check/update credits
|
||||||
|
user = User.query.get(int(user_data["id"]))
|
||||||
|
if not user or not user.is_active:
|
||||||
|
return jsonify({"error": "User not found or inactive"}), 401
|
||||||
|
|
||||||
|
# Check if user has enough credits
|
||||||
|
if user.credits < credits_needed:
|
||||||
|
return (
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}"
|
||||||
|
}
|
||||||
|
),
|
||||||
|
402, # Payment Required status code
|
||||||
|
)
|
||||||
|
|
||||||
|
# Deduct credits
|
||||||
|
user.credits -= credits_needed
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# Execute the function
|
||||||
|
result = f(*args, **kwargs)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ class TokenService:
|
|||||||
"is_active": user_data.get("is_active"),
|
"is_active": user_data.get("is_active"),
|
||||||
"provider": user_data.get("provider"),
|
"provider": user_data.get("provider"),
|
||||||
"providers": user_data.get("providers", []),
|
"providers": user_data.get("providers", []),
|
||||||
|
"plan": user_data.get("plan"),
|
||||||
|
"credits": user_data.get("credits"),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ cli = FlaskGroup(app)
|
|||||||
def init_db():
|
def init_db():
|
||||||
"""Initialize the database."""
|
"""Initialize the database."""
|
||||||
print("Initializing database...")
|
print("Initializing database...")
|
||||||
db.create_all()
|
from app.database_init import init_database
|
||||||
|
init_database()
|
||||||
print("Database initialized successfully!")
|
print("Database initialized successfully!")
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@@ -21,7 +22,8 @@ def reset_db():
|
|||||||
"""Reset the database (drop all tables and recreate)."""
|
"""Reset the database (drop all tables and recreate)."""
|
||||||
print("Resetting database...")
|
print("Resetting database...")
|
||||||
db.drop_all()
|
db.drop_all()
|
||||||
db.create_all()
|
from app.database_init import init_database
|
||||||
|
init_database()
|
||||||
print("Database reset successfully!")
|
print("Database reset successfully!")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user