Refactor scheduled task repository and schemas for improved type hints and consistency

- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase.
- Refactored import statements for better organization and clarity.
- Enhanced the ScheduledTaskBase schema to use modern type hints.
- Cleaned up unnecessary comments and whitespace in various files.
- Improved error handling and logging in task execution handlers.
- Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
JSC
2025-08-28 23:38:47 +02:00
parent 96801dc4d6
commit dc89e45675
19 changed files with 292 additions and 291 deletions

View File

@@ -1,7 +1,5 @@
"""API endpoints for scheduled task management.""" """API endpoints for scheduled task management."""
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -11,7 +9,7 @@ from app.core.dependencies import (
get_admin_user, get_admin_user,
get_current_active_user, get_current_active_user,
) )
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
from app.models.user import User from app.models.user import User
from app.schemas.scheduler import ( from app.schemas.scheduler import (
ScheduledTaskCreate, ScheduledTaskCreate,
@@ -54,15 +52,15 @@ async def create_task(
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@router.get("/tasks", response_model=List[ScheduledTaskResponse]) @router.get("/tasks", response_model=list[ScheduledTaskResponse])
async def get_user_tasks( async def get_user_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"), status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"), task_type: TaskType | None = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"), limit: int | None = Query(50, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"), offset: int | None = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user),
scheduler_service: SchedulerService = Depends(get_scheduler_service), scheduler_service: SchedulerService = Depends(get_scheduler_service),
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get user's scheduled tasks.""" """Get user's scheduled tasks."""
return await scheduler_service.get_user_tasks( return await scheduler_service.get_user_tasks(
user_id=current_user.id, user_id=current_user.id,
@@ -81,17 +79,17 @@ async def get_task(
) -> ScheduledTask: ) -> ScheduledTask:
"""Get a specific scheduled task.""" """Get a specific scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session) repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id) task = await repo.get_by_id(task_id)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin # Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin: if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
return task return task
@@ -104,22 +102,22 @@ async def update_task(
) -> ScheduledTask: ) -> ScheduledTask:
"""Update a scheduled task.""" """Update a scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session) repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id) task = await repo.get_by_id(task_id)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin # Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin: if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
# Update task fields # Update task fields
update_data = task_update.model_dump(exclude_unset=True) update_data = task_update.model_dump(exclude_unset=True)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(task, field, value) setattr(task, field, value)
updated_task = await repo.update(task) updated_task = await repo.update(task)
return updated_task return updated_task
@@ -133,72 +131,72 @@ async def cancel_task(
) -> dict: ) -> dict:
"""Cancel a scheduled task.""" """Cancel a scheduled task."""
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session) repo = ScheduledTaskRepository(db_session)
task = await repo.get_by_id(task_id) task = await repo.get_by_id(task_id)
if not task: if not task:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
# Check if user owns the task or is admin # Check if user owns the task or is admin
if task.user_id != current_user.id and not current_user.is_admin: if task.user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
success = await scheduler_service.cancel_task(task_id) success = await scheduler_service.cancel_task(task_id)
if not success: if not success:
raise HTTPException(status_code=400, detail="Failed to cancel task") raise HTTPException(status_code=400, detail="Failed to cancel task")
return {"message": "Task cancelled successfully"} return {"message": "Task cancelled successfully"}
# Admin-only endpoints # Admin-only endpoints
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse]) @router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
async def get_all_tasks( async def get_all_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"), status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"), task_type: TaskType | None = Query(None, description="Filter by task type"),
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"), limit: int | None = Query(100, description="Maximum number of tasks to return"),
offset: Optional[int] = Query(0, description="Number of tasks to skip"), offset: int | None = Query(0, description="Number of tasks to skip"),
current_user: User = Depends(get_admin_user), current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db), db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get all scheduled tasks (admin only).""" """Get all scheduled tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session) repo = ScheduledTaskRepository(db_session)
# Get all tasks with pagination and filtering # Get all tasks with pagination and filtering
from sqlmodel import select from sqlmodel import select
statement = select(ScheduledTask) statement = select(ScheduledTask)
if status: if status:
statement = statement.where(ScheduledTask.status == status) statement = statement.where(ScheduledTask.status == status)
if task_type: if task_type:
statement = statement.where(ScheduledTask.task_type == task_type) statement = statement.where(ScheduledTask.task_type == task_type)
statement = statement.order_by(ScheduledTask.scheduled_at.desc()) statement = statement.order_by(ScheduledTask.scheduled_at.desc())
if offset: if offset:
statement = statement.offset(offset) statement = statement.offset(offset)
if limit: if limit:
statement = statement.limit(limit) statement = statement.limit(limit)
result = await db_session.exec(statement) result = await db_session.exec(statement)
return list(result.all()) return list(result.all())
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse]) @router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
async def get_system_tasks( async def get_system_tasks(
status: Optional[TaskStatus] = Query(None, description="Filter by task status"), status: TaskStatus | None = Query(None, description="Filter by task status"),
task_type: Optional[TaskType] = Query(None, description="Filter by task type"), task_type: TaskType | None = Query(None, description="Filter by task type"),
current_user: User = Depends(get_admin_user), current_user: User = Depends(get_admin_user),
db_session: AsyncSession = Depends(get_db), db_session: AsyncSession = Depends(get_db),
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get system tasks (admin only).""" """Get system tasks (admin only)."""
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
repo = ScheduledTaskRepository(db_session) repo = ScheduledTaskRepository(db_session)
return await repo.get_system_tasks(status=status, task_type=task_type) return await repo.get_system_tasks(status=status, task_type=task_type)
@@ -225,4 +223,4 @@ async def create_system_task(
) )
return task return task
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))

View File

@@ -4,11 +4,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel import SQLModel from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
from app.core.config import settings from app.core.config import settings
from app.core.logging import get_logger from app.core.logging import get_logger
from app.core.seeds import seed_all_data from app.core.seeds import seed_all_data
# Import all models to ensure SQLModel metadata discovery
import app.models # noqa: F401
engine: AsyncEngine = create_async_engine( engine: AsyncEngine = create_async_engine(
settings.DATABASE_URL, settings.DATABASE_URL,

View File

@@ -11,11 +11,14 @@ from app.core.database import get_session_factory, init_db
from app.core.logging import get_logger, setup_logging from app.core.logging import get_logger, setup_logging
from app.middleware.logging import LoggingMiddleware from app.middleware.logging import LoggingMiddleware
from app.services.extraction_processor import extraction_processor from app.services.extraction_processor import extraction_processor
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service from app.services.player import (
get_player_service,
initialize_player_service,
shutdown_player_service,
)
from app.services.scheduler import SchedulerService from app.services.scheduler import SchedulerService
from app.services.socket import socket_manager from app.services.socket import socket_manager
scheduler_service = None scheduler_service = None
@@ -31,7 +34,7 @@ def get_global_scheduler_service() -> SchedulerService:
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]: async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
"""Application lifespan context manager for setup and teardown.""" """Application lifespan context manager for setup and teardown."""
global scheduler_service global scheduler_service
setup_logging() setup_logging()
logger = get_logger(__name__) logger = get_logger(__name__)
logger.info("Starting application") logger.info("Starting application")

View File

@@ -20,7 +20,7 @@ __all__ = [
"CreditAction", "CreditAction",
"CreditTransaction", "CreditTransaction",
"Extraction", "Extraction",
"Favorite", "Favorite",
"Plan", "Plan",
"Playlist", "Playlist",
"PlaylistSound", "PlaylistSound",

View File

@@ -1,11 +1,10 @@
"""Scheduled task model for flexible task scheduling with timezone support.""" """Scheduled task model for flexible task scheduling with timezone support."""
import uuid
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any
from sqlmodel import JSON, Column, Field, SQLModel from sqlmodel import JSON, Column, Field
from app.models.base import BaseModel from app.models.base import BaseModel
@@ -57,11 +56,11 @@ class ScheduledTask(BaseModel, table=True):
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')", description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
) )
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE) recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
cron_expression: Optional[str] = Field( cron_expression: str | None = Field(
default=None, default=None,
description="Cron expression for custom recurrence (when recurrence_type is CRON)", description="Cron expression for custom recurrence (when recurrence_type is CRON)",
) )
recurrence_count: Optional[int] = Field( recurrence_count: int | None = Field(
default=None, default=None,
description="Number of times to repeat (None for infinite)", description="Number of times to repeat (None for infinite)",
) )
@@ -75,29 +74,29 @@ class ScheduledTask(BaseModel, table=True):
) )
# User association (None for system tasks) # User association (None for system tasks)
user_id: Optional[int] = Field( user_id: int | None = Field(
default=None, default=None,
foreign_key="user.id", foreign_key="user.id",
description="User who created the task (None for system tasks)", description="User who created the task (None for system tasks)",
) )
# Execution tracking # Execution tracking
last_executed_at: Optional[datetime] = Field( last_executed_at: datetime | None = Field(
default=None, default=None,
description="When the task was last executed (UTC)", description="When the task was last executed (UTC)",
) )
next_execution_at: Optional[datetime] = Field( next_execution_at: datetime | None = Field(
default=None, default=None,
description="When the task should be executed next (UTC, for recurring tasks)", description="When the task should be executed next (UTC, for recurring tasks)",
) )
error_message: Optional[str] = Field( error_message: str | None = Field(
default=None, default=None,
description="Error message if execution failed", description="Error message if execution failed",
) )
# Task lifecycle # Task lifecycle
is_active: bool = Field(default=True, description="Whether the task is active") is_active: bool = Field(default=True, description="Whether the task is active")
expires_at: Optional[datetime] = Field( expires_at: datetime | None = Field(
default=None, default=None,
description="When the task expires (UTC, optional)", description="When the task expires (UTC, optional)",
) )
@@ -122,4 +121,4 @@ class ScheduledTask(BaseModel, table=True):
def is_system_task(self) -> bool: def is_system_task(self) -> bool:
"""Check if this is a system task (no user association).""" """Check if this is a system task (no user association)."""
return self.user_id is None return self.user_id is None

View File

@@ -1,12 +1,16 @@
"""Repository for scheduled task operations.""" """Repository for scheduled task operations."""
from datetime import datetime from datetime import datetime
from typing import List, Optional
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.base import BaseRepository from app.repositories.base import BaseRepository
@@ -17,7 +21,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
"""Initialize the repository.""" """Initialize the repository."""
super().__init__(ScheduledTask, session) super().__init__(ScheduledTask, session)
async def get_pending_tasks(self) -> List[ScheduledTask]: async def get_pending_tasks(self) -> list[ScheduledTask]:
"""Get all pending tasks that are ready to be executed.""" """Get all pending tasks that are ready to be executed."""
now = datetime.utcnow() now = datetime.utcnow()
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
@@ -28,7 +32,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
async def get_active_tasks(self) -> List[ScheduledTask]: async def get_active_tasks(self) -> list[ScheduledTask]:
"""Get all active tasks.""" """Get all active tasks."""
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
ScheduledTask.is_active.is_(True), ScheduledTask.is_active.is_(True),
@@ -40,11 +44,11 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def get_user_tasks( async def get_user_tasks(
self, self,
user_id: int, user_id: int,
status: Optional[TaskStatus] = None, status: TaskStatus | None = None,
task_type: Optional[TaskType] = None, task_type: TaskType | None = None,
limit: Optional[int] = None, limit: int | None = None,
offset: Optional[int] = None, offset: int | None = None,
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get tasks for a specific user.""" """Get tasks for a specific user."""
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id) statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
@@ -67,9 +71,9 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def get_system_tasks( async def get_system_tasks(
self, self,
status: Optional[TaskStatus] = None, status: TaskStatus | None = None,
task_type: Optional[TaskType] = None, task_type: TaskType | None = None,
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get system tasks (tasks with no user association).""" """Get system tasks (tasks with no user association)."""
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None)) statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
@@ -84,7 +88,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]: async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]:
"""Get recurring tasks that need their next execution scheduled.""" """Get recurring tasks that need their next execution scheduled."""
now = datetime.utcnow() now = datetime.utcnow()
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
@@ -96,7 +100,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
result = await self.session.exec(statement) result = await self.session.exec(statement)
return list(result.all()) return list(result.all())
async def get_expired_tasks(self) -> List[ScheduledTask]: async def get_expired_tasks(self) -> list[ScheduledTask]:
"""Get expired tasks that should be cleaned up.""" """Get expired tasks that should be cleaned up."""
now = datetime.utcnow() now = datetime.utcnow()
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
@@ -110,7 +114,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def cancel_user_tasks( async def cancel_user_tasks(
self, self,
user_id: int, user_id: int,
task_type: Optional[TaskType] = None, task_type: TaskType | None = None,
) -> int: ) -> int:
"""Cancel all pending/running tasks for a user.""" """Cancel all pending/running tasks for a user."""
statement = select(ScheduledTask).where( statement = select(ScheduledTask).where(
@@ -144,7 +148,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
async def mark_as_completed( async def mark_as_completed(
self, self,
task: ScheduledTask, task: ScheduledTask,
next_execution_at: Optional[datetime] = None, next_execution_at: datetime | None = None,
) -> None: ) -> None:
"""Mark a task as completed and set next execution if recurring.""" """Mark a task as completed and set next execution if recurring."""
task.status = TaskStatus.COMPLETED task.status = TaskStatus.COMPLETED
@@ -174,4 +178,4 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
self.session.add(task) self.session.add(task)
await self.session.commit() await self.session.commit()
await self.session.refresh(task) await self.session.refresh(task)

View File

@@ -1,7 +1,7 @@
"""Schemas for scheduled task API.""" """Schemas for scheduled task API."""
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -15,7 +15,7 @@ class ScheduledTaskBase(BaseModel):
task_type: TaskType = Field(description="Type of task to execute") task_type: TaskType = Field(description="Type of task to execute")
scheduled_at: datetime = Field(description="When the task should be executed") scheduled_at: datetime = Field(description="When the task should be executed")
timezone: str = Field(default="UTC", description="Timezone for scheduling") timezone: str = Field(default="UTC", description="Timezone for scheduling")
parameters: Dict[str, Any] = Field( parameters: dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Task-specific parameters", description="Task-specific parameters",
) )
@@ -23,15 +23,15 @@ class ScheduledTaskBase(BaseModel):
default=RecurrenceType.NONE, default=RecurrenceType.NONE,
description="Recurrence pattern", description="Recurrence pattern",
) )
cron_expression: Optional[str] = Field( cron_expression: str | None = Field(
default=None, default=None,
description="Cron expression for custom recurrence", description="Cron expression for custom recurrence",
) )
recurrence_count: Optional[int] = Field( recurrence_count: int | None = Field(
default=None, default=None,
description="Number of times to repeat (None for infinite)", description="Number of times to repeat (None for infinite)",
) )
expires_at: Optional[datetime] = Field( expires_at: datetime | None = Field(
default=None, default=None,
description="When the task expires (optional)", description="When the task expires (optional)",
) )
@@ -40,18 +40,17 @@ class ScheduledTaskBase(BaseModel):
class ScheduledTaskCreate(ScheduledTaskBase): class ScheduledTaskCreate(ScheduledTaskBase):
"""Schema for creating a scheduled task.""" """Schema for creating a scheduled task."""
pass
class ScheduledTaskUpdate(BaseModel): class ScheduledTaskUpdate(BaseModel):
"""Schema for updating a scheduled task.""" """Schema for updating a scheduled task."""
name: Optional[str] = None name: str | None = None
scheduled_at: Optional[datetime] = None scheduled_at: datetime | None = None
timezone: Optional[str] = None timezone: str | None = None
parameters: Optional[Dict[str, Any]] = None parameters: dict[str, Any] | None = None
is_active: Optional[bool] = None is_active: bool | None = None
expires_at: Optional[datetime] = None expires_at: datetime | None = None
class ScheduledTaskResponse(ScheduledTaskBase): class ScheduledTaskResponse(ScheduledTaskBase):
@@ -59,11 +58,11 @@ class ScheduledTaskResponse(ScheduledTaskBase):
id: int id: int
status: TaskStatus status: TaskStatus
user_id: Optional[int] = None user_id: int | None = None
executions_count: int executions_count: int
last_executed_at: Optional[datetime] = None last_executed_at: datetime | None = None
next_execution_at: Optional[datetime] = None next_execution_at: datetime | None = None
error_message: Optional[str] = None error_message: str | None = None
is_active: bool is_active: bool
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@@ -78,7 +77,7 @@ class ScheduledTaskResponse(ScheduledTaskBase):
class CreditRechargeParameters(BaseModel): class CreditRechargeParameters(BaseModel):
"""Parameters for credit recharge tasks.""" """Parameters for credit recharge tasks."""
user_id: Optional[int] = Field( user_id: int | None = Field(
default=None, default=None,
description="Specific user ID to recharge (None for all users)", description="Specific user ID to recharge (None for all users)",
) )
@@ -109,10 +108,10 @@ class CreateCreditRechargeTask(BaseModel):
scheduled_at: datetime scheduled_at: datetime
timezone: str = "UTC" timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None cron_expression: str | None = None
recurrence_count: Optional[int] = None recurrence_count: int | None = None
expires_at: Optional[datetime] = None expires_at: datetime | None = None
user_id: Optional[int] = None user_id: int | None = None
def to_task_create(self) -> ScheduledTaskCreate: def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema.""" """Convert to generic task creation schema."""
@@ -137,9 +136,9 @@ class CreatePlaySoundTask(BaseModel):
sound_id: int sound_id: int
timezone: str = "UTC" timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None cron_expression: str | None = None
recurrence_count: Optional[int] = None recurrence_count: int | None = None
expires_at: Optional[datetime] = None expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate: def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema.""" """Convert to generic task creation schema."""
@@ -166,9 +165,9 @@ class CreatePlayPlaylistTask(BaseModel):
shuffle: bool = False shuffle: bool = False
timezone: str = "UTC" timezone: str = "UTC"
recurrence_type: RecurrenceType = RecurrenceType.NONE recurrence_type: RecurrenceType = RecurrenceType.NONE
cron_expression: Optional[str] = None cron_expression: str | None = None
recurrence_count: Optional[int] = None recurrence_count: int | None = None
expires_at: Optional[datetime] = None expires_at: datetime | None = None
def to_task_create(self) -> ScheduledTaskCreate: def to_task_create(self) -> ScheduledTaskCreate:
"""Convert to generic task creation schema.""" """Convert to generic task creation schema."""
@@ -186,4 +185,4 @@ class CreatePlayPlaylistTask(BaseModel):
cron_expression=self.cron_expression, cron_expression=self.cron_expression,
recurrence_count=self.recurrence_count, recurrence_count=self.recurrence_count,
expires_at=self.expires_at, expires_at=self.expires_at,
) )

View File

@@ -2,7 +2,7 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional from typing import Any
import pytz import pytz
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -52,7 +52,7 @@ class SchedulerService:
logger.info("Starting enhanced scheduler service...") logger.info("Starting enhanced scheduler service...")
self.scheduler.start() self.scheduler.start()
# Schedule system tasks initialization for after startup # Schedule system tasks initialization for after startup
self.scheduler.add_job( self.scheduler.add_job(
self._initialize_system_tasks, self._initialize_system_tasks,
@@ -62,7 +62,7 @@ class SchedulerService:
name="Initialize System Tasks", name="Initialize System Tasks",
replace_existing=True, replace_existing=True,
) )
# Schedule periodic cleanup and maintenance # Schedule periodic cleanup and maintenance
self.scheduler.add_job( self.scheduler.add_job(
self._maintenance_job, self._maintenance_job,
@@ -86,18 +86,18 @@ class SchedulerService:
name: str, name: str,
task_type: TaskType, task_type: TaskType,
scheduled_at: datetime, scheduled_at: datetime,
parameters: Optional[Dict[str, Any]] = None, parameters: dict[str, Any] | None = None,
user_id: Optional[int] = None, user_id: int | None = None,
timezone: str = "UTC", timezone: str = "UTC",
recurrence_type: RecurrenceType = RecurrenceType.NONE, recurrence_type: RecurrenceType = RecurrenceType.NONE,
cron_expression: Optional[str] = None, cron_expression: str | None = None,
recurrence_count: Optional[int] = None, recurrence_count: int | None = None,
expires_at: Optional[datetime] = None, expires_at: datetime | None = None,
) -> ScheduledTask: ) -> ScheduledTask:
"""Create a new scheduled task.""" """Create a new scheduled task."""
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
# Convert scheduled_at to UTC if it's in a different timezone # Convert scheduled_at to UTC if it's in a different timezone
if timezone != "UTC": if timezone != "UTC":
tz = pytz.timezone(timezone) tz = pytz.timezone(timezone)
@@ -105,7 +105,7 @@ class SchedulerService:
# Assume the datetime is in the specified timezone # Assume the datetime is in the specified timezone
scheduled_at = tz.localize(scheduled_at) scheduled_at = tz.localize(scheduled_at)
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None) scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
task_data = { task_data = {
"name": name, "name": name,
"task_type": task_type, "task_type": task_type,
@@ -118,59 +118,59 @@ class SchedulerService:
"recurrence_count": recurrence_count, "recurrence_count": recurrence_count,
"expires_at": expires_at, "expires_at": expires_at,
} }
created_task = await repo.create(task_data) created_task = await repo.create(task_data)
await self._schedule_apscheduler_job(created_task) await self._schedule_apscheduler_job(created_task)
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})") logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
return created_task return created_task
async def cancel_task(self, task_id: int) -> bool: async def cancel_task(self, task_id: int) -> bool:
"""Cancel a scheduled task.""" """Cancel a scheduled task."""
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
task = await repo.get_by_id(task_id) task = await repo.get_by_id(task_id)
if not task: if not task:
return False return False
task.status = TaskStatus.CANCELLED task.status = TaskStatus.CANCELLED
task.is_active = False task.is_active = False
await repo.update(task) await repo.update(task)
# Remove from APScheduler # Remove from APScheduler
try: try:
self.scheduler.remove_job(str(task_id)) self.scheduler.remove_job(str(task_id))
except Exception: except Exception:
pass # Job might not exist in scheduler pass # Job might not exist in scheduler
logger.info(f"Cancelled task: {task.name} ({task_id})") logger.info(f"Cancelled task: {task.name} ({task_id})")
return True return True
async def get_user_tasks( async def get_user_tasks(
self, self,
user_id: int, user_id: int,
status: Optional[TaskStatus] = None, status: TaskStatus | None = None,
task_type: Optional[TaskType] = None, task_type: TaskType | None = None,
limit: Optional[int] = None, limit: int | None = None,
offset: Optional[int] = None, offset: int | None = None,
) -> List[ScheduledTask]: ) -> list[ScheduledTask]:
"""Get tasks for a specific user.""" """Get tasks for a specific user."""
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
return await repo.get_user_tasks(user_id, status, task_type, limit, offset) return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
async def _initialize_system_tasks(self) -> None: async def _initialize_system_tasks(self) -> None:
"""Initialize system tasks and load active tasks from database.""" """Initialize system tasks and load active tasks from database."""
logger.info("Initializing system tasks...") logger.info("Initializing system tasks...")
try: try:
# Create system tasks if they don't exist # Create system tasks if they don't exist
await self._ensure_system_tasks() await self._ensure_system_tasks()
# Load all active tasks from database # Load all active tasks from database
await self._load_active_tasks() await self._load_active_tasks()
logger.info("System tasks initialized successfully") logger.info("System tasks initialized successfully")
except Exception: except Exception:
logger.exception("Failed to initialize system tasks") logger.exception("Failed to initialize system tasks")
@@ -179,24 +179,24 @@ class SchedulerService:
"""Ensure required system tasks exist.""" """Ensure required system tasks exist."""
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
# Check if daily credit recharge task exists # Check if daily credit recharge task exists
system_tasks = await repo.get_system_tasks( system_tasks = await repo.get_system_tasks(
task_type=TaskType.CREDIT_RECHARGE task_type=TaskType.CREDIT_RECHARGE,
) )
daily_recharge_exists = any( daily_recharge_exists = any(
task.recurrence_type == RecurrenceType.DAILY task.recurrence_type == RecurrenceType.DAILY
and task.is_active and task.is_active
for task in system_tasks for task in system_tasks
) )
if not daily_recharge_exists: if not daily_recharge_exists:
# Create daily credit recharge task # Create daily credit recharge task
tomorrow_midnight = datetime.utcnow().replace( tomorrow_midnight = datetime.utcnow().replace(
hour=0, minute=0, second=0, microsecond=0 hour=0, minute=0, second=0, microsecond=0,
) + timedelta(days=1) ) + timedelta(days=1)
task_data = { task_data = {
"name": "Daily Credit Recharge", "name": "Daily Credit Recharge",
"task_type": TaskType.CREDIT_RECHARGE, "task_type": TaskType.CREDIT_RECHARGE,
@@ -204,41 +204,41 @@ class SchedulerService:
"recurrence_type": RecurrenceType.DAILY, "recurrence_type": RecurrenceType.DAILY,
"parameters": {}, "parameters": {},
} }
await repo.create(task_data) await repo.create(task_data)
logger.info("Created system daily credit recharge task") logger.info("Created system daily credit recharge task")
async def _load_active_tasks(self) -> None: async def _load_active_tasks(self) -> None:
"""Load all active tasks from database into scheduler.""" """Load all active tasks from database into scheduler."""
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
active_tasks = await repo.get_active_tasks() active_tasks = await repo.get_active_tasks()
for task in active_tasks: for task in active_tasks:
await self._schedule_apscheduler_job(task) await self._schedule_apscheduler_job(task)
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler") logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None: async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
"""Schedule a task in APScheduler.""" """Schedule a task in APScheduler."""
job_id = str(task.id) job_id = str(task.id)
# Remove existing job if it exists # Remove existing job if it exists
try: try:
self.scheduler.remove_job(job_id) self.scheduler.remove_job(job_id)
except Exception: except Exception:
pass pass
# Don't schedule if task is not active or already completed/failed # Don't schedule if task is not active or already completed/failed
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
return return
# Create trigger based on recurrence type # Create trigger based on recurrence type
trigger = self._create_trigger(task) trigger = self._create_trigger(task)
if not trigger: if not trigger:
logger.warning(f"Could not create trigger for task {task.id}") logger.warning(f"Could not create trigger for task {task.id}")
return return
# Schedule the job # Schedule the job
self.scheduler.add_job( self.scheduler.add_job(
self._execute_task, self._execute_task,
@@ -248,76 +248,76 @@ class SchedulerService:
name=task.name, name=task.name,
replace_existing=True, replace_existing=True,
) )
logger.debug(f"Scheduled APScheduler job for task {task.id}") logger.debug(f"Scheduled APScheduler job for task {task.id}")
def _create_trigger(self, task: ScheduledTask): def _create_trigger(self, task: ScheduledTask):
"""Create APScheduler trigger based on task configuration.""" """Create APScheduler trigger based on task configuration."""
tz = pytz.timezone(task.timezone) tz = pytz.timezone(task.timezone)
if task.recurrence_type == RecurrenceType.NONE: if task.recurrence_type == RecurrenceType.NONE:
return DateTrigger(run_date=task.scheduled_at, timezone=tz) return DateTrigger(run_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression: if task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
return CronTrigger.from_crontab(task.cron_expression, timezone=tz) return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
elif task.recurrence_type == RecurrenceType.HOURLY: if task.recurrence_type == RecurrenceType.HOURLY:
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz) return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.DAILY: if task.recurrence_type == RecurrenceType.DAILY:
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz) return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.WEEKLY: if task.recurrence_type == RecurrenceType.WEEKLY:
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz) return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.MONTHLY: if task.recurrence_type == RecurrenceType.MONTHLY:
# Use cron trigger for monthly (more reliable than interval) # Use cron trigger for monthly (more reliable than interval)
scheduled_time = task.scheduled_at scheduled_time = task.scheduled_at
return CronTrigger( return CronTrigger(
day=scheduled_time.day, day=scheduled_time.day,
hour=scheduled_time.hour, hour=scheduled_time.hour,
minute=scheduled_time.minute, minute=scheduled_time.minute,
timezone=tz timezone=tz,
) )
elif task.recurrence_type == RecurrenceType.YEARLY: if task.recurrence_type == RecurrenceType.YEARLY:
scheduled_time = task.scheduled_at scheduled_time = task.scheduled_at
return CronTrigger( return CronTrigger(
month=scheduled_time.month, month=scheduled_time.month,
day=scheduled_time.day, day=scheduled_time.day,
hour=scheduled_time.hour, hour=scheduled_time.hour,
minute=scheduled_time.minute, minute=scheduled_time.minute,
timezone=tz timezone=tz,
) )
return None return None
async def _execute_task(self, task_id: int) -> None: async def _execute_task(self, task_id: int) -> None:
"""Execute a scheduled task.""" """Execute a scheduled task."""
task_id_str = str(task_id) task_id_str = str(task_id)
# Prevent concurrent execution of the same task # Prevent concurrent execution of the same task
if task_id_str in self._running_tasks: if task_id_str in self._running_tasks:
logger.warning(f"Task {task_id} is already running, skipping execution") logger.warning(f"Task {task_id} is already running, skipping execution")
return return
self._running_tasks.add(task_id_str) self._running_tasks.add(task_id_str)
try: try:
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
# Get fresh task data # Get fresh task data
task = await repo.get_by_id(task_id) task = await repo.get_by_id(task_id)
if not task: if not task:
logger.warning(f"Task {task_id} not found") logger.warning(f"Task {task_id} not found")
return return
# Check if task is still active and pending # Check if task is still active and pending
if not task.is_active or task.status != TaskStatus.PENDING: if not task.is_active or task.status != TaskStatus.PENDING:
logger.info(f"Task {task_id} is not active or not pending, skipping") logger.info(f"Task {task_id} is not active or not pending, skipping")
return return
# Check if task has expired # Check if task has expired
if task.is_expired(): if task.is_expired():
logger.info(f"Task {task_id} has expired, marking as cancelled") logger.info(f"Task {task_id} has expired, marking as cancelled")
@@ -325,78 +325,78 @@ class SchedulerService:
task.is_active = False task.is_active = False
await repo.update(task) await repo.update(task)
return return
# Mark task as running # Mark task as running
await repo.mark_as_running(task) await repo.mark_as_running(task)
# Execute the task # Execute the task
try: try:
handler_registry = TaskHandlerRegistry( handler_registry = TaskHandlerRegistry(
session, self.db_session_factory, self.credit_service, self.player_service session, self.db_session_factory, self.credit_service, self.player_service,
) )
await handler_registry.execute_task(task) await handler_registry.execute_task(task)
# Calculate next execution time for recurring tasks # Calculate next execution time for recurring tasks
next_execution_at = None next_execution_at = None
if task.should_repeat(): if task.should_repeat():
next_execution_at = self._calculate_next_execution(task) next_execution_at = self._calculate_next_execution(task)
# Mark as completed # Mark as completed
await repo.mark_as_completed(task, next_execution_at) await repo.mark_as_completed(task, next_execution_at)
# Reschedule if recurring # Reschedule if recurring
if next_execution_at and task.should_repeat(): if next_execution_at and task.should_repeat():
# Refresh task to get updated data # Refresh task to get updated data
await session.refresh(task) await session.refresh(task)
await self._schedule_apscheduler_job(task) await self._schedule_apscheduler_job(task)
except Exception as e: except Exception as e:
await repo.mark_as_failed(task, str(e)) await repo.mark_as_failed(task, str(e))
logger.exception(f"Task {task_id} execution failed: {str(e)}") logger.exception(f"Task {task_id} execution failed: {e!s}")
finally: finally:
self._running_tasks.discard(task_id_str) self._running_tasks.discard(task_id_str)
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]: def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None:
"""Calculate the next execution time for a recurring task.""" """Calculate the next execution time for a recurring task."""
now = datetime.utcnow() now = datetime.utcnow()
if task.recurrence_type == RecurrenceType.HOURLY: if task.recurrence_type == RecurrenceType.HOURLY:
return now + timedelta(hours=1) return now + timedelta(hours=1)
elif task.recurrence_type == RecurrenceType.DAILY: if task.recurrence_type == RecurrenceType.DAILY:
return now + timedelta(days=1) return now + timedelta(days=1)
elif task.recurrence_type == RecurrenceType.WEEKLY: if task.recurrence_type == RecurrenceType.WEEKLY:
return now + timedelta(weeks=1) return now + timedelta(weeks=1)
elif task.recurrence_type == RecurrenceType.MONTHLY: if task.recurrence_type == RecurrenceType.MONTHLY:
# Add approximately one month # Add approximately one month
return now + timedelta(days=30) return now + timedelta(days=30)
elif task.recurrence_type == RecurrenceType.YEARLY: if task.recurrence_type == RecurrenceType.YEARLY:
return now + timedelta(days=365) return now + timedelta(days=365)
return None return None
async def _maintenance_job(self) -> None: async def _maintenance_job(self) -> None:
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues.""" """Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
try: try:
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
# Handle expired tasks # Handle expired tasks
expired_tasks = await repo.get_expired_tasks() expired_tasks = await repo.get_expired_tasks()
for task in expired_tasks: for task in expired_tasks:
task.status = TaskStatus.CANCELLED task.status = TaskStatus.CANCELLED
task.is_active = False task.is_active = False
await repo.update(task) await repo.update(task)
# Remove from scheduler # Remove from scheduler
try: try:
self.scheduler.remove_job(str(task.id)) self.scheduler.remove_job(str(task.id))
except Exception: except Exception:
pass pass
if expired_tasks: if expired_tasks:
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks") logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
# Handle any missed recurring tasks # Handle any missed recurring tasks
due_recurring = await repo.get_recurring_tasks_due_for_next_execution() due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
for task in due_recurring: for task in due_recurring:
@@ -405,9 +405,9 @@ class SchedulerService:
task.scheduled_at = task.next_execution_at or datetime.utcnow() task.scheduled_at = task.next_execution_at or datetime.utcnow()
await repo.update(task) await repo.update(task)
await self._schedule_apscheduler_job(task) await self._schedule_apscheduler_job(task)
if due_recurring: if due_recurring:
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks") logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
except Exception: except Exception:
logger.exception("Maintenance job failed") logger.exception("Maintenance job failed")

View File

@@ -1,6 +1,5 @@
"""Task execution handlers for different task types.""" """Task execution handlers for different task types."""
from typing import Any, Dict, Optional
from collections.abc import Callable from collections.abc import Callable
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -18,7 +17,6 @@ logger = get_logger(__name__)
class TaskExecutionError(Exception): class TaskExecutionError(Exception):
"""Exception raised when task execution fails.""" """Exception raised when task execution fails."""
pass
class TaskHandlerRegistry: class TaskHandlerRegistry:
@@ -58,8 +56,8 @@ class TaskHandlerRegistry:
await handler(task) await handler(task)
logger.info(f"Task {task.id} executed successfully") logger.info(f"Task {task.id} executed successfully")
except Exception as e: except Exception as e:
logger.exception(f"Task {task.id} execution failed: {str(e)}") logger.exception(f"Task {task.id} execution failed: {e!s}")
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e raise TaskExecutionError(f"Task execution failed: {e!s}") from e
async def _handle_credit_recharge(self, task: ScheduledTask) -> None: async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
"""Handle credit recharge task.""" """Handle credit recharge task."""
@@ -72,7 +70,7 @@ class TaskHandlerRegistry:
user_id_int = int(user_id) user_id_int = int(user_id)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise TaskExecutionError(f"Invalid user_id format: {user_id}") from e raise TaskExecutionError(f"Invalid user_id format: {user_id}") from e
stats = await self.credit_service.recharge_user_credits(user_id_int) stats = await self.credit_service.recharge_user_credits(user_id_int)
logger.info(f"Recharged credits for user {user_id}: {stats}") logger.info(f"Recharged credits for user {user_id}: {stats}")
else: else:
@@ -105,7 +103,7 @@ class TaskHandlerRegistry:
logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})") logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})")
except Exception as e: except Exception as e:
# Convert HTTP exceptions or credit errors to task execution errors # Convert HTTP exceptions or credit errors to task execution errors
raise TaskExecutionError(f"Failed to play sound with credits: {str(e)}") from e raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e
else: else:
# System task: play without credit deduction # System task: play without credit deduction
sound = await self.sound_repository.get_by_id(sound_id_int) sound = await self.sound_repository.get_by_id(sound_id_int)
@@ -116,10 +114,10 @@ class TaskHandlerRegistry:
vlc_service = VLCPlayerService(self.db_session_factory) vlc_service = VLCPlayerService(self.db_session_factory)
success = await vlc_service.play_sound(sound) success = await vlc_service.play_sound(sound)
if not success: if not success:
raise TaskExecutionError(f"Failed to play sound {sound.filename}") raise TaskExecutionError(f"Failed to play sound {sound.filename}")
logger.info(f"Played sound {sound.filename} via scheduled system task") logger.info(f"Played sound {sound.filename} via scheduled system task")
async def _handle_play_playlist(self, task: ScheduledTask) -> None: async def _handle_play_playlist(self, task: ScheduledTask) -> None:
@@ -157,4 +155,4 @@ class TaskHandlerRegistry:
# Start playing # Start playing
await self.player_service.play() await self.player_service.play()
logger.info(f"Started playing playlist {playlist.name} via scheduled task") logger.info(f"Started playing playlist {playlist.name} via scheduled task")

View File

@@ -238,13 +238,13 @@ class VLCPlayerService:
return return
logger.info("Recording play count for sound %s", sound_id) logger.info("Recording play count for sound %s", sound_id)
# Initialize variables for WebSocket event # Initialize variables for WebSocket event
old_count = 0 old_count = 0
sound = None sound = None
admin_user_id = None admin_user_id = None
admin_user_name = None admin_user_name = None
try: try:
async with self.db_session_factory() as session: async with self.db_session_factory() as session:
sound_repo = SoundRepository(session) sound_repo = SoundRepository(session)

View File

@@ -7,15 +7,16 @@ from datetime import datetime
from app.core.database import get_session_factory from app.core.database import get_session_factory
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
async def check_tasks(): async def check_tasks():
session_factory = get_session_factory() session_factory = get_session_factory()
async with session_factory() as session: async with session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
# Get all tasks # Get all tasks
all_tasks = await repo.get_all(limit=20) all_tasks = await repo.get_all(limit=20)
print("All tasks in database:") print("All tasks in database:")
print("=" * 80) print("=" * 80)
for task in all_tasks: for task in all_tasks:
@@ -32,14 +33,14 @@ async def check_tasks():
print(f"Error: {task.error_message}") print(f"Error: {task.error_message}")
print(f"Parameters: {task.parameters}") print(f"Parameters: {task.parameters}")
print("-" * 40) print("-" * 40)
# Check specifically for pending tasks # Check specifically for pending tasks
print(f"\nCurrent time: {datetime.utcnow()}") print(f"\nCurrent time: {datetime.utcnow()}")
print("\nPending tasks:") print("\nPending tasks:")
from app.models.scheduled_task import TaskStatus from app.models.scheduled_task import TaskStatus
pending_tasks = await repo.get_all(limit=10) pending_tasks = await repo.get_all(limit=10)
pending_tasks = [t for t in pending_tasks if t.status == TaskStatus.PENDING and t.is_active] pending_tasks = [t for t in pending_tasks if t.status == TaskStatus.PENDING and t.is_active]
if not pending_tasks: if not pending_tasks:
print("No pending tasks found") print("No pending tasks found")
else: else:
@@ -48,4 +49,4 @@ async def check_tasks():
print(f"- {task.name} (ID: {task.id}): scheduled for {task.scheduled_at} (in {time_diff})") print(f"- {task.name} (ID: {task.id}): scheduled for {task.scheduled_at} (in {time_diff})")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(check_tasks()) asyncio.run(check_tasks())

View File

@@ -5,18 +5,19 @@ import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.core.database import get_session_factory from app.core.database import get_session_factory
from app.models.scheduled_task import RecurrenceType, TaskType
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
from app.models.scheduled_task import TaskType, RecurrenceType
async def create_future_task(): async def create_future_task():
session_factory = get_session_factory() session_factory = get_session_factory()
# Create a task for 2 minutes from now # Create a task for 2 minutes from now
future_time = datetime.utcnow() + timedelta(minutes=2) future_time = datetime.utcnow() + timedelta(minutes=2)
async with session_factory() as session: async with session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
task_data = { task_data = {
"name": f"Future Task {future_time.strftime('%H:%M:%S')}", "name": f"Future Task {future_time.strftime('%H:%M:%S')}",
"task_type": TaskType.PLAY_SOUND, "task_type": TaskType.PLAY_SOUND,
@@ -26,11 +27,11 @@ async def create_future_task():
"user_id": 1, "user_id": 1,
"recurrence_type": RecurrenceType.NONE, "recurrence_type": RecurrenceType.NONE,
} }
task = await repo.create(task_data) task = await repo.create(task_data)
print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}") print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}")
print(f"Current time: {datetime.utcnow()}") print(f"Current time: {datetime.utcnow()}")
print(f"Task will execute in: {future_time - datetime.utcnow()}") print(f"Task will execute in: {future_time - datetime.utcnow()}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(create_future_task()) asyncio.run(create_future_task())

View File

@@ -4,21 +4,21 @@
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from app.core.database import get_session_factory
from app.main import get_global_scheduler_service from app.main import get_global_scheduler_service
from app.models.scheduled_task import TaskType, RecurrenceType from app.models.scheduled_task import RecurrenceType, TaskType
async def test_api_task_creation(): async def test_api_task_creation():
"""Test creating a task through the scheduler service (simulates API call).""" """Test creating a task through the scheduler service (simulates API call)."""
try: try:
scheduler_service = get_global_scheduler_service() scheduler_service = get_global_scheduler_service()
# Create a task for 2 minutes from now # Create a task for 2 minutes from now
future_time = datetime.utcnow() + timedelta(minutes=2) future_time = datetime.utcnow() + timedelta(minutes=2)
print(f"Creating task scheduled for: {future_time}") print(f"Creating task scheduled for: {future_time}")
print(f"Current time: {datetime.utcnow()}") print(f"Current time: {datetime.utcnow()}")
task = await scheduler_service.create_task( task = await scheduler_service.create_task(
name=f"API Test Task {future_time.strftime('%H:%M:%S')}", name=f"API Test Task {future_time.strftime('%H:%M:%S')}",
task_type=TaskType.PLAY_SOUND, task_type=TaskType.PLAY_SOUND,
@@ -28,13 +28,13 @@ async def test_api_task_creation():
timezone="UTC", timezone="UTC",
recurrence_type=RecurrenceType.NONE, recurrence_type=RecurrenceType.NONE,
) )
print(f"Created task: {task.name} (ID: {task.id})") print(f"Created task: {task.name} (ID: {task.id})")
print(f"Task will execute in: {future_time - datetime.utcnow()}") print(f"Task will execute in: {future_time - datetime.utcnow()}")
print("Task should be automatically scheduled in APScheduler!") print("Task should be automatically scheduled in APScheduler!")
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(test_api_task_creation()) asyncio.run(test_api_task_creation())

View File

@@ -5,15 +5,16 @@ import asyncio
from datetime import datetime from datetime import datetime
from app.core.database import get_session_factory from app.core.database import get_session_factory
from app.models.scheduled_task import RecurrenceType, TaskType
from app.repositories.scheduled_task import ScheduledTaskRepository from app.repositories.scheduled_task import ScheduledTaskRepository
from app.models.scheduled_task import TaskType, RecurrenceType
async def create_test_task(): async def create_test_task():
session_factory = get_session_factory() session_factory = get_session_factory()
async with session_factory() as session: async with session_factory() as session:
repo = ScheduledTaskRepository(session) repo = ScheduledTaskRepository(session)
task_data = { task_data = {
"name": "Live Test Task", "name": "Live Test Task",
"task_type": TaskType.PLAY_SOUND, "task_type": TaskType.PLAY_SOUND,
@@ -23,9 +24,9 @@ async def create_test_task():
"user_id": 1, "user_id": 1,
"recurrence_type": RecurrenceType.NONE, "recurrence_type": RecurrenceType.NONE,
} }
task = await repo.create(task_data) task = await repo.create(task_data)
print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}") print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(create_test_task()) asyncio.run(create_test_task())

View File

@@ -351,11 +351,11 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
@pytest.fixture @pytest.fixture
def test_user_id(test_user: User): def test_user_id(test_user: User):
"""Get test user ID.""" """Get test user ID."""
return test_user.id return test_user.id
@pytest.fixture @pytest.fixture
def test_sound_id(): def test_sound_id():
"""Create a test sound ID.""" """Create a test sound ID."""
import uuid import uuid
@@ -364,7 +364,7 @@ def test_sound_id():
@pytest.fixture @pytest.fixture
def test_playlist_id(): def test_playlist_id():
"""Create a test playlist ID.""" """Create a test playlist ID."""
import uuid import uuid
return uuid.uuid4() return uuid.uuid4()

View File

@@ -3,8 +3,6 @@
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
import pytest
from app.models.scheduled_task import ( from app.models.scheduled_task import (
RecurrenceType, RecurrenceType,
ScheduledTask, ScheduledTask,
@@ -217,4 +215,4 @@ class TestScheduledTaskModel:
assert RecurrenceType.WEEKLY == "weekly" assert RecurrenceType.WEEKLY == "weekly"
assert RecurrenceType.MONTHLY == "monthly" assert RecurrenceType.MONTHLY == "monthly"
assert RecurrenceType.YEARLY == "yearly" assert RecurrenceType.YEARLY == "yearly"
assert RecurrenceType.CRON == "cron" assert RecurrenceType.CRON == "cron"

View File

@@ -2,7 +2,6 @@
import uuid import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -491,4 +490,4 @@ class TestScheduledTaskRepository:
updated_task = await repository.get_by_id(sample_task.id) updated_task = await repository.get_by_id(sample_task.id)
assert updated_task.status == TaskStatus.FAILED assert updated_task.status == TaskStatus.FAILED
# Non-recurring tasks should be deactivated on failure # Non-recurring tasks should be deactivated on failure
assert updated_task.is_active is False assert updated_task.is_active is False

View File

@@ -51,7 +51,7 @@ class TestSchedulerService:
sample_task_data: dict, sample_task_data: dict,
): ):
"""Test creating a scheduled task.""" """Test creating a scheduled task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule: with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
assert task.id is not None assert task.id is not None
@@ -68,7 +68,7 @@ class TestSchedulerService:
test_user_id: uuid.UUID, test_user_id: uuid.UUID,
): ):
"""Test creating a user task.""" """Test creating a user task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task( task = await scheduler_service.create_task(
user_id=test_user_id, user_id=test_user_id,
**sample_task_data, **sample_task_data,
@@ -83,7 +83,7 @@ class TestSchedulerService:
sample_task_data: dict, sample_task_data: dict,
): ):
"""Test creating a system task.""" """Test creating a system task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
assert task.user_id is None assert task.user_id is None
@@ -95,7 +95,7 @@ class TestSchedulerService:
sample_task_data: dict, sample_task_data: dict,
): ):
"""Test creating a recurring task.""" """Test creating a recurring task."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task( task = await scheduler_service.create_task(
recurrence_type=RecurrenceType.DAILY, recurrence_type=RecurrenceType.DAILY,
recurrence_count=5, recurrence_count=5,
@@ -114,11 +114,11 @@ class TestSchedulerService:
"""Test creating task with timezone conversion.""" """Test creating task with timezone conversion."""
# Use a specific datetime for testing # Use a specific datetime for testing
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY
sample_task_data["scheduled_at"] = ny_time sample_task_data["scheduled_at"] = ny_time
sample_task_data["timezone"] = "America/New_York" sample_task_data["timezone"] = "America/New_York"
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# The scheduled_at should be converted to UTC # The scheduled_at should be converted to UTC
@@ -134,11 +134,11 @@ class TestSchedulerService:
): ):
"""Test cancelling a task.""" """Test cancelling a task."""
# Create a task first # Create a task first
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Mock the scheduler remove_job method # Mock the scheduler remove_job method
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove: with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
result = await scheduler_service.cancel_task(task.id) result = await scheduler_service.cancel_task(task.id)
assert result is True assert result is True
@@ -167,7 +167,7 @@ class TestSchedulerService:
test_user_id: uuid.UUID, test_user_id: uuid.UUID,
): ):
"""Test getting user tasks.""" """Test getting user tasks."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
# Create user task # Create user task
await scheduler_service.create_task( await scheduler_service.create_task(
user_id=test_user_id, user_id=test_user_id,
@@ -188,12 +188,12 @@ class TestSchedulerService:
): ):
"""Test ensuring system tasks exist.""" """Test ensuring system tasks exist."""
# Mock the repository to return no existing tasks # Mock the repository to return no existing tasks
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get: with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create: with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
mock_get.return_value = [] mock_get.return_value = []
await scheduler_service._ensure_system_tasks() await scheduler_service._ensure_system_tasks()
# Should create daily credit recharge task # Should create daily credit recharge task
mock_create.assert_called_once() mock_create.assert_called_once()
created_task = mock_create.call_args[0][0] created_task = mock_create.call_args[0][0]
@@ -213,13 +213,13 @@ class TestSchedulerService:
recurrence_type=RecurrenceType.DAILY, recurrence_type=RecurrenceType.DAILY,
is_active=True, is_active=True,
) )
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get: with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get:
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create: with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create:
mock_get.return_value = [existing_task] mock_get.return_value = [existing_task]
await scheduler_service._ensure_system_tasks() await scheduler_service._ensure_system_tasks()
# Should not create new task # Should not create new task
mock_create.assert_not_called() mock_create.assert_not_called()
@@ -294,7 +294,7 @@ class TestSchedulerService:
): ):
"""Test calculating next execution time.""" """Test calculating next execution time."""
now = datetime.utcnow() now = datetime.utcnow()
# Test different recurrence types # Test different recurrence types
test_cases = [ test_cases = [
(RecurrenceType.HOURLY, timedelta(hours=1)), (RecurrenceType.HOURLY, timedelta(hours=1)),
@@ -312,7 +312,7 @@ class TestSchedulerService:
recurrence_type=recurrence_type, recurrence_type=recurrence_type,
) )
with patch('app.services.scheduler.datetime') as mock_datetime: with patch("app.services.scheduler.datetime") as mock_datetime:
mock_datetime.utcnow.return_value = now mock_datetime.utcnow.return_value = now
next_execution = scheduler_service._calculate_next_execution(task) next_execution = scheduler_service._calculate_next_execution(task)
@@ -335,7 +335,7 @@ class TestSchedulerService:
next_execution = scheduler_service._calculate_next_execution(task) next_execution = scheduler_service._calculate_next_execution(task)
assert next_execution is None assert next_execution is None
@patch('app.services.task_handlers.TaskHandlerRegistry') @patch("app.services.task_handlers.TaskHandlerRegistry")
async def test_execute_task_success( async def test_execute_task_success(
self, self,
mock_handler_class, mock_handler_class,
@@ -344,7 +344,7 @@ class TestSchedulerService:
): ):
"""Test successful task execution.""" """Test successful task execution."""
# Create task # Create task
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Mock handler registry # Mock handler registry
@@ -365,7 +365,7 @@ class TestSchedulerService:
assert updated_task.status == TaskStatus.COMPLETED assert updated_task.status == TaskStatus.COMPLETED
assert updated_task.executions_count == 1 assert updated_task.executions_count == 1
@patch('app.services.task_handlers.TaskHandlerRegistry') @patch("app.services.task_handlers.TaskHandlerRegistry")
async def test_execute_task_failure( async def test_execute_task_failure(
self, self,
mock_handler_class, mock_handler_class,
@@ -374,7 +374,7 @@ class TestSchedulerService:
): ):
"""Test task execution failure.""" """Test task execution failure."""
# Create task # Create task
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Mock handler to raise exception # Mock handler to raise exception
@@ -409,8 +409,8 @@ class TestSchedulerService:
"""Test executing expired task.""" """Test executing expired task."""
# Create expired task # Create expired task
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1) sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Execute task # Execute task
@@ -430,20 +430,20 @@ class TestSchedulerService:
sample_task_data: dict, sample_task_data: dict,
): ):
"""Test prevention of concurrent task execution.""" """Test prevention of concurrent task execution."""
with patch.object(scheduler_service, '_schedule_apscheduler_job'): with patch.object(scheduler_service, "_schedule_apscheduler_job"):
task = await scheduler_service.create_task(**sample_task_data) task = await scheduler_service.create_task(**sample_task_data)
# Add task to running set # Add task to running set
scheduler_service._running_tasks.add(str(task.id)) scheduler_service._running_tasks.add(str(task.id))
# Try to execute - should return without doing anything # Try to execute - should return without doing anything
with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class: with patch("app.services.task_handlers.TaskHandlerRegistry") as mock_handler_class:
await scheduler_service._execute_task(task.id) await scheduler_service._execute_task(task.id)
# Handler should not be called # Handler should not be called
mock_handler_class.assert_not_called() mock_handler_class.assert_not_called()
@patch('app.repositories.scheduled_task.ScheduledTaskRepository') @patch("app.repositories.scheduled_task.ScheduledTaskRepository")
async def test_maintenance_job_expired_tasks( async def test_maintenance_job_expired_tasks(
self, self,
mock_repo_class, mock_repo_class,
@@ -453,22 +453,22 @@ class TestSchedulerService:
# Mock expired task # Mock expired task
expired_task = MagicMock() expired_task = MagicMock()
expired_task.id = uuid.uuid4() expired_task.id = uuid.uuid4()
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo.get_expired_tasks.return_value = [expired_task] mock_repo.get_expired_tasks.return_value = [expired_task]
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [] mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove: with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove:
await scheduler_service._maintenance_job() await scheduler_service._maintenance_job()
# Should mark as cancelled and remove from scheduler # Should mark as cancelled and remove from scheduler
assert expired_task.status == TaskStatus.CANCELLED assert expired_task.status == TaskStatus.CANCELLED
assert expired_task.is_active is False assert expired_task.is_active is False
mock_repo.update.assert_called_with(expired_task) mock_repo.update.assert_called_with(expired_task)
mock_remove.assert_called_once_with(str(expired_task.id)) mock_remove.assert_called_once_with(str(expired_task.id))
@patch('app.repositories.scheduled_task.ScheduledTaskRepository') @patch("app.repositories.scheduled_task.ScheduledTaskRepository")
async def test_maintenance_job_due_recurring_tasks( async def test_maintenance_job_due_recurring_tasks(
self, self,
mock_repo_class, mock_repo_class,
@@ -479,17 +479,17 @@ class TestSchedulerService:
due_task = MagicMock() due_task = MagicMock()
due_task.should_repeat.return_value = True due_task.should_repeat.return_value = True
due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5) due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5)
mock_repo = AsyncMock() mock_repo = AsyncMock()
mock_repo.get_expired_tasks.return_value = [] mock_repo.get_expired_tasks.return_value = []
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task] mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
mock_repo_class.return_value = mock_repo mock_repo_class.return_value = mock_repo
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule: with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule:
await scheduler_service._maintenance_job() await scheduler_service._maintenance_job()
# Should reset to pending and reschedule # Should reset to pending and reschedule
assert due_task.status == TaskStatus.PENDING assert due_task.status == TaskStatus.PENDING
assert due_task.scheduled_at == due_task.next_execution_at assert due_task.scheduled_at == due_task.next_execution_at
mock_repo.update.assert_called_with(due_task) mock_repo.update.assert_called_with(due_task)
mock_schedule.assert_called_once_with(due_task) mock_schedule.assert_called_once_with(due_task)

View File

@@ -133,8 +133,8 @@ class TestTaskHandlerRegistry:
mock_sound.id = test_sound_id mock_sound.id = test_sound_id
mock_sound.filename = "test_sound.mp3" mock_sound.filename = "test_sound.mp3"
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound): with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class: with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_vlc_class.return_value = mock_vlc_service mock_vlc_class.return_value = mock_vlc_service
@@ -186,7 +186,7 @@ class TestTaskHandlerRegistry:
parameters={"sound_id": str(test_sound_id)}, parameters={"sound_id": str(test_sound_id)},
) )
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None): with patch.object(task_registry.sound_repository, "get_by_id", return_value=None):
with pytest.raises(TaskExecutionError, match="Sound not found"): with pytest.raises(TaskExecutionError, match="Sound not found"):
await task_registry.execute_task(task) await task_registry.execute_task(task)
@@ -206,8 +206,8 @@ class TestTaskHandlerRegistry:
mock_sound = MagicMock() mock_sound = MagicMock()
mock_sound.filename = "test_sound.mp3" mock_sound.filename = "test_sound.mp3"
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound): with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound):
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class: with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class:
mock_vlc_service = AsyncMock() mock_vlc_service = AsyncMock()
mock_vlc_class.return_value = mock_vlc_service mock_vlc_class.return_value = mock_vlc_service
@@ -238,7 +238,7 @@ class TestTaskHandlerRegistry:
mock_playlist.id = test_playlist_id mock_playlist.id = test_playlist_id
mock_playlist.name = "Test Playlist" mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task) await task_registry.execute_task(task)
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id) task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
@@ -264,7 +264,7 @@ class TestTaskHandlerRegistry:
mock_playlist = MagicMock() mock_playlist = MagicMock()
mock_playlist.name = "Test Playlist" mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task) await task_registry.execute_task(task)
# Should use default values # Should use default values
@@ -314,7 +314,7 @@ class TestTaskHandlerRegistry:
parameters={"playlist_id": str(test_playlist_id)}, parameters={"playlist_id": str(test_playlist_id)},
) )
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=None):
with pytest.raises(TaskExecutionError, match="Playlist not found"): with pytest.raises(TaskExecutionError, match="Playlist not found"):
await task_registry.execute_task(task) await task_registry.execute_task(task)
@@ -327,7 +327,7 @@ class TestTaskHandlerRegistry:
"""Test play playlist task with various valid play modes.""" """Test play playlist task with various valid play modes."""
mock_playlist = MagicMock() mock_playlist = MagicMock()
mock_playlist.name = "Test Playlist" mock_playlist.name = "Test Playlist"
valid_modes = ["continuous", "loop", "loop_one", "random", "single"] valid_modes = ["continuous", "loop", "loop_one", "random", "single"]
for mode in valid_modes: for mode in valid_modes:
@@ -341,7 +341,7 @@ class TestTaskHandlerRegistry:
}, },
) )
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task) await task_registry.execute_task(task)
mock_player_service.set_mode.assert_called_with(mode) mock_player_service.set_mode.assert_called_with(mode)
@@ -368,7 +368,7 @@ class TestTaskHandlerRegistry:
mock_playlist = MagicMock() mock_playlist = MagicMock()
mock_playlist.name = "Test Playlist" mock_playlist.name = "Test Playlist"
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist):
await task_registry.execute_task(task) await task_registry.execute_task(task)
# Should not call set_mode for invalid mode # Should not call set_mode for invalid mode
@@ -421,4 +421,4 @@ class TestTaskHandlerRegistry:
TaskType.PLAY_SOUND, TaskType.PLAY_SOUND,
TaskType.PLAY_PLAYLIST, TaskType.PLAY_PLAYLIST,
} }
assert set(registry._handlers.keys()) == expected_handlers assert set(registry._handlers.keys()) == expected_handlers