166 lines
4.2 KiB
Python
166 lines
4.2 KiB
Python
"""Database utility functions for common operations."""
|
|
|
|
from typing import Any, Dict, List, Optional, Type, TypeVar
|
|
from sqlmodel import select, SQLModel
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
T = TypeVar("T", bound=SQLModel)
|
|
|
|
|
|
async def create_and_save(
|
|
session: AsyncSession,
|
|
model_class: Type[T],
|
|
**kwargs: Any
|
|
) -> T:
|
|
"""Create, add, commit, and refresh a model instance.
|
|
|
|
This consolidates the common database pattern of:
|
|
- instance = ModelClass(**kwargs)
|
|
- session.add(instance)
|
|
- await session.commit()
|
|
- await session.refresh(instance)
|
|
|
|
Args:
|
|
session: Database session
|
|
model_class: SQLModel class to instantiate
|
|
**kwargs: Arguments to pass to model constructor
|
|
|
|
Returns:
|
|
Created and refreshed model instance
|
|
"""
|
|
instance = model_class(**kwargs)
|
|
session.add(instance)
|
|
await session.commit()
|
|
await session.refresh(instance)
|
|
return instance
|
|
|
|
|
|
async def get_or_create(
|
|
session: AsyncSession,
|
|
model_class: Type[T],
|
|
defaults: Optional[Dict[str, Any]] = None,
|
|
**kwargs: Any
|
|
) -> tuple[T, bool]:
|
|
"""Get an existing instance or create a new one.
|
|
|
|
Args:
|
|
session: Database session
|
|
model_class: SQLModel class
|
|
defaults: Default values for creation (if not found)
|
|
**kwargs: Filter criteria for lookup
|
|
|
|
Returns:
|
|
Tuple of (instance, created) where created is True if instance was created
|
|
"""
|
|
# Build filter conditions
|
|
filters = []
|
|
for key, value in kwargs.items():
|
|
filters.append(getattr(model_class, key) == value)
|
|
|
|
# Try to find existing instance
|
|
statement = select(model_class).where(*filters)
|
|
result = await session.exec(statement)
|
|
instance = result.first()
|
|
|
|
if instance:
|
|
return instance, False
|
|
|
|
# Create new instance
|
|
create_kwargs = {**kwargs}
|
|
if defaults:
|
|
create_kwargs.update(defaults)
|
|
|
|
instance = await create_and_save(session, model_class, **create_kwargs)
|
|
return instance, True
|
|
|
|
|
|
async def update_and_save(
|
|
session: AsyncSession,
|
|
instance: T,
|
|
**updates: Any
|
|
) -> T:
|
|
"""Update model instance fields and save to database.
|
|
|
|
Args:
|
|
session: Database session
|
|
instance: Model instance to update
|
|
**updates: Field updates to apply
|
|
|
|
Returns:
|
|
Updated and refreshed model instance
|
|
"""
|
|
for field, value in updates.items():
|
|
setattr(instance, field, value)
|
|
|
|
session.add(instance)
|
|
await session.commit()
|
|
await session.refresh(instance)
|
|
return instance
|
|
|
|
|
|
async def bulk_create(
|
|
session: AsyncSession,
|
|
model_class: Type[T],
|
|
items: List[Dict[str, Any]]
|
|
) -> List[T]:
|
|
"""Create multiple model instances in bulk.
|
|
|
|
Args:
|
|
session: Database session
|
|
model_class: SQLModel class to instantiate
|
|
items: List of dictionaries with model data
|
|
|
|
Returns:
|
|
List of created model instances
|
|
"""
|
|
instances = []
|
|
for item_data in items:
|
|
instance = model_class(**item_data)
|
|
session.add(instance)
|
|
instances.append(instance)
|
|
|
|
await session.commit()
|
|
|
|
# Refresh all instances
|
|
for instance in instances:
|
|
await session.refresh(instance)
|
|
|
|
return instances
|
|
|
|
|
|
async def delete_and_commit(
|
|
session: AsyncSession,
|
|
instance: T
|
|
) -> None:
|
|
"""Delete an instance and commit the transaction.
|
|
|
|
Args:
|
|
session: Database session
|
|
instance: Model instance to delete
|
|
"""
|
|
await session.delete(instance)
|
|
await session.commit()
|
|
|
|
|
|
async def exists(
|
|
session: AsyncSession,
|
|
model_class: Type[T],
|
|
**kwargs: Any
|
|
) -> bool:
|
|
"""Check if a model instance exists with given criteria.
|
|
|
|
Args:
|
|
session: Database session
|
|
model_class: SQLModel class
|
|
**kwargs: Filter criteria
|
|
|
|
Returns:
|
|
True if instance exists, False otherwise
|
|
"""
|
|
filters = []
|
|
for key, value in kwargs.items():
|
|
filters.append(getattr(model_class, key) == value)
|
|
|
|
statement = select(model_class).where(*filters)
|
|
result = await session.exec(statement)
|
|
return result.first() is not None |