Refactor scheduled task repository and schemas for improved type hints and consistency

- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase.
- Refactored import statements for better organization and clarity.
- Enhanced the ScheduledTaskBase schema to use modern type hints.
- Cleaned up unnecessary comments and whitespace in various files.
- Improved error handling and logging in task execution handlers.
- Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
JSC
2025-08-28 23:38:47 +02:00
parent 96801dc4d6
commit dc89e45675
19 changed files with 292 additions and 291 deletions

View File

@@ -1,7 +1,5 @@
"""API endpoints for scheduled task management."""
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -11,7 +9,7 @@ from app.core.dependencies import (
get_admin_user,
get_current_active_user,
)
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
from app.models.user import User
from app.schemas.scheduler import (
ScheduledTaskCreate,
@@ -54,15 +52,15 @@ async def create_task(
raise HTTPException(status_code=400, detail=str(e))
@router.get("/tasks", response_model=List[ScheduledTaskResponse])
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
async def get_user_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: TaskType | None = Query(None, description="Filter by task type"),
limit: int | None = Query(50, description="Maximum number of tasks to return"),
offset: int | None = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_current_active_user),
scheduler_service: SchedulerService = Depends(get_scheduler_service),
) -> List[ScheduledTask]:
) -> list[ScheduledTask]:
"""Get user's scheduled tasks."""
return await scheduler_service.get_user_tasks(
user_id=current_user.id,
@@ -152,15 +150,15 @@ async def cancel_task(
# Admin-only endpoints
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse])
@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
async def get_all_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: TaskType | None = Query(None, description="Filter by task type"),
limit: int | None = Query(100, description="Maximum number of tasks to return"),
offset: int | None = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]:
) -> list[ScheduledTask]:
"""Get all scheduled tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository
@@ -189,13 +187,13 @@ async def get_all_tasks(
return list(result.all())
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse])
@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
async def get_system_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: TaskType | None = Query(None, description="Filter by task type"),
current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]:
) -> list[ScheduledTask]:
"""Get system tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository

View File

@@ -4,11 +4,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
from app.core.config import settings
from app.core.logging import get_logger
from app.core.seeds import seed_all_data
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
engine: AsyncEngine = create_async_engine(
settings.DATABASE_URL,

View File

@@ -11,11 +11,14 @@ from app.core.database import get_session_factory, init_db
from app.core.logging import get_logger, setup_logging
from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service
from app.services.player import (
get_player_service,
initialize_player_service,
shutdown_player_service,
)
from app.services.scheduler import SchedulerService
from app.services.socket import socket_manager
scheduler_service = None

View File

@@ -1,11 +1,10 @@
"""Scheduled task model for flexible task scheduling with timezone support."""
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Optional
from typing import Any
from sqlmodel import JSON, Column, Field, SQLModel
from sqlmodel import JSON, Column, Field
from app.models.base import BaseModel
@@ -57,11 +56,11 @@ class ScheduledTask(BaseModel, table=True):
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
)
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
cron_expression: Optional[str] = Field(
cron_expression: str | None = Field(
default=None,
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
)
recurrence_count: Optional[int] = Field(
recurrence_count: int | None = Field(
default=None,
description="Number of times to repeat (None for infinite)",
)
@@ -75,29 +74,29 @@ class ScheduledTask(BaseModel, table=True):
)
# User association (None for system tasks)
user_id: Optional[int] = Field(
user_id: int | None = Field(
default=None,
foreign_key="user.id",
description="User who created the task (None for system tasks)",
)
# Execution tracking
last_executed_at: Optional[datetime] = Field(
last_executed_at: datetime | None = Field(
default=None,
description="When the task was last executed (UTC)",
)
next_execution_at: Optional[datetime] = Field(
next_execution_at: datetime | None = Field(
default=None,
description="When the task should be executed next (UTC, for recurring tasks)",
)
error_message: Optional[str] = Field(
error_message: str | None = Field(
default=None,
description="Error message if execution failed",
)
# Task lifecycle
is_active: bool = Field(default=True, description="Whether the task is active")
expires_at: Optional[datetime] = Field(
expires_at: datetime | None = Field(
default=None,
description="When the task expires (UTC, optional)",
)

View File

@@ -1,12 +1,16 @@
"""Repository for scheduled task operations."""
from datetime import datetime
from typing import List, Optional
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.base import BaseRepository
@@ -17,7 +21,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
"""Initialize the repository."""
super().__init__(ScheduledTask, session)
async def get_pending_tasks(self) -> List[ScheduledTask]:
async def get_pending_tasks(self) -> list[ScheduledTask]:
"""Get all pending tasks that are ready to be executed."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
@@ -28,7 +32,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement)
return list(result.all())
async def get_active_tasks(self) -> List[ScheduledTask]:
async def get_active_tasks(self) -> list[ScheduledTask]:
"""Get all active tasks."""
statement = select(ScheduledTask).where(
ScheduledTask.is_active.is_(True),
@@ -40,11 +44,11 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def get_user_tasks(
self,
user_id: int,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ScheduledTask]:
status: TaskStatus | None = None,
task_type: TaskType | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[ScheduledTask]:
"""Get tasks for a specific user."""
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
@@ -67,9 +71,9 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def get_system_tasks(
self,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
) -> List[ScheduledTask]:
status: TaskStatus | None = None,
task_type: TaskType | None = None,
) -> list[ScheduledTask]:
"""Get system tasks (tasks with no user association)."""
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
@@ -84,7 +88,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement)
return list(result.all())
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
"""Get recurring tasks that need their next execution scheduled."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
@@ -96,7 +100,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement)
return list(result.all())
async def get_expired_tasks(self) -> List[ScheduledTask]:
async def get_expired_tasks(self) -> list[ScheduledTask]:
"""Get expired tasks that should be cleaned up."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
@@ -110,7 +114,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def cancel_user_tasks(
self,
user_id: int,
task_type: Optional[TaskType] = None,
task_type: TaskType | None = None,
) -> int:
"""Cancel all pending/running tasks for a user."""
statement = select(ScheduledTask).where(
@@ -144,7 +148,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def mark_as_completed(
self,
task: ScheduledTask,
next_execution_at: Optional[datetime] = None,
next_execution_at: datetime | None = None,
) -> None:
"""Mark a task as completed and set next execution if recurring."""
task.status = TaskStatus.COMPLETED

View File

@@ -1,7 +1,7 @@
"""Schemas for scheduled task API."""
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@@ -15,7 +15,7 @@ class ScheduledTaskBase(BaseModel):
task_type: TaskType = Field(description="Type of task to execute")
scheduled_at: datetime = Field(description="When the task should be executed")
timezone: str = Field(default="UTC", description="Timezone for scheduling")
parameters: Dict[str, Any] = Field(
parameters: dict[str, Any] = Field(
default_factory=dict,
description="Task-specific parameters",
)
@@ -23,15 +23,15 @@ class ScheduledTaskBase(BaseModel):
default=RecurrenceType.NONE,
description="Recurrence pattern",
)
cron_expression: Optional[str] = Field(
cron_expression: str | None = Field(
default=None,
description="Cron expression for custom recurrence",
)
recurrence_count: Optional[int] = Field(
recurrence_count: int | None = Field(
default=None,
description="Number of times to repeat (None for infinite)",
)
expires_at: Optional[datetime] = Field(
expires_at: datetime | None = Field(
default=None,
description="When the task expires (optional)",
)
@@ -40,18 +40,17 @@ class ScheduledTaskBase(BaseModel):
class ScheduledTaskCreate(ScheduledTaskBase):
"""Schema for creating a scheduled task."""
pass
class ScheduledTaskUpdate(BaseModel):
"""Schema for updating a scheduled task."""
name: Optional[str] = None
scheduled_at: Optional[datetime] = None
timezone: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = None
expires_at: Optional[datetime] = None
name: str | None = None
scheduled_at: datetime | None = None
timezone: str | None = None
parameters: dict[str, Any] | None = None
is_active: bool | None = None
expires_at: datetime | None = None
class ScheduledTaskResponse(ScheduledTaskBase):
@@ -59,11 +58,11 @@ class ScheduledTaskResponse(ScheduledTaskBase):
id: int
status: TaskStatus
user_id: Optional[int] = None
user_id: int | None = None
executions_count: int
last_executed_at: Optional[datetime] = None
next_execution_at: Optional[datetime] = None
error_message: Optional[str] = None
last_executed_at: datetime | None = None
next_execution_at: datetime | None = None
error_message: str | None = None
is_active: bool
created_at: datetime
updated_at: datetime
@@ -78,7 +77,7 @@ class ScheduledTaskResponse(ScheduledTaskBase):
class CreditRechargeParameters(BaseModel):
"""Parameters for credit recharge tasks."""
user_id: Optional[int] = Field(
user_id: int | None = Field(
default=None,
description="Specific user ID to recharge (None for all users)",
)
@@ -109,10 +108,10 @@ class CreateCreditRechargeTask(BaseModel):
scheduled_at: datetime
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
user_id: Optional[int] = None
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
user_id: int | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
@@ -137,9 +136,9 @@ class CreatePlaySoundTask(BaseModel):
sound_id: int
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
@@ -166,9 +165,9 @@ class CreatePlayPlaylistTask(BaseModel):
shuffle: bool = False
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""

View File

@@ -2,7 +2,7 @@
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import Any
import pytz
from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -86,13 +86,13 @@ class SchedulerService:
name: str,
task_type: TaskType,
scheduled_at: datetime,
parameters: Optional[Dict[str, Any]] = None,
user_id: Optional[int] = None,
parameters: dict[str, Any] | None = None,
user_id: int | None = None,
timezone: str = "UTC",
recurrence_type: RecurrenceType = RecurrenceType.NONE,
cron_expression: Optional[str] = None,
recurrence_count: Optional[int] = None,
expires_at: Optional[datetime] = None,
cron_expression: str | None = None,
recurrence_count: int | None = None,
expires_at: datetime | None = None,
) -> ScheduledTask:
"""Create a new scheduled task."""
async with self.db_session_factory() as session:
@@ -150,11 +150,11 @@ class SchedulerService:
async def get_user_tasks(
self,
user_id: int,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ScheduledTask]:
status: TaskStatus | None = None,
task_type: TaskType | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[ScheduledTask]:
"""Get tasks for a specific user."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
@@ -182,7 +182,7 @@ class SchedulerService:
# Check if daily credit recharge task exists
system_tasks = await repo.get_system_tasks(
task_type=TaskType.CREDIT_RECHARGE
task_type=TaskType.CREDIT_RECHARGE,
)
daily_recharge_exists = any(
@@ -194,7 +194,7 @@ class SchedulerService:
if not daily_recharge_exists:
# Create daily credit recharge task
tomorrow_midnight = datetime.utcnow().replace(
hour=0, minute=0, second=0, microsecond=0
hour=0, minute=0, second=0, microsecond=0,
) + timedelta(days=1)
task_data = {
@@ -258,36 +258,36 @@ class SchedulerService:
if task.recurrence_type == RecurrenceType.NONE:
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
elif task.recurrence_type == RecurrenceType.HOURLY:
if task.recurrence_type == RecurrenceType.HOURLY:
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.DAILY:
if task.recurrence_type == RecurrenceType.DAILY:
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.WEEKLY:
if task.recurrence_type == RecurrenceType.WEEKLY:
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.MONTHLY:
if task.recurrence_type == RecurrenceType.MONTHLY:
# Use cron trigger for monthly (more reliable than interval)
scheduled_time = task.scheduled_at
return CronTrigger(
day=scheduled_time.day,
hour=scheduled_time.hour,
minute=scheduled_time.minute,
timezone=tz
timezone=tz,
)
elif task.recurrence_type == RecurrenceType.YEARLY:
if task.recurrence_type == RecurrenceType.YEARLY:
scheduled_time = task.scheduled_at
return CronTrigger(
month=scheduled_time.month,
day=scheduled_time.day,
hour=scheduled_time.hour,
minute=scheduled_time.minute,
timezone=tz
timezone=tz,
)
return None
@@ -332,7 +332,7 @@ class SchedulerService:
# Execute the task
try:
handler_registry = TaskHandlerRegistry(
session, self.db_session_factory, self.credit_service, self.player_service
session, self.db_session_factory, self.credit_service, self.player_service,
)
await handler_registry.execute_task(task)
@@ -352,25 +352,25 @@ class SchedulerService:
except Exception as e:
await repo.mark_as_failed(task, str(e))
logger.exception(f"Task {task_id} execution failed: {str(e)}")
logger.exception(f"Task {task_id} execution failed: {e!s}")
finally:
self._running_tasks.discard(task_id_str)
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
"""Calculate the next execution time for a recurring task."""
now = datetime.utcnow()
if task.recurrence_type == RecurrenceType.HOURLY:
return now + timedelta(hours=1)
elif task.recurrence_type == RecurrenceType.DAILY:
if task.recurrence_type == RecurrenceType.DAILY:
return now + timedelta(days=1)
elif task.recurrence_type == RecurrenceType.WEEKLY:
if task.recurrence_type == RecurrenceType.WEEKLY:
return now + timedelta(weeks=1)
elif task.recurrence_type == RecurrenceType.MONTHLY:
if task.recurrence_type == RecurrenceType.MONTHLY:
# Add approximately one month
return now + timedelta(days=30)
elif task.recurrence_type == RecurrenceType.YEARLY:
if task.recurrence_type == RecurrenceType.YEARLY:
return now + timedelta(days=365)
return None

View File

@@ -1,6 +1,5 @@
"""Task execution handlers for different task types."""
from typing import Any, Dict, Optional
from collections.abc import Callable
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -18,7 +17,6 @@ logger = get_logger(__name__)
class TaskExecutionError(Exception):
"""Exception raised when task execution fails."""
pass
class TaskHandlerRegistry:
@@ -58,8 +56,8 @@ class TaskHandlerRegistry:
await handler(task)
logger.info(f"Task {task.id} executed successfully")
except Exception as e:
logger.exception(f"Task {task.id} execution failed: {str(e)}")
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
logger.exception(f"Task {task.id} execution failed: {e!s}")
raise TaskExecutionError(f"Task execution failed: {e!s}") from e
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
"""Handle credit recharge task."""
@@ -105,7 +103,7 @@ class TaskHandlerRegistry:
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
except Exception as e:
# Convert HTTP exceptions or credit errors to task execution errors
raise TaskExecutionError(f"Failed to play sound with credits: {str(e)}") from e
raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e
else:
# System task: play without credit deduction
sound = await self.sound_repository.get_by_id(sound_id_int)

View File

@@ -7,6 +7,7 @@ from datetime import datetime
from app.core.database import get_session_factory
from app.repositories.scheduled_task import ScheduledTaskRepository
async def check_tasks():
session_factory = get_session_factory()

View File

@@ -5,8 +5,9 @@ import asyncio
from datetime import datetime, timedelta
from app.core.database import get_session_factory
from app.models.scheduled_task import RecurrenceType, TaskType
from app.repositories.scheduled_task import ScheduledTaskRepository
from app.models.scheduled_task import TaskType, RecurrenceType
async def create_future_task():
session_factory = get_session_factory()

View File

@@ -4,9 +4,9 @@
import asyncio
from datetime import datetime, timedelta
from app.core.database import get_session_factory
from app.main import get_global_scheduler_service
from app.models.scheduled_task import TaskType, RecurrenceType
from app.models.scheduled_task import RecurrenceType, TaskType
async def test_api_task_creation():
"""Test creating a task through the scheduler service (simulates API call)."""

View File

@@ -5,8 +5,9 @@ import asyncio
from datetime import datetime
from app.core.database import get_session_factory
from app.models.scheduled_task import RecurrenceType, TaskType
from app.repositories.scheduled_task import ScheduledTaskRepository
from app.models.scheduled_task import TaskType, RecurrenceType
async def create_test_task():
session_factory = get_session_factory()

View File

@@ -3,8 +3,6 @@
import uuid
from datetime import datetime, timedelta
import pytest
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,

View File

@@ -2,7 +2,6 @@
import uuid
from datetime import datetime, timedelta
from typing import List
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession

View File

@@ -51,7 +51,7 @@ class TestSchedulerService:
sample_task_data: dict,
):
"""Test creating a scheduled task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
task = await scheduler_service.create_task(**sample_task_data)
assert task.id is not None
@@ -68,7 +68,7 @@ class TestSchedulerService:
test_user_id: uuid.UUID,
):
"""Test creating a user task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(
user_id=test_user_id,
**sample_task_data,
@@ -83,7 +83,7 @@ class TestSchedulerService:
sample_task_data: dict,
):
"""Test creating a system task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
assert task.user_id is None
@@ -95,7 +95,7 @@ class TestSchedulerService:
sample_task_data: dict,
):
"""Test creating a recurring task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(
recurrence_type=RecurrenceType.DAILY,
recurrence_count=5,
@@ -118,7 +118,7 @@ class TestSchedulerService:
sample_task_data["scheduled_at"] = ny_time
sample_task_data["timezone"] = "America/New_York"
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
# The scheduled_at should be converted to UTC
@@ -134,11 +134,11 @@ class TestSchedulerService:
):
"""Test cancelling a task."""
# Create a task first
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
# Mock the scheduler remove_job method
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
result = await scheduler_service.cancel_task(task.id)
assert result is True
@@ -167,7 +167,7 @@ class TestSchedulerService:
test_user_id: uuid.UUID,
):
"""Test getting user tasks."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
# Create user task
await scheduler_service.create_task(
user_id=test_user_id,
@@ -188,8 +188,8 @@ class TestSchedulerService:
):
"""Test ensuring system tasks exist."""
# Mock the repository to return no existing tasks
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
mock_get.return_value = []
await scheduler_service._ensure_system_tasks()
@@ -214,8 +214,8 @@ class TestSchedulerService:
is_active=True,
)
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
mock_get.return_value = [existing_task]
await scheduler_service._ensure_system_tasks()
@@ -312,7 +312,7 @@ class TestSchedulerService:
recurrence_type=recurrence_type,
)
with patch('app.services.scheduler.datetime') as mock_datetime:
with patch("app.services.scheduler.datetime") as mock_datetime:
mock_datetime.utcnow.return_value = now
next_execution = scheduler_service._calculate_next_execution(task)
@@ -335,7 +335,7 @@ class TestSchedulerService:
next_execution = scheduler_service._calculate_next_execution(task)
assert next_execution is None
@patch('app.services.task_handlers.TaskHandlerRegistry')
@patch("app.services.task_handlers.TaskHandlerRegistry")
async def test_execute_task_success(
self,
mock_handler_class,
@@ -344,7 +344,7 @@ class TestSchedulerService:
):
"""Test successful task execution."""
# Create task
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
# Mock handler registry
@@ -365,7 +365,7 @@ class TestSchedulerService:
assert updated_task.status == TaskStatus.COMPLETED
assert updated_task.executions_count == 1
@patch('app.services.task_handlers.TaskHandlerRegistry')
@patch("app.services.task_handlers.TaskHandlerRegistry")
async def test_execute_task_failure(
self,
mock_handler_class,
@@ -374,7 +374,7 @@ class TestSchedulerService:
):
"""Test task execution failure."""
# Create task
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
# Mock handler to raise exception
@@ -410,7 +410,7 @@ class TestSchedulerService:
# Create expired task
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
# Execute task
@@ -430,20 +430,20 @@ class TestSchedulerService:
sample_task_data: dict,
):
"""Test prevention of concurrent task execution."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data)
# Add task to running set
scheduler_service._running_tasks.add(str(task.id))
# Try to execute - should return without doing anything
with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class:
with patch("app.services.task_handlers.TaskHandlerRegistry") as mock_handler_class:
await scheduler_service._execute_task(task.id)
# Handler should not be called
mock_handler_class.assert_not_called()
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
async def test_maintenance_job_expired_tasks(
self,
mock_repo_class,
@@ -459,7 +459,7 @@ class TestSchedulerService:
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
mock_repo_class.return_value = mock_repo
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
await scheduler_service._maintenance_job()
# Should mark as cancelled and remove from scheduler
@@ -468,7 +468,7 @@ class TestSchedulerService:
mock_repo.update.assert_called_with(expired_task)
mock_remove.assert_called_once_with(str(expired_task.id))
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
async def test_maintenance_job_due_recurring_tasks(
self,
mock_repo_class,
@@ -485,7 +485,7 @@ class TestSchedulerService:
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
mock_repo_class.return_value = mock_repo
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
await scheduler_service._maintenance_job()
# Should reset to pending and reschedule

View File

@@ -133,8 +133,8 @@ class TestTaskHandlerRegistry:
mock_sound.id = test_sound_id
mock_sound.filename = "test_sound.mp3"
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
mock_vlc_service = AsyncMock()
mock_vlc_class.return_value = mock_vlc_service
@@ -186,7 +186,7 @@ class TestTaskHandlerRegistry:
parameters={"sound_id": str(test_sound_id)},
)
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None):
with patch.object(task_registry.sound_repository, "get_by_id", return_value=None):
with pytest.raises(TaskExecutionError, match="Sound not found"):
await task_registry.execute_task(task)
@@ -206,8 +206,8 @@ class TestTaskHandlerRegistry:
mock_sound = MagicMock()
mock_sound.filename = "test_sound.mp3"
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
mock_vlc_service = AsyncMock()
mock_vlc_class.return_value = mock_vlc_service
@@ -238,7 +238,7 @@ class TestTaskHandlerRegistry:
mock_playlist.id = test_playlist_id
mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task)
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
@@ -264,7 +264,7 @@ class TestTaskHandlerRegistry:
mock_playlist = MagicMock()
mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task)
# Should use default values
@@ -314,7 +314,7 @@ class TestTaskHandlerRegistry:
parameters={"playlist_id": str(test_playlist_id)},
)
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None):
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=None):
with pytest.raises(TaskExecutionError, match="Playlist not found"):
await task_registry.execute_task(task)
@@ -341,7 +341,7 @@ class TestTaskHandlerRegistry:
},
)
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task)
mock_player_service.set_mode.assert_called_with(mode)
@@ -368,7 +368,7 @@ class TestTaskHandlerRegistry:
mock_playlist = MagicMock()
mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task)
# Should not call set_mode for invalid mode