Files
sdb2-backend/tests/test_scheduler_service.py

520 lines
20 KiB
Python

"""Tests for scheduler service."""
import uuid
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.schemas.scheduler import ScheduledTaskCreate
from app.services.scheduler import SchedulerService
class TestSchedulerService:
"""Test cases for scheduler service."""
@pytest.fixture
def mock_player_service(self):
"""Create mock player service."""
return MagicMock()
@pytest.fixture
def scheduler_service(
self,
db_session: AsyncSession,
mock_player_service,
) -> SchedulerService:
"""Create scheduler service fixture."""
def session_factory():
return db_session
return SchedulerService(session_factory, mock_player_service)
@pytest.fixture
def sample_task_data(self) -> dict:
"""Sample task data for testing."""
return {
"name": "Test Task",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": datetime.now(tz=UTC) + timedelta(hours=1),
"parameters": {"test": "value"},
"timezone": "UTC",
}
def _create_task_schema(self, task_data: dict, **overrides) -> ScheduledTaskCreate:
"""Create ScheduledTaskCreate schema from dict."""
data = {**task_data, **overrides}
return ScheduledTaskCreate(**data)
async def test_create_task(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test creating a scheduled task."""
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
assert task.id is not None
assert task.name == sample_task_data["name"]
assert task.task_type == sample_task_data["task_type"]
assert task.status == TaskStatus.PENDING
assert task.parameters == sample_task_data["parameters"]
mock_schedule.assert_called_once_with(task)
async def test_create_user_task(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
test_user_id: uuid.UUID,
):
"""Test creating a user task."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(
task_data=schema,
user_id=test_user_id,
)
assert task.user_id == test_user_id
assert not task.is_system_task()
async def test_create_system_task(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test creating a system task."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
assert task.user_id is None
assert task.is_system_task()
async def test_create_recurring_task(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test creating a recurring task."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(
sample_task_data,
recurrence_type=RecurrenceType.DAILY,
recurrence_count=5,
)
task = await scheduler_service.create_task(task_data=schema)
assert task.recurrence_type == RecurrenceType.DAILY
assert task.recurrence_count == 5
assert task.is_recurring()
async def test_create_task_with_timezone_conversion(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test creating task with timezone conversion."""
# Use a specific datetime for testing
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY # noqa: DTZ001
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(
sample_task_data,
scheduled_at=ny_time,
timezone="America/New_York",
)
task = await scheduler_service.create_task(task_data=schema)
# The scheduled_at should be converted to UTC
assert task.timezone == "America/New_York"
# In winter, EST is UTC-5, so noon EST becomes 5 PM UTC
# Note: This test might need adjustment based on DST
assert task.scheduled_at.hour in [16, 17] # Account for DST
async def test_cancel_task(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test cancelling a task."""
# Create a task first
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
# Mock the scheduler remove_job method
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
result = await scheduler_service.cancel_task(task.id)
assert result is True
mock_remove.assert_called_once_with(str(task.id))
# Check task is cancelled in database
from app.repositories.scheduled_task import ScheduledTaskRepository
async with scheduler_service.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
updated_task = await repo.get_by_id(task.id)
assert updated_task.status == TaskStatus.CANCELLED
assert updated_task.is_active is False
async def test_cancel_nonexistent_task(
self,
scheduler_service: SchedulerService,
):
"""Test cancelling a non-existent task."""
result = await scheduler_service.cancel_task(uuid.uuid4())
assert result is False
async def test_get_user_tasks(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
test_user_id: uuid.UUID,
):
"""Test getting user tasks."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
# Create user task
schema = self._create_task_schema(sample_task_data)
await scheduler_service.create_task(
task_data=schema,
user_id=test_user_id,
)
# Create system task
system_schema = self._create_task_schema(sample_task_data)
await scheduler_service.create_task(task_data=system_schema)
user_tasks = await scheduler_service.get_user_tasks(test_user_id)
assert len(user_tasks) == 1
assert user_tasks[0].user_id == test_user_id
async def test_ensure_system_tasks(
self,
scheduler_service: SchedulerService,
):
"""Test ensuring system tasks exist."""
# Mock the repository to return no existing tasks
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
mock_get.return_value = []
await scheduler_service._ensure_system_tasks()
# Should create daily credit recharge task
mock_create.assert_called_once()
created_task_data = mock_create.call_args[0][0]
assert created_task_data["name"] == "Daily Credit Recharge"
assert created_task_data["task_type"] == TaskType.CREDIT_RECHARGE
assert created_task_data["recurrence_type"] == RecurrenceType.DAILY
async def test_ensure_system_tasks_already_exist(
self,
scheduler_service: SchedulerService,
):
"""Test ensuring system tasks when they already exist."""
existing_task = ScheduledTask(
name="Existing Daily Credit Recharge",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC),
recurrence_type=RecurrenceType.DAILY,
is_active=True,
)
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
mock_get.return_value = [existing_task]
await scheduler_service._ensure_system_tasks()
# Should not create new task
mock_create.assert_not_called()
def test_create_trigger_one_shot(
self,
scheduler_service: SchedulerService,
):
"""Test creating one-shot trigger."""
task = ScheduledTask(
name="One Shot",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.NONE,
)
trigger = scheduler_service._create_trigger(task)
assert trigger is not None
assert trigger.__class__.__name__ == "DateTrigger"
def test_create_trigger_daily(
self,
scheduler_service: SchedulerService,
):
"""Test creating daily interval trigger."""
task = ScheduledTask(
name="Daily",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
)
trigger = scheduler_service._create_trigger(task)
assert trigger is not None
assert trigger.__class__.__name__ == "IntervalTrigger"
def test_create_trigger_cron(
self,
scheduler_service: SchedulerService,
):
"""Test creating cron trigger."""
task = ScheduledTask(
name="Cron",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.CRON,
cron_expression="0 9 * * *", # 9 AM daily
)
trigger = scheduler_service._create_trigger(task)
assert trigger is not None
assert trigger.__class__.__name__ == "CronTrigger"
def test_create_trigger_monthly(
self,
scheduler_service: SchedulerService,
):
"""Test creating monthly cron trigger."""
task = ScheduledTask(
name="Monthly",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime(2024, 1, 15, 10, 30, 0, tzinfo=UTC), # 15th at 10:30 AM
recurrence_type=RecurrenceType.MONTHLY,
)
trigger = scheduler_service._create_trigger(task)
assert trigger is not None
assert trigger.__class__.__name__ == "CronTrigger"
def test_calculate_next_execution(
self,
scheduler_service: SchedulerService,
):
"""Test calculating next execution time."""
now = datetime.now(tz=UTC)
# Test different recurrence types
test_cases = [
(RecurrenceType.HOURLY, timedelta(hours=1)),
(RecurrenceType.DAILY, timedelta(days=1)),
(RecurrenceType.WEEKLY, timedelta(weeks=1)),
(RecurrenceType.MONTHLY, timedelta(days=30)),
(RecurrenceType.YEARLY, timedelta(days=365)),
]
for recurrence_type, expected_delta in test_cases:
task = ScheduledTask(
name="Test",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=now,
recurrence_type=recurrence_type,
)
with patch("app.services.scheduler.datetime") as mock_datetime:
mock_datetime.now.return_value = now
next_execution = scheduler_service._calculate_next_execution(task)
assert next_execution is not None
# Allow some tolerance for execution time
assert abs((next_execution - now) - expected_delta) < timedelta(seconds=1)
def test_calculate_next_execution_none_recurrence(
self,
scheduler_service: SchedulerService,
):
"""Test calculating next execution for non-recurring task."""
task = ScheduledTask(
name="One Shot",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.now(tz=UTC),
recurrence_type=RecurrenceType.NONE,
)
next_execution = scheduler_service._calculate_next_execution(task)
assert next_execution is None
@patch("app.services.scheduler.TaskHandlerRegistry")
async def test_execute_task_success(
self,
mock_handler_class,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test successful task execution."""
# Create task ready for immediate execution
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
ready_data = {**sample_task_data, "scheduled_at": datetime.now(tz=UTC) - timedelta(minutes=1)}
schema = self._create_task_schema(ready_data)
task = await scheduler_service.create_task(task_data=schema)
# Mock handler registry
mock_handler = AsyncMock()
mock_handler_class.return_value = mock_handler
# Execute task
await scheduler_service._execute_task(task.id)
# Verify handler was called
mock_handler.execute_task.assert_called_once()
# Check task is marked as completed
from app.repositories.scheduled_task import ScheduledTaskRepository
async with scheduler_service.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
updated_task = await repo.get_by_id(task.id)
assert updated_task.status == TaskStatus.COMPLETED
assert updated_task.executions_count == 1
@patch("app.services.scheduler.TaskHandlerRegistry")
async def test_execute_task_failure(
self,
mock_handler_class,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test task execution failure."""
# Create task ready for immediate execution
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
ready_data = {**sample_task_data, "scheduled_at": datetime.now(tz=UTC) - timedelta(minutes=1)}
schema = self._create_task_schema(ready_data)
task = await scheduler_service.create_task(task_data=schema)
# Mock handler to raise exception
mock_handler = AsyncMock()
mock_handler.execute_task.side_effect = Exception("Task failed")
mock_handler_class.return_value = mock_handler
# Execute task
await scheduler_service._execute_task(task.id)
# Check task is marked as failed
from app.repositories.scheduled_task import ScheduledTaskRepository
async with scheduler_service.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
updated_task = await repo.get_by_id(task.id)
assert updated_task.status == TaskStatus.FAILED
assert "Task failed" in updated_task.error_message
async def test_execute_nonexistent_task(
self,
scheduler_service: SchedulerService,
):
"""Test executing non-existent task."""
# Should handle gracefully
await scheduler_service._execute_task(uuid.uuid4())
async def test_execute_expired_task(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test executing expired task."""
# Create expired task (stored as naive UTC datetime)
expires_at = datetime.now(tz=UTC).replace(tzinfo=None) - timedelta(hours=1)
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(sample_task_data, expires_at=expires_at)
task = await scheduler_service.create_task(task_data=schema)
# Execute task
await scheduler_service._execute_task(task.id)
# Check task is cancelled
from app.repositories.scheduled_task import ScheduledTaskRepository
async with scheduler_service.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
updated_task = await repo.get_by_id(task.id)
assert updated_task.status == TaskStatus.CANCELLED
assert updated_task.is_active is False
async def test_concurrent_task_execution_prevention(
self,
scheduler_service: SchedulerService,
sample_task_data: dict,
):
"""Test prevention of concurrent task execution."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
# Add task to running set
scheduler_service._running_tasks.add(str(task.id))
# Try to execute - should return without doing anything
with patch("app.services.task_handlers.TaskHandlerRegistry") as mock_handler_class:
await scheduler_service._execute_task(task.id)
# Handler should not be called
mock_handler_class.assert_not_called()
@patch("app.services.scheduler.ScheduledTaskRepository")
async def test_maintenance_job_expired_tasks(
self,
mock_repo_class,
scheduler_service: SchedulerService,
):
"""Test maintenance job handling expired tasks."""
# Mock expired task
expired_task = MagicMock()
expired_task.id = uuid.uuid4()
mock_repo = AsyncMock()
mock_repo.get_expired_tasks.return_value = [expired_task]
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
mock_repo_class.return_value = mock_repo
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
await scheduler_service._maintenance_job()
# Should mark as cancelled and remove from scheduler
mock_repo.update.assert_called_with(expired_task, {
"status": TaskStatus.CANCELLED,
"is_active": False,
})
mock_remove.assert_called_once_with(str(expired_task.id))
@patch("app.services.scheduler.ScheduledTaskRepository")
async def test_maintenance_job_due_recurring_tasks(
self,
mock_repo_class,
scheduler_service: SchedulerService,
):
"""Test maintenance job handling due recurring tasks."""
# Mock due recurring task
due_task = MagicMock()
due_task.should_repeat.return_value = True
due_task.next_execution_at = datetime.now(tz=UTC) - timedelta(minutes=5)
mock_repo = AsyncMock()
mock_repo.get_expired_tasks.return_value = []
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
mock_repo_class.return_value = mock_repo
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
await scheduler_service._maintenance_job()
# Should reset to pending and reschedule
mock_repo.update.assert_called_with(due_task, {
"status": TaskStatus.PENDING,
"scheduled_at": due_task.next_execution_at,
})
mock_schedule.assert_called_once_with(due_task)