Files
sdb2-backend/app/repositories/base.py
JSC 6068599a47
All checks were successful
Backend CI / lint (push) Successful in 9m49s
Backend CI / test (push) Successful in 6m15s
Refactor test cases for improved readability and consistency
- Adjusted function signatures in various test files to enhance clarity by aligning parameters.
- Updated patching syntax for better readability across test cases.
- Improved formatting and spacing in test assertions and mock setups.
- Ensured consistent use of async/await patterns in async test functions.
- Enhanced comments for better understanding of test intentions.
2025-08-01 20:53:30 +02:00

151 lines
4.1 KiB
Python

"""Base repository with common CRUD operations."""
from typing import Any, TypeVar
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
# Type variable for the model
ModelType = TypeVar("ModelType")
logger = get_logger(__name__)
class BaseRepository[ModelType]:
"""Base repository with common CRUD operations."""
def __init__(self, model: type[ModelType], session: AsyncSession) -> None:
"""Initialize the repository.
Args:
model: The SQLModel class
session: Database session
"""
self.model = model
self.session = session
async def get_by_id(self, entity_id: int) -> ModelType | None:
"""Get an entity by ID.
Args:
entity_id: The entity ID
Returns:
The entity if found, None otherwise
"""
try:
statement = select(self.model).where(self.model.id == entity_id)
result = await self.session.exec(statement)
return result.first()
except Exception:
logger.exception(
"Failed to get %s by ID: %s",
self.model.__name__,
entity_id,
)
raise
async def get_all(
self,
limit: int = 100,
offset: int = 0,
) -> list[ModelType]:
"""Get all entities with pagination.
Args:
limit: Maximum number of entities to return
offset: Number of entities to skip
Returns:
List of entities
"""
try:
statement = select(self.model).limit(limit).offset(offset)
result = await self.session.exec(statement)
return list(result.all())
except Exception:
logger.exception("Failed to get all %s", self.model.__name__)
raise
async def create(self, entity_data: dict[str, Any]) -> ModelType:
"""Create a new entity.
Args:
entity_data: Dictionary of entity data
Returns:
The created entity
"""
try:
entity = self.model(**entity_data)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
except Exception:
await self.session.rollback()
logger.exception("Failed to create %s", self.model.__name__)
raise
else:
logger.info(
"Created new %s with ID: %s",
self.model.__name__,
getattr(entity, "id", "unknown"),
)
return entity
async def update(self, entity: ModelType, update_data: dict[str, Any]) -> ModelType:
"""Update an entity.
Args:
entity: The entity to update
update_data: Dictionary of fields to update
Returns:
The updated entity
"""
try:
for field, value in update_data.items():
setattr(entity, field, value)
self.session.add(entity)
await self.session.commit()
await self.session.refresh(entity)
except Exception:
await self.session.rollback()
logger.exception("Failed to update %s", self.model.__name__)
raise
else:
logger.info(
"Updated %s with ID: %s",
self.model.__name__,
getattr(entity, "id", "unknown"),
)
return entity
async def delete(self, entity: ModelType) -> None:
"""Delete an entity.
Args:
entity: The entity to delete
"""
try:
await self.session.delete(entity)
await self.session.commit()
logger.info(
"Deleted %s with ID: %s",
self.model.__name__,
getattr(entity, "id", "unknown"),
)
except Exception:
await self.session.rollback()
logger.exception("Failed to delete %s", self.model.__name__)
raise