- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
494 lines
17 KiB
Python
494 lines
17 KiB
Python
"""Tests for scheduled task repository."""
|
|
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
|
|
import pytest
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.models.scheduled_task import (
|
|
RecurrenceType,
|
|
ScheduledTask,
|
|
TaskStatus,
|
|
TaskType,
|
|
)
|
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
|
|
|
|
|
class TestScheduledTaskRepository:
|
|
"""Test cases for scheduled task repository."""
|
|
|
|
@pytest.fixture
|
|
def repository(self, db_session: AsyncSession) -> ScheduledTaskRepository:
|
|
"""Create repository fixture."""
|
|
return ScheduledTaskRepository(db_session)
|
|
|
|
@pytest.fixture
|
|
async def sample_task(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
) -> ScheduledTask:
|
|
"""Create a sample scheduled task."""
|
|
task = ScheduledTask(
|
|
name="Test Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
parameters={"test": "value"},
|
|
)
|
|
return await repository.create(task)
|
|
|
|
@pytest.fixture
|
|
async def user_task(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
test_user_id: uuid.UUID,
|
|
) -> ScheduledTask:
|
|
"""Create a user task."""
|
|
task = ScheduledTask(
|
|
name="User Task",
|
|
task_type=TaskType.PLAY_SOUND,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=2),
|
|
user_id=test_user_id,
|
|
parameters={"sound_id": str(uuid.uuid4())},
|
|
)
|
|
return await repository.create(task)
|
|
|
|
async def test_create_task(self, repository: ScheduledTaskRepository):
|
|
"""Test creating a scheduled task."""
|
|
task = ScheduledTask(
|
|
name="Test Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
timezone="America/New_York",
|
|
recurrence_type=RecurrenceType.DAILY,
|
|
parameters={"test": "value"},
|
|
)
|
|
|
|
created_task = await repository.create(task)
|
|
|
|
assert created_task.id is not None
|
|
assert created_task.name == "Test Task"
|
|
assert created_task.task_type == TaskType.CREDIT_RECHARGE
|
|
assert created_task.status == TaskStatus.PENDING
|
|
assert created_task.timezone == "America/New_York"
|
|
assert created_task.recurrence_type == RecurrenceType.DAILY
|
|
assert created_task.parameters == {"test": "value"}
|
|
assert created_task.is_active is True
|
|
assert created_task.executions_count == 0
|
|
|
|
async def test_get_pending_tasks(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
):
|
|
"""Test getting pending tasks."""
|
|
# Create tasks with different statuses and times
|
|
past_pending = ScheduledTask(
|
|
name="Past Pending",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
|
status=TaskStatus.PENDING,
|
|
)
|
|
await repository.create(past_pending)
|
|
|
|
future_pending = ScheduledTask(
|
|
name="Future Pending",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
status=TaskStatus.PENDING,
|
|
)
|
|
await repository.create(future_pending)
|
|
|
|
completed_task = ScheduledTask(
|
|
name="Completed",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
|
status=TaskStatus.COMPLETED,
|
|
)
|
|
await repository.create(completed_task)
|
|
|
|
inactive_task = ScheduledTask(
|
|
name="Inactive",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
|
status=TaskStatus.PENDING,
|
|
is_active=False,
|
|
)
|
|
await repository.create(inactive_task)
|
|
|
|
pending_tasks = await repository.get_pending_tasks()
|
|
task_names = [task.name for task in pending_tasks]
|
|
|
|
# Only the past pending task should be returned
|
|
assert len(pending_tasks) == 1
|
|
assert "Past Pending" in task_names
|
|
|
|
async def test_get_user_tasks(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
user_task: ScheduledTask,
|
|
test_user_id: uuid.UUID,
|
|
):
|
|
"""Test getting tasks for a specific user."""
|
|
# Create another user's task
|
|
other_user_id = uuid.uuid4()
|
|
other_task = ScheduledTask(
|
|
name="Other User Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
user_id=other_user_id,
|
|
)
|
|
await repository.create(other_task)
|
|
|
|
# Create system task (no user)
|
|
system_task = ScheduledTask(
|
|
name="System Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
)
|
|
await repository.create(system_task)
|
|
|
|
user_tasks = await repository.get_user_tasks(test_user_id)
|
|
|
|
assert len(user_tasks) == 1
|
|
assert user_tasks[0].name == "User Task"
|
|
assert user_tasks[0].user_id == test_user_id
|
|
|
|
async def test_get_user_tasks_with_filters(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
test_user_id: uuid.UUID,
|
|
):
|
|
"""Test getting user tasks with status and type filters."""
|
|
# Create tasks with different statuses and types
|
|
tasks_data = [
|
|
("Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
|
|
("Task 2", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
|
|
("Task 3", TaskStatus.PENDING, TaskType.PLAY_SOUND),
|
|
("Task 4", TaskStatus.FAILED, TaskType.PLAY_PLAYLIST),
|
|
]
|
|
|
|
for name, status, task_type in tasks_data:
|
|
task = ScheduledTask(
|
|
name=name,
|
|
task_type=task_type,
|
|
status=status,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
user_id=test_user_id,
|
|
)
|
|
await repository.create(task)
|
|
|
|
# Test status filter
|
|
pending_tasks = await repository.get_user_tasks(
|
|
test_user_id,
|
|
status=TaskStatus.PENDING,
|
|
)
|
|
assert len(pending_tasks) == 2
|
|
assert all(task.status == TaskStatus.PENDING for task in pending_tasks)
|
|
|
|
# Test type filter
|
|
credit_tasks = await repository.get_user_tasks(
|
|
test_user_id,
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
)
|
|
assert len(credit_tasks) == 2
|
|
assert all(task.task_type == TaskType.CREDIT_RECHARGE for task in credit_tasks)
|
|
|
|
# Test combined filters
|
|
pending_credit_tasks = await repository.get_user_tasks(
|
|
test_user_id,
|
|
status=TaskStatus.PENDING,
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
)
|
|
assert len(pending_credit_tasks) == 1
|
|
assert pending_credit_tasks[0].name == "Task 1"
|
|
|
|
async def test_get_system_tasks(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
sample_task: ScheduledTask,
|
|
user_task: ScheduledTask,
|
|
):
|
|
"""Test getting system tasks."""
|
|
system_tasks = await repository.get_system_tasks()
|
|
|
|
assert len(system_tasks) == 1
|
|
assert system_tasks[0].name == "Test Task"
|
|
assert system_tasks[0].user_id is None
|
|
|
|
async def test_get_recurring_tasks_due_for_next_execution(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
):
|
|
"""Test getting recurring tasks due for next execution."""
|
|
# Create completed recurring task that should be re-executed
|
|
due_task = ScheduledTask(
|
|
name="Due Recurring",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
|
recurrence_type=RecurrenceType.DAILY,
|
|
status=TaskStatus.COMPLETED,
|
|
next_execution_at=datetime.utcnow() - timedelta(minutes=5),
|
|
)
|
|
await repository.create(due_task)
|
|
|
|
# Create completed recurring task not due yet
|
|
not_due_task = ScheduledTask(
|
|
name="Not Due Recurring",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
|
recurrence_type=RecurrenceType.DAILY,
|
|
status=TaskStatus.COMPLETED,
|
|
next_execution_at=datetime.utcnow() + timedelta(hours=1),
|
|
)
|
|
await repository.create(not_due_task)
|
|
|
|
# Create non-recurring completed task
|
|
non_recurring = ScheduledTask(
|
|
name="Non-recurring",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
|
recurrence_type=RecurrenceType.NONE,
|
|
status=TaskStatus.COMPLETED,
|
|
)
|
|
await repository.create(non_recurring)
|
|
|
|
due_tasks = await repository.get_recurring_tasks_due_for_next_execution()
|
|
|
|
assert len(due_tasks) == 1
|
|
assert due_tasks[0].name == "Due Recurring"
|
|
|
|
async def test_get_expired_tasks(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
):
|
|
"""Test getting expired tasks."""
|
|
# Create expired task
|
|
expired_task = ScheduledTask(
|
|
name="Expired Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
expires_at=datetime.utcnow() - timedelta(hours=1),
|
|
)
|
|
await repository.create(expired_task)
|
|
|
|
# Create non-expired task
|
|
valid_task = ScheduledTask(
|
|
name="Valid Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
expires_at=datetime.utcnow() + timedelta(hours=2),
|
|
)
|
|
await repository.create(valid_task)
|
|
|
|
# Create task with no expiry
|
|
no_expiry_task = ScheduledTask(
|
|
name="No Expiry",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
)
|
|
await repository.create(no_expiry_task)
|
|
|
|
expired_tasks = await repository.get_expired_tasks()
|
|
|
|
assert len(expired_tasks) == 1
|
|
assert expired_tasks[0].name == "Expired Task"
|
|
|
|
async def test_cancel_user_tasks(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
test_user_id: uuid.UUID,
|
|
):
|
|
"""Test cancelling user tasks."""
|
|
# Create multiple user tasks
|
|
tasks_data = [
|
|
("Pending Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
|
|
("Running Task", TaskStatus.RUNNING, TaskType.PLAY_SOUND),
|
|
("Completed Task", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
|
|
]
|
|
|
|
for name, status, task_type in tasks_data:
|
|
task = ScheduledTask(
|
|
name=name,
|
|
task_type=task_type,
|
|
status=status,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
user_id=test_user_id,
|
|
)
|
|
await repository.create(task)
|
|
|
|
# Cancel all user tasks
|
|
cancelled_count = await repository.cancel_user_tasks(test_user_id)
|
|
|
|
assert cancelled_count == 2 # Only pending and running tasks
|
|
|
|
# Verify tasks are cancelled
|
|
user_tasks = await repository.get_user_tasks(test_user_id)
|
|
pending_or_running = [
|
|
task for task in user_tasks
|
|
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING]
|
|
]
|
|
cancelled_tasks = [
|
|
task for task in user_tasks
|
|
if task.status == TaskStatus.CANCELLED
|
|
]
|
|
|
|
assert len(pending_or_running) == 0
|
|
assert len(cancelled_tasks) == 2
|
|
|
|
async def test_cancel_user_tasks_by_type(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
test_user_id: uuid.UUID,
|
|
):
|
|
"""Test cancelling user tasks by type."""
|
|
# Create tasks of different types
|
|
credit_task = ScheduledTask(
|
|
name="Credit Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
user_id=test_user_id,
|
|
)
|
|
await repository.create(credit_task)
|
|
|
|
sound_task = ScheduledTask(
|
|
name="Sound Task",
|
|
task_type=TaskType.PLAY_SOUND,
|
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
|
user_id=test_user_id,
|
|
)
|
|
await repository.create(sound_task)
|
|
|
|
# Cancel only credit tasks
|
|
cancelled_count = await repository.cancel_user_tasks(
|
|
test_user_id,
|
|
TaskType.CREDIT_RECHARGE,
|
|
)
|
|
|
|
assert cancelled_count == 1
|
|
|
|
# Verify only credit task is cancelled
|
|
user_tasks = await repository.get_user_tasks(test_user_id)
|
|
credit_tasks = [
|
|
task for task in user_tasks
|
|
if task.task_type == TaskType.CREDIT_RECHARGE
|
|
]
|
|
sound_tasks = [
|
|
task for task in user_tasks
|
|
if task.task_type == TaskType.PLAY_SOUND
|
|
]
|
|
|
|
assert len(credit_tasks) == 1
|
|
assert credit_tasks[0].status == TaskStatus.CANCELLED
|
|
assert len(sound_tasks) == 1
|
|
assert sound_tasks[0].status == TaskStatus.PENDING
|
|
|
|
async def test_mark_as_running(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
sample_task: ScheduledTask,
|
|
):
|
|
"""Test marking task as running."""
|
|
await repository.mark_as_running(sample_task)
|
|
|
|
updated_task = await repository.get_by_id(sample_task.id)
|
|
assert updated_task.status == TaskStatus.RUNNING
|
|
|
|
async def test_mark_as_completed(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
sample_task: ScheduledTask,
|
|
):
|
|
"""Test marking task as completed."""
|
|
initial_count = sample_task.executions_count
|
|
next_execution = datetime.utcnow() + timedelta(days=1)
|
|
|
|
await repository.mark_as_completed(sample_task, next_execution)
|
|
|
|
updated_task = await repository.get_by_id(sample_task.id)
|
|
assert updated_task.status == TaskStatus.COMPLETED
|
|
assert updated_task.executions_count == initial_count + 1
|
|
assert updated_task.last_executed_at is not None
|
|
assert updated_task.error_message is None
|
|
|
|
async def test_mark_as_completed_recurring_task(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
):
|
|
"""Test marking recurring task as completed."""
|
|
task = ScheduledTask(
|
|
name="Recurring Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow(),
|
|
recurrence_type=RecurrenceType.DAILY,
|
|
)
|
|
created_task = await repository.create(task)
|
|
|
|
next_execution = datetime.utcnow() + timedelta(days=1)
|
|
await repository.mark_as_completed(created_task, next_execution)
|
|
|
|
updated_task = await repository.get_by_id(created_task.id)
|
|
# Should be set back to pending for next execution
|
|
assert updated_task.status == TaskStatus.PENDING
|
|
assert updated_task.next_execution_at == next_execution
|
|
assert updated_task.is_active is True
|
|
|
|
async def test_mark_as_completed_non_recurring_task(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
sample_task: ScheduledTask,
|
|
):
|
|
"""Test marking non-recurring task as completed."""
|
|
await repository.mark_as_completed(sample_task, None)
|
|
|
|
updated_task = await repository.get_by_id(sample_task.id)
|
|
assert updated_task.status == TaskStatus.COMPLETED
|
|
assert updated_task.is_active is False
|
|
|
|
async def test_mark_as_failed(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
sample_task: ScheduledTask,
|
|
):
|
|
"""Test marking task as failed."""
|
|
error_message = "Task execution failed"
|
|
|
|
await repository.mark_as_failed(sample_task, error_message)
|
|
|
|
updated_task = await repository.get_by_id(sample_task.id)
|
|
assert updated_task.status == TaskStatus.FAILED
|
|
assert updated_task.error_message == error_message
|
|
assert updated_task.last_executed_at is not None
|
|
|
|
async def test_mark_as_failed_recurring_task(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
):
|
|
"""Test marking recurring task as failed."""
|
|
task = ScheduledTask(
|
|
name="Recurring Task",
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
scheduled_at=datetime.utcnow(),
|
|
recurrence_type=RecurrenceType.DAILY,
|
|
)
|
|
created_task = await repository.create(task)
|
|
|
|
await repository.mark_as_failed(created_task, "Failed")
|
|
|
|
updated_task = await repository.get_by_id(created_task.id)
|
|
assert updated_task.status == TaskStatus.FAILED
|
|
# Recurring tasks should remain active even after failure
|
|
assert updated_task.is_active is True
|
|
|
|
async def test_mark_as_failed_non_recurring_task(
|
|
self,
|
|
repository: ScheduledTaskRepository,
|
|
sample_task: ScheduledTask,
|
|
):
|
|
"""Test marking non-recurring task as failed."""
|
|
await repository.mark_as_failed(sample_task, "Failed")
|
|
|
|
updated_task = await repository.get_by_id(sample_task.id)
|
|
assert updated_task.status == TaskStatus.FAILED
|
|
# Non-recurring tasks should be deactivated on failure
|
|
assert updated_task.is_active is False
|