"""User OAuth model for storing user's connected providers.""" from datetime import datetime from typing import TYPE_CHECKING, Optional from zoneinfo import ZoneInfo from sqlalchemy import DateTime, ForeignKey, String, Text from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import db if TYPE_CHECKING: from app.models.user import User class UserOAuth(db.Model): """Model for storing user's connected OAuth providers.""" __tablename__ = "user_oauth" id: Mapped[int] = mapped_column(primary_key=True) # User relationship user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False) # OAuth provider information provider: Mapped[str] = mapped_column(String(50), nullable=False) provider_id: Mapped[str] = mapped_column(String(255), nullable=False) # Provider-specific user information email: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) picture: Mapped[str | None] = mapped_column(Text, nullable=True) # Timestamps created_at: Mapped[datetime] = mapped_column( DateTime, default=lambda: datetime.now(tz=ZoneInfo("UTC")), nullable=False, ) updated_at: Mapped[datetime] = mapped_column( DateTime, default=lambda: datetime.now(tz=ZoneInfo("UTC")), onupdate=lambda: datetime.now(tz=ZoneInfo("UTC")), nullable=False, ) # Unique constraint on provider + provider_id combination __table_args__ = ( db.UniqueConstraint( "provider", "provider_id", name="unique_provider_user", ), ) # Relationships user: Mapped["User"] = relationship( "User", back_populates="oauth_providers", ) def __repr__(self) -> str: """String representation of UserOAuth.""" return f"" def to_dict(self) -> dict: """Convert oauth provider to dictionary.""" return { "id": self.id, "provider": self.provider, "provider_id": self.provider_id, "email": self.email, "name": self.name, "picture": self.picture, "created_at": self.created_at.isoformat(), "updated_at": self.updated_at.isoformat(), } @classmethod def find_by_provider_and_id( cls, provider: str, provider_id: str, ) -> Optional["UserOAuth"]: """Find OAuth provider by provider name and provider ID.""" return cls.query.filter_by( provider=provider, provider_id=provider_id, ).first() @classmethod def create_or_update( cls, user_id: int, provider: str, provider_id: str, email: str, name: str, picture: str | None = None, ) -> "UserOAuth": """Create new OAuth provider or update existing one.""" oauth_provider = cls.find_by_provider_and_id(provider, provider_id) if oauth_provider: # Update existing provider oauth_provider.user_id = user_id oauth_provider.email = email oauth_provider.name = name oauth_provider.picture = picture oauth_provider.updated_at = datetime.now(tz=ZoneInfo("UTC")) else: # Create new provider oauth_provider = cls( user_id=user_id, provider=provider, provider_id=provider_id, email=email, name=name, picture=picture, ) db.session.add(oauth_provider) db.session.commit() return oauth_provider