426 lines
16 KiB
Python
426 lines
16 KiB
Python
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
|
|
|
|
from collections.abc import Callable
|
|
from contextlib import suppress
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytz
|
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|
from apscheduler.triggers.cron import CronTrigger
|
|
from apscheduler.triggers.date import DateTrigger
|
|
from apscheduler.triggers.interval import IntervalTrigger
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from app.core.logging import get_logger
|
|
from app.models.scheduled_task import (
|
|
RecurrenceType,
|
|
ScheduledTask,
|
|
TaskStatus,
|
|
TaskType,
|
|
)
|
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
|
from app.schemas.scheduler import ScheduledTaskCreate
|
|
from app.services.credit import CreditService
|
|
from app.services.player import PlayerService
|
|
from app.services.task_handlers import TaskHandlerRegistry
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class SchedulerService:
|
|
"""Enhanced service for managing scheduled tasks with timezone support."""
|
|
|
|
def __init__(
|
|
self,
|
|
db_session_factory: Callable[[], AsyncSession],
|
|
player_service: PlayerService,
|
|
) -> None:
|
|
"""Initialize the scheduler service.
|
|
|
|
Args:
|
|
db_session_factory: Factory function to create database sessions
|
|
player_service: Player service for audio playback tasks
|
|
|
|
"""
|
|
self.db_session_factory = db_session_factory
|
|
self.scheduler = AsyncIOScheduler(timezone=pytz.UTC)
|
|
self.credit_service = CreditService(db_session_factory)
|
|
self.player_service = player_service
|
|
self._running_tasks: set[str] = set()
|
|
|
|
async def start(self) -> None:
|
|
"""Start the scheduler and load all active tasks."""
|
|
logger.info("Starting enhanced scheduler service...")
|
|
|
|
self.scheduler.start()
|
|
|
|
# Schedule system tasks initialization for after startup
|
|
self.scheduler.add_job(
|
|
self._initialize_system_tasks,
|
|
"date",
|
|
run_date=datetime.now(tz=UTC) + timedelta(seconds=2),
|
|
id="initialize_system_tasks",
|
|
name="Initialize System Tasks",
|
|
replace_existing=True,
|
|
)
|
|
|
|
# Schedule periodic cleanup and maintenance
|
|
self.scheduler.add_job(
|
|
self._maintenance_job,
|
|
"interval",
|
|
minutes=5,
|
|
id="scheduler_maintenance",
|
|
name="Scheduler Maintenance",
|
|
replace_existing=True,
|
|
)
|
|
|
|
logger.info("Enhanced scheduler service started successfully")
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the scheduler."""
|
|
logger.info("Stopping scheduler service...")
|
|
self.scheduler.shutdown(wait=True)
|
|
logger.info("Scheduler service stopped")
|
|
|
|
async def create_task(
|
|
self,
|
|
task_data: ScheduledTaskCreate,
|
|
user_id: int | None = None,
|
|
) -> ScheduledTask:
|
|
"""Create a new scheduled task from schema data."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
# Convert scheduled_at to UTC if it's in a different timezone
|
|
scheduled_at = task_data.scheduled_at
|
|
if task_data.timezone != "UTC":
|
|
tz = pytz.timezone(task_data.timezone)
|
|
if scheduled_at.tzinfo is None:
|
|
# Assume the datetime is in the specified timezone
|
|
scheduled_at = tz.localize(scheduled_at)
|
|
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
|
|
|
db_task_data = {
|
|
"name": task_data.name,
|
|
"task_type": task_data.task_type,
|
|
"scheduled_at": scheduled_at,
|
|
"timezone": task_data.timezone,
|
|
"parameters": task_data.parameters,
|
|
"user_id": user_id,
|
|
"recurrence_type": task_data.recurrence_type,
|
|
"cron_expression": task_data.cron_expression,
|
|
"recurrence_count": task_data.recurrence_count,
|
|
"expires_at": task_data.expires_at,
|
|
}
|
|
|
|
created_task = await repo.create(db_task_data)
|
|
await self._schedule_apscheduler_job(created_task)
|
|
|
|
logger.info(
|
|
"Created scheduled task: %s (%s)",
|
|
created_task.name,
|
|
created_task.id,
|
|
)
|
|
return created_task
|
|
|
|
async def cancel_task(self, task_id: int) -> bool:
|
|
"""Cancel a scheduled task."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
task = await repo.get_by_id(task_id)
|
|
if not task:
|
|
return False
|
|
|
|
await repo.update(task, {
|
|
"status": TaskStatus.CANCELLED,
|
|
"is_active": False,
|
|
})
|
|
|
|
# Remove from APScheduler (job might not exist in scheduler)
|
|
with suppress(Exception):
|
|
self.scheduler.remove_job(str(task_id))
|
|
|
|
logger.info("Cancelled task: %s (%s)", task.name, task_id)
|
|
return True
|
|
|
|
async def get_user_tasks(
|
|
self,
|
|
user_id: int,
|
|
status: TaskStatus | None = None,
|
|
task_type: TaskType | None = None,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
) -> list[ScheduledTask]:
|
|
"""Get tasks for a specific user."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
|
|
|
|
async def _initialize_system_tasks(self) -> None:
|
|
"""Initialize system tasks and load active tasks from database."""
|
|
logger.info("Initializing system tasks...")
|
|
|
|
try:
|
|
# Create system tasks if they don't exist
|
|
await self._ensure_system_tasks()
|
|
|
|
# Load all active tasks from database
|
|
await self._load_active_tasks()
|
|
|
|
logger.info("System tasks initialized successfully")
|
|
except Exception:
|
|
logger.exception("Failed to initialize system tasks")
|
|
|
|
async def _ensure_system_tasks(self) -> None:
|
|
"""Ensure required system tasks exist."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
# Check if daily credit recharge task exists
|
|
system_tasks = await repo.get_system_tasks(
|
|
task_type=TaskType.CREDIT_RECHARGE,
|
|
)
|
|
|
|
daily_recharge_exists = any(
|
|
task.recurrence_type == RecurrenceType.DAILY
|
|
and task.is_active
|
|
for task in system_tasks
|
|
)
|
|
|
|
if not daily_recharge_exists:
|
|
# Create daily credit recharge task
|
|
tomorrow_midnight = datetime.now(tz=UTC).replace(
|
|
hour=0, minute=0, second=0, microsecond=0,
|
|
) + timedelta(days=1)
|
|
|
|
task_data = {
|
|
"name": "Daily Credit Recharge",
|
|
"task_type": TaskType.CREDIT_RECHARGE,
|
|
"scheduled_at": tomorrow_midnight,
|
|
"recurrence_type": RecurrenceType.DAILY,
|
|
"parameters": {},
|
|
}
|
|
|
|
await repo.create(task_data)
|
|
logger.info("Created system daily credit recharge task")
|
|
|
|
async def _load_active_tasks(self) -> None:
|
|
"""Load all active tasks from database into scheduler."""
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
active_tasks = await repo.get_active_tasks()
|
|
|
|
for task in active_tasks:
|
|
await self._schedule_apscheduler_job(task)
|
|
|
|
logger.info("Loaded %s active tasks into scheduler", len(active_tasks))
|
|
|
|
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
|
"""Schedule a task in APScheduler."""
|
|
job_id = str(task.id)
|
|
|
|
# Remove existing job if it exists
|
|
with suppress(Exception):
|
|
self.scheduler.remove_job(job_id)
|
|
|
|
# Don't schedule if task is not active or already completed/failed
|
|
inactive_statuses = [
|
|
TaskStatus.COMPLETED,
|
|
TaskStatus.FAILED,
|
|
TaskStatus.CANCELLED,
|
|
]
|
|
if not task.is_active or task.status in inactive_statuses:
|
|
return
|
|
|
|
# Create trigger based on recurrence type
|
|
trigger = self._create_trigger(task)
|
|
if not trigger:
|
|
logger.warning("Could not create trigger for task %s", task.id)
|
|
return
|
|
|
|
# Schedule the job
|
|
self.scheduler.add_job(
|
|
self._execute_task,
|
|
trigger=trigger,
|
|
args=[task.id],
|
|
id=job_id,
|
|
name=task.name,
|
|
replace_existing=True,
|
|
)
|
|
|
|
logger.debug("Scheduled APScheduler job for task %s", task.id)
|
|
|
|
def _create_trigger(
|
|
self, task: ScheduledTask,
|
|
) -> DateTrigger | IntervalTrigger | CronTrigger | None:
|
|
"""Create APScheduler trigger based on task configuration."""
|
|
tz = pytz.timezone(task.timezone)
|
|
scheduled_time = task.scheduled_at
|
|
|
|
# Handle special cases first
|
|
if task.recurrence_type == RecurrenceType.NONE:
|
|
return DateTrigger(run_date=scheduled_time, timezone=tz)
|
|
|
|
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
|
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
|
|
|
# Handle interval-based recurrence types
|
|
interval_configs = {
|
|
RecurrenceType.HOURLY: {"hours": 1},
|
|
RecurrenceType.DAILY: {"days": 1},
|
|
RecurrenceType.WEEKLY: {"weeks": 1},
|
|
}
|
|
|
|
if task.recurrence_type in interval_configs:
|
|
config = interval_configs[task.recurrence_type]
|
|
return IntervalTrigger(start_date=scheduled_time, timezone=tz, **config)
|
|
|
|
# Handle cron-based recurrence types
|
|
cron_configs = {
|
|
RecurrenceType.MONTHLY: {
|
|
"day": scheduled_time.day,
|
|
"hour": scheduled_time.hour,
|
|
"minute": scheduled_time.minute,
|
|
},
|
|
RecurrenceType.YEARLY: {
|
|
"month": scheduled_time.month,
|
|
"day": scheduled_time.day,
|
|
"hour": scheduled_time.hour,
|
|
"minute": scheduled_time.minute,
|
|
},
|
|
}
|
|
|
|
if task.recurrence_type in cron_configs:
|
|
config = cron_configs[task.recurrence_type]
|
|
return CronTrigger(timezone=tz, **config)
|
|
|
|
return None
|
|
|
|
async def _execute_task(self, task_id: int) -> None:
|
|
"""Execute a scheduled task."""
|
|
task_id_str = str(task_id)
|
|
|
|
# Prevent concurrent execution of the same task
|
|
if task_id_str in self._running_tasks:
|
|
logger.warning("Task %s is already running, skipping execution", task_id)
|
|
return
|
|
|
|
self._running_tasks.add(task_id_str)
|
|
|
|
try:
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
# Get fresh task data
|
|
task = await repo.get_by_id(task_id)
|
|
if not task:
|
|
logger.warning("Task %s not found", task_id)
|
|
return
|
|
|
|
# Check if task is still active and pending
|
|
if not task.is_active or task.status != TaskStatus.PENDING:
|
|
logger.info("Task %s not active or not pending, skipping", task_id)
|
|
return
|
|
|
|
# Check if task has expired
|
|
if task.is_expired():
|
|
logger.info("Task %s has expired, marking as cancelled", task_id)
|
|
await repo.update(task, {
|
|
"status": TaskStatus.CANCELLED,
|
|
"is_active": False,
|
|
})
|
|
return
|
|
|
|
# Mark task as running
|
|
await repo.mark_as_running(task)
|
|
|
|
# Execute the task
|
|
try:
|
|
handler_registry = TaskHandlerRegistry(
|
|
session,
|
|
self.db_session_factory,
|
|
self.credit_service,
|
|
self.player_service,
|
|
)
|
|
await handler_registry.execute_task(task)
|
|
|
|
# Calculate next execution time for recurring tasks
|
|
next_execution_at = None
|
|
if task.should_repeat():
|
|
next_execution_at = self._calculate_next_execution(task)
|
|
|
|
# Mark as completed
|
|
await repo.mark_as_completed(task, next_execution_at)
|
|
|
|
# Reschedule if recurring
|
|
if next_execution_at and task.should_repeat():
|
|
# Refresh task to get updated data
|
|
await session.refresh(task)
|
|
await self._schedule_apscheduler_job(task)
|
|
|
|
except Exception as e:
|
|
await repo.mark_as_failed(task, str(e))
|
|
logger.exception("Task %s execution failed", task_id)
|
|
|
|
finally:
|
|
self._running_tasks.discard(task_id_str)
|
|
|
|
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
|
|
"""Calculate the next execution time for a recurring task."""
|
|
now = datetime.now(tz=UTC)
|
|
|
|
if task.recurrence_type == RecurrenceType.HOURLY:
|
|
return now + timedelta(hours=1)
|
|
if task.recurrence_type == RecurrenceType.DAILY:
|
|
return now + timedelta(days=1)
|
|
if task.recurrence_type == RecurrenceType.WEEKLY:
|
|
return now + timedelta(weeks=1)
|
|
if task.recurrence_type == RecurrenceType.MONTHLY:
|
|
# Add approximately one month
|
|
return now + timedelta(days=30)
|
|
if task.recurrence_type == RecurrenceType.YEARLY:
|
|
return now + timedelta(days=365)
|
|
|
|
return None
|
|
|
|
async def _maintenance_job(self) -> None:
|
|
"""Periodic maintenance job to clean up expired tasks and handle scheduling."""
|
|
try:
|
|
async with self.db_session_factory() as session:
|
|
repo = ScheduledTaskRepository(session)
|
|
|
|
# Handle expired tasks
|
|
expired_tasks = await repo.get_expired_tasks()
|
|
for task in expired_tasks:
|
|
await repo.update(task, {
|
|
"status": TaskStatus.CANCELLED,
|
|
"is_active": False,
|
|
})
|
|
|
|
# Remove from scheduler
|
|
with suppress(Exception):
|
|
self.scheduler.remove_job(str(task.id))
|
|
|
|
if expired_tasks:
|
|
logger.info("Cleaned up %s expired tasks", len(expired_tasks))
|
|
|
|
# Handle any missed recurring tasks
|
|
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
|
for task in due_recurring:
|
|
if task.should_repeat():
|
|
next_scheduled_at = (
|
|
task.next_execution_at or datetime.now(tz=UTC)
|
|
)
|
|
await repo.update(task, {
|
|
"status": TaskStatus.PENDING,
|
|
"scheduled_at": next_scheduled_at,
|
|
})
|
|
await self._schedule_apscheduler_job(task)
|
|
|
|
if due_recurring:
|
|
logger.info("Rescheduled %s recurring tasks", len(due_recurring))
|
|
|
|
except Exception:
|
|
logger.exception("Maintenance job failed")
|