109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
"""OAuth provider linking service."""
|
|
|
|
from authlib.integrations.flask_client import OAuth
|
|
|
|
from app.models.user import User
|
|
from app.models.user_oauth import UserOAuth
|
|
from app.services.oauth_providers.registry import OAuthProviderRegistry
|
|
|
|
|
|
class OAuthLinkingService:
|
|
"""Service for linking and unlinking OAuth providers."""
|
|
|
|
@staticmethod
|
|
def link_provider_to_user(
|
|
provider: str,
|
|
current_user_id: int,
|
|
) -> dict:
|
|
"""Link a new OAuth provider to existing user account."""
|
|
# Get current user from database
|
|
user = User.query.get(current_user_id)
|
|
if not user:
|
|
raise ValueError("User not found")
|
|
|
|
# Get OAuth provider and process callback
|
|
oauth = OAuth()
|
|
registry = OAuthProviderRegistry(oauth)
|
|
oauth_provider = registry.get_provider(provider)
|
|
|
|
if not oauth_provider:
|
|
raise ValueError(f"OAuth provider '{provider}' not configured")
|
|
|
|
# Exchange code for token and get user info
|
|
token = oauth_provider.exchange_code_for_token(None, None)
|
|
raw_user_info = oauth_provider.get_user_info(token)
|
|
provider_data = oauth_provider.normalize_user_data(raw_user_info)
|
|
|
|
if not provider_data.get("id"):
|
|
raise ValueError("Failed to get user information from provider")
|
|
|
|
# Check if this provider is already linked to another user
|
|
existing_provider = UserOAuth.find_by_provider_and_id(
|
|
provider,
|
|
provider_data["id"],
|
|
)
|
|
|
|
if existing_provider and existing_provider.user_id != user.id:
|
|
raise ValueError(
|
|
"This provider account is already linked to another user",
|
|
)
|
|
|
|
# Link the provider to current user
|
|
UserOAuth.create_or_update(
|
|
user_id=user.id,
|
|
provider=provider,
|
|
provider_id=provider_data["id"],
|
|
email=provider_data["email"],
|
|
name=provider_data["name"],
|
|
picture=provider_data.get("picture"),
|
|
)
|
|
|
|
return {"message": f"{provider.title()} account linked successfully"}
|
|
|
|
@staticmethod
|
|
def unlink_provider_from_user(
|
|
provider: str,
|
|
current_user_id: int,
|
|
) -> dict:
|
|
"""Unlink an OAuth provider from user account."""
|
|
from app.database import db
|
|
|
|
user = User.query.get(current_user_id)
|
|
if not user:
|
|
raise ValueError("User not found")
|
|
|
|
# Check if user has more than one provider (prevent locking out)
|
|
if len(user.oauth_providers) <= 1:
|
|
raise ValueError("Cannot unlink last authentication provider")
|
|
|
|
# Find and remove the provider
|
|
oauth_provider = user.get_provider(provider)
|
|
if not oauth_provider:
|
|
raise ValueError(
|
|
f"Provider '{provider}' not linked to this account",
|
|
)
|
|
|
|
db.session.delete(oauth_provider)
|
|
db.session.commit()
|
|
|
|
return {"message": f"{provider.title()} account unlinked successfully"}
|
|
|
|
@staticmethod
|
|
def get_user_providers(user_id: int) -> dict:
|
|
"""Get all OAuth providers linked to a user."""
|
|
user = User.query.get(user_id)
|
|
if not user:
|
|
raise ValueError("User not found")
|
|
|
|
return {
|
|
"providers": [
|
|
{
|
|
"provider": oauth.provider,
|
|
"email": oauth.email,
|
|
"name": oauth.name,
|
|
"picture": oauth.picture,
|
|
}
|
|
for oauth in user.oauth_providers
|
|
],
|
|
}
|