"""Database utility functions for common operations.""" from typing import Any, Dict, List, Optional, Type, TypeVar from sqlmodel import select, SQLModel from sqlmodel.ext.asyncio.session import AsyncSession T = TypeVar("T", bound=SQLModel) async def create_and_save( session: AsyncSession, model_class: Type[T], **kwargs: Any ) -> T: """Create, add, commit, and refresh a model instance. This consolidates the common database pattern of: - instance = ModelClass(**kwargs) - session.add(instance) - await session.commit() - await session.refresh(instance) Args: session: Database session model_class: SQLModel class to instantiate **kwargs: Arguments to pass to model constructor Returns: Created and refreshed model instance """ instance = model_class(**kwargs) session.add(instance) await session.commit() await session.refresh(instance) return instance async def get_or_create( session: AsyncSession, model_class: Type[T], defaults: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> tuple[T, bool]: """Get an existing instance or create a new one. Args: session: Database session model_class: SQLModel class defaults: Default values for creation (if not found) **kwargs: Filter criteria for lookup Returns: Tuple of (instance, created) where created is True if instance was created """ # Build filter conditions filters = [] for key, value in kwargs.items(): filters.append(getattr(model_class, key) == value) # Try to find existing instance statement = select(model_class).where(*filters) result = await session.exec(statement) instance = result.first() if instance: return instance, False # Create new instance create_kwargs = {**kwargs} if defaults: create_kwargs.update(defaults) instance = await create_and_save(session, model_class, **create_kwargs) return instance, True async def update_and_save( session: AsyncSession, instance: T, **updates: Any ) -> T: """Update model instance fields and save to database. Args: session: Database session instance: Model instance to update **updates: Field updates to apply Returns: Updated and refreshed model instance """ for field, value in updates.items(): setattr(instance, field, value) session.add(instance) await session.commit() await session.refresh(instance) return instance async def bulk_create( session: AsyncSession, model_class: Type[T], items: List[Dict[str, Any]] ) -> List[T]: """Create multiple model instances in bulk. Args: session: Database session model_class: SQLModel class to instantiate items: List of dictionaries with model data Returns: List of created model instances """ instances = [] for item_data in items: instance = model_class(**item_data) session.add(instance) instances.append(instance) await session.commit() # Refresh all instances for instance in instances: await session.refresh(instance) return instances async def delete_and_commit( session: AsyncSession, instance: T ) -> None: """Delete an instance and commit the transaction. Args: session: Database session instance: Model instance to delete """ await session.delete(instance) await session.commit() async def exists( session: AsyncSession, model_class: Type[T], **kwargs: Any ) -> bool: """Check if a model instance exists with given criteria. Args: session: Database session model_class: SQLModel class **kwargs: Filter criteria Returns: True if instance exists, False otherwise """ filters = [] for key, value in kwargs.items(): filters.append(getattr(model_class, key) == value) statement = select(model_class).where(*filters) result = await session.exec(statement) return result.first() is not None