- Added OAuth2 endpoints for Google and GitHub authentication. - Created OAuth service to handle provider interactions and user info retrieval. - Implemented user OAuth repository for managing user OAuth links in the database. - Updated auth service to support linking existing users and creating new users via OAuth. - Added CORS middleware to allow frontend access. - Created tests for OAuth endpoints and service functionality. - Introduced environment configuration for OAuth client IDs and secrets. - Added logging for OAuth operations and error handling.
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
"""Repository for user OAuth operations."""
|
|
|
|
from typing import Any
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.user_oauth import UserOauth
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class UserOauthRepository:
|
|
"""Repository for user OAuth operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize repository with database session."""
|
|
self.session = session
|
|
|
|
async def get_by_provider_user_id(
|
|
self,
|
|
provider: str,
|
|
provider_user_id: str,
|
|
) -> UserOauth | None:
|
|
"""Get user OAuth by provider and provider user ID."""
|
|
try:
|
|
statement = select(UserOauth).where(
|
|
UserOauth.provider == provider,
|
|
UserOauth.provider_user_id == provider_user_id,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to get user OAuth by provider user ID: %s:%s",
|
|
provider,
|
|
provider_user_id,
|
|
)
|
|
raise
|
|
|
|
async def get_by_user_id_and_provider(
|
|
self,
|
|
user_id: int,
|
|
provider: str,
|
|
) -> UserOauth | None:
|
|
"""Get user OAuth by user ID and provider."""
|
|
try:
|
|
statement = select(UserOauth).where(
|
|
UserOauth.user_id == user_id,
|
|
UserOauth.provider == provider,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to get user OAuth by user ID and provider: %s:%s",
|
|
user_id,
|
|
provider,
|
|
)
|
|
raise
|
|
else:
|
|
return result.first()
|
|
|
|
async def create(self, oauth_data: dict[str, Any]) -> UserOauth:
|
|
"""Create a new user OAuth record."""
|
|
try:
|
|
oauth = UserOauth(**oauth_data)
|
|
self.session.add(oauth)
|
|
await self.session.commit()
|
|
await self.session.refresh(oauth)
|
|
logger.info(
|
|
"Created OAuth link for user %s with provider %s",
|
|
oauth.user_id,
|
|
oauth.provider,
|
|
)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to create user OAuth")
|
|
raise
|
|
else:
|
|
return oauth
|
|
|
|
async def update(self, oauth: UserOauth, update_data: dict[str, Any]) -> UserOauth:
|
|
"""Update a user OAuth record."""
|
|
try:
|
|
for key, value in update_data.items():
|
|
setattr(oauth, key, value)
|
|
|
|
self.session.add(oauth)
|
|
await self.session.commit()
|
|
await self.session.refresh(oauth)
|
|
logger.info(
|
|
"Updated OAuth link for user %s with provider %s",
|
|
oauth.user_id,
|
|
oauth.provider,
|
|
)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to update user OAuth")
|
|
raise
|
|
else:
|
|
return oauth
|
|
|
|
async def delete(self, oauth: UserOauth) -> None:
|
|
"""Delete a user OAuth record."""
|
|
try:
|
|
await self.session.delete(oauth)
|
|
await self.session.commit()
|
|
logger.info(
|
|
"Deleted OAuth link for user %s with provider %s",
|
|
oauth.user_id,
|
|
oauth.provider,
|
|
)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to delete user OAuth")
|
|
raise
|