Add comprehensive tests for scheduled task repository, scheduler service, and task handlers
- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates. - Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs. - Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback. - Ensured proper error handling and edge cases in task execution scenarios. - Added fixtures and mocks to facilitate isolated testing of services and repositories.
This commit is contained in:
494
tests/test_scheduled_task_repository.py
Normal file
494
tests/test_scheduled_task_repository.py
Normal file
@@ -0,0 +1,494 @@
|
||||
"""Tests for scheduled task repository."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
class TestScheduledTaskRepository:
|
||||
"""Test cases for scheduled task repository."""
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, db_session: AsyncSession) -> ScheduledTaskRepository:
|
||||
"""Create repository fixture."""
|
||||
return ScheduledTaskRepository(db_session)
|
||||
|
||||
@pytest.fixture
|
||||
async def sample_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
) -> ScheduledTask:
|
||||
"""Create a sample scheduled task."""
|
||||
task = ScheduledTask(
|
||||
name="Test Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
parameters={"test": "value"},
|
||||
)
|
||||
return await repository.create(task)
|
||||
|
||||
@pytest.fixture
|
||||
async def user_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
test_user_id: uuid.UUID,
|
||||
) -> ScheduledTask:
|
||||
"""Create a user task."""
|
||||
task = ScheduledTask(
|
||||
name="User Task",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=2),
|
||||
user_id=test_user_id,
|
||||
parameters={"sound_id": str(uuid.uuid4())},
|
||||
)
|
||||
return await repository.create(task)
|
||||
|
||||
async def test_create_task(self, repository: ScheduledTaskRepository):
|
||||
"""Test creating a scheduled task."""
|
||||
task = ScheduledTask(
|
||||
name="Test Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
timezone="America/New_York",
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
parameters={"test": "value"},
|
||||
)
|
||||
|
||||
created_task = await repository.create(task)
|
||||
|
||||
assert created_task.id is not None
|
||||
assert created_task.name == "Test Task"
|
||||
assert created_task.task_type == TaskType.CREDIT_RECHARGE
|
||||
assert created_task.status == TaskStatus.PENDING
|
||||
assert created_task.timezone == "America/New_York"
|
||||
assert created_task.recurrence_type == RecurrenceType.DAILY
|
||||
assert created_task.parameters == {"test": "value"}
|
||||
assert created_task.is_active is True
|
||||
assert created_task.executions_count == 0
|
||||
|
||||
async def test_get_pending_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test getting pending tasks."""
|
||||
# Create tasks with different statuses and times
|
||||
past_pending = ScheduledTask(
|
||||
name="Past Pending",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
await repository.create(past_pending)
|
||||
|
||||
future_pending = ScheduledTask(
|
||||
name="Future Pending",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
await repository.create(future_pending)
|
||||
|
||||
completed_task = ScheduledTask(
|
||||
name="Completed",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
status=TaskStatus.COMPLETED,
|
||||
)
|
||||
await repository.create(completed_task)
|
||||
|
||||
inactive_task = ScheduledTask(
|
||||
name="Inactive",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
status=TaskStatus.PENDING,
|
||||
is_active=False,
|
||||
)
|
||||
await repository.create(inactive_task)
|
||||
|
||||
pending_tasks = await repository.get_pending_tasks()
|
||||
task_names = [task.name for task in pending_tasks]
|
||||
|
||||
# Only the past pending task should be returned
|
||||
assert len(pending_tasks) == 1
|
||||
assert "Past Pending" in task_names
|
||||
|
||||
async def test_get_user_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
user_task: ScheduledTask,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test getting tasks for a specific user."""
|
||||
# Create another user's task
|
||||
other_user_id = uuid.uuid4()
|
||||
other_task = ScheduledTask(
|
||||
name="Other User Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=other_user_id,
|
||||
)
|
||||
await repository.create(other_task)
|
||||
|
||||
# Create system task (no user)
|
||||
system_task = ScheduledTask(
|
||||
name="System Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
await repository.create(system_task)
|
||||
|
||||
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||
|
||||
assert len(user_tasks) == 1
|
||||
assert user_tasks[0].name == "User Task"
|
||||
assert user_tasks[0].user_id == test_user_id
|
||||
|
||||
async def test_get_user_tasks_with_filters(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test getting user tasks with status and type filters."""
|
||||
# Create tasks with different statuses and types
|
||||
tasks_data = [
|
||||
("Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
|
||||
("Task 2", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
|
||||
("Task 3", TaskStatus.PENDING, TaskType.PLAY_SOUND),
|
||||
("Task 4", TaskStatus.FAILED, TaskType.PLAY_PLAYLIST),
|
||||
]
|
||||
|
||||
for name, status, task_type in tasks_data:
|
||||
task = ScheduledTask(
|
||||
name=name,
|
||||
task_type=task_type,
|
||||
status=status,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=test_user_id,
|
||||
)
|
||||
await repository.create(task)
|
||||
|
||||
# Test status filter
|
||||
pending_tasks = await repository.get_user_tasks(
|
||||
test_user_id,
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
assert len(pending_tasks) == 2
|
||||
assert all(task.status == TaskStatus.PENDING for task in pending_tasks)
|
||||
|
||||
# Test type filter
|
||||
credit_tasks = await repository.get_user_tasks(
|
||||
test_user_id,
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
assert len(credit_tasks) == 2
|
||||
assert all(task.task_type == TaskType.CREDIT_RECHARGE for task in credit_tasks)
|
||||
|
||||
# Test combined filters
|
||||
pending_credit_tasks = await repository.get_user_tasks(
|
||||
test_user_id,
|
||||
status=TaskStatus.PENDING,
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
assert len(pending_credit_tasks) == 1
|
||||
assert pending_credit_tasks[0].name == "Task 1"
|
||||
|
||||
async def test_get_system_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
user_task: ScheduledTask,
|
||||
):
|
||||
"""Test getting system tasks."""
|
||||
system_tasks = await repository.get_system_tasks()
|
||||
|
||||
assert len(system_tasks) == 1
|
||||
assert system_tasks[0].name == "Test Task"
|
||||
assert system_tasks[0].user_id is None
|
||||
|
||||
async def test_get_recurring_tasks_due_for_next_execution(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test getting recurring tasks due for next execution."""
|
||||
# Create completed recurring task that should be re-executed
|
||||
due_task = ScheduledTask(
|
||||
name="Due Recurring",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
status=TaskStatus.COMPLETED,
|
||||
next_execution_at=datetime.utcnow() - timedelta(minutes=5),
|
||||
)
|
||||
await repository.create(due_task)
|
||||
|
||||
# Create completed recurring task not due yet
|
||||
not_due_task = ScheduledTask(
|
||||
name="Not Due Recurring",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
status=TaskStatus.COMPLETED,
|
||||
next_execution_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
await repository.create(not_due_task)
|
||||
|
||||
# Create non-recurring completed task
|
||||
non_recurring = ScheduledTask(
|
||||
name="Non-recurring",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.NONE,
|
||||
status=TaskStatus.COMPLETED,
|
||||
)
|
||||
await repository.create(non_recurring)
|
||||
|
||||
due_tasks = await repository.get_recurring_tasks_due_for_next_execution()
|
||||
|
||||
assert len(due_tasks) == 1
|
||||
assert due_tasks[0].name == "Due Recurring"
|
||||
|
||||
async def test_get_expired_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test getting expired tasks."""
|
||||
# Create expired task
|
||||
expired_task = ScheduledTask(
|
||||
name="Expired Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||
)
|
||||
await repository.create(expired_task)
|
||||
|
||||
# Create non-expired task
|
||||
valid_task = ScheduledTask(
|
||||
name="Valid Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
expires_at=datetime.utcnow() + timedelta(hours=2),
|
||||
)
|
||||
await repository.create(valid_task)
|
||||
|
||||
# Create task with no expiry
|
||||
no_expiry_task = ScheduledTask(
|
||||
name="No Expiry",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
await repository.create(no_expiry_task)
|
||||
|
||||
expired_tasks = await repository.get_expired_tasks()
|
||||
|
||||
assert len(expired_tasks) == 1
|
||||
assert expired_tasks[0].name == "Expired Task"
|
||||
|
||||
async def test_cancel_user_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test cancelling user tasks."""
|
||||
# Create multiple user tasks
|
||||
tasks_data = [
|
||||
("Pending Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
|
||||
("Running Task", TaskStatus.RUNNING, TaskType.PLAY_SOUND),
|
||||
("Completed Task", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
|
||||
]
|
||||
|
||||
for name, status, task_type in tasks_data:
|
||||
task = ScheduledTask(
|
||||
name=name,
|
||||
task_type=task_type,
|
||||
status=status,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=test_user_id,
|
||||
)
|
||||
await repository.create(task)
|
||||
|
||||
# Cancel all user tasks
|
||||
cancelled_count = await repository.cancel_user_tasks(test_user_id)
|
||||
|
||||
assert cancelled_count == 2 # Only pending and running tasks
|
||||
|
||||
# Verify tasks are cancelled
|
||||
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||
pending_or_running = [
|
||||
task for task in user_tasks
|
||||
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING]
|
||||
]
|
||||
cancelled_tasks = [
|
||||
task for task in user_tasks
|
||||
if task.status == TaskStatus.CANCELLED
|
||||
]
|
||||
|
||||
assert len(pending_or_running) == 0
|
||||
assert len(cancelled_tasks) == 2
|
||||
|
||||
async def test_cancel_user_tasks_by_type(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test cancelling user tasks by type."""
|
||||
# Create tasks of different types
|
||||
credit_task = ScheduledTask(
|
||||
name="Credit Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=test_user_id,
|
||||
)
|
||||
await repository.create(credit_task)
|
||||
|
||||
sound_task = ScheduledTask(
|
||||
name="Sound Task",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=test_user_id,
|
||||
)
|
||||
await repository.create(sound_task)
|
||||
|
||||
# Cancel only credit tasks
|
||||
cancelled_count = await repository.cancel_user_tasks(
|
||||
test_user_id,
|
||||
TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
|
||||
assert cancelled_count == 1
|
||||
|
||||
# Verify only credit task is cancelled
|
||||
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||
credit_tasks = [
|
||||
task for task in user_tasks
|
||||
if task.task_type == TaskType.CREDIT_RECHARGE
|
||||
]
|
||||
sound_tasks = [
|
||||
task for task in user_tasks
|
||||
if task.task_type == TaskType.PLAY_SOUND
|
||||
]
|
||||
|
||||
assert len(credit_tasks) == 1
|
||||
assert credit_tasks[0].status == TaskStatus.CANCELLED
|
||||
assert len(sound_tasks) == 1
|
||||
assert sound_tasks[0].status == TaskStatus.PENDING
|
||||
|
||||
async def test_mark_as_running(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking task as running."""
|
||||
await repository.mark_as_running(sample_task)
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.RUNNING
|
||||
|
||||
async def test_mark_as_completed(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking task as completed."""
|
||||
initial_count = sample_task.executions_count
|
||||
next_execution = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
await repository.mark_as_completed(sample_task, next_execution)
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.COMPLETED
|
||||
assert updated_task.executions_count == initial_count + 1
|
||||
assert updated_task.last_executed_at is not None
|
||||
assert updated_task.error_message is None
|
||||
|
||||
async def test_mark_as_completed_recurring_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test marking recurring task as completed."""
|
||||
task = ScheduledTask(
|
||||
name="Recurring Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
)
|
||||
created_task = await repository.create(task)
|
||||
|
||||
next_execution = datetime.utcnow() + timedelta(days=1)
|
||||
await repository.mark_as_completed(created_task, next_execution)
|
||||
|
||||
updated_task = await repository.get_by_id(created_task.id)
|
||||
# Should be set back to pending for next execution
|
||||
assert updated_task.status == TaskStatus.PENDING
|
||||
assert updated_task.next_execution_at == next_execution
|
||||
assert updated_task.is_active is True
|
||||
|
||||
async def test_mark_as_completed_non_recurring_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking non-recurring task as completed."""
|
||||
await repository.mark_as_completed(sample_task, None)
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.COMPLETED
|
||||
assert updated_task.is_active is False
|
||||
|
||||
async def test_mark_as_failed(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking task as failed."""
|
||||
error_message = "Task execution failed"
|
||||
|
||||
await repository.mark_as_failed(sample_task, error_message)
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.FAILED
|
||||
assert updated_task.error_message == error_message
|
||||
assert updated_task.last_executed_at is not None
|
||||
|
||||
async def test_mark_as_failed_recurring_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test marking recurring task as failed."""
|
||||
task = ScheduledTask(
|
||||
name="Recurring Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
)
|
||||
created_task = await repository.create(task)
|
||||
|
||||
await repository.mark_as_failed(created_task, "Failed")
|
||||
|
||||
updated_task = await repository.get_by_id(created_task.id)
|
||||
assert updated_task.status == TaskStatus.FAILED
|
||||
# Recurring tasks should remain active even after failure
|
||||
assert updated_task.is_active is True
|
||||
|
||||
async def test_mark_as_failed_non_recurring_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking non-recurring task as failed."""
|
||||
await repository.mark_as_failed(sample_task, "Failed")
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.FAILED
|
||||
# Non-recurring tasks should be deactivated on failure
|
||||
assert updated_task.is_active is False
|
||||
Reference in New Issue
Block a user