Refactor code for improved readability and consistency
- Cleaned up whitespace and formatting across multiple files for better readability.
This commit is contained in:
@@ -83,7 +83,9 @@ class AuthService:
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = oauth_provider.provider # Override provider for OAuth login
|
||||
jwt_user_data["provider"] = (
|
||||
oauth_provider.provider
|
||||
) # Override provider for OAuth login
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
@@ -156,7 +158,9 @@ class AuthService:
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password registration
|
||||
jwt_user_data["provider"] = (
|
||||
"password" # Override provider for password registration
|
||||
)
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(
|
||||
@@ -199,7 +203,9 @@ class AuthService:
|
||||
|
||||
# Prepare user data for JWT token using user.to_dict()
|
||||
jwt_user_data = user.to_dict()
|
||||
jwt_user_data["provider"] = "password" # Override provider for password login
|
||||
jwt_user_data["provider"] = (
|
||||
"password" # Override provider for password login
|
||||
)
|
||||
|
||||
# Generate JWT tokens
|
||||
access_token = self.token_service.generate_access_token(jwt_user_data)
|
||||
|
||||
@@ -12,14 +12,14 @@ def get_user_from_jwt() -> dict[str, Any] | None:
|
||||
try:
|
||||
# Try to verify JWT token in request - this sets up the context
|
||||
verify_jwt_in_request()
|
||||
|
||||
|
||||
current_user_id = get_jwt_identity()
|
||||
if not current_user_id:
|
||||
return None
|
||||
|
||||
# Query database for user data instead of using JWT claims
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
user = User.query.get(int(current_user_id))
|
||||
if not user or not user.is_active:
|
||||
return None
|
||||
@@ -70,7 +70,7 @@ def get_user_from_api_token() -> dict[str, Any] | None:
|
||||
providers.append("password")
|
||||
if user.api_token:
|
||||
providers.append("api_token")
|
||||
|
||||
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
@@ -148,22 +148,23 @@ def require_role(required_role: str):
|
||||
|
||||
def require_credits(credits_needed: int):
|
||||
"""Decorator to require and deduct credits for routes."""
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
from app.models.user import User
|
||||
from app.database import db
|
||||
|
||||
|
||||
# First check authentication
|
||||
user_data = get_current_user()
|
||||
if not user_data:
|
||||
return jsonify({"error": "Authentication required"}), 401
|
||||
|
||||
|
||||
# Get the actual user from database to check/update credits
|
||||
user = User.query.get(int(user_data["id"]))
|
||||
if not user or not user.is_active:
|
||||
return jsonify({"error": "User not found or inactive"}), 401
|
||||
|
||||
|
||||
# Check if user has enough credits
|
||||
if user.credits < credits_needed:
|
||||
return (
|
||||
@@ -174,15 +175,16 @@ def require_credits(credits_needed: int):
|
||||
),
|
||||
402, # Payment Required status code
|
||||
)
|
||||
|
||||
|
||||
# Deduct credits
|
||||
user.credits -= credits_needed
|
||||
db.session.commit()
|
||||
|
||||
|
||||
# Execute the function
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -5,35 +5,35 @@ from authlib.integrations.flask_client import OAuth
|
||||
|
||||
class OAuthProvider(ABC):
|
||||
"""Abstract base class for OAuth providers."""
|
||||
|
||||
|
||||
def __init__(self, oauth: OAuth, client_id: str, client_secret: str):
|
||||
self.oauth = oauth
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self._client = None
|
||||
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name (e.g., 'google', 'github')."""
|
||||
pass
|
||||
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable provider name (e.g., 'Google', 'GitHub')."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
"""Return OAuth client configuration."""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract user information from OAuth token response."""
|
||||
pass
|
||||
|
||||
|
||||
def get_client(self):
|
||||
"""Get or create OAuth client."""
|
||||
if self._client is None:
|
||||
@@ -42,27 +42,29 @@ class OAuthProvider(ABC):
|
||||
name=self.name,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
**config
|
||||
**config,
|
||||
)
|
||||
return self._client
|
||||
|
||||
|
||||
def get_authorization_url(self, redirect_uri: str) -> str:
|
||||
"""Generate authorization URL for OAuth flow."""
|
||||
client = self.get_client()
|
||||
return client.authorize_redirect(redirect_uri).location
|
||||
|
||||
def exchange_code_for_token(self, code: str = None, redirect_uri: str = None) -> Dict[str, Any]:
|
||||
|
||||
def exchange_code_for_token(
|
||||
self, code: str = None, redirect_uri: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Exchange authorization code for access token."""
|
||||
client = self.get_client()
|
||||
token = client.authorize_access_token()
|
||||
return token
|
||||
|
||||
|
||||
def normalize_user_data(self, user_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Normalize user data to common format."""
|
||||
return {
|
||||
'id': user_info.get('id'),
|
||||
'email': user_info.get('email'),
|
||||
'name': user_info.get('name'),
|
||||
'picture': user_info.get('picture'),
|
||||
'provider': self.name
|
||||
}
|
||||
"id": user_info.get("id"),
|
||||
"email": user_info.get("email"),
|
||||
"name": user_info.get("name"),
|
||||
"picture": user_info.get("picture"),
|
||||
"provider": self.name,
|
||||
}
|
||||
|
||||
@@ -4,49 +4,47 @@ from .base import OAuthProvider
|
||||
|
||||
class GitHubOAuthProvider(OAuthProvider):
|
||||
"""GitHub OAuth provider implementation."""
|
||||
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return 'github'
|
||||
|
||||
return "github"
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return 'GitHub'
|
||||
|
||||
return "GitHub"
|
||||
|
||||
def get_client_config(self) -> Dict[str, Any]:
|
||||
"""Return GitHub OAuth client configuration."""
|
||||
return {
|
||||
'access_token_url': 'https://github.com/login/oauth/access_token',
|
||||
'authorize_url': 'https://github.com/login/oauth/authorize',
|
||||
'api_base_url': 'https://api.github.com/',
|
||||
'client_kwargs': {
|
||||
'scope': 'user:email'
|
||||
}
|
||||
"access_token_url": "https://github.com/login/oauth/access_token",
|
||||
"authorize_url": "https://github.com/login/oauth/authorize",
|
||||
"api_base_url": "https://api.github.com/",
|
||||
"client_kwargs": {"scope": "user:email"},
|
||||
}
|
||||
|
||||
|
||||
def get_user_info(self, token: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Extract user information from GitHub OAuth token response."""
|
||||
client = self.get_client()
|
||||
|
||||
|
||||
# Get user profile
|
||||
user_resp = client.get('user', token=token)
|
||||
user_resp = client.get("user", token=token)
|
||||
user_data = user_resp.json()
|
||||
|
||||
|
||||
# Get user email (may be private)
|
||||
email = user_data.get('email')
|
||||
email = user_data.get("email")
|
||||
if not email:
|
||||
# If email is private, get from emails endpoint
|
||||
emails_resp = client.get('user/emails', token=token)
|
||||
emails_resp = client.get("user/emails", token=token)
|
||||
emails = emails_resp.json()
|
||||
# Find primary email
|
||||
for email_obj in emails:
|
||||
if email_obj.get('primary', False):
|
||||
email = email_obj.get('email')
|
||||
if email_obj.get("primary", False):
|
||||
email = email_obj.get("email")
|
||||
break
|
||||
|
||||
|
||||
return {
|
||||
'id': str(user_data.get('id')),
|
||||
'email': email,
|
||||
'name': user_data.get('name') or user_data.get('login'),
|
||||
'picture': user_data.get('avatar_url')
|
||||
}
|
||||
"id": str(user_data.get("id")),
|
||||
"email": email,
|
||||
"name": user_data.get("name") or user_data.get("login"),
|
||||
"picture": user_data.get("avatar_url"),
|
||||
}
|
||||
|
||||
@@ -8,38 +8,38 @@ from .github import GitHubOAuthProvider
|
||||
|
||||
class OAuthProviderRegistry:
|
||||
"""Registry for OAuth providers."""
|
||||
|
||||
|
||||
def __init__(self, oauth: OAuth):
|
||||
self.oauth = oauth
|
||||
self._providers: Dict[str, OAuthProvider] = {}
|
||||
self._initialize_providers()
|
||||
|
||||
|
||||
def _initialize_providers(self):
|
||||
"""Initialize available providers based on environment variables."""
|
||||
# Google OAuth
|
||||
google_client_id = os.getenv('GOOGLE_CLIENT_ID')
|
||||
google_client_secret = os.getenv('GOOGLE_CLIENT_SECRET')
|
||||
google_client_id = os.getenv("GOOGLE_CLIENT_ID")
|
||||
google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||
if google_client_id and google_client_secret:
|
||||
self._providers['google'] = GoogleOAuthProvider(
|
||||
self._providers["google"] = GoogleOAuthProvider(
|
||||
self.oauth, google_client_id, google_client_secret
|
||||
)
|
||||
|
||||
|
||||
# GitHub OAuth
|
||||
github_client_id = os.getenv('GITHUB_CLIENT_ID')
|
||||
github_client_secret = os.getenv('GITHUB_CLIENT_SECRET')
|
||||
github_client_id = os.getenv("GITHUB_CLIENT_ID")
|
||||
github_client_secret = os.getenv("GITHUB_CLIENT_SECRET")
|
||||
if github_client_id and github_client_secret:
|
||||
self._providers['github'] = GitHubOAuthProvider(
|
||||
self._providers["github"] = GitHubOAuthProvider(
|
||||
self.oauth, github_client_id, github_client_secret
|
||||
)
|
||||
|
||||
|
||||
def get_provider(self, name: str) -> Optional[OAuthProvider]:
|
||||
"""Get OAuth provider by name."""
|
||||
return self._providers.get(name)
|
||||
|
||||
|
||||
def get_available_providers(self) -> Dict[str, OAuthProvider]:
|
||||
"""Get all available providers."""
|
||||
return self._providers.copy()
|
||||
|
||||
|
||||
def is_provider_available(self, name: str) -> bool:
|
||||
"""Check if provider is available."""
|
||||
return name in self._providers
|
||||
return name in self._providers
|
||||
|
||||
Reference in New Issue
Block a user