"""OAuth provider linking service.""" from authlib.integrations.flask_client import OAuth from app.models.user import User from app.models.user_oauth import UserOAuth from app.services.oauth_providers.registry import OAuthProviderRegistry class OAuthLinkingService: """Service for linking and unlinking OAuth providers.""" @staticmethod def link_provider_to_user( provider: str, current_user_id: int, ) -> dict: """Link a new OAuth provider to existing user account.""" # Get current user from database user = User.query.get(current_user_id) if not user: raise ValueError("User not found") # Get OAuth provider and process callback oauth = OAuth() registry = OAuthProviderRegistry(oauth) oauth_provider = registry.get_provider(provider) if not oauth_provider: raise ValueError(f"OAuth provider '{provider}' not configured") # Exchange code for token and get user info token = oauth_provider.exchange_code_for_token(None, None) raw_user_info = oauth_provider.get_user_info(token) provider_data = oauth_provider.normalize_user_data(raw_user_info) if not provider_data.get("id"): raise ValueError("Failed to get user information from provider") # Check if this provider is already linked to another user existing_provider = UserOAuth.find_by_provider_and_id( provider, provider_data["id"], ) if existing_provider and existing_provider.user_id != user.id: raise ValueError( "This provider account is already linked to another user", ) # Link the provider to current user UserOAuth.create_or_update( user_id=user.id, provider=provider, provider_id=provider_data["id"], email=provider_data["email"], name=provider_data["name"], picture=provider_data.get("picture"), ) return {"message": f"{provider.title()} account linked successfully"} @staticmethod def unlink_provider_from_user( provider: str, current_user_id: int, ) -> dict: """Unlink an OAuth provider from user account.""" from app.database import db user = User.query.get(current_user_id) if not user: raise ValueError("User not found") # Check if user has more than one provider (prevent locking out) if len(user.oauth_providers) <= 1: raise ValueError("Cannot unlink last authentication provider") # Find and remove the provider oauth_provider = user.get_provider(provider) if not oauth_provider: raise ValueError( f"Provider '{provider}' not linked to this account", ) db.session.delete(oauth_provider) db.session.commit() return {"message": f"{provider.title()} account unlinked successfully"} @staticmethod def get_user_providers(user_id: int) -> dict: """Get all OAuth providers linked to a user.""" user = User.query.get(user_id) if not user: raise ValueError("User not found") return { "providers": [ { "provider": oauth.provider, "email": oauth.email, "name": oauth.name, "picture": oauth.picture, } for oauth in user.oauth_providers ], }