refactor: clean up code by adding missing commas and improving import order

This commit is contained in:
JSC
2025-07-02 10:46:53 +02:00
parent 171dbb9b63
commit 703212656f
20 changed files with 87 additions and 496 deletions

View File

@@ -26,7 +26,7 @@ def create_app():
# Configure Flask-JWT-Extended # Configure Flask-JWT-Extended
app.config["JWT_SECRET_KEY"] = os.environ.get( app.config["JWT_SECRET_KEY"] = os.environ.get(
"JWT_SECRET_KEY", "jwt-secret-key" "JWT_SECRET_KEY", "jwt-secret-key",
) )
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15) app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7) app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7)

View File

@@ -1,7 +1,7 @@
"""Database configuration and initialization.""" """Database configuration and initialization."""
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
db = SQLAlchemy() db = SQLAlchemy()
migrate = Migrate() migrate = Migrate()

View File

@@ -68,7 +68,7 @@ def migrate_users_to_plans():
# 0 credits means they spent them, NULL means they never got assigned # 0 credits means they spent them, NULL means they never got assigned
try: try:
users_without_credits = User.query.filter( users_without_credits = User.query.filter(
User.plan_id.isnot(None), User.credits.is_(None) User.plan_id.isnot(None), User.credits.is_(None),
).all() ).all()
except Exception: except Exception:
# Credits column doesn't exist yet, will be handled by create_all # Credits column doesn't exist yet, will be handled by create_all
@@ -113,7 +113,7 @@ def migrate_users_to_plans():
if updated_count > 0: if updated_count > 0:
db.session.commit() db.session.commit()
print( print(
f"Updated {updated_count} existing users with plans and credits" f"Updated {updated_count} existing users with plans and credits",
) )
except Exception: except Exception:

View File

@@ -2,17 +2,17 @@
import secrets import secrets
from datetime import datetime from datetime import datetime
from typing import Optional, TYPE_CHECKING from typing import TYPE_CHECKING, Optional
from werkzeug.security import check_password_hash, generate_password_hash from sqlalchemy import DateTime, ForeignKey, Integer, String
from sqlalchemy import String, DateTime, Integer, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from werkzeug.security import check_password_hash, generate_password_hash
from app.database import db from app.database import db
if TYPE_CHECKING: if TYPE_CHECKING:
from app.models.user_oauth import UserOAuth
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user_oauth import UserOAuth
class User(db.Model): class User(db.Model):
@@ -25,16 +25,16 @@ class User(db.Model):
# Primary user information (can be updated from any connected provider) # Primary user information (can be updated from any connected provider)
email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True)
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) picture: Mapped[str | None] = mapped_column(String(500), nullable=True)
# Password authentication (optional - users can use OAuth instead) # Password authentication (optional - users can use OAuth instead)
password_hash: Mapped[Optional[str]] = mapped_column( password_hash: Mapped[str | None] = mapped_column(
String(255), nullable=True String(255), nullable=True,
) )
# Role-based access control # Role-based access control
role: Mapped[str] = mapped_column( role: Mapped[str] = mapped_column(
String(50), nullable=False, default="user" String(50), nullable=False, default="user",
) )
# User status # User status
@@ -42,21 +42,21 @@ class User(db.Model):
# Plan relationship # Plan relationship
plan_id: Mapped[int] = mapped_column( plan_id: Mapped[int] = mapped_column(
Integer, ForeignKey("plans.id"), nullable=False Integer, ForeignKey("plans.id"), nullable=False,
) )
# User credits (populated from plan credits on creation) # User credits (populated from plan credits on creation)
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0) credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
# API token for programmatic access # API token for programmatic access
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) api_token: Mapped[str | None] = mapped_column(String(255), nullable=True)
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column( api_token_expires_at: Mapped[datetime | None] = mapped_column(
DateTime, nullable=True DateTime, nullable=True,
) )
# Timestamps # Timestamps
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, nullable=False DateTime, default=datetime.utcnow, nullable=False,
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime, DateTime,
@@ -67,7 +67,7 @@ class User(db.Model):
# Relationships # Relationships
oauth_providers: Mapped[list["UserOAuth"]] = relationship( oauth_providers: Mapped[list["UserOAuth"]] = relationship(
"UserOAuth", back_populates="user", cascade="all, delete-orphan" "UserOAuth", back_populates="user", cascade="all, delete-orphan",
) )
plan: Mapped["Plan"] = relationship("Plan", back_populates="users") plan: Mapped["Plan"] = relationship("Plan", back_populates="users")
@@ -190,15 +190,15 @@ class User(db.Model):
provider_id: str, provider_id: str,
email: str, email: str,
name: str, name: str,
picture: Optional[str] = None, picture: str | None = None,
) -> tuple["User", "UserOAuth"]: ) -> tuple["User", "UserOAuth"]:
"""Find existing user or create new one from OAuth data.""" """Find existing user or create new one from OAuth data."""
from app.models.user_oauth import UserOAuth
from app.models.plan import Plan from app.models.plan import Plan
from app.models.user_oauth import UserOAuth
# First, try to find existing OAuth provider # First, try to find existing OAuth provider
oauth_provider = UserOAuth.find_by_provider_and_id( oauth_provider = UserOAuth.find_by_provider_and_id(
provider, provider_id provider, provider_id,
) )
if oauth_provider: if oauth_provider:
@@ -211,7 +211,7 @@ class User(db.Model):
# Update user info with latest data # Update user info with latest data
user.update_from_provider( user.update_from_provider(
{"email": email, "name": name, "picture": picture} {"email": email, "name": name, "picture": picture},
) )
else: else:
# Try to find user by email to link the new provider # Try to find user by email to link the new provider
@@ -256,7 +256,7 @@ class User(db.Model):
@classmethod @classmethod
def create_with_password( def create_with_password(
cls, email: str, password: str, name: str cls, email: str, password: str, name: str,
) -> "User": ) -> "User":
"""Create new user with email and password.""" """Create new user with email and password."""
from app.models.plan import Plan from app.models.plan import Plan
@@ -293,7 +293,7 @@ class User(db.Model):
@classmethod @classmethod
def authenticate_with_password( def authenticate_with_password(
cls, email: str, password: str cls, email: str, password: str,
) -> Optional["User"]: ) -> Optional["User"]:
"""Authenticate user with email and password.""" """Authenticate user with email and password."""
user = cls.find_by_email(email) user = cls.find_by_email(email)

View File

@@ -1,9 +1,9 @@
"""User OAuth model for storing user's connected providers.""" """User OAuth model for storing user's connected providers."""
from datetime import datetime from datetime import datetime
from typing import Optional, TYPE_CHECKING from typing import TYPE_CHECKING, Optional
from sqlalchemy import String, DateTime, Text, ForeignKey from sqlalchemy import DateTime, ForeignKey, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import db from app.database import db
@@ -29,11 +29,11 @@ class UserOAuth(db.Model):
# Provider-specific user information # Provider-specific user information
email: Mapped[str] = mapped_column(String(255), nullable=False) email: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True) picture: Mapped[str | None] = mapped_column(Text, nullable=True)
# Timestamps # Timestamps
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, nullable=False DateTime, default=datetime.utcnow, nullable=False,
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime, DateTime,
@@ -45,13 +45,13 @@ class UserOAuth(db.Model):
# Unique constraint on provider + provider_id combination # Unique constraint on provider + provider_id combination
__table_args__ = ( __table_args__ = (
db.UniqueConstraint( db.UniqueConstraint(
"provider", "provider_id", name="unique_provider_user" "provider", "provider_id", name="unique_provider_user",
), ),
) )
# Relationships # Relationships
user: Mapped["User"] = relationship( user: Mapped["User"] = relationship(
"User", back_populates="oauth_providers" "User", back_populates="oauth_providers",
) )
def __repr__(self) -> str: def __repr__(self) -> str:
@@ -73,11 +73,11 @@ class UserOAuth(db.Model):
@classmethod @classmethod
def find_by_provider_and_id( def find_by_provider_and_id(
cls, provider: str, provider_id: str cls, provider: str, provider_id: str,
) -> Optional["UserOAuth"]: ) -> Optional["UserOAuth"]:
"""Find OAuth provider by provider name and provider ID.""" """Find OAuth provider by provider name and provider ID."""
return cls.query.filter_by( return cls.query.filter_by(
provider=provider, provider_id=provider_id provider=provider, provider_id=provider_id,
).first() ).first()
@classmethod @classmethod
@@ -88,7 +88,7 @@ class UserOAuth(db.Model):
provider_id: str, provider_id: str,
email: str, email: str,
name: str, name: str,
picture: Optional[str] = None, picture: str | None = None,
) -> "UserOAuth": ) -> "UserOAuth":
"""Create new OAuth provider or update existing one.""" """Create new OAuth provider or update existing one."""
oauth_provider = cls.find_by_provider_and_id(provider, provider_id) oauth_provider = cls.find_by_provider_and_id(provider, provider_id)

View File

@@ -31,7 +31,7 @@ def callback(provider):
# If successful, redirect to frontend dashboard with cookies # If successful, redirect to frontend dashboard with cookies
if auth_response.status_code == 200: if auth_response.status_code == 200:
redirect_response = make_response( redirect_response = make_response(
redirect("http://localhost:3000/dashboard") redirect("http://localhost:3000/dashboard"),
) )
# Copy all cookies from the auth response # Copy all cookies from the auth response
@@ -39,9 +39,8 @@ def callback(provider):
redirect_response.headers.add("Set-Cookie", cookie) redirect_response.headers.add("Set-Cookie", cookie)
return redirect_response return redirect_response
else: # If there was an error, redirect to login with error
# If there was an error, redirect to login with error return redirect("http://localhost:3000/login?error=oauth_failed")
return redirect("http://localhost:3000/login?error=oauth_failed")
except Exception as e: except Exception as e:
error_msg = str(e).replace(" ", "_").replace('"', "") error_msg = str(e).replace(" ", "_").replace('"', "")
@@ -129,7 +128,7 @@ def refresh():
def link_provider(provider): def link_provider(provider):
"""Link a new OAuth provider to current user account.""" """Link a new OAuth provider to current user account."""
redirect_uri = url_for( redirect_uri = url_for(
"auth.link_callback", provider=provider, _external=True "auth.link_callback", provider=provider, _external=True,
) )
return auth_service.redirect_to_login(provider, redirect_uri) return auth_service.redirect_to_login(provider, redirect_uri)
@@ -168,19 +167,19 @@ def link_callback(provider):
if not provider_data.get("id"): if not provider_data.get("id"):
return { return {
"error": "Failed to get user information from provider" "error": "Failed to get user information from provider",
}, 400 }, 400
# Check if this provider is already linked to another user # Check if this provider is already linked to another user
from app.models.user_oauth import UserOAuth from app.models.user_oauth import UserOAuth
existing_provider = UserOAuth.find_by_provider_and_id( existing_provider = UserOAuth.find_by_provider_and_id(
provider, provider_data["id"] provider, provider_data["id"],
) )
if existing_provider and existing_provider.user_id != user.id: if existing_provider and existing_provider.user_id != user.id:
return { return {
"error": "This provider account is already linked to another user" "error": "This provider account is already linked to another user",
}, 409 }, 409
# Link the provider to current user # Link the provider to current user
@@ -210,7 +209,6 @@ def unlink_provider(provider):
from app.database import db from app.database import db
from app.models.user import User from app.models.user import User
from app.models.user_oauth import UserOAuth
user = User.query.get(current_user_id) user = User.query.get(current_user_id)
if not user: if not user:
@@ -224,7 +222,7 @@ def unlink_provider(provider):
oauth_provider = user.get_provider(provider) oauth_provider = user.get_provider(provider)
if not oauth_provider: if not oauth_provider:
return { return {
"error": f"Provider '{provider}' not linked to this account" "error": f"Provider '{provider}' not linked to this account",
}, 404 }, 404
db.session.delete(oauth_provider) db.session.delete(oauth_provider)
@@ -279,6 +277,7 @@ def me():
def update_profile(): def update_profile():
"""Update current user profile information.""" """Update current user profile information."""
from flask import request from flask import request
from app.database import db from app.database import db
from app.models.user import User from app.models.user import User
@@ -323,7 +322,7 @@ def update_profile():
return {"message": "Profile updated successfully", "user": updated_user} return {"message": "Profile updated successfully", "user": updated_user}
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
return {"error": f"Failed to update profile: {str(e)}"}, 500 return {"error": f"Failed to update profile: {e!s}"}, 500
@bp.route("/password", methods=["PUT"]) @bp.route("/password", methods=["PUT"])
@@ -331,9 +330,10 @@ def update_profile():
def change_password(): def change_password():
"""Change or set user password.""" """Change or set user password."""
from flask import request from flask import request
from werkzeug.security import check_password_hash
from app.database import db from app.database import db
from app.models.user import User from app.models.user import User
from werkzeug.security import check_password_hash
data = request.get_json() data = request.get_json()
if not data: if not data:
@@ -365,7 +365,7 @@ def change_password():
# User has a password AND logged in via password, require current password for verification # User has a password AND logged in via password, require current password for verification
if not current_password: if not current_password:
return { return {
"error": "Current password is required to change password" "error": "Current password is required to change password",
}, 400 }, 400
if not check_password_hash(user.password_hash, current_password): if not check_password_hash(user.password_hash, current_password):
@@ -380,4 +380,4 @@ def change_password():
return {"message": "Password updated successfully"} return {"message": "Password updated successfully"}
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
return {"error": f"Failed to update password: {str(e)}"}, 500 return {"error": f"Failed to update password: {e!s}"}, 500

View File

@@ -5,8 +5,8 @@ from flask import Blueprint
from app.services.decorators import ( from app.services.decorators import (
get_current_user, get_current_user,
require_auth, require_auth,
require_role,
require_credits, require_credits,
require_role,
) )
bp = Blueprint("main", __name__) bp = Blueprint("main", __name__)

View File

@@ -89,10 +89,10 @@ class AuthService:
# Generate JWT tokens # Generate JWT tokens
access_token = self.token_service.generate_access_token( access_token = self.token_service.generate_access_token(
jwt_user_data jwt_user_data,
) )
refresh_token = self.token_service.generate_refresh_token( refresh_token = self.token_service.generate_refresh_token(
jwt_user_data jwt_user_data,
) )
# Create response and set HTTP-only cookies # Create response and set HTTP-only cookies
@@ -100,7 +100,7 @@ class AuthService:
{ {
"message": "Login successful", "message": "Login successful",
"user": jwt_user_data, "user": jwt_user_data,
} },
) )
# Set JWT cookies # Set JWT cookies
@@ -149,7 +149,7 @@ class AuthService:
return None return None
def register_with_password( def register_with_password(
self, email: str, password: str, name: str self, email: str, password: str, name: str,
) -> Any: ) -> Any:
"""Register new user with email and password.""" """Register new user with email and password."""
try: try:
@@ -164,10 +164,10 @@ class AuthService:
# Generate JWT tokens # Generate JWT tokens
access_token = self.token_service.generate_access_token( access_token = self.token_service.generate_access_token(
jwt_user_data jwt_user_data,
) )
refresh_token = self.token_service.generate_refresh_token( refresh_token = self.token_service.generate_refresh_token(
jwt_user_data jwt_user_data,
) )
# Create response and set HTTP-only cookies # Create response and set HTTP-only cookies
@@ -175,7 +175,7 @@ class AuthService:
{ {
"message": "Registration successful", "message": "Registration successful",
"user": jwt_user_data, "user": jwt_user_data,
} },
) )
# Set JWT cookies # Set JWT cookies
@@ -196,7 +196,7 @@ class AuthService:
if not user: if not user:
response = jsonify( response = jsonify(
{"error": "Invalid email, password or disabled account"} {"error": "Invalid email, password or disabled account"},
) )
response.status_code = 401 response.status_code = 401
return response return response
@@ -216,7 +216,7 @@ class AuthService:
{ {
"message": "Login successful", "message": "Login successful",
"user": jwt_user_data, "user": jwt_user_data,
} },
) )
# Set JWT cookies # Set JWT cookies

View File

@@ -4,7 +4,7 @@ from functools import wraps
from typing import Any from typing import Any
from flask import jsonify, request from flask import jsonify, request
from flask_jwt_extended import get_jwt, get_jwt_identity, verify_jwt_in_request from flask_jwt_extended import get_jwt_identity, verify_jwt_in_request
def get_user_from_jwt() -> dict[str, Any] | None: def get_user_from_jwt() -> dict[str, Any] | None:
@@ -109,7 +109,7 @@ def require_auth(f):
if not user: if not user:
return ( return (
jsonify( jsonify(
{"error": "Authentication required (JWT or API token)"} {"error": "Authentication required (JWT or API token)"},
), ),
401, 401,
) )
@@ -133,8 +133,8 @@ def require_role(required_role: str):
return ( return (
jsonify( jsonify(
{ {
"error": f"Access denied. {required_role.title()} role required" "error": f"Access denied. {required_role.title()} role required",
} },
), ),
403, 403,
) )
@@ -152,8 +152,8 @@ def require_credits(credits_needed: int):
def decorator(f): def decorator(f):
@wraps(f) @wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
from app.models.user import User
from app.database import db from app.database import db
from app.models.user import User
# First check authentication # First check authentication
user_data = get_current_user() user_data = get_current_user()
@@ -170,8 +170,8 @@ def require_credits(credits_needed: int):
return ( return (
jsonify( jsonify(
{ {
"error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}" "error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}",
} },
), ),
402, # Payment Required status code 402, # Payment Required status code
) )

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Any, Optional from typing import Any
from authlib.integrations.flask_client import OAuth from authlib.integrations.flask_client import OAuth
@@ -16,23 +17,19 @@ class OAuthProvider(ABC):
@abstractmethod @abstractmethod
def name(self) -> str: def name(self) -> str:
"""Provider name (e.g., 'google', 'github').""" """Provider name (e.g., 'google', 'github')."""
pass
@property @property
@abstractmethod @abstractmethod
def display_name(self) -> str: def display_name(self) -> str:
"""Human-readable provider name (e.g., 'Google', 'GitHub').""" """Human-readable provider name (e.g., 'Google', 'GitHub')."""
pass
@abstractmethod @abstractmethod
def get_client_config(self) -> Dict[str, Any]: def get_client_config(self) -> dict[str, Any]:
"""Return OAuth client configuration.""" """Return OAuth client configuration."""
pass
@abstractmethod @abstractmethod
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]: def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
"""Extract user information from OAuth token response.""" """Extract user information from OAuth token response."""
pass
def get_client(self): def get_client(self):
"""Get or create OAuth client.""" """Get or create OAuth client."""
@@ -52,14 +49,14 @@ class OAuthProvider(ABC):
return client.authorize_redirect(redirect_uri).location return client.authorize_redirect(redirect_uri).location
def exchange_code_for_token( def exchange_code_for_token(
self, code: str = None, redirect_uri: str = None self, code: str = None, redirect_uri: str = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Exchange authorization code for access token.""" """Exchange authorization code for access token."""
client = self.get_client() client = self.get_client()
token = client.authorize_access_token() token = client.authorize_access_token()
return token return token
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]: def normalize_user_data(self, user_info: dict[str, Any]) -> dict[str, Any]:
"""Normalize user data to common format.""" """Normalize user data to common format."""
return { return {
"id": user_info.get("id"), "id": user_info.get("id"),

View File

@@ -1,4 +1,5 @@
from typing import Dict, Any from typing import Any
from .base import OAuthProvider from .base import OAuthProvider
@@ -13,7 +14,7 @@ class GitHubOAuthProvider(OAuthProvider):
def display_name(self) -> str: def display_name(self) -> str:
return "GitHub" return "GitHub"
def get_client_config(self) -> Dict[str, Any]: def get_client_config(self) -> dict[str, Any]:
"""Return GitHub OAuth client configuration.""" """Return GitHub OAuth client configuration."""
return { return {
"access_token_url": "https://github.com/login/oauth/access_token", "access_token_url": "https://github.com/login/oauth/access_token",
@@ -22,7 +23,7 @@ class GitHubOAuthProvider(OAuthProvider):
"client_kwargs": {"scope": "user:email"}, "client_kwargs": {"scope": "user:email"},
} }
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]: def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
"""Extract user information from GitHub OAuth token response.""" """Extract user information from GitHub OAuth token response."""
client = self.get_client() client = self.get_client()

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict from typing import Any
from .base import OAuthProvider from .base import OAuthProvider
@@ -14,14 +14,14 @@ class GoogleOAuthProvider(OAuthProvider):
def display_name(self) -> str: def display_name(self) -> str:
return "Google" return "Google"
def get_client_config(self) -> Dict[str, Any]: def get_client_config(self) -> dict[str, Any]:
"""Return Google OAuth client configuration.""" """Return Google OAuth client configuration."""
return { return {
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
"client_kwargs": {"scope": "openid email profile"}, "client_kwargs": {"scope": "openid email profile"},
} }
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]: def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]:
"""Extract user information from Google OAuth token response.""" """Extract user information from Google OAuth token response."""
client = self.get_client() client = self.get_client()
user_info = client.userinfo(token=token) user_info = client.userinfo(token=token)

View File

@@ -1,9 +1,10 @@
import os import os
from typing import Dict, Optional
from authlib.integrations.flask_client import OAuth from authlib.integrations.flask_client import OAuth
from .base import OAuthProvider from .base import OAuthProvider
from .google import GoogleOAuthProvider
from .github import GitHubOAuthProvider from .github import GitHubOAuthProvider
from .google import GoogleOAuthProvider
class OAuthProviderRegistry: class OAuthProviderRegistry:
@@ -11,7 +12,7 @@ class OAuthProviderRegistry:
def __init__(self, oauth: OAuth): def __init__(self, oauth: OAuth):
self.oauth = oauth self.oauth = oauth
self._providers: Dict[str, OAuthProvider] = {} self._providers: dict[str, OAuthProvider] = {}
self._initialize_providers() self._initialize_providers()
def _initialize_providers(self): def _initialize_providers(self):
@@ -21,7 +22,7 @@ class OAuthProviderRegistry:
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET") google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
if google_client_id and google_client_secret: if google_client_id and google_client_secret:
self._providers["google"] = GoogleOAuthProvider( self._providers["google"] = GoogleOAuthProvider(
self.oauth, google_client_id, google_client_secret self.oauth, google_client_id, google_client_secret,
) )
# GitHub OAuth # GitHub OAuth
@@ -29,14 +30,14 @@ class OAuthProviderRegistry:
github_client_secret = os.getenv("GITHUB_CLIENT_SECRET") github_client_secret = os.getenv("GITHUB_CLIENT_SECRET")
if github_client_id and github_client_secret: if github_client_id and github_client_secret:
self._providers["github"] = GitHubOAuthProvider( self._providers["github"] = GitHubOAuthProvider(
self.oauth, github_client_id, github_client_secret self.oauth, github_client_id, github_client_secret,
) )
def get_provider(self, name: str) -> Optional[OAuthProvider]: def get_provider(self, name: str) -> OAuthProvider | None:
"""Get OAuth provider by name.""" """Get OAuth provider by name."""
return self._providers.get(name) return self._providers.get(name)
def get_available_providers(self) -> Dict[str, OAuthProvider]: def get_available_providers(self) -> dict[str, OAuthProvider]:
"""Get all available providers.""" """Get all available providers."""
return self._providers.copy() return self._providers.copy()

View File

@@ -1,8 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Database migration script for Flask-Migrate.""" """Database migration script for Flask-Migrate."""
import os
from flask.cli import FlaskGroup from flask.cli import FlaskGroup
from app import create_app from app import create_app
from app.database import db from app.database import db

View File

View File

@@ -1,94 +0,0 @@
"""Tests for authentication routes with Flask-JWT-Extended."""
from unittest.mock import Mock, patch
import pytest
from app import create_app
@pytest.fixture
def client():
"""Create a test client for the Flask application."""
app = create_app()
app.config["TESTING"] = True
app.config["JWT_COOKIE_SECURE"] = False # Allow cookies in testing
with app.test_client() as client:
yield client
class TestAuthRoutesJWTExtended:
"""Test cases for authentication routes with Flask-JWT-Extended."""
@patch("app.routes.auth.auth_service.get_login_url")
def test_login_route(self, mock_get_login_url: Mock, client) -> None:
"""Test the login route."""
mock_get_login_url.return_value = (
"https://accounts.google.com/oauth/authorize?..."
)
response = client.get("/api/auth/login")
assert response.status_code == 200
data = response.get_json()
assert "login_url" in data
assert (
data["login_url"]
== "https://accounts.google.com/oauth/authorize?..."
)
@patch("app.routes.auth.auth_service.handle_callback")
def test_callback_route_success(
self, mock_handle_callback: Mock, client
) -> None:
"""Test successful callback route."""
mock_response = Mock()
mock_response.get_json.return_value = {
"message": "Login successful",
"user": {
"id": "123",
"email": "test@example.com",
"name": "Test User",
},
}
mock_handle_callback.return_value = mock_response
response = client.get("/api/auth/callback?code=test_code")
mock_handle_callback.assert_called_once()
@patch("app.routes.auth.auth_service.handle_callback")
def test_callback_route_error(
self, mock_handle_callback: Mock, client
) -> None:
"""Test callback route with error."""
mock_handle_callback.side_effect = Exception("OAuth error")
response = client.get("/api/auth/callback?code=test_code")
assert response.status_code == 400
data = response.get_json()
assert data["error"] == "OAuth error"
@patch("app.routes.auth.auth_service.logout")
def test_logout_route(self, mock_logout: Mock, client) -> None:
"""Test logout route."""
mock_response = Mock()
mock_response.get_json.return_value = {
"message": "Logged out successfully"
}
mock_logout.return_value = mock_response
response = client.get("/api/auth/logout")
mock_logout.assert_called_once()
def test_me_route_not_authenticated(self, client) -> None:
"""Test /me route when not authenticated."""
response = client.get("/api/auth/me")
assert response.status_code == 401
data = response.get_json()
assert "msg" in data # Flask-JWT-Extended error format
def test_refresh_route_not_authenticated(self, client) -> None:
"""Test /refresh route when not authenticated."""
response = client.post("/api/auth/refresh")
assert response.status_code == 401
data = response.get_json()
assert "msg" in data # Flask-JWT-Extended error format

View File

@@ -1,50 +0,0 @@
"""Tests for AuthService with Flask-JWT-Extended."""
from unittest.mock import Mock, patch
from app import create_app
from app.services.auth_service import AuthService
class TestAuthServiceJWTExtended:
"""Test cases for AuthService with Flask-JWT-Extended."""
def test_init_without_app(self) -> None:
"""Test initializing AuthService without Flask app."""
auth_service = AuthService()
assert auth_service.oauth is not None
assert auth_service.google is None
assert auth_service.token_service is not None
@patch("app.services.auth_service.os.getenv")
def test_init_app(self, mock_getenv: Mock) -> None:
"""Test initializing AuthService with Flask app."""
mock_getenv.side_effect = lambda key: {
"GOOGLE_CLIENT_ID": "test_client_id",
"GOOGLE_CLIENT_SECRET": "test_client_secret",
}.get(key)
app = create_app()
auth_service = AuthService()
auth_service.init_app(app)
# Verify OAuth was initialized
assert auth_service.google is not None
@patch("app.services.auth_service.unset_jwt_cookies")
@patch("app.services.auth_service.jsonify")
def test_logout(self, mock_jsonify: Mock, mock_unset: Mock) -> None:
"""Test logout functionality."""
app = create_app()
with app.app_context():
mock_response = Mock()
mock_jsonify.return_value = mock_response
auth_service = AuthService()
result = auth_service.logout()
assert result == mock_response
mock_unset.assert_called_once_with(mock_response)
mock_jsonify.assert_called_once_with(
{"message": "Logged out successfully"}
)

View File

@@ -1,40 +0,0 @@
"""Tests for routes."""
import pytest
from app import create_app
@pytest.fixture
def client():
"""Create a test client for the Flask application."""
app = create_app()
app.config["TESTING"] = True
with app.test_client() as client:
yield client
class TestMainRoutes:
"""Test cases for main routes."""
def test_index_route(self, client) -> None:
"""Test the index route."""
response = client.get("/api/")
assert response.status_code == 200
assert response.get_json() == {
"message": "API is running",
"status": "ok",
}
def test_health_route(self, client) -> None:
"""Test health check route."""
response = client.get("/api/health")
assert response.status_code == 200
assert response.get_json() == {"status": "ok"}
def test_protected_route_without_auth(self, client) -> None:
"""Test protected route without authentication."""
response = client.get("/api/protected")
assert response.status_code == 401
data = response.get_json()
assert data["error"] == "Authentication required (JWT or API token)"

View File

@@ -1,167 +0,0 @@
"""Tests for TokenService."""
from datetime import datetime, timezone
from unittest.mock import patch
import jwt
import pytest
from app.services.token_service import TokenService
class TestTokenService:
"""Test cases for TokenService."""
def test_init(self) -> None:
"""Test TokenService initialization."""
token_service = TokenService()
assert token_service.algorithm == "HS256"
assert token_service.access_token_expire_minutes == 15
assert token_service.refresh_token_expire_days == 7
def test_generate_access_token(self) -> None:
"""Test access token generation."""
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_access_token(user_data)
assert isinstance(token, str)
# Verify token content
payload = jwt.decode(
token,
token_service.secret_key,
algorithms=[token_service.algorithm],
)
assert payload["user_id"] == "123"
assert payload["email"] == "test@example.com"
assert payload["name"] == "Test User"
assert payload["type"] == "access"
def test_generate_refresh_token(self) -> None:
"""Test refresh token generation."""
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_refresh_token(user_data)
assert isinstance(token, str)
# Verify token content
payload = jwt.decode(
token,
token_service.secret_key,
algorithms=[token_service.algorithm],
)
assert payload["user_id"] == "123"
assert payload["type"] == "refresh"
def test_verify_valid_token(self) -> None:
"""Test verifying a valid token."""
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_access_token(user_data)
payload = token_service.verify_token(token)
assert payload is not None
assert payload["user_id"] == "123"
assert payload["type"] == "access"
def test_verify_invalid_token(self) -> None:
"""Test verifying an invalid token."""
token_service = TokenService()
payload = token_service.verify_token("invalid.token.here")
assert payload is None
@patch("app.services.token_service.datetime")
def test_verify_expired_token(self, mock_datetime) -> None:
"""Test verifying an expired token."""
# Set up mock to return a past time for token generation
past_time = datetime(2020, 1, 1, tzinfo=timezone.utc)
mock_datetime.now.return_value = past_time
mock_datetime.UTC = timezone.utc
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_access_token(user_data)
# Reset mock to current time for verification
mock_datetime.now.return_value = datetime.now(timezone.utc)
payload = token_service.verify_token(token)
assert payload is None
def test_is_access_token(self) -> None:
"""Test access token type checking."""
token_service = TokenService()
access_payload = {"type": "access", "user_id": "123"}
refresh_payload = {"type": "refresh", "user_id": "123"}
assert token_service.is_access_token(access_payload)
assert not token_service.is_access_token(refresh_payload)
def test_is_refresh_token(self) -> None:
"""Test refresh token type checking."""
token_service = TokenService()
access_payload = {"type": "access", "user_id": "123"}
refresh_payload = {"type": "refresh", "user_id": "123"}
assert token_service.is_refresh_token(refresh_payload)
assert not token_service.is_refresh_token(access_payload)
def test_get_user_from_access_token_valid(self) -> None:
"""Test extracting user from valid access token."""
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_access_token(user_data)
extracted_user = token_service.get_user_from_access_token(token)
assert extracted_user == user_data
def test_get_user_from_access_token_refresh_token(self) -> None:
"""Test extracting user from refresh token (should fail)."""
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_refresh_token(user_data)
extracted_user = token_service.get_user_from_access_token(token)
assert extracted_user is None
def test_get_user_from_access_token_invalid(self) -> None:
"""Test extracting user from invalid token."""
token_service = TokenService()
extracted_user = token_service.get_user_from_access_token(
"invalid.token"
)
assert extracted_user is None

View File

@@ -1,57 +0,0 @@
"""Tests for TokenService using Flask-JWT-Extended."""
from unittest.mock import patch
from app import create_app
from app.services.token_service import TokenService
class TestTokenServiceJWTExtended:
"""Test cases for TokenService with Flask-JWT-Extended."""
def test_generate_access_token(self) -> None:
"""Test access token generation."""
app = create_app()
with app.app_context():
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
"picture": "https://example.com/pic.jpg",
}
token = token_service.generate_access_token(user_data)
assert isinstance(token, str)
assert len(token) > 0
def test_generate_refresh_token(self) -> None:
"""Test refresh token generation."""
app = create_app()
with app.app_context():
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
token = token_service.generate_refresh_token(user_data)
assert isinstance(token, str)
assert len(token) > 0
def test_generate_tokens_different(self) -> None:
"""Test that access and refresh tokens are different."""
app = create_app()
with app.app_context():
token_service = TokenService()
user_data = {
"id": "123",
"email": "test@example.com",
"name": "Test User",
}
access_token = token_service.generate_access_token(user_data)
refresh_token = token_service.generate_refresh_token(user_data)
assert access_token != refresh_token