"""Repository for scheduled task operations.""" from datetime import datetime from typing import List, Optional from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType from app.repositories.base import BaseRepository class ScheduledTaskRepository(BaseRepository[ScheduledTask]): """Repository for scheduled task database operations.""" def __init__(self, session: AsyncSession) -> None: """Initialize the repository.""" super().__init__(ScheduledTask, session) async def get_pending_tasks(self) -> List[ScheduledTask]: """Get all pending tasks that are ready to be executed.""" now = datetime.utcnow() statement = select(ScheduledTask).where( ScheduledTask.status == TaskStatus.PENDING, ScheduledTask.is_active.is_(True), ScheduledTask.scheduled_at <= now, ) result = await self.session.exec(statement) return list(result.all()) async def get_active_tasks(self) -> List[ScheduledTask]: """Get all active tasks.""" statement = select(ScheduledTask).where( ScheduledTask.is_active.is_(True), ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]), ) result = await self.session.exec(statement) return list(result.all()) async def get_user_tasks( self, user_id: int, status: Optional[TaskStatus] = None, task_type: Optional[TaskType] = None, limit: Optional[int] = None, offset: Optional[int] = None, ) -> List[ScheduledTask]: """Get tasks for a specific user.""" statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id) if status: statement = statement.where(ScheduledTask.status == status) if task_type: statement = statement.where(ScheduledTask.task_type == task_type) statement = statement.order_by(ScheduledTask.scheduled_at.desc()) if offset: statement = statement.offset(offset) if limit: statement = statement.limit(limit) result = await self.session.exec(statement) return list(result.all()) async def get_system_tasks( self, status: Optional[TaskStatus] = None, task_type: Optional[TaskType] = None, ) -> List[ScheduledTask]: """Get system tasks (tasks with no user association).""" statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None)) if status: statement = statement.where(ScheduledTask.status == status) if task_type: statement = statement.where(ScheduledTask.task_type == task_type) statement = statement.order_by(ScheduledTask.scheduled_at.desc()) result = await self.session.exec(statement) return list(result.all()) async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]: """Get recurring tasks that need their next execution scheduled.""" now = datetime.utcnow() statement = select(ScheduledTask).where( ScheduledTask.recurrence_type != RecurrenceType.NONE, ScheduledTask.is_active.is_(True), ScheduledTask.status == TaskStatus.COMPLETED, ScheduledTask.next_execution_at <= now, ) result = await self.session.exec(statement) return list(result.all()) async def get_expired_tasks(self) -> List[ScheduledTask]: """Get expired tasks that should be cleaned up.""" now = datetime.utcnow() statement = select(ScheduledTask).where( ScheduledTask.expires_at.is_not(None), ScheduledTask.expires_at <= now, ScheduledTask.is_active.is_(True), ) result = await self.session.exec(statement) return list(result.all()) async def cancel_user_tasks( self, user_id: int, task_type: Optional[TaskType] = None, ) -> int: """Cancel all pending/running tasks for a user.""" statement = select(ScheduledTask).where( ScheduledTask.user_id == user_id, ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]), ) if task_type: statement = statement.where(ScheduledTask.task_type == task_type) result = await self.session.exec(statement) tasks = list(result.all()) count = 0 for task in tasks: task.status = TaskStatus.CANCELLED task.is_active = False self.session.add(task) count += 1 await self.session.commit() return count async def mark_as_running(self, task: ScheduledTask) -> None: """Mark a task as running.""" task.status = TaskStatus.RUNNING self.session.add(task) await self.session.commit() await self.session.refresh(task) async def mark_as_completed( self, task: ScheduledTask, next_execution_at: Optional[datetime] = None, ) -> None: """Mark a task as completed and set next execution if recurring.""" task.status = TaskStatus.COMPLETED task.last_executed_at = datetime.utcnow() task.executions_count += 1 task.error_message = None if next_execution_at and task.should_repeat(): task.next_execution_at = next_execution_at task.status = TaskStatus.PENDING elif not task.should_repeat(): task.is_active = False self.session.add(task) await self.session.commit() await self.session.refresh(task) async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None: """Mark a task as failed with error message.""" task.status = TaskStatus.FAILED task.error_message = error_message task.last_executed_at = datetime.utcnow() # For non-recurring tasks, deactivate on failure if not task.is_recurring(): task.is_active = False self.session.add(task) await self.session.commit() await self.session.refresh(task)