Refactor code for improved readability and consistency

- Cleaned up whitespace and formatting across multiple files for better readability.
This commit is contained in:
JSC
2025-07-02 10:37:48 +02:00
parent e63c7a0767
commit 171dbb9b63
19 changed files with 361 additions and 260 deletions

View File

@@ -11,8 +11,8 @@ def init_db(app):
"""Initialize database with Flask app.""" """Initialize database with Flask app."""
db.init_app(app) db.init_app(app)
migrate.init_app(app, db) migrate.init_app(app, db)
# Import models here to ensure they are registered with SQLAlchemy # Import models here to ensure they are registered with SQLAlchemy
from app.models import user, user_oauth # noqa: F401 from app.models import user, user_oauth # noqa: F401
return db return db

View File

@@ -8,10 +8,10 @@ def init_database():
"""Initialize database tables and seed with default data.""" """Initialize database tables and seed with default data."""
# Create all tables # Create all tables
db.create_all() db.create_all()
# Seed plans if they don't exist # Seed plans if they don't exist
seed_plans() seed_plans()
# Migrate existing users to have plans # Migrate existing users to have plans
migrate_users_to_plans() migrate_users_to_plans()
@@ -21,7 +21,7 @@ def seed_plans():
# Check if plans already exist # Check if plans already exist
if Plan.query.count() > 0: if Plan.query.count() > 0:
return return
# Create default plans # Create default plans
plans_data = [ plans_data = [
{ {
@@ -46,11 +46,11 @@ def seed_plans():
"max_credits": 300, "max_credits": 300,
}, },
] ]
for plan_data in plans_data: for plan_data in plans_data:
plan = Plan(**plan_data) plan = Plan(**plan_data)
db.session.add(plan) db.session.add(plan)
db.session.commit() db.session.commit()
print(f"Seeded {len(plans_data)} plans into database") print(f"Seeded {len(plans_data)} plans into database")
@@ -58,11 +58,11 @@ def seed_plans():
def migrate_users_to_plans(): def migrate_users_to_plans():
"""Assign plans to existing users who don't have one.""" """Assign plans to existing users who don't have one."""
from app.models.user import User from app.models.user import User
try: try:
# Find users without plans # Find users without plans
users_without_plans = User.query.filter(User.plan_id.is_(None)).all() users_without_plans = User.query.filter(User.plan_id.is_(None)).all()
# Find users with plans but NULL credits (only if credits column exists) # Find users with plans but NULL credits (only if credits column exists)
# Note: We only migrate users with NULL credits, not 0 credits # Note: We only migrate users with NULL credits, not 0 credits
# 0 credits means they spent them, NULL means they never got assigned # 0 credits means they spent them, NULL means they never got assigned
@@ -73,19 +73,19 @@ def migrate_users_to_plans():
except Exception: except Exception:
# Credits column doesn't exist yet, will be handled by create_all # Credits column doesn't exist yet, will be handled by create_all
users_without_credits = [] users_without_credits = []
if not users_without_plans and not users_without_credits: if not users_without_plans and not users_without_credits:
return return
# Get default and pro plans # Get default and pro plans
default_plan = Plan.get_default_plan() default_plan = Plan.get_default_plan()
pro_plan = Plan.get_pro_plan() pro_plan = Plan.get_pro_plan()
# Get the first user (admin) from all users ordered by ID # Get the first user (admin) from all users ordered by ID
first_user = User.query.order_by(User.id).first() first_user = User.query.order_by(User.id).first()
updated_count = 0 updated_count = 0
# Assign plans to users without plans # Assign plans to users without plans
for user in users_without_plans: for user in users_without_plans:
# First user gets pro plan, others get free plan # First user gets pro plan, others get free plan
@@ -104,17 +104,19 @@ def migrate_users_to_plans():
except Exception: except Exception:
pass pass
updated_count += 1 updated_count += 1
# Assign credits to users with plans but no credits # Assign credits to users with plans but no credits
for user in users_without_credits: for user in users_without_credits:
user.credits = user.plan.credits user.credits = user.plan.credits
updated_count += 1 updated_count += 1
if updated_count > 0: if updated_count > 0:
db.session.commit() db.session.commit()
print(f"Updated {updated_count} existing users with plans and credits") print(
f"Updated {updated_count} existing users with plans and credits"
)
except Exception: except Exception:
# If there's any error (like missing columns), just skip migration # If there's any error (like missing columns), just skip migration
# The database will be properly created by create_all() # The database will be properly created by create_all()
pass pass

View File

@@ -4,4 +4,4 @@ from .plan import Plan
from .user import User from .user import User
from .user_oauth import UserOAuth from .user_oauth import UserOAuth
__all__ = ["Plan", "User", "UserOAuth"] __all__ = ["Plan", "User", "UserOAuth"]

View File

@@ -55,4 +55,4 @@ class Plan(db.Model):
"description": self.description, "description": self.description,
"credits": self.credits, "credits": self.credits,
"max_credits": self.max_credits, "max_credits": self.max_credits,
} }

View File

@@ -17,54 +17,65 @@ if TYPE_CHECKING:
class User(db.Model): class User(db.Model):
"""User model for storing user information.""" """User model for storing user information."""
__tablename__ = "users" __tablename__ = "users"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
# Primary user information (can be updated from any connected provider) # Primary user information (can be updated from any connected provider)
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
# Password authentication (optional - users can use OAuth instead) # Password authentication (optional - users can use OAuth instead)
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) password_hash: Mapped[Optional[str]] = mapped_column(
String(255), nullable=True
)
# Role-based access control # Role-based access control
role: Mapped[str] = mapped_column(String(50), nullable=False, default="user") role: Mapped[str] = mapped_column(
String(50), nullable=False, default="user"
)
# User status # User status
is_active: Mapped[bool] = mapped_column(nullable=False, default=True) is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
# Plan relationship # Plan relationship
plan_id: Mapped[int] = mapped_column(Integer, ForeignKey("plans.id"), nullable=False) plan_id: Mapped[int] = mapped_column(
Integer, ForeignKey("plans.id"), nullable=False
)
# User credits (populated from plan credits on creation) # User credits (populated from plan credits on creation)
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0) credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# API token for programmatic access # API token for programmatic access
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(
DateTime, nullable=True
)
# Timestamps # Timestamps
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, nullable=False DateTime, default=datetime.utcnow, nullable=False
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
nullable=False,
) )
# Relationships # Relationships
oauth_providers: Mapped[list["UserOAuth"]] = relationship( oauth_providers: Mapped[list["UserOAuth"]] = relationship(
"UserOAuth", back_populates="user", cascade="all, delete-orphan" "UserOAuth", back_populates="user", cascade="all, delete-orphan"
) )
plan: Mapped["Plan"] = relationship("Plan", back_populates="users") plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
def __repr__(self) -> str: def __repr__(self) -> str:
"""String representation of User.""" """String representation of User."""
provider_count = len(self.oauth_providers) provider_count = len(self.oauth_providers)
return f"<User {self.email} ({provider_count} providers)>" return f"<User {self.email} ({provider_count} providers)>"
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""Convert user to dictionary.""" """Convert user to dictionary."""
# Build comprehensive providers list # Build comprehensive providers list
@@ -73,7 +84,7 @@ class User(db.Model):
providers.append("password") providers.append("password")
if self.api_token: if self.api_token:
providers.append("api_token") providers.append("api_token")
return { return {
"id": str(self.id), "id": str(self.id),
"email": self.email, "email": self.email,
@@ -82,25 +93,27 @@ class User(db.Model):
"role": self.role, "role": self.role,
"is_active": self.is_active, "is_active": self.is_active,
"api_token": self.api_token, "api_token": self.api_token,
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None, "api_token_expires_at": self.api_token_expires_at.isoformat()
if self.api_token_expires_at
else None,
"providers": providers, "providers": providers,
"plan": self.plan.to_dict() if self.plan else None, "plan": self.plan.to_dict() if self.plan else None,
"credits": self.credits, "credits": self.credits,
"created_at": self.created_at.isoformat(), "created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(), "updated_at": self.updated_at.isoformat(),
} }
def get_provider(self, provider_name: str) -> Optional["UserOAuth"]: def get_provider(self, provider_name: str) -> Optional["UserOAuth"]:
"""Get specific OAuth provider for this user.""" """Get specific OAuth provider for this user."""
for provider in self.oauth_providers: for provider in self.oauth_providers:
if provider.provider == provider_name: if provider.provider == provider_name:
return provider return provider
return None return None
def has_provider(self, provider_name: str) -> bool: def has_provider(self, provider_name: str) -> bool:
"""Check if user has specific OAuth provider connected.""" """Check if user has specific OAuth provider connected."""
return self.get_provider(provider_name) is not None return self.get_provider(provider_name) is not None
def update_from_provider(self, provider_data: dict) -> None: def update_from_provider(self, provider_data: dict) -> None:
"""Update user info from provider data (email, name, picture).""" """Update user info from provider data (email, name, picture)."""
self.email = provider_data.get("email", self.email) self.email = provider_data.get("email", self.email)
@@ -108,60 +121,60 @@ class User(db.Model):
self.picture = provider_data.get("picture", self.picture) self.picture = provider_data.get("picture", self.picture)
self.updated_at = datetime.utcnow() self.updated_at = datetime.utcnow()
db.session.commit() db.session.commit()
def set_password(self, password: str) -> None: def set_password(self, password: str) -> None:
"""Hash and set user password.""" """Hash and set user password."""
self.password_hash = generate_password_hash(password) self.password_hash = generate_password_hash(password)
self.updated_at = datetime.utcnow() self.updated_at = datetime.utcnow()
def check_password(self, password: str) -> bool: def check_password(self, password: str) -> bool:
"""Check if provided password matches user's password.""" """Check if provided password matches user's password."""
if not self.password_hash: if not self.password_hash:
return False return False
return check_password_hash(self.password_hash, password) return check_password_hash(self.password_hash, password)
def has_password(self) -> bool: def has_password(self) -> bool:
"""Check if user has a password set.""" """Check if user has a password set."""
return self.password_hash is not None return self.password_hash is not None
def generate_api_token(self) -> str: def generate_api_token(self) -> str:
"""Generate a new API token for the user.""" """Generate a new API token for the user."""
self.api_token = secrets.token_urlsafe(32) self.api_token = secrets.token_urlsafe(32)
self.api_token_expires_at = None # No expiration by default self.api_token_expires_at = None # No expiration by default
self.updated_at = datetime.utcnow() self.updated_at = datetime.utcnow()
return self.api_token return self.api_token
def is_api_token_valid(self) -> bool: def is_api_token_valid(self) -> bool:
"""Check if the user's API token is valid (exists and not expired).""" """Check if the user's API token is valid (exists and not expired)."""
if not self.api_token: if not self.api_token:
return False return False
if self.api_token_expires_at is None: if self.api_token_expires_at is None:
return True # No expiration return True # No expiration
return datetime.utcnow() < self.api_token_expires_at return datetime.utcnow() < self.api_token_expires_at
def revoke_api_token(self) -> None: def revoke_api_token(self) -> None:
"""Revoke the user's API token.""" """Revoke the user's API token."""
self.api_token = None self.api_token = None
self.api_token_expires_at = None self.api_token_expires_at = None
self.updated_at = datetime.utcnow() self.updated_at = datetime.utcnow()
def activate(self) -> None: def activate(self) -> None:
"""Activate the user account.""" """Activate the user account."""
self.is_active = True self.is_active = True
self.updated_at = datetime.utcnow() self.updated_at = datetime.utcnow()
def deactivate(self) -> None: def deactivate(self) -> None:
"""Deactivate the user account.""" """Deactivate the user account."""
self.is_active = False self.is_active = False
self.updated_at = datetime.utcnow() self.updated_at = datetime.utcnow()
@classmethod @classmethod
def find_by_email(cls, email: str) -> Optional["User"]: def find_by_email(cls, email: str) -> Optional["User"]:
"""Find user by email address.""" """Find user by email address."""
return cls.query.filter_by(email=email).first() return cls.query.filter_by(email=email).first()
@classmethod @classmethod
def find_by_api_token(cls, api_token: str) -> Optional["User"]: def find_by_api_token(cls, api_token: str) -> Optional["User"]:
"""Find user by API token if token is valid.""" """Find user by API token if token is valid."""
@@ -169,18 +182,25 @@ class User(db.Model):
if user and user.is_api_token_valid(): if user and user.is_api_token_valid():
return user return user
return None return None
@classmethod @classmethod
def find_or_create_from_oauth( def find_or_create_from_oauth(
cls, provider: str, provider_id: str, email: str, name: str, picture: Optional[str] = None cls,
provider: str,
provider_id: str,
email: str,
name: str,
picture: Optional[str] = None,
) -> tuple["User", "UserOAuth"]: ) -> tuple["User", "UserOAuth"]:
"""Find existing user or create new one from OAuth data.""" """Find existing user or create new one from OAuth data."""
from app.models.user_oauth import UserOAuth from app.models.user_oauth import UserOAuth
from app.models.plan import Plan from app.models.plan import Plan
# First, try to find existing OAuth provider # First, try to find existing OAuth provider
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id) oauth_provider = UserOAuth.find_by_provider_and_id(
provider, provider_id
)
if oauth_provider: if oauth_provider:
# Update existing provider and user info # Update existing provider and user info
user = oauth_provider.user user = oauth_provider.user
@@ -188,24 +208,26 @@ class User(db.Model):
oauth_provider.name = name oauth_provider.name = name
oauth_provider.picture = picture oauth_provider.picture = picture
oauth_provider.updated_at = datetime.utcnow() oauth_provider.updated_at = datetime.utcnow()
# Update user info with latest data # Update user info with latest data
user.update_from_provider({"email": email, "name": name, "picture": picture}) user.update_from_provider(
{"email": email, "name": name, "picture": picture}
)
else: else:
# Try to find user by email to link the new provider # Try to find user by email to link the new provider
user = cls.find_by_email(email) user = cls.find_by_email(email)
if not user: if not user:
# Check if this is the first user (admin with pro plan) # Check if this is the first user (admin with pro plan)
user_count = cls.query.count() user_count = cls.query.count()
role = "admin" if user_count == 0 else "user" role = "admin" if user_count == 0 else "user"
# Assign plan: first user gets pro, others get free # Assign plan: first user gets pro, others get free
if user_count == 0: if user_count == 0:
plan = Plan.get_pro_plan() plan = Plan.get_pro_plan()
else: else:
plan = Plan.get_default_plan() plan = Plan.get_default_plan()
# Create new user # Create new user
user = cls( user = cls(
email=email, email=email,
@@ -218,7 +240,7 @@ class User(db.Model):
user.generate_api_token() # Generate API token on creation user.generate_api_token() # Generate API token on creation
db.session.add(user) db.session.add(user)
db.session.flush() # Flush to get user.id db.session.flush() # Flush to get user.id
# Create new OAuth provider # Create new OAuth provider
oauth_provider = UserOAuth.create_or_update( oauth_provider = UserOAuth.create_or_update(
user_id=user.id, user_id=user.id,
@@ -228,30 +250,32 @@ class User(db.Model):
name=name, name=name,
picture=picture, picture=picture,
) )
db.session.commit() db.session.commit()
return user, oauth_provider return user, oauth_provider
@classmethod @classmethod
def create_with_password(cls, email: str, password: str, name: str) -> "User": def create_with_password(
cls, email: str, password: str, name: str
) -> "User":
"""Create new user with email and password.""" """Create new user with email and password."""
from app.models.plan import Plan from app.models.plan import Plan
# Check if user already exists # Check if user already exists
existing_user = cls.find_by_email(email) existing_user = cls.find_by_email(email)
if existing_user: if existing_user:
raise ValueError("User with this email already exists") raise ValueError("User with this email already exists")
# Check if this is the first user (admin with pro plan) # Check if this is the first user (admin with pro plan)
user_count = cls.query.count() user_count = cls.query.count()
role = "admin" if user_count == 0 else "user" role = "admin" if user_count == 0 else "user"
# Assign plan: first user gets pro, others get free # Assign plan: first user gets pro, others get free
if user_count == 0: if user_count == 0:
plan = Plan.get_pro_plan() plan = Plan.get_pro_plan()
else: else:
plan = Plan.get_default_plan() plan = Plan.get_default_plan()
# Create new user # Create new user
user = cls( user = cls(
email=email, email=email,
@@ -262,15 +286,17 @@ class User(db.Model):
) )
user.set_password(password) user.set_password(password)
user.generate_api_token() # Generate API token on creation user.generate_api_token() # Generate API token on creation
db.session.add(user) db.session.add(user)
db.session.commit() db.session.commit()
return user return user
@classmethod @classmethod
def authenticate_with_password(cls, email: str, password: str) -> Optional["User"]: def authenticate_with_password(
cls, email: str, password: str
) -> Optional["User"]:
"""Authenticate user with email and password.""" """Authenticate user with email and password."""
user = cls.find_by_email(email) user = cls.find_by_email(email)
if user and user.check_password(password) and user.is_active: if user and user.check_password(password) and user.is_active:
return user return user
return None return None

View File

@@ -14,43 +14,50 @@ if TYPE_CHECKING:
class UserOAuth(db.Model): class UserOAuth(db.Model):
"""Model for storing user's connected OAuth providers.""" """Model for storing user's connected OAuth providers."""
__tablename__ = "user_oauth" __tablename__ = "user_oauth"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
# User relationship # User relationship
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False) user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
# OAuth provider information # OAuth provider information
provider: Mapped[str] = mapped_column(String(50), nullable=False) provider: Mapped[str] = mapped_column(String(50), nullable=False)
provider_id: Mapped[str] = mapped_column(String(255), nullable=False) provider_id: Mapped[str] = mapped_column(String(255), nullable=False)
# Provider-specific user information # Provider-specific user information
email: Mapped[str] = mapped_column(String(255), nullable=False) email: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True) picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# Timestamps # Timestamps
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, nullable=False DateTime, default=datetime.utcnow, nullable=False
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False DateTime,
default=datetime.utcnow,
onupdate=datetime.utcnow,
nullable=False,
) )
# Unique constraint on provider + provider_id combination # Unique constraint on provider + provider_id combination
__table_args__ = ( __table_args__ = (
db.UniqueConstraint("provider", "provider_id", name="unique_provider_user"), db.UniqueConstraint(
"provider", "provider_id", name="unique_provider_user"
),
) )
# Relationships # Relationships
user: Mapped["User"] = relationship("User", back_populates="oauth_providers") user: Mapped["User"] = relationship(
"User", back_populates="oauth_providers"
)
def __repr__(self) -> str: def __repr__(self) -> str:
"""String representation of UserOAuth.""" """String representation of UserOAuth."""
return f"<UserOAuth {self.email} ({self.provider})>" return f"<UserOAuth {self.email} ({self.provider})>"
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""Convert oauth provider to dictionary.""" """Convert oauth provider to dictionary."""
return { return {
@@ -63,25 +70,29 @@ class UserOAuth(db.Model):
"created_at": self.created_at.isoformat(), "created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(), "updated_at": self.updated_at.isoformat(),
} }
@classmethod @classmethod
def find_by_provider_and_id(cls, provider: str, provider_id: str) -> Optional["UserOAuth"]: def find_by_provider_and_id(
cls, provider: str, provider_id: str
) -> Optional["UserOAuth"]:
"""Find OAuth provider by provider name and provider ID.""" """Find OAuth provider by provider name and provider ID."""
return cls.query.filter_by(provider=provider, provider_id=provider_id).first() return cls.query.filter_by(
provider=provider, provider_id=provider_id
).first()
@classmethod @classmethod
def create_or_update( def create_or_update(
cls, cls,
user_id: int, user_id: int,
provider: str, provider: str,
provider_id: str, provider_id: str,
email: str, email: str,
name: str, name: str,
picture: Optional[str] = None picture: Optional[str] = None,
) -> "UserOAuth": ) -> "UserOAuth":
"""Create new OAuth provider or update existing one.""" """Create new OAuth provider or update existing one."""
oauth_provider = cls.find_by_provider_and_id(provider, provider_id) oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
if oauth_provider: if oauth_provider:
# Update existing provider # Update existing provider
oauth_provider.user_id = user_id oauth_provider.user_id = user_id
@@ -100,6 +111,6 @@ class UserOAuth(db.Model):
picture=picture, picture=picture,
) )
db.session.add(oauth_provider) db.session.add(oauth_provider)
db.session.commit() db.session.commit()
return oauth_provider return oauth_provider

View File

@@ -281,19 +281,19 @@ def update_profile():
from flask import request from flask import request
from app.database import db from app.database import db
from app.models.user import User from app.models.user import User
data = request.get_json() data = request.get_json()
if not data: if not data:
return {"error": "No data provided"}, 400 return {"error": "No data provided"}, 400
user_data = get_current_user() user_data = get_current_user()
if not user_data: if not user_data:
return {"error": "User not authenticated"}, 401 return {"error": "User not authenticated"}, 401
user = User.query.get(int(user_data["id"])) user = User.query.get(int(user_data["id"]))
if not user: if not user:
return {"error": "User not found"}, 404 return {"error": "User not found"}, 404
# Update allowed fields # Update allowed fields
if "name" in data: if "name" in data:
name = data["name"].strip() name = data["name"].strip()
@@ -302,10 +302,10 @@ def update_profile():
if len(name) > 100: if len(name) > 100:
return {"error": "Name too long (max 100 characters)"}, 400 return {"error": "Name too long (max 100 characters)"}, 400
user.name = name user.name = name
try: try:
db.session.commit() db.session.commit()
# Return fresh user data from database # Return fresh user data from database
updated_user = { updated_user = {
"id": str(user.id), "id": str(user.id),
@@ -319,11 +319,8 @@ def update_profile():
"plan": user.plan.to_dict() if user.plan else None, "plan": user.plan.to_dict() if user.plan else None,
"credits": user.credits, "credits": user.credits,
} }
return { return {"message": "Profile updated successfully", "user": updated_user}
"message": "Profile updated successfully",
"user": updated_user
}
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
return {"error": f"Failed to update profile: {str(e)}"}, 500 return {"error": f"Failed to update profile: {str(e)}"}, 500
@@ -337,50 +334,50 @@ def change_password():
from app.database import db from app.database import db
from app.models.user import User from app.models.user import User
from werkzeug.security import check_password_hash from werkzeug.security import check_password_hash
data = request.get_json() data = request.get_json()
if not data: if not data:
return {"error": "No data provided"}, 400 return {"error": "No data provided"}, 400
user_data = get_current_user() user_data = get_current_user()
if not user_data: if not user_data:
return {"error": "User not authenticated"}, 401 return {"error": "User not authenticated"}, 401
user = User.query.get(int(user_data["id"])) user = User.query.get(int(user_data["id"]))
if not user: if not user:
return {"error": "User not found"}, 404 return {"error": "User not found"}, 404
new_password = data.get("new_password") new_password = data.get("new_password")
current_password = data.get("current_password") current_password = data.get("current_password")
if not new_password: if not new_password:
return {"error": "New password is required"}, 400 return {"error": "New password is required"}, 400
# Password validation # Password validation
if len(new_password) < 6: if len(new_password) < 6:
return {"error": "Password must be at least 6 characters long"}, 400 return {"error": "Password must be at least 6 characters long"}, 400
# Check authentication method: if user logged in via password, require current password # Check authentication method: if user logged in via password, require current password
# If user logged in via OAuth, they can change password without current password # If user logged in via OAuth, they can change password without current password
current_auth_method = user_data.get("provider", "unknown") current_auth_method = user_data.get("provider", "unknown")
if user.password_hash and current_auth_method == "password": if user.password_hash and current_auth_method == "password":
# User has a password AND logged in via password, require current password for verification # User has a password AND logged in via password, require current password for verification
if not current_password: if not current_password:
return {"error": "Current password is required to change password"}, 400 return {
"error": "Current password is required to change password"
}, 400
if not check_password_hash(user.password_hash, current_password): if not check_password_hash(user.password_hash, current_password):
return {"error": "Current password is incorrect"}, 400 return {"error": "Current password is incorrect"}, 400
# If user logged in via OAuth (google, github, etc.), they can change password without current password # If user logged in via OAuth (google, github, etc.), they can change password without current password
# Set the new password # Set the new password
try: try:
user.set_password(new_password) user.set_password(new_password)
db.session.commit() db.session.commit()
return { return {"message": "Password updated successfully"}
"message": "Password updated successfully"
}
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
return {"error": f"Failed to update password: {str(e)}"}, 500 return {"error": f"Failed to update password: {str(e)}"}, 500

View File

@@ -2,7 +2,12 @@
from flask import Blueprint from flask import Blueprint
from app.services.decorators import get_current_user, require_auth, require_role, require_credits from app.services.decorators import (
get_current_user,
require_auth,
require_role,
require_credits,
)
bp = Blueprint("main", __name__) bp = Blueprint("main", __name__)
@@ -63,7 +68,8 @@ def use_credits(amount: int) -> dict[str, str]:
return { return {
"message": f"Successfully used endpoint! You requested amount: {amount}", "message": f"Successfully used endpoint! You requested amount: {amount}",
"user": user["email"], "user": user["email"],
"remaining_credits": user["credits"] - 5, # Note: credits already deducted by decorator "remaining_credits": user["credits"]
- 5, # Note: credits already deducted by decorator
} }

View File

@@ -83,7 +83,9 @@ class AuthService:
# Prepare user data for JWT token using user.to_dict() # Prepare user data for JWT token using user.to_dict()
jwt_user_data = user.to_dict() jwt_user_data = user.to_dict()
jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login jwt_user_data["provider"] = (
oauth_provider.provider
) # Override provider for OAuth login
# Generate JWT tokens # Generate JWT tokens
access_token = self.token_service.generate_access_token( access_token = self.token_service.generate_access_token(
@@ -156,7 +158,9 @@ class AuthService:
# Prepare user data for JWT token using user.to_dict() # Prepare user data for JWT token using user.to_dict()
jwt_user_data = user.to_dict() jwt_user_data = user.to_dict()
jwt_user_data["provider"] = "password" # Override provider for password registration jwt_user_data["provider"] = (
"password" # Override provider for password registration
)
# Generate JWT tokens # Generate JWT tokens
access_token = self.token_service.generate_access_token( access_token = self.token_service.generate_access_token(
@@ -199,7 +203,9 @@ class AuthService:
# Prepare user data for JWT token using user.to_dict() # Prepare user data for JWT token using user.to_dict()
jwt_user_data = user.to_dict() jwt_user_data = user.to_dict()
jwt_user_data["provider"] = "password" # Override provider for password login jwt_user_data["provider"] = (
"password" # Override provider for password login
)
# Generate JWT tokens # Generate JWT tokens
access_token = self.token_service.generate_access_token(jwt_user_data) access_token = self.token_service.generate_access_token(jwt_user_data)

View File

@@ -12,14 +12,14 @@ def get_user_from_jwt() -> dict[str, Any] | None:
try: try:
# Try to verify JWT token in request - this sets up the context # Try to verify JWT token in request - this sets up the context
verify_jwt_in_request() verify_jwt_in_request()
current_user_id = get_jwt_identity() current_user_id = get_jwt_identity()
if not current_user_id: if not current_user_id:
return None return None
# Query database for user data instead of using JWT claims # Query database for user data instead of using JWT claims
from app.models.user import User from app.models.user import User
user = User.query.get(int(current_user_id)) user = User.query.get(int(current_user_id))
if not user or not user.is_active: if not user or not user.is_active:
return None return None
@@ -70,7 +70,7 @@ def get_user_from_api_token() -> dict[str, Any] | None:
providers.append("password") providers.append("password")
if user.api_token: if user.api_token:
providers.append("api_token") providers.append("api_token")
return { return {
"id": str(user.id), "id": str(user.id),
"email": user.email, "email": user.email,
@@ -148,22 +148,23 @@ def require_role(required_role: str):
def require_credits(credits_needed: int): def require_credits(credits_needed: int):
"""Decorator to require and deduct credits for routes.""" """Decorator to require and deduct credits for routes."""
def decorator(f): def decorator(f):
@wraps(f) @wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
from app.models.user import User from app.models.user import User
from app.database import db from app.database import db
# First check authentication # First check authentication
user_data = get_current_user() user_data = get_current_user()
if not user_data: if not user_data:
return jsonify({"error": "Authentication required"}), 401 return jsonify({"error": "Authentication required"}), 401
# Get the actual user from database to check/update credits # Get the actual user from database to check/update credits
user = User.query.get(int(user_data["id"])) user = User.query.get(int(user_data["id"]))
if not user or not user.is_active: if not user or not user.is_active:
return jsonify({"error": "User not found or inactive"}), 401 return jsonify({"error": "User not found or inactive"}), 401
# Check if user has enough credits # Check if user has enough credits
if user.credits < credits_needed: if user.credits < credits_needed:
return ( return (
@@ -174,15 +175,16 @@ def require_credits(credits_needed: int):
), ),
402, # Payment Required status code 402, # Payment Required status code
) )
# Deduct credits # Deduct credits
user.credits -= credits_needed user.credits -= credits_needed
db.session.commit() db.session.commit()
# Execute the function # Execute the function
result = f(*args, **kwargs) result = f(*args, **kwargs)
return result return result
return wrapper return wrapper
return decorator return decorator

View File

@@ -5,35 +5,35 @@ from authlib.integrations.flask_client import OAuth
class OAuthProvider(ABC): class OAuthProvider(ABC):
"""Abstract base class for OAuth providers.""" """Abstract base class for OAuth providers."""
def __init__(self, oauth: OAuth, client_id: str, client_secret: str): def __init__(self, oauth: OAuth, client_id: str, client_secret: str):
self.oauth = oauth self.oauth = oauth
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret self.client_secret = client_secret
self._client = None self._client = None
@property @property
@abstractmethod @abstractmethod
def name(self) -> str: def name(self) -> str:
"""Provider name (e.g., 'google', 'github').""" """Provider name (e.g., 'google', 'github')."""
pass pass
@property @property
@abstractmethod @abstractmethod
def display_name(self) -> str: def display_name(self) -> str:
"""Human-readable provider name (e.g., 'Google', 'GitHub').""" """Human-readable provider name (e.g., 'Google', 'GitHub')."""
pass pass
@abstractmethod @abstractmethod
def get_client_config(self) -> Dict[str, Any]: def get_client_config(self) -> Dict[str, Any]:
"""Return OAuth client configuration.""" """Return OAuth client configuration."""
pass pass
@abstractmethod @abstractmethod
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]: def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
"""Extract user information from OAuth token response.""" """Extract user information from OAuth token response."""
pass pass
def get_client(self): def get_client(self):
"""Get or create OAuth client.""" """Get or create OAuth client."""
if self._client is None: if self._client is None:
@@ -42,27 +42,29 @@ class OAuthProvider(ABC):
name=self.name, name=self.name,
client_id=self.client_id, client_id=self.client_id,
client_secret=self.client_secret, client_secret=self.client_secret,
**config **config,
) )
return self._client return self._client
def get_authorization_url(self, redirect_uri: str) -> str: def get_authorization_url(self, redirect_uri: str) -> str:
"""Generate authorization URL for OAuth flow.""" """Generate authorization URL for OAuth flow."""
client = self.get_client() client = self.get_client()
return client.authorize_redirect(redirect_uri).location return client.authorize_redirect(redirect_uri).location
def exchange_code_for_token(self, code: str = None, redirect_uri: str = None) -> Dict[str, Any]: def exchange_code_for_token(
self, code: str = None, redirect_uri: str = None
) -> Dict[str, Any]:
"""Exchange authorization code for access token.""" """Exchange authorization code for access token."""
client = self.get_client() client = self.get_client()
token = client.authorize_access_token() token = client.authorize_access_token()
return token return token
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]: def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize user data to common format.""" """Normalize user data to common format."""
return { return {
'id': user_info.get('id'), "id": user_info.get("id"),
'email': user_info.get('email'), "email": user_info.get("email"),
'name': user_info.get('name'), "name": user_info.get("name"),
'picture': user_info.get('picture'), "picture": user_info.get("picture"),
'provider': self.name "provider": self.name,
} }

View File

@@ -4,49 +4,47 @@ from .base import OAuthProvider
class GitHubOAuthProvider(OAuthProvider): class GitHubOAuthProvider(OAuthProvider):
"""GitHub OAuth provider implementation.""" """GitHub OAuth provider implementation."""
@property @property
def name(self) -> str: def name(self) -> str:
return 'github' return "github"
@property @property
def display_name(self) -> str: def display_name(self) -> str:
return 'GitHub' return "GitHub"
def get_client_config(self) -> Dict[str, Any]: def get_client_config(self) -> Dict[str, Any]:
"""Return GitHub OAuth client configuration.""" """Return GitHub OAuth client configuration."""
return { return {
'access_token_url': 'https://github.com/login/oauth/access_token', "access_token_url": "https://github.com/login/oauth/access_token",
'authorize_url': 'https://github.com/login/oauth/authorize', "authorize_url": "https://github.com/login/oauth/authorize",
'api_base_url': 'https://api.github.com/', "api_base_url": "https://api.github.com/",
'client_kwargs': { "client_kwargs": {"scope": "user:email"},
'scope': 'user:email'
}
} }
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]: def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
"""Extract user information from GitHub OAuth token response.""" """Extract user information from GitHub OAuth token response."""
client = self.get_client() client = self.get_client()
# Get user profile # Get user profile
user_resp = client.get('user', token=token) user_resp = client.get("user", token=token)
user_data = user_resp.json() user_data = user_resp.json()
# Get user email (may be private) # Get user email (may be private)
email = user_data.get('email') email = user_data.get("email")
if not email: if not email:
# If email is private, get from emails endpoint # If email is private, get from emails endpoint
emails_resp = client.get('user/emails', token=token) emails_resp = client.get("user/emails", token=token)
emails = emails_resp.json() emails = emails_resp.json()
# Find primary email # Find primary email
for email_obj in emails: for email_obj in emails:
if email_obj.get('primary', False): if email_obj.get("primary", False):
email = email_obj.get('email') email = email_obj.get("email")
break break
return { return {
'id': str(user_data.get('id')), "id": str(user_data.get("id")),
'email': email, "email": email,
'name': user_data.get('name') or user_data.get('login'), "name": user_data.get("name") or user_data.get("login"),
'picture': user_data.get('avatar_url') "picture": user_data.get("avatar_url"),
} }

View File

@@ -8,38 +8,38 @@ from .github import GitHubOAuthProvider
class OAuthProviderRegistry: class OAuthProviderRegistry:
"""Registry for OAuth providers.""" """Registry for OAuth providers."""
def __init__(self, oauth: OAuth): def __init__(self, oauth: OAuth):
self.oauth = oauth self.oauth = oauth
self._providers: Dict[str, OAuthProvider] = {} self._providers: Dict[str, OAuthProvider] = {}
self._initialize_providers() self._initialize_providers()
def _initialize_providers(self): def _initialize_providers(self):
"""Initialize available providers based on environment variables.""" """Initialize available providers based on environment variables."""
# Google OAuth # Google OAuth
google_client_id = os.getenv('GOOGLE_CLIENT_ID') google_client_id = os.getenv("GOOGLE_CLIENT_ID")
google_client_secret = os.getenv('GOOGLE_CLIENT_SECRET') google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
if google_client_id and google_client_secret: if google_client_id and google_client_secret:
self._providers['google'] = GoogleOAuthProvider( self._providers["google"] = GoogleOAuthProvider(
self.oauth, google_client_id, google_client_secret self.oauth, google_client_id, google_client_secret
) )
# GitHub OAuth # GitHub OAuth
github_client_id = os.getenv('GITHUB_CLIENT_ID') github_client_id = os.getenv("GITHUB_CLIENT_ID")
github_client_secret = os.getenv('GITHUB_CLIENT_SECRET') github_client_secret = os.getenv("GITHUB_CLIENT_SECRET")
if github_client_id and github_client_secret: if github_client_id and github_client_secret:
self._providers['github'] = GitHubOAuthProvider( self._providers["github"] = GitHubOAuthProvider(
self.oauth, github_client_id, github_client_secret self.oauth, github_client_id, github_client_secret
) )
def get_provider(self, name: str) -> Optional[OAuthProvider]: def get_provider(self, name: str) -> Optional[OAuthProvider]:
"""Get OAuth provider by name.""" """Get OAuth provider by name."""
return self._providers.get(name) return self._providers.get(name)
def get_available_providers(self) -> Dict[str, OAuthProvider]: def get_available_providers(self) -> Dict[str, OAuthProvider]:
"""Get all available providers.""" """Get all available providers."""
return self._providers.copy() return self._providers.copy()
def is_provider_available(self, name: str) -> bool: def is_provider_available(self, name: str) -> bool:
"""Check if provider is available.""" """Check if provider is available."""
return name in self._providers return name in self._providers

View File

@@ -9,22 +9,27 @@ from app.database import db
app = create_app() app = create_app()
cli = FlaskGroup(app) cli = FlaskGroup(app)
@cli.command() @cli.command()
def init_db(): def init_db():
"""Initialize the database.""" """Initialize the database."""
print("Initializing database...") print("Initializing database...")
from app.database_init import init_database from app.database_init import init_database
init_database() init_database()
print("Database initialized successfully!") print("Database initialized successfully!")
@cli.command() @cli.command()
def reset_db(): def reset_db():
"""Reset the database (drop all tables and recreate).""" """Reset the database (drop all tables and recreate)."""
print("Resetting database...") print("Resetting database...")
db.drop_all() db.drop_all()
from app.database_init import init_database from app.database_init import init_database
init_database() init_database()
print("Database reset successfully!") print("Database reset successfully!")
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()

View File

@@ -23,32 +23,45 @@ class TestAuthRoutesJWTExtended:
@patch("app.routes.auth.auth_service.get_login_url") @patch("app.routes.auth.auth_service.get_login_url")
def test_login_route(self, mock_get_login_url: Mock, client) -> None: def test_login_route(self, mock_get_login_url: Mock, client) -> None:
"""Test the login route.""" """Test the login route."""
mock_get_login_url.return_value = "https://accounts.google.com/oauth/authorize?..." mock_get_login_url.return_value = (
"https://accounts.google.com/oauth/authorize?..."
)
response = client.get("/api/auth/login") response = client.get("/api/auth/login")
assert response.status_code == 200 assert response.status_code == 200
data = response.get_json() data = response.get_json()
assert "login_url" in data assert "login_url" in data
assert data["login_url"] == "https://accounts.google.com/oauth/authorize?..." assert (
data["login_url"]
== "https://accounts.google.com/oauth/authorize?..."
)
@patch("app.routes.auth.auth_service.handle_callback") @patch("app.routes.auth.auth_service.handle_callback")
def test_callback_route_success(self, mock_handle_callback: Mock, client) -> None: def test_callback_route_success(
self, mock_handle_callback: Mock, client
) -> None:
"""Test successful callback route.""" """Test successful callback route."""
mock_response = Mock() mock_response = Mock()
mock_response.get_json.return_value = { mock_response.get_json.return_value = {
"message": "Login successful", "message": "Login successful",
"user": {"id": "123", "email": "test@example.com", "name": "Test User"} "user": {
"id": "123",
"email": "test@example.com",
"name": "Test User",
},
} }
mock_handle_callback.return_value = mock_response mock_handle_callback.return_value = mock_response
response = client.get("/api/auth/callback?code=test_code") response = client.get("/api/auth/callback?code=test_code")
mock_handle_callback.assert_called_once() mock_handle_callback.assert_called_once()
@patch("app.routes.auth.auth_service.handle_callback") @patch("app.routes.auth.auth_service.handle_callback")
def test_callback_route_error(self, mock_handle_callback: Mock, client) -> None: def test_callback_route_error(
self, mock_handle_callback: Mock, client
) -> None:
"""Test callback route with error.""" """Test callback route with error."""
mock_handle_callback.side_effect = Exception("OAuth error") mock_handle_callback.side_effect = Exception("OAuth error")
response = client.get("/api/auth/callback?code=test_code") response = client.get("/api/auth/callback?code=test_code")
assert response.status_code == 400 assert response.status_code == 400
data = response.get_json() data = response.get_json()
@@ -58,9 +71,11 @@ class TestAuthRoutesJWTExtended:
def test_logout_route(self, mock_logout: Mock, client) -> None: def test_logout_route(self, mock_logout: Mock, client) -> None:
"""Test logout route.""" """Test logout route."""
mock_response = Mock() mock_response = Mock()
mock_response.get_json.return_value = {"message": "Logged out successfully"} mock_response.get_json.return_value = {
"message": "Logged out successfully"
}
mock_logout.return_value = mock_response mock_logout.return_value = mock_response
response = client.get("/api/auth/logout") response = client.get("/api/auth/logout")
mock_logout.assert_called_once() mock_logout.assert_called_once()
@@ -76,4 +91,4 @@ class TestAuthRoutesJWTExtended:
response = client.post("/api/auth/refresh") response = client.post("/api/auth/refresh")
assert response.status_code == 401 assert response.status_code == 401
data = response.get_json() data = response.get_json()
assert "msg" in data # Flask-JWT-Extended error format assert "msg" in data # Flask-JWT-Extended error format

View File

@@ -21,13 +21,13 @@ class TestAuthServiceJWTExtended:
"""Test initializing AuthService with Flask app.""" """Test initializing AuthService with Flask app."""
mock_getenv.side_effect = lambda key: { mock_getenv.side_effect = lambda key: {
"GOOGLE_CLIENT_ID": "test_client_id", "GOOGLE_CLIENT_ID": "test_client_id",
"GOOGLE_CLIENT_SECRET": "test_client_secret" "GOOGLE_CLIENT_SECRET": "test_client_secret",
}.get(key) }.get(key)
app = create_app() app = create_app()
auth_service = AuthService() auth_service = AuthService()
auth_service.init_app(app) auth_service.init_app(app)
# Verify OAuth was initialized # Verify OAuth was initialized
assert auth_service.google is not None assert auth_service.google is not None
@@ -39,10 +39,12 @@ class TestAuthServiceJWTExtended:
with app.app_context(): with app.app_context():
mock_response = Mock() mock_response = Mock()
mock_jsonify.return_value = mock_response mock_jsonify.return_value = mock_response
auth_service = AuthService() auth_service = AuthService()
result = auth_service.logout() result = auth_service.logout()
assert result == mock_response assert result == mock_response
mock_unset.assert_called_once_with(mock_response) mock_unset.assert_called_once_with(mock_response)
mock_jsonify.assert_called_once_with({"message": "Logged out successfully"}) mock_jsonify.assert_called_once_with(
{"message": "Logged out successfully"}
)

View File

@@ -21,7 +21,10 @@ class TestMainRoutes:
"""Test the index route.""" """Test the index route."""
response = client.get("/api/") response = client.get("/api/")
assert response.status_code == 200 assert response.status_code == 200
assert response.get_json() == {"message": "API is running", "status": "ok"} assert response.get_json() == {
"message": "API is running",
"status": "ok",
}
def test_health_route(self, client) -> None: def test_health_route(self, client) -> None:
"""Test health check route.""" """Test health check route."""
@@ -34,4 +37,4 @@ class TestMainRoutes:
response = client.get("/api/protected") response = client.get("/api/protected")
assert response.status_code == 401 assert response.status_code == 401
data = response.get_json() data = response.get_json()
assert data["error"] == "Authentication required (JWT or API token)" assert data["error"] == "Authentication required (JWT or API token)"

View File

@@ -25,14 +25,18 @@ class TestTokenService:
user_data = { user_data = {
"id": "123", "id": "123",
"email": "test@example.com", "email": "test@example.com",
"name": "Test User" "name": "Test User",
} }
token = token_service.generate_access_token(user_data) token = token_service.generate_access_token(user_data)
assert isinstance(token, str) assert isinstance(token, str)
# Verify token content # Verify token content
payload = jwt.decode(token, token_service.secret_key, algorithms=[token_service.algorithm]) payload = jwt.decode(
token,
token_service.secret_key,
algorithms=[token_service.algorithm],
)
assert payload["user_id"] == "123" assert payload["user_id"] == "123"
assert payload["email"] == "test@example.com" assert payload["email"] == "test@example.com"
assert payload["name"] == "Test User" assert payload["name"] == "Test User"
@@ -44,25 +48,33 @@ class TestTokenService:
user_data = { user_data = {
"id": "123", "id": "123",
"email": "test@example.com", "email": "test@example.com",
"name": "Test User" "name": "Test User",
} }
token = token_service.generate_refresh_token(user_data) token = token_service.generate_refresh_token(user_data)
assert isinstance(token, str) assert isinstance(token, str)
# Verify token content # Verify token content
payload = jwt.decode(token, token_service.secret_key, algorithms=[token_service.algorithm]) payload = jwt.decode(
token,
token_service.secret_key,
algorithms=[token_service.algorithm],
)
assert payload["user_id"] == "123" assert payload["user_id"] == "123"
assert payload["type"] == "refresh" assert payload["type"] == "refresh"
def test_verify_valid_token(self) -> None: def test_verify_valid_token(self) -> None:
"""Test verifying a valid token.""" """Test verifying a valid token."""
token_service = TokenService() token_service = TokenService()
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"} user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_access_token(user_data) token = token_service.generate_access_token(user_data)
payload = token_service.verify_token(token) payload = token_service.verify_token(token)
assert payload is not None assert payload is not None
assert payload["user_id"] == "123" assert payload["user_id"] == "123"
assert payload["type"] == "access" assert payload["type"] == "access"
@@ -70,7 +82,7 @@ class TestTokenService:
def test_verify_invalid_token(self) -> None: def test_verify_invalid_token(self) -> None:
"""Test verifying an invalid token.""" """Test verifying an invalid token."""
token_service = TokenService() token_service = TokenService()
payload = token_service.verify_token("invalid.token.here") payload = token_service.verify_token("invalid.token.here")
assert payload is None assert payload is None
@@ -81,61 +93,75 @@ class TestTokenService:
past_time = datetime(2020, 1, 1, tzinfo=timezone.utc) past_time = datetime(2020, 1, 1, tzinfo=timezone.utc)
mock_datetime.now.return_value = past_time mock_datetime.now.return_value = past_time
mock_datetime.UTC = timezone.utc mock_datetime.UTC = timezone.utc
token_service = TokenService() token_service = TokenService()
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"} user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_access_token(user_data) token = token_service.generate_access_token(user_data)
# Reset mock to current time for verification # Reset mock to current time for verification
mock_datetime.now.return_value = datetime.now(timezone.utc) mock_datetime.now.return_value = datetime.now(timezone.utc)
payload = token_service.verify_token(token) payload = token_service.verify_token(token)
assert payload is None assert payload is None
def test_is_access_token(self) -> None: def test_is_access_token(self) -> None:
"""Test access token type checking.""" """Test access token type checking."""
token_service = TokenService() token_service = TokenService()
access_payload = {"type": "access", "user_id": "123"} access_payload = {"type": "access", "user_id": "123"}
refresh_payload = {"type": "refresh", "user_id": "123"} refresh_payload = {"type": "refresh", "user_id": "123"}
assert token_service.is_access_token(access_payload) assert token_service.is_access_token(access_payload)
assert not token_service.is_access_token(refresh_payload) assert not token_service.is_access_token(refresh_payload)
def test_is_refresh_token(self) -> None: def test_is_refresh_token(self) -> None:
"""Test refresh token type checking.""" """Test refresh token type checking."""
token_service = TokenService() token_service = TokenService()
access_payload = {"type": "access", "user_id": "123"} access_payload = {"type": "access", "user_id": "123"}
refresh_payload = {"type": "refresh", "user_id": "123"} refresh_payload = {"type": "refresh", "user_id": "123"}
assert token_service.is_refresh_token(refresh_payload) assert token_service.is_refresh_token(refresh_payload)
assert not token_service.is_refresh_token(access_payload) assert not token_service.is_refresh_token(access_payload)
def test_get_user_from_access_token_valid(self) -> None: def test_get_user_from_access_token_valid(self) -> None:
"""Test extracting user from valid access token.""" """Test extracting user from valid access token."""
token_service = TokenService() token_service = TokenService()
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"} user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_access_token(user_data) token = token_service.generate_access_token(user_data)
extracted_user = token_service.get_user_from_access_token(token) extracted_user = token_service.get_user_from_access_token(token)
assert extracted_user == user_data assert extracted_user == user_data
def test_get_user_from_access_token_refresh_token(self) -> None: def test_get_user_from_access_token_refresh_token(self) -> None:
"""Test extracting user from refresh token (should fail).""" """Test extracting user from refresh token (should fail)."""
token_service = TokenService() token_service = TokenService()
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"} user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_refresh_token(user_data) token = token_service.generate_refresh_token(user_data)
extracted_user = token_service.get_user_from_access_token(token) extracted_user = token_service.get_user_from_access_token(token)
assert extracted_user is None assert extracted_user is None
def test_get_user_from_access_token_invalid(self) -> None: def test_get_user_from_access_token_invalid(self) -> None:
"""Test extracting user from invalid token.""" """Test extracting user from invalid token."""
token_service = TokenService() token_service = TokenService()
extracted_user = token_service.get_user_from_access_token("invalid.token") extracted_user = token_service.get_user_from_access_token(
assert extracted_user is None "invalid.token"
)
assert extracted_user is None

View File

@@ -18,9 +18,9 @@ class TestTokenServiceJWTExtended:
"id": "123", "id": "123",
"email": "test@example.com", "email": "test@example.com",
"name": "Test User", "name": "Test User",
"picture": "https://example.com/pic.jpg" "picture": "https://example.com/pic.jpg",
} }
token = token_service.generate_access_token(user_data) token = token_service.generate_access_token(user_data)
assert isinstance(token, str) assert isinstance(token, str)
assert len(token) > 0 assert len(token) > 0
@@ -33,9 +33,9 @@ class TestTokenServiceJWTExtended:
user_data = { user_data = {
"id": "123", "id": "123",
"email": "test@example.com", "email": "test@example.com",
"name": "Test User" "name": "Test User",
} }
token = token_service.generate_refresh_token(user_data) token = token_service.generate_refresh_token(user_data)
assert isinstance(token, str) assert isinstance(token, str)
assert len(token) > 0 assert len(token) > 0
@@ -48,10 +48,10 @@ class TestTokenServiceJWTExtended:
user_data = { user_data = {
"id": "123", "id": "123",
"email": "test@example.com", "email": "test@example.com",
"name": "Test User" "name": "Test User",
} }
access_token = token_service.generate_access_token(user_data) access_token = token_service.generate_access_token(user_data)
refresh_token = token_service.generate_refresh_token(user_data) refresh_token = token_service.generate_refresh_token(user_data)
assert access_token != refresh_token assert access_token != refresh_token