Refactor scheduled task repository and schemas for improved type hints and consistency

- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase.
- Refactored import statements for better organization and clarity.
- Enhanced the ScheduledTaskBase schema to use modern type hints.
- Cleaned up unnecessary comments and whitespace in various files.
- Improved error handling and logging in task execution handlers.
- Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
JSC
2025-08-28 23:38:47 +02:00
parent 96801dc4d6
commit dc89e45675
19 changed files with 292 additions and 291 deletions

View File

@@ -1,7 +1,5 @@
"""API endpoints for scheduled task management."""
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -11,7 +9,7 @@ from app.core.dependencies import (
get_admin_user,
get_current_active_user,
)
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
from app.models.user import User
from app.schemas.scheduler import (
ScheduledTaskCreate,
@@ -54,15 +52,15 @@ async def create_task(
raise HTTPException(status_code=400, detail=str(e))
@router.get("/tasks", response_model=List[ScheduledTaskResponse])
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
async def get_user_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: TaskType | None = Query(None, description="Filter by task type"),
limit: int | None = Query(50, description="Maximum number of tasks to return"),
offset: int | None = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_current_active_user),
scheduler_service: SchedulerService = Depends(get_scheduler_service),
) -> List[ScheduledTask]:
) -> list[ScheduledTask]:
"""Get user's scheduled tasks."""
return await scheduler_service.get_user_tasks(
user_id=current_user.id,
@@ -81,17 +79,17 @@ async def get_task(
) -> ScheduledTask:
"""Get a specific scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
return task
@@ -104,22 +102,22 @@ async def update_task(
) -> ScheduledTask:
"""Update a scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
# Update task fields
update_data = task_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(task, field, value)
updated_task = await repo.update(task)
return updated_task
@@ -133,72 +131,72 @@ async def cancel_task(
) -> dict:
"""Cancel a scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied")
success = await scheduler_service.cancel_task(task_id)
if not success:
raise HTTPException(status_code=400, detail="Failed to cancel task")
return {"message": "Task cancelled successfully"}
# Admin-only endpoints
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse])
@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
async def get_all_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: TaskType | None = Query(None, description="Filter by task type"),
limit: int | None = Query(100, description="Maximum number of tasks to return"),
offset: int | None = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]:
) -> list[ScheduledTask]:
"""Get all scheduled tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
# Get all tasks with pagination and filtering
from sqlmodel import select
statement = select(ScheduledTask)
if status:
statement = statement.where(ScheduledTask.status == status)
if task_type:
statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
if offset:
statement = statement.offset(offset)
if limit:
statement = statement.limit(limit)
result = await db_session.exec(statement)
return list(result.all())
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse])
@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
async def get_system_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: TaskType | None = Query(None, description="Filter by task type"),
current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]:
) -> list[ScheduledTask]:
"""Get system tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session)
return await repo.get_system_tasks(status=status, task_type=task_type)
@@ -225,4 +223,4 @@ async def create_system_task(
)
return task
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -4,11 +4,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
from app.core.config import settings
from app.core.logging import get_logger
from app.core.seeds import seed_all_data
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
engine: AsyncEngine = create_async_engine(
settings.DATABASE_URL,

View File

@@ -11,11 +11,14 @@ from app.core.database import get_session_factory, init_db
from app.core.logging import get_logger, setup_logging
from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service
from app.services.player import (
get_player_service,
initialize_player_service,
shutdown_player_service,
)
from app.services.scheduler import SchedulerService
from app.services.socket import socket_manager
scheduler_service = None
@@ -31,7 +34,7 @@ def get_global_scheduler_service() -> SchedulerService:
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
"""Application lifespan context manager for setup and teardown."""
global scheduler_service
setup_logging()
logger = get_logger(__name__)
logger.info("Starting application")

View File

@@ -20,7 +20,7 @@ __all__ = [
"CreditAction",
"CreditTransaction",
"Extraction",
"Favorite",
"Favorite",
"Plan",
"Playlist",
"PlaylistSound",

View File

@@ -1,11 +1,10 @@
"""Scheduled task model for flexible task scheduling with timezone support."""
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, Optional
from typing import Any
from sqlmodel import JSON, Column, Field, SQLModel
from sqlmodel import JSON, Column, Field
from app.models.base import BaseModel
@@ -57,11 +56,11 @@ class ScheduledTask(BaseModel, table=True):
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
)
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
cron_expression: Optional[str] = Field(
cron_expression: str | None = Field(
default=None,
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
)
recurrence_count: Optional[int] = Field(
recurrence_count: int | None = Field(
default=None,
description="Number of times to repeat (None for infinite)",
)
@@ -75,29 +74,29 @@ class ScheduledTask(BaseModel, table=True):
)
# User association (None for system tasks)
user_id: Optional[int] = Field(
user_id: int | None = Field(
default=None,
foreign_key="user.id",
description="User who created the task (None for system tasks)",
)
# Execution tracking
last_executed_at: Optional[datetime] = Field(
last_executed_at: datetime | None = Field(
default=None,
description="When the task was last executed (UTC)",
)
next_execution_at: Optional[datetime] = Field(
next_execution_at: datetime | None = Field(
default=None,
description="When the task should be executed next (UTC, for recurring tasks)",
)
error_message: Optional[str] = Field(
error_message: str | None = Field(
default=None,
description="Error message if execution failed",
)
# Task lifecycle
is_active: bool = Field(default=True, description="Whether the task is active")
expires_at: Optional[datetime] = Field(
expires_at: datetime | None = Field(
default=None,
description="When the task expires (UTC, optional)",
)
@@ -122,4 +121,4 @@ class ScheduledTask(BaseModel, table=True):
def is_system_task(self) -> bool:
"""Check if this is a system task (no user association)."""
return self.user_id is None
return self.user_id is None

View File

@@ -1,12 +1,16 @@
"""Repository for scheduled task operations."""
from datetime import datetime
from typing import List, Optional
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.base import BaseRepository
@@ -17,7 +21,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
"""Initialize the repository."""
super().__init__(ScheduledTask, session)
async def get_pending_tasks(self) -> List[ScheduledTask]:
async def get_pending_tasks(self) -> list[ScheduledTask]:
"""Get all pending tasks that are ready to be executed."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
@@ -28,7 +32,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement)
return list(result.all())
async def get_active_tasks(self) -> List[ScheduledTask]:
async def get_active_tasks(self) -> list[ScheduledTask]:
"""Get all active tasks."""
statement = select(ScheduledTask).where(
ScheduledTask.is_active.is_(True),
@@ -40,11 +44,11 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def get_user_tasks(
self,
user_id: int,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ScheduledTask]:
status: TaskStatus | None = None,
task_type: TaskType | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[ScheduledTask]:
"""Get tasks for a specific user."""
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
@@ -67,9 +71,9 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def get_system_tasks(
self,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
) -> List[ScheduledTask]:
status: TaskStatus | None = None,
task_type: TaskType | None = None,
) -> list[ScheduledTask]:
"""Get system tasks (tasks with no user association)."""
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
@@ -84,7 +88,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement)
return list(result.all())
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
"""Get recurring tasks that need their next execution scheduled."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
@@ -96,7 +100,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement)
return list(result.all())
async def get_expired_tasks(self) -> List[ScheduledTask]:
async def get_expired_tasks(self) -> list[ScheduledTask]:
"""Get expired tasks that should be cleaned up."""
now = datetime.utcnow()
statement = select(ScheduledTask).where(
@@ -110,7 +114,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def cancel_user_tasks(
self,
user_id: int,
task_type: Optional[TaskType] = None,
task_type: TaskType | None = None,
) -> int:
"""Cancel all pending/running tasks for a user."""
statement = select(ScheduledTask).where(
@@ -144,7 +148,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def mark_as_completed(
self,
task: ScheduledTask,
next_execution_at: Optional[datetime] = None,
next_execution_at: datetime | None = None,
) -> None:
"""Mark a task as completed and set next execution if recurring."""
task.status = TaskStatus.COMPLETED
@@ -174,4 +178,4 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
self.session.add(task)
await self.session.commit()
await self.session.refresh(task)
await self.session.refresh(task)

View File

@@ -1,7 +1,7 @@
"""Schemas for scheduled task API."""
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
@@ -15,7 +15,7 @@ class ScheduledTaskBase(BaseModel):
task_type: TaskType = Field(description="Type of task to execute")
scheduled_at: datetime = Field(description="When the task should be executed")
timezone: str = Field(default="UTC", description="Timezone for scheduling")
parameters: Dict[str, Any] = Field(
parameters: dict[str, Any] = Field(
default_factory=dict,
description="Task-specific parameters",
)
@@ -23,15 +23,15 @@ class ScheduledTaskBase(BaseModel):
default=RecurrenceType.NONE,
description="Recurrence pattern",
)
cron_expression: Optional[str] = Field(
cron_expression: str | None = Field(
default=None,
description="Cron expression for custom recurrence",
)
recurrence_count: Optional[int] = Field(
recurrence_count: int | None = Field(
default=None,
description="Number of times to repeat (None for infinite)",
)
expires_at: Optional[datetime] = Field(
expires_at: datetime | None = Field(
default=None,
description="When the task expires (optional)",
)
@@ -40,18 +40,17 @@ class ScheduledTaskBase(BaseModel):
class ScheduledTaskCreate(ScheduledTaskBase):
"""Schema for creating a scheduled task."""
pass
class ScheduledTaskUpdate(BaseModel):
"""Schema for updating a scheduled task."""
name: Optional[str] = None
scheduled_at: Optional[datetime] = None
timezone: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
is_active: Optional[bool] = None
expires_at: Optional[datetime] = None
name: str | None = None
scheduled_at: datetime | None = None
timezone: str | None = None
parameters: dict[str, Any] | None = None
is_active: bool | None = None
expires_at: datetime | None = None
class ScheduledTaskResponse(ScheduledTaskBase):
@@ -59,11 +58,11 @@ class ScheduledTaskResponse(ScheduledTaskBase):
id: int
status: TaskStatus
user_id: Optional[int] = None
user_id: int | None = None
executions_count: int
last_executed_at: Optional[datetime] = None
next_execution_at: Optional[datetime] = None
error_message: Optional[str] = None
last_executed_at: datetime | None = None
next_execution_at: datetime | None = None
error_message: str | None = None
is_active: bool
created_at: datetime
updated_at: datetime
@@ -78,7 +77,7 @@ class ScheduledTaskResponse(ScheduledTaskBase):
class CreditRechargeParameters(BaseModel):
"""Parameters for credit recharge tasks."""
user_id: Optional[int] = Field(
user_id: int | None = Field(
default=None,
description="Specific user ID to recharge (None for all users)",
)
@@ -109,10 +108,10 @@ class CreateCreditRechargeTask(BaseModel):
scheduled_at: datetime
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
user_id: Optional[int] = None
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
user_id: int | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
@@ -137,9 +136,9 @@ class CreatePlaySoundTask(BaseModel):
sound_id: int
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
@@ -166,9 +165,9 @@ class CreatePlayPlaylistTask(BaseModel):
shuffle: bool = False
timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None
recurrence_count: Optional[int] = None
expires_at: Optional[datetime] = None
cron_expression: str | None = None
recurrence_count: int | None = None
expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema."""
@@ -186,4 +185,4 @@ class CreatePlayPlaylistTask(BaseModel):
cron_expression=self.cron_expression,
recurrence_count=self.recurrence_count,
expires_at=self.expires_at,
)
)

View File

@@ -2,7 +2,7 @@
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from typing import Any
import pytz
from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -52,7 +52,7 @@ class SchedulerService:
logger.info("Starting enhanced scheduler service...")
self.scheduler.start()
# Schedule system tasks initialization for after startup
self.scheduler.add_job(
self._initialize_system_tasks,
@@ -62,7 +62,7 @@ class SchedulerService:
name="Initialize System Tasks",
replace_existing=True,
)
# Schedule periodic cleanup and maintenance
self.scheduler.add_job(
self._maintenance_job,
@@ -86,18 +86,18 @@ class SchedulerService:
name: str,
task_type: TaskType,
scheduled_at: datetime,
parameters: Optional[Dict[str, Any]] = None,
user_id: Optional[int] = None,
parameters: dict[str, Any] | None = None,
user_id: int | None = None,
timezone: str = "UTC",
recurrence_type: RecurrenceType = RecurrenceType.NONE,
cron_expression: Optional[str] = None,
recurrence_count: Optional[int] = None,
expires_at: Optional[datetime] = None,
cron_expression: str | None = None,
recurrence_count: int | None = None,
expires_at: datetime | None = None,
) -> ScheduledTask:
"""Create a new scheduled task."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Convert scheduled_at to UTC if it's in a different timezone
if timezone != "UTC":
tz = pytz.timezone(timezone)
@@ -105,7 +105,7 @@ class SchedulerService:
# Assume the datetime is in the specified timezone
scheduled_at = tz.localize(scheduled_at)
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
task_data = {
"name": name,
"task_type": task_type,
@@ -118,59 +118,59 @@ class SchedulerService:
"recurrence_count": recurrence_count,
"expires_at": expires_at,
}
created_task = await repo.create(task_data)
await self._schedule_apscheduler_job(created_task)
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
return created_task
async def cancel_task(self, task_id: int) -> bool:
"""Cancel a scheduled task."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
task = await repo.get_by_id(task_id)
if not task:
return False
task.status = TaskStatus.CANCELLED
task.is_active = False
await repo.update(task)
# Remove from APScheduler
try:
self.scheduler.remove_job(str(task_id))
except Exception:
pass # Job might not exist in scheduler
logger.info(f"Cancelled task: {task.name} ({task_id})")
return True
async def get_user_tasks(
self,
user_id: int,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ScheduledTask]:
status: TaskStatus | None = None,
task_type: TaskType | None = None,
limit: int | None = None,
offset: int | None = None,
) -> list[ScheduledTask]:
"""Get tasks for a specific user."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
async def _initialize_system_tasks(self) -> None:
"""Initialize system tasks and load active tasks from database."""
logger.info("Initializing system tasks...")
try:
# Create system tasks if they don't exist
await self._ensure_system_tasks()
# Load all active tasks from database
await self._load_active_tasks()
logger.info("System tasks initialized successfully")
except Exception:
logger.exception("Failed to initialize system tasks")
@@ -179,24 +179,24 @@ class SchedulerService:
"""Ensure required system tasks exist."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Check if daily credit recharge task exists
system_tasks = await repo.get_system_tasks(
task_type=TaskType.CREDIT_RECHARGE
task_type=TaskType.CREDIT_RECHARGE,
)
daily_recharge_exists = any(
task.recurrence_type == RecurrenceType.DAILY
and task.is_active
for task in system_tasks
)
if not daily_recharge_exists:
# Create daily credit recharge task
tomorrow_midnight = datetime.utcnow().replace(
hour=0, minute=0, second=0, microsecond=0
hour=0, minute=0, second=0, microsecond=0,
) + timedelta(days=1)
task_data = {
"name": "Daily Credit Recharge",
"task_type": TaskType.CREDIT_RECHARGE,
@@ -204,41 +204,41 @@ class SchedulerService:
"recurrence_type": RecurrenceType.DAILY,
"parameters": {},
}
await repo.create(task_data)
logger.info("Created system daily credit recharge task")
async def _load_active_tasks(self) -> None:
"""Load all active tasks from database into scheduler."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
active_tasks = await repo.get_active_tasks()
for task in active_tasks:
await self._schedule_apscheduler_job(task)
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
"""Schedule a task in APScheduler."""
job_id = str(task.id)
# Remove existing job if it exists
try:
self.scheduler.remove_job(job_id)
except Exception:
pass
# Don't schedule if task is not active or already completed/failed
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
return
# Create trigger based on recurrence type
trigger = self._create_trigger(task)
if not trigger:
logger.warning(f"Could not create trigger for task {task.id}")
return
# Schedule the job
self.scheduler.add_job(
self._execute_task,
@@ -248,76 +248,76 @@ class SchedulerService:
name=task.name,
replace_existing=True,
)
logger.debug(f"Scheduled APScheduler job for task {task.id}")
def _create_trigger(self, task: ScheduledTask):
"""Create APScheduler trigger based on task configuration."""
tz = pytz.timezone(task.timezone)
if task.recurrence_type == RecurrenceType.NONE:
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
elif task.recurrence_type == RecurrenceType.HOURLY:
if task.recurrence_type == RecurrenceType.HOURLY:
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.DAILY:
if task.recurrence_type == RecurrenceType.DAILY:
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.WEEKLY:
if task.recurrence_type == RecurrenceType.WEEKLY:
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.MONTHLY:
if task.recurrence_type == RecurrenceType.MONTHLY:
# Use cron trigger for monthly (more reliable than interval)
scheduled_time = task.scheduled_at
return CronTrigger(
day=scheduled_time.day,
hour=scheduled_time.hour,
minute=scheduled_time.minute,
timezone=tz
timezone=tz,
)
elif task.recurrence_type == RecurrenceType.YEARLY:
if task.recurrence_type == RecurrenceType.YEARLY:
scheduled_time = task.scheduled_at
return CronTrigger(
month=scheduled_time.month,
day=scheduled_time.day,
hour=scheduled_time.hour,
minute=scheduled_time.minute,
timezone=tz
timezone=tz,
)
return None
async def _execute_task(self, task_id: int) -> None:
"""Execute a scheduled task."""
task_id_str = str(task_id)
# Prevent concurrent execution of the same task
if task_id_str in self._running_tasks:
logger.warning(f"Task {task_id} is already running, skipping execution")
return
self._running_tasks.add(task_id_str)
try:
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Get fresh task data
task = await repo.get_by_id(task_id)
if not task:
logger.warning(f"Task {task_id} not found")
return
# Check if task is still active and pending
if not task.is_active or task.status != TaskStatus.PENDING:
logger.info(f"Task {task_id} is not active or not pending, skipping")
return
# Check if task has expired
if task.is_expired():
logger.info(f"Task {task_id} has expired, marking as cancelled")
@@ -325,78 +325,78 @@ class SchedulerService:
task.is_active = False
await repo.update(task)
return
# Mark task as running
await repo.mark_as_running(task)
# Execute the task
try:
handler_registry = TaskHandlerRegistry(
session, self.db_session_factory, self.credit_service, self.player_service
session, self.db_session_factory, self.credit_service, self.player_service,
)
await handler_registry.execute_task(task)
# Calculate next execution time for recurring tasks
next_execution_at = None
if task.should_repeat():
next_execution_at = self._calculate_next_execution(task)
# Mark as completed
await repo.mark_as_completed(task, next_execution_at)
# Reschedule if recurring
if next_execution_at and task.should_repeat():
# Refresh task to get updated data
await session.refresh(task)
await self._schedule_apscheduler_job(task)
except Exception as e:
await repo.mark_as_failed(task, str(e))
logger.exception(f"Task {task_id} execution failed: {str(e)}")
logger.exception(f"Task {task_id} execution failed: {e!s}")
finally:
self._running_tasks.discard(task_id_str)
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
"""Calculate the next execution time for a recurring task."""
now = datetime.utcnow()
if task.recurrence_type == RecurrenceType.HOURLY:
return now + timedelta(hours=1)
elif task.recurrence_type == RecurrenceType.DAILY:
if task.recurrence_type == RecurrenceType.DAILY:
return now + timedelta(days=1)
elif task.recurrence_type == RecurrenceType.WEEKLY:
if task.recurrence_type == RecurrenceType.WEEKLY:
return now + timedelta(weeks=1)
elif task.recurrence_type == RecurrenceType.MONTHLY:
if task.recurrence_type == RecurrenceType.MONTHLY:
# Add approximately one month
return now + timedelta(days=30)
elif task.recurrence_type == RecurrenceType.YEARLY:
if task.recurrence_type == RecurrenceType.YEARLY:
return now + timedelta(days=365)
return None
async def _maintenance_job(self) -> None:
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
try:
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Handle expired tasks
expired_tasks = await repo.get_expired_tasks()
for task in expired_tasks:
task.status = TaskStatus.CANCELLED
task.is_active = False
await repo.update(task)
# Remove from scheduler
try:
self.scheduler.remove_job(str(task.id))
except Exception:
pass
if expired_tasks:
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
# Handle any missed recurring tasks
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
for task in due_recurring:
@@ -405,9 +405,9 @@ class SchedulerService:
task.scheduled_at = task.next_execution_at or datetime.utcnow()
await repo.update(task)
await self._schedule_apscheduler_job(task)
if due_recurring:
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
except Exception:
logger.exception("Maintenance job failed")

View File

@@ -1,6 +1,5 @@
"""Task execution handlers for different task types."""
from typing import Any, Dict, Optional
from collections.abc import Callable
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -18,7 +17,6 @@ logger = get_logger(__name__)
class TaskExecutionError(Exception):
"""Exception raised when task execution fails."""
pass
class TaskHandlerRegistry:
@@ -58,8 +56,8 @@ class TaskHandlerRegistry:
await handler(task)
logger.info(f"Task {task.id} executed successfully")
except Exception as e:
logger.exception(f"Task {task.id} execution failed: {str(e)}")
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
logger.exception(f"Task {task.id} execution failed: {e!s}")
raise TaskExecutionError(f"Task execution failed: {e!s}") from e
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
"""Handle credit recharge task."""
@@ -72,7 +70,7 @@ class TaskHandlerRegistry:
user_id_int = int(user_id)
except (ValueError, TypeError) as e:
raise TaskExecutionError(f"Invalid user_id format: {user_id}") from e
stats = await self.credit_service.recharge_user_credits(user_id_int)
logger.info(f"Recharged credits for user {user_id}: {stats}")
else:
@@ -105,7 +103,7 @@ class TaskHandlerRegistry:
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
except Exception as e:
# Convert HTTP exceptions or credit errors to task execution errors
raise TaskExecutionError(f"Failed to play sound with credits: {str(e)}") from e
raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e
else:
# System task: play without credit deduction
sound = await self.sound_repository.get_by_id(sound_id_int)
@@ -116,10 +114,10 @@ class TaskHandlerRegistry:
vlc_service = VLCPlayerService(self.db_session_factory)
success = await vlc_service.play_sound(sound)
if not success:
raise TaskExecutionError(f"Failed to play sound {sound.filename}")
logger.info(f"Played sound {sound.filename} via scheduled system task")
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
@@ -157,4 +155,4 @@ class TaskHandlerRegistry:
# Start playing
await self.player_service.play()
logger.info(f"Started playing playlist {playlist.name} via scheduled task")
logger.info(f"Started playing playlist {playlist.name} via scheduled task")

View File

@@ -238,13 +238,13 @@ class VLCPlayerService:
return
logger.info("Recording play count for sound %s", sound_id)
# Initialize variables for WebSocket event
old_count = 0
sound = None
admin_user_id = None
admin_user_name = None
try:
async with self.db_session_factory() as session:
sound_repo = SoundRepository(session)