Refactor code structure for improved readability and maintainability
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
"""API endpoints for scheduled task management."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
@@ -9,12 +11,15 @@ from app.core.dependencies import (
|
||||
get_admin_user,
|
||||
get_current_active_user,
|
||||
)
|
||||
from app.core.services import get_global_scheduler_service
|
||||
from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.user import User
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
from app.schemas.scheduler import (
|
||||
ScheduledTaskCreate,
|
||||
ScheduledTaskResponse,
|
||||
ScheduledTaskUpdate,
|
||||
TaskFilterParams,
|
||||
)
|
||||
from app.services.scheduler import SchedulerService
|
||||
|
||||
@@ -23,47 +28,21 @@ router = APIRouter(prefix="/scheduler")
|
||||
|
||||
def get_scheduler_service() -> SchedulerService:
|
||||
"""Get the global scheduler service instance."""
|
||||
from app.main import get_global_scheduler_service
|
||||
return get_global_scheduler_service()
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=ScheduledTaskResponse)
|
||||
async def create_task(
|
||||
task_data: ScheduledTaskCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task."""
|
||||
try:
|
||||
task = await scheduler_service.create_task(
|
||||
name=task_data.name,
|
||||
task_type=task_data.task_type,
|
||||
scheduled_at=task_data.scheduled_at,
|
||||
parameters=task_data.parameters,
|
||||
user_id=current_user.id,
|
||||
timezone=task_data.timezone,
|
||||
recurrence_type=task_data.recurrence_type,
|
||||
cron_expression=task_data.cron_expression,
|
||||
recurrence_count=task_data.recurrence_count,
|
||||
expires_at=task_data.expires_at,
|
||||
)
|
||||
return task
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_user_tasks(
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
limit: int | None = Query(50, description="Maximum number of tasks to return"),
|
||||
offset: int | None = Query(0, description="Number of tasks to skip"),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get user's scheduled tasks."""
|
||||
return await scheduler_service.get_user_tasks(
|
||||
user_id=current_user.id,
|
||||
def get_task_filters(
|
||||
status: Annotated[
|
||||
TaskStatus | None, Query(description="Filter by task status"),
|
||||
] = None,
|
||||
task_type: Annotated[
|
||||
TaskType | None, Query(description="Filter by task type"),
|
||||
] = None,
|
||||
limit: Annotated[int, Query(description="Maximum number of tasks to return")] = 50,
|
||||
offset: Annotated[int, Query(description="Number of tasks to skip")] = 0,
|
||||
) -> TaskFilterParams:
|
||||
"""Create task filter parameters from query parameters."""
|
||||
return TaskFilterParams(
|
||||
status=status,
|
||||
task_type=task_type,
|
||||
limit=limit,
|
||||
@@ -71,15 +50,45 @@ async def get_user_tasks(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=ScheduledTaskResponse)
|
||||
async def create_task(
|
||||
task_data: ScheduledTaskCreate,
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
scheduler_service: Annotated[SchedulerService, Depends(get_scheduler_service)],
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task."""
|
||||
try:
|
||||
return await scheduler_service.create_task(
|
||||
task_data=task_data,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_user_tasks(
|
||||
filters: Annotated[TaskFilterParams, Depends(get_task_filters)],
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
scheduler_service: Annotated[SchedulerService, Depends(get_scheduler_service)],
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get user's scheduled tasks."""
|
||||
return await scheduler_service.get_user_tasks(
|
||||
user_id=current_user.id,
|
||||
status=filters.status,
|
||||
task_type=filters.task_type,
|
||||
limit=filters.limit,
|
||||
offset=filters.offset,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=ScheduledTaskResponse)
|
||||
async def get_task(
|
||||
task_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> ScheduledTask:
|
||||
"""Get a specific scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
@@ -97,12 +106,10 @@ async def get_task(
|
||||
async def update_task(
|
||||
task_id: int,
|
||||
task_update: ScheduledTaskUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> ScheduledTask:
|
||||
"""Update a scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
@@ -118,20 +125,19 @@ async def update_task(
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
|
||||
updated_task = await repo.update(task)
|
||||
return updated_task
|
||||
return await repo.update(task)
|
||||
|
||||
|
||||
@router.delete("/tasks/{task_id}")
|
||||
async def cancel_task(
|
||||
task_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
current_user: Annotated[User, Depends(get_current_active_user)] = ...,
|
||||
scheduler_service: Annotated[
|
||||
SchedulerService, Depends(get_scheduler_service),
|
||||
] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> dict:
|
||||
"""Cancel a scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
@@ -152,20 +158,23 @@ async def cancel_task(
|
||||
# Admin-only endpoints
|
||||
@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_all_tasks(
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
limit: int | None = Query(100, description="Maximum number of tasks to return"),
|
||||
offset: int | None = Query(0, description="Number of tasks to skip"),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
status: Annotated[
|
||||
TaskStatus | None, Query(description="Filter by task status"),
|
||||
] = None,
|
||||
task_type: Annotated[
|
||||
TaskType | None, Query(description="Filter by task type"),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int | None, Query(description="Maximum number of tasks to return"),
|
||||
] = 100,
|
||||
offset: Annotated[
|
||||
int | None, Query(description="Number of tasks to skip"),
|
||||
] = 0,
|
||||
_: Annotated[User, Depends(get_admin_user)] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get all scheduled tasks (admin only)."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
|
||||
# Get all tasks with pagination and filtering
|
||||
from sqlmodel import select
|
||||
# Build query with pagination and filtering
|
||||
|
||||
statement = select(ScheduledTask)
|
||||
|
||||
@@ -189,14 +198,16 @@ async def get_all_tasks(
|
||||
|
||||
@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_system_tasks(
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
status: Annotated[
|
||||
TaskStatus | None, Query(description="Filter by task status"),
|
||||
] = None,
|
||||
task_type: Annotated[
|
||||
TaskType | None, Query(description="Filter by task type"),
|
||||
] = None,
|
||||
_: Annotated[User, Depends(get_admin_user)] = ...,
|
||||
db_session: Annotated[AsyncSession, Depends(get_db)] = ...,
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get system tasks (admin only)."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
return await repo.get_system_tasks(status=status, task_type=task_type)
|
||||
|
||||
@@ -204,23 +215,16 @@ async def get_system_tasks(
|
||||
@router.post("/admin/system-tasks", response_model=ScheduledTaskResponse)
|
||||
async def create_system_task(
|
||||
task_data: ScheduledTaskCreate,
|
||||
current_user: User = Depends(get_admin_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
_: Annotated[User, Depends(get_admin_user)] = ...,
|
||||
scheduler_service: Annotated[
|
||||
SchedulerService, Depends(get_scheduler_service),
|
||||
] = ...,
|
||||
) -> ScheduledTask:
|
||||
"""Create a system task (admin only)."""
|
||||
try:
|
||||
task = await scheduler_service.create_task(
|
||||
name=task_data.name,
|
||||
task_type=task_data.task_type,
|
||||
scheduled_at=task_data.scheduled_at,
|
||||
parameters=task_data.parameters,
|
||||
return await scheduler_service.create_task(
|
||||
task_data=task_data,
|
||||
user_id=None, # System task
|
||||
timezone=task_data.timezone,
|
||||
recurrence_type=task_data.recurrence_type,
|
||||
cron_expression=task_data.cron_expression,
|
||||
recurrence_count=task_data.recurrence_count,
|
||||
expires_at=task_data.expires_at,
|
||||
)
|
||||
return task
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
23
app/core/services.py
Normal file
23
app/core/services.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Global services container to avoid circular imports."""
|
||||
|
||||
from app.services.scheduler import SchedulerService
|
||||
|
||||
|
||||
class AppServices:
|
||||
"""Container for application services."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the application services container."""
|
||||
self.scheduler_service: SchedulerService | None = None
|
||||
|
||||
|
||||
# Global service container
|
||||
app_services = AppServices()
|
||||
|
||||
|
||||
def get_global_scheduler_service() -> SchedulerService:
|
||||
"""Get the global scheduler service instance."""
|
||||
if app_services.scheduler_service is None:
|
||||
msg = "Scheduler service not initialized"
|
||||
raise RuntimeError(msg)
|
||||
return app_services.scheduler_service
|
||||
25
app/main.py
25
app/main.py
@@ -9,6 +9,7 @@ from app.api import api_router
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_session_factory, init_db
|
||||
from app.core.logging import get_logger, setup_logging
|
||||
from app.core.services import app_services
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
from app.services.player import (
|
||||
@@ -19,22 +20,10 @@ from app.services.player import (
|
||||
from app.services.scheduler import SchedulerService
|
||||
from app.services.socket import socket_manager
|
||||
|
||||
scheduler_service = None
|
||||
|
||||
|
||||
def get_global_scheduler_service() -> SchedulerService:
|
||||
"""Get the global scheduler service instance."""
|
||||
global scheduler_service
|
||||
if scheduler_service is None:
|
||||
raise RuntimeError("Scheduler service not initialized")
|
||||
return scheduler_service
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
||||
"""Application lifespan context manager for setup and teardown."""
|
||||
global scheduler_service
|
||||
|
||||
setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting application")
|
||||
@@ -53,20 +42,22 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
||||
# Start the scheduler service
|
||||
try:
|
||||
player_service = get_player_service() # Get the initialized player service
|
||||
scheduler_service = SchedulerService(get_session_factory(), player_service)
|
||||
await scheduler_service.start()
|
||||
app_services.scheduler_service = SchedulerService(
|
||||
get_session_factory(), player_service,
|
||||
)
|
||||
await app_services.scheduler_service.start()
|
||||
logger.info("Enhanced scheduler service started")
|
||||
except Exception:
|
||||
logger.exception("Failed to start scheduler service - continuing without it")
|
||||
scheduler_service = None
|
||||
app_services.scheduler_service = None
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down application")
|
||||
|
||||
# Stop the scheduler service
|
||||
if scheduler_service:
|
||||
await scheduler_service.stop()
|
||||
if app_services.scheduler_service:
|
||||
await app_services.scheduler_service.stop()
|
||||
logger.info("Scheduler service stopped")
|
||||
|
||||
# Stop the player service
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Scheduled task model for flexible task scheduling with timezone support."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
@@ -42,7 +42,7 @@ class RecurrenceType(str, Enum):
|
||||
class ScheduledTask(BaseModel, table=True):
|
||||
"""Model for scheduled tasks with timezone support."""
|
||||
|
||||
__tablename__ = "scheduled_tasks"
|
||||
__tablename__ = "scheduled_task"
|
||||
|
||||
id: int | None = Field(primary_key=True, default=None)
|
||||
name: str = Field(max_length=255, description="Human-readable task name")
|
||||
@@ -53,12 +53,12 @@ class ScheduledTask(BaseModel, table=True):
|
||||
scheduled_at: datetime = Field(description="When the task should be executed (UTC)")
|
||||
timezone: str = Field(
|
||||
default="UTC",
|
||||
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
|
||||
description="Timezone for scheduling (e.g., 'America/New_York')",
|
||||
)
|
||||
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
|
||||
cron_expression: str | None = Field(
|
||||
default=None,
|
||||
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
|
||||
description="Cron expression for custom recurrence",
|
||||
)
|
||||
recurrence_count: int | None = Field(
|
||||
default=None,
|
||||
@@ -105,7 +105,7 @@ class ScheduledTask(BaseModel, table=True):
|
||||
"""Check if the task has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
return datetime.now(tz=UTC).replace(tzinfo=None) > self.expires_at
|
||||
|
||||
def is_recurring(self) -> bool:
|
||||
"""Check if the task is recurring."""
|
||||
|
||||
@@ -72,18 +72,22 @@ class BaseRepository[ModelType]:
|
||||
logger.exception("Failed to get all %s", self.model.__name__)
|
||||
raise
|
||||
|
||||
async def create(self, entity_data: dict[str, Any]) -> ModelType:
|
||||
async def create(self, entity_data: dict[str, Any] | ModelType) -> ModelType:
|
||||
"""Create a new entity.
|
||||
|
||||
Args:
|
||||
entity_data: Dictionary of entity data
|
||||
entity_data: Dictionary of entity data or model instance
|
||||
|
||||
Returns:
|
||||
The created entity
|
||||
|
||||
"""
|
||||
try:
|
||||
entity = self.model(**entity_data)
|
||||
if isinstance(entity_data, dict):
|
||||
entity = self.model(**entity_data)
|
||||
else:
|
||||
# Already a model instance
|
||||
entity = entity_data
|
||||
self.session.add(entity)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(entity)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Repository for scheduled task operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -23,7 +23,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
async def get_pending_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get all pending tasks that are ready to be executed."""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.status == TaskStatus.PENDING,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
@@ -90,7 +90,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
|
||||
"""Get recurring tasks that need their next execution scheduled."""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.recurrence_type != RecurrenceType.NONE,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
@@ -102,7 +102,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
|
||||
async def get_expired_tasks(self) -> list[ScheduledTask]:
|
||||
"""Get expired tasks that should be cleaned up."""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(tz=UTC)
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.expires_at.is_not(None),
|
||||
ScheduledTask.expires_at <= now,
|
||||
@@ -152,7 +152,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
) -> None:
|
||||
"""Mark a task as completed and set next execution if recurring."""
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.last_executed_at = datetime.utcnow()
|
||||
task.last_executed_at = datetime.now(tz=UTC)
|
||||
task.executions_count += 1
|
||||
task.error_message = None
|
||||
|
||||
@@ -170,7 +170,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
"""Mark a task as failed with error message."""
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error_message = error_message
|
||||
task.last_executed_at = datetime.utcnow()
|
||||
task.last_executed_at = datetime.now(tz=UTC)
|
||||
|
||||
# For non-recurring tasks, deactivate on failure
|
||||
if not task.is_recurring():
|
||||
|
||||
@@ -8,6 +8,15 @@ from pydantic import BaseModel, Field
|
||||
from app.models.scheduled_task import RecurrenceType, TaskStatus, TaskType
|
||||
|
||||
|
||||
class TaskFilterParams(BaseModel):
|
||||
"""Query parameters for filtering tasks."""
|
||||
|
||||
status: TaskStatus | None = Field(default=None, description="Filter by task status")
|
||||
task_type: TaskType | None = Field(default=None, description="Filter by task type")
|
||||
limit: int = Field(default=50, description="Maximum number of tasks to return")
|
||||
offset: int = Field(default=0, description="Number of tasks to skip")
|
||||
|
||||
|
||||
class ScheduledTaskBase(BaseModel):
|
||||
"""Base schema for scheduled tasks."""
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class PlayerState:
|
||||
"""Convert player state to dictionary for serialization."""
|
||||
return {
|
||||
"status": self.status.value,
|
||||
"mode": self.mode.value,
|
||||
"mode": self.mode.value if isinstance(self.mode, PlayerMode) else self.mode,
|
||||
"volume": self.volume,
|
||||
"previous_volume": self.previous_volume,
|
||||
"position": self.current_sound_position or 0,
|
||||
@@ -401,8 +401,16 @@ class PlayerService:
|
||||
if self.state.volume == 0 and self.state.previous_volume > 0:
|
||||
await self.set_volume(self.state.previous_volume)
|
||||
|
||||
async def set_mode(self, mode: PlayerMode) -> None:
|
||||
async def set_mode(self, mode: PlayerMode | str) -> None:
|
||||
"""Set playback mode."""
|
||||
if isinstance(mode, str):
|
||||
# Convert string to PlayerMode enum
|
||||
try:
|
||||
mode = PlayerMode(mode)
|
||||
except ValueError:
|
||||
logger.error("Invalid player mode: %s", mode)
|
||||
return
|
||||
|
||||
self.state.mode = mode
|
||||
await self._broadcast_state()
|
||||
logger.info("Playback mode set to: %s", mode.value)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
from contextlib import suppress
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytz
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
@@ -19,6 +19,7 @@ from app.models.scheduled_task import (
|
||||
TaskType,
|
||||
)
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
from app.schemas.scheduler import ScheduledTaskCreate
|
||||
from app.services.credit import CreditService
|
||||
from app.services.player import PlayerService
|
||||
from app.services.task_handlers import TaskHandlerRegistry
|
||||
@@ -57,7 +58,7 @@ class SchedulerService:
|
||||
self.scheduler.add_job(
|
||||
self._initialize_system_tasks,
|
||||
"date",
|
||||
run_date=datetime.utcnow() + timedelta(seconds=2),
|
||||
run_date=datetime.now(tz=UTC) + timedelta(seconds=2),
|
||||
id="initialize_system_tasks",
|
||||
name="Initialize System Tasks",
|
||||
replace_existing=True,
|
||||
@@ -83,46 +84,43 @@ class SchedulerService:
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
name: str,
|
||||
task_type: TaskType,
|
||||
scheduled_at: datetime,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
task_data: ScheduledTaskCreate,
|
||||
user_id: int | None = None,
|
||||
timezone: str = "UTC",
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE,
|
||||
cron_expression: str | None = None,
|
||||
recurrence_count: int | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task."""
|
||||
"""Create a new scheduled task from schema data."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Convert scheduled_at to UTC if it's in a different timezone
|
||||
if timezone != "UTC":
|
||||
tz = pytz.timezone(timezone)
|
||||
scheduled_at = task_data.scheduled_at
|
||||
if task_data.timezone != "UTC":
|
||||
tz = pytz.timezone(task_data.timezone)
|
||||
if scheduled_at.tzinfo is None:
|
||||
# Assume the datetime is in the specified timezone
|
||||
scheduled_at = tz.localize(scheduled_at)
|
||||
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
||||
|
||||
task_data = {
|
||||
"name": name,
|
||||
"task_type": task_type,
|
||||
db_task_data = {
|
||||
"name": task_data.name,
|
||||
"task_type": task_data.task_type,
|
||||
"scheduled_at": scheduled_at,
|
||||
"timezone": timezone,
|
||||
"parameters": parameters or {},
|
||||
"timezone": task_data.timezone,
|
||||
"parameters": task_data.parameters,
|
||||
"user_id": user_id,
|
||||
"recurrence_type": recurrence_type,
|
||||
"cron_expression": cron_expression,
|
||||
"recurrence_count": recurrence_count,
|
||||
"expires_at": expires_at,
|
||||
"recurrence_type": task_data.recurrence_type,
|
||||
"cron_expression": task_data.cron_expression,
|
||||
"recurrence_count": task_data.recurrence_count,
|
||||
"expires_at": task_data.expires_at,
|
||||
}
|
||||
|
||||
created_task = await repo.create(task_data)
|
||||
created_task = await repo.create(db_task_data)
|
||||
await self._schedule_apscheduler_job(created_task)
|
||||
|
||||
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
|
||||
logger.info(
|
||||
"Created scheduled task: %s (%s)",
|
||||
created_task.name,
|
||||
created_task.id,
|
||||
)
|
||||
return created_task
|
||||
|
||||
async def cancel_task(self, task_id: int) -> bool:
|
||||
@@ -134,17 +132,16 @@ class SchedulerService:
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
await repo.update(task, {
|
||||
"status": TaskStatus.CANCELLED,
|
||||
"is_active": False,
|
||||
})
|
||||
|
||||
# Remove from APScheduler
|
||||
try:
|
||||
# Remove from APScheduler (job might not exist in scheduler)
|
||||
with suppress(Exception):
|
||||
self.scheduler.remove_job(str(task_id))
|
||||
except Exception:
|
||||
pass # Job might not exist in scheduler
|
||||
|
||||
logger.info(f"Cancelled task: {task.name} ({task_id})")
|
||||
logger.info("Cancelled task: %s (%s)", task.name, task_id)
|
||||
return True
|
||||
|
||||
async def get_user_tasks(
|
||||
@@ -193,7 +190,7 @@ class SchedulerService:
|
||||
|
||||
if not daily_recharge_exists:
|
||||
# Create daily credit recharge task
|
||||
tomorrow_midnight = datetime.utcnow().replace(
|
||||
tomorrow_midnight = datetime.now(tz=UTC).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0,
|
||||
) + timedelta(days=1)
|
||||
|
||||
@@ -217,26 +214,29 @@ class SchedulerService:
|
||||
for task in active_tasks:
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
|
||||
logger.info("Loaded %s active tasks into scheduler", len(active_tasks))
|
||||
|
||||
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
||||
"""Schedule a task in APScheduler."""
|
||||
job_id = str(task.id)
|
||||
|
||||
# Remove existing job if it exists
|
||||
try:
|
||||
with suppress(Exception):
|
||||
self.scheduler.remove_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Don't schedule if task is not active or already completed/failed
|
||||
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||
inactive_statuses = [
|
||||
TaskStatus.COMPLETED,
|
||||
TaskStatus.FAILED,
|
||||
TaskStatus.CANCELLED,
|
||||
]
|
||||
if not task.is_active or task.status in inactive_statuses:
|
||||
return
|
||||
|
||||
# Create trigger based on recurrence type
|
||||
trigger = self._create_trigger(task)
|
||||
if not trigger:
|
||||
logger.warning(f"Could not create trigger for task {task.id}")
|
||||
logger.warning("Could not create trigger for task %s", task.id)
|
||||
return
|
||||
|
||||
# Schedule the job
|
||||
@@ -249,46 +249,51 @@ class SchedulerService:
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.debug(f"Scheduled APScheduler job for task {task.id}")
|
||||
logger.debug("Scheduled APScheduler job for task %s", task.id)
|
||||
|
||||
def _create_trigger(self, task: ScheduledTask):
|
||||
def _create_trigger(
|
||||
self, task: ScheduledTask,
|
||||
) -> DateTrigger | IntervalTrigger | CronTrigger | None:
|
||||
"""Create APScheduler trigger based on task configuration."""
|
||||
tz = pytz.timezone(task.timezone)
|
||||
scheduled_time = task.scheduled_at
|
||||
|
||||
# Handle special cases first
|
||||
if task.recurrence_type == RecurrenceType.NONE:
|
||||
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
|
||||
return DateTrigger(run_date=scheduled_time, timezone=tz)
|
||||
|
||||
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
||||
|
||||
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
|
||||
# Handle interval-based recurrence types
|
||||
interval_configs = {
|
||||
RecurrenceType.HOURLY: {"hours": 1},
|
||||
RecurrenceType.DAILY: {"days": 1},
|
||||
RecurrenceType.WEEKLY: {"weeks": 1},
|
||||
}
|
||||
|
||||
if task.recurrence_type == RecurrenceType.DAILY:
|
||||
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
|
||||
if task.recurrence_type in interval_configs:
|
||||
config = interval_configs[task.recurrence_type]
|
||||
return IntervalTrigger(start_date=scheduled_time, timezone=tz, **config)
|
||||
|
||||
if task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
|
||||
# Handle cron-based recurrence types
|
||||
cron_configs = {
|
||||
RecurrenceType.MONTHLY: {
|
||||
"day": scheduled_time.day,
|
||||
"hour": scheduled_time.hour,
|
||||
"minute": scheduled_time.minute,
|
||||
},
|
||||
RecurrenceType.YEARLY: {
|
||||
"month": scheduled_time.month,
|
||||
"day": scheduled_time.day,
|
||||
"hour": scheduled_time.hour,
|
||||
"minute": scheduled_time.minute,
|
||||
},
|
||||
}
|
||||
|
||||
if task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
# Use cron trigger for monthly (more reliable than interval)
|
||||
scheduled_time = task.scheduled_at
|
||||
return CronTrigger(
|
||||
day=scheduled_time.day,
|
||||
hour=scheduled_time.hour,
|
||||
minute=scheduled_time.minute,
|
||||
timezone=tz,
|
||||
)
|
||||
|
||||
if task.recurrence_type == RecurrenceType.YEARLY:
|
||||
scheduled_time = task.scheduled_at
|
||||
return CronTrigger(
|
||||
month=scheduled_time.month,
|
||||
day=scheduled_time.day,
|
||||
hour=scheduled_time.hour,
|
||||
minute=scheduled_time.minute,
|
||||
timezone=tz,
|
||||
)
|
||||
if task.recurrence_type in cron_configs:
|
||||
config = cron_configs[task.recurrence_type]
|
||||
return CronTrigger(timezone=tz, **config)
|
||||
|
||||
return None
|
||||
|
||||
@@ -298,7 +303,7 @@ class SchedulerService:
|
||||
|
||||
# Prevent concurrent execution of the same task
|
||||
if task_id_str in self._running_tasks:
|
||||
logger.warning(f"Task {task_id} is already running, skipping execution")
|
||||
logger.warning("Task %s is already running, skipping execution", task_id)
|
||||
return
|
||||
|
||||
self._running_tasks.add(task_id_str)
|
||||
@@ -310,20 +315,21 @@ class SchedulerService:
|
||||
# Get fresh task data
|
||||
task = await repo.get_by_id(task_id)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found")
|
||||
logger.warning("Task %s not found", task_id)
|
||||
return
|
||||
|
||||
# Check if task is still active and pending
|
||||
if not task.is_active or task.status != TaskStatus.PENDING:
|
||||
logger.info(f"Task {task_id} is not active or not pending, skipping")
|
||||
logger.info("Task %s not active or not pending, skipping", task_id)
|
||||
return
|
||||
|
||||
# Check if task has expired
|
||||
if task.is_expired():
|
||||
logger.info(f"Task {task_id} has expired, marking as cancelled")
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
logger.info("Task %s has expired, marking as cancelled", task_id)
|
||||
await repo.update(task, {
|
||||
"status": TaskStatus.CANCELLED,
|
||||
"is_active": False,
|
||||
})
|
||||
return
|
||||
|
||||
# Mark task as running
|
||||
@@ -332,7 +338,10 @@ class SchedulerService:
|
||||
# Execute the task
|
||||
try:
|
||||
handler_registry = TaskHandlerRegistry(
|
||||
session, self.db_session_factory, self.credit_service, self.player_service,
|
||||
session,
|
||||
self.db_session_factory,
|
||||
self.credit_service,
|
||||
self.player_service,
|
||||
)
|
||||
await handler_registry.execute_task(task)
|
||||
|
||||
@@ -352,14 +361,14 @@ class SchedulerService:
|
||||
|
||||
except Exception as e:
|
||||
await repo.mark_as_failed(task, str(e))
|
||||
logger.exception(f"Task {task_id} execution failed: {e!s}")
|
||||
logger.exception("Task %s execution failed", task_id)
|
||||
|
||||
finally:
|
||||
self._running_tasks.discard(task_id_str)
|
||||
|
||||
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
|
||||
"""Calculate the next execution time for a recurring task."""
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(tz=UTC)
|
||||
|
||||
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||
return now + timedelta(hours=1)
|
||||
@@ -376,7 +385,7 @@ class SchedulerService:
|
||||
return None
|
||||
|
||||
async def _maintenance_job(self) -> None:
|
||||
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
|
||||
"""Periodic maintenance job to clean up expired tasks and handle scheduling."""
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
@@ -384,30 +393,33 @@ class SchedulerService:
|
||||
# Handle expired tasks
|
||||
expired_tasks = await repo.get_expired_tasks()
|
||||
for task in expired_tasks:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
await repo.update(task, {
|
||||
"status": TaskStatus.CANCELLED,
|
||||
"is_active": False,
|
||||
})
|
||||
|
||||
# Remove from scheduler
|
||||
try:
|
||||
with suppress(Exception):
|
||||
self.scheduler.remove_job(str(task.id))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if expired_tasks:
|
||||
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
|
||||
logger.info("Cleaned up %s expired tasks", len(expired_tasks))
|
||||
|
||||
# Handle any missed recurring tasks
|
||||
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
||||
for task in due_recurring:
|
||||
if task.should_repeat():
|
||||
task.status = TaskStatus.PENDING
|
||||
task.scheduled_at = task.next_execution_at or datetime.utcnow()
|
||||
await repo.update(task)
|
||||
next_scheduled_at = (
|
||||
task.next_execution_at or datetime.now(tz=UTC)
|
||||
)
|
||||
await repo.update(task, {
|
||||
"status": TaskStatus.PENDING,
|
||||
"scheduled_at": next_scheduled_at,
|
||||
})
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
if due_recurring:
|
||||
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
|
||||
logger.info("Rescheduled %s recurring tasks", len(due_recurring))
|
||||
|
||||
except Exception:
|
||||
logger.exception("Maintenance job failed")
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.repositories.playlist import PlaylistRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.credit import CreditService
|
||||
from app.services.player import PlayerService
|
||||
from app.services.vlc_player import VLCPlayerService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -48,16 +49,23 @@ class TaskHandlerRegistry:
|
||||
"""Execute a task based on its type."""
|
||||
handler = self._handlers.get(task.task_type)
|
||||
if not handler:
|
||||
raise TaskExecutionError(f"No handler registered for task type: {task.task_type}")
|
||||
msg = f"No handler registered for task type: {task.task_type}"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
logger.info(f"Executing task {task.id} ({task.task_type.value}): {task.name}")
|
||||
logger.info(
|
||||
"Executing task %s (%s): %s",
|
||||
task.id,
|
||||
task.task_type.value,
|
||||
task.name,
|
||||
)
|
||||
|
||||
try:
|
||||
await handler(task)
|
||||
logger.info(f"Task {task.id} executed successfully")
|
||||
logger.info("Task %s executed successfully", task.id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Task {task.id} execution failed: {e!s}")
|
||||
raise TaskExecutionError(f"Task execution failed: {e!s}") from e
|
||||
logger.exception("Task %s execution failed", task.id)
|
||||
msg = f"Task execution failed: {e!s}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
|
||||
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
|
||||
"""Handle credit recharge task."""
|
||||
@@ -69,14 +77,15 @@ class TaskHandlerRegistry:
|
||||
try:
|
||||
user_id_int = int(user_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TaskExecutionError(f"Invalid user_id format: {user_id}") from e
|
||||
msg = f"Invalid user_id format: {user_id}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
|
||||
stats = await self.credit_service.recharge_user_credits(user_id_int)
|
||||
logger.info(f"Recharged credits for user {user_id}: {stats}")
|
||||
logger.info("Recharged credits for user %s: %s", user_id, stats)
|
||||
else:
|
||||
# Recharge all users (system task)
|
||||
stats = await self.credit_service.recharge_all_users_credits()
|
||||
logger.info(f"Recharged credits for all users: {stats}")
|
||||
logger.info("Recharged credits for all users: %s", stats)
|
||||
|
||||
async def _handle_play_sound(self, task: ScheduledTask) -> None:
|
||||
"""Handle play sound task."""
|
||||
@@ -84,41 +93,54 @@ class TaskHandlerRegistry:
|
||||
sound_id = parameters.get("sound_id")
|
||||
|
||||
if not sound_id:
|
||||
raise TaskExecutionError("sound_id parameter is required for PLAY_SOUND tasks")
|
||||
msg = "sound_id parameter is required for PLAY_SOUND tasks"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
try:
|
||||
# Handle both integer and string sound IDs
|
||||
sound_id_int = int(sound_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TaskExecutionError(f"Invalid sound_id format: {sound_id}") from e
|
||||
msg = f"Invalid sound_id format: {sound_id}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
|
||||
# Check if this is a user task (has user_id)
|
||||
if task.user_id:
|
||||
# User task: use credit-aware playback
|
||||
from app.services.vlc_player import VLCPlayerService
|
||||
|
||||
vlc_service = VLCPlayerService(self.db_session_factory)
|
||||
try:
|
||||
result = await vlc_service.play_sound_with_credits(sound_id_int, task.user_id)
|
||||
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
|
||||
result = await vlc_service.play_sound_with_credits(
|
||||
sound_id_int, task.user_id,
|
||||
)
|
||||
logger.info(
|
||||
(
|
||||
"Played sound %s via scheduled task for user %s "
|
||||
"(credits deducted: %s)"
|
||||
),
|
||||
result.get("sound_name", sound_id),
|
||||
task.user_id,
|
||||
result.get("credits_deducted", 0),
|
||||
)
|
||||
except Exception as e:
|
||||
# Convert HTTP exceptions or credit errors to task execution errors
|
||||
raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e
|
||||
msg = f"Failed to play sound with credits: {e!s}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
else:
|
||||
# System task: play without credit deduction
|
||||
sound = await self.sound_repository.get_by_id(sound_id_int)
|
||||
if not sound:
|
||||
raise TaskExecutionError(f"Sound not found: {sound_id}")
|
||||
msg = f"Sound not found: {sound_id}"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
from app.services.vlc_player import VLCPlayerService
|
||||
|
||||
vlc_service = VLCPlayerService(self.db_session_factory)
|
||||
success = await vlc_service.play_sound(sound)
|
||||
|
||||
if not success:
|
||||
raise TaskExecutionError(f"Failed to play sound {sound.filename}")
|
||||
msg = f"Failed to play sound {sound.filename}"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
logger.info(f"Played sound {sound.filename} via scheduled system task")
|
||||
logger.info("Played sound %s via scheduled system task", sound.filename)
|
||||
|
||||
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
|
||||
"""Handle play playlist task."""
|
||||
@@ -128,31 +150,34 @@ class TaskHandlerRegistry:
|
||||
shuffle = parameters.get("shuffle", False)
|
||||
|
||||
if not playlist_id:
|
||||
raise TaskExecutionError("playlist_id parameter is required for PLAY_PLAYLIST tasks")
|
||||
msg = "playlist_id parameter is required for PLAY_PLAYLIST tasks"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
try:
|
||||
# Handle both integer and string playlist IDs
|
||||
playlist_id_int = int(playlist_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TaskExecutionError(f"Invalid playlist_id format: {playlist_id}") from e
|
||||
msg = f"Invalid playlist_id format: {playlist_id}"
|
||||
raise TaskExecutionError(msg) from e
|
||||
|
||||
# Get the playlist from database
|
||||
playlist = await self.playlist_repository.get_by_id(playlist_id_int)
|
||||
if not playlist:
|
||||
raise TaskExecutionError(f"Playlist not found: {playlist_id}")
|
||||
msg = f"Playlist not found: {playlist_id}"
|
||||
raise TaskExecutionError(msg)
|
||||
|
||||
# Load playlist in player
|
||||
await self.player_service.load_playlist(playlist_id_int)
|
||||
|
||||
# Set play mode if specified
|
||||
if play_mode in ["continuous", "loop", "loop_one", "random", "single"]:
|
||||
self.player_service.set_mode(play_mode)
|
||||
await self.player_service.set_mode(play_mode)
|
||||
|
||||
# Enable shuffle if requested
|
||||
if shuffle:
|
||||
self.player_service.set_shuffle(True)
|
||||
await self.player_service.set_shuffle(shuffle=True)
|
||||
|
||||
# Start playing
|
||||
await self.player_service.play()
|
||||
|
||||
logger.info(f"Started playing playlist {playlist.name} via scheduled task")
|
||||
logger.info("Started playing playlist %s via scheduled task", playlist.name)
|
||||
|
||||
Reference in New Issue
Block a user