Refactor scheduled task repository and schemas for improved type hints and consistency
- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
@@ -351,11 +351,11 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id(test_user: User):
|
||||
"""Get test user ID."""
|
||||
"""Get test user ID."""
|
||||
return test_user.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture
|
||||
def test_sound_id():
|
||||
"""Create a test sound ID."""
|
||||
import uuid
|
||||
@@ -364,7 +364,7 @@ def test_sound_id():
|
||||
|
||||
@pytest.fixture
|
||||
def test_playlist_id():
|
||||
"""Create a test playlist ID."""
|
||||
"""Create a test playlist ID."""
|
||||
import uuid
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
@@ -217,4 +215,4 @@ class TestScheduledTaskModel:
|
||||
assert RecurrenceType.WEEKLY == "weekly"
|
||||
assert RecurrenceType.MONTHLY == "monthly"
|
||||
assert RecurrenceType.YEARLY == "yearly"
|
||||
assert RecurrenceType.CRON == "cron"
|
||||
assert RecurrenceType.CRON == "cron"
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -491,4 +490,4 @@ class TestScheduledTaskRepository:
|
||||
updated_task = await repository.get_by_id(sample_task.id)
|
||||
assert updated_task.status == TaskStatus.FAILED
|
||||
# Non-recurring tasks should be deactivated on failure
|
||||
assert updated_task.is_active is False
|
||||
assert updated_task.is_active is False
|
||||
|
||||
@@ -51,7 +51,7 @@ class TestSchedulerService:
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a scheduled task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
assert task.id is not None
|
||||
@@ -68,7 +68,7 @@ class TestSchedulerService:
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test creating a user task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(
|
||||
user_id=test_user_id,
|
||||
**sample_task_data,
|
||||
@@ -83,7 +83,7 @@ class TestSchedulerService:
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a system task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
assert task.user_id is None
|
||||
@@ -95,7 +95,7 @@ class TestSchedulerService:
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a recurring task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
recurrence_count=5,
|
||||
@@ -114,11 +114,11 @@ class TestSchedulerService:
|
||||
"""Test creating task with timezone conversion."""
|
||||
# Use a specific datetime for testing
|
||||
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY
|
||||
|
||||
|
||||
sample_task_data["scheduled_at"] = ny_time
|
||||
sample_task_data["timezone"] = "America/New_York"
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# The scheduled_at should be converted to UTC
|
||||
@@ -134,11 +134,11 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test cancelling a task."""
|
||||
# Create a task first
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock the scheduler remove_job method
|
||||
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
||||
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
|
||||
result = await scheduler_service.cancel_task(task.id)
|
||||
|
||||
assert result is True
|
||||
@@ -167,7 +167,7 @@ class TestSchedulerService:
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test getting user tasks."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
# Create user task
|
||||
await scheduler_service.create_task(
|
||||
user_id=test_user_id,
|
||||
@@ -188,12 +188,12 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test ensuring system tasks exist."""
|
||||
# Mock the repository to return no existing tasks
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
||||
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
|
||||
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
|
||||
mock_get.return_value = []
|
||||
|
||||
|
||||
await scheduler_service._ensure_system_tasks()
|
||||
|
||||
|
||||
# Should create daily credit recharge task
|
||||
mock_create.assert_called_once()
|
||||
created_task = mock_create.call_args[0][0]
|
||||
@@ -213,13 +213,13 @@ class TestSchedulerService:
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
||||
|
||||
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
|
||||
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
|
||||
mock_get.return_value = [existing_task]
|
||||
|
||||
|
||||
await scheduler_service._ensure_system_tasks()
|
||||
|
||||
|
||||
# Should not create new task
|
||||
mock_create.assert_not_called()
|
||||
|
||||
@@ -294,7 +294,7 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test calculating next execution time."""
|
||||
now = datetime.utcnow()
|
||||
|
||||
|
||||
# Test different recurrence types
|
||||
test_cases = [
|
||||
(RecurrenceType.HOURLY, timedelta(hours=1)),
|
||||
@@ -312,7 +312,7 @@ class TestSchedulerService:
|
||||
recurrence_type=recurrence_type,
|
||||
)
|
||||
|
||||
with patch('app.services.scheduler.datetime') as mock_datetime:
|
||||
with patch("app.services.scheduler.datetime") as mock_datetime:
|
||||
mock_datetime.utcnow.return_value = now
|
||||
next_execution = scheduler_service._calculate_next_execution(task)
|
||||
|
||||
@@ -335,7 +335,7 @@ class TestSchedulerService:
|
||||
next_execution = scheduler_service._calculate_next_execution(task)
|
||||
assert next_execution is None
|
||||
|
||||
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
||||
@patch("app.services.task_handlers.TaskHandlerRegistry")
|
||||
async def test_execute_task_success(
|
||||
self,
|
||||
mock_handler_class,
|
||||
@@ -344,7 +344,7 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test successful task execution."""
|
||||
# Create task
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock handler registry
|
||||
@@ -365,7 +365,7 @@ class TestSchedulerService:
|
||||
assert updated_task.status == TaskStatus.COMPLETED
|
||||
assert updated_task.executions_count == 1
|
||||
|
||||
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
||||
@patch("app.services.task_handlers.TaskHandlerRegistry")
|
||||
async def test_execute_task_failure(
|
||||
self,
|
||||
mock_handler_class,
|
||||
@@ -374,7 +374,7 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test task execution failure."""
|
||||
# Create task
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock handler to raise exception
|
||||
@@ -409,8 +409,8 @@ class TestSchedulerService:
|
||||
"""Test executing expired task."""
|
||||
# Create expired task
|
||||
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Execute task
|
||||
@@ -430,20 +430,20 @@ class TestSchedulerService:
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test prevention of concurrent task execution."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Add task to running set
|
||||
scheduler_service._running_tasks.add(str(task.id))
|
||||
|
||||
# Try to execute - should return without doing anything
|
||||
with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class:
|
||||
with patch("app.services.task_handlers.TaskHandlerRegistry") as mock_handler_class:
|
||||
await scheduler_service._execute_task(task.id)
|
||||
|
||||
|
||||
# Handler should not be called
|
||||
mock_handler_class.assert_not_called()
|
||||
|
||||
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
||||
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
|
||||
async def test_maintenance_job_expired_tasks(
|
||||
self,
|
||||
mock_repo_class,
|
||||
@@ -453,22 +453,22 @@ class TestSchedulerService:
|
||||
# Mock expired task
|
||||
expired_task = MagicMock()
|
||||
expired_task.id = uuid.uuid4()
|
||||
|
||||
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_expired_tasks.return_value = [expired_task]
|
||||
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
|
||||
mock_repo_class.return_value = mock_repo
|
||||
|
||||
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
||||
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
|
||||
await scheduler_service._maintenance_job()
|
||||
|
||||
|
||||
# Should mark as cancelled and remove from scheduler
|
||||
assert expired_task.status == TaskStatus.CANCELLED
|
||||
assert expired_task.is_active is False
|
||||
mock_repo.update.assert_called_with(expired_task)
|
||||
mock_remove.assert_called_once_with(str(expired_task.id))
|
||||
|
||||
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
||||
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
|
||||
async def test_maintenance_job_due_recurring_tasks(
|
||||
self,
|
||||
mock_repo_class,
|
||||
@@ -479,17 +479,17 @@ class TestSchedulerService:
|
||||
due_task = MagicMock()
|
||||
due_task.should_repeat.return_value = True
|
||||
due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5)
|
||||
|
||||
|
||||
mock_repo = AsyncMock()
|
||||
mock_repo.get_expired_tasks.return_value = []
|
||||
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
|
||||
mock_repo_class.return_value = mock_repo
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
|
||||
await scheduler_service._maintenance_job()
|
||||
|
||||
|
||||
# Should reset to pending and reschedule
|
||||
assert due_task.status == TaskStatus.PENDING
|
||||
assert due_task.scheduled_at == due_task.next_execution_at
|
||||
mock_repo.update.assert_called_with(due_task)
|
||||
mock_schedule.assert_called_once_with(due_task)
|
||||
mock_schedule.assert_called_once_with(due_task)
|
||||
|
||||
@@ -133,8 +133,8 @@ class TestTaskHandlerRegistry:
|
||||
mock_sound.id = test_sound_id
|
||||
mock_sound.filename = "test_sound.mp3"
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
||||
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
||||
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
|
||||
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_vlc_class.return_value = mock_vlc_service
|
||||
|
||||
@@ -186,7 +186,7 @@ class TestTaskHandlerRegistry:
|
||||
parameters={"sound_id": str(test_sound_id)},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None):
|
||||
with patch.object(task_registry.sound_repository, "get_by_id", return_value=None):
|
||||
with pytest.raises(TaskExecutionError, match="Sound not found"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
@@ -206,8 +206,8 @@ class TestTaskHandlerRegistry:
|
||||
mock_sound = MagicMock()
|
||||
mock_sound.filename = "test_sound.mp3"
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
||||
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
||||
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
|
||||
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_vlc_class.return_value = mock_vlc_service
|
||||
|
||||
@@ -238,7 +238,7 @@ class TestTaskHandlerRegistry:
|
||||
mock_playlist.id = test_playlist_id
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
|
||||
@@ -264,7 +264,7 @@ class TestTaskHandlerRegistry:
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
# Should use default values
|
||||
@@ -314,7 +314,7 @@ class TestTaskHandlerRegistry:
|
||||
parameters={"playlist_id": str(test_playlist_id)},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=None):
|
||||
with pytest.raises(TaskExecutionError, match="Playlist not found"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
@@ -327,7 +327,7 @@ class TestTaskHandlerRegistry:
|
||||
"""Test play playlist task with various valid play modes."""
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
|
||||
valid_modes = ["continuous", "loop", "loop_one", "random", "single"]
|
||||
|
||||
for mode in valid_modes:
|
||||
@@ -341,7 +341,7 @@ class TestTaskHandlerRegistry:
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
mock_player_service.set_mode.assert_called_with(mode)
|
||||
|
||||
@@ -368,7 +368,7 @@ class TestTaskHandlerRegistry:
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
# Should not call set_mode for invalid mode
|
||||
@@ -421,4 +421,4 @@ class TestTaskHandlerRegistry:
|
||||
TaskType.PLAY_SOUND,
|
||||
TaskType.PLAY_PLAYLIST,
|
||||
}
|
||||
assert set(registry._handlers.keys()) == expected_handlers
|
||||
assert set(registry._handlers.keys()) == expected_handlers
|
||||
|
||||
Reference in New Issue
Block a user