Add comprehensive tests for scheduled task repository, scheduler service, and task handlers
- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates. - Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs. - Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback. - Ensured proper error handling and edge cases in task execution scenarios. - Added fixtures and mocks to facilitate isolated testing of services and repositories.
This commit is contained in:
@@ -12,6 +12,7 @@ from app.api.v1 import (
|
||||
main,
|
||||
player,
|
||||
playlists,
|
||||
scheduler,
|
||||
socket,
|
||||
sounds,
|
||||
)
|
||||
@@ -28,6 +29,7 @@ api_router.include_router(files.router, tags=["files"])
|
||||
api_router.include_router(main.router, tags=["main"])
|
||||
api_router.include_router(player.router, tags=["player"])
|
||||
api_router.include_router(playlists.router, tags=["playlists"])
|
||||
api_router.include_router(scheduler.router, tags=["scheduler"])
|
||||
api_router.include_router(socket.router, tags=["socket"])
|
||||
api_router.include_router(sounds.router, tags=["sounds"])
|
||||
api_router.include_router(admin.router)
|
||||
|
||||
228
app/api/v1/scheduler.py
Normal file
228
app/api/v1/scheduler.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""API endpoints for scheduled task management."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.dependencies import (
|
||||
get_admin_user,
|
||||
get_current_active_user,
|
||||
)
|
||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.user import User
|
||||
from app.schemas.scheduler import (
|
||||
ScheduledTaskCreate,
|
||||
ScheduledTaskResponse,
|
||||
ScheduledTaskUpdate,
|
||||
)
|
||||
from app.services.scheduler import SchedulerService
|
||||
|
||||
router = APIRouter(prefix="/scheduler")
|
||||
|
||||
|
||||
def get_scheduler_service() -> SchedulerService:
|
||||
"""Get the global scheduler service instance."""
|
||||
from app.main import get_global_scheduler_service
|
||||
return get_global_scheduler_service()
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=ScheduledTaskResponse)
|
||||
async def create_task(
|
||||
task_data: ScheduledTaskCreate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task."""
|
||||
try:
|
||||
task = await scheduler_service.create_task(
|
||||
name=task_data.name,
|
||||
task_type=task_data.task_type,
|
||||
scheduled_at=task_data.scheduled_at,
|
||||
parameters=task_data.parameters,
|
||||
user_id=current_user.id,
|
||||
timezone=task_data.timezone,
|
||||
recurrence_type=task_data.recurrence_type,
|
||||
cron_expression=task_data.cron_expression,
|
||||
recurrence_count=task_data.recurrence_count,
|
||||
expires_at=task_data.expires_at,
|
||||
)
|
||||
return task
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=List[ScheduledTaskResponse])
|
||||
async def get_user_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"),
|
||||
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get user's scheduled tasks."""
|
||||
return await scheduler_service.get_user_tasks(
|
||||
user_id=current_user.id,
|
||||
status=status,
|
||||
task_type=task_type,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=ScheduledTaskResponse)
|
||||
async def get_task(
|
||||
task_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> ScheduledTask:
|
||||
"""Get a specific scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@router.patch("/tasks/{task_id}", response_model=ScheduledTaskResponse)
|
||||
async def update_task(
|
||||
task_id: int,
|
||||
task_update: ScheduledTaskUpdate,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> ScheduledTask:
|
||||
"""Update a scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Update task fields
|
||||
update_data = task_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
|
||||
updated_task = await repo.update(task)
|
||||
return updated_task
|
||||
|
||||
|
||||
@router.delete("/tasks/{task_id}")
|
||||
async def cancel_task(
|
||||
task_id: int,
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Cancel a scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
success = await scheduler_service.cancel_task(task_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to cancel task")
|
||||
|
||||
return {"message": "Task cancelled successfully"}
|
||||
|
||||
|
||||
# Admin-only endpoints
|
||||
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse])
|
||||
async def get_all_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"),
|
||||
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get all scheduled tasks (admin only)."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
|
||||
# Get all tasks with pagination and filtering
|
||||
from sqlmodel import select
|
||||
|
||||
statement = select(ScheduledTask)
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await db_session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
|
||||
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse])
|
||||
async def get_system_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get system tasks (admin only)."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
return await repo.get_system_tasks(status=status, task_type=task_type)
|
||||
|
||||
|
||||
@router.post("/admin/system-tasks", response_model=ScheduledTaskResponse)
|
||||
async def create_system_task(
|
||||
task_data: ScheduledTaskCreate,
|
||||
current_user: User = Depends(get_admin_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
) -> ScheduledTask:
|
||||
"""Create a system task (admin only)."""
|
||||
try:
|
||||
task = await scheduler_service.create_task(
|
||||
name=task_data.name,
|
||||
task_type=task_data.task_type,
|
||||
scheduled_at=task_data.scheduled_at,
|
||||
parameters=task_data.parameters,
|
||||
user_id=None, # System task
|
||||
timezone=task_data.timezone,
|
||||
recurrence_type=task_data.recurrence_type,
|
||||
cron_expression=task_data.cron_expression,
|
||||
recurrence_count=task_data.recurrence_count,
|
||||
expires_at=task_data.expires_at,
|
||||
)
|
||||
return task
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -7,17 +7,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from app.core.config import settings
|
||||
from app.core.logging import get_logger
|
||||
from app.core.seeds import seed_all_data
|
||||
from app.models import ( # noqa: F401
|
||||
extraction,
|
||||
favorite,
|
||||
plan,
|
||||
playlist,
|
||||
playlist_sound,
|
||||
sound,
|
||||
sound_played,
|
||||
user,
|
||||
user_oauth,
|
||||
)
|
||||
# Import all models to ensure SQLModel metadata discovery
|
||||
import app.models # noqa: F401
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
|
||||
31
app/main.py
31
app/main.py
@@ -11,14 +11,27 @@ from app.core.database import get_session_factory, init_db
|
||||
from app.core.logging import get_logger, setup_logging
|
||||
from app.middleware.logging import LoggingMiddleware
|
||||
from app.services.extraction_processor import extraction_processor
|
||||
from app.services.player import initialize_player_service, shutdown_player_service
|
||||
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service
|
||||
from app.services.scheduler import SchedulerService
|
||||
from app.services.socket import socket_manager
|
||||
|
||||
|
||||
scheduler_service = None
|
||||
|
||||
|
||||
def get_global_scheduler_service() -> SchedulerService:
|
||||
"""Get the global scheduler service instance."""
|
||||
global scheduler_service
|
||||
if scheduler_service is None:
|
||||
raise RuntimeError("Scheduler service not initialized")
|
||||
return scheduler_service
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
||||
"""Application lifespan context manager for setup and teardown."""
|
||||
global scheduler_service
|
||||
|
||||
setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Starting application")
|
||||
@@ -35,17 +48,23 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
||||
logger.info("Player service started")
|
||||
|
||||
# Start the scheduler service
|
||||
scheduler_service = SchedulerService(get_session_factory())
|
||||
await scheduler_service.start()
|
||||
logger.info("Scheduler service started")
|
||||
try:
|
||||
player_service = get_player_service() # Get the initialized player service
|
||||
scheduler_service = SchedulerService(get_session_factory(), player_service)
|
||||
await scheduler_service.start()
|
||||
logger.info("Enhanced scheduler service started")
|
||||
except Exception:
|
||||
logger.exception("Failed to start scheduler service - continuing without it")
|
||||
scheduler_service = None
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down application")
|
||||
|
||||
# Stop the scheduler service
|
||||
await scheduler_service.stop()
|
||||
logger.info("Scheduler service stopped")
|
||||
if scheduler_service:
|
||||
await scheduler_service.stop()
|
||||
logger.info("Scheduler service stopped")
|
||||
|
||||
# Stop the player service
|
||||
await shutdown_player_service()
|
||||
|
||||
@@ -1 +1,32 @@
|
||||
"""Models package."""
|
||||
|
||||
# Import all models for SQLAlchemy metadata discovery
|
||||
from .base import BaseModel
|
||||
from .credit_action import CreditAction
|
||||
from .credit_transaction import CreditTransaction
|
||||
from .extraction import Extraction
|
||||
from .favorite import Favorite
|
||||
from .plan import Plan
|
||||
from .playlist import Playlist
|
||||
from .playlist_sound import PlaylistSound
|
||||
from .scheduled_task import ScheduledTask
|
||||
from .sound import Sound
|
||||
from .sound_played import SoundPlayed
|
||||
from .user import User
|
||||
from .user_oauth import UserOauth
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"CreditAction",
|
||||
"CreditTransaction",
|
||||
"Extraction",
|
||||
"Favorite",
|
||||
"Plan",
|
||||
"Playlist",
|
||||
"PlaylistSound",
|
||||
"ScheduledTask",
|
||||
"Sound",
|
||||
"SoundPlayed",
|
||||
"User",
|
||||
"UserOauth",
|
||||
]
|
||||
|
||||
125
app/models/scheduled_task.py
Normal file
125
app/models/scheduled_task.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Scheduled task model for flexible task scheduling with timezone support."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlmodel import JSON, Column, Field, SQLModel
|
||||
|
||||
from app.models.base import BaseModel
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
"""Available task types."""
|
||||
|
||||
CREDIT_RECHARGE = "credit_recharge"
|
||||
PLAY_SOUND = "play_sound"
|
||||
PLAY_PLAYLIST = "play_playlist"
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task execution status."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class RecurrenceType(str, Enum):
|
||||
"""Recurrence patterns."""
|
||||
|
||||
NONE = "none" # One-shot task
|
||||
HOURLY = "hourly"
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
YEARLY = "yearly"
|
||||
CRON = "cron" # Custom cron expression
|
||||
|
||||
|
||||
class ScheduledTask(BaseModel, table=True):
|
||||
"""Model for scheduled tasks with timezone support."""
|
||||
|
||||
__tablename__ = "scheduled_tasks"
|
||||
|
||||
id: int | None = Field(primary_key=True, default=None)
|
||||
name: str = Field(max_length=255, description="Human-readable task name")
|
||||
task_type: TaskType = Field(description="Type of task to execute")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING)
|
||||
|
||||
# Scheduling fields with timezone support
|
||||
scheduled_at: datetime = Field(description="When the task should be executed (UTC)")
|
||||
timezone: str = Field(
|
||||
default="UTC",
|
||||
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
|
||||
)
|
||||
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
|
||||
cron_expression: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
|
||||
)
|
||||
recurrence_count: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Number of times to repeat (None for infinite)",
|
||||
)
|
||||
executions_count: int = Field(default=0, description="Number of times executed")
|
||||
|
||||
# Task parameters
|
||||
parameters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
sa_column=Column(JSON),
|
||||
description="Task-specific parameters",
|
||||
)
|
||||
|
||||
# User association (None for system tasks)
|
||||
user_id: Optional[int] = Field(
|
||||
default=None,
|
||||
foreign_key="user.id",
|
||||
description="User who created the task (None for system tasks)",
|
||||
)
|
||||
|
||||
# Execution tracking
|
||||
last_executed_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
description="When the task was last executed (UTC)",
|
||||
)
|
||||
next_execution_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
description="When the task should be executed next (UTC, for recurring tasks)",
|
||||
)
|
||||
error_message: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Error message if execution failed",
|
||||
)
|
||||
|
||||
# Task lifecycle
|
||||
is_active: bool = Field(default=True, description="Whether the task is active")
|
||||
expires_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
description="When the task expires (UTC, optional)",
|
||||
)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the task has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.utcnow() > self.expires_at
|
||||
|
||||
def is_recurring(self) -> bool:
|
||||
"""Check if the task is recurring."""
|
||||
return self.recurrence_type != RecurrenceType.NONE
|
||||
|
||||
def should_repeat(self) -> bool:
|
||||
"""Check if the task should be repeated."""
|
||||
if not self.is_recurring():
|
||||
return False
|
||||
if self.recurrence_count is None:
|
||||
return True
|
||||
return self.executions_count < self.recurrence_count
|
||||
|
||||
def is_system_task(self) -> bool:
|
||||
"""Check if this is a system task (no user association)."""
|
||||
return self.user_id is None
|
||||
177
app/repositories/scheduled_task.py
Normal file
177
app/repositories/scheduled_task.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Repository for scheduled task operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||
"""Repository for scheduled task database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize the repository."""
|
||||
super().__init__(ScheduledTask, session)
|
||||
|
||||
async def get_pending_tasks(self) -> List[ScheduledTask]:
|
||||
"""Get all pending tasks that are ready to be executed."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.status == TaskStatus.PENDING,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.scheduled_at <= now,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_active_tasks(self) -> List[ScheduledTask]:
|
||||
"""Get all active tasks."""
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_system_tasks(
|
||||
self,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get system tasks (tasks with no user association)."""
|
||||
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
|
||||
"""Get recurring tasks that need their next execution scheduled."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.recurrence_type != RecurrenceType.NONE,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
ScheduledTask.status == TaskStatus.COMPLETED,
|
||||
ScheduledTask.next_execution_at <= now,
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def get_expired_tasks(self) -> List[ScheduledTask]:
|
||||
"""Get expired tasks that should be cleaned up."""
|
||||
now = datetime.utcnow()
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.expires_at.is_not(None),
|
||||
ScheduledTask.expires_at <= now,
|
||||
ScheduledTask.is_active.is_(True),
|
||||
)
|
||||
result = await self.session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
async def cancel_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
task_type: Optional[TaskType] = None,
|
||||
) -> int:
|
||||
"""Cancel all pending/running tasks for a user."""
|
||||
statement = select(ScheduledTask).where(
|
||||
ScheduledTask.user_id == user_id,
|
||||
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
||||
)
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
result = await self.session.exec(statement)
|
||||
tasks = list(result.all())
|
||||
|
||||
count = 0
|
||||
for task in tasks:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
self.session.add(task)
|
||||
count += 1
|
||||
|
||||
await self.session.commit()
|
||||
return count
|
||||
|
||||
async def mark_as_running(self, task: ScheduledTask) -> None:
|
||||
"""Mark a task as running."""
|
||||
task.status = TaskStatus.RUNNING
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
|
||||
async def mark_as_completed(
|
||||
self,
|
||||
task: ScheduledTask,
|
||||
next_execution_at: Optional[datetime] = None,
|
||||
) -> None:
|
||||
"""Mark a task as completed and set next execution if recurring."""
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.last_executed_at = datetime.utcnow()
|
||||
task.executions_count += 1
|
||||
task.error_message = None
|
||||
|
||||
if next_execution_at and task.should_repeat():
|
||||
task.next_execution_at = next_execution_at
|
||||
task.status = TaskStatus.PENDING
|
||||
elif not task.should_repeat():
|
||||
task.is_active = False
|
||||
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
|
||||
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
|
||||
"""Mark a task as failed with error message."""
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error_message = error_message
|
||||
task.last_executed_at = datetime.utcnow()
|
||||
|
||||
# For non-recurring tasks, deactivate on failure
|
||||
if not task.is_recurring():
|
||||
task.is_active = False
|
||||
|
||||
self.session.add(task)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(task)
|
||||
189
app/schemas/scheduler.py
Normal file
189
app/schemas/scheduler.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Schemas for scheduled task API."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.models.scheduled_task import RecurrenceType, TaskStatus, TaskType
|
||||
|
||||
|
||||
class ScheduledTaskBase(BaseModel):
|
||||
"""Base schema for scheduled tasks."""
|
||||
|
||||
name: str = Field(description="Human-readable task name")
|
||||
task_type: TaskType = Field(description="Type of task to execute")
|
||||
scheduled_at: datetime = Field(description="When the task should be executed")
|
||||
timezone: str = Field(default="UTC", description="Timezone for scheduling")
|
||||
parameters: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Task-specific parameters",
|
||||
)
|
||||
recurrence_type: RecurrenceType = Field(
|
||||
default=RecurrenceType.NONE,
|
||||
description="Recurrence pattern",
|
||||
)
|
||||
cron_expression: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Cron expression for custom recurrence",
|
||||
)
|
||||
recurrence_count: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Number of times to repeat (None for infinite)",
|
||||
)
|
||||
expires_at: Optional[datetime] = Field(
|
||||
default=None,
|
||||
description="When the task expires (optional)",
|
||||
)
|
||||
|
||||
|
||||
class ScheduledTaskCreate(ScheduledTaskBase):
|
||||
"""Schema for creating a scheduled task."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScheduledTaskUpdate(BaseModel):
|
||||
"""Schema for updating a scheduled task."""
|
||||
|
||||
name: Optional[str] = None
|
||||
scheduled_at: Optional[datetime] = None
|
||||
timezone: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
is_active: Optional[bool] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class ScheduledTaskResponse(ScheduledTaskBase):
|
||||
"""Schema for scheduled task responses."""
|
||||
|
||||
id: int
|
||||
status: TaskStatus
|
||||
user_id: Optional[int] = None
|
||||
executions_count: int
|
||||
last_executed_at: Optional[datetime] = None
|
||||
next_execution_at: Optional[datetime] = None
|
||||
error_message: Optional[str] = None
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration."""
|
||||
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Task-specific parameter schemas
|
||||
class CreditRechargeParameters(BaseModel):
|
||||
"""Parameters for credit recharge tasks."""
|
||||
|
||||
user_id: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Specific user ID to recharge (None for all users)",
|
||||
)
|
||||
|
||||
|
||||
class PlaySoundParameters(BaseModel):
|
||||
"""Parameters for play sound tasks."""
|
||||
|
||||
sound_id: int = Field(description="ID of the sound to play")
|
||||
|
||||
|
||||
class PlayPlaylistParameters(BaseModel):
|
||||
"""Parameters for play playlist tasks."""
|
||||
|
||||
playlist_id: int = Field(description="ID of the playlist to play")
|
||||
play_mode: str = Field(
|
||||
default="continuous",
|
||||
description="Play mode (continuous, loop, loop_one, random, single)",
|
||||
)
|
||||
shuffle: bool = Field(default=False, description="Whether to shuffle the playlist")
|
||||
|
||||
|
||||
# Convenience schemas for creating specific task types
|
||||
class CreateCreditRechargeTask(BaseModel):
|
||||
"""Schema for creating credit recharge tasks."""
|
||||
|
||||
name: str = "Credit Recharge"
|
||||
scheduled_at: datetime
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: Optional[str] = None
|
||||
recurrence_count: Optional[int] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
user_id: Optional[int] = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
return ScheduledTaskCreate(
|
||||
name=self.name,
|
||||
task_type=TaskType.CREDIT_RECHARGE,
|
||||
scheduled_at=self.scheduled_at,
|
||||
timezone=self.timezone,
|
||||
parameters={"user_id": self.user_id},
|
||||
recurrence_type=self.recurrence_type,
|
||||
cron_expression=self.cron_expression,
|
||||
recurrence_count=self.recurrence_count,
|
||||
expires_at=self.expires_at,
|
||||
)
|
||||
|
||||
|
||||
class CreatePlaySoundTask(BaseModel):
|
||||
"""Schema for creating play sound tasks."""
|
||||
|
||||
name: str
|
||||
scheduled_at: datetime
|
||||
sound_id: int
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: Optional[str] = None
|
||||
recurrence_count: Optional[int] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
return ScheduledTaskCreate(
|
||||
name=self.name,
|
||||
task_type=TaskType.PLAY_SOUND,
|
||||
scheduled_at=self.scheduled_at,
|
||||
timezone=self.timezone,
|
||||
parameters={"sound_id": self.sound_id},
|
||||
recurrence_type=self.recurrence_type,
|
||||
cron_expression=self.cron_expression,
|
||||
recurrence_count=self.recurrence_count,
|
||||
expires_at=self.expires_at,
|
||||
)
|
||||
|
||||
|
||||
class CreatePlayPlaylistTask(BaseModel):
|
||||
"""Schema for creating play playlist tasks."""
|
||||
|
||||
name: str
|
||||
scheduled_at: datetime
|
||||
playlist_id: int
|
||||
play_mode: str = "continuous"
|
||||
shuffle: bool = False
|
||||
timezone: str = "UTC"
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||
cron_expression: Optional[str] = None
|
||||
recurrence_count: Optional[int] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
def to_task_create(self) -> ScheduledTaskCreate:
|
||||
"""Convert to generic task creation schema."""
|
||||
return ScheduledTaskCreate(
|
||||
name=self.name,
|
||||
task_type=TaskType.PLAY_PLAYLIST,
|
||||
scheduled_at=self.scheduled_at,
|
||||
timezone=self.timezone,
|
||||
parameters={
|
||||
"playlist_id": self.playlist_id,
|
||||
"play_mode": self.play_mode,
|
||||
"shuffle": self.shuffle,
|
||||
},
|
||||
recurrence_type=self.recurrence_type,
|
||||
cron_expression=self.cron_expression,
|
||||
recurrence_count=self.recurrence_count,
|
||||
expires_at=self.expires_at,
|
||||
)
|
||||
@@ -1,63 +1,413 @@
|
||||
"""Scheduler service for periodic tasks."""
|
||||
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytz
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.scheduled_task import (
|
||||
RecurrenceType,
|
||||
ScheduledTask,
|
||||
TaskStatus,
|
||||
TaskType,
|
||||
)
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
from app.services.credit import CreditService
|
||||
from app.services.player import PlayerService
|
||||
from app.services.task_handlers import TaskHandlerRegistry
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SchedulerService:
|
||||
"""Service for managing scheduled tasks."""
|
||||
"""Enhanced service for managing scheduled tasks with timezone support."""
|
||||
|
||||
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
db_session_factory: Callable[[], AsyncSession],
|
||||
player_service: PlayerService,
|
||||
) -> None:
|
||||
"""Initialize the scheduler service.
|
||||
|
||||
Args:
|
||||
db_session_factory: Factory function to create database sessions
|
||||
player_service: Player service for audio playback tasks
|
||||
|
||||
"""
|
||||
self.db_session_factory = db_session_factory
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.scheduler = AsyncIOScheduler(timezone=pytz.UTC)
|
||||
self.credit_service = CreditService(db_session_factory)
|
||||
self.player_service = player_service
|
||||
self._running_tasks: set[str] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the scheduler and register all tasks."""
|
||||
logger.info("Starting scheduler service...")
|
||||
"""Start the scheduler and load all active tasks."""
|
||||
logger.info("Starting enhanced scheduler service...")
|
||||
|
||||
# Add daily credit recharge job (runs at midnight UTC)
|
||||
self.scheduler.start()
|
||||
|
||||
# Schedule system tasks initialization for after startup
|
||||
self.scheduler.add_job(
|
||||
self._daily_credit_recharge,
|
||||
"cron",
|
||||
hour=0,
|
||||
minute=0,
|
||||
id="daily_credit_recharge",
|
||||
name="Daily Credit Recharge",
|
||||
self._initialize_system_tasks,
|
||||
"date",
|
||||
run_date=datetime.utcnow() + timedelta(seconds=2),
|
||||
id="initialize_system_tasks",
|
||||
name="Initialize System Tasks",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
# Schedule periodic cleanup and maintenance
|
||||
self.scheduler.add_job(
|
||||
self._maintenance_job,
|
||||
"interval",
|
||||
minutes=5,
|
||||
id="scheduler_maintenance",
|
||||
name="Scheduler Maintenance",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
logger.info("Scheduler service started successfully")
|
||||
logger.info("Enhanced scheduler service started successfully")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the scheduler."""
|
||||
logger.info("Stopping scheduler service...")
|
||||
self.scheduler.shutdown()
|
||||
self.scheduler.shutdown(wait=True)
|
||||
logger.info("Scheduler service stopped")
|
||||
|
||||
async def _daily_credit_recharge(self) -> None:
|
||||
"""Execute daily credit recharge for all users."""
|
||||
logger.info("Starting daily credit recharge task...")
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
name: str,
|
||||
task_type: TaskType,
|
||||
scheduled_at: datetime,
|
||||
parameters: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[int] = None,
|
||||
timezone: str = "UTC",
|
||||
recurrence_type: RecurrenceType = RecurrenceType.NONE,
|
||||
cron_expression: Optional[str] = None,
|
||||
recurrence_count: Optional[int] = None,
|
||||
expires_at: Optional[datetime] = None,
|
||||
) -> ScheduledTask:
|
||||
"""Create a new scheduled task."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Convert scheduled_at to UTC if it's in a different timezone
|
||||
if timezone != "UTC":
|
||||
tz = pytz.timezone(timezone)
|
||||
if scheduled_at.tzinfo is None:
|
||||
# Assume the datetime is in the specified timezone
|
||||
scheduled_at = tz.localize(scheduled_at)
|
||||
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
||||
|
||||
task_data = {
|
||||
"name": name,
|
||||
"task_type": task_type,
|
||||
"scheduled_at": scheduled_at,
|
||||
"timezone": timezone,
|
||||
"parameters": parameters or {},
|
||||
"user_id": user_id,
|
||||
"recurrence_type": recurrence_type,
|
||||
"cron_expression": cron_expression,
|
||||
"recurrence_count": recurrence_count,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
created_task = await repo.create(task_data)
|
||||
await self._schedule_apscheduler_job(created_task)
|
||||
|
||||
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
|
||||
return created_task
|
||||
|
||||
async def cancel_task(self, task_id: int) -> bool:
|
||||
"""Cancel a scheduled task."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
task = await repo.get_by_id(task_id)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
|
||||
# Remove from APScheduler
|
||||
try:
|
||||
self.scheduler.remove_job(str(task_id))
|
||||
except Exception:
|
||||
pass # Job might not exist in scheduler
|
||||
|
||||
logger.info(f"Cancelled task: {task.name} ({task_id})")
|
||||
return True
|
||||
|
||||
async def get_user_tasks(
|
||||
self,
|
||||
user_id: int,
|
||||
status: Optional[TaskStatus] = None,
|
||||
task_type: Optional[TaskType] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
) -> List[ScheduledTask]:
|
||||
"""Get tasks for a specific user."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
|
||||
|
||||
async def _initialize_system_tasks(self) -> None:
|
||||
"""Initialize system tasks and load active tasks from database."""
|
||||
logger.info("Initializing system tasks...")
|
||||
|
||||
try:
|
||||
stats = await self.credit_service.recharge_all_users_credits()
|
||||
logger.info(
|
||||
"Daily credit recharge completed successfully: %s",
|
||||
stats,
|
||||
)
|
||||
# Create system tasks if they don't exist
|
||||
await self._ensure_system_tasks()
|
||||
|
||||
# Load all active tasks from database
|
||||
await self._load_active_tasks()
|
||||
|
||||
logger.info("System tasks initialized successfully")
|
||||
except Exception:
|
||||
logger.exception("Daily credit recharge task failed")
|
||||
logger.exception("Failed to initialize system tasks")
|
||||
|
||||
async def _ensure_system_tasks(self) -> None:
|
||||
"""Ensure required system tasks exist."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Check if daily credit recharge task exists
|
||||
system_tasks = await repo.get_system_tasks(
|
||||
task_type=TaskType.CREDIT_RECHARGE
|
||||
)
|
||||
|
||||
daily_recharge_exists = any(
|
||||
task.recurrence_type == RecurrenceType.DAILY
|
||||
and task.is_active
|
||||
for task in system_tasks
|
||||
)
|
||||
|
||||
if not daily_recharge_exists:
|
||||
# Create daily credit recharge task
|
||||
tomorrow_midnight = datetime.utcnow().replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
) + timedelta(days=1)
|
||||
|
||||
task_data = {
|
||||
"name": "Daily Credit Recharge",
|
||||
"task_type": TaskType.CREDIT_RECHARGE,
|
||||
"scheduled_at": tomorrow_midnight,
|
||||
"recurrence_type": RecurrenceType.DAILY,
|
||||
"parameters": {},
|
||||
}
|
||||
|
||||
await repo.create(task_data)
|
||||
logger.info("Created system daily credit recharge task")
|
||||
|
||||
async def _load_active_tasks(self) -> None:
|
||||
"""Load all active tasks from database into scheduler."""
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
active_tasks = await repo.get_active_tasks()
|
||||
|
||||
for task in active_tasks:
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
|
||||
|
||||
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
||||
"""Schedule a task in APScheduler."""
|
||||
job_id = str(task.id)
|
||||
|
||||
# Remove existing job if it exists
|
||||
try:
|
||||
self.scheduler.remove_job(job_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Don't schedule if task is not active or already completed/failed
|
||||
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||
return
|
||||
|
||||
# Create trigger based on recurrence type
|
||||
trigger = self._create_trigger(task)
|
||||
if not trigger:
|
||||
logger.warning(f"Could not create trigger for task {task.id}")
|
||||
return
|
||||
|
||||
# Schedule the job
|
||||
self.scheduler.add_job(
|
||||
self._execute_task,
|
||||
trigger=trigger,
|
||||
args=[task.id],
|
||||
id=job_id,
|
||||
name=task.name,
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
logger.debug(f"Scheduled APScheduler job for task {task.id}")
|
||||
|
||||
def _create_trigger(self, task: ScheduledTask):
|
||||
"""Create APScheduler trigger based on task configuration."""
|
||||
tz = pytz.timezone(task.timezone)
|
||||
|
||||
if task.recurrence_type == RecurrenceType.NONE:
|
||||
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.HOURLY:
|
||||
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.DAILY:
|
||||
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
# Use cron trigger for monthly (more reliable than interval)
|
||||
scheduled_time = task.scheduled_at
|
||||
return CronTrigger(
|
||||
day=scheduled_time.day,
|
||||
hour=scheduled_time.hour,
|
||||
minute=scheduled_time.minute,
|
||||
timezone=tz
|
||||
)
|
||||
|
||||
elif task.recurrence_type == RecurrenceType.YEARLY:
|
||||
scheduled_time = task.scheduled_at
|
||||
return CronTrigger(
|
||||
month=scheduled_time.month,
|
||||
day=scheduled_time.day,
|
||||
hour=scheduled_time.hour,
|
||||
minute=scheduled_time.minute,
|
||||
timezone=tz
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def _execute_task(self, task_id: int) -> None:
|
||||
"""Execute a scheduled task."""
|
||||
task_id_str = str(task_id)
|
||||
|
||||
# Prevent concurrent execution of the same task
|
||||
if task_id_str in self._running_tasks:
|
||||
logger.warning(f"Task {task_id} is already running, skipping execution")
|
||||
return
|
||||
|
||||
self._running_tasks.add(task_id_str)
|
||||
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Get fresh task data
|
||||
task = await repo.get_by_id(task_id)
|
||||
if not task:
|
||||
logger.warning(f"Task {task_id} not found")
|
||||
return
|
||||
|
||||
# Check if task is still active and pending
|
||||
if not task.is_active or task.status != TaskStatus.PENDING:
|
||||
logger.info(f"Task {task_id} is not active or not pending, skipping")
|
||||
return
|
||||
|
||||
# Check if task has expired
|
||||
if task.is_expired():
|
||||
logger.info(f"Task {task_id} has expired, marking as cancelled")
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
return
|
||||
|
||||
# Mark task as running
|
||||
await repo.mark_as_running(task)
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
handler_registry = TaskHandlerRegistry(
|
||||
session, self.credit_service, self.player_service
|
||||
)
|
||||
await handler_registry.execute_task(task)
|
||||
|
||||
# Calculate next execution time for recurring tasks
|
||||
next_execution_at = None
|
||||
if task.should_repeat():
|
||||
next_execution_at = self._calculate_next_execution(task)
|
||||
|
||||
# Mark as completed
|
||||
await repo.mark_as_completed(task, next_execution_at)
|
||||
|
||||
# Reschedule if recurring
|
||||
if next_execution_at and task.should_repeat():
|
||||
# Refresh task to get updated data
|
||||
await session.refresh(task)
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
except Exception as e:
|
||||
await repo.mark_as_failed(task, str(e))
|
||||
logger.exception(f"Task {task_id} execution failed: {str(e)}")
|
||||
|
||||
finally:
|
||||
self._running_tasks.discard(task_id_str)
|
||||
|
||||
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
|
||||
"""Calculate the next execution time for a recurring task."""
|
||||
now = datetime.utcnow()
|
||||
|
||||
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||
return now + timedelta(hours=1)
|
||||
elif task.recurrence_type == RecurrenceType.DAILY:
|
||||
return now + timedelta(days=1)
|
||||
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
||||
return now + timedelta(weeks=1)
|
||||
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
||||
# Add approximately one month
|
||||
return now + timedelta(days=30)
|
||||
elif task.recurrence_type == RecurrenceType.YEARLY:
|
||||
return now + timedelta(days=365)
|
||||
|
||||
return None
|
||||
|
||||
async def _maintenance_job(self) -> None:
|
||||
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
repo = ScheduledTaskRepository(session)
|
||||
|
||||
# Handle expired tasks
|
||||
expired_tasks = await repo.get_expired_tasks()
|
||||
for task in expired_tasks:
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.is_active = False
|
||||
await repo.update(task)
|
||||
|
||||
# Remove from scheduler
|
||||
try:
|
||||
self.scheduler.remove_job(str(task.id))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if expired_tasks:
|
||||
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
|
||||
|
||||
# Handle any missed recurring tasks
|
||||
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
||||
for task in due_recurring:
|
||||
if task.should_repeat():
|
||||
task.status = TaskStatus.PENDING
|
||||
task.scheduled_at = task.next_execution_at or datetime.utcnow()
|
||||
await repo.update(task)
|
||||
await self._schedule_apscheduler_job(task)
|
||||
|
||||
if due_recurring:
|
||||
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Maintenance job failed")
|
||||
|
||||
137
app/services/task_handlers.py
Normal file
137
app/services/task_handlers.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Task execution handlers for different task types."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from app.core.logging import get_logger
|
||||
from app.models.scheduled_task import ScheduledTask, TaskType
|
||||
from app.repositories.playlist import PlaylistRepository
|
||||
from app.repositories.sound import SoundRepository
|
||||
from app.services.credit import CreditService
|
||||
from app.services.player import PlayerService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TaskExecutionError(Exception):
|
||||
"""Exception raised when task execution fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TaskHandlerRegistry:
|
||||
"""Registry for task execution handlers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_session: AsyncSession,
|
||||
credit_service: CreditService,
|
||||
player_service: PlayerService,
|
||||
) -> None:
|
||||
"""Initialize the task handler registry."""
|
||||
self.db_session = db_session
|
||||
self.credit_service = credit_service
|
||||
self.player_service = player_service
|
||||
self.sound_repository = SoundRepository(db_session)
|
||||
self.playlist_repository = PlaylistRepository(db_session)
|
||||
|
||||
# Register handlers
|
||||
self._handlers = {
|
||||
TaskType.CREDIT_RECHARGE: self._handle_credit_recharge,
|
||||
TaskType.PLAY_SOUND: self._handle_play_sound,
|
||||
TaskType.PLAY_PLAYLIST: self._handle_play_playlist,
|
||||
}
|
||||
|
||||
async def execute_task(self, task: ScheduledTask) -> None:
|
||||
"""Execute a task based on its type."""
|
||||
handler = self._handlers.get(task.task_type)
|
||||
if not handler:
|
||||
raise TaskExecutionError(f"No handler registered for task type: {task.task_type}")
|
||||
|
||||
logger.info(f"Executing task {task.id} ({task.task_type.value}): {task.name}")
|
||||
|
||||
try:
|
||||
await handler(task)
|
||||
logger.info(f"Task {task.id} executed successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"Task {task.id} execution failed: {str(e)}")
|
||||
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
|
||||
|
||||
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
|
||||
"""Handle credit recharge task."""
|
||||
parameters = task.parameters
|
||||
user_id = parameters.get("user_id")
|
||||
|
||||
if user_id:
|
||||
# Recharge specific user
|
||||
user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
stats = await self.credit_service.recharge_user_credits(user_uuid)
|
||||
logger.info(f"Recharged credits for user {user_id}: {stats}")
|
||||
else:
|
||||
# Recharge all users (system task)
|
||||
stats = await self.credit_service.recharge_all_users_credits()
|
||||
logger.info(f"Recharged credits for all users: {stats}")
|
||||
|
||||
async def _handle_play_sound(self, task: ScheduledTask) -> None:
|
||||
"""Handle play sound task."""
|
||||
parameters = task.parameters
|
||||
sound_id = parameters.get("sound_id")
|
||||
|
||||
if not sound_id:
|
||||
raise TaskExecutionError("sound_id parameter is required for PLAY_SOUND tasks")
|
||||
|
||||
try:
|
||||
sound_uuid = uuid.UUID(sound_id) if isinstance(sound_id, str) else sound_id
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TaskExecutionError(f"Invalid sound_id format: {sound_id}") from e
|
||||
|
||||
# Get the sound from database
|
||||
sound = await self.sound_repository.get_by_id(sound_uuid)
|
||||
if not sound:
|
||||
raise TaskExecutionError(f"Sound not found: {sound_id}")
|
||||
|
||||
# Play the sound through VLC
|
||||
from app.services.vlc_player import VLCPlayerService
|
||||
|
||||
vlc_service = VLCPlayerService(lambda: self.db_session)
|
||||
await vlc_service.play_sound(sound)
|
||||
|
||||
logger.info(f"Played sound {sound.filename} via scheduled task")
|
||||
|
||||
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
|
||||
"""Handle play playlist task."""
|
||||
parameters = task.parameters
|
||||
playlist_id = parameters.get("playlist_id")
|
||||
play_mode = parameters.get("play_mode", "continuous")
|
||||
shuffle = parameters.get("shuffle", False)
|
||||
|
||||
if not playlist_id:
|
||||
raise TaskExecutionError("playlist_id parameter is required for PLAY_PLAYLIST tasks")
|
||||
|
||||
try:
|
||||
playlist_uuid = uuid.UUID(playlist_id) if isinstance(playlist_id, str) else playlist_id
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TaskExecutionError(f"Invalid playlist_id format: {playlist_id}") from e
|
||||
|
||||
# Get the playlist from database
|
||||
playlist = await self.playlist_repository.get_by_id(playlist_uuid)
|
||||
if not playlist:
|
||||
raise TaskExecutionError(f"Playlist not found: {playlist_id}")
|
||||
|
||||
# Load playlist in player
|
||||
await self.player_service.load_playlist(playlist_uuid)
|
||||
|
||||
# Set play mode if specified
|
||||
if play_mode in ["continuous", "loop", "loop_one", "random", "single"]:
|
||||
self.player_service.set_mode(play_mode)
|
||||
|
||||
# Enable shuffle if requested
|
||||
if shuffle:
|
||||
self.player_service.set_shuffle(True)
|
||||
|
||||
# Start playing
|
||||
self.player_service.play()
|
||||
|
||||
logger.info(f"Started playing playlist {playlist.name} via scheduled task")
|
||||
@@ -238,75 +238,76 @@ class VLCPlayerService:
|
||||
return
|
||||
|
||||
logger.info("Recording play count for sound %s", sound_id)
|
||||
session = self.db_session_factory()
|
||||
|
||||
# Initialize variables for WebSocket event
|
||||
old_count = 0
|
||||
sound = None
|
||||
admin_user_id = None
|
||||
admin_user_name = None
|
||||
|
||||
try:
|
||||
sound_repo = SoundRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
async with self.db_session_factory() as session:
|
||||
sound_repo = SoundRepository(session)
|
||||
user_repo = UserRepository(session)
|
||||
|
||||
# Update sound play count
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
old_count = 0
|
||||
if sound:
|
||||
old_count = sound.play_count
|
||||
await sound_repo.update(
|
||||
sound,
|
||||
{"play_count": sound.play_count + 1},
|
||||
# Update sound play count
|
||||
sound = await sound_repo.get_by_id(sound_id)
|
||||
if sound:
|
||||
old_count = sound.play_count
|
||||
# Update the sound's play count using direct attribute modification
|
||||
sound.play_count = sound.play_count + 1
|
||||
session.add(sound)
|
||||
await session.commit()
|
||||
await session.refresh(sound)
|
||||
logger.info(
|
||||
"Updated sound %s play_count: %s -> %s",
|
||||
sound_id,
|
||||
old_count,
|
||||
old_count + 1,
|
||||
)
|
||||
else:
|
||||
logger.warning("Sound %s not found for play count update", sound_id)
|
||||
|
||||
# Record play history for admin user (ID 1) as placeholder
|
||||
# This could be refined to track per-user play history
|
||||
admin_user = await user_repo.get_by_id(1)
|
||||
if admin_user:
|
||||
admin_user_id = admin_user.id
|
||||
admin_user_name = admin_user.name
|
||||
|
||||
# Always create a new SoundPlayed record for each play event
|
||||
sound_played = SoundPlayed(
|
||||
user_id=admin_user_id, # Can be None for player-based plays
|
||||
sound_id=sound_id,
|
||||
)
|
||||
session.add(sound_played)
|
||||
logger.info(
|
||||
"Updated sound %s play_count: %s -> %s",
|
||||
sound_id,
|
||||
old_count,
|
||||
old_count + 1,
|
||||
)
|
||||
else:
|
||||
logger.warning("Sound %s not found for play count update", sound_id)
|
||||
|
||||
# Record play history for admin user (ID 1) as placeholder
|
||||
# This could be refined to track per-user play history
|
||||
admin_user = await user_repo.get_by_id(1)
|
||||
admin_user_id = None
|
||||
admin_user_name = None
|
||||
if admin_user:
|
||||
admin_user_id = admin_user.id
|
||||
admin_user_name = admin_user.name
|
||||
|
||||
# Always create a new SoundPlayed record for each play event
|
||||
sound_played = SoundPlayed(
|
||||
user_id=admin_user_id, # Can be None for player-based plays
|
||||
sound_id=sound_id,
|
||||
)
|
||||
session.add(sound_played)
|
||||
logger.info(
|
||||
"Created SoundPlayed record for user %s, sound %s",
|
||||
admin_user_id,
|
||||
sound_id,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Successfully recorded play count for sound %s", sound_id)
|
||||
|
||||
# Emit sound_played event via WebSocket
|
||||
try:
|
||||
event_data = {
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound_name,
|
||||
"user_id": admin_user_id,
|
||||
"user_name": admin_user_name,
|
||||
"play_count": (old_count + 1) if sound else None,
|
||||
}
|
||||
await socket_manager.broadcast_to_all("sound_played", event_data)
|
||||
logger.info("Broadcasted sound_played event for sound %s", sound_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to broadcast sound_played event for sound %s",
|
||||
"Created SoundPlayed record for user %s, sound %s",
|
||||
admin_user_id,
|
||||
sound_id,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
logger.info("Successfully recorded play count for sound %s", sound_id)
|
||||
except Exception:
|
||||
logger.exception("Error recording play count for sound %s", sound_id)
|
||||
await session.rollback()
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
# Emit sound_played event via WebSocket (outside session context)
|
||||
try:
|
||||
event_data = {
|
||||
"sound_id": sound_id,
|
||||
"sound_name": sound_name,
|
||||
"user_id": admin_user_id,
|
||||
"user_name": admin_user_name,
|
||||
"play_count": (old_count + 1) if sound else None,
|
||||
}
|
||||
await socket_manager.broadcast_to_all("sound_played", event_data)
|
||||
logger.info("Broadcasted sound_played event for sound %s", sound_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to broadcast sound_played event for sound %s",
|
||||
sound_id,
|
||||
)
|
||||
|
||||
async def play_sound_with_credits(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user