Refactor code structure for improved readability and maintainability
This commit is contained in:
@@ -72,18 +72,22 @@ class BaseRepository[ModelType]:
|
||||
logger.exception("Failed to get all %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
||||
async def create(self, entity_data: dict[str, Any] | ModelType) -> ModelType:
|
||||
"""Create a new entity.
|
||||
|
||||
Args:
|
||||
entity_data: Dictionary of entity data
|
||||
entity_data: Dictionary of entity data or model instance
|
||||
|
||||
Returns:
|
||||
The created entity
|
||||
|
||||
"""
|
||||
try:
|
||||
entity = self.model(**entity_data)
|
||||
if isinstance(entity_data, dict):
|
||||
entity = self.model(**entity_data)
|
||||
else:
|
||||
# Already a model instance
|
||||
entity = entity_data
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Repository for scheduled task operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -23,7 +23,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
async def get_pending_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get all pending tasks that are ready to be executed."""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.status == TaskStatus.PENDING,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
@@ -90,7 +90,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
|
||||
"""Get recurring tasks that need their next execution scheduled."""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.recurrence_type != RecurrenceType.NONE,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
@@ -102,7 +102,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
async def get_expired_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get expired tasks that should be cleaned up."""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.expires_at.is_not(None),
|
||||
ScheduledTask.expires_at <= now,
|
||||
@@ -152,7 +152,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
) -> None:
|
||||
"""Mark a task as completed and set next execution if recurring."""
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.last_executed_at = datetime.utcnow()
|
||||
task.last_executed_at = datetime.now(tz=UTC)
|
||||
task.executions_count += 1
|
||||
task.error_message = None
|
||||
|
||||
@@ -170,7 +170,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
"""Mark a task as failed with error message."""
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error_message = error_message
|
||||
task.last_executed_at = datetime.utcnow()
|
||||
task.last_executed_at = datetime.now(tz=UTC)
|
||||
|
||||
# For non-recurring tasks, deactivate on failure
|
||||
if not task.is_recurring():
|
||||
|
||||
Reference in New Issue
Block a user