117 lines
3.5 KiB
Python
117 lines
3.5 KiB
Python
"""User OAuth model for storing user's connected providers."""
|
|
|
|
from datetime import datetime
|
|
from typing import Optional, TYPE_CHECKING
|
|
|
|
from sqlalchemy import String, DateTime, Text, ForeignKey
|
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
|
|
from app.database import db
|
|
|
|
if TYPE_CHECKING:
|
|
from app.models.user import User
|
|
|
|
|
|
class UserOAuth(db.Model):
|
|
"""Model for storing user's connected OAuth providers."""
|
|
|
|
__tablename__ = "user_oauth"
|
|
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
|
|
# User relationship
|
|
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
|
|
|
|
# OAuth provider information
|
|
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
|
provider_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
|
|
# Provider-specific user information
|
|
email: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
picture: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
|
|
|
# Timestamps
|
|
created_at: Mapped[datetime] = mapped_column(
|
|
DateTime, default=datetime.utcnow, nullable=False
|
|
)
|
|
updated_at: Mapped[datetime] = mapped_column(
|
|
DateTime,
|
|
default=datetime.utcnow,
|
|
onupdate=datetime.utcnow,
|
|
nullable=False,
|
|
)
|
|
|
|
# Unique constraint on provider + provider_id combination
|
|
__table_args__ = (
|
|
db.UniqueConstraint(
|
|
"provider", "provider_id", name="unique_provider_user"
|
|
),
|
|
)
|
|
|
|
# Relationships
|
|
user: Mapped["User"] = relationship(
|
|
"User", back_populates="oauth_providers"
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation of UserOAuth."""
|
|
return f"<UserOAuth {self.email} ({self.provider})>"
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert oauth provider to dictionary."""
|
|
return {
|
|
"id": self.id,
|
|
"provider": self.provider,
|
|
"provider_id": self.provider_id,
|
|
"email": self.email,
|
|
"name": self.name,
|
|
"picture": self.picture,
|
|
"created_at": self.created_at.isoformat(),
|
|
"updated_at": self.updated_at.isoformat(),
|
|
}
|
|
|
|
@classmethod
|
|
def find_by_provider_and_id(
|
|
cls, provider: str, provider_id: str
|
|
) -> Optional["UserOAuth"]:
|
|
"""Find OAuth provider by provider name and provider ID."""
|
|
return cls.query.filter_by(
|
|
provider=provider, provider_id=provider_id
|
|
).first()
|
|
|
|
@classmethod
|
|
def create_or_update(
|
|
cls,
|
|
user_id: int,
|
|
provider: str,
|
|
provider_id: str,
|
|
email: str,
|
|
name: str,
|
|
picture: Optional[str] = None,
|
|
) -> "UserOAuth":
|
|
"""Create new OAuth provider or update existing one."""
|
|
oauth_provider = cls.find_by_provider_and_id(provider, provider_id)
|
|
|
|
if oauth_provider:
|
|
# Update existing provider
|
|
oauth_provider.user_id = user_id
|
|
oauth_provider.email = email
|
|
oauth_provider.name = name
|
|
oauth_provider.picture = picture
|
|
oauth_provider.updated_at = datetime.utcnow()
|
|
else:
|
|
# Create new provider
|
|
oauth_provider = cls(
|
|
user_id=user_id,
|
|
provider=provider,
|
|
provider_id=provider_id,
|
|
email=email,
|
|
name=name,
|
|
picture=picture,
|
|
)
|
|
db.session.add(oauth_provider)
|
|
|
|
db.session.commit()
|
|
return oauth_provider
|