Files
sdb2-backend/app/services/auth.py
JSC 6b55ff0e81 Refactor user endpoint tests to include pagination and response structure validation
- Updated tests for listing users to validate pagination and response format.
- Changed mock return values to include total count and pagination details.
- Refactored user creation mocks for clarity and consistency.
- Enhanced assertions to check for presence of pagination fields in responses.
- Adjusted test cases for user retrieval and updates to ensure proper handling of user data.
- Improved readability by restructuring mock definitions and assertions across various test files.
2025-08-17 12:36:52 +02:00

523 lines
18 KiB
Python

"""Authentication service."""
import hashlib
from datetime import UTC, datetime, timedelta
from fastapi import HTTPException, status
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.models.user import User
from app.repositories.user import UserRepository
from app.repositories.user_oauth import UserOauthRepository
from app.schemas.auth import (
AuthResponse,
TokenResponse,
UserLoginRequest,
UserRegisterRequest,
UserResponse,
)
from app.services.oauth import OAuthUserInfo
from app.utils.auth import JWTUtils, PasswordUtils, TokenUtils
logger = get_logger(__name__)
class AuthService:
"""Service for authentication operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the auth service."""
self.session = session
self.user_repo = UserRepository(session)
self.oauth_repo = UserOauthRepository(session)
async def register(self, request: UserRegisterRequest) -> AuthResponse:
"""Register a new user."""
logger.info("Attempting to register user with email: %s", request.email)
# Check if email already exists
if await self.user_repo.email_exists(request.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email address is already registered",
)
# Hash the password
hashed_password = PasswordUtils.hash_password(request.password)
# Create user data
user_data = {
"email": request.email,
"name": request.name,
"password_hash": hashed_password,
"role": "user",
"is_active": True,
}
# Create the user
user = await self.user_repo.create(user_data)
# Generate access token
token = self._create_access_token(user)
# Create response
user_response = await self.create_user_response(user)
logger.info("Successfully registered user: %s", user.email)
return AuthResponse(user=user_response, token=token)
async def login(self, request: UserLoginRequest) -> AuthResponse:
"""Authenticate a user login."""
logger.info("Attempting to login user with email: %s", request.email)
# Get user by email
user = await self.user_repo.get_by_email(request.email)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
# Verify password
if not user.password_hash or not PasswordUtils.verify_password(
request.password,
user.password_hash,
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
# Generate access token
token = self._create_access_token(user)
# Create response
user_response = await self.create_user_response(user)
logger.info("Successfully authenticated user: %s", user.email)
return AuthResponse(user=user_response, token=token)
async def get_current_user(self, user_id: int) -> User:
"""Get the current authenticated user."""
user = await self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
return user
async def get_user_by_api_token(self, api_token: str) -> User | None:
"""Get a user by their API token."""
return await self.user_repo.get_by_api_token(api_token)
async def generate_api_token(self, user: User, expires_days: int = 365) -> str:
"""Generate a new API token for a user."""
# Generate a secure random token
api_token = TokenUtils.generate_api_token()
# Set expiration date
expires_at = datetime.now(UTC) + timedelta(days=expires_days)
# Update user with new API token
update_data = {
"api_token": api_token,
"api_token_expires_at": expires_at,
}
await self.user_repo.update(user, update_data)
logger.info("Generated new API token for user: %s", user.email)
return api_token
async def revoke_api_token(self, user: User) -> None:
"""Revoke a user's API token."""
update_data = {
"api_token": None,
"api_token_expires_at": None,
}
await self.user_repo.update(user, update_data)
logger.info("Revoked API token for user: %s", user.email)
def _create_access_token(self, user: User) -> TokenResponse:
"""Create an access token for a user."""
access_token_expires = timedelta(
minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES,
)
token_data = {
"sub": str(user.id),
"email": user.email,
"role": user.role,
}
access_token = JWTUtils.create_access_token(
data=token_data,
expires_delta=access_token_expires,
)
return TokenResponse(
access_token=access_token,
token_type="bearer", # noqa: S106 # This is OAuth2 standard, not a password
expires_in=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
async def create_and_store_refresh_token(self, user: User) -> str:
"""Create and store a refresh token for a user."""
refresh_token_expires = timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
token_data = {
"sub": str(user.id),
"email": user.email,
}
refresh_token = JWTUtils.create_refresh_token(
data=token_data,
expires_delta=refresh_token_expires,
)
# Hash the refresh token for storage
refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
# Store hash and expiration in database
user.refresh_token_hash = refresh_token_hash
user.refresh_token_expires_at = datetime.now(UTC) + refresh_token_expires
self.session.add(user)
await self.session.commit()
return refresh_token
async def refresh_access_token(self, refresh_token: str) -> TokenResponse:
"""Create a new access token using a refresh token."""
try:
# Decode the refresh token
payload = JWTUtils.decode_refresh_token(refresh_token)
user_id_str = payload.get("sub")
if not user_id_str:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
user_id = int(user_id_str)
# Get the user
user = await self.user_repo.get_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
# Check if refresh token hash matches stored hash
refresh_token_hash = hashlib.sha256(refresh_token.encode()).hexdigest()
if (
not user.refresh_token_hash
or user.refresh_token_hash != refresh_token_hash
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
)
# Check if refresh token is expired
if user.refresh_token_expires_at and datetime.now(
UTC,
) > user.refresh_token_expires_at.replace(tzinfo=UTC):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token has expired",
)
# Check if user is active
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
# Create new access token
return self._create_access_token(user)
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to refresh access token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid refresh token",
) from e
async def revoke_refresh_token(self, user: User) -> None:
"""Revoke a user's refresh token."""
try:
# Use the repository to update the user to ensure proper session handling
update_data = {
"refresh_token_hash": None,
"refresh_token_expires_at": None,
}
await self.user_repo.update(user, update_data)
logger.info("Refresh token revoked for user: %s", user.email)
except Exception:
logger.exception("Failed to revoke refresh token for user: %s", user.email)
raise
async def create_user_response(self, user: User) -> UserResponse:
"""Create a user response from a user model."""
# Always refresh to ensure the plan relationship is loaded
await self.session.refresh(user, ["plan"])
# Ensure user has an ID (should always be true for persisted users)
if user.id is None:
msg = "User must have an ID to create response"
raise ValueError(msg)
return UserResponse(
id=user.id,
email=user.email,
name=user.name,
picture=user.picture,
role=user.role,
credits=user.credits,
is_active=user.is_active,
plan={
"id": user.plan.id,
"code": user.plan.code,
"name": user.plan.name,
"description": user.plan.description,
"credits": user.plan.credits,
"max_credits": user.plan.max_credits,
},
created_at=user.created_at,
updated_at=user.updated_at,
)
async def oauth_login(self, oauth_user_info: OAuthUserInfo) -> AuthResponse:
"""Handle OAuth login - link or create user."""
logger.info(
"OAuth login attempt for %s with provider %s",
oauth_user_info.email,
oauth_user_info.provider,
)
# Check if user already has OAuth link for this provider
existing_oauth = await self.oauth_repo.get_by_provider_user_id(
oauth_user_info.provider,
oauth_user_info.provider_user_id,
)
if existing_oauth:
# User exists with this OAuth provider, get the user
user = await self.user_repo.get_by_id(existing_oauth.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Linked user not found",
)
# Refresh user to avoid greenlet issues
await self.session.refresh(user)
# Update OAuth record with latest info
oauth_update_data = {
"email": oauth_user_info.email,
"name": oauth_user_info.name,
"picture": oauth_user_info.picture,
}
await self.oauth_repo.update(existing_oauth, oauth_update_data)
logger.info(
"OAuth login successful for existing user: %s",
oauth_user_info.email,
)
else:
# Check if user exists by email
user = await self.user_repo.get_by_email(oauth_user_info.email)
if user:
# Refresh user to avoid greenlet issues
await self.session.refresh(user)
# Store user picture value to avoid greenlet issues later
current_user_picture = user.picture
# Link existing user to OAuth provider
oauth_data = {
"user_id": user.id,
"provider": oauth_user_info.provider,
"provider_user_id": oauth_user_info.provider_user_id,
"email": oauth_user_info.email,
"name": oauth_user_info.name,
"picture": oauth_user_info.picture,
}
await self.oauth_repo.create(oauth_data)
# Update user profile with OAuth info if needed
user_update_data = {}
if not current_user_picture and oauth_user_info.picture:
user_update_data["picture"] = oauth_user_info.picture
if user_update_data:
await self.user_repo.update(user, user_update_data)
# Refresh user after update to avoid greenlet issues
await self.session.refresh(user)
logger.info(
"Linked existing user %s to OAuth provider %s",
oauth_user_info.email,
oauth_user_info.provider,
)
else:
# Create new user
user_data = {
"email": oauth_user_info.email,
"name": oauth_user_info.name,
"picture": oauth_user_info.picture,
"is_active": True,
# No password for OAuth users
"password_hash": None,
}
user = await self.user_repo.create(user_data)
# Create OAuth link
oauth_data = {
"user_id": user.id,
"provider": oauth_user_info.provider,
"provider_user_id": oauth_user_info.provider_user_id,
"email": oauth_user_info.email,
"name": oauth_user_info.name,
"picture": oauth_user_info.picture,
}
await self.oauth_repo.create(oauth_data)
logger.info(
"Created new user %s from OAuth provider %s",
oauth_user_info.email,
oauth_user_info.provider,
)
# Refresh user to avoid greenlet issues and check if user is active
await self.session.refresh(user)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Account is deactivated",
)
# Generate access token
token = self._create_access_token(user)
# Create response
user_response = await self.create_user_response(user)
logger.info(
"OAuth login completed successfully for user: %s",
oauth_user_info.email,
)
return AuthResponse(user=user_response, token=token)
async def update_user_profile(self, user: User, data: dict) -> User:
"""Update user profile information."""
logger.info("Updating profile for user: %s", user.email)
# Only allow updating specific fields
allowed_fields = {"name"}
update_data = {k: v for k, v in data.items() if k in allowed_fields}
if not update_data:
return user
# Update user
for field, value in update_data.items():
setattr(user, field, value)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user, ["plan"])
logger.info("Profile updated successfully for user: %s", user.email)
return user
async def change_user_password(
self,
user: User,
current_password: str | None,
new_password: str,
) -> None:
"""Change user's password."""
# Store user email before any operations to avoid session detachment issues
user_email = user.email
logger.info("Changing password for user: %s", user_email)
# Store whether user had existing password before we modify it
had_existing_password = user.password_hash is not None
# If user has existing password, verify it
if had_existing_password:
if not current_password:
msg = "Current password is required when changing existing password"
raise ValueError(msg)
if not PasswordUtils.verify_password(current_password, user.password_hash):
msg = "Current password is incorrect"
raise ValueError(msg)
else:
# User doesn't have a password (OAuth-only user), setting first password
logger.info("Setting first password for OAuth user: %s", user_email)
# Hash new password
new_password_hash = PasswordUtils.hash_password(new_password)
# Update user
user.password_hash = new_password_hash
self.session.add(user)
await self.session.commit()
logger.info(
"Password %s successfully for user: %s",
"changed" if had_existing_password else "set",
user_email,
)
async def user_to_response(self, user: User) -> UserResponse:
"""Convert User model to UserResponse with plan information."""
# Load plan relationship if not already loaded
if not hasattr(user, "plan") or not user.plan:
await self.session.refresh(user, ["plan"])
return UserResponse(
id=user.id,
email=user.email,
name=user.name,
picture=user.picture,
role=user.role,
credits=user.credits,
is_active=user.is_active,
plan={
"id": user.plan.id,
"name": user.plan.name,
"max_credits": user.plan.max_credits,
"features": [], # Add features if needed
},
created_at=user.created_at,
updated_at=user.updated_at,
)
async def get_user_oauth_providers(self, user: User) -> list:
"""Get OAuth providers connected to the user."""
return await self.oauth_repo.get_by_user_id(user.id)