182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
"""Repository for scheduled task operations."""
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.scheduled_task import (
|
|
RecurrenceType,
|
|
ScheduledTask,
|
|
TaskStatus,
|
|
TaskType,
|
|
)
|
|
from app.repositories.base import BaseRepository
|
|
|
|
|
|
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|
"""Repository for scheduled task database operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the repository."""
|
|
super().__init__(ScheduledTask, session)
|
|
|
|
async def get_pending_tasks(self) -> list[ScheduledTask]:
|
|
"""Get all pending tasks that are ready to be executed."""
|
|
now = datetime.now(tz=UTC)
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.status == TaskStatus.PENDING,
|
|
ScheduledTask.is_active.is_(True),
|
|
ScheduledTask.scheduled_at <= now,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_active_tasks(self) -> list[ScheduledTask]:
|
|
"""Get all active tasks."""
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.is_active.is_(True),
|
|
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_user_tasks(
|
|
self,
|
|
user_id: int,
|
|
status: TaskStatus | None = None,
|
|
task_type: TaskType | None = None,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
) -> list[ScheduledTask]:
|
|
"""Get tasks for a specific user."""
|
|
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
|
|
|
if status:
|
|
statement = statement.where(ScheduledTask.status == status)
|
|
|
|
if task_type:
|
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
|
|
|
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
|
|
|
if offset:
|
|
statement = statement.offset(offset)
|
|
|
|
if limit:
|
|
statement = statement.limit(limit)
|
|
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_system_tasks(
|
|
self,
|
|
status: TaskStatus | None = None,
|
|
task_type: TaskType | None = None,
|
|
) -> list[ScheduledTask]:
|
|
"""Get system tasks (tasks with no user association)."""
|
|
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
|
|
|
if status:
|
|
statement = statement.where(ScheduledTask.status == status)
|
|
|
|
if task_type:
|
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
|
|
|
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
|
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
|
|
"""Get recurring tasks that need their next execution scheduled."""
|
|
now = datetime.now(tz=UTC)
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.recurrence_type != RecurrenceType.NONE,
|
|
ScheduledTask.is_active.is_(True),
|
|
ScheduledTask.status == TaskStatus.COMPLETED,
|
|
ScheduledTask.next_execution_at <= now,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_expired_tasks(self) -> list[ScheduledTask]:
|
|
"""Get expired tasks that should be cleaned up."""
|
|
now = datetime.now(tz=UTC)
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.expires_at.is_not(None),
|
|
ScheduledTask.expires_at <= now,
|
|
ScheduledTask.is_active.is_(True),
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def cancel_user_tasks(
|
|
self,
|
|
user_id: int,
|
|
task_type: TaskType | None = None,
|
|
) -> int:
|
|
"""Cancel all pending/running tasks for a user."""
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.user_id == user_id,
|
|
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
|
)
|
|
|
|
if task_type:
|
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
|
|
|
result = await self.session.exec(statement)
|
|
tasks = list(result.all())
|
|
|
|
count = 0
|
|
for task in tasks:
|
|
task.status = TaskStatus.CANCELLED
|
|
task.is_active = False
|
|
self.session.add(task)
|
|
count += 1
|
|
|
|
await self.session.commit()
|
|
return count
|
|
|
|
async def mark_as_running(self, task: ScheduledTask) -> None:
|
|
"""Mark a task as running."""
|
|
task.status = TaskStatus.RUNNING
|
|
self.session.add(task)
|
|
await self.session.commit()
|
|
await self.session.refresh(task)
|
|
|
|
async def mark_as_completed(
|
|
self,
|
|
task: ScheduledTask,
|
|
next_execution_at: datetime | None = None,
|
|
) -> None:
|
|
"""Mark a task as completed and set next execution if recurring."""
|
|
task.status = TaskStatus.COMPLETED
|
|
task.last_executed_at = datetime.now(tz=UTC)
|
|
task.executions_count += 1
|
|
task.error_message = None
|
|
|
|
if next_execution_at and task.should_repeat():
|
|
task.next_execution_at = next_execution_at
|
|
task.status = TaskStatus.PENDING
|
|
elif not task.should_repeat():
|
|
task.is_active = False
|
|
|
|
self.session.add(task)
|
|
await self.session.commit()
|
|
await self.session.refresh(task)
|
|
|
|
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
|
|
"""Mark a task as failed with error message."""
|
|
task.status = TaskStatus.FAILED
|
|
task.error_message = error_message
|
|
task.last_executed_at = datetime.now(tz=UTC)
|
|
|
|
# For non-recurring tasks, deactivate on failure
|
|
if not task.is_recurring():
|
|
task.is_active = False
|
|
|
|
self.session.add(task)
|
|
await self.session.commit()
|
|
await self.session.refresh(task)
|