Add comprehensive tests for scheduled task repository, scheduler service, and task handlers

- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates.
- Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs.
- Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback.
- Ensured proper error handling and edge cases in task execution scenarios.
- Added fixtures and mocks to facilitate isolated testing of services and repositories.
This commit is contained in:
JSC
2025-08-28 22:37:43 +02:00
parent 7dee6e320e
commit 03abed6d39
23 changed files with 3415 additions and 103 deletions

View File

@@ -12,6 +12,7 @@ from app.api.v1 import (
main,
player,
playlists,
scheduler,
socket,
sounds,
)
@@ -28,6 +29,7 @@ api_router.include_router(files.router, tags=["files"])
api_router.include_router(main.router, tags=["main"])
api_router.include_router(player.router, tags=["player"])
api_router.include_router(playlists.router, tags=["playlists"])
api_router.include_router(scheduler.router, tags=["scheduler"])
api_router.include_router(socket.router, tags=["socket"])
api_router.include_router(sounds.router, tags=["sounds"])
api_router.include_router(admin.router)

228
app/api/v1/scheduler.py Normal file
View File

@@ -0,0 +1,228 @@
"""API endpoints for scheduled task management."""
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.database import get_db
from app.core.dependencies import (
get_admin_user,
get_current_active_user,
)
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
from app.models.user import User
from app.schemas.scheduler import (
ScheduledTaskCreate,
ScheduledTaskResponse,
ScheduledTaskUpdate,
)
from app.services.scheduler import SchedulerService
router = APIRouter(prefix="/scheduler")
def get_scheduler_service() -> SchedulerService:
"""Get the global scheduler service instance."""
from app.main import get_global_scheduler_service
return get_global_scheduler_service()
@router.post("/tasks", response_model=ScheduledTaskResponse)
async def create_task(
task_data: ScheduledTaskCreate,
current_user: User = Depends(get_current_active_user),
scheduler_service: SchedulerService = Depends(get_scheduler_service),
) -> ScheduledTask:
"""Create a new scheduled task."""
try:
task = await scheduler_service.create_task(
name=task_data.name,
task_type=task_data.task_type,
scheduled_at=task_data.scheduled_at,
parameters=task_data.parameters,
user_id=current_user.id,
timezone=task_data.timezone,
recurrence_type=task_data.recurrence_type,
cron_expression=task_data.cron_expression,
recurrence_count=task_data.recurrence_count,
expires_at=task_data.expires_at,
)
return task
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.get("/tasks", response_model=List[ScheduledTaskResponse])
async def get_user_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_current_active_user),
scheduler_service: SchedulerService = Depends(get_scheduler_service),
) -> List[ScheduledTask]:
"""Get user's scheduled tasks."""
return await scheduler_service.get_user_tasks(
user_id=current_user.id,
status=status,
task_type=task_type,
limit=limit,
offset=offset,
)
@router.get("/tasks/{task_id}", response_model=ScheduledTaskResponse)
async def get_task(
task_id: int,
current_user: User = Depends(get_current_active_user),
db_session: AsyncSession = Depends(get_db),
) -> ScheduledTask:
"""Get a specific scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
return task
@router.patch("/tasks/{task_id}", response_model=ScheduledTaskResponse)
async def update_task(
task_id: int,
task_update: ScheduledTaskUpdate,
current_user: User = Depends(get_current_active_user),
db_session: AsyncSession = Depends(get_db),
) -> ScheduledTask:
"""Update a scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
# Update task fields
update_data = task_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(task, field, value)
updated_task = await repo.update(task)
return updated_task
@router.delete("/tasks/{task_id}")
async def cancel_task(
task_id: int,
current_user: User = Depends(get_current_active_user),
scheduler_service: SchedulerService = Depends(get_scheduler_service),
db_session: AsyncSession = Depends(get_db),
) -> dict:
"""Cancel a scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
success = await scheduler_service.cancel_task(task_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to cancel task")
return {"message": "Task cancelled successfully"}
# Admin-only endpoints
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse])
async def get_all_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]:
"""Get all scheduled tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
# Get all tasks with pagination and filtering
from sqlmodel import select
statement = select(ScheduledTask)
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await db_session.exec(statement)
return list(result.all())
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse])
async def get_system_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]:
"""Get system tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
return await repo.get_system_tasks(status=status, task_type=task_type)
@router.post("/admin/system-tasks", response_model=ScheduledTaskResponse)
async def create_system_task(
task_data: ScheduledTaskCreate,
current_user: User = Depends(get_admin_user),
scheduler_service: SchedulerService = Depends(get_scheduler_service),
) -> ScheduledTask:
"""Create a system task (admin only)."""
try:
task = await scheduler_service.create_task(
name=task_data.name,
task_type=task_data.task_type,
scheduled_at=task_data.scheduled_at,
parameters=task_data.parameters,
user_id=None, # System task
timezone=task_data.timezone,
recurrence_type=task_data.recurrence_type,
cron_expression=task_data.cron_expression,
recurrence_count=task_data.recurrence_count,
expires_at=task_data.expires_at,
)
return task
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -7,17 +7,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.config import settings
from app.core.logging import get_logger
from app.core.seeds import seed_all_data
from app.models import ( # noqa: F401
extraction,
favorite,
plan,
playlist,
playlist_sound,
sound,
sound_played,
user,
user_oauth,
)
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
engine: AsyncEngine = create_async_engine(
settings.DATABASE_URL,

View File

@@ -11,14 +11,27 @@ from app.core.database import get_session_factory, init_db
from app.core.logging import get_logger, setup_logging
from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor
from app.services.player import initialize_player_service, shutdown_player_service
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service
from app.services.scheduler import SchedulerService
from app.services.socket import socket_manager
scheduler_service = None
def get_global_scheduler_service() -> SchedulerService:
"""Get the global scheduler service instance."""
global scheduler_service
if scheduler_service is None:
raise RuntimeError("Scheduler service not initialized")
return scheduler_service
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
"""Application lifespan context manager for setup and teardown."""
global scheduler_service
setup_logging()
logger = get_logger(__name__)
logger.info("Starting application")
@@ -35,17 +48,23 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
logger.info("Player service started")
# Start the scheduler service
scheduler_service = SchedulerService(get_session_factory())
await scheduler_service.start()
logger.info("Scheduler service started")
try:
player_service = get_player_service() # Get the initialized player service
scheduler_service = SchedulerService(get_session_factory(), player_service)
await scheduler_service.start()
logger.info("Enhanced scheduler service started")
except Exception:
logger.exception("Failed to start scheduler service - continuing without it")
scheduler_service = None
yield
logger.info("Shutting down application")
# Stop the scheduler service
await scheduler_service.stop()
logger.info("Scheduler service stopped")
if scheduler_service:
await scheduler_service.stop()
logger.info("Scheduler service stopped")
# Stop the player service
await shutdown_player_service()

View File

@@ -1 +1,32 @@
"""Models package."""
# Import all models for SQLAlchemy metadata discovery
from .base import BaseModel
from .credit_action import CreditAction
from .credit_transaction import CreditTransaction
from .extraction import Extraction
from .favorite import Favorite
from .plan import Plan
from .playlist import Playlist
from .playlist_sound import PlaylistSound
from .scheduled_task import ScheduledTask
from .sound import Sound
from .sound_played import SoundPlayed
from .user import User
from .user_oauth import UserOauth
__all__ = [
"BaseModel",
"CreditAction",
"CreditTransaction",
"Extraction",
"Favorite",
"Plan",
"Playlist",
"PlaylistSound",
"ScheduledTask",
"Sound",
"SoundPlayed",
"User",
"UserOauth",
]

View File

@@ -0,0 +1,125 @@
"""Scheduled task model for flexible task scheduling with timezone support."""
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Optional
from sqlmodel import JSON, Column, Field, SQLModel
from app.models.base import BaseModel
class TaskType(str, Enum):
"""Available task types."""
CREDIT_RECHARGE = "credit_recharge"
PLAY_SOUND = "play_sound"
PLAY_PLAYLIST = "play_playlist"
class TaskStatus(str, Enum):
"""Task execution status."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class RecurrenceType(str, Enum):
"""Recurrence patterns."""
NONE = "none" # One-shot task
HOURLY = "hourly"
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
YEARLY = "yearly"
CRON = "cron" # Custom cron expression
class ScheduledTask(BaseModel, table=True):
"""Model for scheduled tasks with timezone support."""
__tablename__ = "scheduled_tasks"
id: int | None = Field(primary_key=True, default=None)
name: str = Field(max_length=255, description="Human-readable task name")
task_type: TaskType = Field(description="Type of task to execute")
status: TaskStatus = Field(default=TaskStatus.PENDING)
# Scheduling fields with timezone support
scheduled_at: datetime = Field(description="When the task should be executed (UTC)")
timezone: str = Field(
default="UTC",
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
)
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
cron_expression: Optional[str] = Field(
default=None,
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
)
recurrence_count: Optional[int] = Field(
default=None,
description="Number of times to repeat (None for infinite)",
)
executions_count: int = Field(default=0, description="Number of times executed")
# Task parameters
parameters: dict[str, Any] = Field(
default_factory=dict,
sa_column=Column(JSON),
description="Task-specific parameters",
)
# User association (None for system tasks)
user_id: Optional[int] = Field(
default=None,
foreign_key="user.id",
description="User who created the task (None for system tasks)",
)
# Execution tracking
last_executed_at: Optional[datetime] = Field(
default=None,
description="When the task was last executed (UTC)",
)
next_execution_at: Optional[datetime] = Field(
default=None,
description="When the task should be executed next (UTC, for recurring tasks)",
)
error_message: Optional[str] = Field(
default=None,
description="Error message if execution failed",
)
# Task lifecycle
is_active: bool = Field(default=True, description="Whether the task is active")
expires_at: Optional[datetime] = Field(
default=None,
description="When the task expires (UTC, optional)",
)
def is_expired(self) -> bool:
"""Check if the task has expired."""
if self.expires_at is None:
return False
return datetime.utcnow() > self.expires_at
def is_recurring(self) -> bool:
"""Check if the task is recurring."""
return self.recurrence_type != RecurrenceType.NONE
def should_repeat(self) -> bool:
"""Check if the task should be repeated."""
if not self.is_recurring():
return False
if self.recurrence_count is None:
return True
return self.executions_count < self.recurrence_count
def is_system_task(self) -> bool:
"""Check if this is a system task (no user association)."""
return self.user_id is None

View File

@@ -0,0 +1,177 @@
"""Repository for scheduled task operations."""
from datetime import datetime
from typing import List, Optional
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
from app.repositories.base import BaseRepository
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
"""Repository for scheduled task database operations."""
def __init__(self, session: AsyncSession) -> None:
"""Initialize the repository."""
super().__init__(ScheduledTask, session)
async def get_pending_tasks(self) -> List[ScheduledTask]:
"""Get all pending tasks that are ready to be executed."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.status == TaskStatus.PENDING,
ScheduledTask.is_active.is_(True),
ScheduledTask.scheduled_at <= now,
)
result = await self.session.exec(statement)
return list(result.all())
async def get_active_tasks(self) -> List[ScheduledTask]:
"""Get all active tasks."""
statement = select(ScheduledTask).where(
ScheduledTask.is_active.is_(True),
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
)
result = await self.session.exec(statement)
return list(result.all())
async def get_user_tasks(
self,
user_id: int,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ScheduledTask]:
"""Get tasks for a specific user."""
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await self.session.exec(statement)
return list(result.all())
async def get_system_tasks(
self,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
) -> List[ScheduledTask]:
"""Get system tasks (tasks with no user association)."""
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
result = await self.session.exec(statement)
return list(result.all())
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
"""Get recurring tasks that need their next execution scheduled."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.recurrence_type != RecurrenceType.NONE,
ScheduledTask.is_active.is_(True),
ScheduledTask.status == TaskStatus.COMPLETED,
ScheduledTask.next_execution_at <= now,
)
result = await self.session.exec(statement)
return list(result.all())
async def get_expired_tasks(self) -> List[ScheduledTask]:
"""Get expired tasks that should be cleaned up."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
ScheduledTask.expires_at.is_not(None),
ScheduledTask.expires_at <= now,
ScheduledTask.is_active.is_(True),
)
result = await self.session.exec(statement)
return list(result.all())
async def cancel_user_tasks(
self,
user_id: int,
task_type: Optional[TaskType] = None,
) -> int:
"""Cancel all pending/running tasks for a user."""
statement = select(ScheduledTask).where(
ScheduledTask.user_id == user_id,
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
result = await self.session.exec(statement)
tasks = list(result.all())
count = 0
for task in tasks:
task.status = TaskStatus.CANCELLED
task.is_active = False
self.session.add(task)
count += 1
await self.session.commit()
return count
async def mark_as_running(self, task: ScheduledTask) -> None:
"""Mark a task as running."""
task.status = TaskStatus.RUNNING
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
async def mark_as_completed(
self,
task: ScheduledTask,
next_execution_at: Optional[datetime] = None,
) -> None:
"""Mark a task as completed and set next execution if recurring."""
task.status = TaskStatus.COMPLETED
task.last_executed_at = datetime.utcnow()
task.executions_count += 1
task.error_message = None
if next_execution_at and task.should_repeat():
task.next_execution_at = next_execution_at
task.status = TaskStatus.PENDING
elif not task.should_repeat():
task.is_active = False
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
"""Mark a task as failed with error message."""
task.status = TaskStatus.FAILED
task.error_message = error_message
task.last_executed_at = datetime.utcnow()
# For non-recurring tasks, deactivate on failure
if not task.is_recurring():
task.is_active = False
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)

189
app/schemas/scheduler.py Normal file
View File

@@ -0,0 +1,189 @@
"""Schemas for scheduled task API."""
from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from app.models.scheduled_task import RecurrenceType, TaskStatus, TaskType
class ScheduledTaskBase(BaseModel):
"""Base schema for scheduled tasks."""
name: str = Field(description="Human-readable task name")
task_type: TaskType = Field(description="Type of task to execute")
scheduled_at: datetime = Field(description="When the task should be executed")
timezone: str = Field(default="UTC", description="Timezone for scheduling")
parameters: Dict[str, Any] = Field(
default_factory=dict,
description="Task-specific parameters",
)
recurrence_type: RecurrenceType = Field(
default=RecurrenceType.NONE,
description="Recurrence pattern",
)
cron_expression: Optional[str] = Field(
default=None,
description="Cron expression for custom recurrence",
)
recurrence_count: Optional[int] = Field(
default=None,
description="Number of times to repeat (None for infinite)",
)
expires_at: Optional[datetime] = Field(
default=None,
description="When the task expires (optional)",
)
class ScheduledTaskCreate(ScheduledTaskBase):
"""Schema for creating a scheduled task."""
pass
class ScheduledTaskUpdate(BaseModel):
"""Schema for updating a scheduled task."""
name: Optional[str] = None
scheduled_at: Optional[datetime] = None
timezone: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = None
expires_at: Optional[datetime] = None
class ScheduledTaskResponse(ScheduledTaskBase):
"""Schema for scheduled task responses."""
id: int
status: TaskStatus
user_id: Optional[int] = None
executions_count: int
last_executed_at: Optional[datetime] = None
next_execution_at: Optional[datetime] = None
error_message: Optional[str] = None
is_active: bool
created_at: datetime
updated_at: datetime
class Config:
"""Pydantic configuration."""
from_attributes = True
# Task-specific parameter schemas
class CreditRechargeParameters(BaseModel):
"""Parameters for credit recharge tasks."""
user_id: Optional[int] = Field(
default=None,
description="Specific user ID to recharge (None for all users)",
)
class PlaySoundParameters(BaseModel):
"""Parameters for play sound tasks."""
sound_id: int = Field(description="ID of the sound to play")
class PlayPlaylistParameters(BaseModel):
"""Parameters for play playlist tasks."""
playlist_id: int = Field(description="ID of the playlist to play")
play_mode: str = Field(
default="continuous",
description="Play mode (continuous, loop, loop_one, random, single)",
)
shuffle: bool = Field(default=False, description="Whether to shuffle the playlist")
# Convenience schemas for creating specific task types
class CreateCreditRechargeTask(BaseModel):
"""Schema for creating credit recharge tasks."""
name: str = "Credit Recharge"
scheduled_at: datetime
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
user_id: Optional[int] = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
return ScheduledTaskCreate(
name=self.name,
task_type=TaskType.CREDIT_RECHARGE,
scheduled_at=self.scheduled_at,
timezone=self.timezone,
parameters={"user_id": self.user_id},
recurrence_type=self.recurrence_type,
cron_expression=self.cron_expression,
recurrence_count=self.recurrence_count,
expires_at=self.expires_at,
)
class CreatePlaySoundTask(BaseModel):
"""Schema for creating play sound tasks."""
name: str
scheduled_at: datetime
sound_id: int
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
return ScheduledTaskCreate(
name=self.name,
task_type=TaskType.PLAY_SOUND,
scheduled_at=self.scheduled_at,
timezone=self.timezone,
parameters={"sound_id": self.sound_id},
recurrence_type=self.recurrence_type,
cron_expression=self.cron_expression,
recurrence_count=self.recurrence_count,
expires_at=self.expires_at,
)
class CreatePlayPlaylistTask(BaseModel):
"""Schema for creating play playlist tasks."""
name: str
scheduled_at: datetime
playlist_id: int
play_mode: str = "continuous"
shuffle: bool = False
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
return ScheduledTaskCreate(
name=self.name,
task_type=TaskType.PLAY_PLAYLIST,
scheduled_at=self.scheduled_at,
timezone=self.timezone,
parameters={
"playlist_id": self.playlist_id,
"play_mode": self.play_mode,
"shuffle": self.shuffle,
},
recurrence_type=self.recurrence_type,
cron_expression=self.cron_expression,
recurrence_count=self.recurrence_count,
expires_at=self.expires_at,
)

View File

@@ -1,63 +1,413 @@
"""Scheduler service for periodic tasks."""
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
import pytz
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.date import DateTrigger
from apscheduler.triggers.interval import IntervalTrigger
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.scheduled_task import ScheduledTaskRepository
from app.services.credit import CreditService
from app.services.player import PlayerService
from app.services.task_handlers import TaskHandlerRegistry
logger = get_logger(__name__)
class SchedulerService:
"""Service for managing scheduled tasks."""
"""Enhanced service for managing scheduled tasks with timezone support."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
def __init__(
self,
db_session_factory: Callable[[], AsyncSession],
player_service: PlayerService,
) -> None:
"""Initialize the scheduler service.
Args:
db_session_factory: Factory function to create database sessions
player_service: Player service for audio playback tasks
"""
self.db_session_factory = db_session_factory
self.scheduler = AsyncIOScheduler()
self.scheduler = AsyncIOScheduler(timezone=pytz.UTC)
self.credit_service = CreditService(db_session_factory)
self.player_service = player_service
self._running_tasks: set[str] = set()
async def start(self) -> None:
"""Start the scheduler and register all tasks."""
logger.info("Starting scheduler service...")
"""Start the scheduler and load all active tasks."""
logger.info("Starting enhanced scheduler service...")
# Add daily credit recharge job (runs at midnight UTC)
self.scheduler.start()
# Schedule system tasks initialization for after startup
self.scheduler.add_job(
self._daily_credit_recharge,
"cron",
hour=0,
minute=0,
id="daily_credit_recharge",
name="Daily Credit Recharge",
self._initialize_system_tasks,
"date",
run_date=datetime.utcnow() + timedelta(seconds=2),
id="initialize_system_tasks",
name="Initialize System Tasks",
replace_existing=True,
)
# Schedule periodic cleanup and maintenance
self.scheduler.add_job(
self._maintenance_job,
"interval",
minutes=5,
id="scheduler_maintenance",
name="Scheduler Maintenance",
replace_existing=True,
)
self.scheduler.start()
logger.info("Scheduler service started successfully")
logger.info("Enhanced scheduler service started successfully")
async def stop(self) -> None:
"""Stop the scheduler."""
logger.info("Stopping scheduler service...")
self.scheduler.shutdown()
self.scheduler.shutdown(wait=True)
logger.info("Scheduler service stopped")
async def _daily_credit_recharge(self) -> None:
"""Execute daily credit recharge for all users."""
logger.info("Starting daily credit recharge task...")
async def create_task(
self,
name: str,
task_type: TaskType,
scheduled_at: datetime,
parameters: Optional[Dict[str, Any]] = None,
user_id: Optional[int] = None,
timezone: str = "UTC",
recurrence_type: RecurrenceType = RecurrenceType.NONE,
cron_expression: Optional[str] = None,
recurrence_count: Optional[int] = None,
expires_at: Optional[datetime] = None,
) -> ScheduledTask:
"""Create a new scheduled task."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Convert scheduled_at to UTC if it's in a different timezone
if timezone != "UTC":
tz = pytz.timezone(timezone)
if scheduled_at.tzinfo is None:
# Assume the datetime is in the specified timezone
scheduled_at = tz.localize(scheduled_at)
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
task_data = {
"name": name,
"task_type": task_type,
"scheduled_at": scheduled_at,
"timezone": timezone,
"parameters": parameters or {},
"user_id": user_id,
"recurrence_type": recurrence_type,
"cron_expression": cron_expression,
"recurrence_count": recurrence_count,
"expires_at": expires_at,
}
created_task = await repo.create(task_data)
await self._schedule_apscheduler_job(created_task)
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
return created_task
async def cancel_task(self, task_id: int) -> bool:
"""Cancel a scheduled task."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
task = await repo.get_by_id(task_id)
if not task:
return False
task.status = TaskStatus.CANCELLED
task.is_active = False
await repo.update(task)
# Remove from APScheduler
try:
self.scheduler.remove_job(str(task_id))
except Exception:
pass # Job might not exist in scheduler
logger.info(f"Cancelled task: {task.name} ({task_id})")
return True
async def get_user_tasks(
self,
user_id: int,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ScheduledTask]:
"""Get tasks for a specific user."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
async def _initialize_system_tasks(self) -> None:
"""Initialize system tasks and load active tasks from database."""
logger.info("Initializing system tasks...")
try:
stats = await self.credit_service.recharge_all_users_credits()
logger.info(
"Daily credit recharge completed successfully: %s",
stats,
)
# Create system tasks if they don't exist
await self._ensure_system_tasks()
# Load all active tasks from database
await self._load_active_tasks()
logger.info("System tasks initialized successfully")
except Exception:
logger.exception("Daily credit recharge task failed")
logger.exception("Failed to initialize system tasks")
async def _ensure_system_tasks(self) -> None:
"""Ensure required system tasks exist."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Check if daily credit recharge task exists
system_tasks = await repo.get_system_tasks(
task_type=TaskType.CREDIT_RECHARGE
)
daily_recharge_exists = any(
task.recurrence_type == RecurrenceType.DAILY
and task.is_active
for task in system_tasks
)
if not daily_recharge_exists:
# Create daily credit recharge task
tomorrow_midnight = datetime.utcnow().replace(
hour=0, minute=0, second=0, microsecond=0
) + timedelta(days=1)
task_data = {
"name": "Daily Credit Recharge",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": tomorrow_midnight,
"recurrence_type": RecurrenceType.DAILY,
"parameters": {},
}
await repo.create(task_data)
logger.info("Created system daily credit recharge task")
async def _load_active_tasks(self) -> None:
"""Load all active tasks from database into scheduler."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
active_tasks = await repo.get_active_tasks()
for task in active_tasks:
await self._schedule_apscheduler_job(task)
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
"""Schedule a task in APScheduler."""
job_id = str(task.id)
# Remove existing job if it exists
try:
self.scheduler.remove_job(job_id)
except Exception:
pass
# Don't schedule if task is not active or already completed/failed
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
return
# Create trigger based on recurrence type
trigger = self._create_trigger(task)
if not trigger:
logger.warning(f"Could not create trigger for task {task.id}")
return
# Schedule the job
self.scheduler.add_job(
self._execute_task,
trigger=trigger,
args=[task.id],
id=job_id,
name=task.name,
replace_existing=True,
)
logger.debug(f"Scheduled APScheduler job for task {task.id}")
def _create_trigger(self, task: ScheduledTask):
"""Create APScheduler trigger based on task configuration."""
tz = pytz.timezone(task.timezone)
if task.recurrence_type == RecurrenceType.NONE:
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
elif task.recurrence_type == RecurrenceType.HOURLY:
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.DAILY:
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.WEEKLY:
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.MONTHLY:
# Use cron trigger for monthly (more reliable than interval)
scheduled_time = task.scheduled_at
return CronTrigger(
day=scheduled_time.day,
hour=scheduled_time.hour,
minute=scheduled_time.minute,
timezone=tz
)
elif task.recurrence_type == RecurrenceType.YEARLY:
scheduled_time = task.scheduled_at
return CronTrigger(
month=scheduled_time.month,
day=scheduled_time.day,
hour=scheduled_time.hour,
minute=scheduled_time.minute,
timezone=tz
)
return None
async def _execute_task(self, task_id: int) -> None:
"""Execute a scheduled task."""
task_id_str = str(task_id)
# Prevent concurrent execution of the same task
if task_id_str in self._running_tasks:
logger.warning(f"Task {task_id} is already running, skipping execution")
return
self._running_tasks.add(task_id_str)
try:
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Get fresh task data
task = await repo.get_by_id(task_id)
if not task:
logger.warning(f"Task {task_id} not found")
return
# Check if task is still active and pending
if not task.is_active or task.status != TaskStatus.PENDING:
logger.info(f"Task {task_id} is not active or not pending, skipping")
return
# Check if task has expired
if task.is_expired():
logger.info(f"Task {task_id} has expired, marking as cancelled")
task.status = TaskStatus.CANCELLED
task.is_active = False
await repo.update(task)
return
# Mark task as running
await repo.mark_as_running(task)
# Execute the task
try:
handler_registry = TaskHandlerRegistry(
session, self.credit_service, self.player_service
)
await handler_registry.execute_task(task)
# Calculate next execution time for recurring tasks
next_execution_at = None
if task.should_repeat():
next_execution_at = self._calculate_next_execution(task)
# Mark as completed
await repo.mark_as_completed(task, next_execution_at)
# Reschedule if recurring
if next_execution_at and task.should_repeat():
# Refresh task to get updated data
await session.refresh(task)
await self._schedule_apscheduler_job(task)
except Exception as e:
await repo.mark_as_failed(task, str(e))
logger.exception(f"Task {task_id} execution failed: {str(e)}")
finally:
self._running_tasks.discard(task_id_str)
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
"""Calculate the next execution time for a recurring task."""
now = datetime.utcnow()
if task.recurrence_type == RecurrenceType.HOURLY:
return now + timedelta(hours=1)
elif task.recurrence_type == RecurrenceType.DAILY:
return now + timedelta(days=1)
elif task.recurrence_type == RecurrenceType.WEEKLY:
return now + timedelta(weeks=1)
elif task.recurrence_type == RecurrenceType.MONTHLY:
# Add approximately one month
return now + timedelta(days=30)
elif task.recurrence_type == RecurrenceType.YEARLY:
return now + timedelta(days=365)
return None
async def _maintenance_job(self) -> None:
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
try:
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Handle expired tasks
expired_tasks = await repo.get_expired_tasks()
for task in expired_tasks:
task.status = TaskStatus.CANCELLED
task.is_active = False
await repo.update(task)
# Remove from scheduler
try:
self.scheduler.remove_job(str(task.id))
except Exception:
pass
if expired_tasks:
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
# Handle any missed recurring tasks
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
for task in due_recurring:
if task.should_repeat():
task.status = TaskStatus.PENDING
task.scheduled_at = task.next_execution_at or datetime.utcnow()
await repo.update(task)
await self._schedule_apscheduler_job(task)
if due_recurring:
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
except Exception:
logger.exception("Maintenance job failed")

View File

@@ -0,0 +1,137 @@
"""Task execution handlers for different task types."""
import uuid
from typing import Any, Dict, Optional
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.scheduled_task import ScheduledTask, TaskType
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
from app.services.credit import CreditService
from app.services.player import PlayerService
logger = get_logger(__name__)
class TaskExecutionError(Exception):
"""Exception raised when task execution fails."""
pass
class TaskHandlerRegistry:
"""Registry for task execution handlers."""
def __init__(
self,
db_session: AsyncSession,
credit_service: CreditService,
player_service: PlayerService,
) -> None:
"""Initialize the task handler registry."""
self.db_session = db_session
self.credit_service = credit_service
self.player_service = player_service
self.sound_repository = SoundRepository(db_session)
self.playlist_repository = PlaylistRepository(db_session)
# Register handlers
self._handlers = {
TaskType.CREDIT_RECHARGE: self._handle_credit_recharge,
TaskType.PLAY_SOUND: self._handle_play_sound,
TaskType.PLAY_PLAYLIST: self._handle_play_playlist,
}
async def execute_task(self, task: ScheduledTask) -> None:
"""Execute a task based on its type."""
handler = self._handlers.get(task.task_type)
if not handler:
raise TaskExecutionError(f"No handler registered for task type: {task.task_type}")
logger.info(f"Executing task {task.id} ({task.task_type.value}): {task.name}")
try:
await handler(task)
logger.info(f"Task {task.id} executed successfully")
except Exception as e:
logger.exception(f"Task {task.id} execution failed: {str(e)}")
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
"""Handle credit recharge task."""
parameters = task.parameters
user_id = parameters.get("user_id")
if user_id:
# Recharge specific user
user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
stats = await self.credit_service.recharge_user_credits(user_uuid)
logger.info(f"Recharged credits for user {user_id}: {stats}")
else:
# Recharge all users (system task)
stats = await self.credit_service.recharge_all_users_credits()
logger.info(f"Recharged credits for all users: {stats}")
async def _handle_play_sound(self, task: ScheduledTask) -> None:
"""Handle play sound task."""
parameters = task.parameters
sound_id = parameters.get("sound_id")
if not sound_id:
raise TaskExecutionError("sound_id parameter is required for PLAY_SOUND tasks")
try:
sound_uuid = uuid.UUID(sound_id) if isinstance(sound_id, str) else sound_id
except (ValueError, TypeError) as e:
raise TaskExecutionError(f"Invalid sound_id format: {sound_id}") from e
# Get the sound from database
sound = await self.sound_repository.get_by_id(sound_uuid)
if not sound:
raise TaskExecutionError(f"Sound not found: {sound_id}")
# Play the sound through VLC
from app.services.vlc_player import VLCPlayerService
vlc_service = VLCPlayerService(lambda: self.db_session)
await vlc_service.play_sound(sound)
logger.info(f"Played sound {sound.filename} via scheduled task")
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
"""Handle play playlist task."""
parameters = task.parameters
playlist_id = parameters.get("playlist_id")
play_mode = parameters.get("play_mode", "continuous")
shuffle = parameters.get("shuffle", False)
if not playlist_id:
raise TaskExecutionError("playlist_id parameter is required for PLAY_PLAYLIST tasks")
try:
playlist_uuid = uuid.UUID(playlist_id) if isinstance(playlist_id, str) else playlist_id
except (ValueError, TypeError) as e:
raise TaskExecutionError(f"Invalid playlist_id format: {playlist_id}") from e
# Get the playlist from database
playlist = await self.playlist_repository.get_by_id(playlist_uuid)
if not playlist:
raise TaskExecutionError(f"Playlist not found: {playlist_id}")
# Load playlist in player
await self.player_service.load_playlist(playlist_uuid)
# Set play mode if specified
if play_mode in ["continuous", "loop", "loop_one", "random", "single"]:
self.player_service.set_mode(play_mode)
# Enable shuffle if requested
if shuffle:
self.player_service.set_shuffle(True)
# Start playing
self.player_service.play()
logger.info(f"Started playing playlist {playlist.name} via scheduled task")

View File

@@ -238,75 +238,76 @@ class VLCPlayerService:
return
logger.info("Recording play count for sound %s", sound_id)
session = self.db_session_factory()
# Initialize variables for WebSocket event
old_count = 0
sound = None
admin_user_id = None
admin_user_name = None
try:
sound_repo = SoundRepository(session)
user_repo = UserRepository(session)
async with self.db_session_factory() as session:
sound_repo = SoundRepository(session)
user_repo = UserRepository(session)
# Update sound play count
sound = await sound_repo.get_by_id(sound_id)
old_count = 0
if sound:
old_count = sound.play_count
await sound_repo.update(
sound,
{"play_count": sound.play_count + 1},
# Update sound play count
sound = await sound_repo.get_by_id(sound_id)
if sound:
old_count = sound.play_count
# Update the sound's play count using direct attribute modification
sound.play_count = sound.play_count + 1
session.add(sound)
await session.commit()
await session.refresh(sound)
logger.info(
"Updated sound %s play_count: %s -> %s",
sound_id,
old_count,
old_count + 1,
)
else:
logger.warning("Sound %s not found for play count update", sound_id)
# Record play history for admin user (ID 1) as placeholder
# This could be refined to track per-user play history
admin_user = await user_repo.get_by_id(1)
if admin_user:
admin_user_id = admin_user.id
admin_user_name = admin_user.name
# Always create a new SoundPlayed record for each play event
sound_played = SoundPlayed(
user_id=admin_user_id, # Can be None for player-based plays
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Updated sound %s play_count: %s -> %s",
sound_id,
old_count,
old_count + 1,
)
else:
logger.warning("Sound %s not found for play count update", sound_id)
# Record play history for admin user (ID 1) as placeholder
# This could be refined to track per-user play history
admin_user = await user_repo.get_by_id(1)
admin_user_id = None
admin_user_name = None
if admin_user:
admin_user_id = admin_user.id
admin_user_name = admin_user.name
# Always create a new SoundPlayed record for each play event
sound_played = SoundPlayed(
user_id=admin_user_id, # Can be None for player-based plays
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Created SoundPlayed record for user %s, sound %s",
admin_user_id,
sound_id,
)
await session.commit()
logger.info("Successfully recorded play count for sound %s", sound_id)
# Emit sound_played event via WebSocket
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": admin_user_id,
"user_name": admin_user_name,
"play_count": (old_count + 1) if sound else None,
}
await socket_manager.broadcast_to_all("sound_played", event_data)
logger.info("Broadcasted sound_played event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_played event for sound %s",
"Created SoundPlayed record for user %s, sound %s",
admin_user_id,
sound_id,
)
await session.commit()
logger.info("Successfully recorded play count for sound %s", sound_id)
except Exception:
logger.exception("Error recording play count for sound %s", sound_id)
await session.rollback()
finally:
await session.close()
# Emit sound_played event via WebSocket (outside session context)
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": admin_user_id,
"user_name": admin_user_name,
"play_count": (old_count + 1) if sound else None,
}
await socket_manager.broadcast_to_all("sound_played", event_data)
logger.info("Broadcasted sound_played event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_played event for sound %s",
sound_id,
)
async def play_sound_with_credits(
self,