- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
182 lines
6.1 KiB
Python
182 lines
6.1 KiB
Python
"""Repository for scheduled task operations."""
|
|
|
|
from datetime import datetime
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.scheduled_task import (
|
|
RecurrenceType,
|
|
ScheduledTask,
|
|
TaskStatus,
|
|
TaskType,
|
|
)
|
|
from app.repositories.base import BaseRepository
|
|
|
|
|
|
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|
"""Repository for scheduled task database operations."""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize the repository."""
|
|
super().__init__(ScheduledTask, session)
|
|
|
|
async def get_pending_tasks(self) -> list[ScheduledTask]:
|
|
"""Get all pending tasks that are ready to be executed."""
|
|
now = datetime.utcnow()
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.status == TaskStatus.PENDING,
|
|
ScheduledTask.is_active.is_(True),
|
|
ScheduledTask.scheduled_at <= now,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_active_tasks(self) -> list[ScheduledTask]:
|
|
"""Get all active tasks."""
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.is_active.is_(True),
|
|
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_user_tasks(
|
|
self,
|
|
user_id: int,
|
|
status: TaskStatus | None = None,
|
|
task_type: TaskType | None = None,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
) -> list[ScheduledTask]:
|
|
"""Get tasks for a specific user."""
|
|
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
|
|
|
if status:
|
|
statement = statement.where(ScheduledTask.status == status)
|
|
|
|
if task_type:
|
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
|
|
|
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
|
|
|
if offset:
|
|
statement = statement.offset(offset)
|
|
|
|
if limit:
|
|
statement = statement.limit(limit)
|
|
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_system_tasks(
|
|
self,
|
|
status: TaskStatus | None = None,
|
|
task_type: TaskType | None = None,
|
|
) -> list[ScheduledTask]:
|
|
"""Get system tasks (tasks with no user association)."""
|
|
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
|
|
|
if status:
|
|
statement = statement.where(ScheduledTask.status == status)
|
|
|
|
if task_type:
|
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
|
|
|
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
|
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
|
|
"""Get recurring tasks that need their next execution scheduled."""
|
|
now = datetime.utcnow()
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.recurrence_type != RecurrenceType.NONE,
|
|
ScheduledTask.is_active.is_(True),
|
|
ScheduledTask.status == TaskStatus.COMPLETED,
|
|
ScheduledTask.next_execution_at <= now,
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def get_expired_tasks(self) -> list[ScheduledTask]:
|
|
"""Get expired tasks that should be cleaned up."""
|
|
now = datetime.utcnow()
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.expires_at.is_not(None),
|
|
ScheduledTask.expires_at <= now,
|
|
ScheduledTask.is_active.is_(True),
|
|
)
|
|
result = await self.session.exec(statement)
|
|
return list(result.all())
|
|
|
|
async def cancel_user_tasks(
|
|
self,
|
|
user_id: int,
|
|
task_type: TaskType | None = None,
|
|
) -> int:
|
|
"""Cancel all pending/running tasks for a user."""
|
|
statement = select(ScheduledTask).where(
|
|
ScheduledTask.user_id == user_id,
|
|
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
|
)
|
|
|
|
if task_type:
|
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
|
|
|
result = await self.session.exec(statement)
|
|
tasks = list(result.all())
|
|
|
|
count = 0
|
|
for task in tasks:
|
|
task.status = TaskStatus.CANCELLED
|
|
task.is_active = False
|
|
self.session.add(task)
|
|
count += 1
|
|
|
|
await self.session.commit()
|
|
return count
|
|
|
|
async def mark_as_running(self, task: ScheduledTask) -> None:
|
|
"""Mark a task as running."""
|
|
task.status = TaskStatus.RUNNING
|
|
self.session.add(task)
|
|
await self.session.commit()
|
|
await self.session.refresh(task)
|
|
|
|
async def mark_as_completed(
|
|
self,
|
|
task: ScheduledTask,
|
|
next_execution_at: datetime | None = None,
|
|
) -> None:
|
|
"""Mark a task as completed and set next execution if recurring."""
|
|
task.status = TaskStatus.COMPLETED
|
|
task.last_executed_at = datetime.utcnow()
|
|
task.executions_count += 1
|
|
task.error_message = None
|
|
|
|
if next_execution_at and task.should_repeat():
|
|
task.next_execution_at = next_execution_at
|
|
task.status = TaskStatus.PENDING
|
|
elif not task.should_repeat():
|
|
task.is_active = False
|
|
|
|
self.session.add(task)
|
|
await self.session.commit()
|
|
await self.session.refresh(task)
|
|
|
|
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
|
|
"""Mark a task as failed with error message."""
|
|
task.status = TaskStatus.FAILED
|
|
task.error_message = error_message
|
|
task.last_executed_at = datetime.utcnow()
|
|
|
|
# For non-recurring tasks, deactivate on failure
|
|
if not task.is_recurring():
|
|
task.is_active = False
|
|
|
|
self.session.add(task)
|
|
await self.session.commit()
|
|
await self.session.refresh(task)
|