Add comprehensive tests for scheduled task repository, scheduler service, and task handlers
- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates. - Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs. - Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback. - Ensured proper error handling and edge cases in task execution scenarios. - Added fixtures and mocks to facilitate isolated testing of services and repositories.
This commit is contained in:
177
app/repositories/scheduled_task.py
Normal file
177
app/repositories/scheduled_task.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Repository for scheduled task operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
"""Repository for scheduled task database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the repository."""
|
||||
super().__init__(ScheduledTask, session)
|
||||
|
||||
async def get_pending_tasks(self) -> List[ScheduledTask]:
|
||||
"""Get all pending tasks that are ready to be executed."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.status == TaskStatus.PENDING,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.scheduled_at <= now,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_active_tasks(self) -> List[ScheduledTask]:
|
||||
"""Get all active tasks."""
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_system_tasks(
|
||||
self,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get system tasks (tasks with no user association)."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
|
||||
"""Get recurring tasks that need their next execution scheduled."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.recurrence_type != RecurrenceType.NONE,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.status == TaskStatus.COMPLETED,
|
||||
ScheduledTask.next_execution_at <= now,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_expired_tasks(self) -> List[ScheduledTask]:
|
||||
"""Get expired tasks that should be cleaned up."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.expires_at.is_not(None),
|
||||
ScheduledTask.expires_at <= now,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def cancel_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
task_type: Optional[TaskType] = None,
|
||||
) -> int:
|
||||
"""Cancel all pending/running tasks for a user."""
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.user_id == user_id,
|
||||
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
||||
)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
tasks = list(result.all())
|
||||
|
||||
count = 0
|
||||
for task in tasks:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
self.session.add(task)
|
||||
count += 1
|
||||
|
||||
await self.session.commit()
|
||||
return count
|
||||
|
||||
async def mark_as_running(self, task: ScheduledTask) -> None:
|
||||
"""Mark a task as running."""
|
||||
task.status = TaskStatus.RUNNING
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
|
||||
async def mark_as_completed(
|
||||
self,
|
||||
task: ScheduledTask,
|
||||
next_execution_at: Optional[datetime] = None,
|
||||
) -> None:
|
||||
"""Mark a task as completed and set next execution if recurring."""
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.last_executed_at = datetime.utcnow()
|
||||
task.executions_count += 1
|
||||
task.error_message = None
|
||||
|
||||
if next_execution_at and task.should_repeat():
|
||||
task.next_execution_at = next_execution_at
|
||||
task.status = TaskStatus.PENDING
|
||||
elif not task.should_repeat():
|
||||
task.is_active = False
|
||||
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
|
||||
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
|
||||
"""Mark a task as failed with error message."""
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error_message = error_message
|
||||
task.last_executed_at = datetime.utcnow()
|
||||
|
||||
# For non-recurring tasks, deactivate on failure
|
||||
if not task.is_recurring():
|
||||
task.is_active = False
|
||||
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
Reference in New Issue
Block a user