Refactor scheduled task repository and schemas for improved type hints and consistency

- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase.
- Refactored import statements for better organization and clarity.
- Enhanced the ScheduledTaskBase schema to use modern type hints.
- Cleaned up unnecessary comments and whitespace in various files.
- Improved error handling and logging in task execution handlers.
- Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
JSC
2025-08-28 23:38:47 +02:00
parent 96801dc4d6
commit dc89e45675
19 changed files with 292 additions and 291 deletions

View File

@@ -1,7 +1,5 @@
"""API endpoints for scheduled task management.""" """API endpoints for scheduled task management."""
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -11,7 +9,7 @@ from app.core.dependencies import (
get_admin_user, get_admin_user,
get_current_active_user, get_current_active_user,
) )
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
from app.models.user import User from app.models.user import User
from app.schemas.scheduler import ( from app.schemas.scheduler import (
ScheduledTaskCreate, ScheduledTaskCreate,
@@ -54,15 +52,15 @@ async def create_task(
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@router.get("/tasks", response_model=List[ScheduledTaskResponse]) @router.get("/tasks", response_model=list[ScheduledTaskResponse])
async def get_user_tasks( async def get_user_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"), status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"), task_type: TaskType | None = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"), limit: int | None = Query(50, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"), offset: int | None = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user),
scheduler_service: SchedulerService = Depends(get_scheduler_service), scheduler_service: SchedulerService = Depends(get_scheduler_service),
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get user's scheduled tasks.""" """Get user's scheduled tasks."""
return await scheduler_service.get_user_tasks( return await scheduler_service.get_user_tasks(
user_id=current_user.id, user_id=current_user.id,
@@ -152,15 +150,15 @@ async def cancel_task(
# Admin-only endpoints # Admin-only endpoints
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse]) @router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
async def get_all_tasks( async def get_all_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"), status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"), task_type: TaskType | None = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"), limit: int | None = Query(100, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"), offset: int | None = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_admin_user), current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db), db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get all scheduled tasks (admin only).""" """Get all scheduled tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
@@ -189,13 +187,13 @@ async def get_all_tasks(
return list(result.all()) return list(result.all())
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse]) @router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
async def get_system_tasks( async def get_system_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"), status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"), task_type: TaskType | None = Query(None, description="Filter by task type"),
current_user: User = Depends(get_admin_user), current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db), db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get system tasks (admin only).""" """Get system tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository

View File

@@ -4,11 +4,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import SQLModel from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
from app.core.config import settings from app.core.config import settings
from app.core.logging import get_logger from app.core.logging import get_logger
from app.core.seeds import seed_all_data from app.core.seeds import seed_all_data
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
engine: AsyncEngine = create_async_engine( engine: AsyncEngine = create_async_engine(
settings.DATABASE_URL, settings.DATABASE_URL,

View File

@@ -11,11 +11,14 @@ from app.core.database import get_session_factory, init_db
from app.core.logging import get_logger, setup_logging from app.core.logging import get_logger, setup_logging
from app.middleware.logging import LoggingMiddleware from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor from app.services.extraction_processor import extraction_processor
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service from app.services.player import (
get_player_service,
initialize_player_service,
shutdown_player_service,
)
from app.services.scheduler import SchedulerService from app.services.scheduler import SchedulerService
from app.services.socket import socket_manager from app.services.socket import socket_manager
scheduler_service = None scheduler_service = None

View File

@@ -1,11 +1,10 @@
"""Scheduled task model for flexible task scheduling with timezone support.""" """Scheduled task model for flexible task scheduling with timezone support."""
import uuid
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any
from sqlmodel import JSON, Column, Field, SQLModel from sqlmodel import JSON, Column, Field
from app.models.base import BaseModel from app.models.base import BaseModel
@@ -57,11 +56,11 @@ class ScheduledTask(BaseModel, table=True):
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')", description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
) )
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE) recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
cron_expression: Optional[str] = Field( cron_expression: str | None = Field(
default=None, default=None,
description="Cron expression for custom recurrence (when recurrence_type is CRON)", description="Cron expression for custom recurrence (when recurrence_type is CRON)",
) )
recurrence_count: Optional[int] = Field( recurrence_count: int | None = Field(
default=None, default=None,
description="Number of times to repeat (None for infinite)", description="Number of times to repeat (None for infinite)",
) )
@@ -75,29 +74,29 @@ class ScheduledTask(BaseModel, table=True):
) )
# User association (None for system tasks) # User association (None for system tasks)
user_id: Optional[int] = Field( user_id: int | None = Field(
default=None, default=None,
foreign_key="user.id", foreign_key="user.id",
description="User who created the task (None for system tasks)", description="User who created the task (None for system tasks)",
) )
# Execution tracking # Execution tracking
last_executed_at: Optional[datetime] = Field( last_executed_at: datetime | None = Field(
default=None, default=None,
description="When the task was last executed (UTC)", description="When the task was last executed (UTC)",
) )
next_execution_at: Optional[datetime] = Field( next_execution_at: datetime | None = Field(
default=None, default=None,
description="When the task should be executed next (UTC, for recurring tasks)", description="When the task should be executed next (UTC, for recurring tasks)",
) )
error_message: Optional[str] = Field( error_message: str | None = Field(
default=None, default=None,
description="Error message if execution failed", description="Error message if execution failed",
) )
# Task lifecycle # Task lifecycle
is_active: bool = Field(default=True, description="Whether the task is active") is_active: bool = Field(default=True, description="Whether the task is active")
expires_at: Optional[datetime] = Field( expires_at: datetime | None = Field(
default=None, default=None,
description="When the task expires (UTC, optional)", description="When the task expires (UTC, optional)",
) )

View File

@@ -1,12 +1,16 @@
"""Repository for scheduled task operations.""" """Repository for scheduled task operations."""
from datetime import datetime from datetime import datetime
from typing import List, Optional
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.base import BaseRepository from app.repositories.base import BaseRepository
@@ -17,7 +21,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
"""Initialize the repository.""" """Initialize the repository."""
super().__init__(ScheduledTask, session) super().__init__(ScheduledTask, session)
async def get_pending_tasks(self) -> List[ScheduledTask]: async def get_pending_tasks(self) -> list[ScheduledTask]:
"""Get all pending tasks that are ready to be executed.""" """Get all pending tasks that are ready to be executed."""
now = datetime.utcnow() now = datetime.utcnow()
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
@@ -28,7 +32,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
async def get_active_tasks(self) -> List[ScheduledTask]: async def get_active_tasks(self) -> list[ScheduledTask]:
"""Get all active tasks.""" """Get all active tasks."""
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
ScheduledTask.is_active.is_(True), ScheduledTask.is_active.is_(True),
@@ -40,11 +44,11 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def get_user_tasks( async def get_user_tasks(
self, self,
user_id: int, user_id: int,
status: Optional[TaskStatus] = None, status: TaskStatus | None = None,
task_type: Optional[TaskType] = None, task_type: TaskType | None = None,
limit: Optional[int] = None, limit: int | None = None,
offset: Optional[int] = None, offset: int | None = None,
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get tasks for a specific user.""" """Get tasks for a specific user."""
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id) statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
@@ -67,9 +71,9 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def get_system_tasks( async def get_system_tasks(
self, self,
status: Optional[TaskStatus] = None, status: TaskStatus | None = None,
task_type: Optional[TaskType] = None, task_type: TaskType | None = None,
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get system tasks (tasks with no user association).""" """Get system tasks (tasks with no user association)."""
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None)) statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
@@ -84,7 +88,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]: async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
"""Get recurring tasks that need their next execution scheduled.""" """Get recurring tasks that need their next execution scheduled."""
now = datetime.utcnow() now = datetime.utcnow()
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
@@ -96,7 +100,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
async def get_expired_tasks(self) -> List[ScheduledTask]: async def get_expired_tasks(self) -> list[ScheduledTask]:
"""Get expired tasks that should be cleaned up.""" """Get expired tasks that should be cleaned up."""
now = datetime.utcnow() now = datetime.utcnow()
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
@@ -110,7 +114,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def cancel_user_tasks( async def cancel_user_tasks(
self, self,
user_id: int, user_id: int,
task_type: Optional[TaskType] = None, task_type: TaskType | None = None,
) -> int: ) -> int:
"""Cancel all pending/running tasks for a user.""" """Cancel all pending/running tasks for a user."""
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
@@ -144,7 +148,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def mark_as_completed( async def mark_as_completed(
self, self,
task: ScheduledTask, task: ScheduledTask,
next_execution_at: Optional[datetime] = None, next_execution_at: datetime | None = None,
) -> None: ) -> None:
"""Mark a task as completed and set next execution if recurring.""" """Mark a task as completed and set next execution if recurring."""
task.status = TaskStatus.COMPLETED task.status = TaskStatus.COMPLETED

View File

@@ -1,7 +1,7 @@
"""Schemas for scheduled task API.""" """Schemas for scheduled task API."""
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -15,7 +15,7 @@ class ScheduledTaskBase(BaseModel):
task_type: TaskType = Field(description="Type of task to execute") task_type: TaskType = Field(description="Type of task to execute")
scheduled_at: datetime = Field(description="When the task should be executed") scheduled_at: datetime = Field(description="When the task should be executed")
timezone: str = Field(default="UTC", description="Timezone for scheduling") timezone: str = Field(default="UTC", description="Timezone for scheduling")
parameters: Dict[str, Any] = Field( parameters: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Task-specific parameters", description="Task-specific parameters",
) )
@@ -23,15 +23,15 @@ class ScheduledTaskBase(BaseModel):
default=RecurrenceType.NONE, default=RecurrenceType.NONE,
description="Recurrence pattern", description="Recurrence pattern",
) )
cron_expression: Optional[str] = Field( cron_expression: str | None = Field(
default=None, default=None,
description="Cron expression for custom recurrence", description="Cron expression for custom recurrence",
) )
recurrence_count: Optional[int] = Field( recurrence_count: int | None = Field(
default=None, default=None,
description="Number of times to repeat (None for infinite)", description="Number of times to repeat (None for infinite)",
) )
expires_at: Optional[datetime] = Field( expires_at: datetime | None = Field(
default=None, default=None,
description="When the task expires (optional)", description="When the task expires (optional)",
) )
@@ -40,18 +40,17 @@ class ScheduledTaskBase(BaseModel):
class ScheduledTaskCreate(ScheduledTaskBase): class ScheduledTaskCreate(ScheduledTaskBase):
"""Schema for creating a scheduled task.""" """Schema for creating a scheduled task."""
pass
class ScheduledTaskUpdate(BaseModel): class ScheduledTaskUpdate(BaseModel):
"""Schema for updating a scheduled task.""" """Schema for updating a scheduled task."""
name: Optional[str] = None name: str | None = None
scheduled_at: Optional[datetime] = None scheduled_at: datetime | None = None
timezone: Optional[str] = None timezone: str | None = None
parameters: Optional[Dict[str, Any]] = None parameters: dict[str, Any] | None = None
is_active: Optional[bool] = None is_active: bool | None = None
expires_at: Optional[datetime] = None expires_at: datetime | None = None
class ScheduledTaskResponse(ScheduledTaskBase): class ScheduledTaskResponse(ScheduledTaskBase):
@@ -59,11 +58,11 @@ class ScheduledTaskResponse(ScheduledTaskBase):
id: int id: int
status: TaskStatus status: TaskStatus
user_id: Optional[int] = None user_id: int | None = None
executions_count: int executions_count: int
last_executed_at: Optional[datetime] = None last_executed_at: datetime | None = None
next_execution_at: Optional[datetime] = None next_execution_at: datetime | None = None
error_message: Optional[str] = None error_message: str | None = None
is_active: bool is_active: bool
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -78,7 +77,7 @@ class ScheduledTaskResponse(ScheduledTaskBase):
class CreditRechargeParameters(BaseModel): class CreditRechargeParameters(BaseModel):
"""Parameters for credit recharge tasks.""" """Parameters for credit recharge tasks."""
user_id: Optional[int] = Field( user_id: int | None = Field(
default=None, default=None,
description="Specific user ID to recharge (None for all users)", description="Specific user ID to recharge (None for all users)",
) )
@@ -109,10 +108,10 @@ class CreateCreditRechargeTask(BaseModel):
scheduled_at: datetime scheduled_at: datetime
timezone: str = "UTC" timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None cron_expression: str | None = None
recurrence_count: Optional[int] = None recurrence_count: int | None = None
expires_at: Optional[datetime] = None expires_at: datetime | None = None
user_id: Optional[int] = None user_id: int | None = None
def to_task_create(self) -> ScheduledTaskCreate: def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema.""" """Convert to generic task creation schema."""
@@ -137,9 +136,9 @@ class CreatePlaySoundTask(BaseModel):
sound_id: int sound_id: int
timezone: str = "UTC" timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None cron_expression: str | None = None
recurrence_count: Optional[int] = None recurrence_count: int | None = None
expires_at: Optional[datetime] = None expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate: def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema.""" """Convert to generic task creation schema."""
@@ -166,9 +165,9 @@ class CreatePlayPlaylistTask(BaseModel):
shuffle: bool = False shuffle: bool = False
timezone: str = "UTC" timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None cron_expression: str | None = None
recurrence_count: Optional[int] = None recurrence_count: int | None = None
expires_at: Optional[datetime] = None expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate: def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema.""" """Convert to generic task creation schema."""

View File

@@ -2,7 +2,7 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional from typing import Any
import pytz import pytz
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -86,13 +86,13 @@ class SchedulerService:
name: str, name: str,
task_type: TaskType, task_type: TaskType,
scheduled_at: datetime, scheduled_at: datetime,
parameters: Optional[Dict[str, Any]] = None, parameters: dict[str, Any] | None = None,
user_id: Optional[int] = None, user_id: int | None = None,
timezone: str = "UTC", timezone: str = "UTC",
recurrence_type: RecurrenceType = RecurrenceType.NONE, recurrence_type: RecurrenceType = RecurrenceType.NONE,
cron_expression: Optional[str] = None, cron_expression: str | None = None,
recurrence_count: Optional[int] = None, recurrence_count: int | None = None,
expires_at: Optional[datetime] = None, expires_at: datetime | None = None,
) -> ScheduledTask: ) -> ScheduledTask:
"""Create a new scheduled task.""" """Create a new scheduled task."""
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
@@ -150,11 +150,11 @@ class SchedulerService:
async def get_user_tasks( async def get_user_tasks(
self, self,
user_id: int, user_id: int,
status: Optional[TaskStatus] = None, status: TaskStatus | None = None,
task_type: Optional[TaskType] = None, task_type: TaskType | None = None,
limit: Optional[int] = None, limit: int | None = None,
offset: Optional[int] = None, offset: int | None = None,
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get tasks for a specific user.""" """Get tasks for a specific user."""
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
@@ -182,7 +182,7 @@ class SchedulerService:
# Check if daily credit recharge task exists # Check if daily credit recharge task exists
system_tasks = await repo.get_system_tasks( system_tasks = await repo.get_system_tasks(
task_type=TaskType.CREDIT_RECHARGE task_type=TaskType.CREDIT_RECHARGE,
) )
daily_recharge_exists = any( daily_recharge_exists = any(
@@ -194,7 +194,7 @@ class SchedulerService:
if not daily_recharge_exists: if not daily_recharge_exists:
# Create daily credit recharge task # Create daily credit recharge task
tomorrow_midnight = datetime.utcnow().replace( tomorrow_midnight = datetime.utcnow().replace(
hour=0, minute=0, second=0, microsecond=0 hour=0, minute=0, second=0, microsecond=0,
) + timedelta(days=1) ) + timedelta(days=1)
task_data = { task_data = {
@@ -258,36 +258,36 @@ class SchedulerService:
if task.recurrence_type == RecurrenceType.NONE: if task.recurrence_type == RecurrenceType.NONE:
return DateTrigger(run_date=task.scheduled_at, timezone=tz) return DateTrigger(run_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression: if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
return CronTrigger.from_crontab(task.cron_expression, timezone=tz) return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
elif task.recurrence_type == RecurrenceType.HOURLY: if task.recurrence_type == RecurrenceType.HOURLY:
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz) return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.DAILY: if task.recurrence_type == RecurrenceType.DAILY:
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz) return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.WEEKLY: if task.recurrence_type == RecurrenceType.WEEKLY:
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz) return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.MONTHLY: if task.recurrence_type == RecurrenceType.MONTHLY:
# Use cron trigger for monthly (more reliable than interval) # Use cron trigger for monthly (more reliable than interval)
scheduled_time = task.scheduled_at scheduled_time = task.scheduled_at
return CronTrigger( return CronTrigger(
day=scheduled_time.day, day=scheduled_time.day,
hour=scheduled_time.hour, hour=scheduled_time.hour,
minute=scheduled_time.minute, minute=scheduled_time.minute,
timezone=tz timezone=tz,
) )
elif task.recurrence_type == RecurrenceType.YEARLY: if task.recurrence_type == RecurrenceType.YEARLY:
scheduled_time = task.scheduled_at scheduled_time = task.scheduled_at
return CronTrigger( return CronTrigger(
month=scheduled_time.month, month=scheduled_time.month,
day=scheduled_time.day, day=scheduled_time.day,
hour=scheduled_time.hour, hour=scheduled_time.hour,
minute=scheduled_time.minute, minute=scheduled_time.minute,
timezone=tz timezone=tz,
) )
return None return None
@@ -332,7 +332,7 @@ class SchedulerService:
# Execute the task # Execute the task
try: try:
handler_registry = TaskHandlerRegistry( handler_registry = TaskHandlerRegistry(
session, self.db_session_factory, self.credit_service, self.player_service session, self.db_session_factory, self.credit_service, self.player_service,
) )
await handler_registry.execute_task(task) await handler_registry.execute_task(task)
@@ -352,25 +352,25 @@ class SchedulerService:
except Exception as e: except Exception as e:
await repo.mark_as_failed(task, str(e)) await repo.mark_as_failed(task, str(e))
logger.exception(f"Task {task_id} execution failed: {str(e)}") logger.exception(f"Task {task_id} execution failed: {e!s}")
finally: finally:
self._running_tasks.discard(task_id_str) self._running_tasks.discard(task_id_str)
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]: def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
"""Calculate the next execution time for a recurring task.""" """Calculate the next execution time for a recurring task."""
now = datetime.utcnow() now = datetime.utcnow()
if task.recurrence_type == RecurrenceType.HOURLY: if task.recurrence_type == RecurrenceType.HOURLY:
return now + timedelta(hours=1) return now + timedelta(hours=1)
elif task.recurrence_type == RecurrenceType.DAILY: if task.recurrence_type == RecurrenceType.DAILY:
return now + timedelta(days=1) return now + timedelta(days=1)
elif task.recurrence_type == RecurrenceType.WEEKLY: if task.recurrence_type == RecurrenceType.WEEKLY:
return now + timedelta(weeks=1) return now + timedelta(weeks=1)
elif task.recurrence_type == RecurrenceType.MONTHLY: if task.recurrence_type == RecurrenceType.MONTHLY:
# Add approximately one month # Add approximately one month
return now + timedelta(days=30) return now + timedelta(days=30)
elif task.recurrence_type == RecurrenceType.YEARLY: if task.recurrence_type == RecurrenceType.YEARLY:
return now + timedelta(days=365) return now + timedelta(days=365)
return None return None

View File

@@ -1,6 +1,5 @@
"""Task execution handlers for different task types.""" """Task execution handlers for different task types."""
from typing import Any, Dict, Optional
from collections.abc import Callable from collections.abc import Callable
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -18,7 +17,6 @@ logger = get_logger(__name__)
class TaskExecutionError(Exception): class TaskExecutionError(Exception):
"""Exception raised when task execution fails.""" """Exception raised when task execution fails."""
pass
class TaskHandlerRegistry: class TaskHandlerRegistry:
@@ -58,8 +56,8 @@ class TaskHandlerRegistry:
await handler(task) await handler(task)
logger.info(f"Task {task.id} executed successfully") logger.info(f"Task {task.id} executed successfully")
except Exception as e: except Exception as e:
logger.exception(f"Task {task.id} execution failed: {str(e)}") logger.exception(f"Task {task.id} execution failed: {e!s}")
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e raise TaskExecutionError(f"Task execution failed: {e!s}") from e
async def _handle_credit_recharge(self, task: ScheduledTask) -> None: async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
"""Handle credit recharge task.""" """Handle credit recharge task."""
@@ -105,7 +103,7 @@ class TaskHandlerRegistry:
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})") logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
except Exception as e: except Exception as e:
# Convert HTTP exceptions or credit errors to task execution errors # Convert HTTP exceptions or credit errors to task execution errors
raise TaskExecutionError(f"Failed to play sound with credits: {str(e)}") from e raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e
else: else:
# System task: play without credit deduction # System task: play without credit deduction
sound = await self.sound_repository.get_by_id(sound_id_int) sound = await self.sound_repository.get_by_id(sound_id_int)

View File

@@ -7,6 +7,7 @@ from datetime import datetime
from app.core.database import get_session_factory from app.core.database import get_session_factory
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
async def check_tasks(): async def check_tasks():
session_factory = get_session_factory() session_factory = get_session_factory()

View File

@@ -5,8 +5,9 @@ import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.core.database import get_session_factory from app.core.database import get_session_factory
from app.models.scheduled_task import RecurrenceType, TaskType
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
from app.models.scheduled_task import TaskType, RecurrenceType
async def create_future_task(): async def create_future_task():
session_factory = get_session_factory() session_factory = get_session_factory()

View File

@@ -4,9 +4,9 @@
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.core.database import get_session_factory
from app.main import get_global_scheduler_service from app.main import get_global_scheduler_service
from app.models.scheduled_task import TaskType, RecurrenceType from app.models.scheduled_task import RecurrenceType, TaskType
async def test_api_task_creation(): async def test_api_task_creation():
"""Test creating a task through the scheduler service (simulates API call).""" """Test creating a task through the scheduler service (simulates API call)."""

View File

@@ -5,8 +5,9 @@ import asyncio
from datetime import datetime from datetime import datetime
from app.core.database import get_session_factory from app.core.database import get_session_factory
from app.models.scheduled_task import RecurrenceType, TaskType
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
from app.models.scheduled_task import TaskType, RecurrenceType
async def create_test_task(): async def create_test_task():
session_factory = get_session_factory() session_factory = get_session_factory()

View File

@@ -3,8 +3,6 @@
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
import pytest
from app.models.scheduled_task import ( from app.models.scheduled_task import (
RecurrenceType, RecurrenceType,
ScheduledTask, ScheduledTask,

View File

@@ -2,7 +2,6 @@
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession

View File

@@ -51,7 +51,7 @@ class TestSchedulerService:
sample_task_data: dict, sample_task_data: dict,
): ):
"""Test creating a scheduled task.""" """Test creating a scheduled task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule: with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
assert task.id is not None assert task.id is not None
@@ -68,7 +68,7 @@ class TestSchedulerService:
test_user_id: uuid.UUID, test_user_id: uuid.UUID,
): ):
"""Test creating a user task.""" """Test creating a user task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task( task = await scheduler_service.create_task(
user_id=test_user_id, user_id=test_user_id,
**sample_task_data, **sample_task_data,
@@ -83,7 +83,7 @@ class TestSchedulerService:
sample_task_data: dict, sample_task_data: dict,
): ):
"""Test creating a system task.""" """Test creating a system task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
assert task.user_id is None assert task.user_id is None
@@ -95,7 +95,7 @@ class TestSchedulerService:
sample_task_data: dict, sample_task_data: dict,
): ):
"""Test creating a recurring task.""" """Test creating a recurring task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task( task = await scheduler_service.create_task(
recurrence_type=RecurrenceType.DAILY, recurrence_type=RecurrenceType.DAILY,
recurrence_count=5, recurrence_count=5,
@@ -118,7 +118,7 @@ class TestSchedulerService:
sample_task_data["scheduled_at"] = ny_time sample_task_data["scheduled_at"] = ny_time
sample_task_data["timezone"] = "America/New_York" sample_task_data["timezone"] = "America/New_York"
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# The scheduled_at should be converted to UTC # The scheduled_at should be converted to UTC
@@ -134,11 +134,11 @@ class TestSchedulerService:
): ):
"""Test cancelling a task.""" """Test cancelling a task."""
# Create a task first # Create a task first
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Mock the scheduler remove_job method # Mock the scheduler remove_job method
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove: with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
result = await scheduler_service.cancel_task(task.id) result = await scheduler_service.cancel_task(task.id)
assert result is True assert result is True
@@ -167,7 +167,7 @@ class TestSchedulerService:
test_user_id: uuid.UUID, test_user_id: uuid.UUID,
): ):
"""Test getting user tasks.""" """Test getting user tasks."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
# Create user task # Create user task
await scheduler_service.create_task( await scheduler_service.create_task(
user_id=test_user_id, user_id=test_user_id,
@@ -188,8 +188,8 @@ class TestSchedulerService:
): ):
"""Test ensuring system tasks exist.""" """Test ensuring system tasks exist."""
# Mock the repository to return no existing tasks # Mock the repository to return no existing tasks
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get: with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create: with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
mock_get.return_value = [] mock_get.return_value = []
await scheduler_service._ensure_system_tasks() await scheduler_service._ensure_system_tasks()
@@ -214,8 +214,8 @@ class TestSchedulerService:
is_active=True, is_active=True,
) )
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get: with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create: with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
mock_get.return_value = [existing_task] mock_get.return_value = [existing_task]
await scheduler_service._ensure_system_tasks() await scheduler_service._ensure_system_tasks()
@@ -312,7 +312,7 @@ class TestSchedulerService:
recurrence_type=recurrence_type, recurrence_type=recurrence_type,
) )
with patch('app.services.scheduler.datetime') as mock_datetime: with patch("app.services.scheduler.datetime") as mock_datetime:
mock_datetime.utcnow.return_value = now mock_datetime.utcnow.return_value = now
next_execution = scheduler_service._calculate_next_execution(task) next_execution = scheduler_service._calculate_next_execution(task)
@@ -335,7 +335,7 @@ class TestSchedulerService:
next_execution = scheduler_service._calculate_next_execution(task) next_execution = scheduler_service._calculate_next_execution(task)
assert next_execution is None assert next_execution is None
@patch('app.services.task_handlers.TaskHandlerRegistry') @patch("app.services.task_handlers.TaskHandlerRegistry")
async def test_execute_task_success( async def test_execute_task_success(
self, self,
mock_handler_class, mock_handler_class,
@@ -344,7 +344,7 @@ class TestSchedulerService:
): ):
"""Test successful task execution.""" """Test successful task execution."""
# Create task # Create task
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Mock handler registry # Mock handler registry
@@ -365,7 +365,7 @@ class TestSchedulerService:
assert updated_task.status == TaskStatus.COMPLETED assert updated_task.status == TaskStatus.COMPLETED
assert updated_task.executions_count == 1 assert updated_task.executions_count == 1
@patch('app.services.task_handlers.TaskHandlerRegistry') @patch("app.services.task_handlers.TaskHandlerRegistry")
async def test_execute_task_failure( async def test_execute_task_failure(
self, self,
mock_handler_class, mock_handler_class,
@@ -374,7 +374,7 @@ class TestSchedulerService:
): ):
"""Test task execution failure.""" """Test task execution failure."""
# Create task # Create task
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Mock handler to raise exception # Mock handler to raise exception
@@ -410,7 +410,7 @@ class TestSchedulerService:
# Create expired task # Create expired task
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1) sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Execute task # Execute task
@@ -430,20 +430,20 @@ class TestSchedulerService:
sample_task_data: dict, sample_task_data: dict,
): ):
"""Test prevention of concurrent task execution.""" """Test prevention of concurrent task execution."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Add task to running set # Add task to running set
scheduler_service._running_tasks.add(str(task.id)) scheduler_service._running_tasks.add(str(task.id))
# Try to execute - should return without doing anything # Try to execute - should return without doing anything
with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class: with patch("app.services.task_handlers.TaskHandlerRegistry") as mock_handler_class:
await scheduler_service._execute_task(task.id) await scheduler_service._execute_task(task.id)
# Handler should not be called # Handler should not be called
mock_handler_class.assert_not_called() mock_handler_class.assert_not_called()
@patch('app.repositories.scheduled_task.ScheduledTaskRepository') @patch("app.repositories.scheduled_task.ScheduledTaskRepository")
async def test_maintenance_job_expired_tasks( async def test_maintenance_job_expired_tasks(
self, self,
mock_repo_class, mock_repo_class,
@@ -459,7 +459,7 @@ class TestSchedulerService:
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [] mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove: with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
await scheduler_service._maintenance_job() await scheduler_service._maintenance_job()
# Should mark as cancelled and remove from scheduler # Should mark as cancelled and remove from scheduler
@@ -468,7 +468,7 @@ class TestSchedulerService:
mock_repo.update.assert_called_with(expired_task) mock_repo.update.assert_called_with(expired_task)
mock_remove.assert_called_once_with(str(expired_task.id)) mock_remove.assert_called_once_with(str(expired_task.id))
@patch('app.repositories.scheduled_task.ScheduledTaskRepository') @patch("app.repositories.scheduled_task.ScheduledTaskRepository")
async def test_maintenance_job_due_recurring_tasks( async def test_maintenance_job_due_recurring_tasks(
self, self,
mock_repo_class, mock_repo_class,
@@ -485,7 +485,7 @@ class TestSchedulerService:
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task] mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule: with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
await scheduler_service._maintenance_job() await scheduler_service._maintenance_job()
# Should reset to pending and reschedule # Should reset to pending and reschedule

View File

@@ -133,8 +133,8 @@ class TestTaskHandlerRegistry:
mock_sound.id = test_sound_id mock_sound.id = test_sound_id
mock_sound.filename = "test_sound.mp3" mock_sound.filename = "test_sound.mp3"
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound): with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class: with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_vlc_class.return_value = mock_vlc_service mock_vlc_class.return_value = mock_vlc_service
@@ -186,7 +186,7 @@ class TestTaskHandlerRegistry:
parameters={"sound_id": str(test_sound_id)}, parameters={"sound_id": str(test_sound_id)},
) )
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None): with patch.object(task_registry.sound_repository, "get_by_id", return_value=None):
with pytest.raises(TaskExecutionError, match="Sound not found"): with pytest.raises(TaskExecutionError, match="Sound not found"):
await task_registry.execute_task(task) await task_registry.execute_task(task)
@@ -206,8 +206,8 @@ class TestTaskHandlerRegistry:
mock_sound = MagicMock() mock_sound = MagicMock()
mock_sound.filename = "test_sound.mp3" mock_sound.filename = "test_sound.mp3"
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound): with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class: with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_vlc_class.return_value = mock_vlc_service mock_vlc_class.return_value = mock_vlc_service
@@ -238,7 +238,7 @@ class TestTaskHandlerRegistry:
mock_playlist.id = test_playlist_id mock_playlist.id = test_playlist_id
mock_playlist.name = "Test Playlist" mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task) await task_registry.execute_task(task)
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id) task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
@@ -264,7 +264,7 @@ class TestTaskHandlerRegistry:
mock_playlist = MagicMock() mock_playlist = MagicMock()
mock_playlist.name = "Test Playlist" mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task) await task_registry.execute_task(task)
# Should use default values # Should use default values
@@ -314,7 +314,7 @@ class TestTaskHandlerRegistry:
parameters={"playlist_id": str(test_playlist_id)}, parameters={"playlist_id": str(test_playlist_id)},
) )
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=None):
with pytest.raises(TaskExecutionError, match="Playlist not found"): with pytest.raises(TaskExecutionError, match="Playlist not found"):
await task_registry.execute_task(task) await task_registry.execute_task(task)
@@ -341,7 +341,7 @@ class TestTaskHandlerRegistry:
}, },
) )
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task) await task_registry.execute_task(task)
mock_player_service.set_mode.assert_called_with(mode) mock_player_service.set_mode.assert_called_with(mode)
@@ -368,7 +368,7 @@ class TestTaskHandlerRegistry:
mock_playlist = MagicMock() mock_playlist = MagicMock()
mock_playlist.name = "Test Playlist" mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task) await task_registry.execute_task(task)
# Should not call set_mode for invalid mode # Should not call set_mode for invalid mode