152 lines
4.2 KiB
Python
152 lines
4.2 KiB
Python
"""Base repository with common CRUD operations."""
|
|
|
|
from typing import Any, TypeVar
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
|
|
# Type variable for the model
|
|
ModelType = TypeVar("ModelType")
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BaseRepository[ModelType]:
|
|
"""Base repository with common CRUD operations."""
|
|
|
|
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
|
|
"""Initialize the repository.
|
|
|
|
Args:
|
|
model: The SQLModel class
|
|
session: Database session
|
|
|
|
"""
|
|
self.model = model
|
|
self.session = session
|
|
|
|
async def get_by_id(self, entity_id: int) -> ModelType | None:
|
|
"""Get an entity by ID.
|
|
|
|
Args:
|
|
entity_id: The entity ID
|
|
|
|
Returns:
|
|
The entity if found, None otherwise
|
|
|
|
"""
|
|
try:
|
|
statement = select(self.model).where(self.model.id == entity_id)
|
|
result = await self.session.exec(statement)
|
|
return result.first()
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to get %s by ID: %s",
|
|
self.model.__name__,
|
|
entity_id,
|
|
)
|
|
raise
|
|
|
|
async def get_all(
|
|
self,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[ModelType]:
|
|
"""Get all entities with pagination.
|
|
|
|
Args:
|
|
limit: Maximum number of entities to return
|
|
offset: Number of entities to skip
|
|
|
|
Returns:
|
|
List of entities
|
|
|
|
"""
|
|
try:
|
|
statement = select(self.model).limit(limit).offset(offset)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
except Exception:
|
|
logger.exception("Failed to get all %s", self.model.__name__)
|
|
raise
|
|
|
|
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
|
"""Create a new entity.
|
|
|
|
Args:
|
|
entity_data: Dictionary of entity data
|
|
|
|
Returns:
|
|
The created entity
|
|
|
|
"""
|
|
try:
|
|
entity = self.model(**entity_data)
|
|
self.session.add(entity)
|
|
await self.session.commit()
|
|
await self.session.refresh(entity)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to create %s", self.model.__name__)
|
|
raise
|
|
else:
|
|
logger.info(
|
|
"Created new %s with ID: %s",
|
|
self.model.__name__,
|
|
getattr(entity, "id", "unknown"),
|
|
)
|
|
return entity
|
|
|
|
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
|
|
"""Update an entity.
|
|
|
|
Args:
|
|
entity: The entity to update
|
|
update_data: Dictionary of fields to update
|
|
|
|
Returns:
|
|
The updated entity
|
|
|
|
"""
|
|
try:
|
|
for field, value in update_data.items():
|
|
setattr(entity, field, value)
|
|
|
|
# The updated_at timestamp will be automatically set by the SQLAlchemy event listener
|
|
self.session.add(entity)
|
|
await self.session.commit()
|
|
await self.session.refresh(entity)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to update %s", self.model.__name__)
|
|
raise
|
|
else:
|
|
logger.info(
|
|
"Updated %s with ID: %s",
|
|
self.model.__name__,
|
|
getattr(entity, "id", "unknown"),
|
|
)
|
|
return entity
|
|
|
|
async def delete(self, entity: ModelType) -> None:
|
|
"""Delete an entity.
|
|
|
|
Args:
|
|
entity: The entity to delete
|
|
|
|
"""
|
|
try:
|
|
await self.session.delete(entity)
|
|
await self.session.commit()
|
|
logger.info(
|
|
"Deleted %s with ID: %s",
|
|
self.model.__name__,
|
|
getattr(entity, "id", "unknown"),
|
|
)
|
|
except Exception:
|
|
await self.session.rollback()
|
|
logger.exception("Failed to delete %s", self.model.__name__)
|
|
raise
|