Refactor code structure for improved readability and maintainability

This commit is contained in:
JSC
2025-08-29 15:27:12 +02:00
parent dc89e45675
commit 2bdd109492
23 changed files with 652 additions and 719 deletions

View File

@@ -358,15 +358,13 @@ def test_user_id(test_user: User):
@pytest.fixture
def test_sound_id():
"""Create a test sound ID."""
import uuid
return uuid.uuid4()
return 1
@pytest.fixture
def test_playlist_id():
"""Create a test playlist ID."""
import uuid
return uuid.uuid4()
return 1
@pytest.fixture

View File

@@ -20,7 +20,9 @@ class TestSchedulerService:
@pytest.fixture
def scheduler_service(self, mock_db_session_factory):
"""Create a scheduler service instance for testing."""
return SchedulerService(mock_db_session_factory)
from unittest.mock import MagicMock
mock_player_service = MagicMock()
return SchedulerService(mock_db_session_factory, mock_player_service)
@pytest.mark.asyncio
async def test_start_scheduler(self, scheduler_service) -> None:
@@ -31,20 +33,18 @@ class TestSchedulerService:
):
await scheduler_service.start()
# Verify job was added
mock_add_job.assert_called_once_with(
scheduler_service._daily_credit_recharge,
"cron",
hour=0,
minute=0,
id="daily_credit_recharge",
name="Daily Credit Recharge",
replace_existing=True,
)
# Verify scheduler was started
# Verify scheduler start was called
mock_start.assert_called_once()
# Verify jobs were added (2 calls: initialize_system_tasks and scheduler_maintenance)
assert mock_add_job.call_count == 2
# Check that the jobs are the expected ones
calls = mock_add_job.call_args_list
job_ids = [call[1]["id"] for call in calls]
assert "initialize_system_tasks" in job_ids
assert "scheduler_maintenance" in job_ids
@pytest.mark.asyncio
async def test_stop_scheduler(self, scheduler_service) -> None:
"""Test stopping the scheduler service."""
@@ -52,36 +52,3 @@ class TestSchedulerService:
await scheduler_service.stop()
mock_shutdown.assert_called_once()
@pytest.mark.asyncio
async def test_daily_credit_recharge_success(self, scheduler_service) -> None:
"""Test successful daily credit recharge task."""
mock_stats = {
"total_users": 10,
"recharged_users": 8,
"skipped_users": 2,
"total_credits_added": 500,
}
with patch.object(
scheduler_service.credit_service,
"recharge_all_users_credits",
) as mock_recharge:
mock_recharge.return_value = mock_stats
await scheduler_service._daily_credit_recharge()
mock_recharge.assert_called_once()
@pytest.mark.asyncio
async def test_daily_credit_recharge_failure(self, scheduler_service) -> None:
"""Test daily credit recharge task with failure."""
with patch.object(
scheduler_service.credit_service,
"recharge_all_users_credits",
) as mock_recharge:
mock_recharge.side_effect = Exception("Database error")
# Should not raise exception, just log it
await scheduler_service._daily_credit_recharge()
mock_recharge.assert_called_once()

View File

@@ -2,7 +2,7 @@
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
@@ -405,8 +405,17 @@ class TestVLCPlayerService:
async def test_record_play_count_success(self, vlc_service_with_db) -> None:
"""Test successful play count recording."""
# Mock session and repositories
mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session
mock_session = MagicMock()
# Make async methods async mocks but keep sync methods as regular mocks
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
mock_session.close = AsyncMock()
# Mock the context manager behavior
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_session)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
vlc_service_with_db.db_session_factory.return_value = mock_context_manager
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
@@ -449,18 +458,18 @@ class TestVLCPlayerService:
# Verify sound repository calls
mock_sound_repo.get_by_id.assert_called_once_with(1)
mock_sound_repo.update.assert_called_once_with(
test_sound,
{"play_count": 1},
)
# Verify user repository calls
mock_user_repo.get_by_id.assert_called_once_with(1)
# Verify session operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.close.assert_called_once()
# Verify session operations (called twice: once for sound, once for sound_played)
assert mock_session.add.call_count == 2
# Commit is called twice: once after updating sound, once after adding sound_played
assert mock_session.commit.call_count == 2
# Context manager handles session cleanup, so no explicit close() call
# Verify the sound's play count was incremented
assert test_sound.play_count == 1
# Verify socket broadcast
mock_socket.broadcast_to_all.assert_called_once_with(
@@ -488,8 +497,17 @@ class TestVLCPlayerService:
) -> None:
"""Test play count recording always creates a new SoundPlayed record."""
# Mock session and repositories
mock_session = AsyncMock()
vlc_service_with_db.db_session_factory.return_value = mock_session
mock_session = MagicMock()
# Make async methods async mocks but keep sync methods as regular mocks
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
mock_session.close = AsyncMock()
# Mock the context manager behavior
mock_context_manager = AsyncMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_session)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)
vlc_service_with_db.db_session_factory.return_value = mock_context_manager
mock_sound_repo = AsyncMock()
mock_user_repo = AsyncMock()
@@ -530,17 +548,19 @@ class TestVLCPlayerService:
await vlc_service_with_db._record_play_count(1, "Test Sound")
# Verify sound play count was updated
mock_sound_repo.update.assert_called_once_with(
test_sound,
{"play_count": 6},
)
# Verify sound repository calls
mock_sound_repo.get_by_id.assert_called_once_with(1)
# Verify new SoundPlayed record was always added
mock_session.add.assert_called_once()
# Verify user repository calls
mock_user_repo.get_by_id.assert_called_once_with(1)
# Verify commit happened
mock_session.commit.assert_called_once()
# Verify session operations (called twice: once for sound, once for sound_played)
assert mock_session.add.call_count == 2
# Commit is called twice: once after updating sound, once after adding sound_played
assert mock_session.commit.call_count == 2
# Verify the sound's play count was incremented from 5 to 6
assert test_sound.play_count == 6
def test_uses_shared_sound_path_utility(self, vlc_service, sample_sound) -> None:
"""Test that VLC service uses the shared sound path utility."""

View File

@@ -1,7 +1,7 @@
"""Tests for scheduled task model."""
import uuid
from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta
from app.models.scheduled_task import (
RecurrenceType,
@@ -19,7 +19,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="Test Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
assert task.name == "Test Task"
@@ -38,7 +38,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="User Task",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=user_id,
)
@@ -50,7 +50,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="System Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
assert task.user_id is None
@@ -61,7 +61,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="Recurring Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
recurrence_count=5,
)
@@ -74,7 +74,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="One-shot Task",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.NONE,
)
@@ -86,7 +86,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="Infinite Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
recurrence_count=None, # Infinite
)
@@ -103,7 +103,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="Limited Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
recurrence_count=3,
)
@@ -120,21 +120,21 @@ class TestScheduledTaskModel:
def test_task_expiration(self):
"""Test task expiration."""
# Non-expired task
# Non-expired task (using naive UTC datetimes)
task = ScheduledTask(
name="Valid Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
expires_at=datetime.utcnow() + timedelta(hours=2),
scheduled_at=datetime.now(tz=UTC).replace(tzinfo=None) + timedelta(hours=1),
expires_at=datetime.now(tz=UTC).replace(tzinfo=None) + timedelta(hours=2),
)
assert not task.is_expired()
# Expired task
# Expired task (using naive UTC datetimes)
expired_task = ScheduledTask(
name="Expired Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
expires_at=datetime.utcnow() - timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC).replace(tzinfo=None) + timedelta(hours=1),
expires_at=datetime.now(tz=UTC).replace(tzinfo=None) - timedelta(hours=1),
)
assert expired_task.is_expired()
@@ -142,7 +142,7 @@ class TestScheduledTaskModel:
no_expiry_task = ScheduledTask(
name="No Expiry Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
assert not no_expiry_task.is_expired()
@@ -157,7 +157,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="Parametrized Task",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
parameters=parameters,
)
@@ -171,7 +171,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="NY Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
timezone="America/New_York",
)
@@ -184,7 +184,7 @@ class TestScheduledTaskModel:
task = ScheduledTask(
name="Cron Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.CRON,
cron_expression=cron_expr,
)

View File

@@ -1,7 +1,6 @@
"""Tests for scheduled task repository."""
import uuid
from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -29,42 +28,42 @@ class TestScheduledTaskRepository:
repository: ScheduledTaskRepository,
) -> ScheduledTask:
"""Create a sample scheduled task."""
task = ScheduledTask(
name="Test Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
parameters={"test": "value"},
)
return await repository.create(task)
task_data = {
"name": "Test Task",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": datetime.now(tz=UTC) + timedelta(hours=1),
"parameters": {"test": "value"},
}
return await repository.create(task_data)
@pytest.fixture
async def user_task(
self,
repository: ScheduledTaskRepository,
test_user_id: uuid.UUID,
test_user_id: int,
) -> ScheduledTask:
"""Create a user task."""
task = ScheduledTask(
name="User Task",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow() + timedelta(hours=2),
user_id=test_user_id,
parameters={"sound_id": str(uuid.uuid4())},
)
return await repository.create(task)
task_data = {
"name": "User Task",
"task_type": TaskType.PLAY_SOUND,
"scheduled_at": datetime.now(tz=UTC) + timedelta(hours=2),
"user_id": test_user_id,
"parameters": {"sound_id": "1"},
}
return await repository.create(task_data)
async def test_create_task(self, repository: ScheduledTaskRepository):
"""Test creating a scheduled task."""
task = ScheduledTask(
name="Test Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
timezone="America/New_York",
recurrence_type=RecurrenceType.DAILY,
parameters={"test": "value"},
)
task_data = {
"name": "Test Task",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": datetime.now(tz=UTC) + timedelta(hours=1),
"timezone": "America/New_York",
"recurrence_type": RecurrenceType.DAILY,
"parameters": {"test": "value"},
}
created_task = await repository.create(task)
created_task = await repository.create(task_data)
assert created_task.id is not None
assert created_task.name == "Test Task"
@@ -85,7 +84,7 @@ class TestScheduledTaskRepository:
past_pending = ScheduledTask(
name="Past Pending",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() - timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
status=TaskStatus.PENDING,
)
await repository.create(past_pending)
@@ -93,7 +92,7 @@ class TestScheduledTaskRepository:
future_pending = ScheduledTask(
name="Future Pending",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
status=TaskStatus.PENDING,
)
await repository.create(future_pending)
@@ -101,7 +100,7 @@ class TestScheduledTaskRepository:
completed_task = ScheduledTask(
name="Completed",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() - timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
status=TaskStatus.COMPLETED,
)
await repository.create(completed_task)
@@ -109,7 +108,7 @@ class TestScheduledTaskRepository:
inactive_task = ScheduledTask(
name="Inactive",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() - timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
status=TaskStatus.PENDING,
is_active=False,
)
@@ -126,15 +125,15 @@ class TestScheduledTaskRepository:
self,
repository: ScheduledTaskRepository,
user_task: ScheduledTask,
test_user_id: uuid.UUID,
test_user_id: int,
):
"""Test getting tasks for a specific user."""
# Create another user's task
other_user_id = uuid.uuid4()
other_user_id = 999
other_task = ScheduledTask(
name="Other User Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=other_user_id,
)
await repository.create(other_task)
@@ -143,7 +142,7 @@ class TestScheduledTaskRepository:
system_task = ScheduledTask(
name="System Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
await repository.create(system_task)
@@ -156,7 +155,7 @@ class TestScheduledTaskRepository:
async def test_get_user_tasks_with_filters(
self,
repository: ScheduledTaskRepository,
test_user_id: uuid.UUID,
test_user_id: int,
):
"""Test getting user tasks with status and type filters."""
# Create tasks with different statuses and types
@@ -172,7 +171,7 @@ class TestScheduledTaskRepository:
name=name,
task_type=task_type,
status=status,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=test_user_id,
)
await repository.create(task)
@@ -224,10 +223,10 @@ class TestScheduledTaskRepository:
due_task = ScheduledTask(
name="Due Recurring",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() - timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
status=TaskStatus.COMPLETED,
next_execution_at=datetime.utcnow() - timedelta(minutes=5),
next_execution_at=datetime.now(tz=UTC) - timedelta(minutes=5),
)
await repository.create(due_task)
@@ -235,10 +234,10 @@ class TestScheduledTaskRepository:
not_due_task = ScheduledTask(
name="Not Due Recurring",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() - timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
status=TaskStatus.COMPLETED,
next_execution_at=datetime.utcnow() + timedelta(hours=1),
next_execution_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
await repository.create(not_due_task)
@@ -246,7 +245,7 @@ class TestScheduledTaskRepository:
non_recurring = ScheduledTask(
name="Non-recurring",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() - timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) - timedelta(hours=1),
recurrence_type=RecurrenceType.NONE,
status=TaskStatus.COMPLETED,
)
@@ -266,8 +265,8 @@ class TestScheduledTaskRepository:
expired_task = ScheduledTask(
name="Expired Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
expires_at=datetime.utcnow() - timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
expires_at=datetime.now(tz=UTC) - timedelta(hours=1),
)
await repository.create(expired_task)
@@ -275,8 +274,8 @@ class TestScheduledTaskRepository:
valid_task = ScheduledTask(
name="Valid Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
expires_at=datetime.utcnow() + timedelta(hours=2),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
expires_at=datetime.now(tz=UTC) + timedelta(hours=2),
)
await repository.create(valid_task)
@@ -284,7 +283,7 @@ class TestScheduledTaskRepository:
no_expiry_task = ScheduledTask(
name="No Expiry",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
)
await repository.create(no_expiry_task)
@@ -296,7 +295,7 @@ class TestScheduledTaskRepository:
async def test_cancel_user_tasks(
self,
repository: ScheduledTaskRepository,
test_user_id: uuid.UUID,
test_user_id: int,
):
"""Test cancelling user tasks."""
# Create multiple user tasks
@@ -311,7 +310,7 @@ class TestScheduledTaskRepository:
name=name,
task_type=task_type,
status=status,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=test_user_id,
)
await repository.create(task)
@@ -338,14 +337,14 @@ class TestScheduledTaskRepository:
async def test_cancel_user_tasks_by_type(
self,
repository: ScheduledTaskRepository,
test_user_id: uuid.UUID,
test_user_id: int,
):
"""Test cancelling user tasks by type."""
# Create tasks of different types
credit_task = ScheduledTask(
name="Credit Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=test_user_id,
)
await repository.create(credit_task)
@@ -353,7 +352,7 @@ class TestScheduledTaskRepository:
sound_task = ScheduledTask(
name="Sound Task",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
user_id=test_user_id,
)
await repository.create(sound_task)
@@ -400,7 +399,7 @@ class TestScheduledTaskRepository:
):
"""Test marking task as completed."""
initial_count = sample_task.executions_count
next_execution = datetime.utcnow() + timedelta(days=1)
next_execution = datetime.now(tz=UTC) + timedelta(days=1)
await repository.mark_as_completed(sample_task, next_execution)
@@ -418,12 +417,12 @@ class TestScheduledTaskRepository:
task = ScheduledTask(
name="Recurring Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
recurrence_type=RecurrenceType.DAILY,
)
created_task = await repository.create(task)
next_execution = datetime.utcnow() + timedelta(days=1)
next_execution = datetime.now(tz=UTC).replace(tzinfo=None) + timedelta(days=1)
await repository.mark_as_completed(created_task, next_execution)
updated_task = await repository.get_by_id(created_task.id)
@@ -467,7 +466,7 @@ class TestScheduledTaskRepository:
task = ScheduledTask(
name="Recurring Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
recurrence_type=RecurrenceType.DAILY,
)
created_task = await repository.create(task)

View File

@@ -1,7 +1,7 @@
"""Tests for scheduler service."""
import uuid
from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -13,6 +13,7 @@ from app.models.scheduled_task import (
TaskStatus,
TaskType,
)
from app.schemas.scheduler import ScheduledTaskCreate
from app.services.scheduler import SchedulerService
@@ -31,7 +32,8 @@ class TestSchedulerService:
mock_player_service,
) -> SchedulerService:
"""Create scheduler service fixture."""
session_factory = lambda: db_session
def session_factory():
return db_session
return SchedulerService(session_factory, mock_player_service)
@pytest.fixture
@@ -40,11 +42,16 @@ class TestSchedulerService:
return {
"name": "Test Task",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": datetime.utcnow() + timedelta(hours=1),
"scheduled_at": datetime.now(tz=UTC) + timedelta(hours=1),
"parameters": {"test": "value"},
"timezone": "UTC",
}
def _create_task_schema(self, task_data: dict, **overrides) -> ScheduledTaskCreate:
"""Create ScheduledTaskCreate schema from dict."""
data = {**task_data, **overrides}
return ScheduledTaskCreate(**data)
async def test_create_task(
self,
scheduler_service: SchedulerService,
@@ -52,7 +59,8 @@ class TestSchedulerService:
):
"""Test creating a scheduled task."""
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
task = await scheduler_service.create_task(**sample_task_data)
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
assert task.id is not None
assert task.name == sample_task_data["name"]
@@ -69,9 +77,10 @@ class TestSchedulerService:
):
"""Test creating a user task."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(
task_data=schema,
user_id=test_user_id,
**sample_task_data,
)
assert task.user_id == test_user_id
@@ -84,7 +93,8 @@ class TestSchedulerService:
):
"""Test creating a system task."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
assert task.user_id is None
assert task.is_system_task()
@@ -96,11 +106,12 @@ class TestSchedulerService:
):
"""Test creating a recurring task."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(
schema = self._create_task_schema(
sample_task_data,
recurrence_type=RecurrenceType.DAILY,
recurrence_count=5,
**sample_task_data,
)
task = await scheduler_service.create_task(task_data=schema)
assert task.recurrence_type == RecurrenceType.DAILY
assert task.recurrence_count == 5
@@ -113,13 +124,15 @@ class TestSchedulerService:
):
"""Test creating task with timezone conversion."""
# Use a specific datetime for testing
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY
sample_task_data["scheduled_at"] = ny_time
sample_task_data["timezone"] = "America/New_York"
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY # noqa: DTZ001
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
schema = self._create_task_schema(
sample_task_data,
scheduled_at=ny_time,
timezone="America/New_York",
)
task = await scheduler_service.create_task(task_data=schema)
# The scheduled_at should be converted to UTC
assert task.timezone == "America/New_York"
@@ -135,7 +148,8 @@ class TestSchedulerService:
"""Test cancelling a task."""
# Create a task first
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
# Mock the scheduler remove_job method
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
@@ -169,13 +183,15 @@ class TestSchedulerService:
"""Test getting user tasks."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
# Create user task
schema = self._create_task_schema(sample_task_data)
await scheduler_service.create_task(
task_data=schema,
user_id=test_user_id,
**sample_task_data,
)
# Create system task
await scheduler_service.create_task(**sample_task_data)
system_schema = self._create_task_schema(sample_task_data)
await scheduler_service.create_task(task_data=system_schema)
user_tasks = await scheduler_service.get_user_tasks(test_user_id)
@@ -196,10 +212,10 @@ class TestSchedulerService:
# Should create daily credit recharge task
mock_create.assert_called_once()
created_task = mock_create.call_args[0][0]
assert created_task.name == "Daily Credit Recharge"
assert created_task.task_type == TaskType.CREDIT_RECHARGE
assert created_task.recurrence_type == RecurrenceType.DAILY
created_task_data = mock_create.call_args[0][0]
assert created_task_data["name"] == "Daily Credit Recharge"
assert created_task_data["task_type"] == TaskType.CREDIT_RECHARGE
assert created_task_data["recurrence_type"] == RecurrenceType.DAILY
async def test_ensure_system_tasks_already_exist(
self,
@@ -209,7 +225,7 @@ class TestSchedulerService:
existing_task = ScheduledTask(
name="Existing Daily Credit Recharge",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
recurrence_type=RecurrenceType.DAILY,
is_active=True,
)
@@ -231,7 +247,7 @@ class TestSchedulerService:
task = ScheduledTask(
name="One Shot",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.NONE,
)
@@ -247,7 +263,7 @@ class TestSchedulerService:
task = ScheduledTask(
name="Daily",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.DAILY,
)
@@ -263,7 +279,7 @@ class TestSchedulerService:
task = ScheduledTask(
name="Cron",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow() + timedelta(hours=1),
scheduled_at=datetime.now(tz=UTC) + timedelta(hours=1),
recurrence_type=RecurrenceType.CRON,
cron_expression="0 9 * * *", # 9 AM daily
)
@@ -280,7 +296,7 @@ class TestSchedulerService:
task = ScheduledTask(
name="Monthly",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime(2024, 1, 15, 10, 30, 0), # 15th at 10:30 AM
scheduled_at=datetime(2024, 1, 15, 10, 30, 0, tzinfo=UTC), # 15th at 10:30 AM
recurrence_type=RecurrenceType.MONTHLY,
)
@@ -293,7 +309,7 @@ class TestSchedulerService:
scheduler_service: SchedulerService,
):
"""Test calculating next execution time."""
now = datetime.utcnow()
now = datetime.now(tz=UTC)
# Test different recurrence types
test_cases = [
@@ -313,7 +329,7 @@ class TestSchedulerService:
)
with patch("app.services.scheduler.datetime") as mock_datetime:
mock_datetime.utcnow.return_value = now
mock_datetime.now.return_value = now
next_execution = scheduler_service._calculate_next_execution(task)
assert next_execution is not None
@@ -328,14 +344,14 @@ class TestSchedulerService:
task = ScheduledTask(
name="One Shot",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
recurrence_type=RecurrenceType.NONE,
)
next_execution = scheduler_service._calculate_next_execution(task)
assert next_execution is None
@patch("app.services.task_handlers.TaskHandlerRegistry")
@patch("app.services.scheduler.TaskHandlerRegistry")
async def test_execute_task_success(
self,
mock_handler_class,
@@ -343,9 +359,11 @@ class TestSchedulerService:
sample_task_data: dict,
):
"""Test successful task execution."""
# Create task
# Create task ready for immediate execution
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
ready_data = {**sample_task_data, "scheduled_at": datetime.now(tz=UTC) - timedelta(minutes=1)}
schema = self._create_task_schema(ready_data)
task = await scheduler_service.create_task(task_data=schema)
# Mock handler registry
mock_handler = AsyncMock()
@@ -365,7 +383,7 @@ class TestSchedulerService:
assert updated_task.status == TaskStatus.COMPLETED
assert updated_task.executions_count == 1
@patch("app.services.task_handlers.TaskHandlerRegistry")
@patch("app.services.scheduler.TaskHandlerRegistry")
async def test_execute_task_failure(
self,
mock_handler_class,
@@ -373,9 +391,11 @@ class TestSchedulerService:
sample_task_data: dict,
):
"""Test task execution failure."""
# Create task
# Create task ready for immediate execution
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
ready_data = {**sample_task_data, "scheduled_at": datetime.now(tz=UTC) - timedelta(minutes=1)}
schema = self._create_task_schema(ready_data)
task = await scheduler_service.create_task(task_data=schema)
# Mock handler to raise exception
mock_handler = AsyncMock()
@@ -407,11 +427,12 @@ class TestSchedulerService:
sample_task_data: dict,
):
"""Test executing expired task."""
# Create expired task
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
# Create expired task (stored as naive UTC datetime)
expires_at = datetime.now(tz=UTC).replace(tzinfo=None) - timedelta(hours=1)
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
schema = self._create_task_schema(sample_task_data, expires_at=expires_at)
task = await scheduler_service.create_task(task_data=schema)
# Execute task
await scheduler_service._execute_task(task.id)
@@ -431,7 +452,8 @@ class TestSchedulerService:
):
"""Test prevention of concurrent task execution."""
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
schema = self._create_task_schema(sample_task_data)
task = await scheduler_service.create_task(task_data=schema)
# Add task to running set
scheduler_service._running_tasks.add(str(task.id))
@@ -443,7 +465,7 @@ class TestSchedulerService:
# Handler should not be called
mock_handler_class.assert_not_called()
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
@patch("app.services.scheduler.ScheduledTaskRepository")
async def test_maintenance_job_expired_tasks(
self,
mock_repo_class,
@@ -463,12 +485,13 @@ class TestSchedulerService:
await scheduler_service._maintenance_job()
# Should mark as cancelled and remove from scheduler
assert expired_task.status == TaskStatus.CANCELLED
assert expired_task.is_active is False
mock_repo.update.assert_called_with(expired_task)
mock_repo.update.assert_called_with(expired_task, {
"status": TaskStatus.CANCELLED,
"is_active": False,
})
mock_remove.assert_called_once_with(str(expired_task.id))
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
@patch("app.services.scheduler.ScheduledTaskRepository")
async def test_maintenance_job_due_recurring_tasks(
self,
mock_repo_class,
@@ -478,7 +501,7 @@ class TestSchedulerService:
# Mock due recurring task
due_task = MagicMock()
due_task.should_repeat.return_value = True
due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5)
due_task.next_execution_at = datetime.now(tz=UTC) - timedelta(minutes=5)
mock_repo = AsyncMock()
mock_repo.get_expired_tasks.return_value = []
@@ -489,7 +512,8 @@ class TestSchedulerService:
await scheduler_service._maintenance_job()
# Should reset to pending and reschedule
assert due_task.status == TaskStatus.PENDING
assert due_task.scheduled_at == due_task.next_execution_at
mock_repo.update.assert_called_with(due_task)
mock_repo.update.assert_called_with(due_task, {
"status": TaskStatus.PENDING,
"scheduled_at": due_task.next_execution_at,
})
mock_schedule.assert_called_once_with(due_task)

View File

@@ -1,6 +1,7 @@
"""Tests for task handlers."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -21,7 +22,7 @@ class TestTaskHandlerRegistry:
@pytest.fixture
def mock_player_service(self):
"""Create mock player service."""
return MagicMock()
return AsyncMock()
@pytest.fixture
def task_registry(
@@ -31,8 +32,11 @@ class TestTaskHandlerRegistry:
mock_player_service,
) -> TaskHandlerRegistry:
"""Create task handler registry fixture."""
def mock_db_session_factory():
return db_session
return TaskHandlerRegistry(
db_session,
mock_db_session_factory,
mock_credit_service,
mock_player_service,
)
@@ -46,7 +50,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="Unknown Task",
task_type="UNKNOWN_TYPE", # Invalid type
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
)
with pytest.raises(TaskExecutionError, match="No handler registered"):
@@ -61,7 +65,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="Daily Credit Recharge",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={},
)
@@ -84,7 +88,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="User Credit Recharge",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"user_id": str(test_user_id)},
)
@@ -107,7 +111,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="User Credit Recharge",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"user_id": test_user_id}, # UUID object instead of string
)
@@ -118,13 +122,13 @@ class TestTaskHandlerRegistry:
async def test_handle_play_sound_success(
self,
task_registry: TaskHandlerRegistry,
test_sound_id: uuid.UUID,
test_sound_id: int,
):
"""Test successful play sound task handling."""
task = ScheduledTask(
name="Play Sound",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"sound_id": str(test_sound_id)},
)
@@ -134,8 +138,9 @@ class TestTaskHandlerRegistry:
mock_sound.filename = "test_sound.mp3"
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
with patch("app.services.task_handlers.VLCPlayerService") as mock_vlc_class:
mock_vlc_service = AsyncMock()
mock_vlc_service.play_sound.return_value = True
mock_vlc_class.return_value = mock_vlc_service
await task_registry.execute_task(task)
@@ -151,7 +156,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="Play Sound",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={}, # Missing sound_id
)
@@ -166,7 +171,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="Play Sound",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"sound_id": "invalid-uuid"},
)
@@ -176,13 +181,13 @@ class TestTaskHandlerRegistry:
async def test_handle_play_sound_not_found(
self,
task_registry: TaskHandlerRegistry,
test_sound_id: uuid.UUID,
test_sound_id: int,
):
"""Test play sound task with non-existent sound."""
task = ScheduledTask(
name="Play Sound",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"sound_id": str(test_sound_id)},
)
@@ -193,13 +198,13 @@ class TestTaskHandlerRegistry:
async def test_handle_play_sound_uuid_parameter(
self,
task_registry: TaskHandlerRegistry,
test_sound_id: uuid.UUID,
test_sound_id: int,
):
"""Test play sound task with UUID parameter (not string)."""
task = ScheduledTask(
name="Play Sound",
task_type=TaskType.PLAY_SOUND,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"sound_id": test_sound_id}, # UUID object
)
@@ -207,7 +212,7 @@ class TestTaskHandlerRegistry:
mock_sound.filename = "test_sound.mp3"
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
with patch("app.services.task_handlers.VLCPlayerService") as mock_vlc_class:
mock_vlc_service = AsyncMock()
mock_vlc_class.return_value = mock_vlc_service
@@ -219,13 +224,13 @@ class TestTaskHandlerRegistry:
self,
task_registry: TaskHandlerRegistry,
mock_player_service,
test_playlist_id: uuid.UUID,
test_playlist_id: int,
):
"""Test successful play playlist task handling."""
task = ScheduledTask(
name="Play Playlist",
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={
"playlist_id": str(test_playlist_id),
"play_mode": "loop",
@@ -244,20 +249,20 @@ class TestTaskHandlerRegistry:
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
mock_player_service.load_playlist.assert_called_once_with(test_playlist_id)
mock_player_service.set_mode.assert_called_once_with("loop")
mock_player_service.set_shuffle.assert_called_once_with(True)
mock_player_service.set_shuffle.assert_called_once_with(shuffle=True)
mock_player_service.play.assert_called_once()
async def test_handle_play_playlist_minimal_parameters(
self,
task_registry: TaskHandlerRegistry,
mock_player_service,
test_playlist_id: uuid.UUID,
test_playlist_id: int,
):
"""Test play playlist task with minimal parameters."""
task = ScheduledTask(
name="Play Playlist",
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"playlist_id": str(test_playlist_id)},
)
@@ -269,7 +274,7 @@ class TestTaskHandlerRegistry:
# Should use default values
mock_player_service.set_mode.assert_called_once_with("continuous")
mock_player_service.set_shuffle.assert_called_once_with(False)
mock_player_service.set_shuffle.assert_not_called()
async def test_handle_play_playlist_missing_playlist_id(
self,
@@ -279,7 +284,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="Play Playlist",
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={}, # Missing playlist_id
)
@@ -294,7 +299,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="Play Playlist",
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"playlist_id": "invalid-uuid"},
)
@@ -304,13 +309,13 @@ class TestTaskHandlerRegistry:
async def test_handle_play_playlist_not_found(
self,
task_registry: TaskHandlerRegistry,
test_playlist_id: uuid.UUID,
test_playlist_id: int,
):
"""Test play playlist task with non-existent playlist."""
task = ScheduledTask(
name="Play Playlist",
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={"playlist_id": str(test_playlist_id)},
)
@@ -322,7 +327,7 @@ class TestTaskHandlerRegistry:
self,
task_registry: TaskHandlerRegistry,
mock_player_service,
test_playlist_id: uuid.UUID,
test_playlist_id: int,
):
"""Test play playlist task with various valid play modes."""
mock_playlist = MagicMock()
@@ -334,7 +339,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="Play Playlist",
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={
"playlist_id": str(test_playlist_id),
"play_mode": mode,
@@ -352,13 +357,13 @@ class TestTaskHandlerRegistry:
self,
task_registry: TaskHandlerRegistry,
mock_player_service,
test_playlist_id: uuid.UUID,
test_playlist_id: int,
):
"""Test play playlist task with invalid play mode."""
task = ScheduledTask(
name="Play Playlist",
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={
"playlist_id": str(test_playlist_id),
"play_mode": "invalid_mode",
@@ -386,7 +391,7 @@ class TestTaskHandlerRegistry:
task = ScheduledTask(
name="Failing Task",
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=datetime.utcnow(),
scheduled_at=datetime.now(tz=UTC),
parameters={},
)
@@ -403,13 +408,17 @@ class TestTaskHandlerRegistry:
mock_player_service,
):
"""Test task registry initialization."""
def mock_db_session_factory():
return db_session
registry = TaskHandlerRegistry(
db_session,
mock_db_session_factory,
mock_credit_service,
mock_player_service,
)
assert registry.db_session == db_session
assert registry.db_session_factory == mock_db_session_factory
assert registry.credit_service == mock_credit_service
assert registry.player_service == mock_player_service
assert registry.sound_repository is not None