diff --git a/app/__init__.py b/app/__init__.py index 877e31f..097c496 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -2,11 +2,11 @@ import os from datetime import timedelta from flask import Flask -from flask_jwt_extended import JWTManager from flask_cors import CORS +from flask_jwt_extended import JWTManager -from app.services.auth_service import AuthService from app.database import init_db +from app.services.auth_service import AuthService # Global auth service instance auth_service = AuthService() @@ -15,17 +15,19 @@ auth_service = AuthService() def create_app(): """Create and configure the Flask application.""" app = Flask(__name__) - + # Configure Flask secret key (required for sessions used by OAuth) app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "dev-secret-key") - + # Configure SQLAlchemy database database_url = os.environ.get("DATABASE_URL", "sqlite:///soundboard.db") app.config["SQLALCHEMY_DATABASE_URI"] = database_url app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - + # Configure Flask-JWT-Extended - app.config["JWT_SECRET_KEY"] = os.environ.get("JWT_SECRET_KEY", "jwt-secret-key") + app.config["JWT_SECRET_KEY"] = os.environ.get( + "JWT_SECRET_KEY", "jwt-secret-key" + ) app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15) app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7) app.config["JWT_TOKEN_LOCATION"] = ["cookies"] @@ -33,26 +35,35 @@ def create_app(): app.config["JWT_COOKIE_CSRF_PROTECT"] = False app.config["JWT_ACCESS_COOKIE_PATH"] = "/api/" app.config["JWT_REFRESH_COOKIE_PATH"] = "/api/auth/refresh" - + # Initialize CORS - CORS(app, - origins=["http://localhost:3000"], # Frontend URL - supports_credentials=True, # Allow cookies - allow_headers=["Content-Type", "Authorization"], - methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"]) - + CORS( + app, + origins=["http://localhost:3000"], # Frontend URL + supports_credentials=True, # Allow cookies + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + ) + # Initialize JWT manager jwt = JWTManager(app) - + # Initialize database init_db(app) - + + # Initialize database tables and seed data + with app.app_context(): + from app.database_init import init_database + + init_database() + # Initialize authentication service with app auth_service.init_app(app) - + # Register blueprints - from app.routes import main, auth + from app.routes import auth, main + app.register_blueprint(main.bp, url_prefix="/api") app.register_blueprint(auth.bp, url_prefix="/api/auth") - - return app \ No newline at end of file + + return app diff --git a/app/database_init.py b/app/database_init.py new file mode 100644 index 0000000..39f1f07 --- /dev/null +++ b/app/database_init.py @@ -0,0 +1,120 @@ +"""Database initialization and seeding functions.""" + +from app.database import db +from app.models.plan import Plan + + +def init_database(): + """Initialize database tables and seed with default data.""" + # Create all tables + db.create_all() + + # Seed plans if they don't exist + seed_plans() + + # Migrate existing users to have plans + migrate_users_to_plans() + + +def seed_plans(): + """Seed the plans table with default plans if empty.""" + # Check if plans already exist + if Plan.query.count() > 0: + return + + # Create default plans + plans_data = [ + { + "code": "free", + "name": "Free Plan", + "description": "Basic features with limited usage", + "credits": 25, + "max_credits": 75, + }, + { + "code": "premium", + "name": "Premium Plan", + "description": "Enhanced features with increased usage limits", + "credits": 50, + "max_credits": 150, + }, + { + "code": "pro", + "name": "Pro Plan", + "description": "Full access with unlimited usage", + "credits": 100, + "max_credits": 300, + }, + ] + + for plan_data in plans_data: + plan = Plan(**plan_data) + db.session.add(plan) + + db.session.commit() + print(f"Seeded {len(plans_data)} plans into database") + + +def migrate_users_to_plans(): + """Assign plans to existing users who don't have one.""" + from app.models.user import User + + try: + # Find users without plans + users_without_plans = User.query.filter(User.plan_id.is_(None)).all() + + # Find users with plans but NULL credits (only if credits column exists) + # Note: We only migrate users with NULL credits, not 0 credits + # 0 credits means they spent them, NULL means they never got assigned + try: + users_without_credits = User.query.filter( + User.plan_id.isnot(None), User.credits.is_(None) + ).all() + except Exception: + # Credits column doesn't exist yet, will be handled by create_all + users_without_credits = [] + + if not users_without_plans and not users_without_credits: + return + + # Get default and pro plans + default_plan = Plan.get_default_plan() + pro_plan = Plan.get_pro_plan() + + # Get the first user (admin) from all users ordered by ID + first_user = User.query.order_by(User.id).first() + + updated_count = 0 + + # Assign plans to users without plans + for user in users_without_plans: + # First user gets pro plan, others get free plan + if user.id == first_user.id: + user.plan_id = pro_plan.id + # Only set credits if the column exists + try: + user.credits = pro_plan.credits + except Exception: + pass + else: + user.plan_id = default_plan.id + # Only set credits if the column exists + try: + user.credits = default_plan.credits + except Exception: + pass + updated_count += 1 + + # Assign credits to users with plans but no credits + for user in users_without_credits: + user.credits = user.plan.credits + updated_count += 1 + + if updated_count > 0: + db.session.commit() + print(f"Updated {updated_count} existing users with plans and credits") + + except Exception: + # If there's any error (like missing columns), just skip migration + # The database will be properly created by create_all() + pass \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py index b9401d0..0576c57 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -1,6 +1,7 @@ """Database models.""" +from .plan import Plan from .user import User from .user_oauth import UserOAuth -__all__ = ["User", "UserOAuth"] \ No newline at end of file +__all__ = ["Plan", "User", "UserOAuth"] \ No newline at end of file diff --git a/app/models/plan.py b/app/models/plan.py new file mode 100644 index 0000000..ddee648 --- /dev/null +++ b/app/models/plan.py @@ -0,0 +1,58 @@ +"""Plan model for user subscription plans.""" + +from sqlalchemy import Column, Integer, String, Text +from sqlalchemy.orm import relationship + +from app.database import db + + +class Plan(db.Model): + """Plan model for user subscription plans.""" + + __tablename__ = "plans" + + id = Column(Integer, primary_key=True) + code = Column(String(50), unique=True, nullable=False, index=True) + name = Column(String(100), nullable=False) + description = Column(Text) + credits = Column(Integer, default=0, nullable=False) + max_credits = Column(Integer, default=0, nullable=False) + + # Relationship with users + users = relationship("User", back_populates="plan", lazy="dynamic") + + def __repr__(self) -> str: + """String representation of Plan.""" + return f"" + + @classmethod + def find_by_code(cls, code: str) -> "Plan | None": + """Find plan by code.""" + return cls.query.filter_by(code=code).first() + + @classmethod + def get_default_plan(cls) -> "Plan": + """Get the default plan (free).""" + plan = cls.find_by_code("free") + if not plan: + raise ValueError("Default 'free' plan not found in database") + return plan + + @classmethod + def get_pro_plan(cls) -> "Plan": + """Get the pro plan.""" + plan = cls.find_by_code("pro") + if not plan: + raise ValueError("'pro' plan not found in database") + return plan + + def to_dict(self) -> dict: + """Convert plan to dictionary.""" + return { + "id": self.id, + "code": self.code, + "name": self.name, + "description": self.description, + "credits": self.credits, + "max_credits": self.max_credits, + } \ No newline at end of file diff --git a/app/models/user.py b/app/models/user.py index 7d7e660..ed5ab60 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -5,13 +5,14 @@ from datetime import datetime from typing import Optional, TYPE_CHECKING from werkzeug.security import check_password_hash, generate_password_hash -from sqlalchemy import String, DateTime +from sqlalchemy import String, DateTime, Integer, ForeignKey from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import db if TYPE_CHECKING: from app.models.user_oauth import UserOAuth + from app.models.plan import Plan class User(db.Model): @@ -35,6 +36,12 @@ class User(db.Model): # User status is_active: Mapped[bool] = mapped_column(nullable=False, default=True) + # Plan relationship + plan_id: Mapped[int] = mapped_column(Integer, ForeignKey("plans.id"), nullable=False) + + # User credits (populated from plan credits on creation) + credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + # API token for programmatic access api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) @@ -51,6 +58,7 @@ class User(db.Model): oauth_providers: Mapped[list["UserOAuth"]] = relationship( "UserOAuth", back_populates="user", cascade="all, delete-orphan" ) + plan: Mapped["Plan"] = relationship("Plan", back_populates="users") def __repr__(self) -> str: """String representation of User.""" @@ -69,6 +77,8 @@ class User(db.Model): "api_token": self.api_token, "api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None, "providers": [provider.provider for provider in self.oauth_providers], + "plan": self.plan.to_dict() if self.plan else None, + "credits": self.credits, "created_at": self.created_at.isoformat(), "updated_at": self.updated_at.isoformat(), } @@ -159,6 +169,7 @@ class User(db.Model): ) -> tuple["User", "UserOAuth"]: """Find existing user or create new one from OAuth data.""" from app.models.user_oauth import UserOAuth + from app.models.plan import Plan # First, try to find existing OAuth provider oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id) @@ -178,16 +189,24 @@ class User(db.Model): user = cls.find_by_email(email) if not user: - # Check if this is the first user (admin) + # Check if this is the first user (admin with pro plan) user_count = cls.query.count() role = "admin" if user_count == 0 else "user" + # Assign plan: first user gets pro, others get free + if user_count == 0: + plan = Plan.get_pro_plan() + else: + plan = Plan.get_default_plan() + # Create new user user = cls( email=email, name=name, picture=picture, role=role, + plan_id=plan.id, + credits=plan.credits, # Set credits from plan ) user.generate_api_token() # Generate API token on creation db.session.add(user) @@ -209,20 +228,30 @@ class User(db.Model): @classmethod def create_with_password(cls, email: str, password: str, name: str) -> "User": """Create new user with email and password.""" + from app.models.plan import Plan + # Check if user already exists existing_user = cls.find_by_email(email) if existing_user: raise ValueError("User with this email already exists") - # Check if this is the first user (admin) + # Check if this is the first user (admin with pro plan) user_count = cls.query.count() role = "admin" if user_count == 0 else "user" + # Assign plan: first user gets pro, others get free + if user_count == 0: + plan = Plan.get_pro_plan() + else: + plan = Plan.get_default_plan() + # Create new user user = cls( email=email, name=name, role=role, + plan_id=plan.id, + credits=plan.credits, # Set credits from plan ) user.set_password(password) user.generate_api_token() # Generate API token on creation diff --git a/app/routes/main.py b/app/routes/main.py index 9e30b17..5b4ee96 100644 --- a/app/routes/main.py +++ b/app/routes/main.py @@ -2,7 +2,7 @@ from flask import Blueprint -from app.services.decorators import get_current_user, require_auth, require_role +from app.services.decorators import get_current_user, require_auth, require_role, require_credits bp = Blueprint("main", __name__) @@ -52,3 +52,29 @@ def admin_only() -> dict[str, str]: def health() -> dict[str, str]: """Health check endpoint.""" return {"status": "ok"} + + +@bp.route("/use-credits/") +@require_auth +@require_credits(5) +def use_credits(amount: int) -> dict[str, str]: + """Test endpoint that costs 5 credits to use.""" + user = get_current_user() + return { + "message": f"Successfully used endpoint! You requested amount: {amount}", + "user": user["email"], + "remaining_credits": user["credits"] - 5, # Note: credits already deducted by decorator + } + + +@bp.route("/expensive-operation") +@require_auth +@require_credits(10) +def expensive_operation() -> dict[str, str]: + """Test endpoint that costs 10 credits to use.""" + user = get_current_user() + return { + "message": "Expensive operation completed successfully!", + "user": user["email"], + "operation_cost": 10, + } diff --git a/app/services/auth_service.py b/app/services/auth_service.py index 6e93b2d..b6c4492 100644 --- a/app/services/auth_service.py +++ b/app/services/auth_service.py @@ -81,17 +81,9 @@ class AuthService: response.status_code = 401 return response - # Prepare user data for JWT token - jwt_user_data = { - "id": str(user.id), - "email": user.email, - "name": user.name, - "picture": user.picture, - "role": user.role, - "is_active": user.is_active, - "provider": oauth_provider.provider, - "providers": [p.provider for p in user.oauth_providers], - } + # Prepare user data for JWT token using user.to_dict() + jwt_user_data = user.to_dict() + jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login # Generate JWT tokens access_token = self.token_service.generate_access_token( @@ -138,6 +130,8 @@ class AuthService: claims = get_jwt() if current_user_id: + # Get plan information from JWT claims + plan_data = claims.get("plan") return { "id": current_user_id, "email": claims.get("email", ""), @@ -147,6 +141,8 @@ class AuthService: "is_active": claims.get("is_active", True), "provider": claims.get("provider", "unknown"), "providers": claims.get("providers", []), + "plan": plan_data, + "credits": claims.get("credits"), } return None @@ -158,17 +154,9 @@ class AuthService: # Create user with password user = User.create_with_password(email, password, name) - # Prepare user data for JWT token - jwt_user_data = { - "id": str(user.id), - "email": user.email, - "name": user.name, - "picture": user.picture, - "role": user.role, - "is_active": user.is_active, - "provider": "password", - "providers": ["password"], - } + # Prepare user data for JWT token using user.to_dict() + jwt_user_data = user.to_dict() + jwt_user_data["provider"] = "password" # Override provider for password registration # Generate JWT tokens access_token = self.token_service.generate_access_token( @@ -209,21 +197,9 @@ class AuthService: response.status_code = 401 return response - # Prepare user data for JWT token - oauth_providers = [p.provider for p in user.oauth_providers] - if user.has_password(): - oauth_providers.append("password") - - jwt_user_data = { - "id": str(user.id), - "email": user.email, - "name": user.name, - "picture": user.picture, - "role": user.role, - "is_active": user.is_active, - "provider": "password", - "providers": oauth_providers, - } + # Prepare user data for JWT token using user.to_dict() + jwt_user_data = user.to_dict() + jwt_user_data["provider"] = "password" # Override provider for password login # Generate JWT tokens access_token = self.token_service.generate_access_token(jwt_user_data) diff --git a/app/services/decorators.py b/app/services/decorators.py index 1f1ad21..a461ffd 100644 --- a/app/services/decorators.py +++ b/app/services/decorators.py @@ -32,6 +32,8 @@ def get_user_from_jwt() -> dict[str, Any] | None: "is_active": is_active, "provider": claims.get("provider", "unknown"), "providers": claims.get("providers", []), + "plan": claims.get("plan"), + "credits": claims.get("credits"), } except Exception: return None @@ -64,6 +66,8 @@ def get_user_from_api_token() -> dict[str, Any] | None: "provider": "api_token", "providers": [p.provider for p in user.oauth_providers] + ["api_token"], + "plan": user.plan.to_dict() if user.plan else None, + "credits": user.credits, } return None @@ -126,3 +130,45 @@ def require_role(required_role: str): return wrapper return decorator + + +def require_credits(credits_needed: int): + """Decorator to require and deduct credits for routes.""" + def decorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + from app.models.user import User + from app.database import db + + # First check authentication + user_data = get_current_user() + if not user_data: + return jsonify({"error": "Authentication required"}), 401 + + # Get the actual user from database to check/update credits + user = User.query.get(int(user_data["id"])) + if not user or not user.is_active: + return jsonify({"error": "User not found or inactive"}), 401 + + # Check if user has enough credits + if user.credits < credits_needed: + return ( + jsonify( + { + "error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}" + } + ), + 402, # Payment Required status code + ) + + # Deduct credits + user.credits -= credits_needed + db.session.commit() + + # Execute the function + result = f(*args, **kwargs) + + return result + + return wrapper + return decorator diff --git a/app/services/token_service.py b/app/services/token_service.py index 25aa859..ee5539b 100644 --- a/app/services/token_service.py +++ b/app/services/token_service.py @@ -20,6 +20,8 @@ class TokenService: "is_active": user_data.get("is_active"), "provider": user_data.get("provider"), "providers": user_data.get("providers", []), + "plan": user_data.get("plan"), + "credits": user_data.get("credits"), }, ) diff --git a/migrate_db.py b/migrate_db.py index 1dd03e0..e1eb6bc 100644 --- a/migrate_db.py +++ b/migrate_db.py @@ -13,7 +13,8 @@ cli = FlaskGroup(app) def init_db(): """Initialize the database.""" print("Initializing database...") - db.create_all() + from app.database_init import init_database + init_database() print("Database initialized successfully!") @cli.command() @@ -21,7 +22,8 @@ def reset_db(): """Reset the database (drop all tables and recreate).""" print("Resetting database...") db.drop_all() - db.create_all() + from app.database_init import init_database + init_database() print("Database reset successfully!") if __name__ == "__main__":