diff --git a/app/__init__.py b/app/__init__.py index fceecfd..18bf292 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -26,7 +26,7 @@ def create_app(): # Configure Flask-JWT-Extended app.config["JWT_SECRET_KEY"] = os.environ.get( - "JWT_SECRET_KEY", "jwt-secret-key" + "JWT_SECRET_KEY", "jwt-secret-key", ) app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(minutes=15) app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=7) diff --git a/app/database.py b/app/database.py index 71588bd..20349b6 100644 --- a/app/database.py +++ b/app/database.py @@ -1,7 +1,7 @@ """Database configuration and initialization.""" -from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate +from flask_sqlalchemy import SQLAlchemy db = SQLAlchemy() migrate = Migrate() diff --git a/app/database_init.py b/app/database_init.py index d48bcd4..2b3c7b7 100644 --- a/app/database_init.py +++ b/app/database_init.py @@ -68,7 +68,7 @@ def migrate_users_to_plans(): # 0 credits means they spent them, NULL means they never got assigned try: users_without_credits = User.query.filter( - User.plan_id.isnot(None), User.credits.is_(None) + User.plan_id.isnot(None), User.credits.is_(None), ).all() except Exception: # Credits column doesn't exist yet, will be handled by create_all @@ -113,7 +113,7 @@ def migrate_users_to_plans(): if updated_count > 0: db.session.commit() print( - f"Updated {updated_count} existing users with plans and credits" + f"Updated {updated_count} existing users with plans and credits", ) except Exception: diff --git a/app/models/user.py b/app/models/user.py index 0ee6234..9b84405 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -2,17 +2,17 @@ import secrets from datetime import datetime -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional -from werkzeug.security import check_password_hash, generate_password_hash -from sqlalchemy import String, DateTime, Integer, ForeignKey +from sqlalchemy import DateTime, ForeignKey, Integer, String from sqlalchemy.orm import Mapped, mapped_column, relationship +from werkzeug.security import check_password_hash, generate_password_hash from app.database import db if TYPE_CHECKING: - from app.models.user_oauth import UserOAuth from app.models.plan import Plan + from app.models.user_oauth import UserOAuth class User(db.Model): @@ -25,16 +25,16 @@ class User(db.Model): # Primary user information (can be updated from any connected provider) email: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) name: Mapped[str] = mapped_column(String(255), nullable=False) - picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True) + picture: Mapped[str | None] = mapped_column(String(500), nullable=True) # Password authentication (optional - users can use OAuth instead) - password_hash: Mapped[Optional[str]] = mapped_column( - String(255), nullable=True + password_hash: Mapped[str | None] = mapped_column( + String(255), nullable=True, ) # Role-based access control role: Mapped[str] = mapped_column( - String(50), nullable=False, default="user" + String(50), nullable=False, default="user", ) # User status @@ -42,21 +42,21 @@ class User(db.Model): # Plan relationship plan_id: Mapped[int] = mapped_column( - Integer, ForeignKey("plans.id"), nullable=False + Integer, ForeignKey("plans.id"), nullable=False, ) # User credits (populated from plan credits on creation) credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0) # API token for programmatic access - api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - api_token_expires_at: Mapped[Optional[datetime]] = mapped_column( - DateTime, nullable=True + api_token: Mapped[str | None] = mapped_column(String(255), nullable=True) + api_token_expires_at: Mapped[datetime | None] = mapped_column( + DateTime, nullable=True, ) # Timestamps created_at: Mapped[datetime] = mapped_column( - DateTime, default=datetime.utcnow, nullable=False + DateTime, default=datetime.utcnow, nullable=False, ) updated_at: Mapped[datetime] = mapped_column( DateTime, @@ -67,7 +67,7 @@ class User(db.Model): # Relationships oauth_providers: Mapped[list["UserOAuth"]] = relationship( - "UserOAuth", back_populates="user", cascade="all, delete-orphan" + "UserOAuth", back_populates="user", cascade="all, delete-orphan", ) plan: Mapped["Plan"] = relationship("Plan", back_populates="users") @@ -190,15 +190,15 @@ class User(db.Model): provider_id: str, email: str, name: str, - picture: Optional[str] = None, + picture: str | None = None, ) -> tuple["User", "UserOAuth"]: """Find existing user or create new one from OAuth data.""" - from app.models.user_oauth import UserOAuth from app.models.plan import Plan + from app.models.user_oauth import UserOAuth # First, try to find existing OAuth provider oauth_provider = UserOAuth.find_by_provider_and_id( - provider, provider_id + provider, provider_id, ) if oauth_provider: @@ -211,7 +211,7 @@ class User(db.Model): # Update user info with latest data user.update_from_provider( - {"email": email, "name": name, "picture": picture} + {"email": email, "name": name, "picture": picture}, ) else: # Try to find user by email to link the new provider @@ -256,7 +256,7 @@ class User(db.Model): @classmethod def create_with_password( - cls, email: str, password: str, name: str + cls, email: str, password: str, name: str, ) -> "User": """Create new user with email and password.""" from app.models.plan import Plan @@ -293,7 +293,7 @@ class User(db.Model): @classmethod def authenticate_with_password( - cls, email: str, password: str + cls, email: str, password: str, ) -> Optional["User"]: """Authenticate user with email and password.""" user = cls.find_by_email(email) diff --git a/app/models/user_oauth.py b/app/models/user_oauth.py index b034030..78e4703 100644 --- a/app/models/user_oauth.py +++ b/app/models/user_oauth.py @@ -1,9 +1,9 @@ """User OAuth model for storing user's connected providers.""" from datetime import datetime -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional -from sqlalchemy import String, DateTime, Text, ForeignKey +from sqlalchemy import DateTime, ForeignKey, String, Text from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import db @@ -29,11 +29,11 @@ class UserOAuth(db.Model): # Provider-specific user information email: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + picture: Mapped[str | None] = mapped_column(Text, nullable=True) # Timestamps created_at: Mapped[datetime] = mapped_column( - DateTime, default=datetime.utcnow, nullable=False + DateTime, default=datetime.utcnow, nullable=False, ) updated_at: Mapped[datetime] = mapped_column( DateTime, @@ -45,13 +45,13 @@ class UserOAuth(db.Model): # Unique constraint on provider + provider_id combination __table_args__ = ( db.UniqueConstraint( - "provider", "provider_id", name="unique_provider_user" + "provider", "provider_id", name="unique_provider_user", ), ) # Relationships user: Mapped["User"] = relationship( - "User", back_populates="oauth_providers" + "User", back_populates="oauth_providers", ) def __repr__(self) -> str: @@ -73,11 +73,11 @@ class UserOAuth(db.Model): @classmethod def find_by_provider_and_id( - cls, provider: str, provider_id: str + cls, provider: str, provider_id: str, ) -> Optional["UserOAuth"]: """Find OAuth provider by provider name and provider ID.""" return cls.query.filter_by( - provider=provider, provider_id=provider_id + provider=provider, provider_id=provider_id, ).first() @classmethod @@ -88,7 +88,7 @@ class UserOAuth(db.Model): provider_id: str, email: str, name: str, - picture: Optional[str] = None, + picture: str | None = None, ) -> "UserOAuth": """Create new OAuth provider or update existing one.""" oauth_provider = cls.find_by_provider_and_id(provider, provider_id) diff --git a/app/routes/auth.py b/app/routes/auth.py index b62b6a9..23496d9 100644 --- a/app/routes/auth.py +++ b/app/routes/auth.py @@ -31,7 +31,7 @@ def callback(provider): # If successful, redirect to frontend dashboard with cookies if auth_response.status_code == 200: redirect_response = make_response( - redirect("http://localhost:3000/dashboard") + redirect("http://localhost:3000/dashboard"), ) # Copy all cookies from the auth response @@ -39,9 +39,8 @@ def callback(provider): redirect_response.headers.add("Set-Cookie", cookie) return redirect_response - else: - # If there was an error, redirect to login with error - return redirect("http://localhost:3000/login?error=oauth_failed") + # If there was an error, redirect to login with error + return redirect("http://localhost:3000/login?error=oauth_failed") except Exception as e: error_msg = str(e).replace(" ", "_").replace('"', "") @@ -129,7 +128,7 @@ def refresh(): def link_provider(provider): """Link a new OAuth provider to current user account.""" redirect_uri = url_for( - "auth.link_callback", provider=provider, _external=True + "auth.link_callback", provider=provider, _external=True, ) return auth_service.redirect_to_login(provider, redirect_uri) @@ -168,19 +167,19 @@ def link_callback(provider): if not provider_data.get("id"): return { - "error": "Failed to get user information from provider" + "error": "Failed to get user information from provider", }, 400 # Check if this provider is already linked to another user from app.models.user_oauth import UserOAuth existing_provider = UserOAuth.find_by_provider_and_id( - provider, provider_data["id"] + provider, provider_data["id"], ) if existing_provider and existing_provider.user_id != user.id: return { - "error": "This provider account is already linked to another user" + "error": "This provider account is already linked to another user", }, 409 # Link the provider to current user @@ -210,7 +209,6 @@ def unlink_provider(provider): from app.database import db from app.models.user import User - from app.models.user_oauth import UserOAuth user = User.query.get(current_user_id) if not user: @@ -224,7 +222,7 @@ def unlink_provider(provider): oauth_provider = user.get_provider(provider) if not oauth_provider: return { - "error": f"Provider '{provider}' not linked to this account" + "error": f"Provider '{provider}' not linked to this account", }, 404 db.session.delete(oauth_provider) @@ -279,6 +277,7 @@ def me(): def update_profile(): """Update current user profile information.""" from flask import request + from app.database import db from app.models.user import User @@ -323,7 +322,7 @@ def update_profile(): return {"message": "Profile updated successfully", "user": updated_user} except Exception as e: db.session.rollback() - return {"error": f"Failed to update profile: {str(e)}"}, 500 + return {"error": f"Failed to update profile: {e!s}"}, 500 @bp.route("/password", methods=["PUT"]) @@ -331,9 +330,10 @@ def update_profile(): def change_password(): """Change or set user password.""" from flask import request + from werkzeug.security import check_password_hash + from app.database import db from app.models.user import User - from werkzeug.security import check_password_hash data = request.get_json() if not data: @@ -365,7 +365,7 @@ def change_password(): # User has a password AND logged in via password, require current password for verification if not current_password: return { - "error": "Current password is required to change password" + "error": "Current password is required to change password", }, 400 if not check_password_hash(user.password_hash, current_password): @@ -380,4 +380,4 @@ def change_password(): return {"message": "Password updated successfully"} except Exception as e: db.session.rollback() - return {"error": f"Failed to update password: {str(e)}"}, 500 + return {"error": f"Failed to update password: {e!s}"}, 500 diff --git a/app/routes/main.py b/app/routes/main.py index 4562a72..c5f8aab 100644 --- a/app/routes/main.py +++ b/app/routes/main.py @@ -5,8 +5,8 @@ from flask import Blueprint from app.services.decorators import ( get_current_user, require_auth, - require_role, require_credits, + require_role, ) bp = Blueprint("main", __name__) diff --git a/app/services/auth_service.py b/app/services/auth_service.py index 1f87dba..40f202c 100644 --- a/app/services/auth_service.py +++ b/app/services/auth_service.py @@ -89,10 +89,10 @@ class AuthService: # Generate JWT tokens access_token = self.token_service.generate_access_token( - jwt_user_data + jwt_user_data, ) refresh_token = self.token_service.generate_refresh_token( - jwt_user_data + jwt_user_data, ) # Create response and set HTTP-only cookies @@ -100,7 +100,7 @@ class AuthService: { "message": "Login successful", "user": jwt_user_data, - } + }, ) # Set JWT cookies @@ -149,7 +149,7 @@ class AuthService: return None def register_with_password( - self, email: str, password: str, name: str + self, email: str, password: str, name: str, ) -> Any: """Register new user with email and password.""" try: @@ -164,10 +164,10 @@ class AuthService: # Generate JWT tokens access_token = self.token_service.generate_access_token( - jwt_user_data + jwt_user_data, ) refresh_token = self.token_service.generate_refresh_token( - jwt_user_data + jwt_user_data, ) # Create response and set HTTP-only cookies @@ -175,7 +175,7 @@ class AuthService: { "message": "Registration successful", "user": jwt_user_data, - } + }, ) # Set JWT cookies @@ -196,7 +196,7 @@ class AuthService: if not user: response = jsonify( - {"error": "Invalid email, password or disabled account"} + {"error": "Invalid email, password or disabled account"}, ) response.status_code = 401 return response @@ -216,7 +216,7 @@ class AuthService: { "message": "Login successful", "user": jwt_user_data, - } + }, ) # Set JWT cookies diff --git a/app/services/decorators.py b/app/services/decorators.py index b508241..f5dc582 100644 --- a/app/services/decorators.py +++ b/app/services/decorators.py @@ -4,7 +4,7 @@ from functools import wraps from typing import Any from flask import jsonify, request -from flask_jwt_extended import get_jwt, get_jwt_identity, verify_jwt_in_request +from flask_jwt_extended import get_jwt_identity, verify_jwt_in_request def get_user_from_jwt() -> dict[str, Any] | None: @@ -109,7 +109,7 @@ def require_auth(f): if not user: return ( jsonify( - {"error": "Authentication required (JWT or API token)"} + {"error": "Authentication required (JWT or API token)"}, ), 401, ) @@ -133,8 +133,8 @@ def require_role(required_role: str): return ( jsonify( { - "error": f"Access denied. {required_role.title()} role required" - } + "error": f"Access denied. {required_role.title()} role required", + }, ), 403, ) @@ -152,8 +152,8 @@ def require_credits(credits_needed: int): def decorator(f): @wraps(f) def wrapper(*args, **kwargs): - from app.models.user import User from app.database import db + from app.models.user import User # First check authentication user_data = get_current_user() @@ -170,8 +170,8 @@ def require_credits(credits_needed: int): return ( jsonify( { - "error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}" - } + "error": f"Insufficient credits. Required: {credits_needed}, Available: {user.credits}", + }, ), 402, # Payment Required status code ) diff --git a/app/services/oauth_providers/base.py b/app/services/oauth_providers/base.py index 471a89f..725b8d0 100644 --- a/app/services/oauth_providers/base.py +++ b/app/services/oauth_providers/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Dict, Any, Optional +from typing import Any + from authlib.integrations.flask_client import OAuth @@ -16,23 +17,19 @@ class OAuthProvider(ABC): @abstractmethod def name(self) -> str: """Provider name (e.g., 'google', 'github').""" - pass @property @abstractmethod def display_name(self) -> str: """Human-readable provider name (e.g., 'Google', 'GitHub').""" - pass @abstractmethod - def get_client_config(self) -> Dict[str, Any]: + def get_client_config(self) -> dict[str, Any]: """Return OAuth client configuration.""" - pass @abstractmethod - def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]: + def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]: """Extract user information from OAuth token response.""" - pass def get_client(self): """Get or create OAuth client.""" @@ -52,14 +49,14 @@ class OAuthProvider(ABC): return client.authorize_redirect(redirect_uri).location def exchange_code_for_token( - self, code: str = None, redirect_uri: str = None - ) -> Dict[str, Any]: + self, code: str = None, redirect_uri: str = None, + ) -> dict[str, Any]: """Exchange authorization code for access token.""" client = self.get_client() token = client.authorize_access_token() return token - def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]: + def normalize_user_data(self, user_info: dict[str, Any]) -> dict[str, Any]: """Normalize user data to common format.""" return { "id": user_info.get("id"), diff --git a/app/services/oauth_providers/github.py b/app/services/oauth_providers/github.py index c11a958..eac11f2 100644 --- a/app/services/oauth_providers/github.py +++ b/app/services/oauth_providers/github.py @@ -1,4 +1,5 @@ -from typing import Dict, Any +from typing import Any + from .base import OAuthProvider @@ -13,7 +14,7 @@ class GitHubOAuthProvider(OAuthProvider): def display_name(self) -> str: return "GitHub" - def get_client_config(self) -> Dict[str, Any]: + def get_client_config(self) -> dict[str, Any]: """Return GitHub OAuth client configuration.""" return { "access_token_url": "https://github.com/login/oauth/access_token", @@ -22,7 +23,7 @@ class GitHubOAuthProvider(OAuthProvider): "client_kwargs": {"scope": "user:email"}, } - def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]: + def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]: """Extract user information from GitHub OAuth token response.""" client = self.get_client() diff --git a/app/services/oauth_providers/google.py b/app/services/oauth_providers/google.py index 33781dd..1925554 100644 --- a/app/services/oauth_providers/google.py +++ b/app/services/oauth_providers/google.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any from .base import OAuthProvider @@ -14,14 +14,14 @@ class GoogleOAuthProvider(OAuthProvider): def display_name(self) -> str: return "Google" - def get_client_config(self) -> Dict[str, Any]: + def get_client_config(self) -> dict[str, Any]: """Return Google OAuth client configuration.""" return { "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", "client_kwargs": {"scope": "openid email profile"}, } - def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]: + def get_user_info(self, token: dict[str, Any]) -> dict[str, Any]: """Extract user information from Google OAuth token response.""" client = self.get_client() user_info = client.userinfo(token=token) diff --git a/app/services/oauth_providers/registry.py b/app/services/oauth_providers/registry.py index b697d20..851746c 100644 --- a/app/services/oauth_providers/registry.py +++ b/app/services/oauth_providers/registry.py @@ -1,9 +1,10 @@ import os -from typing import Dict, Optional + from authlib.integrations.flask_client import OAuth + from .base import OAuthProvider -from .google import GoogleOAuthProvider from .github import GitHubOAuthProvider +from .google import GoogleOAuthProvider class OAuthProviderRegistry: @@ -11,7 +12,7 @@ class OAuthProviderRegistry: def __init__(self, oauth: OAuth): self.oauth = oauth - self._providers: Dict[str, OAuthProvider] = {} + self._providers: dict[str, OAuthProvider] = {} self._initialize_providers() def _initialize_providers(self): @@ -21,7 +22,7 @@ class OAuthProviderRegistry: google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET") if google_client_id and google_client_secret: self._providers["google"] = GoogleOAuthProvider( - self.oauth, google_client_id, google_client_secret + self.oauth, google_client_id, google_client_secret, ) # GitHub OAuth @@ -29,14 +30,14 @@ class OAuthProviderRegistry: github_client_secret = os.getenv("GITHUB_CLIENT_SECRET") if github_client_id and github_client_secret: self._providers["github"] = GitHubOAuthProvider( - self.oauth, github_client_id, github_client_secret + self.oauth, github_client_id, github_client_secret, ) - def get_provider(self, name: str) -> Optional[OAuthProvider]: + def get_provider(self, name: str) -> OAuthProvider | None: """Get OAuth provider by name.""" return self._providers.get(name) - def get_available_providers(self) -> Dict[str, OAuthProvider]: + def get_available_providers(self) -> dict[str, OAuthProvider]: """Get all available providers.""" return self._providers.copy() diff --git a/migrate_db.py b/migrate_db.py index 6446f51..a5699e8 100644 --- a/migrate_db.py +++ b/migrate_db.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 """Database migration script for Flask-Migrate.""" -import os from flask.cli import FlaskGroup + from app import create_app from app.database import db diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_auth_routes.py b/tests/test_auth_routes.py deleted file mode 100644 index 5adc93e..0000000 --- a/tests/test_auth_routes.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Tests for authentication routes with Flask-JWT-Extended.""" - -from unittest.mock import Mock, patch - -import pytest - -from app import create_app - - -@pytest.fixture -def client(): - """Create a test client for the Flask application.""" - app = create_app() - app.config["TESTING"] = True - app.config["JWT_COOKIE_SECURE"] = False # Allow cookies in testing - with app.test_client() as client: - yield client - - -class TestAuthRoutesJWTExtended: - """Test cases for authentication routes with Flask-JWT-Extended.""" - - @patch("app.routes.auth.auth_service.get_login_url") - def test_login_route(self, mock_get_login_url: Mock, client) -> None: - """Test the login route.""" - mock_get_login_url.return_value = ( - "https://accounts.google.com/oauth/authorize?..." - ) - - response = client.get("/api/auth/login") - assert response.status_code == 200 - data = response.get_json() - assert "login_url" in data - assert ( - data["login_url"] - == "https://accounts.google.com/oauth/authorize?..." - ) - - @patch("app.routes.auth.auth_service.handle_callback") - def test_callback_route_success( - self, mock_handle_callback: Mock, client - ) -> None: - """Test successful callback route.""" - mock_response = Mock() - mock_response.get_json.return_value = { - "message": "Login successful", - "user": { - "id": "123", - "email": "test@example.com", - "name": "Test User", - }, - } - mock_handle_callback.return_value = mock_response - - response = client.get("/api/auth/callback?code=test_code") - mock_handle_callback.assert_called_once() - - @patch("app.routes.auth.auth_service.handle_callback") - def test_callback_route_error( - self, mock_handle_callback: Mock, client - ) -> None: - """Test callback route with error.""" - mock_handle_callback.side_effect = Exception("OAuth error") - - response = client.get("/api/auth/callback?code=test_code") - assert response.status_code == 400 - data = response.get_json() - assert data["error"] == "OAuth error" - - @patch("app.routes.auth.auth_service.logout") - def test_logout_route(self, mock_logout: Mock, client) -> None: - """Test logout route.""" - mock_response = Mock() - mock_response.get_json.return_value = { - "message": "Logged out successfully" - } - mock_logout.return_value = mock_response - - response = client.get("/api/auth/logout") - mock_logout.assert_called_once() - - def test_me_route_not_authenticated(self, client) -> None: - """Test /me route when not authenticated.""" - response = client.get("/api/auth/me") - assert response.status_code == 401 - data = response.get_json() - assert "msg" in data # Flask-JWT-Extended error format - - def test_refresh_route_not_authenticated(self, client) -> None: - """Test /refresh route when not authenticated.""" - response = client.post("/api/auth/refresh") - assert response.status_code == 401 - data = response.get_json() - assert "msg" in data # Flask-JWT-Extended error format diff --git a/tests/test_auth_service.py b/tests/test_auth_service.py deleted file mode 100644 index e08dfba..0000000 --- a/tests/test_auth_service.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Tests for AuthService with Flask-JWT-Extended.""" - -from unittest.mock import Mock, patch - -from app import create_app -from app.services.auth_service import AuthService - - -class TestAuthServiceJWTExtended: - """Test cases for AuthService with Flask-JWT-Extended.""" - - def test_init_without_app(self) -> None: - """Test initializing AuthService without Flask app.""" - auth_service = AuthService() - assert auth_service.oauth is not None - assert auth_service.google is None - assert auth_service.token_service is not None - - @patch("app.services.auth_service.os.getenv") - def test_init_app(self, mock_getenv: Mock) -> None: - """Test initializing AuthService with Flask app.""" - mock_getenv.side_effect = lambda key: { - "GOOGLE_CLIENT_ID": "test_client_id", - "GOOGLE_CLIENT_SECRET": "test_client_secret", - }.get(key) - - app = create_app() - auth_service = AuthService() - auth_service.init_app(app) - - # Verify OAuth was initialized - assert auth_service.google is not None - - @patch("app.services.auth_service.unset_jwt_cookies") - @patch("app.services.auth_service.jsonify") - def test_logout(self, mock_jsonify: Mock, mock_unset: Mock) -> None: - """Test logout functionality.""" - app = create_app() - with app.app_context(): - mock_response = Mock() - mock_jsonify.return_value = mock_response - - auth_service = AuthService() - result = auth_service.logout() - - assert result == mock_response - mock_unset.assert_called_once_with(mock_response) - mock_jsonify.assert_called_once_with( - {"message": "Logged out successfully"} - ) diff --git a/tests/test_routes.py b/tests/test_routes.py deleted file mode 100644 index 77db409..0000000 --- a/tests/test_routes.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Tests for routes.""" - -import pytest - -from app import create_app - - -@pytest.fixture -def client(): - """Create a test client for the Flask application.""" - app = create_app() - app.config["TESTING"] = True - with app.test_client() as client: - yield client - - -class TestMainRoutes: - """Test cases for main routes.""" - - def test_index_route(self, client) -> None: - """Test the index route.""" - response = client.get("/api/") - assert response.status_code == 200 - assert response.get_json() == { - "message": "API is running", - "status": "ok", - } - - def test_health_route(self, client) -> None: - """Test health check route.""" - response = client.get("/api/health") - assert response.status_code == 200 - assert response.get_json() == {"status": "ok"} - - def test_protected_route_without_auth(self, client) -> None: - """Test protected route without authentication.""" - response = client.get("/api/protected") - assert response.status_code == 401 - data = response.get_json() - assert data["error"] == "Authentication required (JWT or API token)" diff --git a/tests/test_token_service.py b/tests/test_token_service.py deleted file mode 100644 index 983e50a..0000000 --- a/tests/test_token_service.py +++ /dev/null @@ -1,167 +0,0 @@ -"""Tests for TokenService.""" - -from datetime import datetime, timezone -from unittest.mock import patch - -import jwt -import pytest - -from app.services.token_service import TokenService - - -class TestTokenService: - """Test cases for TokenService.""" - - def test_init(self) -> None: - """Test TokenService initialization.""" - token_service = TokenService() - assert token_service.algorithm == "HS256" - assert token_service.access_token_expire_minutes == 15 - assert token_service.refresh_token_expire_days == 7 - - def test_generate_access_token(self) -> None: - """Test access token generation.""" - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - } - - token = token_service.generate_access_token(user_data) - assert isinstance(token, str) - - # Verify token content - payload = jwt.decode( - token, - token_service.secret_key, - algorithms=[token_service.algorithm], - ) - assert payload["user_id"] == "123" - assert payload["email"] == "test@example.com" - assert payload["name"] == "Test User" - assert payload["type"] == "access" - - def test_generate_refresh_token(self) -> None: - """Test refresh token generation.""" - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - } - - token = token_service.generate_refresh_token(user_data) - assert isinstance(token, str) - - # Verify token content - payload = jwt.decode( - token, - token_service.secret_key, - algorithms=[token_service.algorithm], - ) - assert payload["user_id"] == "123" - assert payload["type"] == "refresh" - - def test_verify_valid_token(self) -> None: - """Test verifying a valid token.""" - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - } - - token = token_service.generate_access_token(user_data) - payload = token_service.verify_token(token) - - assert payload is not None - assert payload["user_id"] == "123" - assert payload["type"] == "access" - - def test_verify_invalid_token(self) -> None: - """Test verifying an invalid token.""" - token_service = TokenService() - - payload = token_service.verify_token("invalid.token.here") - assert payload is None - - @patch("app.services.token_service.datetime") - def test_verify_expired_token(self, mock_datetime) -> None: - """Test verifying an expired token.""" - # Set up mock to return a past time for token generation - past_time = datetime(2020, 1, 1, tzinfo=timezone.utc) - mock_datetime.now.return_value = past_time - mock_datetime.UTC = timezone.utc - - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - } - - token = token_service.generate_access_token(user_data) - - # Reset mock to current time for verification - mock_datetime.now.return_value = datetime.now(timezone.utc) - - payload = token_service.verify_token(token) - assert payload is None - - def test_is_access_token(self) -> None: - """Test access token type checking.""" - token_service = TokenService() - - access_payload = {"type": "access", "user_id": "123"} - refresh_payload = {"type": "refresh", "user_id": "123"} - - assert token_service.is_access_token(access_payload) - assert not token_service.is_access_token(refresh_payload) - - def test_is_refresh_token(self) -> None: - """Test refresh token type checking.""" - token_service = TokenService() - - access_payload = {"type": "access", "user_id": "123"} - refresh_payload = {"type": "refresh", "user_id": "123"} - - assert token_service.is_refresh_token(refresh_payload) - assert not token_service.is_refresh_token(access_payload) - - def test_get_user_from_access_token_valid(self) -> None: - """Test extracting user from valid access token.""" - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - } - - token = token_service.generate_access_token(user_data) - extracted_user = token_service.get_user_from_access_token(token) - - assert extracted_user == user_data - - def test_get_user_from_access_token_refresh_token(self) -> None: - """Test extracting user from refresh token (should fail).""" - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - } - - token = token_service.generate_refresh_token(user_data) - extracted_user = token_service.get_user_from_access_token(token) - - assert extracted_user is None - - def test_get_user_from_access_token_invalid(self) -> None: - """Test extracting user from invalid token.""" - token_service = TokenService() - - extracted_user = token_service.get_user_from_access_token( - "invalid.token" - ) - assert extracted_user is None diff --git a/tests/test_token_service_jwt_extended.py b/tests/test_token_service_jwt_extended.py deleted file mode 100644 index 1ad0b04..0000000 --- a/tests/test_token_service_jwt_extended.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Tests for TokenService using Flask-JWT-Extended.""" - -from unittest.mock import patch - -from app import create_app -from app.services.token_service import TokenService - - -class TestTokenServiceJWTExtended: - """Test cases for TokenService with Flask-JWT-Extended.""" - - def test_generate_access_token(self) -> None: - """Test access token generation.""" - app = create_app() - with app.app_context(): - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - "picture": "https://example.com/pic.jpg", - } - - token = token_service.generate_access_token(user_data) - assert isinstance(token, str) - assert len(token) > 0 - - def test_generate_refresh_token(self) -> None: - """Test refresh token generation.""" - app = create_app() - with app.app_context(): - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - } - - token = token_service.generate_refresh_token(user_data) - assert isinstance(token, str) - assert len(token) > 0 - - def test_generate_tokens_different(self) -> None: - """Test that access and refresh tokens are different.""" - app = create_app() - with app.app_context(): - token_service = TokenService() - user_data = { - "id": "123", - "email": "test@example.com", - "name": "Test User", - } - - access_token = token_service.generate_access_token(user_data) - refresh_token = token_service.generate_refresh_token(user_data) - - assert access_token != refresh_token