Add comprehensive tests for scheduled task repository, scheduler service, and task handlers

- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates.
- Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs.
- Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback.
- Ensured proper error handling and edge cases in task execution scenarios.
- Added fixtures and mocks to facilitate isolated testing of services and repositories.
This commit is contained in:
JSC
2025-08-28 22:37:43 +02:00
parent 7dee6e320e
commit 03abed6d39
23 changed files with 3415 additions and 103 deletions

View File

@@ -0,0 +1,177 @@
"""Repository for scheduled task operations."""
from datetime import datetime
from typing import List, Optional
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
from app.repositories.base import BaseRepository
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
"""Repository for scheduled task database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the repository."""
super().__init__(ScheduledTask, session)
async def get_pending_tasks(self) -> List[ScheduledTask]:
"""Get all pending tasks that are ready to be executed."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.status == TaskStatus.PENDING,
ScheduledTask.is_active.is_(True),
ScheduledTask.scheduled_at <= now,
)
result = await self.session.exec(statement)
return list(result.all())
async def get_active_tasks(self) -> List[ScheduledTask]:
"""Get all active tasks."""
statement = select(ScheduledTask).where(
ScheduledTask.is_active.is_(True),
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
)
result = await self.session.exec(statement)
return list(result.all())
async def get_user_tasks(
self,
user_id: int,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ScheduledTask]:
"""Get tasks for a specific user."""
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await self.session.exec(statement)
return list(result.all())
async def get_system_tasks(
self,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
) -> List[ScheduledTask]:
"""Get system tasks (tasks with no user association)."""
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
result = await self.session.exec(statement)
return list(result.all())
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
"""Get recurring tasks that need their next execution scheduled."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.recurrence_type != RecurrenceType.NONE,
ScheduledTask.is_active.is_(True),
ScheduledTask.status == TaskStatus.COMPLETED,
ScheduledTask.next_execution_at <= now,
)
result = await self.session.exec(statement)
return list(result.all())
async def get_expired_tasks(self) -> List[ScheduledTask]:
"""Get expired tasks that should be cleaned up."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.expires_at.is_not(None),
ScheduledTask.expires_at <= now,
ScheduledTask.is_active.is_(True),
)
result = await self.session.exec(statement)
return list(result.all())
async def cancel_user_tasks(
self,
user_id: int,
task_type: Optional[TaskType] = None,
) -> int:
"""Cancel all pending/running tasks for a user."""
statement = select(ScheduledTask).where(
ScheduledTask.user_id == user_id,
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
result = await self.session.exec(statement)
tasks = list(result.all())
count = 0
for task in tasks:
task.status = TaskStatus.CANCELLED
task.is_active = False
self.session.add(task)
count += 1
await self.session.commit()
return count
async def mark_as_running(self, task: ScheduledTask) -> None:
"""Mark a task as running."""
task.status = TaskStatus.RUNNING
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
async def mark_as_completed(
self,
task: ScheduledTask,
next_execution_at: Optional[datetime] = None,
) -> None:
"""Mark a task as completed and set next execution if recurring."""
task.status = TaskStatus.COMPLETED
task.last_executed_at = datetime.utcnow()
task.executions_count += 1
task.error_message = None
if next_execution_at and task.should_repeat():
task.next_execution_at = next_execution_at
task.status = TaskStatus.PENDING
elif not task.should_repeat():
task.is_active = False
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
"""Mark a task as failed with error message."""
task.status = TaskStatus.FAILED
task.error_message = error_message
task.last_executed_at = datetime.utcnow()
# For non-recurring tasks, deactivate on failure
if not task.is_recurring():
task.is_active = False
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)