Refactor scheduled task repository and schemas for improved type hints and consistency
- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
"""Repository for scheduled task operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
@@ -17,7 +21,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
"""Initialize the repository."""
|
||||
super().__init__(ScheduledTask, session)
|
||||
|
||||
async def get_pending_tasks(self) -> List[ScheduledTask]:
|
||||
async def get_pending_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get all pending tasks that are ready to be executed."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
@@ -28,7 +32,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_active_tasks(self) -> List[ScheduledTask]:
|
||||
async def get_active_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get all active tasks."""
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.is_active.is_(True),
|
||||
@@ -40,11 +44,11 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
||||
|
||||
@@ -67,9 +71,9 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
async def get_system_tasks(
|
||||
self,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get system tasks (tasks with no user association)."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
||||
|
||||
@@ -84,7 +88,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
|
||||
"""Get recurring tasks that need their next execution scheduled."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
@@ -96,7 +100,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_expired_tasks(self) -> List[ScheduledTask]:
|
||||
async def get_expired_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get expired tasks that should be cleaned up."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
@@ -110,7 +114,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
async def cancel_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
task_type: Optional[TaskType] = None,
|
||||
task_type: TaskType | None = None,
|
||||
) -> int:
|
||||
"""Cancel all pending/running tasks for a user."""
|
||||
statement = select(ScheduledTask).where(
|
||||
@@ -144,7 +148,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
async def mark_as_completed(
|
||||
self,
|
||||
task: ScheduledTask,
|
||||
next_execution_at: Optional[datetime] = None,
|
||||
next_execution_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Mark a task as completed and set next execution if recurring."""
|
||||
task.status = TaskStatus.COMPLETED
|
||||
@@ -174,4 +178,4 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
await self.session.refresh(task)
|
||||
|
||||
Reference in New Issue
Block a user