"""Authentication service.""" import hashlib from datetime import UTC, datetime, timedelta from fastapi import HTTPException, status from sqlmodel.ext.asyncio.session import AsyncSession from app.core.config import settings from app.core.logging import get_logger from app.models.user import User from app.repositories.user import UserRepository from app.repositories.user_oauth import UserOauthRepository from app.schemas.auth import ( AuthResponse, TokenResponse, UserLoginRequest, UserRegisterRequest, UserResponse, ) from app.services.oauth import OAuthUserInfo from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils logger = get_logger(__name__) class AuthService: """Service for authentication operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the auth service.""" self.session = session self.user_repo = UserRepository(session) self.oauth_repo = UserOauthRepository(session) async def register(self, request: UserRegisterRequest) -> AuthResponse: """Register a new user.""" logger.info("Attempting to register user with email: %s", request.email) # Check if email already exists if await self.user_repo.email_exists(request.email): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email address is already registered", ) # Hash the password hashed_password = PasswordUtils.hash_password(request.password) # Create user data user_data = { "email": request.email, "name": request.name, "password_hash": hashed_password, "role": "user", "is_active": True, } # Create the user user = await self.user_repo.create(user_data) # Generate access token token = self._create_access_token(user) # Create response user_response = await self.create_user_response(user) logger.info("Successfully registered user: %s", user.email) return AuthResponse(user=user_response, token=token) async def login(self, request: UserLoginRequest) -> AuthResponse: """Authenticate a user login.""" logger.info("Attempting to login user with email: %s", request.email) # Get user by email user = await self.user_repo.get_by_email(request.email) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password", ) # Check if user is active if not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated", ) # Verify password if not user.password_hash or not PasswordUtils.verify_password( request.password, user.password_hash, ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password", ) # Generate access token token = self._create_access_token(user) # Create response user_response = await self.create_user_response(user) logger.info("Successfully authenticated user: %s", user.email) return AuthResponse(user=user_response, token=token) async def get_current_user(self, user_id: int) -> User: """Get the current authenticated user.""" user = await self.user_repo.get_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found", ) if not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated", ) return user async def get_user_by_api_token(self, api_token: str) -> User | None: """Get a user by their API token.""" return await self.user_repo.get_by_api_token(api_token) async def generate_api_token(self, user: User, expires_days: int = 365) -> str: """Generate a new API token for a user.""" # Generate a secure random token api_token = TokenUtils.generate_api_token() # Set expiration date expires_at = datetime.now(UTC) + timedelta(days=expires_days) # Update user with new API token update_data = { "api_token": api_token, "api_token_expires_at": expires_at, } await self.user_repo.update(user, update_data) logger.info("Generated new API token for user: %s", user.email) return api_token async def revoke_api_token(self, user: User) -> None: """Revoke a user's API token.""" update_data = { "api_token": None, "api_token_expires_at": None, } await self.user_repo.update(user, update_data) logger.info("Revoked API token for user: %s", user.email) def _create_access_token(self, user: User) -> TokenResponse: """Create an access token for a user.""" access_token_expires = timedelta( minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES, ) token_data = { "sub": str(user.id), "email": user.email, "role": user.role, } access_token = JWTUtils.create_access_token( data=token_data, expires_delta=access_token_expires, ) return TokenResponse( access_token=access_token, token_type="bearer", # noqa: S106 # This is OAuth2 standard, not a password expires_in=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) async def create_and_store_refresh_token(self, user: User) -> str: """Create and store a refresh token for a user.""" refresh_token_expires = timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) token_data = { "sub": str(user.id), "email": user.email, } refresh_token = JWTUtils.create_refresh_token( data=token_data, expires_delta=refresh_token_expires, ) # Hash the refresh token for storage refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest() # Store hash and expiration in database user.refresh_token_hash = refresh_token_hash user.refresh_token_expires_at = datetime.now(UTC) + refresh_token_expires self.session.add(user) await self.session.commit() return refresh_token async def refresh_access_token(self, refresh_token: str) -> TokenResponse: """Create a new access token using a refresh token.""" try: # Decode the refresh token payload = JWTUtils.decode_refresh_token(refresh_token) user_id_str = payload.get("sub") if not user_id_str: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) user_id = int(user_id_str) # Get the user user = await self.user_repo.get_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) # Check if refresh token hash matches stored hash refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest() if ( not user.refresh_token_hash or user.refresh_token_hash != refresh_token_hash ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) # Check if refresh token is expired if user.refresh_token_expires_at and datetime.now( UTC, ) > user.refresh_token_expires_at.replace(tzinfo=UTC): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token has expired", ) # Check if user is active if not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated", ) # Create new access token return self._create_access_token(user) except HTTPException: raise except Exception as e: logger.exception("Failed to refresh access token") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) from e async def revoke_refresh_token(self, user: User) -> None: """Revoke a user's refresh token.""" try: # Use the repository to update the user to ensure proper session handling update_data = { "refresh_token_hash": None, "refresh_token_expires_at": None, } await self.user_repo.update(user, update_data) logger.info("Refresh token revoked for user: %s", user.email) except Exception: logger.exception("Failed to revoke refresh token for user: %s", user.email) raise async def create_user_response(self, user: User) -> UserResponse: """Create a user response from a user model.""" # Always refresh to ensure the plan relationship is loaded await self.session.refresh(user, ["plan"]) # Ensure user has an ID (should always be true for persisted users) if user.id is None: msg = "User must have an ID to create response" raise ValueError(msg) return UserResponse( id=user.id, email=user.email, name=user.name, picture=user.picture, role=user.role, credits=user.credits, is_active=user.is_active, plan={ "id": user.plan.id, "code": user.plan.code, "name": user.plan.name, "description": user.plan.description, "credits": user.plan.credits, "max_credits": user.plan.max_credits, }, created_at=user.created_at, updated_at=user.updated_at, ) async def oauth_login(self, oauth_user_info: OAuthUserInfo) -> AuthResponse: """Handle OAuth login - link or create user.""" logger.info( "OAuth login attempt for %s with provider %s", oauth_user_info.email, oauth_user_info.provider, ) # Check if user already has OAuth link for this provider existing_oauth = await self.oauth_repo.get_by_provider_user_id( oauth_user_info.provider, oauth_user_info.provider_user_id, ) if existing_oauth: # User exists with this OAuth provider, get the user user = await self.user_repo.get_by_id(existing_oauth.user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Linked user not found", ) # Refresh user to avoid greenlet issues await self.session.refresh(user) # Update OAuth record with latest info oauth_update_data = { "email": oauth_user_info.email, "name": oauth_user_info.name, "picture": oauth_user_info.picture, } await self.oauth_repo.update(existing_oauth, oauth_update_data) logger.info( "OAuth login successful for existing user: %s", oauth_user_info.email, ) else: # Check if user exists by email user = await self.user_repo.get_by_email(oauth_user_info.email) if user: # Refresh user to avoid greenlet issues await self.session.refresh(user) # Store user picture value to avoid greenlet issues later current_user_picture = user.picture # Link existing user to OAuth provider oauth_data = { "user_id": user.id, "provider": oauth_user_info.provider, "provider_user_id": oauth_user_info.provider_user_id, "email": oauth_user_info.email, "name": oauth_user_info.name, "picture": oauth_user_info.picture, } await self.oauth_repo.create(oauth_data) # Update user profile with OAuth info if needed user_update_data = {} if not current_user_picture and oauth_user_info.picture: user_update_data["picture"] = oauth_user_info.picture if user_update_data: await self.user_repo.update(user, user_update_data) # Refresh user after update to avoid greenlet issues await self.session.refresh(user) logger.info( "Linked existing user %s to OAuth provider %s", oauth_user_info.email, oauth_user_info.provider, ) else: # Create new user user_data = { "email": oauth_user_info.email, "name": oauth_user_info.name, "picture": oauth_user_info.picture, "is_active": True, # No password for OAuth users "password_hash": None, } user = await self.user_repo.create(user_data) # Create OAuth link oauth_data = { "user_id": user.id, "provider": oauth_user_info.provider, "provider_user_id": oauth_user_info.provider_user_id, "email": oauth_user_info.email, "name": oauth_user_info.name, "picture": oauth_user_info.picture, } await self.oauth_repo.create(oauth_data) logger.info( "Created new user %s from OAuth provider %s", oauth_user_info.email, oauth_user_info.provider, ) # Refresh user to avoid greenlet issues and check if user is active await self.session.refresh(user) if not user.is_active: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated", ) # Generate access token token = self._create_access_token(user) # Create response user_response = await self.create_user_response(user) logger.info( "OAuth login completed successfully for user: %s", oauth_user_info.email, ) return AuthResponse(user=user_response, token=token) async def update_user_profile(self, user: User, data: dict) -> User: """Update user profile information.""" logger.info("Updating profile for user: %s", user.email) # Only allow updating specific fields allowed_fields = {"name"} update_data = {k: v for k, v in data.items() if k in allowed_fields} if not update_data: return user # Update user for field, value in update_data.items(): setattr(user, field, value) self.session.add(user) await self.session.commit() await self.session.refresh(user, ["plan"]) logger.info("Profile updated successfully for user: %s", user.email) return user async def change_user_password( self, user: User, current_password: str | None, new_password: str, ) -> None: """Change user's password.""" # Store user email before any operations to avoid session detachment issues user_email = user.email logger.info("Changing password for user: %s", user_email) # Store whether user had existing password before we modify it had_existing_password = user.password_hash is not None # If user has existing password, verify it if had_existing_password: if not current_password: msg = "Current password is required when changing existing password" raise ValueError(msg) if not PasswordUtils.verify_password(current_password, user.password_hash): msg = "Current password is incorrect" raise ValueError(msg) else: # User doesn't have a password (OAuth-only user), setting first password logger.info("Setting first password for OAuth user: %s", user_email) # Hash new password new_password_hash = PasswordUtils.hash_password(new_password) # Update user user.password_hash = new_password_hash self.session.add(user) await self.session.commit() logger.info( "Password %s successfully for user: %s", "changed" if had_existing_password else "set", user_email, ) async def user_to_response(self, user: User) -> UserResponse: """Convert User model to UserResponse with plan information.""" # Load plan relationship if not already loaded if not hasattr(user, "plan") or not user.plan: await self.session.refresh(user, ["plan"]) return UserResponse( id=user.id, email=user.email, name=user.name, picture=user.picture, role=user.role, credits=user.credits, is_active=user.is_active, plan={ "id": user.plan.id, "name": user.plan.name, "max_credits": user.plan.max_credits, "features": [], # Add features if needed }, created_at=user.created_at, updated_at=user.updated_at, ) async def get_user_oauth_providers(self, user: User) -> list: """Get OAuth providers connected to the user.""" return await self.oauth_repo.get_by_user_id(user.id)