Refactor code for improved readability and consistency
- Cleaned up whitespace and formatting across multiple files for better readability.
This commit is contained in:
@@ -112,7 +112,9 @@ def migrate_users_to_plans():
|
||||
|
||||
if updated_count > 0:
|
||||
db.session.commit()
|
||||
print(f"Updated {updated_count} existing users with plans and credits")
|
||||
print(
|
||||
f"Updated {updated_count} existing users with plans and credits"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# If there's any error (like missing columns), just skip migration
|
||||
|
||||
@@ -28,30 +28,41 @@ class User(db.Model):
|
||||
picture: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
|
||||
|
||||
# Password authentication (optional - users can use OAuth instead)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
password_hash: Mapped[Optional[str]] = mapped_column(
|
||||
String(255), nullable=True
|
||||
)
|
||||
|
||||
# Role-based access control
|
||||
role: Mapped[str] = mapped_column(String(50), nullable=False, default="user")
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False, default="user"
|
||||
)
|
||||
|
||||
# User status
|
||||
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
|
||||
|
||||
# Plan relationship
|
||||
plan_id: Mapped[int] = mapped_column(Integer, ForeignKey("plans.id"), nullable=False)
|
||||
plan_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("plans.id"), nullable=False
|
||||
)
|
||||
|
||||
# User credits (populated from plan credits on creation)
|
||||
credits: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# API token for programmatic access
|
||||
api_token: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
|
||||
api_token_expires_at: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
@@ -82,7 +93,9 @@ class User(db.Model):
|
||||
"role": self.role,
|
||||
"is_active": self.is_active,
|
||||
"api_token": self.api_token,
|
||||
"api_token_expires_at": self.api_token_expires_at.isoformat() if self.api_token_expires_at else None,
|
||||
"api_token_expires_at": self.api_token_expires_at.isoformat()
|
||||
if self.api_token_expires_at
|
||||
else None,
|
||||
"providers": providers,
|
||||
"plan": self.plan.to_dict() if self.plan else None,
|
||||
"credits": self.credits,
|
||||
@@ -172,14 +185,21 @@ class User(db.Model):
|
||||
|
||||
@classmethod
|
||||
def find_or_create_from_oauth(
|
||||
cls, provider: str, provider_id: str, email: str, name: str, picture: Optional[str] = None
|
||||
cls,
|
||||
provider: str,
|
||||
provider_id: str,
|
||||
email: str,
|
||||
name: str,
|
||||
picture: Optional[str] = None,
|
||||
) -> tuple["User", "UserOAuth"]:
|
||||
"""Find existing user or create new one from OAuth data."""
|
||||
from app.models.user_oauth import UserOAuth
|
||||
from app.models.plan import Plan
|
||||
|
||||
# First, try to find existing OAuth provider
|
||||
oauth_provider = UserOAuth.find_by_provider_and_id(provider, provider_id)
|
||||
oauth_provider = UserOAuth.find_by_provider_and_id(
|
||||
provider, provider_id
|
||||
)
|
||||
|
||||
if oauth_provider:
|
||||
# Update existing provider and user info
|
||||
@@ -190,7 +210,9 @@ class User(db.Model):
|
||||
oauth_provider.updated_at = datetime.utcnow()
|
||||
|
||||
# Update user info with latest data
|
||||
user.update_from_provider({"email": email, "name": name, "picture": picture})
|
||||
user.update_from_provider(
|
||||
{"email": email, "name": name, "picture": picture}
|
||||
)
|
||||
else:
|
||||
# Try to find user by email to link the new provider
|
||||
user = cls.find_by_email(email)
|
||||
@@ -233,7 +255,9 @@ class User(db.Model):
|
||||
return user, oauth_provider
|
||||
|
||||
@classmethod
|
||||
def create_with_password(cls, email: str, password: str, name: str) -> "User":
|
||||
def create_with_password(
|
||||
cls, email: str, password: str, name: str
|
||||
) -> "User":
|
||||
"""Create new user with email and password."""
|
||||
from app.models.plan import Plan
|
||||
|
||||
@@ -268,7 +292,9 @@ class User(db.Model):
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def authenticate_with_password(cls, email: str, password: str) -> Optional["User"]:
|
||||
def authenticate_with_password(
|
||||
cls, email: str, password: str
|
||||
) -> Optional["User"]:
|
||||
"""Authenticate user with email and password."""
|
||||
user = cls.find_by_email(email)
|
||||
if user and user.check_password(password) and user.is_active:
|
||||
|
||||
@@ -36,16 +36,23 @@ class UserOAuth(db.Model):
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
DateTime,
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Unique constraint on provider + provider_id combination
|
||||
__table_args__ = (
|
||||
db.UniqueConstraint("provider", "provider_id", name="unique_provider_user"),
|
||||
db.UniqueConstraint(
|
||||
"provider", "provider_id", name="unique_provider_user"
|
||||
),
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship("User", back_populates="oauth_providers")
|
||||
user: Mapped["User"] = relationship(
|
||||
"User", back_populates="oauth_providers"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of UserOAuth."""
|
||||
@@ -65,9 +72,13 @@ class UserOAuth(db.Model):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def find_by_provider_and_id(cls, provider: str, provider_id: str) -> Optional["UserOAuth"]:
|
||||
def find_by_provider_and_id(
|
||||
cls, provider: str, provider_id: str
|
||||
) -> Optional["UserOAuth"]:
|
||||
"""Find OAuth provider by provider name and provider ID."""
|
||||
return cls.query.filter_by(provider=provider, provider_id=provider_id).first()
|
||||
return cls.query.filter_by(
|
||||
provider=provider, provider_id=provider_id
|
||||
).first()
|
||||
|
||||
@classmethod
|
||||
def create_or_update(
|
||||
@@ -77,7 +88,7 @@ class UserOAuth(db.Model):
|
||||
provider_id: str,
|
||||
email: str,
|
||||
name: str,
|
||||
picture: Optional[str] = None
|
||||
picture: Optional[str] = None,
|
||||
) -> "UserOAuth":
|
||||
"""Create new OAuth provider or update existing one."""
|
||||
oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
|
||||
|
||||
@@ -320,10 +320,7 @@ def update_profile():
|
||||
"credits": user.credits,
|
||||
}
|
||||
|
||||
return {
|
||||
"message": "Profile updated successfully",
|
||||
"user": updated_user
|
||||
}
|
||||
return {"message": "Profile updated successfully", "user": updated_user}
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return {"error": f"Failed to update profile: {str(e)}"}, 500
|
||||
@@ -367,7 +364,9 @@ def change_password():
|
||||
if user.password_hash and current_auth_method == "password":
|
||||
# User has a password AND logged in via password, require current password for verification
|
||||
if not current_password:
|
||||
return {"error": "Current password is required to change password"}, 400
|
||||
return {
|
||||
"error": "Current password is required to change password"
|
||||
}, 400
|
||||
|
||||
if not check_password_hash(user.password_hash, current_password):
|
||||
return {"error": "Current password is incorrect"}, 400
|
||||
@@ -378,9 +377,7 @@ def change_password():
|
||||
user.set_password(new_password)
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
"message": "Password updated successfully"
|
||||
}
|
||||
return {"message": "Password updated successfully"}
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
return {"error": f"Failed to update password: {str(e)}"}, 500
|
||||
|
||||
@@ -2,7 +2,12 @@
|
||||
|
||||
from flask import Blueprint
|
||||
|
||||
from app.services.decorators import get_current_user, require_auth, require_role, require_credits
|
||||
from app.services.decorators import (
|
||||
get_current_user,
|
||||
require_auth,
|
||||
require_role,
|
||||
require_credits,
|
||||
)
|
||||
|
||||
bp = Blueprint("main", __name__)
|
||||
|
||||
@@ -63,7 +68,8 @@ def use_credits(amount: int) -> dict[str, str]:
|
||||
return {
|
||||
"message": f"Successfully used endpoint! You requested amount: {amount}",
|
||||
"user": user["email"],
|
||||
"remaining_credits": user["credits"] - 5, # Note: credits already deducted by decorator
|
||||
"remaining_credits": user["credits"]
|
||||
- 5, # Note: credits already deducted by decorator
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -83,7 +83,9 @@ class AuthService:
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login
|
||||
jwt_user_data["provider"] = (
|
||||
oauth_provider.provider
|
||||
) # Override provider for OAuth login
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
@@ -156,7 +158,9 @@ class AuthService:
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password registration
|
||||
jwt_user_data["provider"] = (
|
||||
"password" # Override provider for password registration
|
||||
)
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
@@ -199,7 +203,9 @@ class AuthService:
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password login
|
||||
jwt_user_data["provider"] = (
|
||||
"password" # Override provider for password login
|
||||
)
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(jwt_user_data)
|
||||
|
||||
@@ -148,6 +148,7 @@ def require_role(required_role: str):
|
||||
|
||||
def require_credits(credits_needed: int):
|
||||
"""Decorator to require and deduct credits for routes."""
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -185,4 +186,5 @@ def require_credits(credits_needed: int):
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -42,7 +42,7 @@ class OAuthProvider(ABC):
|
||||
name=self.name,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
**config
|
||||
**config,
|
||||
)
|
||||
return self._client
|
||||
|
||||
@@ -51,7 +51,9 @@ class OAuthProvider(ABC):
|
||||
client = self.get_client()
|
||||
return client.authorize_redirect(redirect_uri).location
|
||||
|
||||
def exchange_code_for_token(self, code: str = None, redirect_uri: str = None) -> Dict[str, Any]:
|
||||
def exchange_code_for_token(
|
||||
self, code: str = None, redirect_uri: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Exchange authorization code for access token."""
|
||||
client = self.get_client()
|
||||
token = client.authorize_access_token()
|
||||
@@ -60,9 +62,9 @@ class OAuthProvider(ABC):
|
||||
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Normalize user data to common format."""
|
||||
return {
|
||||
'id': user_info.get('id'),
|
||||
'email': user_info.get('email'),
|
||||
'name': user_info.get('name'),
|
||||
'picture': user_info.get('picture'),
|
||||
'provider': self.name
|
||||
"id": user_info.get("id"),
|
||||
"email": user_info.get("email"),
|
||||
"name": user_info.get("name"),
|
||||
"picture": user_info.get("picture"),
|
||||
"provider": self.name,
|
||||
}
|
||||
@@ -7,21 +7,19 @@ class GitHubOAuthProvider(OAuthProvider):
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return 'github'
|
||||
return "github"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return 'GitHub'
|
||||
return "GitHub"
|
||||
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
"""Return GitHub OAuth client configuration."""
|
||||
return {
|
||||
'access_token_url': 'https://github.com/login/oauth/access_token',
|
||||
'authorize_url': 'https://github.com/login/oauth/authorize',
|
||||
'api_base_url': 'https://api.github.com/',
|
||||
'client_kwargs': {
|
||||
'scope': 'user:email'
|
||||
}
|
||||
"access_token_url": "https://github.com/login/oauth/access_token",
|
||||
"authorize_url": "https://github.com/login/oauth/authorize",
|
||||
"api_base_url": "https://api.github.com/",
|
||||
"client_kwargs": {"scope": "user:email"},
|
||||
}
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -29,24 +27,24 @@ class GitHubOAuthProvider(OAuthProvider):
|
||||
client = self.get_client()
|
||||
|
||||
# Get user profile
|
||||
user_resp = client.get('user', token=token)
|
||||
user_resp = client.get("user", token=token)
|
||||
user_data = user_resp.json()
|
||||
|
||||
# Get user email (may be private)
|
||||
email = user_data.get('email')
|
||||
email = user_data.get("email")
|
||||
if not email:
|
||||
# If email is private, get from emails endpoint
|
||||
emails_resp = client.get('user/emails', token=token)
|
||||
emails_resp = client.get("user/emails", token=token)
|
||||
emails = emails_resp.json()
|
||||
# Find primary email
|
||||
for email_obj in emails:
|
||||
if email_obj.get('primary', False):
|
||||
email = email_obj.get('email')
|
||||
if email_obj.get("primary", False):
|
||||
email = email_obj.get("email")
|
||||
break
|
||||
|
||||
return {
|
||||
'id': str(user_data.get('id')),
|
||||
'email': email,
|
||||
'name': user_data.get('name') or user_data.get('login'),
|
||||
'picture': user_data.get('avatar_url')
|
||||
"id": str(user_data.get("id")),
|
||||
"email": email,
|
||||
"name": user_data.get("name") or user_data.get("login"),
|
||||
"picture": user_data.get("avatar_url"),
|
||||
}
|
||||
@@ -17,18 +17,18 @@ class OAuthProviderRegistry:
|
||||
def _initialize_providers(self):
|
||||
"""Initialize available providers based on environment variables."""
|
||||
# Google OAuth
|
||||
google_client_id = os.getenv('GOOGLE_CLIENT_ID')
|
||||
google_client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID")
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||
if google_client_id and google_client_secret:
|
||||
self._providers['google'] = GoogleOAuthProvider(
|
||||
self._providers["google"] = GoogleOAuthProvider(
|
||||
self.oauth, google_client_id, google_client_secret
|
||||
)
|
||||
|
||||
# GitHub OAuth
|
||||
github_client_id = os.getenv('GITHUB_CLIENT_ID')
|
||||
github_client_secret = os.getenv('GITHUB_CLIENT_SECRET')
|
||||
github_client_id = os.getenv("GITHUB_CLIENT_ID")
|
||||
github_client_secret = os.getenv("GITHUB_CLIENT_SECRET")
|
||||
if github_client_id and github_client_secret:
|
||||
self._providers['github'] = GitHubOAuthProvider(
|
||||
self._providers["github"] = GitHubOAuthProvider(
|
||||
self.oauth, github_client_id, github_client_secret
|
||||
)
|
||||
|
||||
|
||||
@@ -9,22 +9,27 @@ from app.database import db
|
||||
app = create_app()
|
||||
cli = FlaskGroup(app)
|
||||
|
||||
|
||||
@cli.command()
|
||||
def init_db():
|
||||
"""Initialize the database."""
|
||||
print("Initializing database...")
|
||||
from app.database_init import init_database
|
||||
|
||||
init_database()
|
||||
print("Database initialized successfully!")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def reset_db():
|
||||
"""Reset the database (drop all tables and recreate)."""
|
||||
print("Resetting database...")
|
||||
db.drop_all()
|
||||
from app.database_init import init_database
|
||||
|
||||
init_database()
|
||||
print("Database reset successfully!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -23,21 +23,32 @@ class TestAuthRoutesJWTExtended:
|
||||
@patch("app.routes.auth.auth_service.get_login_url")
|
||||
def test_login_route(self, mock_get_login_url: Mock, client) -> None:
|
||||
"""Test the login route."""
|
||||
mock_get_login_url.return_value = "https://accounts.google.com/oauth/authorize?..."
|
||||
mock_get_login_url.return_value = (
|
||||
"https://accounts.google.com/oauth/authorize?..."
|
||||
)
|
||||
|
||||
response = client.get("/api/auth/login")
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert "login_url" in data
|
||||
assert data["login_url"] == "https://accounts.google.com/oauth/authorize?..."
|
||||
assert (
|
||||
data["login_url"]
|
||||
== "https://accounts.google.com/oauth/authorize?..."
|
||||
)
|
||||
|
||||
@patch("app.routes.auth.auth_service.handle_callback")
|
||||
def test_callback_route_success(self, mock_handle_callback: Mock, client) -> None:
|
||||
def test_callback_route_success(
|
||||
self, mock_handle_callback: Mock, client
|
||||
) -> None:
|
||||
"""Test successful callback route."""
|
||||
mock_response = Mock()
|
||||
mock_response.get_json.return_value = {
|
||||
"message": "Login successful",
|
||||
"user": {"id": "123", "email": "test@example.com", "name": "Test User"}
|
||||
"user": {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
}
|
||||
mock_handle_callback.return_value = mock_response
|
||||
|
||||
@@ -45,7 +56,9 @@ class TestAuthRoutesJWTExtended:
|
||||
mock_handle_callback.assert_called_once()
|
||||
|
||||
@patch("app.routes.auth.auth_service.handle_callback")
|
||||
def test_callback_route_error(self, mock_handle_callback: Mock, client) -> None:
|
||||
def test_callback_route_error(
|
||||
self, mock_handle_callback: Mock, client
|
||||
) -> None:
|
||||
"""Test callback route with error."""
|
||||
mock_handle_callback.side_effect = Exception("OAuth error")
|
||||
|
||||
@@ -58,7 +71,9 @@ class TestAuthRoutesJWTExtended:
|
||||
def test_logout_route(self, mock_logout: Mock, client) -> None:
|
||||
"""Test logout route."""
|
||||
mock_response = Mock()
|
||||
mock_response.get_json.return_value = {"message": "Logged out successfully"}
|
||||
mock_response.get_json.return_value = {
|
||||
"message": "Logged out successfully"
|
||||
}
|
||||
mock_logout.return_value = mock_response
|
||||
|
||||
response = client.get("/api/auth/logout")
|
||||
|
||||
@@ -21,7 +21,7 @@ class TestAuthServiceJWTExtended:
|
||||
"""Test initializing AuthService with Flask app."""
|
||||
mock_getenv.side_effect = lambda key: {
|
||||
"GOOGLE_CLIENT_ID": "test_client_id",
|
||||
"GOOGLE_CLIENT_SECRET": "test_client_secret"
|
||||
"GOOGLE_CLIENT_SECRET": "test_client_secret",
|
||||
}.get(key)
|
||||
|
||||
app = create_app()
|
||||
@@ -45,4 +45,6 @@ class TestAuthServiceJWTExtended:
|
||||
|
||||
assert result == mock_response
|
||||
mock_unset.assert_called_once_with(mock_response)
|
||||
mock_jsonify.assert_called_once_with({"message": "Logged out successfully"})
|
||||
mock_jsonify.assert_called_once_with(
|
||||
{"message": "Logged out successfully"}
|
||||
)
|
||||
|
||||
@@ -21,7 +21,10 @@ class TestMainRoutes:
|
||||
"""Test the index route."""
|
||||
response = client.get("/api/")
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"message": "API is running", "status": "ok"}
|
||||
assert response.get_json() == {
|
||||
"message": "API is running",
|
||||
"status": "ok",
|
||||
}
|
||||
|
||||
def test_health_route(self, client) -> None:
|
||||
"""Test health check route."""
|
||||
|
||||
@@ -25,14 +25,18 @@ class TestTokenService:
|
||||
user_data = {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User"
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
token = token_service.generate_access_token(user_data)
|
||||
assert isinstance(token, str)
|
||||
|
||||
# Verify token content
|
||||
payload = jwt.decode(token, token_service.secret_key, algorithms=[token_service.algorithm])
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
token_service.secret_key,
|
||||
algorithms=[token_service.algorithm],
|
||||
)
|
||||
assert payload["user_id"] == "123"
|
||||
assert payload["email"] == "test@example.com"
|
||||
assert payload["name"] == "Test User"
|
||||
@@ -44,21 +48,29 @@ class TestTokenService:
|
||||
user_data = {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User"
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
token = token_service.generate_refresh_token(user_data)
|
||||
assert isinstance(token, str)
|
||||
|
||||
# Verify token content
|
||||
payload = jwt.decode(token, token_service.secret_key, algorithms=[token_service.algorithm])
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
token_service.secret_key,
|
||||
algorithms=[token_service.algorithm],
|
||||
)
|
||||
assert payload["user_id"] == "123"
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_verify_valid_token(self) -> None:
|
||||
"""Test verifying a valid token."""
|
||||
token_service = TokenService()
|
||||
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"}
|
||||
user_data = {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
token = token_service.generate_access_token(user_data)
|
||||
payload = token_service.verify_token(token)
|
||||
@@ -83,7 +95,11 @@ class TestTokenService:
|
||||
mock_datetime.UTC = timezone.utc
|
||||
|
||||
token_service = TokenService()
|
||||
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"}
|
||||
user_data = {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
token = token_service.generate_access_token(user_data)
|
||||
|
||||
@@ -116,7 +132,11 @@ class TestTokenService:
|
||||
def test_get_user_from_access_token_valid(self) -> None:
|
||||
"""Test extracting user from valid access token."""
|
||||
token_service = TokenService()
|
||||
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"}
|
||||
user_data = {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
token = token_service.generate_access_token(user_data)
|
||||
extracted_user = token_service.get_user_from_access_token(token)
|
||||
@@ -126,7 +146,11 @@ class TestTokenService:
|
||||
def test_get_user_from_access_token_refresh_token(self) -> None:
|
||||
"""Test extracting user from refresh token (should fail)."""
|
||||
token_service = TokenService()
|
||||
user_data = {"id": "123", "email": "test@example.com", "name": "Test User"}
|
||||
user_data = {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
token = token_service.generate_refresh_token(user_data)
|
||||
extracted_user = token_service.get_user_from_access_token(token)
|
||||
@@ -137,5 +161,7 @@ class TestTokenService:
|
||||
"""Test extracting user from invalid token."""
|
||||
token_service = TokenService()
|
||||
|
||||
extracted_user = token_service.get_user_from_access_token("invalid.token")
|
||||
extracted_user = token_service.get_user_from_access_token(
|
||||
"invalid.token"
|
||||
)
|
||||
assert extracted_user is None
|
||||
@@ -18,7 +18,7 @@ class TestTokenServiceJWTExtended:
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User",
|
||||
"picture": "https://example.com/pic.jpg"
|
||||
"picture": "https://example.com/pic.jpg",
|
||||
}
|
||||
|
||||
token = token_service.generate_access_token(user_data)
|
||||
@@ -33,7 +33,7 @@ class TestTokenServiceJWTExtended:
|
||||
user_data = {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User"
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
token = token_service.generate_refresh_token(user_data)
|
||||
@@ -48,7 +48,7 @@ class TestTokenServiceJWTExtended:
|
||||
user_data = {
|
||||
"id": "123",
|
||||
"email": "test@example.com",
|
||||
"name": "Test User"
|
||||
"name": "Test User",
|
||||
}
|
||||
|
||||
access_token = token_service.generate_access_token(user_data)
|
||||
|
||||
Reference in New Issue
Block a user