Refactor code for improved readability and consistency
- Cleaned up whitespace and formatting across multiple files for better readability.
This commit is contained in:
@@ -11,8 +11,8 @@ def init_db(app):
|
|||||||
"""Initialize database with Flask app."""
|
"""Initialize database with Flask app."""
|
||||||
db.init_app(app)
|
db.init_app(app)
|
||||||
migrate.init_app(app, db)
|
migrate.init_app(app, db)
|
||||||
|
|
||||||
# Import models here to ensure they are registered with SQLAlchemy
|
# Import models here to ensure they are registered with SQLAlchemy
|
||||||
from app.models import user, user_oauth # noqa: F401
|
from app.models import user, user_oauth # noqa: F401
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ def init_database():
|
|||||||
"""Initialize database tables and seed with default data."""
|
"""Initialize database tables and seed with default data."""
|
||||||
# Create all tables
|
# Create all tables
|
||||||
db.create_all()
|
db.create_all()
|
||||||
|
|
||||||
# Seed plans if they don't exist
|
# Seed plans if they don't exist
|
||||||
seed_plans()
|
seed_plans()
|
||||||
|
|
||||||
# Migrate existing users to have plans
|
# Migrate existing users to have plans
|
||||||
migrate_users_to_plans()
|
migrate_users_to_plans()
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ def seed_plans():
|
|||||||
# Check if plans already exist
|
# Check if plans already exist
|
||||||
if Plan.query.count() > 0:
|
if Plan.query.count() > 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create default plans
|
# Create default plans
|
||||||
plans_data = [
|
plans_data = [
|
||||||
{
|
{
|
||||||
@@ -46,11 +46,11 @@ def seed_plans():
|
|||||||
"max_credits": 300,
|
"max_credits": 300,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for plan_data in plans_data:
|
for plan_data in plans_data:
|
||||||
plan = Plan(**plan_data)
|
plan = Plan(**plan_data)
|
||||||
db.session.add(plan)
|
db.session.add(plan)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
print(f"Seeded {len(plans_data)} plans into database")
|
print(f"Seeded {len(plans_data)} plans into database")
|
||||||
|
|
||||||
@@ -58,11 +58,11 @@ def seed_plans():
|
|||||||
def migrate_users_to_plans():
|
def migrate_users_to_plans():
|
||||||
"""Assign plans to existing users who don't have one."""
|
"""Assign plans to existing users who don't have one."""
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Find users without plans
|
# Find users without plans
|
||||||
users_without_plans = User.query.filter(User.plan_id.is_(None)).all()
|
users_without_plans = User.query.filter(User.plan_id.is_(None)).all()
|
||||||
|
|
||||||
# Find users with plans but NULL credits (only if credits column exists)
|
# Find users with plans but NULL credits (only if credits column exists)
|
||||||
# Note: We only migrate users with NULL credits, not 0 credits
|
# Note: We only migrate users with NULL credits, not 0 credits
|
||||||
# 0 credits means they spent them, NULL means they never got assigned
|
# 0 credits means they spent them, NULL means they never got assigned
|
||||||
@@ -73,19 +73,19 @@ def migrate_users_to_plans():
|
|||||||
except Exception:
|
except Exception:
|
||||||
# Credits column doesn't exist yet, will be handled by create_all
|
# Credits column doesn't exist yet, will be handled by create_all
|
||||||
users_without_credits = []
|
users_without_credits = []
|
||||||
|
|
||||||
if not users_without_plans and not users_without_credits:
|
if not users_without_plans and not users_without_credits:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get default and pro plans
|
# Get default and pro plans
|
||||||
default_plan = Plan.get_default_plan()
|
default_plan = Plan.get_default_plan()
|
||||||
pro_plan = Plan.get_pro_plan()
|
pro_plan = Plan.get_pro_plan()
|
||||||
|
|
||||||
# Get the first user (admin) from all users ordered by ID
|
# Get the first user (admin) from all users ordered by ID
|
||||||
first_user = User.query.order_by(User.id).first()
|
first_user = User.query.order_by(User.id).first()
|
||||||
|
|
||||||
updated_count = 0
|
updated_count = 0
|
||||||
|
|
||||||
# Assign plans to users without plans
|
# Assign plans to users without plans
|
||||||
for user in users_without_plans:
|
for user in users_without_plans:
|
||||||
# First user gets pro plan, others get free plan
|
# First user gets pro plan, others get free plan
|
||||||
@@ -104,17 +104,19 @@ def migrate_users_to_plans():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
updated_count += 1
|
updated_count += 1
|
||||||
|
|
||||||
# Assign credits to users with plans but no credits
|
# Assign credits to users with plans but no credits
|
||||||
for user in users_without_credits:
|
for user in users_without_credits:
|
||||||
user.credits = user.plan.credits
|
user.credits = user.plan.credits
|
||||||
updated_count += 1
|
updated_count += 1
|
||||||
|
|
||||||
if updated_count > 0:
|
if updated_count > 0:
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
print(f"Updated {updated_count} existing users with plans and credits")
|
print(
|
||||||
|
f"Updated {updated_count} existing users with plans and credits"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# If there's any error (like missing columns), just skip migration
|
# If there's any error (like missing columns), just skip migration
|
||||||
# The database will be properly created by create_all()
|
# The database will be properly created by create_all()
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ from .plan import Plan
|
|||||||
from .user import User
|
from .user import User
|
||||||
from .user_oauth import UserOAuth
|
from .user_oauth import UserOAuth
|
||||||
|
|
||||||
__all__ = ["Plan", "User", "UserOAuth"]
|
__all__ = ["Plan", "User", "UserOAuth"]
|
||||||
|
|||||||
@@ -55,4 +55,4 @@ class Plan(db.Model):
|
|||||||
"description": self.description,
|
"description": self.description,
|
||||||
"credits": self.credits,
|
"credits": self.credits,
|
||||||
"max_credits": self.max_credits,
|
"max_credits": self.max_credits,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,54 +17,65 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class User(db.Model):
|
class User(db.Model):
|
||||||
"""User model for storing user information."""
|
"""User model for storing user information."""
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
|
||||||
# Primary user information (can be updated from any connected provider)
|
# Primary user information (can be updated from any connected provider)
|
||||||
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||||
|
|
||||||
# Password authentication (optional - users can use OAuth instead)
|
# Password authentication (optional - users can use OAuth instead)
|
||||||
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
password_hash: Mapped[Optional[str]] = mapped_column(
|
||||||
|
String(255), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
# Role-based access control
|
# Role-based access control
|
||||||
role: Mapped[str] = mapped_column(String(50), nullable=False, default="user")
|
role: Mapped[str] = mapped_column(
|
||||||
|
String(50), nullable=False, default="user"
|
||||||
|
)
|
||||||
|
|
||||||
# User status
|
# User status
|
||||||
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
||||||
|
|
||||||
# Plan relationship
|
# Plan relationship
|
||||||
plan_id: Mapped[int] = mapped_column(Integer, ForeignKey("plans.id"), nullable=False)
|
plan_id: Mapped[int] = mapped_column(
|
||||||
|
Integer, ForeignKey("plans.id"), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
# User credits (populated from plan credits on creation)
|
# User credits (populated from plan credits on creation)
|
||||||
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
# API token for programmatic access
|
# API token for programmatic access
|
||||||
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(
|
||||||
|
DateTime, nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, nullable=False
|
DateTime, default=datetime.utcnow, nullable=False
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
DateTime,
|
||||||
|
default=datetime.utcnow,
|
||||||
|
onupdate=datetime.utcnow,
|
||||||
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
oauth_providers: Mapped[list["UserOAuth"]] = relationship(
|
||||||
"UserOAuth", back_populates="user", cascade="all, delete-orphan"
|
"UserOAuth", back_populates="user", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
|
plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""String representation of User."""
|
"""String representation of User."""
|
||||||
provider_count = len(self.oauth_providers)
|
provider_count = len(self.oauth_providers)
|
||||||
return f"<User {self.email} ({provider_count} providers)>"
|
return f"<User {self.email} ({provider_count} providers)>"
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""Convert user to dictionary."""
|
"""Convert user to dictionary."""
|
||||||
# Build comprehensive providers list
|
# Build comprehensive providers list
|
||||||
@@ -73,7 +84,7 @@ class User(db.Model):
|
|||||||
providers.append("password")
|
providers.append("password")
|
||||||
if self.api_token:
|
if self.api_token:
|
||||||
providers.append("api_token")
|
providers.append("api_token")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": str(self.id),
|
"id": str(self.id),
|
||||||
"email": self.email,
|
"email": self.email,
|
||||||
@@ -82,25 +93,27 @@ class User(db.Model):
|
|||||||
"role": self.role,
|
"role": self.role,
|
||||||
"is_active": self.is_active,
|
"is_active": self.is_active,
|
||||||
"api_token": self.api_token,
|
"api_token": self.api_token,
|
||||||
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
|
"api_token_expires_at": self.api_token_expires_at.isoformat()
|
||||||
|
if self.api_token_expires_at
|
||||||
|
else None,
|
||||||
"providers": providers,
|
"providers": providers,
|
||||||
"plan": self.plan.to_dict() if self.plan else None,
|
"plan": self.plan.to_dict() if self.plan else None,
|
||||||
"credits": self.credits,
|
"credits": self.credits,
|
||||||
"created_at": self.created_at.isoformat(),
|
"created_at": self.created_at.isoformat(),
|
||||||
"updated_at": self.updated_at.isoformat(),
|
"updated_at": self.updated_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_provider(self, provider_name: str) -> Optional["UserOAuth"]:
|
def get_provider(self, provider_name: str) -> Optional["UserOAuth"]:
|
||||||
"""Get specific OAuth provider for this user."""
|
"""Get specific OAuth provider for this user."""
|
||||||
for provider in self.oauth_providers:
|
for provider in self.oauth_providers:
|
||||||
if provider.provider == provider_name:
|
if provider.provider == provider_name:
|
||||||
return provider
|
return provider
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def has_provider(self, provider_name: str) -> bool:
|
def has_provider(self, provider_name: str) -> bool:
|
||||||
"""Check if user has specific OAuth provider connected."""
|
"""Check if user has specific OAuth provider connected."""
|
||||||
return self.get_provider(provider_name) is not None
|
return self.get_provider(provider_name) is not None
|
||||||
|
|
||||||
def update_from_provider(self, provider_data: dict) -> None:
|
def update_from_provider(self, provider_data: dict) -> None:
|
||||||
"""Update user info from provider data (email, name, picture)."""
|
"""Update user info from provider data (email, name, picture)."""
|
||||||
self.email = provider_data.get("email", self.email)
|
self.email = provider_data.get("email", self.email)
|
||||||
@@ -108,60 +121,60 @@ class User(db.Model):
|
|||||||
self.picture = provider_data.get("picture", self.picture)
|
self.picture = provider_data.get("picture", self.picture)
|
||||||
self.updated_at = datetime.utcnow()
|
self.updated_at = datetime.utcnow()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def set_password(self, password: str) -> None:
|
def set_password(self, password: str) -> None:
|
||||||
"""Hash and set user password."""
|
"""Hash and set user password."""
|
||||||
self.password_hash = generate_password_hash(password)
|
self.password_hash = generate_password_hash(password)
|
||||||
self.updated_at = datetime.utcnow()
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
def check_password(self, password: str) -> bool:
|
def check_password(self, password: str) -> bool:
|
||||||
"""Check if provided password matches user's password."""
|
"""Check if provided password matches user's password."""
|
||||||
if not self.password_hash:
|
if not self.password_hash:
|
||||||
return False
|
return False
|
||||||
return check_password_hash(self.password_hash, password)
|
return check_password_hash(self.password_hash, password)
|
||||||
|
|
||||||
def has_password(self) -> bool:
|
def has_password(self) -> bool:
|
||||||
"""Check if user has a password set."""
|
"""Check if user has a password set."""
|
||||||
return self.password_hash is not None
|
return self.password_hash is not None
|
||||||
|
|
||||||
def generate_api_token(self) -> str:
|
def generate_api_token(self) -> str:
|
||||||
"""Generate a new API token for the user."""
|
"""Generate a new API token for the user."""
|
||||||
self.api_token = secrets.token_urlsafe(32)
|
self.api_token = secrets.token_urlsafe(32)
|
||||||
self.api_token_expires_at = None # No expiration by default
|
self.api_token_expires_at = None # No expiration by default
|
||||||
self.updated_at = datetime.utcnow()
|
self.updated_at = datetime.utcnow()
|
||||||
return self.api_token
|
return self.api_token
|
||||||
|
|
||||||
def is_api_token_valid(self) -> bool:
|
def is_api_token_valid(self) -> bool:
|
||||||
"""Check if the user's API token is valid (exists and not expired)."""
|
"""Check if the user's API token is valid (exists and not expired)."""
|
||||||
if not self.api_token:
|
if not self.api_token:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.api_token_expires_at is None:
|
if self.api_token_expires_at is None:
|
||||||
return True # No expiration
|
return True # No expiration
|
||||||
|
|
||||||
return datetime.utcnow() < self.api_token_expires_at
|
return datetime.utcnow() < self.api_token_expires_at
|
||||||
|
|
||||||
def revoke_api_token(self) -> None:
|
def revoke_api_token(self) -> None:
|
||||||
"""Revoke the user's API token."""
|
"""Revoke the user's API token."""
|
||||||
self.api_token = None
|
self.api_token = None
|
||||||
self.api_token_expires_at = None
|
self.api_token_expires_at = None
|
||||||
self.updated_at = datetime.utcnow()
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
def activate(self) -> None:
|
def activate(self) -> None:
|
||||||
"""Activate the user account."""
|
"""Activate the user account."""
|
||||||
self.is_active = True
|
self.is_active = True
|
||||||
self.updated_at = datetime.utcnow()
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
def deactivate(self) -> None:
|
def deactivate(self) -> None:
|
||||||
"""Deactivate the user account."""
|
"""Deactivate the user account."""
|
||||||
self.is_active = False
|
self.is_active = False
|
||||||
self.updated_at = datetime.utcnow()
|
self.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_email(cls, email: str) -> Optional["User"]:
|
def find_by_email(cls, email: str) -> Optional["User"]:
|
||||||
"""Find user by email address."""
|
"""Find user by email address."""
|
||||||
return cls.query.filter_by(email=email).first()
|
return cls.query.filter_by(email=email).first()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_api_token(cls, api_token: str) -> Optional["User"]:
|
def find_by_api_token(cls, api_token: str) -> Optional["User"]:
|
||||||
"""Find user by API token if token is valid."""
|
"""Find user by API token if token is valid."""
|
||||||
@@ -169,18 +182,25 @@ class User(db.Model):
|
|||||||
if user and user.is_api_token_valid():
|
if user and user.is_api_token_valid():
|
||||||
return user
|
return user
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_or_create_from_oauth(
|
def find_or_create_from_oauth(
|
||||||
cls, provider: str, provider_id: str, email: str, name: str, picture: Optional[str] = None
|
cls,
|
||||||
|
provider: str,
|
||||||
|
provider_id: str,
|
||||||
|
email: str,
|
||||||
|
name: str,
|
||||||
|
picture: Optional[str] = None,
|
||||||
) -> tuple["User", "UserOAuth"]:
|
) -> tuple["User", "UserOAuth"]:
|
||||||
"""Find existing user or create new one from OAuth data."""
|
"""Find existing user or create new one from OAuth data."""
|
||||||
from app.models.user_oauth import UserOAuth
|
from app.models.user_oauth import UserOAuth
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
|
|
||||||
# First, try to find existing OAuth provider
|
# First, try to find existing OAuth provider
|
||||||
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id)
|
oauth_provider = UserOAuth.find_by_provider_and_id(
|
||||||
|
provider, provider_id
|
||||||
|
)
|
||||||
|
|
||||||
if oauth_provider:
|
if oauth_provider:
|
||||||
# Update existing provider and user info
|
# Update existing provider and user info
|
||||||
user = oauth_provider.user
|
user = oauth_provider.user
|
||||||
@@ -188,24 +208,26 @@ class User(db.Model):
|
|||||||
oauth_provider.name = name
|
oauth_provider.name = name
|
||||||
oauth_provider.picture = picture
|
oauth_provider.picture = picture
|
||||||
oauth_provider.updated_at = datetime.utcnow()
|
oauth_provider.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
# Update user info with latest data
|
# Update user info with latest data
|
||||||
user.update_from_provider({"email": email, "name": name, "picture": picture})
|
user.update_from_provider(
|
||||||
|
{"email": email, "name": name, "picture": picture}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Try to find user by email to link the new provider
|
# Try to find user by email to link the new provider
|
||||||
user = cls.find_by_email(email)
|
user = cls.find_by_email(email)
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
# Check if this is the first user (admin with pro plan)
|
# Check if this is the first user (admin with pro plan)
|
||||||
user_count = cls.query.count()
|
user_count = cls.query.count()
|
||||||
role = "admin" if user_count == 0 else "user"
|
role = "admin" if user_count == 0 else "user"
|
||||||
|
|
||||||
# Assign plan: first user gets pro, others get free
|
# Assign plan: first user gets pro, others get free
|
||||||
if user_count == 0:
|
if user_count == 0:
|
||||||
plan = Plan.get_pro_plan()
|
plan = Plan.get_pro_plan()
|
||||||
else:
|
else:
|
||||||
plan = Plan.get_default_plan()
|
plan = Plan.get_default_plan()
|
||||||
|
|
||||||
# Create new user
|
# Create new user
|
||||||
user = cls(
|
user = cls(
|
||||||
email=email,
|
email=email,
|
||||||
@@ -218,7 +240,7 @@ class User(db.Model):
|
|||||||
user.generate_api_token() # Generate API token on creation
|
user.generate_api_token() # Generate API token on creation
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
db.session.flush() # Flush to get user.id
|
db.session.flush() # Flush to get user.id
|
||||||
|
|
||||||
# Create new OAuth provider
|
# Create new OAuth provider
|
||||||
oauth_provider = UserOAuth.create_or_update(
|
oauth_provider = UserOAuth.create_or_update(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
@@ -228,30 +250,32 @@ class User(db.Model):
|
|||||||
name=name,
|
name=name,
|
||||||
picture=picture,
|
picture=picture,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return user, oauth_provider
|
return user, oauth_provider
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_with_password(cls, email: str, password: str, name: str) -> "User":
|
def create_with_password(
|
||||||
|
cls, email: str, password: str, name: str
|
||||||
|
) -> "User":
|
||||||
"""Create new user with email and password."""
|
"""Create new user with email and password."""
|
||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
|
|
||||||
# Check if user already exists
|
# Check if user already exists
|
||||||
existing_user = cls.find_by_email(email)
|
existing_user = cls.find_by_email(email)
|
||||||
if existing_user:
|
if existing_user:
|
||||||
raise ValueError("User with this email already exists")
|
raise ValueError("User with this email already exists")
|
||||||
|
|
||||||
# Check if this is the first user (admin with pro plan)
|
# Check if this is the first user (admin with pro plan)
|
||||||
user_count = cls.query.count()
|
user_count = cls.query.count()
|
||||||
role = "admin" if user_count == 0 else "user"
|
role = "admin" if user_count == 0 else "user"
|
||||||
|
|
||||||
# Assign plan: first user gets pro, others get free
|
# Assign plan: first user gets pro, others get free
|
||||||
if user_count == 0:
|
if user_count == 0:
|
||||||
plan = Plan.get_pro_plan()
|
plan = Plan.get_pro_plan()
|
||||||
else:
|
else:
|
||||||
plan = Plan.get_default_plan()
|
plan = Plan.get_default_plan()
|
||||||
|
|
||||||
# Create new user
|
# Create new user
|
||||||
user = cls(
|
user = cls(
|
||||||
email=email,
|
email=email,
|
||||||
@@ -262,15 +286,17 @@ class User(db.Model):
|
|||||||
)
|
)
|
||||||
user.set_password(password)
|
user.set_password(password)
|
||||||
user.generate_api_token() # Generate API token on creation
|
user.generate_api_token() # Generate API token on creation
|
||||||
|
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def authenticate_with_password(cls, email: str, password: str) -> Optional["User"]:
|
def authenticate_with_password(
|
||||||
|
cls, email: str, password: str
|
||||||
|
) -> Optional["User"]:
|
||||||
"""Authenticate user with email and password."""
|
"""Authenticate user with email and password."""
|
||||||
user = cls.find_by_email(email)
|
user = cls.find_by_email(email)
|
||||||
if user and user.check_password(password) and user.is_active:
|
if user and user.check_password(password) and user.is_active:
|
||||||
return user
|
return user
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -14,43 +14,50 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class UserOAuth(db.Model):
|
class UserOAuth(db.Model):
|
||||||
"""Model for storing user's connected OAuth providers."""
|
"""Model for storing user's connected OAuth providers."""
|
||||||
|
|
||||||
__tablename__ = "user_oauth"
|
__tablename__ = "user_oauth"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
|
||||||
# User relationship
|
# User relationship
|
||||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
|
||||||
|
|
||||||
# OAuth provider information
|
# OAuth provider information
|
||||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
provider_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
provider_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
|
||||||
# Provider-specific user information
|
# Provider-specific user information
|
||||||
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, nullable=False
|
DateTime, default=datetime.utcnow, nullable=False
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
DateTime,
|
||||||
|
default=datetime.utcnow,
|
||||||
|
onupdate=datetime.utcnow,
|
||||||
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Unique constraint on provider + provider_id combination
|
# Unique constraint on provider + provider_id combination
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
db.UniqueConstraint("provider", "provider_id", name="unique_provider_user"),
|
db.UniqueConstraint(
|
||||||
|
"provider", "provider_id", name="unique_provider_user"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
user: Mapped["User"] = relationship("User", back_populates="oauth_providers")
|
user: Mapped["User"] = relationship(
|
||||||
|
"User", back_populates="oauth_providers"
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""String representation of UserOAuth."""
|
"""String representation of UserOAuth."""
|
||||||
return f"<UserOAuth {self.email} ({self.provider})>"
|
return f"<UserOAuth {self.email} ({self.provider})>"
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""Convert oauth provider to dictionary."""
|
"""Convert oauth provider to dictionary."""
|
||||||
return {
|
return {
|
||||||
@@ -63,25 +70,29 @@ class UserOAuth(db.Model):
|
|||||||
"created_at": self.created_at.isoformat(),
|
"created_at": self.created_at.isoformat(),
|
||||||
"updated_at": self.updated_at.isoformat(),
|
"updated_at": self.updated_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_provider_and_id(cls, provider: str, provider_id: str) -> Optional["UserOAuth"]:
|
def find_by_provider_and_id(
|
||||||
|
cls, provider: str, provider_id: str
|
||||||
|
) -> Optional["UserOAuth"]:
|
||||||
"""Find OAuth provider by provider name and provider ID."""
|
"""Find OAuth provider by provider name and provider ID."""
|
||||||
return cls.query.filter_by(provider=provider, provider_id=provider_id).first()
|
return cls.query.filter_by(
|
||||||
|
provider=provider, provider_id=provider_id
|
||||||
|
).first()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_or_update(
|
def create_or_update(
|
||||||
cls,
|
cls,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
provider: str,
|
provider: str,
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
email: str,
|
email: str,
|
||||||
name: str,
|
name: str,
|
||||||
picture: Optional[str] = None
|
picture: Optional[str] = None,
|
||||||
) -> "UserOAuth":
|
) -> "UserOAuth":
|
||||||
"""Create new OAuth provider or update existing one."""
|
"""Create new OAuth provider or update existing one."""
|
||||||
oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
|
oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
|
||||||
|
|
||||||
if oauth_provider:
|
if oauth_provider:
|
||||||
# Update existing provider
|
# Update existing provider
|
||||||
oauth_provider.user_id = user_id
|
oauth_provider.user_id = user_id
|
||||||
@@ -100,6 +111,6 @@ class UserOAuth(db.Model):
|
|||||||
picture=picture,
|
picture=picture,
|
||||||
)
|
)
|
||||||
db.session.add(oauth_provider)
|
db.session.add(oauth_provider)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return oauth_provider
|
return oauth_provider
|
||||||
|
|||||||
@@ -281,19 +281,19 @@ def update_profile():
|
|||||||
from flask import request
|
from flask import request
|
||||||
from app.database import db
|
from app.database import db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data:
|
if not data:
|
||||||
return {"error": "No data provided"}, 400
|
return {"error": "No data provided"}, 400
|
||||||
|
|
||||||
user_data = get_current_user()
|
user_data = get_current_user()
|
||||||
if not user_data:
|
if not user_data:
|
||||||
return {"error": "User not authenticated"}, 401
|
return {"error": "User not authenticated"}, 401
|
||||||
|
|
||||||
user = User.query.get(int(user_data["id"]))
|
user = User.query.get(int(user_data["id"]))
|
||||||
if not user:
|
if not user:
|
||||||
return {"error": "User not found"}, 404
|
return {"error": "User not found"}, 404
|
||||||
|
|
||||||
# Update allowed fields
|
# Update allowed fields
|
||||||
if "name" in data:
|
if "name" in data:
|
||||||
name = data["name"].strip()
|
name = data["name"].strip()
|
||||||
@@ -302,10 +302,10 @@ def update_profile():
|
|||||||
if len(name) > 100:
|
if len(name) > 100:
|
||||||
return {"error": "Name too long (max 100 characters)"}, 400
|
return {"error": "Name too long (max 100 characters)"}, 400
|
||||||
user.name = name
|
user.name = name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Return fresh user data from database
|
# Return fresh user data from database
|
||||||
updated_user = {
|
updated_user = {
|
||||||
"id": str(user.id),
|
"id": str(user.id),
|
||||||
@@ -319,11 +319,8 @@ def update_profile():
|
|||||||
"plan": user.plan.to_dict() if user.plan else None,
|
"plan": user.plan.to_dict() if user.plan else None,
|
||||||
"credits": user.credits,
|
"credits": user.credits,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {"message": "Profile updated successfully", "user": updated_user}
|
||||||
"message": "Profile updated successfully",
|
|
||||||
"user": updated_user
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
return {"error": f"Failed to update profile: {str(e)}"}, 500
|
return {"error": f"Failed to update profile: {str(e)}"}, 500
|
||||||
@@ -337,50 +334,50 @@ def change_password():
|
|||||||
from app.database import db
|
from app.database import db
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from werkzeug.security import check_password_hash
|
from werkzeug.security import check_password_hash
|
||||||
|
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
if not data:
|
if not data:
|
||||||
return {"error": "No data provided"}, 400
|
return {"error": "No data provided"}, 400
|
||||||
|
|
||||||
user_data = get_current_user()
|
user_data = get_current_user()
|
||||||
if not user_data:
|
if not user_data:
|
||||||
return {"error": "User not authenticated"}, 401
|
return {"error": "User not authenticated"}, 401
|
||||||
|
|
||||||
user = User.query.get(int(user_data["id"]))
|
user = User.query.get(int(user_data["id"]))
|
||||||
if not user:
|
if not user:
|
||||||
return {"error": "User not found"}, 404
|
return {"error": "User not found"}, 404
|
||||||
|
|
||||||
new_password = data.get("new_password")
|
new_password = data.get("new_password")
|
||||||
current_password = data.get("current_password")
|
current_password = data.get("current_password")
|
||||||
|
|
||||||
if not new_password:
|
if not new_password:
|
||||||
return {"error": "New password is required"}, 400
|
return {"error": "New password is required"}, 400
|
||||||
|
|
||||||
# Password validation
|
# Password validation
|
||||||
if len(new_password) < 6:
|
if len(new_password) < 6:
|
||||||
return {"error": "Password must be at least 6 characters long"}, 400
|
return {"error": "Password must be at least 6 characters long"}, 400
|
||||||
|
|
||||||
# Check authentication method: if user logged in via password, require current password
|
# Check authentication method: if user logged in via password, require current password
|
||||||
# If user logged in via OAuth, they can change password without current password
|
# If user logged in via OAuth, they can change password without current password
|
||||||
current_auth_method = user_data.get("provider", "unknown")
|
current_auth_method = user_data.get("provider", "unknown")
|
||||||
|
|
||||||
if user.password_hash and current_auth_method == "password":
|
if user.password_hash and current_auth_method == "password":
|
||||||
# User has a password AND logged in via password, require current password for verification
|
# User has a password AND logged in via password, require current password for verification
|
||||||
if not current_password:
|
if not current_password:
|
||||||
return {"error": "Current password is required to change password"}, 400
|
return {
|
||||||
|
"error": "Current password is required to change password"
|
||||||
|
}, 400
|
||||||
|
|
||||||
if not check_password_hash(user.password_hash, current_password):
|
if not check_password_hash(user.password_hash, current_password):
|
||||||
return {"error": "Current password is incorrect"}, 400
|
return {"error": "Current password is incorrect"}, 400
|
||||||
# If user logged in via OAuth (google, github, etc.), they can change password without current password
|
# If user logged in via OAuth (google, github, etc.), they can change password without current password
|
||||||
|
|
||||||
# Set the new password
|
# Set the new password
|
||||||
try:
|
try:
|
||||||
user.set_password(new_password)
|
user.set_password(new_password)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {
|
return {"message": "Password updated successfully"}
|
||||||
"message": "Password updated successfully"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
return {"error": f"Failed to update password: {str(e)}"}, 500
|
return {"error": f"Failed to update password: {str(e)}"}, 500
|
||||||
|
|||||||
@@ -2,7 +2,12 @@
|
|||||||
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
||||||
from app.services.decorators import get_current_user, require_auth, require_role, require_credits
|
from app.services.decorators import (
|
||||||
|
get_current_user,
|
||||||
|
require_auth,
|
||||||
|
require_role,
|
||||||
|
require_credits,
|
||||||
|
)
|
||||||
|
|
||||||
bp = Blueprint("main", __name__)
|
bp = Blueprint("main", __name__)
|
||||||
|
|
||||||
@@ -63,7 +68,8 @@ def use_credits(amount: int) -> dict[str, str]:
|
|||||||
return {
|
return {
|
||||||
"message": f"Successfully used endpoint! You requested amount: {amount}",
|
"message": f"Successfully used endpoint! You requested amount: {amount}",
|
||||||
"user": user["email"],
|
"user": user["email"],
|
||||||
"remaining_credits": user["credits"] - 5, # Note: credits already deducted by decorator
|
"remaining_credits": user["credits"]
|
||||||
|
- 5, # Note: credits already deducted by decorator
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,9 @@ class AuthService:
|
|||||||
|
|
||||||
# Prepare user data for JWT token using user.to_dict()
|
# Prepare user data for JWT token using user.to_dict()
|
||||||
jwt_user_data = user.to_dict()
|
jwt_user_data = user.to_dict()
|
||||||
jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login
|
jwt_user_data["provider"] = (
|
||||||
|
oauth_provider.provider
|
||||||
|
) # Override provider for OAuth login
|
||||||
|
|
||||||
# Generate JWT tokens
|
# Generate JWT tokens
|
||||||
access_token = self.token_service.generate_access_token(
|
access_token = self.token_service.generate_access_token(
|
||||||
@@ -156,7 +158,9 @@ class AuthService:
|
|||||||
|
|
||||||
# Prepare user data for JWT token using user.to_dict()
|
# Prepare user data for JWT token using user.to_dict()
|
||||||
jwt_user_data = user.to_dict()
|
jwt_user_data = user.to_dict()
|
||||||
jwt_user_data["provider"] = "password" # Override provider for password registration
|
jwt_user_data["provider"] = (
|
||||||
|
"password" # Override provider for password registration
|
||||||
|
)
|
||||||
|
|
||||||
# Generate JWT tokens
|
# Generate JWT tokens
|
||||||
access_token = self.token_service.generate_access_token(
|
access_token = self.token_service.generate_access_token(
|
||||||
@@ -199,7 +203,9 @@ class AuthService:
|
|||||||
|
|
||||||
# Prepare user data for JWT token using user.to_dict()
|
# Prepare user data for JWT token using user.to_dict()
|
||||||
jwt_user_data = user.to_dict()
|
jwt_user_data = user.to_dict()
|
||||||
jwt_user_data["provider"] = "password" # Override provider for password login
|
jwt_user_data["provider"] = (
|
||||||
|
"password" # Override provider for password login
|
||||||
|
)
|
||||||
|
|
||||||
# Generate JWT tokens
|
# Generate JWT tokens
|
||||||
access_token = self.token_service.generate_access_token(jwt_user_data)
|
access_token = self.token_service.generate_access_token(jwt_user_data)
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ def get_user_from_jwt() -> dict[str, Any] | None:
|
|||||||
try:
|
try:
|
||||||
# Try to verify JWT token in request - this sets up the context
|
# Try to verify JWT token in request - this sets up the context
|
||||||
verify_jwt_in_request()
|
verify_jwt_in_request()
|
||||||
|
|
||||||
current_user_id = get_jwt_identity()
|
current_user_id = get_jwt_identity()
|
||||||
if not current_user_id:
|
if not current_user_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Query database for user data instead of using JWT claims
|
# Query database for user data instead of using JWT claims
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
||||||
user = User.query.get(int(current_user_id))
|
user = User.query.get(int(current_user_id))
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
return None
|
return None
|
||||||
@@ -70,7 +70,7 @@ def get_user_from_api_token() -> dict[str, Any] | None:
|
|||||||
providers.append("password")
|
providers.append("password")
|
||||||
if user.api_token:
|
if user.api_token:
|
||||||
providers.append("api_token")
|
providers.append("api_token")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": str(user.id),
|
"id": str(user.id),
|
||||||
"email": user.email,
|
"email": user.email,
|
||||||
@@ -148,22 +148,23 @@ def require_role(required_role: str):
|
|||||||
|
|
||||||
def require_credits(credits_needed: int):
|
def require_credits(credits_needed: int):
|
||||||
"""Decorator to require and deduct credits for routes."""
|
"""Decorator to require and deduct credits for routes."""
|
||||||
|
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.database import db
|
from app.database import db
|
||||||
|
|
||||||
# First check authentication
|
# First check authentication
|
||||||
user_data = get_current_user()
|
user_data = get_current_user()
|
||||||
if not user_data:
|
if not user_data:
|
||||||
return jsonify({"error": "Authentication required"}), 401
|
return jsonify({"error": "Authentication required"}), 401
|
||||||
|
|
||||||
# Get the actual user from database to check/update credits
|
# Get the actual user from database to check/update credits
|
||||||
user = User.query.get(int(user_data["id"]))
|
user = User.query.get(int(user_data["id"]))
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
return jsonify({"error": "User not found or inactive"}), 401
|
return jsonify({"error": "User not found or inactive"}), 401
|
||||||
|
|
||||||
# Check if user has enough credits
|
# Check if user has enough credits
|
||||||
if user.credits < credits_needed:
|
if user.credits < credits_needed:
|
||||||
return (
|
return (
|
||||||
@@ -174,15 +175,16 @@ def require_credits(credits_needed: int):
|
|||||||
),
|
),
|
||||||
402, # Payment Required status code
|
402, # Payment Required status code
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deduct credits
|
# Deduct credits
|
||||||
user.credits -= credits_needed
|
user.credits -= credits_needed
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
# Execute the function
|
# Execute the function
|
||||||
result = f(*args, **kwargs)
|
result = f(*args, **kwargs)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|||||||
@@ -5,35 +5,35 @@ from authlib.integrations.flask_client import OAuth
|
|||||||
|
|
||||||
class OAuthProvider(ABC):
|
class OAuthProvider(ABC):
|
||||||
"""Abstract base class for OAuth providers."""
|
"""Abstract base class for OAuth providers."""
|
||||||
|
|
||||||
def __init__(self, oauth: OAuth, client_id: str, client_secret: str):
|
def __init__(self, oauth: OAuth, client_id: str, client_secret: str):
|
||||||
self.oauth = oauth
|
self.oauth = oauth
|
||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.client_secret = client_secret
|
self.client_secret = client_secret
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Provider name (e.g., 'google', 'github')."""
|
"""Provider name (e.g., 'google', 'github')."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def display_name(self) -> str:
|
def display_name(self) -> str:
|
||||||
"""Human-readable provider name (e.g., 'Google', 'GitHub')."""
|
"""Human-readable provider name (e.g., 'Google', 'GitHub')."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_client_config(self) -> Dict[str, Any]:
|
def get_client_config(self) -> Dict[str, Any]:
|
||||||
"""Return OAuth client configuration."""
|
"""Return OAuth client configuration."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Extract user information from OAuth token response."""
|
"""Extract user information from OAuth token response."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_client(self):
|
def get_client(self):
|
||||||
"""Get or create OAuth client."""
|
"""Get or create OAuth client."""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
@@ -42,27 +42,29 @@ class OAuthProvider(ABC):
|
|||||||
name=self.name,
|
name=self.name,
|
||||||
client_id=self.client_id,
|
client_id=self.client_id,
|
||||||
client_secret=self.client_secret,
|
client_secret=self.client_secret,
|
||||||
**config
|
**config,
|
||||||
)
|
)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def get_authorization_url(self, redirect_uri: str) -> str:
|
def get_authorization_url(self, redirect_uri: str) -> str:
|
||||||
"""Generate authorization URL for OAuth flow."""
|
"""Generate authorization URL for OAuth flow."""
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
return client.authorize_redirect(redirect_uri).location
|
return client.authorize_redirect(redirect_uri).location
|
||||||
|
|
||||||
def exchange_code_for_token(self, code: str = None, redirect_uri: str = None) -> Dict[str, Any]:
|
def exchange_code_for_token(
|
||||||
|
self, code: str = None, redirect_uri: str = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Exchange authorization code for access token."""
|
"""Exchange authorization code for access token."""
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
token = client.authorize_access_token()
|
token = client.authorize_access_token()
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
|
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Normalize user data to common format."""
|
"""Normalize user data to common format."""
|
||||||
return {
|
return {
|
||||||
'id': user_info.get('id'),
|
"id": user_info.get("id"),
|
||||||
'email': user_info.get('email'),
|
"email": user_info.get("email"),
|
||||||
'name': user_info.get('name'),
|
"name": user_info.get("name"),
|
||||||
'picture': user_info.get('picture'),
|
"picture": user_info.get("picture"),
|
||||||
'provider': self.name
|
"provider": self.name,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,49 +4,47 @@ from .base import OAuthProvider
|
|||||||
|
|
||||||
class GitHubOAuthProvider(OAuthProvider):
|
class GitHubOAuthProvider(OAuthProvider):
|
||||||
"""GitHub OAuth provider implementation."""
|
"""GitHub OAuth provider implementation."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return 'github'
|
return "github"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def display_name(self) -> str:
|
def display_name(self) -> str:
|
||||||
return 'GitHub'
|
return "GitHub"
|
||||||
|
|
||||||
def get_client_config(self) -> Dict[str, Any]:
|
def get_client_config(self) -> Dict[str, Any]:
|
||||||
"""Return GitHub OAuth client configuration."""
|
"""Return GitHub OAuth client configuration."""
|
||||||
return {
|
return {
|
||||||
'access_token_url': 'https://github.com/login/oauth/access_token',
|
"access_token_url": "https://github.com/login/oauth/access_token",
|
||||||
'authorize_url': 'https://github.com/login/oauth/authorize',
|
"authorize_url": "https://github.com/login/oauth/authorize",
|
||||||
'api_base_url': 'https://api.github.com/',
|
"api_base_url": "https://api.github.com/",
|
||||||
'client_kwargs': {
|
"client_kwargs": {"scope": "user:email"},
|
||||||
'scope': 'user:email'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Extract user information from GitHub OAuth token response."""
|
"""Extract user information from GitHub OAuth token response."""
|
||||||
client = self.get_client()
|
client = self.get_client()
|
||||||
|
|
||||||
# Get user profile
|
# Get user profile
|
||||||
user_resp = client.get('user', token=token)
|
user_resp = client.get("user", token=token)
|
||||||
user_data = user_resp.json()
|
user_data = user_resp.json()
|
||||||
|
|
||||||
# Get user email (may be private)
|
# Get user email (may be private)
|
||||||
email = user_data.get('email')
|
email = user_data.get("email")
|
||||||
if not email:
|
if not email:
|
||||||
# If email is private, get from emails endpoint
|
# If email is private, get from emails endpoint
|
||||||
emails_resp = client.get('user/emails', token=token)
|
emails_resp = client.get("user/emails", token=token)
|
||||||
emails = emails_resp.json()
|
emails = emails_resp.json()
|
||||||
# Find primary email
|
# Find primary email
|
||||||
for email_obj in emails:
|
for email_obj in emails:
|
||||||
if email_obj.get('primary', False):
|
if email_obj.get("primary", False):
|
||||||
email = email_obj.get('email')
|
email = email_obj.get("email")
|
||||||
break
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'id': str(user_data.get('id')),
|
"id": str(user_data.get("id")),
|
||||||
'email': email,
|
"email": email,
|
||||||
'name': user_data.get('name') or user_data.get('login'),
|
"name": user_data.get("name") or user_data.get("login"),
|
||||||
'picture': user_data.get('avatar_url')
|
"picture": user_data.get("avatar_url"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,38 +8,38 @@ from .github import GitHubOAuthProvider
|
|||||||
|
|
||||||
class OAuthProviderRegistry:
|
class OAuthProviderRegistry:
|
||||||
"""Registry for OAuth providers."""
|
"""Registry for OAuth providers."""
|
||||||
|
|
||||||
def __init__(self, oauth: OAuth):
|
def __init__(self, oauth: OAuth):
|
||||||
self.oauth = oauth
|
self.oauth = oauth
|
||||||
self._providers: Dict[str, OAuthProvider] = {}
|
self._providers: Dict[str, OAuthProvider] = {}
|
||||||
self._initialize_providers()
|
self._initialize_providers()
|
||||||
|
|
||||||
def _initialize_providers(self):
|
def _initialize_providers(self):
|
||||||
"""Initialize available providers based on environment variables."""
|
"""Initialize available providers based on environment variables."""
|
||||||
# Google OAuth
|
# Google OAuth
|
||||||
google_client_id = os.getenv('GOOGLE_CLIENT_ID')
|
google_client_id = os.getenv("GOOGLE_CLIENT_ID")
|
||||||
google_client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
|
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||||
if google_client_id and google_client_secret:
|
if google_client_id and google_client_secret:
|
||||||
self._providers['google'] = GoogleOAuthProvider(
|
self._providers["google"] = GoogleOAuthProvider(
|
||||||
self.oauth, google_client_id, google_client_secret
|
self.oauth, google_client_id, google_client_secret
|
||||||
)
|
)
|
||||||
|
|
||||||
# GitHub OAuth
|
# GitHub OAuth
|
||||||
github_client_id = os.getenv('GITHUB_CLIENT_ID')
|
github_client_id = os.getenv("GITHUB_CLIENT_ID")
|
||||||
github_client_secret = os.getenv('GITHUB_CLIENT_SECRET')
|
github_client_secret = os.getenv("GITHUB_CLIENT_SECRET")
|
||||||
if github_client_id and github_client_secret:
|
if github_client_id and github_client_secret:
|
||||||
self._providers['github'] = GitHubOAuthProvider(
|
self._providers["github"] = GitHubOAuthProvider(
|
||||||
self.oauth, github_client_id, github_client_secret
|
self.oauth, github_client_id, github_client_secret
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_provider(self, name: str) -> Optional[OAuthProvider]:
|
def get_provider(self, name: str) -> Optional[OAuthProvider]:
|
||||||
"""Get OAuth provider by name."""
|
"""Get OAuth provider by name."""
|
||||||
return self._providers.get(name)
|
return self._providers.get(name)
|
||||||
|
|
||||||
def get_available_providers(self) -> Dict[str, OAuthProvider]:
|
def get_available_providers(self) -> Dict[str, OAuthProvider]:
|
||||||
"""Get all available providers."""
|
"""Get all available providers."""
|
||||||
return self._providers.copy()
|
return self._providers.copy()
|
||||||
|
|
||||||
def is_provider_available(self, name: str) -> bool:
|
def is_provider_available(self, name: str) -> bool:
|
||||||
"""Check if provider is available."""
|
"""Check if provider is available."""
|
||||||
return name in self._providers
|
return name in self._providers
|
||||||
|
|||||||
@@ -9,22 +9,27 @@ from app.database import db
|
|||||||
app = create_app()
|
app = create_app()
|
||||||
cli = FlaskGroup(app)
|
cli = FlaskGroup(app)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def init_db():
|
def init_db():
|
||||||
"""Initialize the database."""
|
"""Initialize the database."""
|
||||||
print("Initializing database...")
|
print("Initializing database...")
|
||||||
from app.database_init import init_database
|
from app.database_init import init_database
|
||||||
|
|
||||||
init_database()
|
init_database()
|
||||||
print("Database initialized successfully!")
|
print("Database initialized successfully!")
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
def reset_db():
|
def reset_db():
|
||||||
"""Reset the database (drop all tables and recreate)."""
|
"""Reset the database (drop all tables and recreate)."""
|
||||||
print("Resetting database...")
|
print("Resetting database...")
|
||||||
db.drop_all()
|
db.drop_all()
|
||||||
from app.database_init import init_database
|
from app.database_init import init_database
|
||||||
|
|
||||||
init_database()
|
init_database()
|
||||||
print("Database reset successfully!")
|
print("Database reset successfully!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
cli()
|
cli()
|
||||||
|
|||||||
@@ -23,32 +23,45 @@ class TestAuthRoutesJWTExtended:
|
|||||||
@patch("app.routes.auth.auth_service.get_login_url")
|
@patch("app.routes.auth.auth_service.get_login_url")
|
||||||
def test_login_route(self, mock_get_login_url: Mock, client) -> None:
|
def test_login_route(self, mock_get_login_url: Mock, client) -> None:
|
||||||
"""Test the login route."""
|
"""Test the login route."""
|
||||||
mock_get_login_url.return_value = "https://accounts.google.com/oauth/authorize?..."
|
mock_get_login_url.return_value = (
|
||||||
|
"https://accounts.google.com/oauth/authorize?..."
|
||||||
|
)
|
||||||
|
|
||||||
response = client.get("/api/auth/login")
|
response = client.get("/api/auth/login")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert "login_url" in data
|
assert "login_url" in data
|
||||||
assert data["login_url"] == "https://accounts.google.com/oauth/authorize?..."
|
assert (
|
||||||
|
data["login_url"]
|
||||||
|
== "https://accounts.google.com/oauth/authorize?..."
|
||||||
|
)
|
||||||
|
|
||||||
@patch("app.routes.auth.auth_service.handle_callback")
|
@patch("app.routes.auth.auth_service.handle_callback")
|
||||||
def test_callback_route_success(self, mock_handle_callback: Mock, client) -> None:
|
def test_callback_route_success(
|
||||||
|
self, mock_handle_callback: Mock, client
|
||||||
|
) -> None:
|
||||||
"""Test successful callback route."""
|
"""Test successful callback route."""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.get_json.return_value = {
|
mock_response.get_json.return_value = {
|
||||||
"message": "Login successful",
|
"message": "Login successful",
|
||||||
"user": {"id": "123", "email": "test@example.com", "name": "Test User"}
|
"user": {
|
||||||
|
"id": "123",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
mock_handle_callback.return_value = mock_response
|
mock_handle_callback.return_value = mock_response
|
||||||
|
|
||||||
response = client.get("/api/auth/callback?code=test_code")
|
response = client.get("/api/auth/callback?code=test_code")
|
||||||
mock_handle_callback.assert_called_once()
|
mock_handle_callback.assert_called_once()
|
||||||
|
|
||||||
@patch("app.routes.auth.auth_service.handle_callback")
|
@patch("app.routes.auth.auth_service.handle_callback")
|
||||||
def test_callback_route_error(self, mock_handle_callback: Mock, client) -> None:
|
def test_callback_route_error(
|
||||||
|
self, mock_handle_callback: Mock, client
|
||||||
|
) -> None:
|
||||||
"""Test callback route with error."""
|
"""Test callback route with error."""
|
||||||
mock_handle_callback.side_effect = Exception("OAuth error")
|
mock_handle_callback.side_effect = Exception("OAuth error")
|
||||||
|
|
||||||
response = client.get("/api/auth/callback?code=test_code")
|
response = client.get("/api/auth/callback?code=test_code")
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
@@ -58,9 +71,11 @@ class TestAuthRoutesJWTExtended:
|
|||||||
def test_logout_route(self, mock_logout: Mock, client) -> None:
|
def test_logout_route(self, mock_logout: Mock, client) -> None:
|
||||||
"""Test logout route."""
|
"""Test logout route."""
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_response.get_json.return_value = {"message": "Logged out successfully"}
|
mock_response.get_json.return_value = {
|
||||||
|
"message": "Logged out successfully"
|
||||||
|
}
|
||||||
mock_logout.return_value = mock_response
|
mock_logout.return_value = mock_response
|
||||||
|
|
||||||
response = client.get("/api/auth/logout")
|
response = client.get("/api/auth/logout")
|
||||||
mock_logout.assert_called_once()
|
mock_logout.assert_called_once()
|
||||||
|
|
||||||
@@ -76,4 +91,4 @@ class TestAuthRoutesJWTExtended:
|
|||||||
response = client.post("/api/auth/refresh")
|
response = client.post("/api/auth/refresh")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert "msg" in data # Flask-JWT-Extended error format
|
assert "msg" in data # Flask-JWT-Extended error format
|
||||||
|
|||||||
@@ -21,13 +21,13 @@ class TestAuthServiceJWTExtended:
|
|||||||
"""Test initializing AuthService with Flask app."""
|
"""Test initializing AuthService with Flask app."""
|
||||||
mock_getenv.side_effect = lambda key: {
|
mock_getenv.side_effect = lambda key: {
|
||||||
"GOOGLE_CLIENT_ID": "test_client_id",
|
"GOOGLE_CLIENT_ID": "test_client_id",
|
||||||
"GOOGLE_CLIENT_SECRET": "test_client_secret"
|
"GOOGLE_CLIENT_SECRET": "test_client_secret",
|
||||||
}.get(key)
|
}.get(key)
|
||||||
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
auth_service = AuthService()
|
auth_service = AuthService()
|
||||||
auth_service.init_app(app)
|
auth_service.init_app(app)
|
||||||
|
|
||||||
# Verify OAuth was initialized
|
# Verify OAuth was initialized
|
||||||
assert auth_service.google is not None
|
assert auth_service.google is not None
|
||||||
|
|
||||||
@@ -39,10 +39,12 @@ class TestAuthServiceJWTExtended:
|
|||||||
with app.app_context():
|
with app.app_context():
|
||||||
mock_response = Mock()
|
mock_response = Mock()
|
||||||
mock_jsonify.return_value = mock_response
|
mock_jsonify.return_value = mock_response
|
||||||
|
|
||||||
auth_service = AuthService()
|
auth_service = AuthService()
|
||||||
result = auth_service.logout()
|
result = auth_service.logout()
|
||||||
|
|
||||||
assert result == mock_response
|
assert result == mock_response
|
||||||
mock_unset.assert_called_once_with(mock_response)
|
mock_unset.assert_called_once_with(mock_response)
|
||||||
mock_jsonify.assert_called_once_with({"message": "Logged out successfully"})
|
mock_jsonify.assert_called_once_with(
|
||||||
|
{"message": "Logged out successfully"}
|
||||||
|
)
|
||||||
|
|||||||
@@ -21,7 +21,10 @@ class TestMainRoutes:
|
|||||||
"""Test the index route."""
|
"""Test the index route."""
|
||||||
response = client.get("/api/")
|
response = client.get("/api/")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.get_json() == {"message": "API is running", "status": "ok"}
|
assert response.get_json() == {
|
||||||
|
"message": "API is running",
|
||||||
|
"status": "ok",
|
||||||
|
}
|
||||||
|
|
||||||
def test_health_route(self, client) -> None:
|
def test_health_route(self, client) -> None:
|
||||||
"""Test health check route."""
|
"""Test health check route."""
|
||||||
@@ -34,4 +37,4 @@ class TestMainRoutes:
|
|||||||
response = client.get("/api/protected")
|
response = client.get("/api/protected")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["error"] == "Authentication required (JWT or API token)"
|
assert data["error"] == "Authentication required (JWT or API token)"
|
||||||
|
|||||||
@@ -25,14 +25,18 @@ class TestTokenService:
|
|||||||
user_data = {
|
user_data = {
|
||||||
"id": "123",
|
"id": "123",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"name": "Test User"
|
"name": "Test User",
|
||||||
}
|
}
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
token = token_service.generate_access_token(user_data)
|
||||||
assert isinstance(token, str)
|
assert isinstance(token, str)
|
||||||
|
|
||||||
# Verify token content
|
# Verify token content
|
||||||
payload = jwt.decode(token, token_service.secret_key, algorithms=[token_service.algorithm])
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
token_service.secret_key,
|
||||||
|
algorithms=[token_service.algorithm],
|
||||||
|
)
|
||||||
assert payload["user_id"] == "123"
|
assert payload["user_id"] == "123"
|
||||||
assert payload["email"] == "test@example.com"
|
assert payload["email"] == "test@example.com"
|
||||||
assert payload["name"] == "Test User"
|
assert payload["name"] == "Test User"
|
||||||
@@ -44,25 +48,33 @@ class TestTokenService:
|
|||||||
user_data = {
|
user_data = {
|
||||||
"id": "123",
|
"id": "123",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"name": "Test User"
|
"name": "Test User",
|
||||||
}
|
}
|
||||||
|
|
||||||
token = token_service.generate_refresh_token(user_data)
|
token = token_service.generate_refresh_token(user_data)
|
||||||
assert isinstance(token, str)
|
assert isinstance(token, str)
|
||||||
|
|
||||||
# Verify token content
|
# Verify token content
|
||||||
payload = jwt.decode(token, token_service.secret_key, algorithms=[token_service.algorithm])
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
token_service.secret_key,
|
||||||
|
algorithms=[token_service.algorithm],
|
||||||
|
)
|
||||||
assert payload["user_id"] == "123"
|
assert payload["user_id"] == "123"
|
||||||
assert payload["type"] == "refresh"
|
assert payload["type"] == "refresh"
|
||||||
|
|
||||||
def test_verify_valid_token(self) -> None:
|
def test_verify_valid_token(self) -> None:
|
||||||
"""Test verifying a valid token."""
|
"""Test verifying a valid token."""
|
||||||
token_service = TokenService()
|
token_service = TokenService()
|
||||||
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"}
|
user_data = {
|
||||||
|
"id": "123",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
}
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
token = token_service.generate_access_token(user_data)
|
||||||
payload = token_service.verify_token(token)
|
payload = token_service.verify_token(token)
|
||||||
|
|
||||||
assert payload is not None
|
assert payload is not None
|
||||||
assert payload["user_id"] == "123"
|
assert payload["user_id"] == "123"
|
||||||
assert payload["type"] == "access"
|
assert payload["type"] == "access"
|
||||||
@@ -70,7 +82,7 @@ class TestTokenService:
|
|||||||
def test_verify_invalid_token(self) -> None:
|
def test_verify_invalid_token(self) -> None:
|
||||||
"""Test verifying an invalid token."""
|
"""Test verifying an invalid token."""
|
||||||
token_service = TokenService()
|
token_service = TokenService()
|
||||||
|
|
||||||
payload = token_service.verify_token("invalid.token.here")
|
payload = token_service.verify_token("invalid.token.here")
|
||||||
assert payload is None
|
assert payload is None
|
||||||
|
|
||||||
@@ -81,61 +93,75 @@ class TestTokenService:
|
|||||||
past_time = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
past_time = datetime(2020, 1, 1, tzinfo=timezone.utc)
|
||||||
mock_datetime.now.return_value = past_time
|
mock_datetime.now.return_value = past_time
|
||||||
mock_datetime.UTC = timezone.utc
|
mock_datetime.UTC = timezone.utc
|
||||||
|
|
||||||
token_service = TokenService()
|
token_service = TokenService()
|
||||||
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"}
|
user_data = {
|
||||||
|
"id": "123",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
}
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
token = token_service.generate_access_token(user_data)
|
||||||
|
|
||||||
# Reset mock to current time for verification
|
# Reset mock to current time for verification
|
||||||
mock_datetime.now.return_value = datetime.now(timezone.utc)
|
mock_datetime.now.return_value = datetime.now(timezone.utc)
|
||||||
|
|
||||||
payload = token_service.verify_token(token)
|
payload = token_service.verify_token(token)
|
||||||
assert payload is None
|
assert payload is None
|
||||||
|
|
||||||
def test_is_access_token(self) -> None:
|
def test_is_access_token(self) -> None:
|
||||||
"""Test access token type checking."""
|
"""Test access token type checking."""
|
||||||
token_service = TokenService()
|
token_service = TokenService()
|
||||||
|
|
||||||
access_payload = {"type": "access", "user_id": "123"}
|
access_payload = {"type": "access", "user_id": "123"}
|
||||||
refresh_payload = {"type": "refresh", "user_id": "123"}
|
refresh_payload = {"type": "refresh", "user_id": "123"}
|
||||||
|
|
||||||
assert token_service.is_access_token(access_payload)
|
assert token_service.is_access_token(access_payload)
|
||||||
assert not token_service.is_access_token(refresh_payload)
|
assert not token_service.is_access_token(refresh_payload)
|
||||||
|
|
||||||
def test_is_refresh_token(self) -> None:
|
def test_is_refresh_token(self) -> None:
|
||||||
"""Test refresh token type checking."""
|
"""Test refresh token type checking."""
|
||||||
token_service = TokenService()
|
token_service = TokenService()
|
||||||
|
|
||||||
access_payload = {"type": "access", "user_id": "123"}
|
access_payload = {"type": "access", "user_id": "123"}
|
||||||
refresh_payload = {"type": "refresh", "user_id": "123"}
|
refresh_payload = {"type": "refresh", "user_id": "123"}
|
||||||
|
|
||||||
assert token_service.is_refresh_token(refresh_payload)
|
assert token_service.is_refresh_token(refresh_payload)
|
||||||
assert not token_service.is_refresh_token(access_payload)
|
assert not token_service.is_refresh_token(access_payload)
|
||||||
|
|
||||||
def test_get_user_from_access_token_valid(self) -> None:
|
def test_get_user_from_access_token_valid(self) -> None:
|
||||||
"""Test extracting user from valid access token."""
|
"""Test extracting user from valid access token."""
|
||||||
token_service = TokenService()
|
token_service = TokenService()
|
||||||
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"}
|
user_data = {
|
||||||
|
"id": "123",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
}
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
token = token_service.generate_access_token(user_data)
|
||||||
extracted_user = token_service.get_user_from_access_token(token)
|
extracted_user = token_service.get_user_from_access_token(token)
|
||||||
|
|
||||||
assert extracted_user == user_data
|
assert extracted_user == user_data
|
||||||
|
|
||||||
def test_get_user_from_access_token_refresh_token(self) -> None:
|
def test_get_user_from_access_token_refresh_token(self) -> None:
|
||||||
"""Test extracting user from refresh token (should fail)."""
|
"""Test extracting user from refresh token (should fail)."""
|
||||||
token_service = TokenService()
|
token_service = TokenService()
|
||||||
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"}
|
user_data = {
|
||||||
|
"id": "123",
|
||||||
|
"email": "test@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
}
|
||||||
|
|
||||||
token = token_service.generate_refresh_token(user_data)
|
token = token_service.generate_refresh_token(user_data)
|
||||||
extracted_user = token_service.get_user_from_access_token(token)
|
extracted_user = token_service.get_user_from_access_token(token)
|
||||||
|
|
||||||
assert extracted_user is None
|
assert extracted_user is None
|
||||||
|
|
||||||
def test_get_user_from_access_token_invalid(self) -> None:
|
def test_get_user_from_access_token_invalid(self) -> None:
|
||||||
"""Test extracting user from invalid token."""
|
"""Test extracting user from invalid token."""
|
||||||
token_service = TokenService()
|
token_service = TokenService()
|
||||||
|
|
||||||
extracted_user = token_service.get_user_from_access_token("invalid.token")
|
extracted_user = token_service.get_user_from_access_token(
|
||||||
assert extracted_user is None
|
"invalid.token"
|
||||||
|
)
|
||||||
|
assert extracted_user is None
|
||||||
|
|||||||
@@ -18,9 +18,9 @@ class TestTokenServiceJWTExtended:
|
|||||||
"id": "123",
|
"id": "123",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"name": "Test User",
|
"name": "Test User",
|
||||||
"picture": "https://example.com/pic.jpg"
|
"picture": "https://example.com/pic.jpg",
|
||||||
}
|
}
|
||||||
|
|
||||||
token = token_service.generate_access_token(user_data)
|
token = token_service.generate_access_token(user_data)
|
||||||
assert isinstance(token, str)
|
assert isinstance(token, str)
|
||||||
assert len(token) > 0
|
assert len(token) > 0
|
||||||
@@ -33,9 +33,9 @@ class TestTokenServiceJWTExtended:
|
|||||||
user_data = {
|
user_data = {
|
||||||
"id": "123",
|
"id": "123",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"name": "Test User"
|
"name": "Test User",
|
||||||
}
|
}
|
||||||
|
|
||||||
token = token_service.generate_refresh_token(user_data)
|
token = token_service.generate_refresh_token(user_data)
|
||||||
assert isinstance(token, str)
|
assert isinstance(token, str)
|
||||||
assert len(token) > 0
|
assert len(token) > 0
|
||||||
@@ -48,10 +48,10 @@ class TestTokenServiceJWTExtended:
|
|||||||
user_data = {
|
user_data = {
|
||||||
"id": "123",
|
"id": "123",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"name": "Test User"
|
"name": "Test User",
|
||||||
}
|
}
|
||||||
|
|
||||||
access_token = token_service.generate_access_token(user_data)
|
access_token = token_service.generate_access_token(user_data)
|
||||||
refresh_token = token_service.generate_refresh_token(user_data)
|
refresh_token = token_service.generate_refresh_token(user_data)
|
||||||
|
|
||||||
assert access_token != refresh_token
|
assert access_token != refresh_token
|
||||||
|
|||||||
Reference in New Issue
Block a user