Files
sdb2-backend/app/repositories/scheduled_task.py
JSC dc89e45675 Refactor scheduled task repository and schemas for improved type hints and consistency
- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase.
- Refactored import statements for better organization and clarity.
- Enhanced the ScheduledTaskBase schema to use modern type hints.
- Cleaned up unnecessary comments and whitespace in various files.
- Improved error handling and logging in task execution handlers.
- Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
2025-08-28 23:38:47 +02:00

182 lines
6.1 KiB
Python

"""Repository for scheduled task operations."""
from datetime import datetime
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.base import BaseRepository
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
"""Repository for scheduled task database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the repository."""
super().__init__(ScheduledTask, session)
async def get_pending_tasks(self) -> list[ScheduledTask]:
"""Get all pending tasks that are ready to be executed."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.status == TaskStatus.PENDING,
ScheduledTask.is_active.is_(True),
ScheduledTask.scheduled_at <= now,
)
result = await self.session.exec(statement)
return list(result.all())
async def get_active_tasks(self) -> list[ScheduledTask]:
"""Get all active tasks."""
statement = select(ScheduledTask).where(
ScheduledTask.is_active.is_(True),
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
)
result = await self.session.exec(statement)
return list(result.all())
async def get_user_tasks(
self,
user_id: int,
status: TaskStatus | None = None,
task_type: TaskType | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[ScheduledTask]:
"""Get tasks for a specific user."""
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await self.session.exec(statement)
return list(result.all())
async def get_system_tasks(
self,
status: TaskStatus | None = None,
task_type: TaskType | None = None,
) -> list[ScheduledTask]:
"""Get system tasks (tasks with no user association)."""
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
result = await self.session.exec(statement)
return list(result.all())
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
"""Get recurring tasks that need their next execution scheduled."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.recurrence_type != RecurrenceType.NONE,
ScheduledTask.is_active.is_(True),
ScheduledTask.status == TaskStatus.COMPLETED,
ScheduledTask.next_execution_at <= now,
)
result = await self.session.exec(statement)
return list(result.all())
async def get_expired_tasks(self) -> list[ScheduledTask]:
"""Get expired tasks that should be cleaned up."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.expires_at.is_not(None),
ScheduledTask.expires_at <= now,
ScheduledTask.is_active.is_(True),
)
result = await self.session.exec(statement)
return list(result.all())
async def cancel_user_tasks(
self,
user_id: int,
task_type: TaskType | None = None,
) -> int:
"""Cancel all pending/running tasks for a user."""
statement = select(ScheduledTask).where(
ScheduledTask.user_id == user_id,
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
result = await self.session.exec(statement)
tasks = list(result.all())
count = 0
for task in tasks:
task.status = TaskStatus.CANCELLED
task.is_active = False
self.session.add(task)
count += 1
await self.session.commit()
return count
async def mark_as_running(self, task: ScheduledTask) -> None:
"""Mark a task as running."""
task.status = TaskStatus.RUNNING
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
async def mark_as_completed(
self,
task: ScheduledTask,
next_execution_at: datetime | None = None,
) -> None:
"""Mark a task as completed and set next execution if recurring."""
task.status = TaskStatus.COMPLETED
task.last_executed_at = datetime.utcnow()
task.executions_count += 1
task.error_message = None
if next_execution_at and task.should_repeat():
task.next_execution_at = next_execution_at
task.status = TaskStatus.PENDING
elif not task.should_repeat():
task.is_active = False
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
"""Mark a task as failed with error message."""
task.status = TaskStatus.FAILED
task.error_message = error_message
task.last_executed_at = datetime.utcnow()
# For non-recurring tasks, deactivate on failure
if not task.is_recurring():
task.is_active = False
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)