refactor: clean up code by adding missing commas and improving import order
This commit is contained in:
@@ -26,7 +26,7 @@ def create_app():
|
|||||||
|
|
||||||
# Configure Flask-JWT-Extended
|
# Configure Flask-JWT-Extended
|
||||||
app.config["JWT_SECRET_KEY"] = os.environ.get(
|
app.config["JWT_SECRET_KEY"] = os.environ.get(
|
||||||
"JWT_SECRET_KEY", "jwt-secret-key"
|
"JWT_SECRET_KEY", "jwt-secret-key",
|
||||||
)
|
)
|
||||||
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
|
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
|
||||||
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)
|
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Database configuration and initialization."""
|
"""Database configuration and initialization."""
|
||||||
|
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
|
||||||
from flask_migrate import Migrate
|
from flask_migrate import Migrate
|
||||||
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
|
|
||||||
db = SQLAlchemy()
|
db = SQLAlchemy()
|
||||||
migrate = Migrate()
|
migrate = Migrate()
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ def migrate_users_to_plans():
|
|||||||
# 0 credits means they spent them, NULL means they never got assigned
|
# 0 credits means they spent them, NULL means they never got assigned
|
||||||
try:
|
try:
|
||||||
users_without_credits = User.query.filter(
|
users_without_credits = User.query.filter(
|
||||||
User.plan_id.isnot(None), User.credits.is_(None)
|
User.plan_id.isnot(None), User.credits.is_(None),
|
||||||
).all()
|
).all()
|
||||||
except Exception:
|
except Exception:
|
||||||
# Credits column doesn't exist yet, will be handled by create_all
|
# Credits column doesn't exist yet, will be handled by create_all
|
||||||
@@ -113,7 +113,7 @@ def migrate_users_to_plans():
|
|||||||
if updated_count > 0:
|
if updated_count > 0:
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
print(
|
print(
|
||||||
f"Updated {updated_count} existing users with plans and credits"
|
f"Updated {updated_count} existing users with plans and credits",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -2,17 +2,17 @@
|
|||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from werkzeug.security import check_password_hash, generate_password_hash
|
from sqlalchemy import DateTime, ForeignKey, Integer, String
|
||||||
from sqlalchemy import String, DateTime, Integer, ForeignKey
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
from werkzeug.security import check_password_hash, generate_password_hash
|
||||||
|
|
||||||
from app.database import db
|
from app.database import db
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.models.user_oauth import UserOAuth
|
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
|
from app.models.user_oauth import UserOAuth
|
||||||
|
|
||||||
|
|
||||||
class User(db.Model):
|
class User(db.Model):
|
||||||
@@ -25,16 +25,16 @@ class User(db.Model):
|
|||||||
# Primary user information (can be updated from any connected provider)
|
# Primary user information (can be updated from any connected provider)
|
||||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
picture: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
|
||||||
# Password authentication (optional - users can use OAuth instead)
|
# Password authentication (optional - users can use OAuth instead)
|
||||||
password_hash: Mapped[Optional[str]] = mapped_column(
|
password_hash: Mapped[str | None] = mapped_column(
|
||||||
String(255), nullable=True
|
String(255), nullable=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Role-based access control
|
# Role-based access control
|
||||||
role: Mapped[str] = mapped_column(
|
role: Mapped[str] = mapped_column(
|
||||||
String(50), nullable=False, default="user"
|
String(50), nullable=False, default="user",
|
||||||
)
|
)
|
||||||
|
|
||||||
# User status
|
# User status
|
||||||
@@ -42,21 +42,21 @@ class User(db.Model):
|
|||||||
|
|
||||||
# Plan relationship
|
# Plan relationship
|
||||||
plan_id: Mapped[int] = mapped_column(
|
plan_id: Mapped[int] = mapped_column(
|
||||||
Integer, ForeignKey("plans.id"), nullable=False
|
Integer, ForeignKey("plans.id"), nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# User credits (populated from plan credits on creation)
|
# User credits (populated from plan credits on creation)
|
||||||
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
# API token for programmatic access
|
# API token for programmatic access
|
||||||
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
api_token: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(
|
api_token_expires_at: Mapped[datetime | None] = mapped_column(
|
||||||
DateTime, nullable=True
|
DateTime, nullable=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, nullable=False
|
DateTime, default=datetime.utcnow, nullable=False,
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime,
|
DateTime,
|
||||||
@@ -67,7 +67,7 @@ class User(db.Model):
|
|||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
||||||
"UserOAuth", back_populates="user", cascade="all, delete-orphan"
|
"UserOAuth", back_populates="user", cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
|
plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
|
||||||
|
|
||||||
@@ -190,15 +190,15 @@ class User(db.Model):
|
|||||||
provider_id: str,
|
provider_id: str,
|
||||||
email: str,
|
email: str,
|
||||||
name: str,
|
name: str,
|
||||||
picture: Optional[str] = None,
|
picture: str | None = None,
|
||||||
) -> tuple["User", "UserOAuth"]:
|
) -> tuple["User", "UserOAuth"]:
|
||||||
"""Find existing user or create new one from OAuth data."""
|
"""Find existing user or create new one from OAuth data."""
|
||||||
from app.models.user_oauth import UserOAuth
|
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
|
from app.models.user_oauth import UserOAuth
|
||||||
|
|
||||||
# First, try to find existing OAuth provider
|
# First, try to find existing OAuth provider
|
||||||
oauth_provider = UserOAuth.find_by_provider_and_id(
|
oauth_provider = UserOAuth.find_by_provider_and_id(
|
||||||
provider, provider_id
|
provider, provider_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if oauth_provider:
|
if oauth_provider:
|
||||||
@@ -211,7 +211,7 @@ class User(db.Model):
|
|||||||
|
|
||||||
# Update user info with latest data
|
# Update user info with latest data
|
||||||
user.update_from_provider(
|
user.update_from_provider(
|
||||||
{"email": email, "name": name, "picture": picture}
|
{"email": email, "name": name, "picture": picture},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Try to find user by email to link the new provider
|
# Try to find user by email to link the new provider
|
||||||
@@ -256,7 +256,7 @@ class User(db.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_with_password(
|
def create_with_password(
|
||||||
cls, email: str, password: str, name: str
|
cls, email: str, password: str, name: str,
|
||||||
) -> "User":
|
) -> "User":
|
||||||
"""Create new user with email and password."""
|
"""Create new user with email and password."""
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
@@ -293,7 +293,7 @@ class User(db.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def authenticate_with_password(
|
def authenticate_with_password(
|
||||||
cls, email: str, password: str
|
cls, email: str, password: str,
|
||||||
) -> Optional["User"]:
|
) -> Optional["User"]:
|
||||||
"""Authenticate user with email and password."""
|
"""Authenticate user with email and password."""
|
||||||
user = cls.find_by_email(email)
|
user = cls.find_by_email(email)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""User OAuth model for storing user's connected providers."""
|
"""User OAuth model for storing user's connected providers."""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from sqlalchemy import String, DateTime, Text, ForeignKey
|
from sqlalchemy import DateTime, ForeignKey, String, Text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import db
|
from app.database import db
|
||||||
@@ -29,11 +29,11 @@ class UserOAuth(db.Model):
|
|||||||
# Provider-specific user information
|
# Provider-specific user information
|
||||||
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
picture: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, nullable=False
|
DateTime, default=datetime.utcnow, nullable=False,
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime,
|
DateTime,
|
||||||
@@ -45,13 +45,13 @@ class UserOAuth(db.Model):
|
|||||||
# Unique constraint on provider + provider_id combination
|
# Unique constraint on provider + provider_id combination
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.UniqueConstraint(
|
db.UniqueConstraint(
|
||||||
"provider", "provider_id", name="unique_provider_user"
|
"provider", "provider_id", name="unique_provider_user",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user: Mapped["User"] = relationship(
|
user: Mapped["User"] = relationship(
|
||||||
"User", back_populates="oauth_providers"
|
"User", back_populates="oauth_providers",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -73,11 +73,11 @@ class UserOAuth(db.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_provider_and_id(
|
def find_by_provider_and_id(
|
||||||
cls, provider: str, provider_id: str
|
cls, provider: str, provider_id: str,
|
||||||
) -> Optional["UserOAuth"]:
|
) -> Optional["UserOAuth"]:
|
||||||
"""Find OAuth provider by provider name and provider ID."""
|
"""Find OAuth provider by provider name and provider ID."""
|
||||||
return cls.query.filter_by(
|
return cls.query.filter_by(
|
||||||
provider=provider, provider_id=provider_id
|
provider=provider, provider_id=provider_id,
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -88,7 +88,7 @@ class UserOAuth(db.Model):
|
|||||||
provider_id: str,
|
provider_id: str,
|
||||||
email: str,
|
email: str,
|
||||||
name: str,
|
name: str,
|
||||||
picture: Optional[str] = None,
|
picture: str | None = None,
|
||||||
) -> "UserOAuth":
|
) -> "UserOAuth":
|
||||||
"""Create new OAuth provider or update existing one."""
|
"""Create new OAuth provider or update existing one."""
|
||||||
oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
|
oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ def callback(provider):
|
|||||||
# If successful, redirect to frontend dashboard with cookies
|
# If successful, redirect to frontend dashboard with cookies
|
||||||
if auth_response.status_code == 200:
|
if auth_response.status_code == 200:
|
||||||
redirect_response = make_response(
|
redirect_response = make_response(
|
||||||
redirect("http://localhost:3000/dashboard")
|
redirect("http://localhost:3000/dashboard"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copy all cookies from the auth response
|
# Copy all cookies from the auth response
|
||||||
@@ -39,9 +39,8 @@ def callback(provider):
|
|||||||
redirect_response.headers.add("Set-Cookie", cookie)
|
redirect_response.headers.add("Set-Cookie", cookie)
|
||||||
|
|
||||||
return redirect_response
|
return redirect_response
|
||||||
else:
|
# If there was an error, redirect to login with error
|
||||||
# If there was an error, redirect to login with error
|
return redirect("http://localhost:3000/login?error=oauth_failed")
|
||||||
return redirect("http://localhost:3000/login?error=oauth_failed")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e).replace(" ", "_").replace('"', "")
|
error_msg = str(e).replace(" ", "_").replace('"', "")
|
||||||
@@ -129,7 +128,7 @@ def refresh():
|
|||||||
def link_provider(provider):
|
def link_provider(provider):
|
||||||
"""Link a new OAuth provider to current user account."""
|
"""Link a new OAuth provider to current user account."""
|
||||||
redirect_uri = url_for(
|
redirect_uri = url_for(
|
||||||
"auth.link_callback", provider=provider, _external=True
|
"auth.link_callback", provider=provider, _external=True,
|
||||||
)
|
)
|
||||||
return auth_service.redirect_to_login(provider, redirect_uri)
|
return auth_service.redirect_to_login(provider, redirect_uri)
|
||||||
|
|
||||||
@@ -168,19 +167,19 @@ def link_callback(provider):
|
|||||||
|
|
||||||
if not provider_data.get("id"):
|
if not provider_data.get("id"):
|
||||||
return {
|
return {
|
||||||
"error": "Failed to get user information from provider"
|
"error": "Failed to get user information from provider",
|
||||||
}, 400
|
}, 400
|
||||||
|
|
||||||
# Check if this provider is already linked to another user
|
# Check if this provider is already linked to another user
|
||||||
from app.models.user_oauth import UserOAuth
|
from app.models.user_oauth import UserOAuth
|
||||||
|
|
||||||
existing_provider = UserOAuth.find_by_provider_and_id(
|
existing_provider = UserOAuth.find_by_provider_and_id(
|
||||||
provider, provider_data["id"]
|
provider, provider_data["id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if existing_provider and existing_provider.user_id != user.id:
|
if existing_provider and existing_provider.user_id != user.id:
|
||||||
return {
|
return {
|
||||||
"error": "This provider account is already linked to another user"
|
"error": "This provider account is already linked to another user",
|
||||||
}, 409
|
}, 409
|
||||||
|
|
||||||
# Link the provider to current user
|
# Link the provider to current user
|
||||||
@@ -210,7 +209,6 @@ def unlink_provider(provider):
|
|||||||
|
|
||||||
from app.database import db
|
from app.database import db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.models.user_oauth import UserOAuth
|
|
||||||
|
|
||||||
user = User.query.get(current_user_id)
|
user = User.query.get(current_user_id)
|
||||||
if not user:
|
if not user:
|
||||||
@@ -224,7 +222,7 @@ def unlink_provider(provider):
|
|||||||
oauth_provider = user.get_provider(provider)
|
oauth_provider = user.get_provider(provider)
|
||||||
if not oauth_provider:
|
if not oauth_provider:
|
||||||
return {
|
return {
|
||||||
"error": f"Provider '{provider}' not linked to this account"
|
"error": f"Provider '{provider}' not linked to this account",
|
||||||
}, 404
|
}, 404
|
||||||
|
|
||||||
db.session.delete(oauth_provider)
|
db.session.delete(oauth_provider)
|
||||||
@@ -279,6 +277,7 @@ def me():
|
|||||||
def update_profile():
|
def update_profile():
|
||||||
"""Update current user profile information."""
|
"""Update current user profile information."""
|
||||||
from flask import request
|
from flask import request
|
||||||
|
|
||||||
from app.database import db
|
from app.database import db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
@@ -323,7 +322,7 @@ def update_profile():
|
|||||||
return {"message": "Profile updated successfully", "user": updated_user}
|
return {"message": "Profile updated successfully", "user": updated_user}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
return {"error": f"Failed to update profile: {str(e)}"}, 500
|
return {"error": f"Failed to update profile: {e!s}"}, 500
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/password", methods=["PUT"])
|
@bp.route("/password", methods=["PUT"])
|
||||||
@@ -331,9 +330,10 @@ def update_profile():
|
|||||||
def change_password():
|
def change_password():
|
||||||
"""Change or set user password."""
|
"""Change or set user password."""
|
||||||
from flask import request
|
from flask import request
|
||||||
|
from werkzeug.security import check_password_hash
|
||||||
|
|
||||||
from app.database import db
|
from app.database import db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from werkzeug.security import check_password_hash
|
|
||||||
|
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data:
|
if not data:
|
||||||
@@ -365,7 +365,7 @@ def change_password():
|
|||||||
# User has a password AND logged in via password, require current password for verification
|
# User has a password AND logged in via password, require current password for verification
|
||||||
if not current_password:
|
if not current_password:
|
||||||
return {
|
return {
|
||||||
"error": "Current password is required to change password"
|
"error": "Current password is required to change password",
|
||||||
}, 400
|
}, 400
|
||||||
|
|
||||||
if not check_password_hash(user.password_hash, current_password):
|
if not check_password_hash(user.password_hash, current_password):
|
||||||
@@ -380,4 +380,4 @@ def change_password():
|
|||||||
return {"message": "Password updated successfully"}
|
return {"message": "Password updated successfully"}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
return {"error": f"Failed to update password: {str(e)}"}, 500
|
return {"error": f"Failed to update password: {e!s}"}, 500
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ from flask import Blueprint
|
|||||||
from app.services.decorators import (
|
from app.services.decorators import (
|
||||||
get_current_user,
|
get_current_user,
|
||||||
require_auth,
|
require_auth,
|
||||||
require_role,
|
|
||||||
require_credits,
|
require_credits,
|
||||||
|
require_role,
|
||||||
)
|
)
|
||||||
|
|
||||||
bp = Blueprint("main", __name__)
|
bp = Blueprint("main", __name__)
|
||||||
|
|||||||
@@ -89,10 +89,10 @@ class AuthService:
|
|||||||
|
|
||||||
# Generate JWT tokens
|
# Generate JWT tokens
|
||||||
access_token = self.token_service.generate_access_token(
|
access_token = self.token_service.generate_access_token(
|
||||||
jwt_user_data
|
jwt_user_data,
|
||||||
)
|
)
|
||||||
refresh_token = self.token_service.generate_refresh_token(
|
refresh_token = self.token_service.generate_refresh_token(
|
||||||
jwt_user_data
|
jwt_user_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create response and set HTTP-only cookies
|
# Create response and set HTTP-only cookies
|
||||||
@@ -100,7 +100,7 @@ class AuthService:
|
|||||||
{
|
{
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
"user": jwt_user_data,
|
"user": jwt_user_data,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set JWT cookies
|
# Set JWT cookies
|
||||||
@@ -149,7 +149,7 @@ class AuthService:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def register_with_password(
|
def register_with_password(
|
||||||
self, email: str, password: str, name: str
|
self, email: str, password: str, name: str,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Register new user with email and password."""
|
"""Register new user with email and password."""
|
||||||
try:
|
try:
|
||||||
@@ -164,10 +164,10 @@ class AuthService:
|
|||||||
|
|
||||||
# Generate JWT tokens
|
# Generate JWT tokens
|
||||||
access_token = self.token_service.generate_access_token(
|
access_token = self.token_service.generate_access_token(
|
||||||
jwt_user_data
|
jwt_user_data,
|
||||||
)
|
)
|
||||||
refresh_token = self.token_service.generate_refresh_token(
|
refresh_token = self.token_service.generate_refresh_token(
|
||||||
jwt_user_data
|
jwt_user_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create response and set HTTP-only cookies
|
# Create response and set HTTP-only cookies
|
||||||
@@ -175,7 +175,7 @@ class AuthService:
|
|||||||
{
|
{
|
||||||
"message": "Registration successful",
|
"message": "Registration successful",
|
||||||
"user": jwt_user_data,
|
"user": jwt_user_data,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set JWT cookies
|
# Set JWT cookies
|
||||||
@@ -196,7 +196,7 @@ class AuthService:
|
|||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
response = jsonify(
|
response = jsonify(
|
||||||
{"error": "Invalid email, password or disabled account"}
|
{"error": "Invalid email, password or disabled account"},
|
||||||
)
|
)
|
||||||
response.status_code = 401
|
response.status_code = 401
|
||||||
return response
|
return response
|
||||||
@@ -216,7 +216,7 @@ class AuthService:
|
|||||||
{
|
{
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
"user": jwt_user_data,
|
"user": jwt_user_data,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set JWT cookies
|
# Set JWT cookies
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from functools import wraps
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
from flask_jwt_extended import get_jwt, get_jwt_identity, verify_jwt_in_request
|
from flask_jwt_extended import get_jwt_identity, verify_jwt_in_request
|
||||||
|
|
||||||
|
|
||||||
def get_user_from_jwt() -> dict[str, Any] | None:
|
def get_user_from_jwt() -> dict[str, Any] | None:
|
||||||
@@ -109,7 +109,7 @@ def require_auth(f):
|
|||||||
if not user:
|
if not user:
|
||||||
return (
|
return (
|
||||||
jsonify(
|
jsonify(
|
||||||
{"error": "Authentication required (JWT or API token)"}
|
{"error": "Authentication required (JWT or API token)"},
|
||||||
),
|
),
|
||||||
401,
|
401,
|
||||||
)
|
)
|
||||||
@@ -133,8 +133,8 @@ def require_role(required_role: str):
|
|||||||
return (
|
return (
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"error": f"Access denied. {required_role.title()} role required"
|
"error": f"Access denied. {required_role.title()} role required",
|
||||||
}
|
},
|
||||||
),
|
),
|
||||||
403,
|
403,
|
||||||
)
|
)
|
||||||
@@ -152,8 +152,8 @@ def require_credits(credits_needed: int):
|
|||||||
def decorator(f):
|
def decorator(f):
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
from app.models.user import User
|
|
||||||
from app.database import db
|
from app.database import db
|
||||||
|
from app.models.user import User
|
||||||
|
|
||||||
# First check authentication
|
# First check authentication
|
||||||
user_data = get_current_user()
|
user_data = get_current_user()
|
||||||
@@ -170,8 +170,8 @@ def require_credits(credits_needed: int):
|
|||||||
return (
|
return (
|
||||||
jsonify(
|
jsonify(
|
||||||
{
|
{
|
||||||
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}"
|
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}",
|
||||||
}
|
},
|
||||||
),
|
),
|
||||||
402, # Payment Required status code
|
402, # Payment Required status code
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from authlib.integrations.flask_client import OAuth
|
from authlib.integrations.flask_client import OAuth
|
||||||
|
|
||||||
|
|
||||||
@@ -16,23 +17,19 @@ class OAuthProvider(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Provider name (e.g., 'google', 'github')."""
|
"""Provider name (e.g., 'google', 'github')."""
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def display_name(self) -> str:
|
def display_name(self) -> str:
|
||||||
"""Human-readable provider name (e.g., 'Google', 'GitHub')."""
|
"""Human-readable provider name (e.g., 'Google', 'GitHub')."""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_client_config(self) -> Dict[str, Any]:
|
def get_client_config(self) -> dict[str, Any]:
|
||||||
"""Return OAuth client configuration."""
|
"""Return OAuth client configuration."""
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Extract user information from OAuth token response."""
|
"""Extract user information from OAuth token response."""
|
||||||
pass
|
|
||||||
|
|
||||||
def get_client(self):
|
def get_client(self):
|
||||||
"""Get or create OAuth client."""
|
"""Get or create OAuth client."""
|
||||||
@@ -52,14 +49,14 @@ class OAuthProvider(ABC):
|
|||||||
return client.authorize_redirect(redirect_uri).location
|
return client.authorize_redirect(redirect_uri).location
|
||||||
|
|
||||||
def exchange_code_for_token(
|
def exchange_code_for_token(
|
||||||
self, code: str = None, redirect_uri: str = None
|
self, code: str = None, redirect_uri: str = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Exchange authorization code for access token."""
|
"""Exchange authorization code for access token."""
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
token = client.authorize_access_token()
|
token = client.authorize_access_token()
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
|
def normalize_user_data(self, user_info: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Normalize user data to common format."""
|
"""Normalize user data to common format."""
|
||||||
return {
|
return {
|
||||||
"id": user_info.get("id"),
|
"id": user_info.get("id"),
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from typing import Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from .base import OAuthProvider
|
from .base import OAuthProvider
|
||||||
|
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ class GitHubOAuthProvider(OAuthProvider):
|
|||||||
def display_name(self) -> str:
|
def display_name(self) -> str:
|
||||||
return "GitHub"
|
return "GitHub"
|
||||||
|
|
||||||
def get_client_config(self) -> Dict[str, Any]:
|
def get_client_config(self) -> dict[str, Any]:
|
||||||
"""Return GitHub OAuth client configuration."""
|
"""Return GitHub OAuth client configuration."""
|
||||||
return {
|
return {
|
||||||
"access_token_url": "https://github.com/login/oauth/access_token",
|
"access_token_url": "https://github.com/login/oauth/access_token",
|
||||||
@@ -22,7 +23,7 @@ class GitHubOAuthProvider(OAuthProvider):
|
|||||||
"client_kwargs": {"scope": "user:email"},
|
"client_kwargs": {"scope": "user:email"},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Extract user information from GitHub OAuth token response."""
|
"""Extract user information from GitHub OAuth token response."""
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from .base import OAuthProvider
|
from .base import OAuthProvider
|
||||||
|
|
||||||
@@ -14,14 +14,14 @@ class GoogleOAuthProvider(OAuthProvider):
|
|||||||
def display_name(self) -> str:
|
def display_name(self) -> str:
|
||||||
return "Google"
|
return "Google"
|
||||||
|
|
||||||
def get_client_config(self) -> Dict[str, Any]:
|
def get_client_config(self) -> dict[str, Any]:
|
||||||
"""Return Google OAuth client configuration."""
|
"""Return Google OAuth client configuration."""
|
||||||
return {
|
return {
|
||||||
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
|
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
|
||||||
"client_kwargs": {"scope": "openid email profile"},
|
"client_kwargs": {"scope": "openid email profile"},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Extract user information from Google OAuth token response."""
|
"""Extract user information from Google OAuth token response."""
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
user_info = client.userinfo(token=token)
|
user_info = client.userinfo(token=token)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, Optional
|
|
||||||
from authlib.integrations.flask_client import OAuth
|
from authlib.integrations.flask_client import OAuth
|
||||||
|
|
||||||
from .base import OAuthProvider
|
from .base import OAuthProvider
|
||||||
from .google import GoogleOAuthProvider
|
|
||||||
from .github import GitHubOAuthProvider
|
from .github import GitHubOAuthProvider
|
||||||
|
from .google import GoogleOAuthProvider
|
||||||
|
|
||||||
|
|
||||||
class OAuthProviderRegistry:
|
class OAuthProviderRegistry:
|
||||||
@@ -11,7 +12,7 @@ class OAuthProviderRegistry:
|
|||||||
|
|
||||||
def __init__(self, oauth: OAuth):
|
def __init__(self, oauth: OAuth):
|
||||||
self.oauth = oauth
|
self.oauth = oauth
|
||||||
self._providers: Dict[str, OAuthProvider] = {}
|
self._providers: dict[str, OAuthProvider] = {}
|
||||||
self._initialize_providers()
|
self._initialize_providers()
|
||||||
|
|
||||||
def _initialize_providers(self):
|
def _initialize_providers(self):
|
||||||
@@ -21,7 +22,7 @@ class OAuthProviderRegistry:
|
|||||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||||
if google_client_id and google_client_secret:
|
if google_client_id and google_client_secret:
|
||||||
self._providers["google"] = GoogleOAuthProvider(
|
self._providers["google"] = GoogleOAuthProvider(
|
||||||
self.oauth, google_client_id, google_client_secret
|
self.oauth, google_client_id, google_client_secret,
|
||||||
)
|
)
|
||||||
|
|
||||||
# GitHub OAuth
|
# GitHub OAuth
|
||||||
@@ -29,14 +30,14 @@ class OAuthProviderRegistry:
|
|||||||
github_client_secret = os.getenv("GITHUB_CLIENT_SECRET")
|
github_client_secret = os.getenv("GITHUB_CLIENT_SECRET")
|
||||||
if github_client_id and github_client_secret:
|
if github_client_id and github_client_secret:
|
||||||
self._providers["github"] = GitHubOAuthProvider(
|
self._providers["github"] = GitHubOAuthProvider(
|
||||||
self.oauth, github_client_id, github_client_secret
|
self.oauth, github_client_id, github_client_secret,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_provider(self, name: str) -> Optional[OAuthProvider]:
|
def get_provider(self, name: str) -> OAuthProvider | None:
|
||||||
"""Get OAuth provider by name."""
|
"""Get OAuth provider by name."""
|
||||||
return self._providers.get(name)
|
return self._providers.get(name)
|
||||||
|
|
||||||
def get_available_providers(self) -> Dict[str, OAuthProvider]:
|
def get_available_providers(self) -> dict[str, OAuthProvider]:
|
||||||
"""Get all available providers."""
|
"""Get all available providers."""
|
||||||
return self._providers.copy()
|
return self._providers.copy()
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Database migration script for Flask-Migrate."""
|
"""Database migration script for Flask-Migrate."""
|
||||||
|
|
||||||
import os
|
|
||||||
from flask.cli import FlaskGroup
|
from flask.cli import FlaskGroup
|
||||||
|
|
||||||
from app import create_app
|
from app import create_app
|
||||||
from app.database import db
|
from app.database import db
|
||||||
|
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
"""Tests for authentication routes with Flask-JWT-Extended."""
|
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app import create_app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client():
|
|
||||||
"""Create a test client for the Flask application."""
|
|
||||||
app = create_app()
|
|
||||||
app.config["TESTING"] = True
|
|
||||||
app.config["JWT_COOKIE_SECURE"] = False # Allow cookies in testing
|
|
||||||
with app.test_client() as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuthRoutesJWTExtended:
|
|
||||||
"""Test cases for authentication routes with Flask-JWT-Extended."""
|
|
||||||
|
|
||||||
@patch("app.routes.auth.auth_service.get_login_url")
|
|
||||||
def test_login_route(self, mock_get_login_url: Mock, client) -> None:
|
|
||||||
"""Test the login route."""
|
|
||||||
mock_get_login_url.return_value = (
|
|
||||||
"https://accounts.google.com/oauth/authorize?..."
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/api/auth/login")
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.get_json()
|
|
||||||
assert "login_url" in data
|
|
||||||
assert (
|
|
||||||
data["login_url"]
|
|
||||||
== "https://accounts.google.com/oauth/authorize?..."
|
|
||||||
)
|
|
||||||
|
|
||||||
@patch("app.routes.auth.auth_service.handle_callback")
|
|
||||||
def test_callback_route_success(
|
|
||||||
self, mock_handle_callback: Mock, client
|
|
||||||
) -> None:
|
|
||||||
"""Test successful callback route."""
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.get_json.return_value = {
|
|
||||||
"message": "Login successful",
|
|
||||||
"user": {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
mock_handle_callback.return_value = mock_response
|
|
||||||
|
|
||||||
response = client.get("/api/auth/callback?code=test_code")
|
|
||||||
mock_handle_callback.assert_called_once()
|
|
||||||
|
|
||||||
@patch("app.routes.auth.auth_service.handle_callback")
|
|
||||||
def test_callback_route_error(
|
|
||||||
self, mock_handle_callback: Mock, client
|
|
||||||
) -> None:
|
|
||||||
"""Test callback route with error."""
|
|
||||||
mock_handle_callback.side_effect = Exception("OAuth error")
|
|
||||||
|
|
||||||
response = client.get("/api/auth/callback?code=test_code")
|
|
||||||
assert response.status_code == 400
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["error"] == "OAuth error"
|
|
||||||
|
|
||||||
@patch("app.routes.auth.auth_service.logout")
|
|
||||||
def test_logout_route(self, mock_logout: Mock, client) -> None:
|
|
||||||
"""Test logout route."""
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_response.get_json.return_value = {
|
|
||||||
"message": "Logged out successfully"
|
|
||||||
}
|
|
||||||
mock_logout.return_value = mock_response
|
|
||||||
|
|
||||||
response = client.get("/api/auth/logout")
|
|
||||||
mock_logout.assert_called_once()
|
|
||||||
|
|
||||||
def test_me_route_not_authenticated(self, client) -> None:
|
|
||||||
"""Test /me route when not authenticated."""
|
|
||||||
response = client.get("/api/auth/me")
|
|
||||||
assert response.status_code == 401
|
|
||||||
data = response.get_json()
|
|
||||||
assert "msg" in data # Flask-JWT-Extended error format
|
|
||||||
|
|
||||||
def test_refresh_route_not_authenticated(self, client) -> None:
|
|
||||||
"""Test /refresh route when not authenticated."""
|
|
||||||
response = client.post("/api/auth/refresh")
|
|
||||||
assert response.status_code == 401
|
|
||||||
data = response.get_json()
|
|
||||||
assert "msg" in data # Flask-JWT-Extended error format
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
"""Tests for AuthService with Flask-JWT-Extended."""
|
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
from app import create_app
|
|
||||||
from app.services.auth_service import AuthService
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuthServiceJWTExtended:
|
|
||||||
"""Test cases for AuthService with Flask-JWT-Extended."""
|
|
||||||
|
|
||||||
def test_init_without_app(self) -> None:
|
|
||||||
"""Test initializing AuthService without Flask app."""
|
|
||||||
auth_service = AuthService()
|
|
||||||
assert auth_service.oauth is not None
|
|
||||||
assert auth_service.google is None
|
|
||||||
assert auth_service.token_service is not None
|
|
||||||
|
|
||||||
@patch("app.services.auth_service.os.getenv")
|
|
||||||
def test_init_app(self, mock_getenv: Mock) -> None:
|
|
||||||
"""Test initializing AuthService with Flask app."""
|
|
||||||
mock_getenv.side_effect = lambda key: {
|
|
||||||
"GOOGLE_CLIENT_ID": "test_client_id",
|
|
||||||
"GOOGLE_CLIENT_SECRET": "test_client_secret",
|
|
||||||
}.get(key)
|
|
||||||
|
|
||||||
app = create_app()
|
|
||||||
auth_service = AuthService()
|
|
||||||
auth_service.init_app(app)
|
|
||||||
|
|
||||||
# Verify OAuth was initialized
|
|
||||||
assert auth_service.google is not None
|
|
||||||
|
|
||||||
@patch("app.services.auth_service.unset_jwt_cookies")
|
|
||||||
@patch("app.services.auth_service.jsonify")
|
|
||||||
def test_logout(self, mock_jsonify: Mock, mock_unset: Mock) -> None:
|
|
||||||
"""Test logout functionality."""
|
|
||||||
app = create_app()
|
|
||||||
with app.app_context():
|
|
||||||
mock_response = Mock()
|
|
||||||
mock_jsonify.return_value = mock_response
|
|
||||||
|
|
||||||
auth_service = AuthService()
|
|
||||||
result = auth_service.logout()
|
|
||||||
|
|
||||||
assert result == mock_response
|
|
||||||
mock_unset.assert_called_once_with(mock_response)
|
|
||||||
mock_jsonify.assert_called_once_with(
|
|
||||||
{"message": "Logged out successfully"}
|
|
||||||
)
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
"""Tests for routes."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app import create_app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client():
|
|
||||||
"""Create a test client for the Flask application."""
|
|
||||||
app = create_app()
|
|
||||||
app.config["TESTING"] = True
|
|
||||||
with app.test_client() as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
class TestMainRoutes:
|
|
||||||
"""Test cases for main routes."""
|
|
||||||
|
|
||||||
def test_index_route(self, client) -> None:
|
|
||||||
"""Test the index route."""
|
|
||||||
response = client.get("/api/")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.get_json() == {
|
|
||||||
"message": "API is running",
|
|
||||||
"status": "ok",
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_health_route(self, client) -> None:
|
|
||||||
"""Test health check route."""
|
|
||||||
response = client.get("/api/health")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.get_json() == {"status": "ok"}
|
|
||||||
|
|
||||||
def test_protected_route_without_auth(self, client) -> None:
|
|
||||||
"""Test protected route without authentication."""
|
|
||||||
response = client.get("/api/protected")
|
|
||||||
assert response.status_code == 401
|
|
||||||
data = response.get_json()
|
|
||||||
assert data["error"] == "Authentication required (JWT or API token)"
|
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
"""Tests for TokenService."""
|
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.services.token_service import TokenService
|
|
||||||
|
|
||||||
|
|
||||||
class TestTokenService:
|
|
||||||
"""Test cases for TokenService."""
|
|
||||||
|
|
||||||
def test_init(self) -> None:
|
|
||||||
"""Test TokenService initialization."""
|
|
||||||
token_service = TokenService()
|
|
||||||
assert token_service.algorithm == "HS256"
|
|
||||||
assert token_service.access_token_expire_minutes == 15
|
|
||||||
assert token_service.refresh_token_expire_days == 7
|
|
||||||
|
|
||||||
def test_generate_access_token(self) -> None:
|
|
||||||
"""Test access token generation."""
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
|
||||||
assert isinstance(token, str)
|
|
||||||
|
|
||||||
# Verify token content
|
|
||||||
payload = jwt.decode(
|
|
||||||
token,
|
|
||||||
token_service.secret_key,
|
|
||||||
algorithms=[token_service.algorithm],
|
|
||||||
)
|
|
||||||
assert payload["user_id"] == "123"
|
|
||||||
assert payload["email"] == "test@example.com"
|
|
||||||
assert payload["name"] == "Test User"
|
|
||||||
assert payload["type"] == "access"
|
|
||||||
|
|
||||||
def test_generate_refresh_token(self) -> None:
|
|
||||||
"""Test refresh token generation."""
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = token_service.generate_refresh_token(user_data)
|
|
||||||
assert isinstance(token, str)
|
|
||||||
|
|
||||||
# Verify token content
|
|
||||||
payload = jwt.decode(
|
|
||||||
token,
|
|
||||||
token_service.secret_key,
|
|
||||||
algorithms=[token_service.algorithm],
|
|
||||||
)
|
|
||||||
assert payload["user_id"] == "123"
|
|
||||||
assert payload["type"] == "refresh"
|
|
||||||
|
|
||||||
def test_verify_valid_token(self) -> None:
|
|
||||||
"""Test verifying a valid token."""
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
|
||||||
payload = token_service.verify_token(token)
|
|
||||||
|
|
||||||
assert payload is not None
|
|
||||||
assert payload["user_id"] == "123"
|
|
||||||
assert payload["type"] == "access"
|
|
||||||
|
|
||||||
def test_verify_invalid_token(self) -> None:
|
|
||||||
"""Test verifying an invalid token."""
|
|
||||||
token_service = TokenService()
|
|
||||||
|
|
||||||
payload = token_service.verify_token("invalid.token.here")
|
|
||||||
assert payload is None
|
|
||||||
|
|
||||||
@patch("app.services.token_service.datetime")
|
|
||||||
def test_verify_expired_token(self, mock_datetime) -> None:
|
|
||||||
"""Test verifying an expired token."""
|
|
||||||
# Set up mock to return a past time for token generation
|
|
||||||
past_time = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
|
||||||
mock_datetime.now.return_value = past_time
|
|
||||||
mock_datetime.UTC = timezone.utc
|
|
||||||
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
|
||||||
|
|
||||||
# Reset mock to current time for verification
|
|
||||||
mock_datetime.now.return_value = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
payload = token_service.verify_token(token)
|
|
||||||
assert payload is None
|
|
||||||
|
|
||||||
def test_is_access_token(self) -> None:
|
|
||||||
"""Test access token type checking."""
|
|
||||||
token_service = TokenService()
|
|
||||||
|
|
||||||
access_payload = {"type": "access", "user_id": "123"}
|
|
||||||
refresh_payload = {"type": "refresh", "user_id": "123"}
|
|
||||||
|
|
||||||
assert token_service.is_access_token(access_payload)
|
|
||||||
assert not token_service.is_access_token(refresh_payload)
|
|
||||||
|
|
||||||
def test_is_refresh_token(self) -> None:
|
|
||||||
"""Test refresh token type checking."""
|
|
||||||
token_service = TokenService()
|
|
||||||
|
|
||||||
access_payload = {"type": "access", "user_id": "123"}
|
|
||||||
refresh_payload = {"type": "refresh", "user_id": "123"}
|
|
||||||
|
|
||||||
assert token_service.is_refresh_token(refresh_payload)
|
|
||||||
assert not token_service.is_refresh_token(access_payload)
|
|
||||||
|
|
||||||
def test_get_user_from_access_token_valid(self) -> None:
|
|
||||||
"""Test extracting user from valid access token."""
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
|
||||||
extracted_user = token_service.get_user_from_access_token(token)
|
|
||||||
|
|
||||||
assert extracted_user == user_data
|
|
||||||
|
|
||||||
def test_get_user_from_access_token_refresh_token(self) -> None:
|
|
||||||
"""Test extracting user from refresh token (should fail)."""
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = token_service.generate_refresh_token(user_data)
|
|
||||||
extracted_user = token_service.get_user_from_access_token(token)
|
|
||||||
|
|
||||||
assert extracted_user is None
|
|
||||||
|
|
||||||
def test_get_user_from_access_token_invalid(self) -> None:
|
|
||||||
"""Test extracting user from invalid token."""
|
|
||||||
token_service = TokenService()
|
|
||||||
|
|
||||||
extracted_user = token_service.get_user_from_access_token(
|
|
||||||
"invalid.token"
|
|
||||||
)
|
|
||||||
assert extracted_user is None
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
"""Tests for TokenService using Flask-JWT-Extended."""
|
|
||||||
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from app import create_app
|
|
||||||
from app.services.token_service import TokenService
|
|
||||||
|
|
||||||
|
|
||||||
class TestTokenServiceJWTExtended:
|
|
||||||
"""Test cases for TokenService with Flask-JWT-Extended."""
|
|
||||||
|
|
||||||
def test_generate_access_token(self) -> None:
|
|
||||||
"""Test access token generation."""
|
|
||||||
app = create_app()
|
|
||||||
with app.app_context():
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
"picture": "https://example.com/pic.jpg",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
|
||||||
assert isinstance(token, str)
|
|
||||||
assert len(token) > 0
|
|
||||||
|
|
||||||
def test_generate_refresh_token(self) -> None:
|
|
||||||
"""Test refresh token generation."""
|
|
||||||
app = create_app()
|
|
||||||
with app.app_context():
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
}
|
|
||||||
|
|
||||||
token = token_service.generate_refresh_token(user_data)
|
|
||||||
assert isinstance(token, str)
|
|
||||||
assert len(token) > 0
|
|
||||||
|
|
||||||
def test_generate_tokens_different(self) -> None:
|
|
||||||
"""Test that access and refresh tokens are different."""
|
|
||||||
app = create_app()
|
|
||||||
with app.app_context():
|
|
||||||
token_service = TokenService()
|
|
||||||
user_data = {
|
|
||||||
"id": "123",
|
|
||||||
"email": "test@example.com",
|
|
||||||
"name": "Test User",
|
|
||||||
}
|
|
||||||
|
|
||||||
access_token = token_service.generate_access_token(user_data)
|
|
||||||
refresh_token = token_service.generate_refresh_token(user_data)
|
|
||||||
|
|
||||||
assert access_token != refresh_token
|
|
||||||
Reference in New Issue
Block a user