Files
sdb2-backend/tests/test_scheduled_task_repository.py

493 lines
17 KiB
Python

"""Tests for scheduled task repository."""
from datetime import UTC, datetime, timedelta
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.scheduled_task import ScheduledTaskRepository
class TestScheduledTaskRepository:
"""Test cases for scheduled task repository."""
@pytest.fixture
def repository(self, db_session: AsyncSession) -> ScheduledTaskRepository:
"""Create repository fixture."""
return ScheduledTaskRepository(db_session)
@pytest.fixture
async def sample_task(
self,
repository: ScheduledTaskRepository,
) -> ScheduledTask:
"""Create a sample scheduled task."""
task_data = {
"name": "Test Task",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": datetime.now(tz=UTC) + timedelta(hours=1),
"parameters": {"test": "value"},
}
return await repository.create(task_data)
@pytest.fixture
async def user_task(
self,
repository: ScheduledTaskRepository,
test_user_id: int,
) -> ScheduledTask:
"""Create a user task."""
task_data = {
"name": "User Task",
"task_type": TaskType.PLAY_SOUND,
"scheduled_at": datetime.now(tz=UTC) + timedelta(hours=2),
"user_id": test_user_id,
"parameters": {"sound_id": "1"},
}
return await repository.create(task_data)
async def test_create_task(self, repository: ScheduledTaskRepository):
"""Test creating a scheduled task."""
task_data = {
"name": "Test Task",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": datetime.now(tz=UTC) + timedelta(hours=1),
"timezone": "America/New_York",
"recurrence_type": RecurrenceType.DAILY,
"parameters": {"test": "value"},
}
created_task = await repository.create(task_data)
assert created_task.id is not None
assert created_task.name == "Test Task"
assert created_task.task_type == TaskType.CREDIT_RECHARGE
assert created_task.status == TaskStatus.PENDING
assert created_task.timezone == "America/New_York"
assert created_task.recurrence_type == RecurrenceType.DAILY
assert created_task.parameters == {"test": "value"}
assert created_task.is_active is True
assert created_task.executions_count == 0
async def test_get_pending_tasks(
self,
repository: ScheduledTaskRepository,
):
"""Test getting pending tasks."""
# Create tasks with different statuses and times
past_pending = ScheduledTask(
name="Past Pending",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
status=TaskStatus.PENDING,
)
await repository.create(past_pending)
future_pending = ScheduledTask(
name="Future Pending",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
status=TaskStatus.PENDING,
)
await repository.create(future_pending)
completed_task = ScheduledTask(
name="Completed",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
status=TaskStatus.COMPLETED,
)
await repository.create(completed_task)
inactive_task = ScheduledTask(
name="Inactive",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
status=TaskStatus.PENDING,
is_active=False,
)
await repository.create(inactive_task)
pending_tasks = await repository.get_pending_tasks()
task_names = [task.name for task in pending_tasks]
# Only the past pending task should be returned
assert len(pending_tasks) == 1
assert "Past Pending" in task_names
async def test_get_user_tasks(
self,
repository: ScheduledTaskRepository,
user_task: ScheduledTask,
test_user_id: int,
):
"""Test getting tasks for a specific user."""
# Create another user's task
other_user_id = 999
other_task = ScheduledTask(
name="Other User Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=other_user_id,
)
await repository.create(other_task)
# Create system task (no user)
system_task = ScheduledTask(
name="System Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
await repository.create(system_task)
user_tasks = await repository.get_user_tasks(test_user_id)
assert len(user_tasks) == 1
assert user_tasks[0].name == "User Task"
assert user_tasks[0].user_id == test_user_id
async def test_get_user_tasks_with_filters(
self,
repository: ScheduledTaskRepository,
test_user_id: int,
):
"""Test getting user tasks with status and type filters."""
# Create tasks with different statuses and types
tasks_data = [
("Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
("Task 2", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
("Task 3", TaskStatus.PENDING, TaskType.PLAY_SOUND),
("Task 4", TaskStatus.FAILED, TaskType.PLAY_PLAYLIST),
]
for name, status, task_type in tasks_data:
task = ScheduledTask(
name=name,
task_type=task_type,
status=status,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=test_user_id,
)
await repository.create(task)
# Test status filter
pending_tasks = await repository.get_user_tasks(
test_user_id,
status=TaskStatus.PENDING,
)
assert len(pending_tasks) == 2
assert all(task.status == TaskStatus.PENDING for task in pending_tasks)
# Test type filter
credit_tasks = await repository.get_user_tasks(
test_user_id,
task_type=TaskType.CREDIT_RECHARGE,
)
assert len(credit_tasks) == 2
assert all(task.task_type == TaskType.CREDIT_RECHARGE for task in credit_tasks)
# Test combined filters
pending_credit_tasks = await repository.get_user_tasks(
test_user_id,
status=TaskStatus.PENDING,
task_type=TaskType.CREDIT_RECHARGE,
)
assert len(pending_credit_tasks) == 1
assert pending_credit_tasks[0].name == "Task 1"
async def test_get_system_tasks(
self,
repository: ScheduledTaskRepository,
sample_task: ScheduledTask,
user_task: ScheduledTask,
):
"""Test getting system tasks."""
system_tasks = await repository.get_system_tasks()
assert len(system_tasks) == 1
assert system_tasks[0].name == "Test Task"
assert system_tasks[0].user_id is None
async def test_get_recurring_tasks_due_for_next_execution(
self,
repository: ScheduledTaskRepository,
):
"""Test getting recurring tasks due for next execution."""
# Create completed recurring task that should be re-executed
due_task = ScheduledTask(
name="Due Recurring",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
status=TaskStatus.COMPLETED,
next_execution_at=datetime.now(tz=UTC) - timedelta(minutes=5),
)
await repository.create(due_task)
# Create completed recurring task not due yet
not_due_task = ScheduledTask(
name="Not Due Recurring",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
status=TaskStatus.COMPLETED,
next_execution_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
await repository.create(not_due_task)
# Create non-recurring completed task
non_recurring = ScheduledTask(
name="Non-recurring",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
recurrence_type=RecurrenceType.NONE,
status=TaskStatus.COMPLETED,
)
await repository.create(non_recurring)
due_tasks = await repository.get_recurring_tasks_due_for_next_execution()
assert len(due_tasks) == 1
assert due_tasks[0].name == "Due Recurring"
async def test_get_expired_tasks(
self,
repository: ScheduledTaskRepository,
):
"""Test getting expired tasks."""
# Create expired task
expired_task = ScheduledTask(
name="Expired Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
expires_at=datetime.now(tz=UTC) - timedelta(hours=1),
)
await repository.create(expired_task)
# Create non-expired task
valid_task = ScheduledTask(
name="Valid Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
expires_at=datetime.now(tz=UTC) + timedelta(hours=2),
)
await repository.create(valid_task)
# Create task with no expiry
no_expiry_task = ScheduledTask(
name="No Expiry",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
await repository.create(no_expiry_task)
expired_tasks = await repository.get_expired_tasks()
assert len(expired_tasks) == 1
assert expired_tasks[0].name == "Expired Task"
async def test_cancel_user_tasks(
self,
repository: ScheduledTaskRepository,
test_user_id: int,
):
"""Test cancelling user tasks."""
# Create multiple user tasks
tasks_data = [
("Pending Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
("Running Task", TaskStatus.RUNNING, TaskType.PLAY_SOUND),
("Completed Task", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
]
for name, status, task_type in tasks_data:
task = ScheduledTask(
name=name,
task_type=task_type,
status=status,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=test_user_id,
)
await repository.create(task)
# Cancel all user tasks
cancelled_count = await repository.cancel_user_tasks(test_user_id)
assert cancelled_count == 2 # Only pending and running tasks
# Verify tasks are cancelled
user_tasks = await repository.get_user_tasks(test_user_id)
pending_or_running = [
task for task in user_tasks
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING]
]
cancelled_tasks = [
task for task in user_tasks
if task.status == TaskStatus.CANCELLED
]
assert len(pending_or_running) == 0
assert len(cancelled_tasks) == 2
async def test_cancel_user_tasks_by_type(
self,
repository: ScheduledTaskRepository,
test_user_id: int,
):
"""Test cancelling user tasks by type."""
# Create tasks of different types
credit_task = ScheduledTask(
name="Credit Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=test_user_id,
)
await repository.create(credit_task)
sound_task = ScheduledTask(
name="Sound Task",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=test_user_id,
)
await repository.create(sound_task)
# Cancel only credit tasks
cancelled_count = await repository.cancel_user_tasks(
test_user_id,
TaskType.CREDIT_RECHARGE,
)
assert cancelled_count == 1
# Verify only credit task is cancelled
user_tasks = await repository.get_user_tasks(test_user_id)
credit_tasks = [
task for task in user_tasks
if task.task_type == TaskType.CREDIT_RECHARGE
]
sound_tasks = [
task for task in user_tasks
if task.task_type == TaskType.PLAY_SOUND
]
assert len(credit_tasks) == 1
assert credit_tasks[0].status == TaskStatus.CANCELLED
assert len(sound_tasks) == 1
assert sound_tasks[0].status == TaskStatus.PENDING
async def test_mark_as_running(
self,
repository: ScheduledTaskRepository,
sample_task: ScheduledTask,
):
"""Test marking task as running."""
await repository.mark_as_running(sample_task)
updated_task = await repository.get_by_id(sample_task.id)
assert updated_task.status == TaskStatus.RUNNING
async def test_mark_as_completed(
self,
repository: ScheduledTaskRepository,
sample_task: ScheduledTask,
):
"""Test marking task as completed."""
initial_count = sample_task.executions_count
next_execution = datetime.now(tz=UTC) + timedelta(days=1)
await repository.mark_as_completed(sample_task, next_execution)
updated_task = await repository.get_by_id(sample_task.id)
assert updated_task.status == TaskStatus.COMPLETED
assert updated_task.executions_count == initial_count + 1
assert updated_task.last_executed_at is not None
assert updated_task.error_message is None
async def test_mark_as_completed_recurring_task(
self,
repository: ScheduledTaskRepository,
):
"""Test marking recurring task as completed."""
task = ScheduledTask(
name="Recurring Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC),
recurrence_type=RecurrenceType.DAILY,
)
created_task = await repository.create(task)
next_execution = datetime.now(tz=UTC).replace(tzinfo=None) + timedelta(days=1)
await repository.mark_as_completed(created_task, next_execution)
updated_task = await repository.get_by_id(created_task.id)
# Should be set back to pending for next execution
assert updated_task.status == TaskStatus.PENDING
assert updated_task.next_execution_at == next_execution
assert updated_task.is_active is True
async def test_mark_as_completed_non_recurring_task(
self,
repository: ScheduledTaskRepository,
sample_task: ScheduledTask,
):
"""Test marking non-recurring task as completed."""
await repository.mark_as_completed(sample_task, None)
updated_task = await repository.get_by_id(sample_task.id)
assert updated_task.status == TaskStatus.COMPLETED
assert updated_task.is_active is False
async def test_mark_as_failed(
self,
repository: ScheduledTaskRepository,
sample_task: ScheduledTask,
):
"""Test marking task as failed."""
error_message = "Task execution failed"
await repository.mark_as_failed(sample_task, error_message)
updated_task = await repository.get_by_id(sample_task.id)
assert updated_task.status == TaskStatus.FAILED
assert updated_task.error_message == error_message
assert updated_task.last_executed_at is not None
async def test_mark_as_failed_recurring_task(
self,
repository: ScheduledTaskRepository,
):
"""Test marking recurring task as failed."""
task = ScheduledTask(
name="Recurring Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC),
recurrence_type=RecurrenceType.DAILY,
)
created_task = await repository.create(task)
await repository.mark_as_failed(created_task, "Failed")
updated_task = await repository.get_by_id(created_task.id)
assert updated_task.status == TaskStatus.FAILED
# Recurring tasks should remain active even after failure
assert updated_task.is_active is True
async def test_mark_as_failed_non_recurring_task(
self,
repository: ScheduledTaskRepository,
sample_task: ScheduledTask,
):
"""Test marking non-recurring task as failed."""
await repository.mark_as_failed(sample_task, "Failed")
updated_task = await repository.get_by_id(sample_task.id)
assert updated_task.status == TaskStatus.FAILED
# Non-recurring tasks should be deactivated on failure
assert updated_task.is_active is False