Refactor scheduled task repository and schemas for improved type hints and consistency
- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
@@ -1,7 +1,5 @@
|
|||||||
"""API endpoints for scheduled task management."""
|
"""API endpoints for scheduled task management."""
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -11,7 +9,7 @@ from app.core.dependencies import (
|
|||||||
get_admin_user,
|
get_admin_user,
|
||||||
get_current_active_user,
|
get_current_active_user,
|
||||||
)
|
)
|
||||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.scheduler import (
|
from app.schemas.scheduler import (
|
||||||
ScheduledTaskCreate,
|
ScheduledTaskCreate,
|
||||||
@@ -54,15 +52,15 @@ async def create_task(
|
|||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tasks", response_model=List[ScheduledTaskResponse])
|
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
|
||||||
async def get_user_tasks(
|
async def get_user_tasks(
|
||||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||||
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"),
|
limit: int | None = Query(50, description="Maximum number of tasks to return"),
|
||||||
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
offset: int | None = Query(0, description="Number of tasks to skip"),
|
||||||
current_user: User = Depends(get_current_active_user),
|
current_user: User = Depends(get_current_active_user),
|
||||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||||
) -> List[ScheduledTask]:
|
) -> list[ScheduledTask]:
|
||||||
"""Get user's scheduled tasks."""
|
"""Get user's scheduled tasks."""
|
||||||
return await scheduler_service.get_user_tasks(
|
return await scheduler_service.get_user_tasks(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -81,17 +79,17 @@ async def get_task(
|
|||||||
) -> ScheduledTask:
|
) -> ScheduledTask:
|
||||||
"""Get a specific scheduled task."""
|
"""Get a specific scheduled task."""
|
||||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
repo = ScheduledTaskRepository(db_session)
|
repo = ScheduledTaskRepository(db_session)
|
||||||
task = await repo.get_by_id(task_id)
|
task = await repo.get_by_id(task_id)
|
||||||
|
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
# Check if user owns the task or is admin
|
# Check if user owns the task or is admin
|
||||||
if task.user_id != current_user.id and not current_user.is_admin:
|
if task.user_id != current_user.id and not current_user.is_admin:
|
||||||
raise HTTPException(status_code=403, detail="Access denied")
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
|
||||||
@@ -104,22 +102,22 @@ async def update_task(
|
|||||||
) -> ScheduledTask:
|
) -> ScheduledTask:
|
||||||
"""Update a scheduled task."""
|
"""Update a scheduled task."""
|
||||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
repo = ScheduledTaskRepository(db_session)
|
repo = ScheduledTaskRepository(db_session)
|
||||||
task = await repo.get_by_id(task_id)
|
task = await repo.get_by_id(task_id)
|
||||||
|
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
# Check if user owns the task or is admin
|
# Check if user owns the task or is admin
|
||||||
if task.user_id != current_user.id and not current_user.is_admin:
|
if task.user_id != current_user.id and not current_user.is_admin:
|
||||||
raise HTTPException(status_code=403, detail="Access denied")
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
# Update task fields
|
# Update task fields
|
||||||
update_data = task_update.model_dump(exclude_unset=True)
|
update_data = task_update.model_dump(exclude_unset=True)
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(task, field, value)
|
setattr(task, field, value)
|
||||||
|
|
||||||
updated_task = await repo.update(task)
|
updated_task = await repo.update(task)
|
||||||
return updated_task
|
return updated_task
|
||||||
|
|
||||||
@@ -133,72 +131,72 @@ async def cancel_task(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Cancel a scheduled task."""
|
"""Cancel a scheduled task."""
|
||||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
repo = ScheduledTaskRepository(db_session)
|
repo = ScheduledTaskRepository(db_session)
|
||||||
task = await repo.get_by_id(task_id)
|
task = await repo.get_by_id(task_id)
|
||||||
|
|
||||||
if not task:
|
if not task:
|
||||||
raise HTTPException(status_code=404, detail="Task not found")
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
# Check if user owns the task or is admin
|
# Check if user owns the task or is admin
|
||||||
if task.user_id != current_user.id and not current_user.is_admin:
|
if task.user_id != current_user.id and not current_user.is_admin:
|
||||||
raise HTTPException(status_code=403, detail="Access denied")
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
success = await scheduler_service.cancel_task(task_id)
|
success = await scheduler_service.cancel_task(task_id)
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(status_code=400, detail="Failed to cancel task")
|
raise HTTPException(status_code=400, detail="Failed to cancel task")
|
||||||
|
|
||||||
return {"message": "Task cancelled successfully"}
|
return {"message": "Task cancelled successfully"}
|
||||||
|
|
||||||
|
|
||||||
# Admin-only endpoints
|
# Admin-only endpoints
|
||||||
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse])
|
@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
|
||||||
async def get_all_tasks(
|
async def get_all_tasks(
|
||||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||||
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"),
|
limit: int | None = Query(100, description="Maximum number of tasks to return"),
|
||||||
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
offset: int | None = Query(0, description="Number of tasks to skip"),
|
||||||
current_user: User = Depends(get_admin_user),
|
current_user: User = Depends(get_admin_user),
|
||||||
db_session: AsyncSession = Depends(get_db),
|
db_session: AsyncSession = Depends(get_db),
|
||||||
) -> List[ScheduledTask]:
|
) -> list[ScheduledTask]:
|
||||||
"""Get all scheduled tasks (admin only)."""
|
"""Get all scheduled tasks (admin only)."""
|
||||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
repo = ScheduledTaskRepository(db_session)
|
repo = ScheduledTaskRepository(db_session)
|
||||||
|
|
||||||
# Get all tasks with pagination and filtering
|
# Get all tasks with pagination and filtering
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
|
|
||||||
statement = select(ScheduledTask)
|
statement = select(ScheduledTask)
|
||||||
|
|
||||||
if status:
|
if status:
|
||||||
statement = statement.where(ScheduledTask.status == status)
|
statement = statement.where(ScheduledTask.status == status)
|
||||||
|
|
||||||
if task_type:
|
if task_type:
|
||||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||||
|
|
||||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||||
|
|
||||||
if offset:
|
if offset:
|
||||||
statement = statement.offset(offset)
|
statement = statement.offset(offset)
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
statement = statement.limit(limit)
|
statement = statement.limit(limit)
|
||||||
|
|
||||||
result = await db_session.exec(statement)
|
result = await db_session.exec(statement)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse])
|
@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
|
||||||
async def get_system_tasks(
|
async def get_system_tasks(
|
||||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||||
current_user: User = Depends(get_admin_user),
|
current_user: User = Depends(get_admin_user),
|
||||||
db_session: AsyncSession = Depends(get_db),
|
db_session: AsyncSession = Depends(get_db),
|
||||||
) -> List[ScheduledTask]:
|
) -> list[ScheduledTask]:
|
||||||
"""Get system tasks (admin only)."""
|
"""Get system tasks (admin only)."""
|
||||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
repo = ScheduledTaskRepository(db_session)
|
repo = ScheduledTaskRepository(db_session)
|
||||||
return await repo.get_system_tasks(status=status, task_type=task_type)
|
return await repo.get_system_tasks(status=status, task_type=task_type)
|
||||||
|
|
||||||
@@ -225,4 +223,4 @@ async def create_system_task(
|
|||||||
)
|
)
|
||||||
return task
|
return task
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
# Import all models to ensure SQLModel metadata discovery
|
||||||
|
import app.models # noqa: F401
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.core.seeds import seed_all_data
|
from app.core.seeds import seed_all_data
|
||||||
# Import all models to ensure SQLModel metadata discovery
|
|
||||||
import app.models # noqa: F401
|
|
||||||
|
|
||||||
engine: AsyncEngine = create_async_engine(
|
engine: AsyncEngine = create_async_engine(
|
||||||
settings.DATABASE_URL,
|
settings.DATABASE_URL,
|
||||||
|
|||||||
@@ -11,11 +11,14 @@ from app.core.database import get_session_factory, init_db
|
|||||||
from app.core.logging import get_logger, setup_logging
|
from app.core.logging import get_logger, setup_logging
|
||||||
from app.middleware.logging import LoggingMiddleware
|
from app.middleware.logging import LoggingMiddleware
|
||||||
from app.services.extraction_processor import extraction_processor
|
from app.services.extraction_processor import extraction_processor
|
||||||
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service
|
from app.services.player import (
|
||||||
|
get_player_service,
|
||||||
|
initialize_player_service,
|
||||||
|
shutdown_player_service,
|
||||||
|
)
|
||||||
from app.services.scheduler import SchedulerService
|
from app.services.scheduler import SchedulerService
|
||||||
from app.services.socket import socket_manager
|
from app.services.socket import socket_manager
|
||||||
|
|
||||||
|
|
||||||
scheduler_service = None
|
scheduler_service = None
|
||||||
|
|
||||||
|
|
||||||
@@ -31,7 +34,7 @@ def get_global_scheduler_service() -> SchedulerService:
|
|||||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
||||||
"""Application lifespan context manager for setup and teardown."""
|
"""Application lifespan context manager for setup and teardown."""
|
||||||
global scheduler_service
|
global scheduler_service
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
logger.info("Starting application")
|
logger.info("Starting application")
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ __all__ = [
|
|||||||
"CreditAction",
|
"CreditAction",
|
||||||
"CreditTransaction",
|
"CreditTransaction",
|
||||||
"Extraction",
|
"Extraction",
|
||||||
"Favorite",
|
"Favorite",
|
||||||
"Plan",
|
"Plan",
|
||||||
"Playlist",
|
"Playlist",
|
||||||
"PlaylistSound",
|
"PlaylistSound",
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
"""Scheduled task model for flexible task scheduling with timezone support."""
|
"""Scheduled task model for flexible task scheduling with timezone support."""
|
||||||
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from sqlmodel import JSON, Column, Field, SQLModel
|
from sqlmodel import JSON, Column, Field
|
||||||
|
|
||||||
from app.models.base import BaseModel
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
@@ -57,11 +56,11 @@ class ScheduledTask(BaseModel, table=True):
|
|||||||
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
|
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
|
||||||
)
|
)
|
||||||
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
|
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
|
||||||
cron_expression: Optional[str] = Field(
|
cron_expression: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
|
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
|
||||||
)
|
)
|
||||||
recurrence_count: Optional[int] = Field(
|
recurrence_count: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Number of times to repeat (None for infinite)",
|
description="Number of times to repeat (None for infinite)",
|
||||||
)
|
)
|
||||||
@@ -75,29 +74,29 @@ class ScheduledTask(BaseModel, table=True):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# User association (None for system tasks)
|
# User association (None for system tasks)
|
||||||
user_id: Optional[int] = Field(
|
user_id: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
foreign_key="user.id",
|
foreign_key="user.id",
|
||||||
description="User who created the task (None for system tasks)",
|
description="User who created the task (None for system tasks)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execution tracking
|
# Execution tracking
|
||||||
last_executed_at: Optional[datetime] = Field(
|
last_executed_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="When the task was last executed (UTC)",
|
description="When the task was last executed (UTC)",
|
||||||
)
|
)
|
||||||
next_execution_at: Optional[datetime] = Field(
|
next_execution_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="When the task should be executed next (UTC, for recurring tasks)",
|
description="When the task should be executed next (UTC, for recurring tasks)",
|
||||||
)
|
)
|
||||||
error_message: Optional[str] = Field(
|
error_message: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Error message if execution failed",
|
description="Error message if execution failed",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Task lifecycle
|
# Task lifecycle
|
||||||
is_active: bool = Field(default=True, description="Whether the task is active")
|
is_active: bool = Field(default=True, description="Whether the task is active")
|
||||||
expires_at: Optional[datetime] = Field(
|
expires_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="When the task expires (UTC, optional)",
|
description="When the task expires (UTC, optional)",
|
||||||
)
|
)
|
||||||
@@ -122,4 +121,4 @@ class ScheduledTask(BaseModel, table=True):
|
|||||||
|
|
||||||
def is_system_task(self) -> bool:
|
def is_system_task(self) -> bool:
|
||||||
"""Check if this is a system task (no user association)."""
|
"""Check if this is a system task (no user association)."""
|
||||||
return self.user_id is None
|
return self.user_id is None
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
"""Repository for scheduled task operations."""
|
"""Repository for scheduled task operations."""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
from app.models.scheduled_task import (
|
||||||
|
RecurrenceType,
|
||||||
|
ScheduledTask,
|
||||||
|
TaskStatus,
|
||||||
|
TaskType,
|
||||||
|
)
|
||||||
from app.repositories.base import BaseRepository
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
@@ -17,7 +21,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
"""Initialize the repository."""
|
"""Initialize the repository."""
|
||||||
super().__init__(ScheduledTask, session)
|
super().__init__(ScheduledTask, session)
|
||||||
|
|
||||||
async def get_pending_tasks(self) -> List[ScheduledTask]:
|
async def get_pending_tasks(self) -> list[ScheduledTask]:
|
||||||
"""Get all pending tasks that are ready to be executed."""
|
"""Get all pending tasks that are ready to be executed."""
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
statement = select(ScheduledTask).where(
|
statement = select(ScheduledTask).where(
|
||||||
@@ -28,7 +32,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
async def get_active_tasks(self) -> List[ScheduledTask]:
|
async def get_active_tasks(self) -> list[ScheduledTask]:
|
||||||
"""Get all active tasks."""
|
"""Get all active tasks."""
|
||||||
statement = select(ScheduledTask).where(
|
statement = select(ScheduledTask).where(
|
||||||
ScheduledTask.is_active.is_(True),
|
ScheduledTask.is_active.is_(True),
|
||||||
@@ -40,11 +44,11 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
async def get_user_tasks(
|
async def get_user_tasks(
|
||||||
self,
|
self,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
status: Optional[TaskStatus] = None,
|
status: TaskStatus | None = None,
|
||||||
task_type: Optional[TaskType] = None,
|
task_type: TaskType | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
offset: Optional[int] = None,
|
offset: int | None = None,
|
||||||
) -> List[ScheduledTask]:
|
) -> list[ScheduledTask]:
|
||||||
"""Get tasks for a specific user."""
|
"""Get tasks for a specific user."""
|
||||||
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
||||||
|
|
||||||
@@ -67,9 +71,9 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
|
|
||||||
async def get_system_tasks(
|
async def get_system_tasks(
|
||||||
self,
|
self,
|
||||||
status: Optional[TaskStatus] = None,
|
status: TaskStatus | None = None,
|
||||||
task_type: Optional[TaskType] = None,
|
task_type: TaskType | None = None,
|
||||||
) -> List[ScheduledTask]:
|
) -> list[ScheduledTask]:
|
||||||
"""Get system tasks (tasks with no user association)."""
|
"""Get system tasks (tasks with no user association)."""
|
||||||
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
||||||
|
|
||||||
@@ -84,7 +88,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
|
async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
|
||||||
"""Get recurring tasks that need their next execution scheduled."""
|
"""Get recurring tasks that need their next execution scheduled."""
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
statement = select(ScheduledTask).where(
|
statement = select(ScheduledTask).where(
|
||||||
@@ -96,7 +100,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
result = await self.session.exec(statement)
|
result = await self.session.exec(statement)
|
||||||
return list(result.all())
|
return list(result.all())
|
||||||
|
|
||||||
async def get_expired_tasks(self) -> List[ScheduledTask]:
|
async def get_expired_tasks(self) -> list[ScheduledTask]:
|
||||||
"""Get expired tasks that should be cleaned up."""
|
"""Get expired tasks that should be cleaned up."""
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
statement = select(ScheduledTask).where(
|
statement = select(ScheduledTask).where(
|
||||||
@@ -110,7 +114,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
async def cancel_user_tasks(
|
async def cancel_user_tasks(
|
||||||
self,
|
self,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
task_type: Optional[TaskType] = None,
|
task_type: TaskType | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Cancel all pending/running tasks for a user."""
|
"""Cancel all pending/running tasks for a user."""
|
||||||
statement = select(ScheduledTask).where(
|
statement = select(ScheduledTask).where(
|
||||||
@@ -144,7 +148,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
async def mark_as_completed(
|
async def mark_as_completed(
|
||||||
self,
|
self,
|
||||||
task: ScheduledTask,
|
task: ScheduledTask,
|
||||||
next_execution_at: Optional[datetime] = None,
|
next_execution_at: datetime | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Mark a task as completed and set next execution if recurring."""
|
"""Mark a task as completed and set next execution if recurring."""
|
||||||
task.status = TaskStatus.COMPLETED
|
task.status = TaskStatus.COMPLETED
|
||||||
@@ -174,4 +178,4 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
|||||||
|
|
||||||
self.session.add(task)
|
self.session.add(task)
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
await self.session.refresh(task)
|
await self.session.refresh(task)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Schemas for scheduled task API."""
|
"""Schemas for scheduled task API."""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ class ScheduledTaskBase(BaseModel):
|
|||||||
task_type: TaskType = Field(description="Type of task to execute")
|
task_type: TaskType = Field(description="Type of task to execute")
|
||||||
scheduled_at: datetime = Field(description="When the task should be executed")
|
scheduled_at: datetime = Field(description="When the task should be executed")
|
||||||
timezone: str = Field(default="UTC", description="Timezone for scheduling")
|
timezone: str = Field(default="UTC", description="Timezone for scheduling")
|
||||||
parameters: Dict[str, Any] = Field(
|
parameters: dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Task-specific parameters",
|
description="Task-specific parameters",
|
||||||
)
|
)
|
||||||
@@ -23,15 +23,15 @@ class ScheduledTaskBase(BaseModel):
|
|||||||
default=RecurrenceType.NONE,
|
default=RecurrenceType.NONE,
|
||||||
description="Recurrence pattern",
|
description="Recurrence pattern",
|
||||||
)
|
)
|
||||||
cron_expression: Optional[str] = Field(
|
cron_expression: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Cron expression for custom recurrence",
|
description="Cron expression for custom recurrence",
|
||||||
)
|
)
|
||||||
recurrence_count: Optional[int] = Field(
|
recurrence_count: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Number of times to repeat (None for infinite)",
|
description="Number of times to repeat (None for infinite)",
|
||||||
)
|
)
|
||||||
expires_at: Optional[datetime] = Field(
|
expires_at: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="When the task expires (optional)",
|
description="When the task expires (optional)",
|
||||||
)
|
)
|
||||||
@@ -40,18 +40,17 @@ class ScheduledTaskBase(BaseModel):
|
|||||||
class ScheduledTaskCreate(ScheduledTaskBase):
|
class ScheduledTaskCreate(ScheduledTaskBase):
|
||||||
"""Schema for creating a scheduled task."""
|
"""Schema for creating a scheduled task."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ScheduledTaskUpdate(BaseModel):
|
class ScheduledTaskUpdate(BaseModel):
|
||||||
"""Schema for updating a scheduled task."""
|
"""Schema for updating a scheduled task."""
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
scheduled_at: Optional[datetime] = None
|
scheduled_at: datetime | None = None
|
||||||
timezone: Optional[str] = None
|
timezone: str | None = None
|
||||||
parameters: Optional[Dict[str, Any]] = None
|
parameters: dict[str, Any] | None = None
|
||||||
is_active: Optional[bool] = None
|
is_active: bool | None = None
|
||||||
expires_at: Optional[datetime] = None
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class ScheduledTaskResponse(ScheduledTaskBase):
|
class ScheduledTaskResponse(ScheduledTaskBase):
|
||||||
@@ -59,11 +58,11 @@ class ScheduledTaskResponse(ScheduledTaskBase):
|
|||||||
|
|
||||||
id: int
|
id: int
|
||||||
status: TaskStatus
|
status: TaskStatus
|
||||||
user_id: Optional[int] = None
|
user_id: int | None = None
|
||||||
executions_count: int
|
executions_count: int
|
||||||
last_executed_at: Optional[datetime] = None
|
last_executed_at: datetime | None = None
|
||||||
next_execution_at: Optional[datetime] = None
|
next_execution_at: datetime | None = None
|
||||||
error_message: Optional[str] = None
|
error_message: str | None = None
|
||||||
is_active: bool
|
is_active: bool
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
@@ -78,7 +77,7 @@ class ScheduledTaskResponse(ScheduledTaskBase):
|
|||||||
class CreditRechargeParameters(BaseModel):
|
class CreditRechargeParameters(BaseModel):
|
||||||
"""Parameters for credit recharge tasks."""
|
"""Parameters for credit recharge tasks."""
|
||||||
|
|
||||||
user_id: Optional[int] = Field(
|
user_id: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Specific user ID to recharge (None for all users)",
|
description="Specific user ID to recharge (None for all users)",
|
||||||
)
|
)
|
||||||
@@ -109,10 +108,10 @@ class CreateCreditRechargeTask(BaseModel):
|
|||||||
scheduled_at: datetime
|
scheduled_at: datetime
|
||||||
timezone: str = "UTC"
|
timezone: str = "UTC"
|
||||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||||
cron_expression: Optional[str] = None
|
cron_expression: str | None = None
|
||||||
recurrence_count: Optional[int] = None
|
recurrence_count: int | None = None
|
||||||
expires_at: Optional[datetime] = None
|
expires_at: datetime | None = None
|
||||||
user_id: Optional[int] = None
|
user_id: int | None = None
|
||||||
|
|
||||||
def to_task_create(self) -> ScheduledTaskCreate:
|
def to_task_create(self) -> ScheduledTaskCreate:
|
||||||
"""Convert to generic task creation schema."""
|
"""Convert to generic task creation schema."""
|
||||||
@@ -137,9 +136,9 @@ class CreatePlaySoundTask(BaseModel):
|
|||||||
sound_id: int
|
sound_id: int
|
||||||
timezone: str = "UTC"
|
timezone: str = "UTC"
|
||||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||||
cron_expression: Optional[str] = None
|
cron_expression: str | None = None
|
||||||
recurrence_count: Optional[int] = None
|
recurrence_count: int | None = None
|
||||||
expires_at: Optional[datetime] = None
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
def to_task_create(self) -> ScheduledTaskCreate:
|
def to_task_create(self) -> ScheduledTaskCreate:
|
||||||
"""Convert to generic task creation schema."""
|
"""Convert to generic task creation schema."""
|
||||||
@@ -166,9 +165,9 @@ class CreatePlayPlaylistTask(BaseModel):
|
|||||||
shuffle: bool = False
|
shuffle: bool = False
|
||||||
timezone: str = "UTC"
|
timezone: str = "UTC"
|
||||||
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||||
cron_expression: Optional[str] = None
|
cron_expression: str | None = None
|
||||||
recurrence_count: Optional[int] = None
|
recurrence_count: int | None = None
|
||||||
expires_at: Optional[datetime] = None
|
expires_at: datetime | None = None
|
||||||
|
|
||||||
def to_task_create(self) -> ScheduledTaskCreate:
|
def to_task_create(self) -> ScheduledTaskCreate:
|
||||||
"""Convert to generic task creation schema."""
|
"""Convert to generic task creation schema."""
|
||||||
@@ -186,4 +185,4 @@ class CreatePlayPlaylistTask(BaseModel):
|
|||||||
cron_expression=self.cron_expression,
|
cron_expression=self.cron_expression,
|
||||||
recurrence_count=self.recurrence_count,
|
recurrence_count=self.recurrence_count,
|
||||||
expires_at=self.expires_at,
|
expires_at=self.expires_at,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import pytz
|
import pytz
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
@@ -52,7 +52,7 @@ class SchedulerService:
|
|||||||
logger.info("Starting enhanced scheduler service...")
|
logger.info("Starting enhanced scheduler service...")
|
||||||
|
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
|
|
||||||
# Schedule system tasks initialization for after startup
|
# Schedule system tasks initialization for after startup
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self._initialize_system_tasks,
|
self._initialize_system_tasks,
|
||||||
@@ -62,7 +62,7 @@ class SchedulerService:
|
|||||||
name="Initialize System Tasks",
|
name="Initialize System Tasks",
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Schedule periodic cleanup and maintenance
|
# Schedule periodic cleanup and maintenance
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self._maintenance_job,
|
self._maintenance_job,
|
||||||
@@ -86,18 +86,18 @@ class SchedulerService:
|
|||||||
name: str,
|
name: str,
|
||||||
task_type: TaskType,
|
task_type: TaskType,
|
||||||
scheduled_at: datetime,
|
scheduled_at: datetime,
|
||||||
parameters: Optional[Dict[str, Any]] = None,
|
parameters: dict[str, Any] | None = None,
|
||||||
user_id: Optional[int] = None,
|
user_id: int | None = None,
|
||||||
timezone: str = "UTC",
|
timezone: str = "UTC",
|
||||||
recurrence_type: RecurrenceType = RecurrenceType.NONE,
|
recurrence_type: RecurrenceType = RecurrenceType.NONE,
|
||||||
cron_expression: Optional[str] = None,
|
cron_expression: str | None = None,
|
||||||
recurrence_count: Optional[int] = None,
|
recurrence_count: int | None = None,
|
||||||
expires_at: Optional[datetime] = None,
|
expires_at: datetime | None = None,
|
||||||
) -> ScheduledTask:
|
) -> ScheduledTask:
|
||||||
"""Create a new scheduled task."""
|
"""Create a new scheduled task."""
|
||||||
async with self.db_session_factory() as session:
|
async with self.db_session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
# Convert scheduled_at to UTC if it's in a different timezone
|
# Convert scheduled_at to UTC if it's in a different timezone
|
||||||
if timezone != "UTC":
|
if timezone != "UTC":
|
||||||
tz = pytz.timezone(timezone)
|
tz = pytz.timezone(timezone)
|
||||||
@@ -105,7 +105,7 @@ class SchedulerService:
|
|||||||
# Assume the datetime is in the specified timezone
|
# Assume the datetime is in the specified timezone
|
||||||
scheduled_at = tz.localize(scheduled_at)
|
scheduled_at = tz.localize(scheduled_at)
|
||||||
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
task_data = {
|
task_data = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"task_type": task_type,
|
"task_type": task_type,
|
||||||
@@ -118,59 +118,59 @@ class SchedulerService:
|
|||||||
"recurrence_count": recurrence_count,
|
"recurrence_count": recurrence_count,
|
||||||
"expires_at": expires_at,
|
"expires_at": expires_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
created_task = await repo.create(task_data)
|
created_task = await repo.create(task_data)
|
||||||
await self._schedule_apscheduler_job(created_task)
|
await self._schedule_apscheduler_job(created_task)
|
||||||
|
|
||||||
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
|
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
|
||||||
return created_task
|
return created_task
|
||||||
|
|
||||||
async def cancel_task(self, task_id: int) -> bool:
|
async def cancel_task(self, task_id: int) -> bool:
|
||||||
"""Cancel a scheduled task."""
|
"""Cancel a scheduled task."""
|
||||||
async with self.db_session_factory() as session:
|
async with self.db_session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
task = await repo.get_by_id(task_id)
|
task = await repo.get_by_id(task_id)
|
||||||
if not task:
|
if not task:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
task.status = TaskStatus.CANCELLED
|
task.status = TaskStatus.CANCELLED
|
||||||
task.is_active = False
|
task.is_active = False
|
||||||
await repo.update(task)
|
await repo.update(task)
|
||||||
|
|
||||||
# Remove from APScheduler
|
# Remove from APScheduler
|
||||||
try:
|
try:
|
||||||
self.scheduler.remove_job(str(task_id))
|
self.scheduler.remove_job(str(task_id))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Job might not exist in scheduler
|
pass # Job might not exist in scheduler
|
||||||
|
|
||||||
logger.info(f"Cancelled task: {task.name} ({task_id})")
|
logger.info(f"Cancelled task: {task.name} ({task_id})")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_user_tasks(
|
async def get_user_tasks(
|
||||||
self,
|
self,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
status: Optional[TaskStatus] = None,
|
status: TaskStatus | None = None,
|
||||||
task_type: Optional[TaskType] = None,
|
task_type: TaskType | None = None,
|
||||||
limit: Optional[int] = None,
|
limit: int | None = None,
|
||||||
offset: Optional[int] = None,
|
offset: int | None = None,
|
||||||
) -> List[ScheduledTask]:
|
) -> list[ScheduledTask]:
|
||||||
"""Get tasks for a specific user."""
|
"""Get tasks for a specific user."""
|
||||||
async with self.db_session_factory() as session:
|
async with self.db_session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
|
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
|
||||||
|
|
||||||
async def _initialize_system_tasks(self) -> None:
|
async def _initialize_system_tasks(self) -> None:
|
||||||
"""Initialize system tasks and load active tasks from database."""
|
"""Initialize system tasks and load active tasks from database."""
|
||||||
logger.info("Initializing system tasks...")
|
logger.info("Initializing system tasks...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create system tasks if they don't exist
|
# Create system tasks if they don't exist
|
||||||
await self._ensure_system_tasks()
|
await self._ensure_system_tasks()
|
||||||
|
|
||||||
# Load all active tasks from database
|
# Load all active tasks from database
|
||||||
await self._load_active_tasks()
|
await self._load_active_tasks()
|
||||||
|
|
||||||
logger.info("System tasks initialized successfully")
|
logger.info("System tasks initialized successfully")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to initialize system tasks")
|
logger.exception("Failed to initialize system tasks")
|
||||||
@@ -179,24 +179,24 @@ class SchedulerService:
|
|||||||
"""Ensure required system tasks exist."""
|
"""Ensure required system tasks exist."""
|
||||||
async with self.db_session_factory() as session:
|
async with self.db_session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
# Check if daily credit recharge task exists
|
# Check if daily credit recharge task exists
|
||||||
system_tasks = await repo.get_system_tasks(
|
system_tasks = await repo.get_system_tasks(
|
||||||
task_type=TaskType.CREDIT_RECHARGE
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
)
|
)
|
||||||
|
|
||||||
daily_recharge_exists = any(
|
daily_recharge_exists = any(
|
||||||
task.recurrence_type == RecurrenceType.DAILY
|
task.recurrence_type == RecurrenceType.DAILY
|
||||||
and task.is_active
|
and task.is_active
|
||||||
for task in system_tasks
|
for task in system_tasks
|
||||||
)
|
)
|
||||||
|
|
||||||
if not daily_recharge_exists:
|
if not daily_recharge_exists:
|
||||||
# Create daily credit recharge task
|
# Create daily credit recharge task
|
||||||
tomorrow_midnight = datetime.utcnow().replace(
|
tomorrow_midnight = datetime.utcnow().replace(
|
||||||
hour=0, minute=0, second=0, microsecond=0
|
hour=0, minute=0, second=0, microsecond=0,
|
||||||
) + timedelta(days=1)
|
) + timedelta(days=1)
|
||||||
|
|
||||||
task_data = {
|
task_data = {
|
||||||
"name": "Daily Credit Recharge",
|
"name": "Daily Credit Recharge",
|
||||||
"task_type": TaskType.CREDIT_RECHARGE,
|
"task_type": TaskType.CREDIT_RECHARGE,
|
||||||
@@ -204,41 +204,41 @@ class SchedulerService:
|
|||||||
"recurrence_type": RecurrenceType.DAILY,
|
"recurrence_type": RecurrenceType.DAILY,
|
||||||
"parameters": {},
|
"parameters": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
await repo.create(task_data)
|
await repo.create(task_data)
|
||||||
logger.info("Created system daily credit recharge task")
|
logger.info("Created system daily credit recharge task")
|
||||||
|
|
||||||
async def _load_active_tasks(self) -> None:
|
async def _load_active_tasks(self) -> None:
|
||||||
"""Load all active tasks from database into scheduler."""
|
"""Load all active tasks from database into scheduler."""
|
||||||
async with self.db_session_factory() as session:
|
async with self.db_session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
active_tasks = await repo.get_active_tasks()
|
active_tasks = await repo.get_active_tasks()
|
||||||
|
|
||||||
for task in active_tasks:
|
for task in active_tasks:
|
||||||
await self._schedule_apscheduler_job(task)
|
await self._schedule_apscheduler_job(task)
|
||||||
|
|
||||||
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
|
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
|
||||||
|
|
||||||
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
||||||
"""Schedule a task in APScheduler."""
|
"""Schedule a task in APScheduler."""
|
||||||
job_id = str(task.id)
|
job_id = str(task.id)
|
||||||
|
|
||||||
# Remove existing job if it exists
|
# Remove existing job if it exists
|
||||||
try:
|
try:
|
||||||
self.scheduler.remove_job(job_id)
|
self.scheduler.remove_job(job_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Don't schedule if task is not active or already completed/failed
|
# Don't schedule if task is not active or already completed/failed
|
||||||
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Create trigger based on recurrence type
|
# Create trigger based on recurrence type
|
||||||
trigger = self._create_trigger(task)
|
trigger = self._create_trigger(task)
|
||||||
if not trigger:
|
if not trigger:
|
||||||
logger.warning(f"Could not create trigger for task {task.id}")
|
logger.warning(f"Could not create trigger for task {task.id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Schedule the job
|
# Schedule the job
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self._execute_task,
|
self._execute_task,
|
||||||
@@ -248,76 +248,76 @@ class SchedulerService:
|
|||||||
name=task.name,
|
name=task.name,
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Scheduled APScheduler job for task {task.id}")
|
logger.debug(f"Scheduled APScheduler job for task {task.id}")
|
||||||
|
|
||||||
def _create_trigger(self, task: ScheduledTask):
|
def _create_trigger(self, task: ScheduledTask):
|
||||||
"""Create APScheduler trigger based on task configuration."""
|
"""Create APScheduler trigger based on task configuration."""
|
||||||
tz = pytz.timezone(task.timezone)
|
tz = pytz.timezone(task.timezone)
|
||||||
|
|
||||||
if task.recurrence_type == RecurrenceType.NONE:
|
if task.recurrence_type == RecurrenceType.NONE:
|
||||||
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
|
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
|
||||||
|
|
||||||
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||||
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
||||||
|
|
||||||
elif task.recurrence_type == RecurrenceType.HOURLY:
|
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||||
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
|
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
|
||||||
|
|
||||||
elif task.recurrence_type == RecurrenceType.DAILY:
|
if task.recurrence_type == RecurrenceType.DAILY:
|
||||||
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
|
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
|
||||||
|
|
||||||
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
if task.recurrence_type == RecurrenceType.WEEKLY:
|
||||||
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
|
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
|
||||||
|
|
||||||
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
if task.recurrence_type == RecurrenceType.MONTHLY:
|
||||||
# Use cron trigger for monthly (more reliable than interval)
|
# Use cron trigger for monthly (more reliable than interval)
|
||||||
scheduled_time = task.scheduled_at
|
scheduled_time = task.scheduled_at
|
||||||
return CronTrigger(
|
return CronTrigger(
|
||||||
day=scheduled_time.day,
|
day=scheduled_time.day,
|
||||||
hour=scheduled_time.hour,
|
hour=scheduled_time.hour,
|
||||||
minute=scheduled_time.minute,
|
minute=scheduled_time.minute,
|
||||||
timezone=tz
|
timezone=tz,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif task.recurrence_type == RecurrenceType.YEARLY:
|
if task.recurrence_type == RecurrenceType.YEARLY:
|
||||||
scheduled_time = task.scheduled_at
|
scheduled_time = task.scheduled_at
|
||||||
return CronTrigger(
|
return CronTrigger(
|
||||||
month=scheduled_time.month,
|
month=scheduled_time.month,
|
||||||
day=scheduled_time.day,
|
day=scheduled_time.day,
|
||||||
hour=scheduled_time.hour,
|
hour=scheduled_time.hour,
|
||||||
minute=scheduled_time.minute,
|
minute=scheduled_time.minute,
|
||||||
timezone=tz
|
timezone=tz,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _execute_task(self, task_id: int) -> None:
|
async def _execute_task(self, task_id: int) -> None:
|
||||||
"""Execute a scheduled task."""
|
"""Execute a scheduled task."""
|
||||||
task_id_str = str(task_id)
|
task_id_str = str(task_id)
|
||||||
|
|
||||||
# Prevent concurrent execution of the same task
|
# Prevent concurrent execution of the same task
|
||||||
if task_id_str in self._running_tasks:
|
if task_id_str in self._running_tasks:
|
||||||
logger.warning(f"Task {task_id} is already running, skipping execution")
|
logger.warning(f"Task {task_id} is already running, skipping execution")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._running_tasks.add(task_id_str)
|
self._running_tasks.add(task_id_str)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.db_session_factory() as session:
|
async with self.db_session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
# Get fresh task data
|
# Get fresh task data
|
||||||
task = await repo.get_by_id(task_id)
|
task = await repo.get_by_id(task_id)
|
||||||
if not task:
|
if not task:
|
||||||
logger.warning(f"Task {task_id} not found")
|
logger.warning(f"Task {task_id} not found")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if task is still active and pending
|
# Check if task is still active and pending
|
||||||
if not task.is_active or task.status != TaskStatus.PENDING:
|
if not task.is_active or task.status != TaskStatus.PENDING:
|
||||||
logger.info(f"Task {task_id} is not active or not pending, skipping")
|
logger.info(f"Task {task_id} is not active or not pending, skipping")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if task has expired
|
# Check if task has expired
|
||||||
if task.is_expired():
|
if task.is_expired():
|
||||||
logger.info(f"Task {task_id} has expired, marking as cancelled")
|
logger.info(f"Task {task_id} has expired, marking as cancelled")
|
||||||
@@ -325,78 +325,78 @@ class SchedulerService:
|
|||||||
task.is_active = False
|
task.is_active = False
|
||||||
await repo.update(task)
|
await repo.update(task)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Mark task as running
|
# Mark task as running
|
||||||
await repo.mark_as_running(task)
|
await repo.mark_as_running(task)
|
||||||
|
|
||||||
# Execute the task
|
# Execute the task
|
||||||
try:
|
try:
|
||||||
handler_registry = TaskHandlerRegistry(
|
handler_registry = TaskHandlerRegistry(
|
||||||
session, self.db_session_factory, self.credit_service, self.player_service
|
session, self.db_session_factory, self.credit_service, self.player_service,
|
||||||
)
|
)
|
||||||
await handler_registry.execute_task(task)
|
await handler_registry.execute_task(task)
|
||||||
|
|
||||||
# Calculate next execution time for recurring tasks
|
# Calculate next execution time for recurring tasks
|
||||||
next_execution_at = None
|
next_execution_at = None
|
||||||
if task.should_repeat():
|
if task.should_repeat():
|
||||||
next_execution_at = self._calculate_next_execution(task)
|
next_execution_at = self._calculate_next_execution(task)
|
||||||
|
|
||||||
# Mark as completed
|
# Mark as completed
|
||||||
await repo.mark_as_completed(task, next_execution_at)
|
await repo.mark_as_completed(task, next_execution_at)
|
||||||
|
|
||||||
# Reschedule if recurring
|
# Reschedule if recurring
|
||||||
if next_execution_at and task.should_repeat():
|
if next_execution_at and task.should_repeat():
|
||||||
# Refresh task to get updated data
|
# Refresh task to get updated data
|
||||||
await session.refresh(task)
|
await session.refresh(task)
|
||||||
await self._schedule_apscheduler_job(task)
|
await self._schedule_apscheduler_job(task)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await repo.mark_as_failed(task, str(e))
|
await repo.mark_as_failed(task, str(e))
|
||||||
logger.exception(f"Task {task_id} execution failed: {str(e)}")
|
logger.exception(f"Task {task_id} execution failed: {e!s}")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self._running_tasks.discard(task_id_str)
|
self._running_tasks.discard(task_id_str)
|
||||||
|
|
||||||
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
|
def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
|
||||||
"""Calculate the next execution time for a recurring task."""
|
"""Calculate the next execution time for a recurring task."""
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
|
||||||
if task.recurrence_type == RecurrenceType.HOURLY:
|
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||||
return now + timedelta(hours=1)
|
return now + timedelta(hours=1)
|
||||||
elif task.recurrence_type == RecurrenceType.DAILY:
|
if task.recurrence_type == RecurrenceType.DAILY:
|
||||||
return now + timedelta(days=1)
|
return now + timedelta(days=1)
|
||||||
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
if task.recurrence_type == RecurrenceType.WEEKLY:
|
||||||
return now + timedelta(weeks=1)
|
return now + timedelta(weeks=1)
|
||||||
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
if task.recurrence_type == RecurrenceType.MONTHLY:
|
||||||
# Add approximately one month
|
# Add approximately one month
|
||||||
return now + timedelta(days=30)
|
return now + timedelta(days=30)
|
||||||
elif task.recurrence_type == RecurrenceType.YEARLY:
|
if task.recurrence_type == RecurrenceType.YEARLY:
|
||||||
return now + timedelta(days=365)
|
return now + timedelta(days=365)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _maintenance_job(self) -> None:
|
async def _maintenance_job(self) -> None:
|
||||||
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
|
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
|
||||||
try:
|
try:
|
||||||
async with self.db_session_factory() as session:
|
async with self.db_session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
# Handle expired tasks
|
# Handle expired tasks
|
||||||
expired_tasks = await repo.get_expired_tasks()
|
expired_tasks = await repo.get_expired_tasks()
|
||||||
for task in expired_tasks:
|
for task in expired_tasks:
|
||||||
task.status = TaskStatus.CANCELLED
|
task.status = TaskStatus.CANCELLED
|
||||||
task.is_active = False
|
task.is_active = False
|
||||||
await repo.update(task)
|
await repo.update(task)
|
||||||
|
|
||||||
# Remove from scheduler
|
# Remove from scheduler
|
||||||
try:
|
try:
|
||||||
self.scheduler.remove_job(str(task.id))
|
self.scheduler.remove_job(str(task.id))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if expired_tasks:
|
if expired_tasks:
|
||||||
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
|
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
|
||||||
|
|
||||||
# Handle any missed recurring tasks
|
# Handle any missed recurring tasks
|
||||||
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
||||||
for task in due_recurring:
|
for task in due_recurring:
|
||||||
@@ -405,9 +405,9 @@ class SchedulerService:
|
|||||||
task.scheduled_at = task.next_execution_at or datetime.utcnow()
|
task.scheduled_at = task.next_execution_at or datetime.utcnow()
|
||||||
await repo.update(task)
|
await repo.update(task)
|
||||||
await self._schedule_apscheduler_job(task)
|
await self._schedule_apscheduler_job(task)
|
||||||
|
|
||||||
if due_recurring:
|
if due_recurring:
|
||||||
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
|
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Maintenance job failed")
|
logger.exception("Maintenance job failed")
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Task execution handlers for different task types."""
|
"""Task execution handlers for different task types."""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -18,7 +17,6 @@ logger = get_logger(__name__)
|
|||||||
class TaskExecutionError(Exception):
|
class TaskExecutionError(Exception):
|
||||||
"""Exception raised when task execution fails."""
|
"""Exception raised when task execution fails."""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TaskHandlerRegistry:
|
class TaskHandlerRegistry:
|
||||||
@@ -58,8 +56,8 @@ class TaskHandlerRegistry:
|
|||||||
await handler(task)
|
await handler(task)
|
||||||
logger.info(f"Task {task.id} executed successfully")
|
logger.info(f"Task {task.id} executed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Task {task.id} execution failed: {str(e)}")
|
logger.exception(f"Task {task.id} execution failed: {e!s}")
|
||||||
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
|
raise TaskExecutionError(f"Task execution failed: {e!s}") from e
|
||||||
|
|
||||||
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
|
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
|
||||||
"""Handle credit recharge task."""
|
"""Handle credit recharge task."""
|
||||||
@@ -72,7 +70,7 @@ class TaskHandlerRegistry:
|
|||||||
user_id_int = int(user_id)
|
user_id_int = int(user_id)
|
||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
raise TaskExecutionError(f"Invalid user_id format: {user_id}") from e
|
raise TaskExecutionError(f"Invalid user_id format: {user_id}") from e
|
||||||
|
|
||||||
stats = await self.credit_service.recharge_user_credits(user_id_int)
|
stats = await self.credit_service.recharge_user_credits(user_id_int)
|
||||||
logger.info(f"Recharged credits for user {user_id}: {stats}")
|
logger.info(f"Recharged credits for user {user_id}: {stats}")
|
||||||
else:
|
else:
|
||||||
@@ -105,7 +103,7 @@ class TaskHandlerRegistry:
|
|||||||
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
|
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Convert HTTP exceptions or credit errors to task execution errors
|
# Convert HTTP exceptions or credit errors to task execution errors
|
||||||
raise TaskExecutionError(f"Failed to play sound with credits: {str(e)}") from e
|
raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e
|
||||||
else:
|
else:
|
||||||
# System task: play without credit deduction
|
# System task: play without credit deduction
|
||||||
sound = await self.sound_repository.get_by_id(sound_id_int)
|
sound = await self.sound_repository.get_by_id(sound_id_int)
|
||||||
@@ -116,10 +114,10 @@ class TaskHandlerRegistry:
|
|||||||
|
|
||||||
vlc_service = VLCPlayerService(self.db_session_factory)
|
vlc_service = VLCPlayerService(self.db_session_factory)
|
||||||
success = await vlc_service.play_sound(sound)
|
success = await vlc_service.play_sound(sound)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
raise TaskExecutionError(f"Failed to play sound {sound.filename}")
|
raise TaskExecutionError(f"Failed to play sound {sound.filename}")
|
||||||
|
|
||||||
logger.info(f"Played sound {sound.filename} via scheduled system task")
|
logger.info(f"Played sound {sound.filename} via scheduled system task")
|
||||||
|
|
||||||
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
|
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
|
||||||
@@ -157,4 +155,4 @@ class TaskHandlerRegistry:
|
|||||||
# Start playing
|
# Start playing
|
||||||
await self.player_service.play()
|
await self.player_service.play()
|
||||||
|
|
||||||
logger.info(f"Started playing playlist {playlist.name} via scheduled task")
|
logger.info(f"Started playing playlist {playlist.name} via scheduled task")
|
||||||
|
|||||||
@@ -238,13 +238,13 @@ class VLCPlayerService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Recording play count for sound %s", sound_id)
|
logger.info("Recording play count for sound %s", sound_id)
|
||||||
|
|
||||||
# Initialize variables for WebSocket event
|
# Initialize variables for WebSocket event
|
||||||
old_count = 0
|
old_count = 0
|
||||||
sound = None
|
sound = None
|
||||||
admin_user_id = None
|
admin_user_id = None
|
||||||
admin_user_name = None
|
admin_user_name = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.db_session_factory() as session:
|
async with self.db_session_factory() as session:
|
||||||
sound_repo = SoundRepository(session)
|
sound_repo = SoundRepository(session)
|
||||||
|
|||||||
@@ -7,15 +7,16 @@ from datetime import datetime
|
|||||||
from app.core.database import get_session_factory
|
from app.core.database import get_session_factory
|
||||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
|
|
||||||
async def check_tasks():
|
async def check_tasks():
|
||||||
session_factory = get_session_factory()
|
session_factory = get_session_factory()
|
||||||
|
|
||||||
async with session_factory() as session:
|
async with session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
# Get all tasks
|
# Get all tasks
|
||||||
all_tasks = await repo.get_all(limit=20)
|
all_tasks = await repo.get_all(limit=20)
|
||||||
|
|
||||||
print("All tasks in database:")
|
print("All tasks in database:")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
for task in all_tasks:
|
for task in all_tasks:
|
||||||
@@ -32,14 +33,14 @@ async def check_tasks():
|
|||||||
print(f"Error: {task.error_message}")
|
print(f"Error: {task.error_message}")
|
||||||
print(f"Parameters: {task.parameters}")
|
print(f"Parameters: {task.parameters}")
|
||||||
print("-" * 40)
|
print("-" * 40)
|
||||||
|
|
||||||
# Check specifically for pending tasks
|
# Check specifically for pending tasks
|
||||||
print(f"\nCurrent time: {datetime.utcnow()}")
|
print(f"\nCurrent time: {datetime.utcnow()}")
|
||||||
print("\nPending tasks:")
|
print("\nPending tasks:")
|
||||||
from app.models.scheduled_task import TaskStatus
|
from app.models.scheduled_task import TaskStatus
|
||||||
pending_tasks = await repo.get_all(limit=10)
|
pending_tasks = await repo.get_all(limit=10)
|
||||||
pending_tasks = [t for t in pending_tasks if t.status == TaskStatus.PENDING and t.is_active]
|
pending_tasks = [t for t in pending_tasks if t.status == TaskStatus.PENDING and t.is_active]
|
||||||
|
|
||||||
if not pending_tasks:
|
if not pending_tasks:
|
||||||
print("No pending tasks found")
|
print("No pending tasks found")
|
||||||
else:
|
else:
|
||||||
@@ -48,4 +49,4 @@ async def check_tasks():
|
|||||||
print(f"- {task.name} (ID: {task.id}): scheduled for {task.scheduled_at} (in {time_diff})")
|
print(f"- {task.name} (ID: {task.id}): scheduled for {task.scheduled_at} (in {time_diff})")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(check_tasks())
|
asyncio.run(check_tasks())
|
||||||
|
|||||||
@@ -5,18 +5,19 @@ import asyncio
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from app.core.database import get_session_factory
|
from app.core.database import get_session_factory
|
||||||
|
from app.models.scheduled_task import RecurrenceType, TaskType
|
||||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
from app.models.scheduled_task import TaskType, RecurrenceType
|
|
||||||
|
|
||||||
async def create_future_task():
|
async def create_future_task():
|
||||||
session_factory = get_session_factory()
|
session_factory = get_session_factory()
|
||||||
|
|
||||||
# Create a task for 2 minutes from now
|
# Create a task for 2 minutes from now
|
||||||
future_time = datetime.utcnow() + timedelta(minutes=2)
|
future_time = datetime.utcnow() + timedelta(minutes=2)
|
||||||
|
|
||||||
async with session_factory() as session:
|
async with session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
task_data = {
|
task_data = {
|
||||||
"name": f"Future Task {future_time.strftime('%H:%M:%S')}",
|
"name": f"Future Task {future_time.strftime('%H:%M:%S')}",
|
||||||
"task_type": TaskType.PLAY_SOUND,
|
"task_type": TaskType.PLAY_SOUND,
|
||||||
@@ -26,11 +27,11 @@ async def create_future_task():
|
|||||||
"user_id": 1,
|
"user_id": 1,
|
||||||
"recurrence_type": RecurrenceType.NONE,
|
"recurrence_type": RecurrenceType.NONE,
|
||||||
}
|
}
|
||||||
|
|
||||||
task = await repo.create(task_data)
|
task = await repo.create(task_data)
|
||||||
print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}")
|
print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}")
|
||||||
print(f"Current time: {datetime.utcnow()}")
|
print(f"Current time: {datetime.utcnow()}")
|
||||||
print(f"Task will execute in: {future_time - datetime.utcnow()}")
|
print(f"Task will execute in: {future_time - datetime.utcnow()}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(create_future_task())
|
asyncio.run(create_future_task())
|
||||||
|
|||||||
@@ -4,21 +4,21 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from app.core.database import get_session_factory
|
|
||||||
from app.main import get_global_scheduler_service
|
from app.main import get_global_scheduler_service
|
||||||
from app.models.scheduled_task import TaskType, RecurrenceType
|
from app.models.scheduled_task import RecurrenceType, TaskType
|
||||||
|
|
||||||
|
|
||||||
async def test_api_task_creation():
|
async def test_api_task_creation():
|
||||||
"""Test creating a task through the scheduler service (simulates API call)."""
|
"""Test creating a task through the scheduler service (simulates API call)."""
|
||||||
try:
|
try:
|
||||||
scheduler_service = get_global_scheduler_service()
|
scheduler_service = get_global_scheduler_service()
|
||||||
|
|
||||||
# Create a task for 2 minutes from now
|
# Create a task for 2 minutes from now
|
||||||
future_time = datetime.utcnow() + timedelta(minutes=2)
|
future_time = datetime.utcnow() + timedelta(minutes=2)
|
||||||
|
|
||||||
print(f"Creating task scheduled for: {future_time}")
|
print(f"Creating task scheduled for: {future_time}")
|
||||||
print(f"Current time: {datetime.utcnow()}")
|
print(f"Current time: {datetime.utcnow()}")
|
||||||
|
|
||||||
task = await scheduler_service.create_task(
|
task = await scheduler_service.create_task(
|
||||||
name=f"API Test Task {future_time.strftime('%H:%M:%S')}",
|
name=f"API Test Task {future_time.strftime('%H:%M:%S')}",
|
||||||
task_type=TaskType.PLAY_SOUND,
|
task_type=TaskType.PLAY_SOUND,
|
||||||
@@ -28,13 +28,13 @@ async def test_api_task_creation():
|
|||||||
timezone="UTC",
|
timezone="UTC",
|
||||||
recurrence_type=RecurrenceType.NONE,
|
recurrence_type=RecurrenceType.NONE,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Created task: {task.name} (ID: {task.id})")
|
print(f"Created task: {task.name} (ID: {task.id})")
|
||||||
print(f"Task will execute in: {future_time - datetime.utcnow()}")
|
print(f"Task will execute in: {future_time - datetime.utcnow()}")
|
||||||
print("Task should be automatically scheduled in APScheduler!")
|
print("Task should be automatically scheduled in APScheduler!")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(test_api_task_creation())
|
asyncio.run(test_api_task_creation())
|
||||||
|
|||||||
11
test_task.py
11
test_task.py
@@ -5,15 +5,16 @@ import asyncio
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.core.database import get_session_factory
|
from app.core.database import get_session_factory
|
||||||
|
from app.models.scheduled_task import RecurrenceType, TaskType
|
||||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
from app.models.scheduled_task import TaskType, RecurrenceType
|
|
||||||
|
|
||||||
async def create_test_task():
|
async def create_test_task():
|
||||||
session_factory = get_session_factory()
|
session_factory = get_session_factory()
|
||||||
|
|
||||||
async with session_factory() as session:
|
async with session_factory() as session:
|
||||||
repo = ScheduledTaskRepository(session)
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
task_data = {
|
task_data = {
|
||||||
"name": "Live Test Task",
|
"name": "Live Test Task",
|
||||||
"task_type": TaskType.PLAY_SOUND,
|
"task_type": TaskType.PLAY_SOUND,
|
||||||
@@ -23,9 +24,9 @@ async def create_test_task():
|
|||||||
"user_id": 1,
|
"user_id": 1,
|
||||||
"recurrence_type": RecurrenceType.NONE,
|
"recurrence_type": RecurrenceType.NONE,
|
||||||
}
|
}
|
||||||
|
|
||||||
task = await repo.create(task_data)
|
task = await repo.create(task_data)
|
||||||
print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}")
|
print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(create_test_task())
|
asyncio.run(create_test_task())
|
||||||
|
|||||||
@@ -351,11 +351,11 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_user_id(test_user: User):
|
def test_user_id(test_user: User):
|
||||||
"""Get test user ID."""
|
"""Get test user ID."""
|
||||||
return test_user.id
|
return test_user.id
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_sound_id():
|
def test_sound_id():
|
||||||
"""Create a test sound ID."""
|
"""Create a test sound ID."""
|
||||||
import uuid
|
import uuid
|
||||||
@@ -364,7 +364,7 @@ def test_sound_id():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_playlist_id():
|
def test_playlist_id():
|
||||||
"""Create a test playlist ID."""
|
"""Create a test playlist ID."""
|
||||||
import uuid
|
import uuid
|
||||||
return uuid.uuid4()
|
return uuid.uuid4()
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.models.scheduled_task import (
|
from app.models.scheduled_task import (
|
||||||
RecurrenceType,
|
RecurrenceType,
|
||||||
ScheduledTask,
|
ScheduledTask,
|
||||||
@@ -217,4 +215,4 @@ class TestScheduledTaskModel:
|
|||||||
assert RecurrenceType.WEEKLY == "weekly"
|
assert RecurrenceType.WEEKLY == "weekly"
|
||||||
assert RecurrenceType.MONTHLY == "monthly"
|
assert RecurrenceType.MONTHLY == "monthly"
|
||||||
assert RecurrenceType.YEARLY == "yearly"
|
assert RecurrenceType.YEARLY == "yearly"
|
||||||
assert RecurrenceType.CRON == "cron"
|
assert RecurrenceType.CRON == "cron"
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
@@ -491,4 +490,4 @@ class TestScheduledTaskRepository:
|
|||||||
updated_task = await repository.get_by_id(sample_task.id)
|
updated_task = await repository.get_by_id(sample_task.id)
|
||||||
assert updated_task.status == TaskStatus.FAILED
|
assert updated_task.status == TaskStatus.FAILED
|
||||||
# Non-recurring tasks should be deactivated on failure
|
# Non-recurring tasks should be deactivated on failure
|
||||||
assert updated_task.is_active is False
|
assert updated_task.is_active is False
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class TestSchedulerService:
|
|||||||
sample_task_data: dict,
|
sample_task_data: dict,
|
||||||
):
|
):
|
||||||
"""Test creating a scheduled task."""
|
"""Test creating a scheduled task."""
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
|
||||||
task = await scheduler_service.create_task(**sample_task_data)
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
assert task.id is not None
|
assert task.id is not None
|
||||||
@@ -68,7 +68,7 @@ class TestSchedulerService:
|
|||||||
test_user_id: uuid.UUID,
|
test_user_id: uuid.UUID,
|
||||||
):
|
):
|
||||||
"""Test creating a user task."""
|
"""Test creating a user task."""
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(
|
task = await scheduler_service.create_task(
|
||||||
user_id=test_user_id,
|
user_id=test_user_id,
|
||||||
**sample_task_data,
|
**sample_task_data,
|
||||||
@@ -83,7 +83,7 @@ class TestSchedulerService:
|
|||||||
sample_task_data: dict,
|
sample_task_data: dict,
|
||||||
):
|
):
|
||||||
"""Test creating a system task."""
|
"""Test creating a system task."""
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(**sample_task_data)
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
assert task.user_id is None
|
assert task.user_id is None
|
||||||
@@ -95,7 +95,7 @@ class TestSchedulerService:
|
|||||||
sample_task_data: dict,
|
sample_task_data: dict,
|
||||||
):
|
):
|
||||||
"""Test creating a recurring task."""
|
"""Test creating a recurring task."""
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(
|
task = await scheduler_service.create_task(
|
||||||
recurrence_type=RecurrenceType.DAILY,
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
recurrence_count=5,
|
recurrence_count=5,
|
||||||
@@ -114,11 +114,11 @@ class TestSchedulerService:
|
|||||||
"""Test creating task with timezone conversion."""
|
"""Test creating task with timezone conversion."""
|
||||||
# Use a specific datetime for testing
|
# Use a specific datetime for testing
|
||||||
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY
|
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY
|
||||||
|
|
||||||
sample_task_data["scheduled_at"] = ny_time
|
sample_task_data["scheduled_at"] = ny_time
|
||||||
sample_task_data["timezone"] = "America/New_York"
|
sample_task_data["timezone"] = "America/New_York"
|
||||||
|
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(**sample_task_data)
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
# The scheduled_at should be converted to UTC
|
# The scheduled_at should be converted to UTC
|
||||||
@@ -134,11 +134,11 @@ class TestSchedulerService:
|
|||||||
):
|
):
|
||||||
"""Test cancelling a task."""
|
"""Test cancelling a task."""
|
||||||
# Create a task first
|
# Create a task first
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(**sample_task_data)
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
# Mock the scheduler remove_job method
|
# Mock the scheduler remove_job method
|
||||||
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
|
||||||
result = await scheduler_service.cancel_task(task.id)
|
result = await scheduler_service.cancel_task(task.id)
|
||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
@@ -167,7 +167,7 @@ class TestSchedulerService:
|
|||||||
test_user_id: uuid.UUID,
|
test_user_id: uuid.UUID,
|
||||||
):
|
):
|
||||||
"""Test getting user tasks."""
|
"""Test getting user tasks."""
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
# Create user task
|
# Create user task
|
||||||
await scheduler_service.create_task(
|
await scheduler_service.create_task(
|
||||||
user_id=test_user_id,
|
user_id=test_user_id,
|
||||||
@@ -188,12 +188,12 @@ class TestSchedulerService:
|
|||||||
):
|
):
|
||||||
"""Test ensuring system tasks exist."""
|
"""Test ensuring system tasks exist."""
|
||||||
# Mock the repository to return no existing tasks
|
# Mock the repository to return no existing tasks
|
||||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
|
||||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
|
||||||
mock_get.return_value = []
|
mock_get.return_value = []
|
||||||
|
|
||||||
await scheduler_service._ensure_system_tasks()
|
await scheduler_service._ensure_system_tasks()
|
||||||
|
|
||||||
# Should create daily credit recharge task
|
# Should create daily credit recharge task
|
||||||
mock_create.assert_called_once()
|
mock_create.assert_called_once()
|
||||||
created_task = mock_create.call_args[0][0]
|
created_task = mock_create.call_args[0][0]
|
||||||
@@ -213,13 +213,13 @@ class TestSchedulerService:
|
|||||||
recurrence_type=RecurrenceType.DAILY,
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
|
||||||
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
|
||||||
mock_get.return_value = [existing_task]
|
mock_get.return_value = [existing_task]
|
||||||
|
|
||||||
await scheduler_service._ensure_system_tasks()
|
await scheduler_service._ensure_system_tasks()
|
||||||
|
|
||||||
# Should not create new task
|
# Should not create new task
|
||||||
mock_create.assert_not_called()
|
mock_create.assert_not_called()
|
||||||
|
|
||||||
@@ -294,7 +294,7 @@ class TestSchedulerService:
|
|||||||
):
|
):
|
||||||
"""Test calculating next execution time."""
|
"""Test calculating next execution time."""
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
|
||||||
# Test different recurrence types
|
# Test different recurrence types
|
||||||
test_cases = [
|
test_cases = [
|
||||||
(RecurrenceType.HOURLY, timedelta(hours=1)),
|
(RecurrenceType.HOURLY, timedelta(hours=1)),
|
||||||
@@ -312,7 +312,7 @@ class TestSchedulerService:
|
|||||||
recurrence_type=recurrence_type,
|
recurrence_type=recurrence_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch('app.services.scheduler.datetime') as mock_datetime:
|
with patch("app.services.scheduler.datetime") as mock_datetime:
|
||||||
mock_datetime.utcnow.return_value = now
|
mock_datetime.utcnow.return_value = now
|
||||||
next_execution = scheduler_service._calculate_next_execution(task)
|
next_execution = scheduler_service._calculate_next_execution(task)
|
||||||
|
|
||||||
@@ -335,7 +335,7 @@ class TestSchedulerService:
|
|||||||
next_execution = scheduler_service._calculate_next_execution(task)
|
next_execution = scheduler_service._calculate_next_execution(task)
|
||||||
assert next_execution is None
|
assert next_execution is None
|
||||||
|
|
||||||
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
@patch("app.services.task_handlers.TaskHandlerRegistry")
|
||||||
async def test_execute_task_success(
|
async def test_execute_task_success(
|
||||||
self,
|
self,
|
||||||
mock_handler_class,
|
mock_handler_class,
|
||||||
@@ -344,7 +344,7 @@ class TestSchedulerService:
|
|||||||
):
|
):
|
||||||
"""Test successful task execution."""
|
"""Test successful task execution."""
|
||||||
# Create task
|
# Create task
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(**sample_task_data)
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
# Mock handler registry
|
# Mock handler registry
|
||||||
@@ -365,7 +365,7 @@ class TestSchedulerService:
|
|||||||
assert updated_task.status == TaskStatus.COMPLETED
|
assert updated_task.status == TaskStatus.COMPLETED
|
||||||
assert updated_task.executions_count == 1
|
assert updated_task.executions_count == 1
|
||||||
|
|
||||||
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
@patch("app.services.task_handlers.TaskHandlerRegistry")
|
||||||
async def test_execute_task_failure(
|
async def test_execute_task_failure(
|
||||||
self,
|
self,
|
||||||
mock_handler_class,
|
mock_handler_class,
|
||||||
@@ -374,7 +374,7 @@ class TestSchedulerService:
|
|||||||
):
|
):
|
||||||
"""Test task execution failure."""
|
"""Test task execution failure."""
|
||||||
# Create task
|
# Create task
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(**sample_task_data)
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
# Mock handler to raise exception
|
# Mock handler to raise exception
|
||||||
@@ -409,8 +409,8 @@ class TestSchedulerService:
|
|||||||
"""Test executing expired task."""
|
"""Test executing expired task."""
|
||||||
# Create expired task
|
# Create expired task
|
||||||
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
|
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
|
||||||
|
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(**sample_task_data)
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
# Execute task
|
# Execute task
|
||||||
@@ -430,20 +430,20 @@ class TestSchedulerService:
|
|||||||
sample_task_data: dict,
|
sample_task_data: dict,
|
||||||
):
|
):
|
||||||
"""Test prevention of concurrent task execution."""
|
"""Test prevention of concurrent task execution."""
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
with patch.object(scheduler_service, "_schedule_apscheduler_job"):
|
||||||
task = await scheduler_service.create_task(**sample_task_data)
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
# Add task to running set
|
# Add task to running set
|
||||||
scheduler_service._running_tasks.add(str(task.id))
|
scheduler_service._running_tasks.add(str(task.id))
|
||||||
|
|
||||||
# Try to execute - should return without doing anything
|
# Try to execute - should return without doing anything
|
||||||
with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class:
|
with patch("app.services.task_handlers.TaskHandlerRegistry") as mock_handler_class:
|
||||||
await scheduler_service._execute_task(task.id)
|
await scheduler_service._execute_task(task.id)
|
||||||
|
|
||||||
# Handler should not be called
|
# Handler should not be called
|
||||||
mock_handler_class.assert_not_called()
|
mock_handler_class.assert_not_called()
|
||||||
|
|
||||||
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
|
||||||
async def test_maintenance_job_expired_tasks(
|
async def test_maintenance_job_expired_tasks(
|
||||||
self,
|
self,
|
||||||
mock_repo_class,
|
mock_repo_class,
|
||||||
@@ -453,22 +453,22 @@ class TestSchedulerService:
|
|||||||
# Mock expired task
|
# Mock expired task
|
||||||
expired_task = MagicMock()
|
expired_task = MagicMock()
|
||||||
expired_task.id = uuid.uuid4()
|
expired_task.id = uuid.uuid4()
|
||||||
|
|
||||||
mock_repo = AsyncMock()
|
mock_repo = AsyncMock()
|
||||||
mock_repo.get_expired_tasks.return_value = [expired_task]
|
mock_repo.get_expired_tasks.return_value = [expired_task]
|
||||||
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
|
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
|
||||||
mock_repo_class.return_value = mock_repo
|
mock_repo_class.return_value = mock_repo
|
||||||
|
|
||||||
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
|
||||||
await scheduler_service._maintenance_job()
|
await scheduler_service._maintenance_job()
|
||||||
|
|
||||||
# Should mark as cancelled and remove from scheduler
|
# Should mark as cancelled and remove from scheduler
|
||||||
assert expired_task.status == TaskStatus.CANCELLED
|
assert expired_task.status == TaskStatus.CANCELLED
|
||||||
assert expired_task.is_active is False
|
assert expired_task.is_active is False
|
||||||
mock_repo.update.assert_called_with(expired_task)
|
mock_repo.update.assert_called_with(expired_task)
|
||||||
mock_remove.assert_called_once_with(str(expired_task.id))
|
mock_remove.assert_called_once_with(str(expired_task.id))
|
||||||
|
|
||||||
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
@patch("app.repositories.scheduled_task.ScheduledTaskRepository")
|
||||||
async def test_maintenance_job_due_recurring_tasks(
|
async def test_maintenance_job_due_recurring_tasks(
|
||||||
self,
|
self,
|
||||||
mock_repo_class,
|
mock_repo_class,
|
||||||
@@ -479,17 +479,17 @@ class TestSchedulerService:
|
|||||||
due_task = MagicMock()
|
due_task = MagicMock()
|
||||||
due_task.should_repeat.return_value = True
|
due_task.should_repeat.return_value = True
|
||||||
due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5)
|
due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5)
|
||||||
|
|
||||||
mock_repo = AsyncMock()
|
mock_repo = AsyncMock()
|
||||||
mock_repo.get_expired_tasks.return_value = []
|
mock_repo.get_expired_tasks.return_value = []
|
||||||
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
|
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
|
||||||
mock_repo_class.return_value = mock_repo
|
mock_repo_class.return_value = mock_repo
|
||||||
|
|
||||||
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
|
||||||
await scheduler_service._maintenance_job()
|
await scheduler_service._maintenance_job()
|
||||||
|
|
||||||
# Should reset to pending and reschedule
|
# Should reset to pending and reschedule
|
||||||
assert due_task.status == TaskStatus.PENDING
|
assert due_task.status == TaskStatus.PENDING
|
||||||
assert due_task.scheduled_at == due_task.next_execution_at
|
assert due_task.scheduled_at == due_task.next_execution_at
|
||||||
mock_repo.update.assert_called_with(due_task)
|
mock_repo.update.assert_called_with(due_task)
|
||||||
mock_schedule.assert_called_once_with(due_task)
|
mock_schedule.assert_called_once_with(due_task)
|
||||||
|
|||||||
@@ -133,8 +133,8 @@ class TestTaskHandlerRegistry:
|
|||||||
mock_sound.id = test_sound_id
|
mock_sound.id = test_sound_id
|
||||||
mock_sound.filename = "test_sound.mp3"
|
mock_sound.filename = "test_sound.mp3"
|
||||||
|
|
||||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
|
||||||
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
|
||||||
mock_vlc_service = AsyncMock()
|
mock_vlc_service = AsyncMock()
|
||||||
mock_vlc_class.return_value = mock_vlc_service
|
mock_vlc_class.return_value = mock_vlc_service
|
||||||
|
|
||||||
@@ -186,7 +186,7 @@ class TestTaskHandlerRegistry:
|
|||||||
parameters={"sound_id": str(test_sound_id)},
|
parameters={"sound_id": str(test_sound_id)},
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None):
|
with patch.object(task_registry.sound_repository, "get_by_id", return_value=None):
|
||||||
with pytest.raises(TaskExecutionError, match="Sound not found"):
|
with pytest.raises(TaskExecutionError, match="Sound not found"):
|
||||||
await task_registry.execute_task(task)
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
@@ -206,8 +206,8 @@ class TestTaskHandlerRegistry:
|
|||||||
mock_sound = MagicMock()
|
mock_sound = MagicMock()
|
||||||
mock_sound.filename = "test_sound.mp3"
|
mock_sound.filename = "test_sound.mp3"
|
||||||
|
|
||||||
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
|
||||||
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
|
||||||
mock_vlc_service = AsyncMock()
|
mock_vlc_service = AsyncMock()
|
||||||
mock_vlc_class.return_value = mock_vlc_service
|
mock_vlc_class.return_value = mock_vlc_service
|
||||||
|
|
||||||
@@ -238,7 +238,7 @@ class TestTaskHandlerRegistry:
|
|||||||
mock_playlist.id = test_playlist_id
|
mock_playlist.id = test_playlist_id
|
||||||
mock_playlist.name = "Test Playlist"
|
mock_playlist.name = "Test Playlist"
|
||||||
|
|
||||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||||
await task_registry.execute_task(task)
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
|
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
|
||||||
@@ -264,7 +264,7 @@ class TestTaskHandlerRegistry:
|
|||||||
mock_playlist = MagicMock()
|
mock_playlist = MagicMock()
|
||||||
mock_playlist.name = "Test Playlist"
|
mock_playlist.name = "Test Playlist"
|
||||||
|
|
||||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||||
await task_registry.execute_task(task)
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
# Should use default values
|
# Should use default values
|
||||||
@@ -314,7 +314,7 @@ class TestTaskHandlerRegistry:
|
|||||||
parameters={"playlist_id": str(test_playlist_id)},
|
parameters={"playlist_id": str(test_playlist_id)},
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None):
|
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=None):
|
||||||
with pytest.raises(TaskExecutionError, match="Playlist not found"):
|
with pytest.raises(TaskExecutionError, match="Playlist not found"):
|
||||||
await task_registry.execute_task(task)
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
@@ -327,7 +327,7 @@ class TestTaskHandlerRegistry:
|
|||||||
"""Test play playlist task with various valid play modes."""
|
"""Test play playlist task with various valid play modes."""
|
||||||
mock_playlist = MagicMock()
|
mock_playlist = MagicMock()
|
||||||
mock_playlist.name = "Test Playlist"
|
mock_playlist.name = "Test Playlist"
|
||||||
|
|
||||||
valid_modes = ["continuous", "loop", "loop_one", "random", "single"]
|
valid_modes = ["continuous", "loop", "loop_one", "random", "single"]
|
||||||
|
|
||||||
for mode in valid_modes:
|
for mode in valid_modes:
|
||||||
@@ -341,7 +341,7 @@ class TestTaskHandlerRegistry:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||||
await task_registry.execute_task(task)
|
await task_registry.execute_task(task)
|
||||||
mock_player_service.set_mode.assert_called_with(mode)
|
mock_player_service.set_mode.assert_called_with(mode)
|
||||||
|
|
||||||
@@ -368,7 +368,7 @@ class TestTaskHandlerRegistry:
|
|||||||
mock_playlist = MagicMock()
|
mock_playlist = MagicMock()
|
||||||
mock_playlist.name = "Test Playlist"
|
mock_playlist.name = "Test Playlist"
|
||||||
|
|
||||||
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
|
||||||
await task_registry.execute_task(task)
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
# Should not call set_mode for invalid mode
|
# Should not call set_mode for invalid mode
|
||||||
@@ -421,4 +421,4 @@ class TestTaskHandlerRegistry:
|
|||||||
TaskType.PLAY_SOUND,
|
TaskType.PLAY_SOUND,
|
||||||
TaskType.PLAY_PLAYLIST,
|
TaskType.PLAY_PLAYLIST,
|
||||||
}
|
}
|
||||||
assert set(registry._handlers.keys()) == expected_handlers
|
assert set(registry._handlers.keys()) == expected_handlers
|
||||||
|
|||||||
Reference in New Issue
Block a user