Files
sdb-back/app/services/oauth_linking_service.py

109 lines
3.5 KiB
Python

"""OAuth provider linking service."""
from authlib.integrations.flask_client import OAuth
from app.models.user import User
from app.models.user_oauth import UserOAuth
from app.services.oauth_providers.registry import OAuthProviderRegistry
class OAuthLinkingService:
"""Service for linking and unlinking OAuth providers."""
@staticmethod
def link_provider_to_user(
provider: str,
current_user_id: int,
) -> dict:
"""Link a new OAuth provider to existing user account."""
# Get current user from database
user = User.query.get(current_user_id)
if not user:
raise ValueError("User not found")
# Get OAuth provider and process callback
oauth = OAuth()
registry = OAuthProviderRegistry(oauth)
oauth_provider = registry.get_provider(provider)
if not oauth_provider:
raise ValueError(f"OAuth provider '{provider}' not configured")
# Exchange code for token and get user info
token = oauth_provider.exchange_code_for_token(None, None)
raw_user_info = oauth_provider.get_user_info(token)
provider_data = oauth_provider.normalize_user_data(raw_user_info)
if not provider_data.get("id"):
raise ValueError("Failed to get user information from provider")
# Check if this provider is already linked to another user
existing_provider = UserOAuth.find_by_provider_and_id(
provider,
provider_data["id"],
)
if existing_provider and existing_provider.user_id != user.id:
raise ValueError(
"This provider account is already linked to another user",
)
# Link the provider to current user
UserOAuth.create_or_update(
user_id=user.id,
provider=provider,
provider_id=provider_data["id"],
email=provider_data["email"],
name=provider_data["name"],
picture=provider_data.get("picture"),
)
return {"message": f"{provider.title()} account linked successfully"}
@staticmethod
def unlink_provider_from_user(
provider: str,
current_user_id: int,
) -> dict:
"""Unlink an OAuth provider from user account."""
from app.database import db
user = User.query.get(current_user_id)
if not user:
raise ValueError("User not found")
# Check if user has more than one provider (prevent locking out)
if len(user.oauth_providers) <= 1:
raise ValueError("Cannot unlink last authentication provider")
# Find and remove the provider
oauth_provider = user.get_provider(provider)
if not oauth_provider:
raise ValueError(
f"Provider '{provider}' not linked to this account",
)
db.session.delete(oauth_provider)
db.session.commit()
return {"message": f"{provider.title()} account unlinked successfully"}
@staticmethod
def get_user_providers(user_id: int) -> dict:
"""Get all OAuth providers linked to a user."""
user = User.query.get(user_id)
if not user:
raise ValueError("User not found")
return {
"providers": [
{
"provider": oauth.provider,
"email": oauth.email,
"name": oauth.name,
"picture": oauth.picture,
}
for oauth in user.oauth_providers
],
}