Refactor scheduled task repository and schemas for improved type hints and consistency
- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
@@ -52,7 +52,7 @@ class SchedulerService:
|
||||
logger.info("Starting enhanced scheduler service...")
|
||||
|
||||
self.scheduler.start()
|
||||
|
||||
|
||||
# Schedule system tasks initialization for after startup
|
||||
self.scheduler.add_job(
|
||||
self._initialize_system_tasks,
|
||||
@@ -62,7 +62,7 @@ class SchedulerService:
|
||||
name="Initialize System Tasks",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
|
||||
# Schedule periodic cleanup and maintenance
|
||||
self.scheduler.add_job(
|
||||
self._maintenance_job,
|
||||
@@ -86,18 +86,18 @@ class SchedulerService:
|
||||
name: str,
|
||||
task_type: TaskType,
|
||||
scheduled_at: datetime,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[int] = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
user_id: int | None = None,
|
||||
timezone: str = "UTC",
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE,
|
||||
cron_expression: Optional[str] = None,
|
||||
recurrence_count: Optional[int] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
cron_expression: str | None = None,
|
||||
recurrence_count: int | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
|
||||
# Convert scheduled_at to UTC if it's in a different timezone
|
||||
if timezone != "UTC":
|
||||
tz = pytz.timezone(timezone)
|
||||
@@ -105,7 +105,7 @@ class SchedulerService:
|
||||
# Assume the datetime is in the specified timezone
|
||||
scheduled_at = tz.localize(scheduled_at)
|
||||
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
||||
|
||||
|
||||
task_data = {
|
||||
"name": name,
|
||||
"task_type": task_type,
|
||||
@@ -118,59 +118,59 @@ class SchedulerService:
|
||||
"recurrence_count": recurrence_count,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
|
||||
created_task = await repo.create(task_data)
|
||||
await self._schedule_apscheduler_job(created_task)
|
||||
|
||||
|
||||
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
|
||||
return created_task
|
||||
|
||||
|
||||
async def cancel_task(self, task_id: int) -> bool:
|
||||
"""Cancel a scheduled task."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
|
||||
task = await repo.get_by_id(task_id)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
|
||||
|
||||
# Remove from APScheduler
|
||||
try:
|
||||
self.scheduler.remove_job(str(task_id))
|
||||
except Exception:
|
||||
pass # Job might not exist in scheduler
|
||||
|
||||
|
||||
logger.info(f"Cancelled task: {task.name} ({task_id})")
|
||||
return True
|
||||
|
||||
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
status: TaskStatus | None = None,
|
||||
task_type: TaskType | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
|
||||
|
||||
|
||||
async def _initialize_system_tasks(self) -> None:
|
||||
"""Initialize system tasks and load active tasks from database."""
|
||||
logger.info("Initializing system tasks...")
|
||||
|
||||
|
||||
try:
|
||||
# Create system tasks if they don't exist
|
||||
await self._ensure_system_tasks()
|
||||
|
||||
|
||||
# Load all active tasks from database
|
||||
await self._load_active_tasks()
|
||||
|
||||
|
||||
logger.info("System tasks initialized successfully")
|
||||
except Exception:
|
||||
logger.exception("Failed to initialize system tasks")
|
||||
@@ -179,24 +179,24 @@ class SchedulerService:
|
||||
"""Ensure required system tasks exist."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
|
||||
# Check if daily credit recharge task exists
|
||||
system_tasks = await repo.get_system_tasks(
|
||||
task_type=TaskType.CREDIT_RECHARGE
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
)
|
||||
|
||||
|
||||
daily_recharge_exists = any(
|
||||
task.recurrence_type == RecurrenceType.DAILY
|
||||
and task.is_active
|
||||
for task in system_tasks
|
||||
)
|
||||
|
||||
|
||||
if not daily_recharge_exists:
|
||||
# Create daily credit recharge task
|
||||
tomorrow_midnight = datetime.utcnow().replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
hour=0, minute=0, second=0, microsecond=0,
|
||||
) + timedelta(days=1)
|
||||
|
||||
|
||||
task_data = {
|
||||
"name": "Daily Credit Recharge",
|
||||
"task_type": TaskType.CREDIT_RECHARGE,
|
||||
@@ -204,41 +204,41 @@ class SchedulerService:
|
||||
"recurrence_type": RecurrenceType.DAILY,
|
||||
"parameters": {},
|
||||
}
|
||||
|
||||
|
||||
await repo.create(task_data)
|
||||
logger.info("Created system daily credit recharge task")
|
||||
|
||||
|
||||
async def _load_active_tasks(self) -> None:
|
||||
"""Load all active tasks from database into scheduler."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
active_tasks = await repo.get_active_tasks()
|
||||
|
||||
|
||||
for task in active_tasks:
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
|
||||
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
|
||||
|
||||
|
||||
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
||||
"""Schedule a task in APScheduler."""
|
||||
job_id = str(task.id)
|
||||
|
||||
|
||||
# Remove existing job if it exists
|
||||
try:
|
||||
self.scheduler.remove_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Don't schedule if task is not active or already completed/failed
|
||||
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||
return
|
||||
|
||||
|
||||
# Create trigger based on recurrence type
|
||||
trigger = self._create_trigger(task)
|
||||
if not trigger:
|
||||
logger.warning(f"Could not create trigger for task {task.id}")
|
||||
return
|
||||
|
||||
|
||||
# Schedule the job
|
||||
self.scheduler.add_job(
|
||||
self._execute_task,
|
||||
@@ -248,76 +248,76 @@ class SchedulerService:
|
||||
name=task.name,
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"Scheduled APScheduler job for task {task.id}")
|
||||
|
||||
|
||||
def _create_trigger(self, task: ScheduledTask):
|
||||
"""Create APScheduler trigger based on task configuration."""
|
||||
tz = pytz.timezone(task.timezone)
|
||||
|
||||
|
||||
if task.recurrence_type == RecurrenceType.NONE:
|
||||
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||
|
||||
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.HOURLY:
|
||||
|
||||
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.DAILY:
|
||||
|
||||
if task.recurrence_type == RecurrenceType.DAILY:
|
||||
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
|
||||
if task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
|
||||
if task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
# Use cron trigger for monthly (more reliable than interval)
|
||||
scheduled_time = task.scheduled_at
|
||||
return CronTrigger(
|
||||
day=scheduled_time.day,
|
||||
hour=scheduled_time.hour,
|
||||
minute=scheduled_time.minute,
|
||||
timezone=tz
|
||||
timezone=tz,
|
||||
)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.YEARLY:
|
||||
|
||||
if task.recurrence_type == RecurrenceType.YEARLY:
|
||||
scheduled_time = task.scheduled_at
|
||||
return CronTrigger(
|
||||
month=scheduled_time.month,
|
||||
day=scheduled_time.day,
|
||||
hour=scheduled_time.hour,
|
||||
minute=scheduled_time.minute,
|
||||
timezone=tz
|
||||
timezone=tz,
|
||||
)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _execute_task(self, task_id: int) -> None:
|
||||
"""Execute a scheduled task."""
|
||||
task_id_str = str(task_id)
|
||||
|
||||
|
||||
# Prevent concurrent execution of the same task
|
||||
if task_id_str in self._running_tasks:
|
||||
logger.warning(f"Task {task_id} is already running, skipping execution")
|
||||
return
|
||||
|
||||
|
||||
self._running_tasks.add(task_id_str)
|
||||
|
||||
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
|
||||
# Get fresh task data
|
||||
task = await repo.get_by_id(task_id)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found")
|
||||
return
|
||||
|
||||
|
||||
# Check if task is still active and pending
|
||||
if not task.is_active or task.status != TaskStatus.PENDING:
|
||||
logger.info(f"Task {task_id} is not active or not pending, skipping")
|
||||
return
|
||||
|
||||
|
||||
# Check if task has expired
|
||||
if task.is_expired():
|
||||
logger.info(f"Task {task_id} has expired, marking as cancelled")
|
||||
@@ -325,78 +325,78 @@ class SchedulerService:
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
return
|
||||
|
||||
|
||||
# Mark task as running
|
||||
await repo.mark_as_running(task)
|
||||
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
handler_registry = TaskHandlerRegistry(
|
||||
session, self.db_session_factory, self.credit_service, self.player_service
|
||||
session, self.db_session_factory, self.credit_service, self.player_service,
|
||||
)
|
||||
await handler_registry.execute_task(task)
|
||||
|
||||
|
||||
# Calculate next execution time for recurring tasks
|
||||
next_execution_at = None
|
||||
if task.should_repeat():
|
||||
next_execution_at = self._calculate_next_execution(task)
|
||||
|
||||
|
||||
# Mark as completed
|
||||
await repo.mark_as_completed(task, next_execution_at)
|
||||
|
||||
|
||||
# Reschedule if recurring
|
||||
if next_execution_at and task.should_repeat():
|
||||
# Refresh task to get updated data
|
||||
await session.refresh(task)
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
await repo.mark_as_failed(task, str(e))
|
||||
logger.exception(f"Task {task_id} execution failed: {str(e)}")
|
||||
|
||||
logger.exception(f"Task {task_id} execution failed: {e!s}")
|
||||
|
||||
finally:
|
||||
self._running_tasks.discard(task_id_str)
|
||||
|
||||
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
|
||||
|
||||
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
|
||||
"""Calculate the next execution time for a recurring task."""
|
||||
now = datetime.utcnow()
|
||||
|
||||
|
||||
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||
return now + timedelta(hours=1)
|
||||
elif task.recurrence_type == RecurrenceType.DAILY:
|
||||
if task.recurrence_type == RecurrenceType.DAILY:
|
||||
return now + timedelta(days=1)
|
||||
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
if task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
return now + timedelta(weeks=1)
|
||||
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
if task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
# Add approximately one month
|
||||
return now + timedelta(days=30)
|
||||
elif task.recurrence_type == RecurrenceType.YEARLY:
|
||||
if task.recurrence_type == RecurrenceType.YEARLY:
|
||||
return now + timedelta(days=365)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _maintenance_job(self) -> None:
|
||||
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
|
||||
# Handle expired tasks
|
||||
expired_tasks = await repo.get_expired_tasks()
|
||||
for task in expired_tasks:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
|
||||
|
||||
# Remove from scheduler
|
||||
try:
|
||||
self.scheduler.remove_job(str(task.id))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if expired_tasks:
|
||||
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
|
||||
|
||||
|
||||
# Handle any missed recurring tasks
|
||||
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
||||
for task in due_recurring:
|
||||
@@ -405,9 +405,9 @@ class SchedulerService:
|
||||
task.scheduled_at = task.next_execution_at or datetime.utcnow()
|
||||
await repo.update(task)
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
|
||||
if due_recurring:
|
||||
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
|
||||
|
||||
|
||||
except Exception:
|
||||
logger.exception("Maintenance job failed")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Task execution handlers for different task types."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -18,7 +17,6 @@ logger = get_logger(__name__)
|
||||
class TaskExecutionError(Exception):
|
||||
"""Exception raised when task execution fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TaskHandlerRegistry:
|
||||
@@ -58,8 +56,8 @@ class TaskHandlerRegistry:
|
||||
await handler(task)
|
||||
logger.info(f"Task {task.id} executed successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"Task {task.id} execution failed: {str(e)}")
|
||||
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
|
||||
logger.exception(f"Task {task.id} execution failed: {e!s}")
|
||||
raise TaskExecutionError(f"Task execution failed: {e!s}") from e
|
||||
|
||||
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
|
||||
"""Handle credit recharge task."""
|
||||
@@ -72,7 +70,7 @@ class TaskHandlerRegistry:
|
||||
user_id_int = int(user_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TaskExecutionError(f"Invalid user_id format: {user_id}") from e
|
||||
|
||||
|
||||
stats = await self.credit_service.recharge_user_credits(user_id_int)
|
||||
logger.info(f"Recharged credits for user {user_id}: {stats}")
|
||||
else:
|
||||
@@ -105,7 +103,7 @@ class TaskHandlerRegistry:
|
||||
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
|
||||
except Exception as e:
|
||||
# Convert HTTP exceptions or credit errors to task execution errors
|
||||
raise TaskExecutionError(f"Failed to play sound with credits: {str(e)}") from e
|
||||
raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e
|
||||
else:
|
||||
# System task: play without credit deduction
|
||||
sound = await self.sound_repository.get_by_id(sound_id_int)
|
||||
@@ -116,10 +114,10 @@ class TaskHandlerRegistry:
|
||||
|
||||
vlc_service = VLCPlayerService(self.db_session_factory)
|
||||
success = await vlc_service.play_sound(sound)
|
||||
|
||||
|
||||
if not success:
|
||||
raise TaskExecutionError(f"Failed to play sound {sound.filename}")
|
||||
|
||||
|
||||
logger.info(f"Played sound {sound.filename} via scheduled system task")
|
||||
|
||||
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
|
||||
@@ -157,4 +155,4 @@ class TaskHandlerRegistry:
|
||||
# Start playing
|
||||
await self.player_service.play()
|
||||
|
||||
logger.info(f"Started playing playlist {playlist.name} via scheduled task")
|
||||
logger.info(f"Started playing playlist {playlist.name} via scheduled task")
|
||||
|
||||
@@ -238,13 +238,13 @@ class VLCPlayerService:
|
||||
return
|
||||
|
||||
logger.info("Recording play count for sound %s", sound_id)
|
||||
|
||||
|
||||
# Initialize variables for WebSocket event
|
||||
old_count = 0
|
||||
sound = None
|
||||
admin_user_id = None
|
||||
admin_user_name = None
|
||||
|
||||
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
sound_repo = SoundRepository(session)
|
||||
|
||||
Reference in New Issue
Block a user