"""Base repository with common CRUD operations.""" from typing import Any, Generic, TypeVar from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.core.logging import get_logger # Type variable for the model ModelType = TypeVar("ModelType") logger = get_logger(__name__) class BaseRepository(Generic[ModelType]): """Base repository with common CRUD operations.""" def __init__(self, model: type[ModelType], session: AsyncSession) -> None: """Initialize the repository. Args: model: The SQLModel class session: Database session """ self.model = model self.session = session async def get_by_id(self, entity_id: int) -> ModelType | None: """Get an entity by ID. Args: entity_id: The entity ID Returns: The entity if found, None otherwise """ try: statement = select(self.model).where(getattr(self.model, "id") == entity_id) result = await self.session.exec(statement) return result.first() except Exception: logger.exception("Failed to get %s by ID: %s", self.model.__name__, entity_id) raise async def get_all( self, limit: int = 100, offset: int = 0, ) -> list[ModelType]: """Get all entities with pagination. Args: limit: Maximum number of entities to return offset: Number of entities to skip Returns: List of entities """ try: statement = select(self.model).limit(limit).offset(offset) result = await self.session.exec(statement) return list(result.all()) except Exception: logger.exception("Failed to get all %s", self.model.__name__) raise async def create(self, entity_data: dict[str, Any]) -> ModelType: """Create a new entity. Args: entity_data: Dictionary of entity data Returns: The created entity """ try: entity = self.model(**entity_data) self.session.add(entity) await self.session.commit() await self.session.refresh(entity) logger.info("Created new %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown")) return entity except Exception: await self.session.rollback() logger.exception("Failed to create %s", self.model.__name__) raise async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType: """Update an entity. Args: entity: The entity to update update_data: Dictionary of fields to update Returns: The updated entity """ try: for field, value in update_data.items(): setattr(entity, field, value) self.session.add(entity) await self.session.commit() await self.session.refresh(entity) logger.info("Updated %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown")) return entity except Exception: await self.session.rollback() logger.exception("Failed to update %s", self.model.__name__) raise async def delete(self, entity: ModelType) -> None: """Delete an entity. Args: entity: The entity to delete """ try: await self.session.delete(entity) await self.session.commit() logger.info("Deleted %s with ID: %s", self.model.__name__, getattr(entity, "id", "unknown")) except Exception: await self.session.rollback() logger.exception("Failed to delete %s", self.model.__name__) raise