- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
414 lines
15 KiB
Python
414 lines
15 KiB
Python
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
|
|
|
|
from collections.abc import Callable
|
|
from datetime import datetime, timedelta
|
|
from typing import Any
|
|
|
|
import pytz
|
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|
from apscheduler.triggers.cron import CronTrigger
|
|
from apscheduler.triggers.date import DateTrigger
|
|
from apscheduler.triggers.interval import IntervalTrigger
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.scheduled_task import (
|
|
RecurrenceType,
|
|
ScheduledTask,
|
|
TaskStatus,
|
|
TaskType,
|
|
)
|
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
|
from app.services.credit import CreditService
|
|
from app.services.player import PlayerService
|
|
from app.services.task_handlers import TaskHandlerRegistry
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class SchedulerService:
|
|
"""Enhanced service for managing scheduled tasks with timezone support."""
|
|
|
|
def __init__(
|
|
self,
|
|
db_session_factory: Callable[[], AsyncSession],
|
|
player_service: PlayerService,
|
|
) -> None:
|
|
"""Initialize the scheduler service.
|
|
|
|
Args:
|
|
db_session_factory: Factory function to create database sessions
|
|
player_service: Player service for audio playback tasks
|
|
|
|
"""
|
|
self.db_session_factory = db_session_factory
|
|
self.scheduler = AsyncIOScheduler(timezone=pytz.UTC)
|
|
self.credit_service = CreditService(db_session_factory)
|
|
self.player_service = player_service
|
|
self._running_tasks: set[str] = set()
|
|
|
|
async def start(self) -> None:
|
|
"""Start the scheduler and load all active tasks."""
|
|
logger.info("Starting enhanced scheduler service...")
|
|
|
|
self.scheduler.start()
|
|
|
|
# Schedule system tasks initialization for after startup
|
|
self.scheduler.add_job(
|
|
self._initialize_system_tasks,
|
|
"date",
|
|
run_date=datetime.utcnow() + timedelta(seconds=2),
|
|
id="initialize_system_tasks",
|
|
name="Initialize System Tasks",
|
|
replace_existing=True,
|
|
)
|
|
|
|
# Schedule periodic cleanup and maintenance
|
|
self.scheduler.add_job(
|
|
self._maintenance_job,
|
|
"interval",
|
|
minutes=5,
|
|
id="scheduler_maintenance",
|
|
name="Scheduler Maintenance",
|
|
replace_existing=True,
|
|
)
|
|
|
|
logger.info("Enhanced scheduler service started successfully")
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the scheduler."""
|
|
logger.info("Stopping scheduler service...")
|
|
self.scheduler.shutdown(wait=True)
|
|
logger.info("Scheduler service stopped")
|
|
|
|
async def create_task(
|
|
self,
|
|
name: str,
|
|
task_type: TaskType,
|
|
scheduled_at: datetime,
|
|
parameters: dict[str, Any] | None = None,
|
|
user_id: int | None = None,
|
|
timezone: str = "UTC",
|
|
recurrence_type: RecurrenceType = RecurrenceType.NONE,
|
|
cron_expression: str | None = None,
|
|
recurrence_count: int | None = None,
|
|
expires_at: datetime | None = None,
|
|
) -> ScheduledTask:
|
|
"""Create a new scheduled task."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
# Convert scheduled_at to UTC if it's in a different timezone
|
|
if timezone != "UTC":
|
|
tz = pytz.timezone(timezone)
|
|
if scheduled_at.tzinfo is None:
|
|
# Assume the datetime is in the specified timezone
|
|
scheduled_at = tz.localize(scheduled_at)
|
|
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
|
|
|
task_data = {
|
|
"name": name,
|
|
"task_type": task_type,
|
|
"scheduled_at": scheduled_at,
|
|
"timezone": timezone,
|
|
"parameters": parameters or {},
|
|
"user_id": user_id,
|
|
"recurrence_type": recurrence_type,
|
|
"cron_expression": cron_expression,
|
|
"recurrence_count": recurrence_count,
|
|
"expires_at": expires_at,
|
|
}
|
|
|
|
created_task = await repo.create(task_data)
|
|
await self._schedule_apscheduler_job(created_task)
|
|
|
|
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
|
|
return created_task
|
|
|
|
async def cancel_task(self, task_id: int) -> bool:
|
|
"""Cancel a scheduled task."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
task = await repo.get_by_id(task_id)
|
|
if not task:
|
|
return False
|
|
|
|
task.status = TaskStatus.CANCELLED
|
|
task.is_active = False
|
|
await repo.update(task)
|
|
|
|
# Remove from APScheduler
|
|
try:
|
|
self.scheduler.remove_job(str(task_id))
|
|
except Exception:
|
|
pass # Job might not exist in scheduler
|
|
|
|
logger.info(f"Cancelled task: {task.name} ({task_id})")
|
|
return True
|
|
|
|
async def get_user_tasks(
|
|
self,
|
|
user_id: int,
|
|
status: TaskStatus | None = None,
|
|
task_type: TaskType | None = None,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
) -> list[ScheduledTask]:
|
|
"""Get tasks for a specific user."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
|
|
|
|
async def _initialize_system_tasks(self) -> None:
|
|
"""Initialize system tasks and load active tasks from database."""
|
|
logger.info("Initializing system tasks...")
|
|
|
|
try:
|
|
# Create system tasks if they don't exist
|
|
await self._ensure_system_tasks()
|
|
|
|
# Load all active tasks from database
|
|
await self._load_active_tasks()
|
|
|
|
logger.info("System tasks initialized successfully")
|
|
except Exception:
|
|
logger.exception("Failed to initialize system tasks")
|
|
|
|
async def _ensure_system_tasks(self) -> None:
|
|
"""Ensure required system tasks exist."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
# Check if daily credit recharge task exists
|
|
system_tasks = await repo.get_system_tasks(
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
)
|
|
|
|
daily_recharge_exists = any(
|
|
task.recurrence_type == RecurrenceType.DAILY
|
|
and task.is_active
|
|
for task in system_tasks
|
|
)
|
|
|
|
if not daily_recharge_exists:
|
|
# Create daily credit recharge task
|
|
tomorrow_midnight = datetime.utcnow().replace(
|
|
hour=0, minute=0, second=0, microsecond=0,
|
|
) + timedelta(days=1)
|
|
|
|
task_data = {
|
|
"name": "Daily Credit Recharge",
|
|
"task_type": TaskType.CREDIT_RECHARGE,
|
|
"scheduled_at": tomorrow_midnight,
|
|
"recurrence_type": RecurrenceType.DAILY,
|
|
"parameters": {},
|
|
}
|
|
|
|
await repo.create(task_data)
|
|
logger.info("Created system daily credit recharge task")
|
|
|
|
async def _load_active_tasks(self) -> None:
|
|
"""Load all active tasks from database into scheduler."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
active_tasks = await repo.get_active_tasks()
|
|
|
|
for task in active_tasks:
|
|
await self._schedule_apscheduler_job(task)
|
|
|
|
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
|
|
|
|
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
|
"""Schedule a task in APScheduler."""
|
|
job_id = str(task.id)
|
|
|
|
# Remove existing job if it exists
|
|
try:
|
|
self.scheduler.remove_job(job_id)
|
|
except Exception:
|
|
pass
|
|
|
|
# Don't schedule if task is not active or already completed/failed
|
|
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
|
return
|
|
|
|
# Create trigger based on recurrence type
|
|
trigger = self._create_trigger(task)
|
|
if not trigger:
|
|
logger.warning(f"Could not create trigger for task {task.id}")
|
|
return
|
|
|
|
# Schedule the job
|
|
self.scheduler.add_job(
|
|
self._execute_task,
|
|
trigger=trigger,
|
|
args=[task.id],
|
|
id=job_id,
|
|
name=task.name,
|
|
replace_existing=True,
|
|
)
|
|
|
|
logger.debug(f"Scheduled APScheduler job for task {task.id}")
|
|
|
|
def _create_trigger(self, task: ScheduledTask):
|
|
"""Create APScheduler trigger based on task configuration."""
|
|
tz = pytz.timezone(task.timezone)
|
|
|
|
if task.recurrence_type == RecurrenceType.NONE:
|
|
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
|
|
|
|
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
|
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
|
|
|
if task.recurrence_type == RecurrenceType.HOURLY:
|
|
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
|
|
|
|
if task.recurrence_type == RecurrenceType.DAILY:
|
|
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
|
|
|
|
if task.recurrence_type == RecurrenceType.WEEKLY:
|
|
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
|
|
|
|
if task.recurrence_type == RecurrenceType.MONTHLY:
|
|
# Use cron trigger for monthly (more reliable than interval)
|
|
scheduled_time = task.scheduled_at
|
|
return CronTrigger(
|
|
day=scheduled_time.day,
|
|
hour=scheduled_time.hour,
|
|
minute=scheduled_time.minute,
|
|
timezone=tz,
|
|
)
|
|
|
|
if task.recurrence_type == RecurrenceType.YEARLY:
|
|
scheduled_time = task.scheduled_at
|
|
return CronTrigger(
|
|
month=scheduled_time.month,
|
|
day=scheduled_time.day,
|
|
hour=scheduled_time.hour,
|
|
minute=scheduled_time.minute,
|
|
timezone=tz,
|
|
)
|
|
|
|
return None
|
|
|
|
async def _execute_task(self, task_id: int) -> None:
|
|
"""Execute a scheduled task."""
|
|
task_id_str = str(task_id)
|
|
|
|
# Prevent concurrent execution of the same task
|
|
if task_id_str in self._running_tasks:
|
|
logger.warning(f"Task {task_id} is already running, skipping execution")
|
|
return
|
|
|
|
self._running_tasks.add(task_id_str)
|
|
|
|
try:
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
# Get fresh task data
|
|
task = await repo.get_by_id(task_id)
|
|
if not task:
|
|
logger.warning(f"Task {task_id} not found")
|
|
return
|
|
|
|
# Check if task is still active and pending
|
|
if not task.is_active or task.status != TaskStatus.PENDING:
|
|
logger.info(f"Task {task_id} is not active or not pending, skipping")
|
|
return
|
|
|
|
# Check if task has expired
|
|
if task.is_expired():
|
|
logger.info(f"Task {task_id} has expired, marking as cancelled")
|
|
task.status = TaskStatus.CANCELLED
|
|
task.is_active = False
|
|
await repo.update(task)
|
|
return
|
|
|
|
# Mark task as running
|
|
await repo.mark_as_running(task)
|
|
|
|
# Execute the task
|
|
try:
|
|
handler_registry = TaskHandlerRegistry(
|
|
session, self.db_session_factory, self.credit_service, self.player_service,
|
|
)
|
|
await handler_registry.execute_task(task)
|
|
|
|
# Calculate next execution time for recurring tasks
|
|
next_execution_at = None
|
|
if task.should_repeat():
|
|
next_execution_at = self._calculate_next_execution(task)
|
|
|
|
# Mark as completed
|
|
await repo.mark_as_completed(task, next_execution_at)
|
|
|
|
# Reschedule if recurring
|
|
if next_execution_at and task.should_repeat():
|
|
# Refresh task to get updated data
|
|
await session.refresh(task)
|
|
await self._schedule_apscheduler_job(task)
|
|
|
|
except Exception as e:
|
|
await repo.mark_as_failed(task, str(e))
|
|
logger.exception(f"Task {task_id} execution failed: {e!s}")
|
|
|
|
finally:
|
|
self._running_tasks.discard(task_id_str)
|
|
|
|
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
|
|
"""Calculate the next execution time for a recurring task."""
|
|
now = datetime.utcnow()
|
|
|
|
if task.recurrence_type == RecurrenceType.HOURLY:
|
|
return now + timedelta(hours=1)
|
|
if task.recurrence_type == RecurrenceType.DAILY:
|
|
return now + timedelta(days=1)
|
|
if task.recurrence_type == RecurrenceType.WEEKLY:
|
|
return now + timedelta(weeks=1)
|
|
if task.recurrence_type == RecurrenceType.MONTHLY:
|
|
# Add approximately one month
|
|
return now + timedelta(days=30)
|
|
if task.recurrence_type == RecurrenceType.YEARLY:
|
|
return now + timedelta(days=365)
|
|
|
|
return None
|
|
|
|
async def _maintenance_job(self) -> None:
|
|
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
|
|
try:
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
# Handle expired tasks
|
|
expired_tasks = await repo.get_expired_tasks()
|
|
for task in expired_tasks:
|
|
task.status = TaskStatus.CANCELLED
|
|
task.is_active = False
|
|
await repo.update(task)
|
|
|
|
# Remove from scheduler
|
|
try:
|
|
self.scheduler.remove_job(str(task.id))
|
|
except Exception:
|
|
pass
|
|
|
|
if expired_tasks:
|
|
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
|
|
|
|
# Handle any missed recurring tasks
|
|
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
|
for task in due_recurring:
|
|
if task.should_repeat():
|
|
task.status = TaskStatus.PENDING
|
|
task.scheduled_at = task.next_execution_at or datetime.utcnow()
|
|
await repo.update(task)
|
|
await self._schedule_apscheduler_job(task)
|
|
|
|
if due_recurring:
|
|
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
|
|
|
|
except Exception:
|
|
logger.exception("Maintenance job failed")
|