Add comprehensive tests for scheduled task repository, scheduler service, and task handlers
- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates. - Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs. - Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback. - Ensured proper error handling and edge cases in task execution scenarios. - Added fixtures and mocks to facilitate isolated testing of services and repositories.
This commit is contained in:
@@ -25,6 +25,7 @@ from app.models.favorite import Favorite # noqa: F401
|
||||
from app.models.plan import Plan
|
||||
from app.models.playlist import Playlist # noqa: F401
|
||||
from app.models.playlist_sound import PlaylistSound # noqa: F401
|
||||
from app.models.scheduled_task import ScheduledTask # noqa: F401
|
||||
from app.models.sound import Sound # noqa: F401
|
||||
from app.models.sound_played import SoundPlayed # noqa: F401
|
||||
from app.models.user import User
|
||||
@@ -346,3 +347,29 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
|
||||
access_token = JWTUtils.create_access_token(token_data)
|
||||
|
||||
return {"access_token": access_token}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id(test_user: User):
|
||||
"""Get test user ID."""
|
||||
return test_user.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_sound_id():
|
||||
"""Create a test sound ID."""
|
||||
import uuid
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_playlist_id():
|
||||
"""Create a test playlist ID."""
|
||||
import uuid
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(test_session: AsyncSession) -> AsyncSession:
|
||||
"""Alias for test_session to match test expectations."""
|
||||
return test_session
|
||||
|
||||
220
tests/test_scheduled_task_model.py
Normal file
220
tests/test_scheduled_task_model.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Tests for scheduled task model."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
|
||||
|
||||
class TestScheduledTaskModel:
|
||||
"""Test cases for scheduled task model."""
|
||||
|
||||
def test_task_creation(self):
|
||||
"""Test basic task creation."""
|
||||
task = ScheduledTask(
|
||||
name="Test Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert task.name == "Test Task"
|
||||
assert task.task_type == TaskType.CREDIT_RECHARGE
|
||||
assert task.status == TaskStatus.PENDING
|
||||
assert task.timezone == "UTC"
|
||||
assert task.recurrence_type == RecurrenceType.NONE
|
||||
assert task.parameters == {}
|
||||
assert task.user_id is None
|
||||
assert task.executions_count == 0
|
||||
assert task.is_active is True
|
||||
|
||||
def test_task_with_user(self):
|
||||
"""Test task creation with user association."""
|
||||
user_id = uuid.uuid4()
|
||||
task = ScheduledTask(
|
||||
name="User Task",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
assert task.user_id == user_id
|
||||
assert not task.is_system_task()
|
||||
|
||||
def test_system_task(self):
|
||||
"""Test system task (no user association)."""
|
||||
task = ScheduledTask(
|
||||
name="System Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert task.user_id is None
|
||||
assert task.is_system_task()
|
||||
|
||||
def test_recurring_task(self):
|
||||
"""Test recurring task properties."""
|
||||
task = ScheduledTask(
|
||||
name="Recurring Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
recurrence_count=5,
|
||||
)
|
||||
|
||||
assert task.is_recurring()
|
||||
assert task.should_repeat()
|
||||
|
||||
def test_non_recurring_task(self):
|
||||
"""Test non-recurring task properties."""
|
||||
task = ScheduledTask(
|
||||
name="One-shot Task",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.NONE,
|
||||
)
|
||||
|
||||
assert not task.is_recurring()
|
||||
assert not task.should_repeat()
|
||||
|
||||
def test_infinite_recurring_task(self):
|
||||
"""Test infinitely recurring task."""
|
||||
task = ScheduledTask(
|
||||
name="Infinite Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
recurrence_count=None, # Infinite
|
||||
)
|
||||
|
||||
assert task.is_recurring()
|
||||
assert task.should_repeat()
|
||||
|
||||
# Even after many executions
|
||||
task.executions_count = 100
|
||||
assert task.should_repeat()
|
||||
|
||||
def test_recurring_task_execution_limit(self):
|
||||
"""Test recurring task with execution limit."""
|
||||
task = ScheduledTask(
|
||||
name="Limited Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
recurrence_count=3,
|
||||
)
|
||||
|
||||
assert task.should_repeat()
|
||||
|
||||
# After 3 executions, should not repeat
|
||||
task.executions_count = 3
|
||||
assert not task.should_repeat()
|
||||
|
||||
# After more than limit, still should not repeat
|
||||
task.executions_count = 5
|
||||
assert not task.should_repeat()
|
||||
|
||||
def test_task_expiration(self):
|
||||
"""Test task expiration."""
|
||||
# Non-expired task
|
||||
task = ScheduledTask(
|
||||
name="Valid Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
expires_at=datetime.utcnow() + timedelta(hours=2),
|
||||
)
|
||||
assert not task.is_expired()
|
||||
|
||||
# Expired task
|
||||
expired_task = ScheduledTask(
|
||||
name="Expired Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||
)
|
||||
assert expired_task.is_expired()
|
||||
|
||||
# Task with no expiration
|
||||
no_expiry_task = ScheduledTask(
|
||||
name="No Expiry Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
assert not no_expiry_task.is_expired()
|
||||
|
||||
def test_task_with_parameters(self):
|
||||
"""Test task with custom parameters."""
|
||||
parameters = {
|
||||
"sound_id": str(uuid.uuid4()),
|
||||
"volume": 80,
|
||||
"repeat": True,
|
||||
}
|
||||
|
||||
task = ScheduledTask(
|
||||
name="Parametrized Task",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
assert task.parameters == parameters
|
||||
assert task.parameters["sound_id"] == parameters["sound_id"]
|
||||
assert task.parameters["volume"] == 80
|
||||
assert task.parameters["repeat"] is True
|
||||
|
||||
def test_task_with_timezone(self):
|
||||
"""Test task with custom timezone."""
|
||||
task = ScheduledTask(
|
||||
name="NY Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
timezone="America/New_York",
|
||||
)
|
||||
|
||||
assert task.timezone == "America/New_York"
|
||||
|
||||
def test_task_with_cron_expression(self):
|
||||
"""Test task with cron expression."""
|
||||
cron_expr = "0 9 * * 1-5" # 9 AM on weekdays
|
||||
|
||||
task = ScheduledTask(
|
||||
name="Cron Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.CRON,
|
||||
cron_expression=cron_expr,
|
||||
)
|
||||
|
||||
assert task.recurrence_type == RecurrenceType.CRON
|
||||
assert task.cron_expression == cron_expr
|
||||
assert task.is_recurring()
|
||||
|
||||
def test_task_status_enum_values(self):
|
||||
"""Test all task status enum values."""
|
||||
assert TaskStatus.PENDING == "pending"
|
||||
assert TaskStatus.RUNNING == "running"
|
||||
assert TaskStatus.COMPLETED == "completed"
|
||||
assert TaskStatus.FAILED == "failed"
|
||||
assert TaskStatus.CANCELLED == "cancelled"
|
||||
|
||||
def test_task_type_enum_values(self):
|
||||
"""Test all task type enum values."""
|
||||
assert TaskType.CREDIT_RECHARGE == "credit_recharge"
|
||||
assert TaskType.PLAY_SOUND == "play_sound"
|
||||
assert TaskType.PLAY_PLAYLIST == "play_playlist"
|
||||
|
||||
def test_recurrence_type_enum_values(self):
|
||||
"""Test all recurrence type enum values."""
|
||||
assert RecurrenceType.NONE == "none"
|
||||
assert RecurrenceType.HOURLY == "hourly"
|
||||
assert RecurrenceType.DAILY == "daily"
|
||||
assert RecurrenceType.WEEKLY == "weekly"
|
||||
assert RecurrenceType.MONTHLY == "monthly"
|
||||
assert RecurrenceType.YEARLY == "yearly"
|
||||
assert RecurrenceType.CRON == "cron"
|
||||
494
tests/test_scheduled_task_repository.py
Normal file
494
tests/test_scheduled_task_repository.py
Normal file
@@ -0,0 +1,494 @@
|
||||
"""Tests for scheduled task repository."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
class TestScheduledTaskRepository:
|
||||
"""Test cases for scheduled task repository."""
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, db_session: AsyncSession) -> ScheduledTaskRepository:
|
||||
"""Create repository fixture."""
|
||||
return ScheduledTaskRepository(db_session)
|
||||
|
||||
@pytest.fixture
|
||||
async def sample_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
) -> ScheduledTask:
|
||||
"""Create a sample scheduled task."""
|
||||
task = ScheduledTask(
|
||||
name="Test Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
parameters={"test": "value"},
|
||||
)
|
||||
return await repository.create(task)
|
||||
|
||||
@pytest.fixture
|
||||
async def user_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
test_user_id: uuid.UUID,
|
||||
) -> ScheduledTask:
|
||||
"""Create a user task."""
|
||||
task = ScheduledTask(
|
||||
name="User Task",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=2),
|
||||
user_id=test_user_id,
|
||||
parameters={"sound_id": str(uuid.uuid4())},
|
||||
)
|
||||
return await repository.create(task)
|
||||
|
||||
async def test_create_task(self, repository: ScheduledTaskRepository):
|
||||
"""Test creating a scheduled task."""
|
||||
task = ScheduledTask(
|
||||
name="Test Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
timezone="America/New_York",
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
parameters={"test": "value"},
|
||||
)
|
||||
|
||||
created_task = await repository.create(task)
|
||||
|
||||
assert created_task.id is not None
|
||||
assert created_task.name == "Test Task"
|
||||
assert created_task.task_type == TaskType.CREDIT_RECHARGE
|
||||
assert created_task.status == TaskStatus.PENDING
|
||||
assert created_task.timezone == "America/New_York"
|
||||
assert created_task.recurrence_type == RecurrenceType.DAILY
|
||||
assert created_task.parameters == {"test": "value"}
|
||||
assert created_task.is_active is True
|
||||
assert created_task.executions_count == 0
|
||||
|
||||
async def test_get_pending_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test getting pending tasks."""
|
||||
# Create tasks with different statuses and times
|
||||
past_pending = ScheduledTask(
|
||||
name="Past Pending",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
await repository.create(past_pending)
|
||||
|
||||
future_pending = ScheduledTask(
|
||||
name="Future Pending",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
await repository.create(future_pending)
|
||||
|
||||
completed_task = ScheduledTask(
|
||||
name="Completed",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
status=TaskStatus.COMPLETED,
|
||||
)
|
||||
await repository.create(completed_task)
|
||||
|
||||
inactive_task = ScheduledTask(
|
||||
name="Inactive",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
status=TaskStatus.PENDING,
|
||||
is_active=False,
|
||||
)
|
||||
await repository.create(inactive_task)
|
||||
|
||||
pending_tasks = await repository.get_pending_tasks()
|
||||
task_names = [task.name for task in pending_tasks]
|
||||
|
||||
# Only the past pending task should be returned
|
||||
assert len(pending_tasks) == 1
|
||||
assert "Past Pending" in task_names
|
||||
|
||||
async def test_get_user_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
user_task: ScheduledTask,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test getting tasks for a specific user."""
|
||||
# Create another user's task
|
||||
other_user_id = uuid.uuid4()
|
||||
other_task = ScheduledTask(
|
||||
name="Other User Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=other_user_id,
|
||||
)
|
||||
await repository.create(other_task)
|
||||
|
||||
# Create system task (no user)
|
||||
system_task = ScheduledTask(
|
||||
name="System Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
await repository.create(system_task)
|
||||
|
||||
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||
|
||||
assert len(user_tasks) == 1
|
||||
assert user_tasks[0].name == "User Task"
|
||||
assert user_tasks[0].user_id == test_user_id
|
||||
|
||||
async def test_get_user_tasks_with_filters(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test getting user tasks with status and type filters."""
|
||||
# Create tasks with different statuses and types
|
||||
tasks_data = [
|
||||
("Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
|
||||
("Task 2", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
|
||||
("Task 3", TaskStatus.PENDING, TaskType.PLAY_SOUND),
|
||||
("Task 4", TaskStatus.FAILED, TaskType.PLAY_PLAYLIST),
|
||||
]
|
||||
|
||||
for name, status, task_type in tasks_data:
|
||||
task = ScheduledTask(
|
||||
name=name,
|
||||
task_type=task_type,
|
||||
status=status,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=test_user_id,
|
||||
)
|
||||
await repository.create(task)
|
||||
|
||||
# Test status filter
|
||||
pending_tasks = await repository.get_user_tasks(
|
||||
test_user_id,
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
assert len(pending_tasks) == 2
|
||||
assert all(task.status == TaskStatus.PENDING for task in pending_tasks)
|
||||
|
||||
# Test type filter
|
||||
credit_tasks = await repository.get_user_tasks(
|
||||
test_user_id,
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
assert len(credit_tasks) == 2
|
||||
assert all(task.task_type == TaskType.CREDIT_RECHARGE for task in credit_tasks)
|
||||
|
||||
# Test combined filters
|
||||
pending_credit_tasks = await repository.get_user_tasks(
|
||||
test_user_id,
|
||||
status=TaskStatus.PENDING,
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
assert len(pending_credit_tasks) == 1
|
||||
assert pending_credit_tasks[0].name == "Task 1"
|
||||
|
||||
async def test_get_system_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
user_task: ScheduledTask,
|
||||
):
|
||||
"""Test getting system tasks."""
|
||||
system_tasks = await repository.get_system_tasks()
|
||||
|
||||
assert len(system_tasks) == 1
|
||||
assert system_tasks[0].name == "Test Task"
|
||||
assert system_tasks[0].user_id is None
|
||||
|
||||
async def test_get_recurring_tasks_due_for_next_execution(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test getting recurring tasks due for next execution."""
|
||||
# Create completed recurring task that should be re-executed
|
||||
due_task = ScheduledTask(
|
||||
name="Due Recurring",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
status=TaskStatus.COMPLETED,
|
||||
next_execution_at=datetime.utcnow() - timedelta(minutes=5),
|
||||
)
|
||||
await repository.create(due_task)
|
||||
|
||||
# Create completed recurring task not due yet
|
||||
not_due_task = ScheduledTask(
|
||||
name="Not Due Recurring",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
status=TaskStatus.COMPLETED,
|
||||
next_execution_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
await repository.create(not_due_task)
|
||||
|
||||
# Create non-recurring completed task
|
||||
non_recurring = ScheduledTask(
|
||||
name="Non-recurring",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.NONE,
|
||||
status=TaskStatus.COMPLETED,
|
||||
)
|
||||
await repository.create(non_recurring)
|
||||
|
||||
due_tasks = await repository.get_recurring_tasks_due_for_next_execution()
|
||||
|
||||
assert len(due_tasks) == 1
|
||||
assert due_tasks[0].name == "Due Recurring"
|
||||
|
||||
async def test_get_expired_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test getting expired tasks."""
|
||||
# Create expired task
|
||||
expired_task = ScheduledTask(
|
||||
name="Expired Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||
)
|
||||
await repository.create(expired_task)
|
||||
|
||||
# Create non-expired task
|
||||
valid_task = ScheduledTask(
|
||||
name="Valid Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
expires_at=datetime.utcnow() + timedelta(hours=2),
|
||||
)
|
||||
await repository.create(valid_task)
|
||||
|
||||
# Create task with no expiry
|
||||
no_expiry_task = ScheduledTask(
|
||||
name="No Expiry",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
)
|
||||
await repository.create(no_expiry_task)
|
||||
|
||||
expired_tasks = await repository.get_expired_tasks()
|
||||
|
||||
assert len(expired_tasks) == 1
|
||||
assert expired_tasks[0].name == "Expired Task"
|
||||
|
||||
async def test_cancel_user_tasks(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test cancelling user tasks."""
|
||||
# Create multiple user tasks
|
||||
tasks_data = [
|
||||
("Pending Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
|
||||
("Running Task", TaskStatus.RUNNING, TaskType.PLAY_SOUND),
|
||||
("Completed Task", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
|
||||
]
|
||||
|
||||
for name, status, task_type in tasks_data:
|
||||
task = ScheduledTask(
|
||||
name=name,
|
||||
task_type=task_type,
|
||||
status=status,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=test_user_id,
|
||||
)
|
||||
await repository.create(task)
|
||||
|
||||
# Cancel all user tasks
|
||||
cancelled_count = await repository.cancel_user_tasks(test_user_id)
|
||||
|
||||
assert cancelled_count == 2 # Only pending and running tasks
|
||||
|
||||
# Verify tasks are cancelled
|
||||
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||
pending_or_running = [
|
||||
task for task in user_tasks
|
||||
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING]
|
||||
]
|
||||
cancelled_tasks = [
|
||||
task for task in user_tasks
|
||||
if task.status == TaskStatus.CANCELLED
|
||||
]
|
||||
|
||||
assert len(pending_or_running) == 0
|
||||
assert len(cancelled_tasks) == 2
|
||||
|
||||
async def test_cancel_user_tasks_by_type(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test cancelling user tasks by type."""
|
||||
# Create tasks of different types
|
||||
credit_task = ScheduledTask(
|
||||
name="Credit Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=test_user_id,
|
||||
)
|
||||
await repository.create(credit_task)
|
||||
|
||||
sound_task = ScheduledTask(
|
||||
name="Sound Task",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
user_id=test_user_id,
|
||||
)
|
||||
await repository.create(sound_task)
|
||||
|
||||
# Cancel only credit tasks
|
||||
cancelled_count = await repository.cancel_user_tasks(
|
||||
test_user_id,
|
||||
TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
|
||||
assert cancelled_count == 1
|
||||
|
||||
# Verify only credit task is cancelled
|
||||
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||
credit_tasks = [
|
||||
task for task in user_tasks
|
||||
if task.task_type == TaskType.CREDIT_RECHARGE
|
||||
]
|
||||
sound_tasks = [
|
||||
task for task in user_tasks
|
||||
if task.task_type == TaskType.PLAY_SOUND
|
||||
]
|
||||
|
||||
assert len(credit_tasks) == 1
|
||||
assert credit_tasks[0].status == TaskStatus.CANCELLED
|
||||
assert len(sound_tasks) == 1
|
||||
assert sound_tasks[0].status == TaskStatus.PENDING
|
||||
|
||||
async def test_mark_as_running(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking task as running."""
|
||||
await repository.mark_as_running(sample_task)
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.RUNNING
|
||||
|
||||
async def test_mark_as_completed(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking task as completed."""
|
||||
initial_count = sample_task.executions_count
|
||||
next_execution = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
await repository.mark_as_completed(sample_task, next_execution)
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.COMPLETED
|
||||
assert updated_task.executions_count == initial_count + 1
|
||||
assert updated_task.last_executed_at is not None
|
||||
assert updated_task.error_message is None
|
||||
|
||||
async def test_mark_as_completed_recurring_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test marking recurring task as completed."""
|
||||
task = ScheduledTask(
|
||||
name="Recurring Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
)
|
||||
created_task = await repository.create(task)
|
||||
|
||||
next_execution = datetime.utcnow() + timedelta(days=1)
|
||||
await repository.mark_as_completed(created_task, next_execution)
|
||||
|
||||
updated_task = await repository.get_by_id(created_task.id)
|
||||
# Should be set back to pending for next execution
|
||||
assert updated_task.status == TaskStatus.PENDING
|
||||
assert updated_task.next_execution_at == next_execution
|
||||
assert updated_task.is_active is True
|
||||
|
||||
async def test_mark_as_completed_non_recurring_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking non-recurring task as completed."""
|
||||
await repository.mark_as_completed(sample_task, None)
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.COMPLETED
|
||||
assert updated_task.is_active is False
|
||||
|
||||
async def test_mark_as_failed(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking task as failed."""
|
||||
error_message = "Task execution failed"
|
||||
|
||||
await repository.mark_as_failed(sample_task, error_message)
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.FAILED
|
||||
assert updated_task.error_message == error_message
|
||||
assert updated_task.last_executed_at is not None
|
||||
|
||||
async def test_mark_as_failed_recurring_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
):
|
||||
"""Test marking recurring task as failed."""
|
||||
task = ScheduledTask(
|
||||
name="Recurring Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
)
|
||||
created_task = await repository.create(task)
|
||||
|
||||
await repository.mark_as_failed(created_task, "Failed")
|
||||
|
||||
updated_task = await repository.get_by_id(created_task.id)
|
||||
assert updated_task.status == TaskStatus.FAILED
|
||||
# Recurring tasks should remain active even after failure
|
||||
assert updated_task.is_active is True
|
||||
|
||||
async def test_mark_as_failed_non_recurring_task(
|
||||
self,
|
||||
repository: ScheduledTaskRepository,
|
||||
sample_task: ScheduledTask,
|
||||
):
|
||||
"""Test marking non-recurring task as failed."""
|
||||
await repository.mark_as_failed(sample_task, "Failed")
|
||||
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.FAILED
|
||||
# Non-recurring tasks should be deactivated on failure
|
||||
assert updated_task.is_active is False
|
||||
495
tests/test_scheduler_service.py
Normal file
495
tests/test_scheduler_service.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""Tests for scheduler service."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from app.services.scheduler import SchedulerService
|
||||
|
||||
|
||||
class TestSchedulerService:
|
||||
"""Test cases for scheduler service."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_player_service(self):
|
||||
"""Create mock player service."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler_service(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_player_service,
|
||||
) -> SchedulerService:
|
||||
"""Create scheduler service fixture."""
|
||||
session_factory = lambda: db_session
|
||||
return SchedulerService(session_factory, mock_player_service)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_task_data(self) -> dict:
|
||||
"""Sample task data for testing."""
|
||||
return {
|
||||
"name": "Test Task",
|
||||
"task_type": TaskType.CREDIT_RECHARGE,
|
||||
"scheduled_at": datetime.utcnow() + timedelta(hours=1),
|
||||
"parameters": {"test": "value"},
|
||||
"timezone": "UTC",
|
||||
}
|
||||
|
||||
async def test_create_task(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a scheduled task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
assert task.id is not None
|
||||
assert task.name == sample_task_data["name"]
|
||||
assert task.task_type == sample_task_data["task_type"]
|
||||
assert task.status == TaskStatus.PENDING
|
||||
assert task.parameters == sample_task_data["parameters"]
|
||||
mock_schedule.assert_called_once_with(task)
|
||||
|
||||
async def test_create_user_task(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test creating a user task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(
|
||||
user_id=test_user_id,
|
||||
**sample_task_data,
|
||||
)
|
||||
|
||||
assert task.user_id == test_user_id
|
||||
assert not task.is_system_task()
|
||||
|
||||
async def test_create_system_task(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a system task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
assert task.user_id is None
|
||||
assert task.is_system_task()
|
||||
|
||||
async def test_create_recurring_task(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a recurring task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
recurrence_count=5,
|
||||
**sample_task_data,
|
||||
)
|
||||
|
||||
assert task.recurrence_type == RecurrenceType.DAILY
|
||||
assert task.recurrence_count == 5
|
||||
assert task.is_recurring()
|
||||
|
||||
async def test_create_task_with_timezone_conversion(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating task with timezone conversion."""
|
||||
# Use a specific datetime for testing
|
||||
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY
|
||||
|
||||
sample_task_data["scheduled_at"] = ny_time
|
||||
sample_task_data["timezone"] = "America/New_York"
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# The scheduled_at should be converted to UTC
|
||||
assert task.timezone == "America/New_York"
|
||||
# In winter, EST is UTC-5, so noon EST becomes 5 PM UTC
|
||||
# Note: This test might need adjustment based on DST
|
||||
assert task.scheduled_at.hour in [16, 17] # Account for DST
|
||||
|
||||
async def test_cancel_task(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test cancelling a task."""
|
||||
# Create a task first
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock the scheduler remove_job method
|
||||
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
||||
result = await scheduler_service.cancel_task(task.id)
|
||||
|
||||
assert result is True
|
||||
mock_remove.assert_called_once_with(str(task.id))
|
||||
|
||||
# Check task is cancelled in database
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
async with scheduler_service.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
updated_task = await repo.get_by_id(task.id)
|
||||
assert updated_task.status == TaskStatus.CANCELLED
|
||||
assert updated_task.is_active is False
|
||||
|
||||
async def test_cancel_nonexistent_task(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test cancelling a non-existent task."""
|
||||
result = await scheduler_service.cancel_task(uuid.uuid4())
|
||||
assert result is False
|
||||
|
||||
async def test_get_user_tasks(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test getting user tasks."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
# Create user task
|
||||
await scheduler_service.create_task(
|
||||
user_id=test_user_id,
|
||||
**sample_task_data,
|
||||
)
|
||||
|
||||
# Create system task
|
||||
await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
user_tasks = await scheduler_service.get_user_tasks(test_user_id)
|
||||
|
||||
assert len(user_tasks) == 1
|
||||
assert user_tasks[0].user_id == test_user_id
|
||||
|
||||
async def test_ensure_system_tasks(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test ensuring system tasks exist."""
|
||||
# Mock the repository to return no existing tasks
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
||||
mock_get.return_value = []
|
||||
|
||||
await scheduler_service._ensure_system_tasks()
|
||||
|
||||
# Should create daily credit recharge task
|
||||
mock_create.assert_called_once()
|
||||
created_task = mock_create.call_args[0][0]
|
||||
assert created_task.name == "Daily Credit Recharge"
|
||||
assert created_task.task_type == TaskType.CREDIT_RECHARGE
|
||||
assert created_task.recurrence_type == RecurrenceType.DAILY
|
||||
|
||||
async def test_ensure_system_tasks_already_exist(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test ensuring system tasks when they already exist."""
|
||||
existing_task = ScheduledTask(
|
||||
name="Existing Daily Credit Recharge",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
||||
mock_get.return_value = [existing_task]
|
||||
|
||||
await scheduler_service._ensure_system_tasks()
|
||||
|
||||
# Should not create new task
|
||||
mock_create.assert_not_called()
|
||||
|
||||
def test_create_trigger_one_shot(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test creating one-shot trigger."""
|
||||
task = ScheduledTask(
|
||||
name="One Shot",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.NONE,
|
||||
)
|
||||
|
||||
trigger = scheduler_service._create_trigger(task)
|
||||
assert trigger is not None
|
||||
assert trigger.__class__.__name__ == "DateTrigger"
|
||||
|
||||
def test_create_trigger_daily(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test creating daily interval trigger."""
|
||||
task = ScheduledTask(
|
||||
name="Daily",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
)
|
||||
|
||||
trigger = scheduler_service._create_trigger(task)
|
||||
assert trigger is not None
|
||||
assert trigger.__class__.__name__ == "IntervalTrigger"
|
||||
|
||||
def test_create_trigger_cron(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test creating cron trigger."""
|
||||
task = ScheduledTask(
|
||||
name="Cron",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||
recurrence_type=RecurrenceType.CRON,
|
||||
cron_expression="0 9 * * *", # 9 AM daily
|
||||
)
|
||||
|
||||
trigger = scheduler_service._create_trigger(task)
|
||||
assert trigger is not None
|
||||
assert trigger.__class__.__name__ == "CronTrigger"
|
||||
|
||||
def test_create_trigger_monthly(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test creating monthly cron trigger."""
|
||||
task = ScheduledTask(
|
||||
name="Monthly",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime(2024, 1, 15, 10, 30, 0), # 15th at 10:30 AM
|
||||
recurrence_type=RecurrenceType.MONTHLY,
|
||||
)
|
||||
|
||||
trigger = scheduler_service._create_trigger(task)
|
||||
assert trigger is not None
|
||||
assert trigger.__class__.__name__ == "CronTrigger"
|
||||
|
||||
def test_calculate_next_execution(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test calculating next execution time."""
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Test different recurrence types
|
||||
test_cases = [
|
||||
(RecurrenceType.HOURLY, timedelta(hours=1)),
|
||||
(RecurrenceType.DAILY, timedelta(days=1)),
|
||||
(RecurrenceType.WEEKLY, timedelta(weeks=1)),
|
||||
(RecurrenceType.MONTHLY, timedelta(days=30)),
|
||||
(RecurrenceType.YEARLY, timedelta(days=365)),
|
||||
]
|
||||
|
||||
for recurrence_type, expected_delta in test_cases:
|
||||
task = ScheduledTask(
|
||||
name="Test",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=now,
|
||||
recurrence_type=recurrence_type,
|
||||
)
|
||||
|
||||
with patch('app.services.scheduler.datetime') as mock_datetime:
|
||||
mock_datetime.utcnow.return_value = now
|
||||
next_execution = scheduler_service._calculate_next_execution(task)
|
||||
|
||||
assert next_execution is not None
|
||||
# Allow some tolerance for execution time
|
||||
assert abs((next_execution - now) - expected_delta) < timedelta(seconds=1)
|
||||
|
||||
def test_calculate_next_execution_none_recurrence(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test calculating next execution for non-recurring task."""
|
||||
task = ScheduledTask(
|
||||
name="One Shot",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
recurrence_type=RecurrenceType.NONE,
|
||||
)
|
||||
|
||||
next_execution = scheduler_service._calculate_next_execution(task)
|
||||
assert next_execution is None
|
||||
|
||||
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
||||
async def test_execute_task_success(
|
||||
self,
|
||||
mock_handler_class,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test successful task execution."""
|
||||
# Create task
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock handler registry
|
||||
mock_handler = AsyncMock()
|
||||
mock_handler_class.return_value = mock_handler
|
||||
|
||||
# Execute task
|
||||
await scheduler_service._execute_task(task.id)
|
||||
|
||||
# Verify handler was called
|
||||
mock_handler.execute_task.assert_called_once()
|
||||
|
||||
# Check task is marked as completed
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
async with scheduler_service.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
updated_task = await repo.get_by_id(task.id)
|
||||
assert updated_task.status == TaskStatus.COMPLETED
|
||||
assert updated_task.executions_count == 1
|
||||
|
||||
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
||||
async def test_execute_task_failure(
|
||||
self,
|
||||
mock_handler_class,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test task execution failure."""
|
||||
# Create task
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock handler to raise exception
|
||||
mock_handler = AsyncMock()
|
||||
mock_handler.execute_task.side_effect = Exception("Task failed")
|
||||
mock_handler_class.return_value = mock_handler
|
||||
|
||||
# Execute task
|
||||
await scheduler_service._execute_task(task.id)
|
||||
|
||||
# Check task is marked as failed
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
async with scheduler_service.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
updated_task = await repo.get_by_id(task.id)
|
||||
assert updated_task.status == TaskStatus.FAILED
|
||||
assert "Task failed" in updated_task.error_message
|
||||
|
||||
async def test_execute_nonexistent_task(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test executing non-existent task."""
|
||||
# Should handle gracefully
|
||||
await scheduler_service._execute_task(uuid.uuid4())
|
||||
|
||||
async def test_execute_expired_task(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test executing expired task."""
|
||||
# Create expired task
|
||||
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Execute task
|
||||
await scheduler_service._execute_task(task.id)
|
||||
|
||||
# Check task is cancelled
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
async with scheduler_service.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
updated_task = await repo.get_by_id(task.id)
|
||||
assert updated_task.status == TaskStatus.CANCELLED
|
||||
assert updated_task.is_active is False
|
||||
|
||||
async def test_concurrent_task_execution_prevention(
|
||||
self,
|
||||
scheduler_service: SchedulerService,
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test prevention of concurrent task execution."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Add task to running set
|
||||
scheduler_service._running_tasks.add(str(task.id))
|
||||
|
||||
# Try to execute - should return without doing anything
|
||||
with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class:
|
||||
await scheduler_service._execute_task(task.id)
|
||||
|
||||
# Handler should not be called
|
||||
mock_handler_class.assert_not_called()
|
||||
|
||||
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
||||
async def test_maintenance_job_expired_tasks(
|
||||
self,
|
||||
mock_repo_class,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test maintenance job handling expired tasks."""
|
||||
# Mock expired task
|
||||
expired_task = MagicMock()
|
||||
expired_task.id = uuid.uuid4()
|
||||
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_expired_tasks.return_value = [expired_task]
|
||||
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
|
||||
mock_repo_class.return_value = mock_repo
|
||||
|
||||
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
||||
await scheduler_service._maintenance_job()
|
||||
|
||||
# Should mark as cancelled and remove from scheduler
|
||||
assert expired_task.status == TaskStatus.CANCELLED
|
||||
assert expired_task.is_active is False
|
||||
mock_repo.update.assert_called_with(expired_task)
|
||||
mock_remove.assert_called_once_with(str(expired_task.id))
|
||||
|
||||
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
||||
async def test_maintenance_job_due_recurring_tasks(
|
||||
self,
|
||||
mock_repo_class,
|
||||
scheduler_service: SchedulerService,
|
||||
):
|
||||
"""Test maintenance job handling due recurring tasks."""
|
||||
# Mock due recurring task
|
||||
due_task = MagicMock()
|
||||
due_task.should_repeat.return_value = True
|
||||
due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5)
|
||||
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_expired_tasks.return_value = []
|
||||
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
|
||||
mock_repo_class.return_value = mock_repo
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
||||
await scheduler_service._maintenance_job()
|
||||
|
||||
# Should reset to pending and reschedule
|
||||
assert due_task.status == TaskStatus.PENDING
|
||||
assert due_task.scheduled_at == due_task.next_execution_at
|
||||
mock_repo.update.assert_called_with(due_task)
|
||||
mock_schedule.assert_called_once_with(due_task)
|
||||
424
tests/test_task_handlers.py
Normal file
424
tests/test_task_handlers.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""Tests for task handlers."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import ScheduledTask, TaskType
|
||||
from app.services.task_handlers import TaskExecutionError, TaskHandlerRegistry
|
||||
|
||||
|
||||
class TestTaskHandlerRegistry:
|
||||
"""Test cases for task handler registry."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_service(self):
|
||||
"""Create mock credit service."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_player_service(self):
|
||||
"""Create mock player service."""
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def task_registry(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_credit_service,
|
||||
mock_player_service,
|
||||
) -> TaskHandlerRegistry:
|
||||
"""Create task handler registry fixture."""
|
||||
return TaskHandlerRegistry(
|
||||
db_session,
|
||||
mock_credit_service,
|
||||
mock_player_service,
|
||||
)
|
||||
|
||||
async def test_execute_task_unknown_type(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
):
|
||||
"""Test executing task with unknown type."""
|
||||
# Create task with invalid type
|
||||
task = ScheduledTask(
|
||||
name="Unknown Task",
|
||||
task_type="UNKNOWN_TYPE", # Invalid type
|
||||
scheduled_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
with pytest.raises(TaskExecutionError, match="No handler registered"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
async def test_handle_credit_recharge_all_users(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
mock_credit_service,
|
||||
):
|
||||
"""Test handling credit recharge for all users."""
|
||||
task = ScheduledTask(
|
||||
name="Daily Credit Recharge",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={},
|
||||
)
|
||||
|
||||
mock_credit_service.recharge_all_users_credits.return_value = {
|
||||
"users_recharged": 10,
|
||||
"total_credits": 1000,
|
||||
}
|
||||
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
mock_credit_service.recharge_all_users_credits.assert_called_once()
|
||||
|
||||
async def test_handle_credit_recharge_specific_user(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
mock_credit_service,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test handling credit recharge for specific user."""
|
||||
task = ScheduledTask(
|
||||
name="User Credit Recharge",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"user_id": str(test_user_id)},
|
||||
)
|
||||
|
||||
mock_credit_service.recharge_user_credits.return_value = {
|
||||
"user_id": str(test_user_id),
|
||||
"credits_added": 100,
|
||||
}
|
||||
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
mock_credit_service.recharge_user_credits.assert_called_once_with(test_user_id)
|
||||
|
||||
async def test_handle_credit_recharge_uuid_user_id(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
mock_credit_service,
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test handling credit recharge with UUID user_id parameter."""
|
||||
task = ScheduledTask(
|
||||
name="User Credit Recharge",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"user_id": test_user_id}, # UUID object instead of string
|
||||
)
|
||||
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
mock_credit_service.recharge_user_credits.assert_called_once_with(test_user_id)
|
||||
|
||||
async def test_handle_play_sound_success(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
test_sound_id: uuid.UUID,
|
||||
):
|
||||
"""Test successful play sound task handling."""
|
||||
task = ScheduledTask(
|
||||
name="Play Sound",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"sound_id": str(test_sound_id)},
|
||||
)
|
||||
|
||||
# Mock sound repository
|
||||
mock_sound = MagicMock()
|
||||
mock_sound.id = test_sound_id
|
||||
mock_sound.filename = "test_sound.mp3"
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
||||
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_vlc_class.return_value = mock_vlc_service
|
||||
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
task_registry.sound_repository.get_by_id.assert_called_once_with(test_sound_id)
|
||||
mock_vlc_service.play_sound.assert_called_once_with(mock_sound)
|
||||
|
||||
async def test_handle_play_sound_missing_sound_id(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
):
|
||||
"""Test play sound task with missing sound_id parameter."""
|
||||
task = ScheduledTask(
|
||||
name="Play Sound",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={}, # Missing sound_id
|
||||
)
|
||||
|
||||
with pytest.raises(TaskExecutionError, match="sound_id parameter is required"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
async def test_handle_play_sound_invalid_sound_id(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
):
|
||||
"""Test play sound task with invalid sound_id parameter."""
|
||||
task = ScheduledTask(
|
||||
name="Play Sound",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"sound_id": "invalid-uuid"},
|
||||
)
|
||||
|
||||
with pytest.raises(TaskExecutionError, match="Invalid sound_id format"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
async def test_handle_play_sound_not_found(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
test_sound_id: uuid.UUID,
|
||||
):
|
||||
"""Test play sound task with non-existent sound."""
|
||||
task = ScheduledTask(
|
||||
name="Play Sound",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"sound_id": str(test_sound_id)},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None):
|
||||
with pytest.raises(TaskExecutionError, match="Sound not found"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
async def test_handle_play_sound_uuid_parameter(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
test_sound_id: uuid.UUID,
|
||||
):
|
||||
"""Test play sound task with UUID parameter (not string)."""
|
||||
task = ScheduledTask(
|
||||
name="Play Sound",
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"sound_id": test_sound_id}, # UUID object
|
||||
)
|
||||
|
||||
mock_sound = MagicMock()
|
||||
mock_sound.filename = "test_sound.mp3"
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
||||
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_vlc_class.return_value = mock_vlc_service
|
||||
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
task_registry.sound_repository.get_by_id.assert_called_once_with(test_sound_id)
|
||||
|
||||
async def test_handle_play_playlist_success(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
mock_player_service,
|
||||
test_playlist_id: uuid.UUID,
|
||||
):
|
||||
"""Test successful play playlist task handling."""
|
||||
task = ScheduledTask(
|
||||
name="Play Playlist",
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={
|
||||
"playlist_id": str(test_playlist_id),
|
||||
"play_mode": "loop",
|
||||
"shuffle": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Mock playlist repository
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.id = test_playlist_id
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
|
||||
mock_player_service.load_playlist.assert_called_once_with(test_playlist_id)
|
||||
mock_player_service.set_mode.assert_called_once_with("loop")
|
||||
mock_player_service.set_shuffle.assert_called_once_with(True)
|
||||
mock_player_service.play.assert_called_once()
|
||||
|
||||
async def test_handle_play_playlist_minimal_parameters(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
mock_player_service,
|
||||
test_playlist_id: uuid.UUID,
|
||||
):
|
||||
"""Test play playlist task with minimal parameters."""
|
||||
task = ScheduledTask(
|
||||
name="Play Playlist",
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"playlist_id": str(test_playlist_id)},
|
||||
)
|
||||
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
# Should use default values
|
||||
mock_player_service.set_mode.assert_called_once_with("continuous")
|
||||
mock_player_service.set_shuffle.assert_called_once_with(False)
|
||||
|
||||
async def test_handle_play_playlist_missing_playlist_id(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
):
|
||||
"""Test play playlist task with missing playlist_id parameter."""
|
||||
task = ScheduledTask(
|
||||
name="Play Playlist",
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={}, # Missing playlist_id
|
||||
)
|
||||
|
||||
with pytest.raises(TaskExecutionError, match="playlist_id parameter is required"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
async def test_handle_play_playlist_invalid_playlist_id(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
):
|
||||
"""Test play playlist task with invalid playlist_id parameter."""
|
||||
task = ScheduledTask(
|
||||
name="Play Playlist",
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"playlist_id": "invalid-uuid"},
|
||||
)
|
||||
|
||||
with pytest.raises(TaskExecutionError, match="Invalid playlist_id format"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
async def test_handle_play_playlist_not_found(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
test_playlist_id: uuid.UUID,
|
||||
):
|
||||
"""Test play playlist task with non-existent playlist."""
|
||||
task = ScheduledTask(
|
||||
name="Play Playlist",
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={"playlist_id": str(test_playlist_id)},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None):
|
||||
with pytest.raises(TaskExecutionError, match="Playlist not found"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
async def test_handle_play_playlist_valid_play_modes(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
mock_player_service,
|
||||
test_playlist_id: uuid.UUID,
|
||||
):
|
||||
"""Test play playlist task with various valid play modes."""
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
valid_modes = ["continuous", "loop", "loop_one", "random", "single"]
|
||||
|
||||
for mode in valid_modes:
|
||||
task = ScheduledTask(
|
||||
name="Play Playlist",
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={
|
||||
"playlist_id": str(test_playlist_id),
|
||||
"play_mode": mode,
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
mock_player_service.set_mode.assert_called_with(mode)
|
||||
|
||||
# Reset mock for next iteration
|
||||
mock_player_service.reset_mock()
|
||||
|
||||
async def test_handle_play_playlist_invalid_play_mode(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
mock_player_service,
|
||||
test_playlist_id: uuid.UUID,
|
||||
):
|
||||
"""Test play playlist task with invalid play mode."""
|
||||
task = ScheduledTask(
|
||||
name="Play Playlist",
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={
|
||||
"playlist_id": str(test_playlist_id),
|
||||
"play_mode": "invalid_mode",
|
||||
},
|
||||
)
|
||||
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
# Should not call set_mode for invalid mode
|
||||
mock_player_service.set_mode.assert_not_called()
|
||||
# But should still load playlist and play
|
||||
mock_player_service.load_playlist.assert_called_once()
|
||||
mock_player_service.play.assert_called_once()
|
||||
|
||||
async def test_task_execution_exception_handling(
|
||||
self,
|
||||
task_registry: TaskHandlerRegistry,
|
||||
mock_credit_service,
|
||||
):
|
||||
"""Test exception handling during task execution."""
|
||||
task = ScheduledTask(
|
||||
name="Failing Task",
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=datetime.utcnow(),
|
||||
parameters={},
|
||||
)
|
||||
|
||||
# Make credit service raise an exception
|
||||
mock_credit_service.recharge_all_users_credits.side_effect = Exception("Service error")
|
||||
|
||||
with pytest.raises(TaskExecutionError, match="Task execution failed: Service error"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
async def test_task_registry_initialization(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
mock_credit_service,
|
||||
mock_player_service,
|
||||
):
|
||||
"""Test task registry initialization."""
|
||||
registry = TaskHandlerRegistry(
|
||||
db_session,
|
||||
mock_credit_service,
|
||||
mock_player_service,
|
||||
)
|
||||
|
||||
assert registry.db_session == db_session
|
||||
assert registry.credit_service == mock_credit_service
|
||||
assert registry.player_service == mock_player_service
|
||||
assert registry.sound_repository is not None
|
||||
assert registry.playlist_repository is not None
|
||||
|
||||
# Check all handlers are registered
|
||||
expected_handlers = {
|
||||
TaskType.CREDIT_RECHARGE,
|
||||
TaskType.PLAY_SOUND,
|
||||
TaskType.PLAY_PLAYLIST,
|
||||
}
|
||||
assert set(registry._handlers.keys()) == expected_handlers
|
||||
Reference in New Issue
Block a user