"""Tests for scheduled task repository.""" from datetime import UTC, datetime, timedelta import pytest from sqlmodel.ext.asyncio.session import AsyncSession from app.models.scheduled_task import ( RecurrenceType, ScheduledTask, TaskStatus, TaskType, ) from app.repositories.scheduled_task import ScheduledTaskRepository class TestScheduledTaskRepository: """Test cases for scheduled task repository.""" @pytest.fixture def repository(self, db_session: AsyncSession) -> ScheduledTaskRepository: """Create repository fixture.""" return ScheduledTaskRepository(db_session) @pytest.fixture async def sample_task( self, repository: ScheduledTaskRepository, ) -> ScheduledTask: """Create a sample scheduled task.""" task_data = { "name": "Test Task", "task_type": TaskType.CREDIT_RECHARGE, "scheduled_at": datetime.now(tz=UTC) + timedelta(hours=1), "parameters": {"test": "value"}, } return await repository.create(task_data) @pytest.fixture async def user_task( self, repository: ScheduledTaskRepository, test_user_id: int, ) -> ScheduledTask: """Create a user task.""" task_data = { "name": "User Task", "task_type": TaskType.PLAY_SOUND, "scheduled_at": datetime.now(tz=UTC) + timedelta(hours=2), "user_id": test_user_id, "parameters": {"sound_id": "1"}, } return await repository.create(task_data) async def test_create_task(self, repository: ScheduledTaskRepository): """Test creating a scheduled task.""" task_data = { "name": "Test Task", "task_type": TaskType.CREDIT_RECHARGE, "scheduled_at": datetime.now(tz=UTC) + timedelta(hours=1), "timezone": "America/New_York", "recurrence_type": RecurrenceType.DAILY, "parameters": {"test": "value"}, } created_task = await repository.create(task_data) assert created_task.id is not None assert created_task.name == "Test Task" assert created_task.task_type == TaskType.CREDIT_RECHARGE assert created_task.status == TaskStatus.PENDING assert created_task.timezone == "America/New_York" assert created_task.recurrence_type == RecurrenceType.DAILY assert created_task.parameters == {"test": "value"} assert created_task.is_active is True assert created_task.executions_count == 0 async def test_get_pending_tasks( self, repository: ScheduledTaskRepository, ): """Test getting pending tasks.""" # Create tasks with different statuses and times past_pending = ScheduledTask( name="Past Pending", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1), status=TaskStatus.PENDING, ) await repository.create(past_pending) future_pending = ScheduledTask( name="Future Pending", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), status=TaskStatus.PENDING, ) await repository.create(future_pending) completed_task = ScheduledTask( name="Completed", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1), status=TaskStatus.COMPLETED, ) await repository.create(completed_task) inactive_task = ScheduledTask( name="Inactive", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1), status=TaskStatus.PENDING, is_active=False, ) await repository.create(inactive_task) pending_tasks = await repository.get_pending_tasks() task_names = [task.name for task in pending_tasks] # Only the past pending task should be returned assert len(pending_tasks) == 1 assert "Past Pending" in task_names async def test_get_user_tasks( self, repository: ScheduledTaskRepository, user_task: ScheduledTask, test_user_id: int, ): """Test getting tasks for a specific user.""" # Create another user's task other_user_id = 999 other_task = ScheduledTask( name="Other User Task", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), user_id=other_user_id, ) await repository.create(other_task) # Create system task (no user) system_task = ScheduledTask( name="System Task", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), ) await repository.create(system_task) user_tasks = await repository.get_user_tasks(test_user_id) assert len(user_tasks) == 1 assert user_tasks[0].name == "User Task" assert user_tasks[0].user_id == test_user_id async def test_get_user_tasks_with_filters( self, repository: ScheduledTaskRepository, test_user_id: int, ): """Test getting user tasks with status and type filters.""" # Create tasks with different statuses and types tasks_data = [ ("Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE), ("Task 2", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE), ("Task 3", TaskStatus.PENDING, TaskType.PLAY_SOUND), ("Task 4", TaskStatus.FAILED, TaskType.PLAY_PLAYLIST), ] for name, status, task_type in tasks_data: task = ScheduledTask( name=name, task_type=task_type, status=status, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), user_id=test_user_id, ) await repository.create(task) # Test status filter pending_tasks = await repository.get_user_tasks( test_user_id, status=TaskStatus.PENDING, ) assert len(pending_tasks) == 2 assert all(task.status == TaskStatus.PENDING for task in pending_tasks) # Test type filter credit_tasks = await repository.get_user_tasks( test_user_id, task_type=TaskType.CREDIT_RECHARGE, ) assert len(credit_tasks) == 2 assert all(task.task_type == TaskType.CREDIT_RECHARGE for task in credit_tasks) # Test combined filters pending_credit_tasks = await repository.get_user_tasks( test_user_id, status=TaskStatus.PENDING, task_type=TaskType.CREDIT_RECHARGE, ) assert len(pending_credit_tasks) == 1 assert pending_credit_tasks[0].name == "Task 1" async def test_get_system_tasks( self, repository: ScheduledTaskRepository, sample_task: ScheduledTask, user_task: ScheduledTask, ): """Test getting system tasks.""" system_tasks = await repository.get_system_tasks() assert len(system_tasks) == 1 assert system_tasks[0].name == "Test Task" assert system_tasks[0].user_id is None async def test_get_recurring_tasks_due_for_next_execution( self, repository: ScheduledTaskRepository, ): """Test getting recurring tasks due for next execution.""" # Create completed recurring task that should be re-executed due_task = ScheduledTask( name="Due Recurring", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1), recurrence_type=RecurrenceType.DAILY, status=TaskStatus.COMPLETED, next_execution_at=datetime.now(tz=UTC) - timedelta(minutes=5), ) await repository.create(due_task) # Create completed recurring task not due yet not_due_task = ScheduledTask( name="Not Due Recurring", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1), recurrence_type=RecurrenceType.DAILY, status=TaskStatus.COMPLETED, next_execution_at=datetime.now(tz=UTC) + timedelta(hours=1), ) await repository.create(not_due_task) # Create non-recurring completed task non_recurring = ScheduledTask( name="Non-recurring", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1), recurrence_type=RecurrenceType.NONE, status=TaskStatus.COMPLETED, ) await repository.create(non_recurring) due_tasks = await repository.get_recurring_tasks_due_for_next_execution() assert len(due_tasks) == 1 assert due_tasks[0].name == "Due Recurring" async def test_get_expired_tasks( self, repository: ScheduledTaskRepository, ): """Test getting expired tasks.""" # Create expired task expired_task = ScheduledTask( name="Expired Task", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), expires_at=datetime.now(tz=UTC) - timedelta(hours=1), ) await repository.create(expired_task) # Create non-expired task valid_task = ScheduledTask( name="Valid Task", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), expires_at=datetime.now(tz=UTC) + timedelta(hours=2), ) await repository.create(valid_task) # Create task with no expiry no_expiry_task = ScheduledTask( name="No Expiry", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), ) await repository.create(no_expiry_task) expired_tasks = await repository.get_expired_tasks() assert len(expired_tasks) == 1 assert expired_tasks[0].name == "Expired Task" async def test_cancel_user_tasks( self, repository: ScheduledTaskRepository, test_user_id: int, ): """Test cancelling user tasks.""" # Create multiple user tasks tasks_data = [ ("Pending Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE), ("Running Task", TaskStatus.RUNNING, TaskType.PLAY_SOUND), ("Completed Task", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE), ] for name, status, task_type in tasks_data: task = ScheduledTask( name=name, task_type=task_type, status=status, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), user_id=test_user_id, ) await repository.create(task) # Cancel all user tasks cancelled_count = await repository.cancel_user_tasks(test_user_id) assert cancelled_count == 2 # Only pending and running tasks # Verify tasks are cancelled user_tasks = await repository.get_user_tasks(test_user_id) pending_or_running = [ task for task in user_tasks if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING] ] cancelled_tasks = [ task for task in user_tasks if task.status == TaskStatus.CANCELLED ] assert len(pending_or_running) == 0 assert len(cancelled_tasks) == 2 async def test_cancel_user_tasks_by_type( self, repository: ScheduledTaskRepository, test_user_id: int, ): """Test cancelling user tasks by type.""" # Create tasks of different types credit_task = ScheduledTask( name="Credit Task", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), user_id=test_user_id, ) await repository.create(credit_task) sound_task = ScheduledTask( name="Sound Task", task_type=TaskType.PLAY_SOUND, scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1), user_id=test_user_id, ) await repository.create(sound_task) # Cancel only credit tasks cancelled_count = await repository.cancel_user_tasks( test_user_id, TaskType.CREDIT_RECHARGE, ) assert cancelled_count == 1 # Verify only credit task is cancelled user_tasks = await repository.get_user_tasks(test_user_id) credit_tasks = [ task for task in user_tasks if task.task_type == TaskType.CREDIT_RECHARGE ] sound_tasks = [ task for task in user_tasks if task.task_type == TaskType.PLAY_SOUND ] assert len(credit_tasks) == 1 assert credit_tasks[0].status == TaskStatus.CANCELLED assert len(sound_tasks) == 1 assert sound_tasks[0].status == TaskStatus.PENDING async def test_mark_as_running( self, repository: ScheduledTaskRepository, sample_task: ScheduledTask, ): """Test marking task as running.""" await repository.mark_as_running(sample_task) updated_task = await repository.get_by_id(sample_task.id) assert updated_task.status == TaskStatus.RUNNING async def test_mark_as_completed( self, repository: ScheduledTaskRepository, sample_task: ScheduledTask, ): """Test marking task as completed.""" initial_count = sample_task.executions_count next_execution = datetime.now(tz=UTC) + timedelta(days=1) await repository.mark_as_completed(sample_task, next_execution) updated_task = await repository.get_by_id(sample_task.id) assert updated_task.status == TaskStatus.COMPLETED assert updated_task.executions_count == initial_count + 1 assert updated_task.last_executed_at is not None assert updated_task.error_message is None async def test_mark_as_completed_recurring_task( self, repository: ScheduledTaskRepository, ): """Test marking recurring task as completed.""" task = ScheduledTask( name="Recurring Task", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC), recurrence_type=RecurrenceType.DAILY, ) created_task = await repository.create(task) next_execution = datetime.now(tz=UTC).replace(tzinfo=None) + timedelta(days=1) await repository.mark_as_completed(created_task, next_execution) updated_task = await repository.get_by_id(created_task.id) # Should be set back to pending for next execution assert updated_task.status == TaskStatus.PENDING assert updated_task.next_execution_at == next_execution assert updated_task.is_active is True async def test_mark_as_completed_non_recurring_task( self, repository: ScheduledTaskRepository, sample_task: ScheduledTask, ): """Test marking non-recurring task as completed.""" await repository.mark_as_completed(sample_task, None) updated_task = await repository.get_by_id(sample_task.id) assert updated_task.status == TaskStatus.COMPLETED assert updated_task.is_active is False async def test_mark_as_failed( self, repository: ScheduledTaskRepository, sample_task: ScheduledTask, ): """Test marking task as failed.""" error_message = "Task execution failed" await repository.mark_as_failed(sample_task, error_message) updated_task = await repository.get_by_id(sample_task.id) assert updated_task.status == TaskStatus.FAILED assert updated_task.error_message == error_message assert updated_task.last_executed_at is not None async def test_mark_as_failed_recurring_task( self, repository: ScheduledTaskRepository, ): """Test marking recurring task as failed.""" task = ScheduledTask( name="Recurring Task", task_type=TaskType.CREDIT_RECHARGE, scheduled_at=datetime.now(tz=UTC), recurrence_type=RecurrenceType.DAILY, ) created_task = await repository.create(task) await repository.mark_as_failed(created_task, "Failed") updated_task = await repository.get_by_id(created_task.id) assert updated_task.status == TaskStatus.FAILED # Recurring tasks should remain active even after failure assert updated_task.is_active is True async def test_mark_as_failed_non_recurring_task( self, repository: ScheduledTaskRepository, sample_task: ScheduledTask, ): """Test marking non-recurring task as failed.""" await repository.mark_as_failed(sample_task, "Failed") updated_task = await repository.get_by_id(sample_task.id) assert updated_task.status == TaskStatus.FAILED # Non-recurring tasks should be deactivated on failure assert updated_task.is_active is False