"""Repository for user OAuth operations.""" from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger from app.models.user_oauth import UserOauth from app.repositories.base import BaseRepository logger = get_logger(__name__) class UserOauthRepository(BaseRepository[UserOauth]): """Repository for user OAuth operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize repository with database session.""" super().__init__(UserOauth, session) async def get_by_provider_user_id( self, provider: str, provider_user_id: str, ) -> UserOauth | None: """Get user OAuth by provider and provider user ID.""" try: statement = select(UserOauth).where( UserOauth.provider == provider, UserOauth.provider_user_id == provider_user_id, ) result = await self.session.exec(statement) return result.first() except Exception: logger.exception( "Failed to get user OAuth by provider user ID: %s:%s", provider, provider_user_id, ) raise async def get_by_user_id_and_provider( self, user_id: int, provider: str, ) -> UserOauth | None: """Get user OAuth by user ID and provider.""" try: statement = select(UserOauth).where( UserOauth.user_id == user_id, UserOauth.provider == provider, ) result = await self.session.exec(statement) except Exception: logger.exception( "Failed to get user OAuth by user ID and provider: %s:%s", user_id, provider, ) raise else: return result.first()