Refactor scheduled task repository and schemas for improved type hints and consistency
- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
"""API endpoints for scheduled task management."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -11,7 +9,7 @@ from app.core.dependencies import (
|
||||
get_admin_user,
|
||||
get_current_active_user,
|
||||
)
|
||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.user import User
|
||||
from app.schemas.scheduler import (
|
||||
ScheduledTaskCreate,
|
||||
@@ -54,15 +52,15 @@ async def create_task(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=List[ScheduledTaskResponse])
|
||||
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_user_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"),
|
||||
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
limit: int | None = Query(50, description="Maximum number of tasks to return"),
|
||||
offset: int | None = Query(0, description="Number of tasks to skip"),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
) -> List[ScheduledTask]:
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get user's scheduled tasks."""
|
||||
return await scheduler_service.get_user_tasks(
|
||||
user_id=current_user.id,
|
||||
@@ -152,15 +150,15 @@ async def cancel_task(
|
||||
|
||||
|
||||
# Admin-only endpoints
|
||||
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse])
|
||||
@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_all_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"),
|
||||
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
limit: int | None = Query(100, description="Maximum number of tasks to return"),
|
||||
offset: int | None = Query(0, description="Number of tasks to skip"),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> List[ScheduledTask]:
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get all scheduled tasks (admin only)."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
@@ -189,13 +187,13 @@ async def get_all_tasks(
|
||||
return list(result.all())
|
||||
|
||||
|
||||
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse])
|
||||
@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_system_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> List[ScheduledTask]:
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get system tasks (admin only)."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
# Import all models to ensure SQLModel metadata discovery
|
||||
import app.models # noqa: F401
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.seeds import seed_all_data
|
||||
# Import all models to ensure SQLModel metadata discovery
|
||||
import app.models # noqa: F401
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
|
||||
@@ -11,11 +11,14 @@ from app.core.database import get_session_factory, init_db
|
||||
from app.core.logging import get_logger, setup_logging
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service
|
||||
from app.services.player import (
|
||||
get_player_service,
|
||||
initialize_player_service,
|
||||
shutdown_player_service,
|
||||
)
|
||||
from app.services.scheduler import SchedulerService
|
||||
from app.services.socket import socket_manager
|
||||
|
||||
|
||||
scheduler_service = None
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Scheduled task model for flexible task scheduling with timezone support."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import JSON, Column, Field, SQLModel
|
||||
from sqlmodel import JSON, Column, Field
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
@@ -57,11 +56,11 @@ class ScheduledTask(BaseModel, table=True):
|
||||
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
|
||||
)
|
||||
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
|
||||
cron_expression: Optional[str] = Field(
|
||||
cron_expression: str | None = Field(
|
||||
default=None,
|
||||
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
|
||||
)
|
||||
recurrence_count: Optional[int] = Field(
|
||||
recurrence_count: int | None = Field(
|
||||
default=None,
|
||||
description="Number of times to repeat (None for infinite)",
|
||||
)
|
||||
@@ -75,29 +74,29 @@ class ScheduledTask(BaseModel, table=True):
|
||||
)
|
||||
|
||||
# User association (None for system tasks)
|
||||
user_id: Optional[int] = Field(
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
foreign_key="user.id",
|
||||
description="User who created the task (None for system tasks)",
|
||||
)
|
||||
|
||||
# Execution tracking
|
||||
last_executed_at: Optional[datetime] = Field(
|
||||
last_executed_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="When the task was last executed (UTC)",
|
||||
)
|
||||
next_execution_at: Optional[datetime] = Field(
|
||||
next_execution_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="When the task should be executed next (UTC, for recurring tasks)",
|
||||
)
|
||||
error_message: Optional[str] = Field(
|
||||
error_message: str | None = Field(
|
||||
default=None,
|
||||
description="Error message if execution failed",
|
||||
)
|
||||
|
||||
# Task lifecycle
|
||||
is_active: bool = Field(default=True, description="Whether the task is active")
|
||||
expires_at: Optional[datetime] = Field(
|
||||
expires_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="When the task expires (UTC, optional)",
|
||||
)
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
"""Repository for scheduled task operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
@@ -17,7 +21,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
"""Initialize the repository."""
|
||||
super().__init__(ScheduledTask, session)
|
||||
|
||||
async def get_pending_tasks(self) -> List[ScheduledTask]:
|
||||
async def get_pending_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get all pending tasks that are ready to be executed."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
@@ -28,7 +32,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_active_tasks(self) -> List[ScheduledTask]:
|
||||
async def get_active_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get all active tasks."""
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.is_active.is_(True),
|
||||
@@ -40,11 +44,11 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
||||
|
||||
@@ -67,9 +71,9 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
async def get_system_tasks(
|
||||
self,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get system tasks (tasks with no user association)."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
||||
|
||||
@@ -84,7 +88,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
|
||||
"""Get recurring tasks that need their next execution scheduled."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
@@ -96,7 +100,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_expired_tasks(self) -> List[ScheduledTask]:
|
||||
async def get_expired_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get expired tasks that should be cleaned up."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
@@ -110,7 +114,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
async def cancel_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
task_type: Optional[TaskType] = None,
|
||||
task_type: TaskType | None = None,
|
||||
) -> int:
|
||||
"""Cancel all pending/running tasks for a user."""
|
||||
statement = select(ScheduledTask).where(
|
||||
@@ -144,7 +148,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
async def mark_as_completed(
|
||||
self,
|
||||
task: ScheduledTask,
|
||||
next_execution_at: Optional[datetime] = None,
|
||||
next_execution_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Mark a task as completed and set next execution if recurring."""
|
||||
task.status = TaskStatus.COMPLETED
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Schemas for scheduled task API."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -15,7 +15,7 @@ class ScheduledTaskBase(BaseModel):
|
||||
task_type: TaskType = Field(description="Type of task to execute")
|
||||
scheduled_at: datetime = Field(description="When the task should be executed")
|
||||
timezone: str = Field(default="UTC", description="Timezone for scheduling")
|
||||
parameters: Dict[str, Any] = Field(
|
||||
parameters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Task-specific parameters",
|
||||
)
|
||||
@@ -23,15 +23,15 @@ class ScheduledTaskBase(BaseModel):
|
||||
default=RecurrenceType.NONE,
|
||||
description="Recurrence pattern",
|
||||
)
|
||||
cron_expression: Optional[str] = Field(
|
||||
cron_expression: str | None = Field(
|
||||
default=None,
|
||||
description="Cron expression for custom recurrence",
|
||||
)
|
||||
recurrence_count: Optional[int] = Field(
|
||||
recurrence_count: int | None = Field(
|
||||
default=None,
|
||||
description="Number of times to repeat (None for infinite)",
|
||||
)
|
||||
expires_at: Optional[datetime] = Field(
|
||||
expires_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="When the task expires (optional)",
|
||||
)
|
||||
@@ -40,18 +40,17 @@ class ScheduledTaskBase(BaseModel):
|
||||
class ScheduledTaskCreate(ScheduledTaskBase):
|
||||
"""Schema for creating a scheduled task."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScheduledTaskUpdate(BaseModel):
|
||||
"""Schema for updating a scheduled task."""
|
||||
|
||||
name: Optional[str] = None
|
||||
scheduled_at: Optional[datetime] = None
|
||||
timezone: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
name: str | None = None
|
||||
scheduled_at: datetime | None = None
|
||||
timezone: str | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
is_active: bool | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class ScheduledTaskResponse(ScheduledTaskBase):
|
||||
@@ -59,11 +58,11 @@ class ScheduledTaskResponse(ScheduledTaskBase):
|
||||
|
||||
id: int
|
||||
status: TaskStatus
|
||||
user_id: Optional[int] = None
|
||||
user_id: int | None = None
|
||||
executions_count: int
|
||||
last_executed_at: Optional[datetime] = None
|
||||
next_execution_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
last_executed_at: datetime | None = None
|
||||
next_execution_at: datetime | None = None
|
||||
error_message: str | None = None
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
@@ -78,7 +77,7 @@ class ScheduledTaskResponse(ScheduledTaskBase):
|
||||
class CreditRechargeParameters(BaseModel):
|
||||
"""Parameters for credit recharge tasks."""
|
||||
|
||||
user_id: Optional[int] = Field(
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
description="Specific user ID to recharge (None for all users)",
|
||||
)
|
||||
@@ -109,10 +108,10 @@ class CreateCreditRechargeTask(BaseModel):
|
||||
scheduled_at: datetime
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: Optional[str] = None
|
||||
recurrence_count: Optional[int] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
user_id: Optional[int] = None
|
||||
cron_expression: str | None = None
|
||||
recurrence_count: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
user_id: int | None = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
@@ -137,9 +136,9 @@ class CreatePlaySoundTask(BaseModel):
|
||||
sound_id: int
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: Optional[str] = None
|
||||
recurrence_count: Optional[int] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
cron_expression: str | None = None
|
||||
recurrence_count: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
@@ -166,9 +165,9 @@ class CreatePlayPlaylistTask(BaseModel):
|
||||
shuffle: bool = False
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: Optional[str] = None
|
||||
recurrence_count: Optional[int] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
cron_expression: str | None = None
|
||||
recurrence_count: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
@@ -86,13 +86,13 @@ class SchedulerService:
|
||||
name: str,
|
||||
task_type: TaskType,
|
||||
scheduled_at: datetime,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[int] = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
user_id: int | None = None,
|
||||
timezone: str = "UTC",
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE,
|
||||
cron_expression: Optional[str] = None,
|
||||
recurrence_count: Optional[int] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
cron_expression: str | None = None,
|
||||
recurrence_count: int | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task."""
|
||||
async with self.db_session_factory() as session:
|
||||
@@ -150,11 +150,11 @@ class SchedulerService:
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
@@ -182,7 +182,7 @@ class SchedulerService:
|
||||
|
||||
# Check if daily credit recharge task exists
|
||||
system_tasks = await repo.get_system_tasks(
|
||||
task_type=TaskType.CREDIT_RECHARGE
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
|
||||
daily_recharge_exists = any(
|
||||
@@ -194,7 +194,7 @@ class SchedulerService:
|
||||
if not daily_recharge_exists:
|
||||
# Create daily credit recharge task
|
||||
tomorrow_midnight = datetime.utcnow().replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
hour=0, minute=0, second=0, microsecond=0,
|
||||
) + timedelta(days=1)
|
||||
|
||||
task_data = {
|
||||
@@ -258,36 +258,36 @@ class SchedulerService:
|
||||
if task.recurrence_type == RecurrenceType.NONE:
|
||||
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.HOURLY:
|
||||
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.DAILY:
|
||||
if task.recurrence_type == RecurrenceType.DAILY:
|
||||
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
if task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
if task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
# Use cron trigger for monthly (more reliable than interval)
|
||||
scheduled_time = task.scheduled_at
|
||||
return CronTrigger(
|
||||
day=scheduled_time.day,
|
||||
hour=scheduled_time.hour,
|
||||
minute=scheduled_time.minute,
|
||||
timezone=tz
|
||||
timezone=tz,
|
||||
)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.YEARLY:
|
||||
if task.recurrence_type == RecurrenceType.YEARLY:
|
||||
scheduled_time = task.scheduled_at
|
||||
return CronTrigger(
|
||||
month=scheduled_time.month,
|
||||
day=scheduled_time.day,
|
||||
hour=scheduled_time.hour,
|
||||
minute=scheduled_time.minute,
|
||||
timezone=tz
|
||||
timezone=tz,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -332,7 +332,7 @@ class SchedulerService:
|
||||
# Execute the task
|
||||
try:
|
||||
handler_registry = TaskHandlerRegistry(
|
||||
session, self.db_session_factory, self.credit_service, self.player_service
|
||||
session, self.db_session_factory, self.credit_service, self.player_service,
|
||||
)
|
||||
await handler_registry.execute_task(task)
|
||||
|
||||
@@ -352,25 +352,25 @@ class SchedulerService:
|
||||
|
||||
except Exception as e:
|
||||
await repo.mark_as_failed(task, str(e))
|
||||
logger.exception(f"Task {task_id} execution failed: {str(e)}")
|
||||
logger.exception(f"Task {task_id} execution failed: {e!s}")
|
||||
|
||||
finally:
|
||||
self._running_tasks.discard(task_id_str)
|
||||
|
||||
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
|
||||
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
|
||||
"""Calculate the next execution time for a recurring task."""
|
||||
now = datetime.utcnow()
|
||||
|
||||
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||
return now + timedelta(hours=1)
|
||||
elif task.recurrence_type == RecurrenceType.DAILY:
|
||||
if task.recurrence_type == RecurrenceType.DAILY:
|
||||
return now + timedelta(days=1)
|
||||
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
if task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
return now + timedelta(weeks=1)
|
||||
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
if task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
# Add approximately one month
|
||||
return now + timedelta(days=30)
|
||||
elif task.recurrence_type == RecurrenceType.YEARLY:
|
||||
if task.recurrence_type == RecurrenceType.YEARLY:
|
||||
return now + timedelta(days=365)
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Task execution handlers for different task types."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -18,7 +17,6 @@ logger = get_logger(__name__)
|
||||
class TaskExecutionError(Exception):
|
||||
"""Exception raised when task execution fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TaskHandlerRegistry:
|
||||
@@ -58,8 +56,8 @@ class TaskHandlerRegistry:
|
||||
await handler(task)
|
||||
logger.info(f"Task {task.id} executed successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"Task {task.id} execution failed: {str(e)}")
|
||||
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
|
||||
logger.exception(f"Task {task.id} execution failed: {e!s}")
|
||||
raise TaskExecutionError(f"Task execution failed: {e!s}") from e
|
||||
|
||||
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
|
||||
"""Handle credit recharge task."""
|
||||
@@ -105,7 +103,7 @@ class TaskHandlerRegistry:
|
||||
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
|
||||
except Exception as e:
|
||||
# Convert HTTP exceptions or credit errors to task execution errors
|
||||
raise TaskExecutionError(f"Failed to play sound with credits: {str(e)}") from e
|
||||
raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e
|
||||
else:
|
||||
# System task: play without credit deduction
|
||||
sound = await self.sound_repository.get_by_id(sound_id_int)
|
||||
|
||||
@@ -7,6 +7,7 @@ from datetime import datetime
|
||||
from app.core.database import get_session_factory
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
async def check_tasks():
|
||||
session_factory = get_session_factory()
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.core.database import get_session_factory
|
||||
from app.models.scheduled_task import RecurrenceType, TaskType
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
from app.models.scheduled_task import TaskType, RecurrenceType
|
||||
|
||||
|
||||
async def create_future_task():
|
||||
session_factory = get_session_factory()
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.core.database import get_session_factory
|
||||
from app.main import get_global_scheduler_service
|
||||
from app.models.scheduled_task import TaskType, RecurrenceType
|
||||
from app.models.scheduled_task import RecurrenceType, TaskType
|
||||
|
||||
|
||||
async def test_api_task_creation():
|
||||
"""Test creating a task through the scheduler service (simulates API call)."""
|
||||
|
||||
@@ -5,8 +5,9 @@ import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.database import get_session_factory
|
||||
from app.models.scheduled_task import RecurrenceType, TaskType
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
from app.models.scheduled_task import TaskType, RecurrenceType
|
||||
|
||||
|
||||
async def create_test_task():
|
||||
session_factory = get_session_factory()
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -51,7 +51,7 @@ class TestSchedulerService:
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a scheduled task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
assert task.id is not None
|
||||
@@ -68,7 +68,7 @@ class TestSchedulerService:
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test creating a user task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(
|
||||
user_id=test_user_id,
|
||||
**sample_task_data,
|
||||
@@ -83,7 +83,7 @@ class TestSchedulerService:
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a system task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
assert task.user_id is None
|
||||
@@ -95,7 +95,7 @@ class TestSchedulerService:
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test creating a recurring task."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(
|
||||
recurrence_type=RecurrenceType.DAILY,
|
||||
recurrence_count=5,
|
||||
@@ -118,7 +118,7 @@ class TestSchedulerService:
|
||||
sample_task_data["scheduled_at"] = ny_time
|
||||
sample_task_data["timezone"] = "America/New_York"
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# The scheduled_at should be converted to UTC
|
||||
@@ -134,11 +134,11 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test cancelling a task."""
|
||||
# Create a task first
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock the scheduler remove_job method
|
||||
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
||||
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
|
||||
result = await scheduler_service.cancel_task(task.id)
|
||||
|
||||
assert result is True
|
||||
@@ -167,7 +167,7 @@ class TestSchedulerService:
|
||||
test_user_id: uuid.UUID,
|
||||
):
|
||||
"""Test getting user tasks."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
# Create user task
|
||||
await scheduler_service.create_task(
|
||||
user_id=test_user_id,
|
||||
@@ -188,8 +188,8 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test ensuring system tasks exist."""
|
||||
# Mock the repository to return no existing tasks
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
||||
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
|
||||
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
|
||||
mock_get.return_value = []
|
||||
|
||||
await scheduler_service._ensure_system_tasks()
|
||||
@@ -214,8 +214,8 @@ class TestSchedulerService:
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
||||
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
|
||||
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
|
||||
mock_get.return_value = [existing_task]
|
||||
|
||||
await scheduler_service._ensure_system_tasks()
|
||||
@@ -312,7 +312,7 @@ class TestSchedulerService:
|
||||
recurrence_type=recurrence_type,
|
||||
)
|
||||
|
||||
with patch('app.services.scheduler.datetime') as mock_datetime:
|
||||
with patch("app.services.scheduler.datetime") as mock_datetime:
|
||||
mock_datetime.utcnow.return_value = now
|
||||
next_execution = scheduler_service._calculate_next_execution(task)
|
||||
|
||||
@@ -335,7 +335,7 @@ class TestSchedulerService:
|
||||
next_execution = scheduler_service._calculate_next_execution(task)
|
||||
assert next_execution is None
|
||||
|
||||
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
||||
@patch("app.services.task_handlers.TaskHandlerRegistry")
|
||||
async def test_execute_task_success(
|
||||
self,
|
||||
mock_handler_class,
|
||||
@@ -344,7 +344,7 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test successful task execution."""
|
||||
# Create task
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock handler registry
|
||||
@@ -365,7 +365,7 @@ class TestSchedulerService:
|
||||
assert updated_task.status == TaskStatus.COMPLETED
|
||||
assert updated_task.executions_count == 1
|
||||
|
||||
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
||||
@patch("app.services.task_handlers.TaskHandlerRegistry")
|
||||
async def test_execute_task_failure(
|
||||
self,
|
||||
mock_handler_class,
|
||||
@@ -374,7 +374,7 @@ class TestSchedulerService:
|
||||
):
|
||||
"""Test task execution failure."""
|
||||
# Create task
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Mock handler to raise exception
|
||||
@@ -410,7 +410,7 @@ class TestSchedulerService:
|
||||
# Create expired task
|
||||
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Execute task
|
||||
@@ -430,20 +430,20 @@ class TestSchedulerService:
|
||||
sample_task_data: dict,
|
||||
):
|
||||
"""Test prevention of concurrent task execution."""
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||
task = await scheduler_service.create_task(**sample_task_data)
|
||||
|
||||
# Add task to running set
|
||||
scheduler_service._running_tasks.add(str(task.id))
|
||||
|
||||
# Try to execute - should return without doing anything
|
||||
with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class:
|
||||
with patch("app.services.task_handlers.TaskHandlerRegistry") as mock_handler_class:
|
||||
await scheduler_service._execute_task(task.id)
|
||||
|
||||
# Handler should not be called
|
||||
mock_handler_class.assert_not_called()
|
||||
|
||||
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
||||
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
|
||||
async def test_maintenance_job_expired_tasks(
|
||||
self,
|
||||
mock_repo_class,
|
||||
@@ -459,7 +459,7 @@ class TestSchedulerService:
|
||||
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
|
||||
mock_repo_class.return_value = mock_repo
|
||||
|
||||
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
||||
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
|
||||
await scheduler_service._maintenance_job()
|
||||
|
||||
# Should mark as cancelled and remove from scheduler
|
||||
@@ -468,7 +468,7 @@ class TestSchedulerService:
|
||||
mock_repo.update.assert_called_with(expired_task)
|
||||
mock_remove.assert_called_once_with(str(expired_task.id))
|
||||
|
||||
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
||||
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
|
||||
async def test_maintenance_job_due_recurring_tasks(
|
||||
self,
|
||||
mock_repo_class,
|
||||
@@ -485,7 +485,7 @@ class TestSchedulerService:
|
||||
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
|
||||
mock_repo_class.return_value = mock_repo
|
||||
|
||||
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
||||
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
|
||||
await scheduler_service._maintenance_job()
|
||||
|
||||
# Should reset to pending and reschedule
|
||||
|
||||
@@ -133,8 +133,8 @@ class TestTaskHandlerRegistry:
|
||||
mock_sound.id = test_sound_id
|
||||
mock_sound.filename = "test_sound.mp3"
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
||||
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
||||
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
|
||||
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_vlc_class.return_value = mock_vlc_service
|
||||
|
||||
@@ -186,7 +186,7 @@ class TestTaskHandlerRegistry:
|
||||
parameters={"sound_id": str(test_sound_id)},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None):
|
||||
with patch.object(task_registry.sound_repository, "get_by_id", return_value=None):
|
||||
with pytest.raises(TaskExecutionError, match="Sound not found"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
@@ -206,8 +206,8 @@ class TestTaskHandlerRegistry:
|
||||
mock_sound = MagicMock()
|
||||
mock_sound.filename = "test_sound.mp3"
|
||||
|
||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
||||
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
||||
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
|
||||
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
|
||||
mock_vlc_service = AsyncMock()
|
||||
mock_vlc_class.return_value = mock_vlc_service
|
||||
|
||||
@@ -238,7 +238,7 @@ class TestTaskHandlerRegistry:
|
||||
mock_playlist.id = test_playlist_id
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
|
||||
@@ -264,7 +264,7 @@ class TestTaskHandlerRegistry:
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
# Should use default values
|
||||
@@ -314,7 +314,7 @@ class TestTaskHandlerRegistry:
|
||||
parameters={"playlist_id": str(test_playlist_id)},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=None):
|
||||
with pytest.raises(TaskExecutionError, match="Playlist not found"):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
@@ -341,7 +341,7 @@ class TestTaskHandlerRegistry:
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
mock_player_service.set_mode.assert_called_with(mode)
|
||||
|
||||
@@ -368,7 +368,7 @@ class TestTaskHandlerRegistry:
|
||||
mock_playlist = MagicMock()
|
||||
mock_playlist.name = "Test Playlist"
|
||||
|
||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||
await task_registry.execute_task(task)
|
||||
|
||||
# Should not call set_mode for invalid mode
|
||||
|
||||
Reference in New Issue
Block a user