diff --git a/app/api/v1/scheduler.py b/app/api/v1/scheduler.py index cd4f85a..c935ece 100644 --- a/app/api/v1/scheduler.py +++ b/app/api/v1/scheduler.py @@ -1,7 +1,5 @@ """API endpoints for scheduled task management.""" -from datetime import datetime -from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from sqlmodel.ext.asyncio.session import AsyncSession @@ -11,7 +9,7 @@ from app.core.dependencies import ( get_admin_user, get_current_active_user, ) -from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType +from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType from app.models.user import User from app.schemas.scheduler import ( ScheduledTaskCreate, @@ -54,15 +52,15 @@ async def create_task( raise HTTPException(status_code=400, detail=str(e)) -@router.get("/tasks", response_model=List[ScheduledTaskResponse]) +@router.get("/tasks", response_model=list[ScheduledTaskResponse]) async def get_user_tasks( - status: Optional[TaskStatus] = Query(None, description="Filter by task status"), - task_type: Optional[TaskType] = Query(None, description="Filter by task type"), - limit: Optional[int] = Query(50, description="Maximum number of tasks to return"), - offset: Optional[int] = Query(0, description="Number of tasks to skip"), + status: TaskStatus | None = Query(None, description="Filter by task status"), + task_type: TaskType | None = Query(None, description="Filter by task type"), + limit: int | None = Query(50, description="Maximum number of tasks to return"), + offset: int | None = Query(0, description="Number of tasks to skip"), current_user: User = Depends(get_current_active_user), scheduler_service: SchedulerService = Depends(get_scheduler_service), -) -> List[ScheduledTask]: +) -> list[ScheduledTask]: """Get user's scheduled tasks.""" return await scheduler_service.get_user_tasks( user_id=current_user.id, @@ -81,17 +79,17 @@ async def get_task( ) -> ScheduledTask: """Get a specific scheduled task.""" from app.repositories.scheduled_task import ScheduledTaskRepository - + repo = ScheduledTaskRepository(db_session) task = await repo.get_by_id(task_id) - + if not task: raise HTTPException(status_code=404, detail="Task not found") - + # Check if user owns the task or is admin if task.user_id != current_user.id and not current_user.is_admin: raise HTTPException(status_code=403, detail="Access denied") - + return task @@ -104,22 +102,22 @@ async def update_task( ) -> ScheduledTask: """Update a scheduled task.""" from app.repositories.scheduled_task import ScheduledTaskRepository - + repo = ScheduledTaskRepository(db_session) task = await repo.get_by_id(task_id) - + if not task: raise HTTPException(status_code=404, detail="Task not found") - + # Check if user owns the task or is admin if task.user_id != current_user.id and not current_user.is_admin: raise HTTPException(status_code=403, detail="Access denied") - + # Update task fields update_data = task_update.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(task, field, value) - + updated_task = await repo.update(task) return updated_task @@ -133,72 +131,72 @@ async def cancel_task( ) -> dict: """Cancel a scheduled task.""" from app.repositories.scheduled_task import ScheduledTaskRepository - + repo = ScheduledTaskRepository(db_session) task = await repo.get_by_id(task_id) - + if not task: raise HTTPException(status_code=404, detail="Task not found") - + # Check if user owns the task or is admin if task.user_id != current_user.id and not current_user.is_admin: raise HTTPException(status_code=403, detail="Access denied") - + success = await scheduler_service.cancel_task(task_id) if not success: raise HTTPException(status_code=400, detail="Failed to cancel task") - + return {"message": "Task cancelled successfully"} # Admin-only endpoints -@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse]) +@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse]) async def get_all_tasks( - status: Optional[TaskStatus] = Query(None, description="Filter by task status"), - task_type: Optional[TaskType] = Query(None, description="Filter by task type"), - limit: Optional[int] = Query(100, description="Maximum number of tasks to return"), - offset: Optional[int] = Query(0, description="Number of tasks to skip"), + status: TaskStatus | None = Query(None, description="Filter by task status"), + task_type: TaskType | None = Query(None, description="Filter by task type"), + limit: int | None = Query(100, description="Maximum number of tasks to return"), + offset: int | None = Query(0, description="Number of tasks to skip"), current_user: User = Depends(get_admin_user), db_session: AsyncSession = Depends(get_db), -) -> List[ScheduledTask]: +) -> list[ScheduledTask]: """Get all scheduled tasks (admin only).""" from app.repositories.scheduled_task import ScheduledTaskRepository - + repo = ScheduledTaskRepository(db_session) - + # Get all tasks with pagination and filtering from sqlmodel import select - + statement = select(ScheduledTask) - + if status: statement = statement.where(ScheduledTask.status == status) - + if task_type: statement = statement.where(ScheduledTask.task_type == task_type) - + statement = statement.order_by(ScheduledTask.scheduled_at.desc()) - + if offset: statement = statement.offset(offset) - + if limit: statement = statement.limit(limit) - + result = await db_session.exec(statement) return list(result.all()) -@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse]) +@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse]) async def get_system_tasks( - status: Optional[TaskStatus] = Query(None, description="Filter by task status"), - task_type: Optional[TaskType] = Query(None, description="Filter by task type"), + status: TaskStatus | None = Query(None, description="Filter by task status"), + task_type: TaskType | None = Query(None, description="Filter by task type"), current_user: User = Depends(get_admin_user), db_session: AsyncSession = Depends(get_db), -) -> List[ScheduledTask]: +) -> list[ScheduledTask]: """Get system tasks (admin only).""" from app.repositories.scheduled_task import ScheduledTaskRepository - + repo = ScheduledTaskRepository(db_session) return await repo.get_system_tasks(status=status, task_type=task_type) @@ -225,4 +223,4 @@ async def create_system_task( ) return task except Exception as e: - raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=400, detail=str(e)) diff --git a/app/core/database.py b/app/core/database.py index 24d9a1b..8763f4a 100644 --- a/app/core/database.py +++ b/app/core/database.py @@ -4,11 +4,11 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession +# Import all models to ensure SQLModel metadata discovery +import app.models # noqa: F401 from app.core.config import settings from app.core.logging import get_logger from app.core.seeds import seed_all_data -# Import all models to ensure SQLModel metadata discovery -import app.models # noqa: F401 engine: AsyncEngine = create_async_engine( settings.DATABASE_URL, diff --git a/app/main.py b/app/main.py index 2acd480..a7c1b59 100644 --- a/app/main.py +++ b/app/main.py @@ -11,11 +11,14 @@ from app.core.database import get_session_factory, init_db from app.core.logging import get_logger, setup_logging from app.middleware.logging import LoggingMiddleware from app.services.extraction_processor import extraction_processor -from app.services.player import initialize_player_service, shutdown_player_service, get_player_service +from app.services.player import ( + get_player_service, + initialize_player_service, + shutdown_player_service, +) from app.services.scheduler import SchedulerService from app.services.socket import socket_manager - scheduler_service = None @@ -31,7 +34,7 @@ def get_global_scheduler_service() -> SchedulerService: async def lifespan(_app: FastAPI) -> AsyncGenerator[None]: """Application lifespan context manager for setup and teardown.""" global scheduler_service - + setup_logging() logger = get_logger(__name__) logger.info("Starting application") diff --git a/app/models/__init__.py b/app/models/__init__.py index 40db630..1e3c2c7 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -20,7 +20,7 @@ __all__ = [ "CreditAction", "CreditTransaction", "Extraction", - "Favorite", + "Favorite", "Plan", "Playlist", "PlaylistSound", diff --git a/app/models/scheduled_task.py b/app/models/scheduled_task.py index f0ddbc8..6e9096d 100644 --- a/app/models/scheduled_task.py +++ b/app/models/scheduled_task.py @@ -1,11 +1,10 @@ """Scheduled task model for flexible task scheduling with timezone support.""" -import uuid from datetime import datetime from enum import Enum -from typing import Any, Optional +from typing import Any -from sqlmodel import JSON, Column, Field, SQLModel +from sqlmodel import JSON, Column, Field from app.models.base import BaseModel @@ -57,11 +56,11 @@ class ScheduledTask(BaseModel, table=True): description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')", ) recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE) - cron_expression: Optional[str] = Field( + cron_expression: str | None = Field( default=None, description="Cron expression for custom recurrence (when recurrence_type is CRON)", ) - recurrence_count: Optional[int] = Field( + recurrence_count: int | None = Field( default=None, description="Number of times to repeat (None for infinite)", ) @@ -75,29 +74,29 @@ class ScheduledTask(BaseModel, table=True): ) # User association (None for system tasks) - user_id: Optional[int] = Field( + user_id: int | None = Field( default=None, foreign_key="user.id", description="User who created the task (None for system tasks)", ) # Execution tracking - last_executed_at: Optional[datetime] = Field( + last_executed_at: datetime | None = Field( default=None, description="When the task was last executed (UTC)", ) - next_execution_at: Optional[datetime] = Field( + next_execution_at: datetime | None = Field( default=None, description="When the task should be executed next (UTC, for recurring tasks)", ) - error_message: Optional[str] = Field( + error_message: str | None = Field( default=None, description="Error message if execution failed", ) # Task lifecycle is_active: bool = Field(default=True, description="Whether the task is active") - expires_at: Optional[datetime] = Field( + expires_at: datetime | None = Field( default=None, description="When the task expires (UTC, optional)", ) @@ -122,4 +121,4 @@ class ScheduledTask(BaseModel, table=True): def is_system_task(self) -> bool: """Check if this is a system task (no user association).""" - return self.user_id is None \ No newline at end of file + return self.user_id is None diff --git a/app/repositories/scheduled_task.py b/app/repositories/scheduled_task.py index f81b6da..5204f5e 100644 --- a/app/repositories/scheduled_task.py +++ b/app/repositories/scheduled_task.py @@ -1,12 +1,16 @@ """Repository for scheduled task operations.""" from datetime import datetime -from typing import List, Optional from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType +from app.models.scheduled_task import ( + RecurrenceType, + ScheduledTask, + TaskStatus, + TaskType, +) from app.repositories.base import BaseRepository @@ -17,7 +21,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): """Initialize the repository.""" super().__init__(ScheduledTask, session) - async def get_pending_tasks(self) -> List[ScheduledTask]: + async def get_pending_tasks(self) -> list[ScheduledTask]: """Get all pending tasks that are ready to be executed.""" now = datetime.utcnow() statement = select(ScheduledTask).where( @@ -28,7 +32,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): result = await self.session.exec(statement) return list(result.all()) - async def get_active_tasks(self) -> List[ScheduledTask]: + async def get_active_tasks(self) -> list[ScheduledTask]: """Get all active tasks.""" statement = select(ScheduledTask).where( ScheduledTask.is_active.is_(True), @@ -40,11 +44,11 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): async def get_user_tasks( self, user_id: int, - status: Optional[TaskStatus] = None, - task_type: Optional[TaskType] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - ) -> List[ScheduledTask]: + status: TaskStatus | None = None, + task_type: TaskType | None = None, + limit: int | None = None, + offset: int | None = None, + ) -> list[ScheduledTask]: """Get tasks for a specific user.""" statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id) @@ -67,9 +71,9 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): async def get_system_tasks( self, - status: Optional[TaskStatus] = None, - task_type: Optional[TaskType] = None, - ) -> List[ScheduledTask]: + status: TaskStatus | None = None, + task_type: TaskType | None = None, + ) -> list[ScheduledTask]: """Get system tasks (tasks with no user association).""" statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None)) @@ -84,7 +88,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): result = await self.session.exec(statement) return list(result.all()) - async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]: + async def get_recurring_tasks_due_for_next_execution(self) -> list[ScheduledTask]: """Get recurring tasks that need their next execution scheduled.""" now = datetime.utcnow() statement = select(ScheduledTask).where( @@ -96,7 +100,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): result = await self.session.exec(statement) return list(result.all()) - async def get_expired_tasks(self) -> List[ScheduledTask]: + async def get_expired_tasks(self) -> list[ScheduledTask]: """Get expired tasks that should be cleaned up.""" now = datetime.utcnow() statement = select(ScheduledTask).where( @@ -110,7 +114,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): async def cancel_user_tasks( self, user_id: int, - task_type: Optional[TaskType] = None, + task_type: TaskType | None = None, ) -> int: """Cancel all pending/running tasks for a user.""" statement = select(ScheduledTask).where( @@ -144,7 +148,7 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): async def mark_as_completed( self, task: ScheduledTask, - next_execution_at: Optional[datetime] = None, + next_execution_at: datetime | None = None, ) -> None: """Mark a task as completed and set next execution if recurring.""" task.status = TaskStatus.COMPLETED @@ -174,4 +178,4 @@ class ScheduledTaskRepository(BaseRepository[ScheduledTask]): self.session.add(task) await self.session.commit() - await self.session.refresh(task) \ No newline at end of file + await self.session.refresh(task) diff --git a/app/schemas/scheduler.py b/app/schemas/scheduler.py index 2bc9d6f..5daa831 100644 --- a/app/schemas/scheduler.py +++ b/app/schemas/scheduler.py @@ -1,7 +1,7 @@ """Schemas for scheduled task API.""" from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -15,7 +15,7 @@ class ScheduledTaskBase(BaseModel): task_type: TaskType = Field(description="Type of task to execute") scheduled_at: datetime = Field(description="When the task should be executed") timezone: str = Field(default="UTC", description="Timezone for scheduling") - parameters: Dict[str, Any] = Field( + parameters: dict[str, Any] = Field( default_factory=dict, description="Task-specific parameters", ) @@ -23,15 +23,15 @@ class ScheduledTaskBase(BaseModel): default=RecurrenceType.NONE, description="Recurrence pattern", ) - cron_expression: Optional[str] = Field( + cron_expression: str | None = Field( default=None, description="Cron expression for custom recurrence", ) - recurrence_count: Optional[int] = Field( + recurrence_count: int | None = Field( default=None, description="Number of times to repeat (None for infinite)", ) - expires_at: Optional[datetime] = Field( + expires_at: datetime | None = Field( default=None, description="When the task expires (optional)", ) @@ -40,18 +40,17 @@ class ScheduledTaskBase(BaseModel): class ScheduledTaskCreate(ScheduledTaskBase): """Schema for creating a scheduled task.""" - pass class ScheduledTaskUpdate(BaseModel): """Schema for updating a scheduled task.""" - name: Optional[str] = None - scheduled_at: Optional[datetime] = None - timezone: Optional[str] = None - parameters: Optional[Dict[str, Any]] = None - is_active: Optional[bool] = None - expires_at: Optional[datetime] = None + name: str | None = None + scheduled_at: datetime | None = None + timezone: str | None = None + parameters: dict[str, Any] | None = None + is_active: bool | None = None + expires_at: datetime | None = None class ScheduledTaskResponse(ScheduledTaskBase): @@ -59,11 +58,11 @@ class ScheduledTaskResponse(ScheduledTaskBase): id: int status: TaskStatus - user_id: Optional[int] = None + user_id: int | None = None executions_count: int - last_executed_at: Optional[datetime] = None - next_execution_at: Optional[datetime] = None - error_message: Optional[str] = None + last_executed_at: datetime | None = None + next_execution_at: datetime | None = None + error_message: str | None = None is_active: bool created_at: datetime updated_at: datetime @@ -78,7 +77,7 @@ class ScheduledTaskResponse(ScheduledTaskBase): class CreditRechargeParameters(BaseModel): """Parameters for credit recharge tasks.""" - user_id: Optional[int] = Field( + user_id: int | None = Field( default=None, description="Specific user ID to recharge (None for all users)", ) @@ -109,10 +108,10 @@ class CreateCreditRechargeTask(BaseModel): scheduled_at: datetime timezone: str = "UTC" recurrence_type: RecurrenceType = RecurrenceType.NONE - cron_expression: Optional[str] = None - recurrence_count: Optional[int] = None - expires_at: Optional[datetime] = None - user_id: Optional[int] = None + cron_expression: str | None = None + recurrence_count: int | None = None + expires_at: datetime | None = None + user_id: int | None = None def to_task_create(self) -> ScheduledTaskCreate: """Convert to generic task creation schema.""" @@ -137,9 +136,9 @@ class CreatePlaySoundTask(BaseModel): sound_id: int timezone: str = "UTC" recurrence_type: RecurrenceType = RecurrenceType.NONE - cron_expression: Optional[str] = None - recurrence_count: Optional[int] = None - expires_at: Optional[datetime] = None + cron_expression: str | None = None + recurrence_count: int | None = None + expires_at: datetime | None = None def to_task_create(self) -> ScheduledTaskCreate: """Convert to generic task creation schema.""" @@ -166,9 +165,9 @@ class CreatePlayPlaylistTask(BaseModel): shuffle: bool = False timezone: str = "UTC" recurrence_type: RecurrenceType = RecurrenceType.NONE - cron_expression: Optional[str] = None - recurrence_count: Optional[int] = None - expires_at: Optional[datetime] = None + cron_expression: str | None = None + recurrence_count: int | None = None + expires_at: datetime | None = None def to_task_create(self) -> ScheduledTaskCreate: """Convert to generic task creation schema.""" @@ -186,4 +185,4 @@ class CreatePlayPlaylistTask(BaseModel): cron_expression=self.cron_expression, recurrence_count=self.recurrence_count, expires_at=self.expires_at, - ) \ No newline at end of file + ) diff --git a/app/services/scheduler.py b/app/services/scheduler.py index 8feb226..fb337cd 100644 --- a/app/services/scheduler.py +++ b/app/services/scheduler.py @@ -2,7 +2,7 @@ from collections.abc import Callable from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any import pytz from apscheduler.schedulers.asyncio import AsyncIOScheduler @@ -52,7 +52,7 @@ class SchedulerService: logger.info("Starting enhanced scheduler service...") self.scheduler.start() - + # Schedule system tasks initialization for after startup self.scheduler.add_job( self._initialize_system_tasks, @@ -62,7 +62,7 @@ class SchedulerService: name="Initialize System Tasks", replace_existing=True, ) - + # Schedule periodic cleanup and maintenance self.scheduler.add_job( self._maintenance_job, @@ -86,18 +86,18 @@ class SchedulerService: name: str, task_type: TaskType, scheduled_at: datetime, - parameters: Optional[Dict[str, Any]] = None, - user_id: Optional[int] = None, + parameters: dict[str, Any] | None = None, + user_id: int | None = None, timezone: str = "UTC", recurrence_type: RecurrenceType = RecurrenceType.NONE, - cron_expression: Optional[str] = None, - recurrence_count: Optional[int] = None, - expires_at: Optional[datetime] = None, + cron_expression: str | None = None, + recurrence_count: int | None = None, + expires_at: datetime | None = None, ) -> ScheduledTask: """Create a new scheduled task.""" async with self.db_session_factory() as session: repo = ScheduledTaskRepository(session) - + # Convert scheduled_at to UTC if it's in a different timezone if timezone != "UTC": tz = pytz.timezone(timezone) @@ -105,7 +105,7 @@ class SchedulerService: # Assume the datetime is in the specified timezone scheduled_at = tz.localize(scheduled_at) scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None) - + task_data = { "name": name, "task_type": task_type, @@ -118,59 +118,59 @@ class SchedulerService: "recurrence_count": recurrence_count, "expires_at": expires_at, } - + created_task = await repo.create(task_data) await self._schedule_apscheduler_job(created_task) - + logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})") return created_task - + async def cancel_task(self, task_id: int) -> bool: """Cancel a scheduled task.""" async with self.db_session_factory() as session: repo = ScheduledTaskRepository(session) - + task = await repo.get_by_id(task_id) if not task: return False - + task.status = TaskStatus.CANCELLED task.is_active = False await repo.update(task) - + # Remove from APScheduler try: self.scheduler.remove_job(str(task_id)) except Exception: pass # Job might not exist in scheduler - + logger.info(f"Cancelled task: {task.name} ({task_id})") return True - + async def get_user_tasks( self, user_id: int, - status: Optional[TaskStatus] = None, - task_type: Optional[TaskType] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - ) -> List[ScheduledTask]: + status: TaskStatus | None = None, + task_type: TaskType | None = None, + limit: int | None = None, + offset: int | None = None, + ) -> list[ScheduledTask]: """Get tasks for a specific user.""" async with self.db_session_factory() as session: repo = ScheduledTaskRepository(session) return await repo.get_user_tasks(user_id, status, task_type, limit, offset) - + async def _initialize_system_tasks(self) -> None: """Initialize system tasks and load active tasks from database.""" logger.info("Initializing system tasks...") - + try: # Create system tasks if they don't exist await self._ensure_system_tasks() - + # Load all active tasks from database await self._load_active_tasks() - + logger.info("System tasks initialized successfully") except Exception: logger.exception("Failed to initialize system tasks") @@ -179,24 +179,24 @@ class SchedulerService: """Ensure required system tasks exist.""" async with self.db_session_factory() as session: repo = ScheduledTaskRepository(session) - + # Check if daily credit recharge task exists system_tasks = await repo.get_system_tasks( - task_type=TaskType.CREDIT_RECHARGE + task_type=TaskType.CREDIT_RECHARGE, ) - + daily_recharge_exists = any( task.recurrence_type == RecurrenceType.DAILY and task.is_active for task in system_tasks ) - + if not daily_recharge_exists: # Create daily credit recharge task tomorrow_midnight = datetime.utcnow().replace( - hour=0, minute=0, second=0, microsecond=0 + hour=0, minute=0, second=0, microsecond=0, ) + timedelta(days=1) - + task_data = { "name": "Daily Credit Recharge", "task_type": TaskType.CREDIT_RECHARGE, @@ -204,41 +204,41 @@ class SchedulerService: "recurrence_type": RecurrenceType.DAILY, "parameters": {}, } - + await repo.create(task_data) logger.info("Created system daily credit recharge task") - + async def _load_active_tasks(self) -> None: """Load all active tasks from database into scheduler.""" async with self.db_session_factory() as session: repo = ScheduledTaskRepository(session) active_tasks = await repo.get_active_tasks() - + for task in active_tasks: await self._schedule_apscheduler_job(task) - + logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler") - + async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None: """Schedule a task in APScheduler.""" job_id = str(task.id) - + # Remove existing job if it exists try: self.scheduler.remove_job(job_id) except Exception: pass - + # Don't schedule if task is not active or already completed/failed if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: return - + # Create trigger based on recurrence type trigger = self._create_trigger(task) if not trigger: logger.warning(f"Could not create trigger for task {task.id}") return - + # Schedule the job self.scheduler.add_job( self._execute_task, @@ -248,76 +248,76 @@ class SchedulerService: name=task.name, replace_existing=True, ) - + logger.debug(f"Scheduled APScheduler job for task {task.id}") - + def _create_trigger(self, task: ScheduledTask): """Create APScheduler trigger based on task configuration.""" tz = pytz.timezone(task.timezone) - + if task.recurrence_type == RecurrenceType.NONE: return DateTrigger(run_date=task.scheduled_at, timezone=tz) - - elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression: + + if task.recurrence_type == RecurrenceType.CRON and task.cron_expression: return CronTrigger.from_crontab(task.cron_expression, timezone=tz) - - elif task.recurrence_type == RecurrenceType.HOURLY: + + if task.recurrence_type == RecurrenceType.HOURLY: return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz) - - elif task.recurrence_type == RecurrenceType.DAILY: + + if task.recurrence_type == RecurrenceType.DAILY: return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz) - - elif task.recurrence_type == RecurrenceType.WEEKLY: + + if task.recurrence_type == RecurrenceType.WEEKLY: return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz) - - elif task.recurrence_type == RecurrenceType.MONTHLY: + + if task.recurrence_type == RecurrenceType.MONTHLY: # Use cron trigger for monthly (more reliable than interval) scheduled_time = task.scheduled_at return CronTrigger( day=scheduled_time.day, hour=scheduled_time.hour, minute=scheduled_time.minute, - timezone=tz + timezone=tz, ) - - elif task.recurrence_type == RecurrenceType.YEARLY: + + if task.recurrence_type == RecurrenceType.YEARLY: scheduled_time = task.scheduled_at return CronTrigger( month=scheduled_time.month, day=scheduled_time.day, hour=scheduled_time.hour, minute=scheduled_time.minute, - timezone=tz + timezone=tz, ) - + return None - + async def _execute_task(self, task_id: int) -> None: """Execute a scheduled task.""" task_id_str = str(task_id) - + # Prevent concurrent execution of the same task if task_id_str in self._running_tasks: logger.warning(f"Task {task_id} is already running, skipping execution") return - + self._running_tasks.add(task_id_str) - + try: async with self.db_session_factory() as session: repo = ScheduledTaskRepository(session) - + # Get fresh task data task = await repo.get_by_id(task_id) if not task: logger.warning(f"Task {task_id} not found") return - + # Check if task is still active and pending if not task.is_active or task.status != TaskStatus.PENDING: logger.info(f"Task {task_id} is not active or not pending, skipping") return - + # Check if task has expired if task.is_expired(): logger.info(f"Task {task_id} has expired, marking as cancelled") @@ -325,78 +325,78 @@ class SchedulerService: task.is_active = False await repo.update(task) return - + # Mark task as running await repo.mark_as_running(task) - + # Execute the task try: handler_registry = TaskHandlerRegistry( - session, self.db_session_factory, self.credit_service, self.player_service + session, self.db_session_factory, self.credit_service, self.player_service, ) await handler_registry.execute_task(task) - + # Calculate next execution time for recurring tasks next_execution_at = None if task.should_repeat(): next_execution_at = self._calculate_next_execution(task) - + # Mark as completed await repo.mark_as_completed(task, next_execution_at) - + # Reschedule if recurring if next_execution_at and task.should_repeat(): # Refresh task to get updated data await session.refresh(task) await self._schedule_apscheduler_job(task) - + except Exception as e: await repo.mark_as_failed(task, str(e)) - logger.exception(f"Task {task_id} execution failed: {str(e)}") - + logger.exception(f"Task {task_id} execution failed: {e!s}") + finally: self._running_tasks.discard(task_id_str) - - def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]: + + def _calculate_next_execution(self, task: ScheduledTask) -> datetime | None: """Calculate the next execution time for a recurring task.""" now = datetime.utcnow() - + if task.recurrence_type == RecurrenceType.HOURLY: return now + timedelta(hours=1) - elif task.recurrence_type == RecurrenceType.DAILY: + if task.recurrence_type == RecurrenceType.DAILY: return now + timedelta(days=1) - elif task.recurrence_type == RecurrenceType.WEEKLY: + if task.recurrence_type == RecurrenceType.WEEKLY: return now + timedelta(weeks=1) - elif task.recurrence_type == RecurrenceType.MONTHLY: + if task.recurrence_type == RecurrenceType.MONTHLY: # Add approximately one month return now + timedelta(days=30) - elif task.recurrence_type == RecurrenceType.YEARLY: + if task.recurrence_type == RecurrenceType.YEARLY: return now + timedelta(days=365) - + return None - + async def _maintenance_job(self) -> None: """Periodic maintenance job to clean up expired tasks and handle scheduling issues.""" try: async with self.db_session_factory() as session: repo = ScheduledTaskRepository(session) - + # Handle expired tasks expired_tasks = await repo.get_expired_tasks() for task in expired_tasks: task.status = TaskStatus.CANCELLED task.is_active = False await repo.update(task) - + # Remove from scheduler try: self.scheduler.remove_job(str(task.id)) except Exception: pass - + if expired_tasks: logger.info(f"Cleaned up {len(expired_tasks)} expired tasks") - + # Handle any missed recurring tasks due_recurring = await repo.get_recurring_tasks_due_for_next_execution() for task in due_recurring: @@ -405,9 +405,9 @@ class SchedulerService: task.scheduled_at = task.next_execution_at or datetime.utcnow() await repo.update(task) await self._schedule_apscheduler_job(task) - + if due_recurring: logger.info(f"Rescheduled {len(due_recurring)} recurring tasks") - + except Exception: logger.exception("Maintenance job failed") diff --git a/app/services/task_handlers.py b/app/services/task_handlers.py index 017fac1..93541ee 100644 --- a/app/services/task_handlers.py +++ b/app/services/task_handlers.py @@ -1,6 +1,5 @@ """Task execution handlers for different task types.""" -from typing import Any, Dict, Optional from collections.abc import Callable from sqlmodel.ext.asyncio.session import AsyncSession @@ -18,7 +17,6 @@ logger = get_logger(__name__) class TaskExecutionError(Exception): """Exception raised when task execution fails.""" - pass class TaskHandlerRegistry: @@ -58,8 +56,8 @@ class TaskHandlerRegistry: await handler(task) logger.info(f"Task {task.id} executed successfully") except Exception as e: - logger.exception(f"Task {task.id} execution failed: {str(e)}") - raise TaskExecutionError(f"Task execution failed: {str(e)}") from e + logger.exception(f"Task {task.id} execution failed: {e!s}") + raise TaskExecutionError(f"Task execution failed: {e!s}") from e async def _handle_credit_recharge(self, task: ScheduledTask) -> None: """Handle credit recharge task.""" @@ -72,7 +70,7 @@ class TaskHandlerRegistry: user_id_int = int(user_id) except (ValueError, TypeError) as e: raise TaskExecutionError(f"Invalid user_id format: {user_id}") from e - + stats = await self.credit_service.recharge_user_credits(user_id_int) logger.info(f"Recharged credits for user {user_id}: {stats}") else: @@ -105,7 +103,7 @@ class TaskHandlerRegistry: logger.info(f"Played sound {result.get('sound_name', sound_id)} via scheduled task for user {task.user_id} (credits deducted: {result.get('credits_deducted', 0)})") except Exception as e: # Convert HTTP exceptions or credit errors to task execution errors - raise TaskExecutionError(f"Failed to play sound with credits: {str(e)}") from e + raise TaskExecutionError(f"Failed to play sound with credits: {e!s}") from e else: # System task: play without credit deduction sound = await self.sound_repository.get_by_id(sound_id_int) @@ -116,10 +114,10 @@ class TaskHandlerRegistry: vlc_service = VLCPlayerService(self.db_session_factory) success = await vlc_service.play_sound(sound) - + if not success: raise TaskExecutionError(f"Failed to play sound {sound.filename}") - + logger.info(f"Played sound {sound.filename} via scheduled system task") async def _handle_play_playlist(self, task: ScheduledTask) -> None: @@ -157,4 +155,4 @@ class TaskHandlerRegistry: # Start playing await self.player_service.play() - logger.info(f"Started playing playlist {playlist.name} via scheduled task") \ No newline at end of file + logger.info(f"Started playing playlist {playlist.name} via scheduled task") diff --git a/app/services/vlc_player.py b/app/services/vlc_player.py index 60b7f07..29a3a1d 100644 --- a/app/services/vlc_player.py +++ b/app/services/vlc_player.py @@ -238,13 +238,13 @@ class VLCPlayerService: return logger.info("Recording play count for sound %s", sound_id) - + # Initialize variables for WebSocket event old_count = 0 sound = None admin_user_id = None admin_user_name = None - + try: async with self.db_session_factory() as session: sound_repo = SoundRepository(session) diff --git a/check_tasks.py b/check_tasks.py index fadb698..bcfad17 100644 --- a/check_tasks.py +++ b/check_tasks.py @@ -7,15 +7,16 @@ from datetime import datetime from app.core.database import get_session_factory from app.repositories.scheduled_task import ScheduledTaskRepository + async def check_tasks(): session_factory = get_session_factory() - + async with session_factory() as session: repo = ScheduledTaskRepository(session) - + # Get all tasks all_tasks = await repo.get_all(limit=20) - + print("All tasks in database:") print("=" * 80) for task in all_tasks: @@ -32,14 +33,14 @@ async def check_tasks(): print(f"Error: {task.error_message}") print(f"Parameters: {task.parameters}") print("-" * 40) - + # Check specifically for pending tasks print(f"\nCurrent time: {datetime.utcnow()}") print("\nPending tasks:") from app.models.scheduled_task import TaskStatus pending_tasks = await repo.get_all(limit=10) pending_tasks = [t for t in pending_tasks if t.status == TaskStatus.PENDING and t.is_active] - + if not pending_tasks: print("No pending tasks found") else: @@ -48,4 +49,4 @@ async def check_tasks(): print(f"- {task.name} (ID: {task.id}): scheduled for {task.scheduled_at} (in {time_diff})") if __name__ == "__main__": - asyncio.run(check_tasks()) \ No newline at end of file + asyncio.run(check_tasks()) diff --git a/create_future_task.py b/create_future_task.py index bd1973c..80f372e 100644 --- a/create_future_task.py +++ b/create_future_task.py @@ -5,18 +5,19 @@ import asyncio from datetime import datetime, timedelta from app.core.database import get_session_factory +from app.models.scheduled_task import RecurrenceType, TaskType from app.repositories.scheduled_task import ScheduledTaskRepository -from app.models.scheduled_task import TaskType, RecurrenceType + async def create_future_task(): session_factory = get_session_factory() - + # Create a task for 2 minutes from now future_time = datetime.utcnow() + timedelta(minutes=2) - + async with session_factory() as session: repo = ScheduledTaskRepository(session) - + task_data = { "name": f"Future Task {future_time.strftime('%H:%M:%S')}", "task_type": TaskType.PLAY_SOUND, @@ -26,11 +27,11 @@ async def create_future_task(): "user_id": 1, "recurrence_type": RecurrenceType.NONE, } - + task = await repo.create(task_data) print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}") print(f"Current time: {datetime.utcnow()}") print(f"Task will execute in: {future_time - datetime.utcnow()}") if __name__ == "__main__": - asyncio.run(create_future_task()) \ No newline at end of file + asyncio.run(create_future_task()) diff --git a/test_api_task.py b/test_api_task.py index d7348b1..056a1da 100644 --- a/test_api_task.py +++ b/test_api_task.py @@ -4,21 +4,21 @@ import asyncio from datetime import datetime, timedelta -from app.core.database import get_session_factory from app.main import get_global_scheduler_service -from app.models.scheduled_task import TaskType, RecurrenceType +from app.models.scheduled_task import RecurrenceType, TaskType + async def test_api_task_creation(): """Test creating a task through the scheduler service (simulates API call).""" try: scheduler_service = get_global_scheduler_service() - + # Create a task for 2 minutes from now future_time = datetime.utcnow() + timedelta(minutes=2) - + print(f"Creating task scheduled for: {future_time}") print(f"Current time: {datetime.utcnow()}") - + task = await scheduler_service.create_task( name=f"API Test Task {future_time.strftime('%H:%M:%S')}", task_type=TaskType.PLAY_SOUND, @@ -28,13 +28,13 @@ async def test_api_task_creation(): timezone="UTC", recurrence_type=RecurrenceType.NONE, ) - + print(f"Created task: {task.name} (ID: {task.id})") print(f"Task will execute in: {future_time - datetime.utcnow()}") print("Task should be automatically scheduled in APScheduler!") - + except Exception as e: print(f"Error: {e}") if __name__ == "__main__": - asyncio.run(test_api_task_creation()) \ No newline at end of file + asyncio.run(test_api_task_creation()) diff --git a/test_task.py b/test_task.py index a7b8945..9717a09 100644 --- a/test_task.py +++ b/test_task.py @@ -5,15 +5,16 @@ import asyncio from datetime import datetime from app.core.database import get_session_factory +from app.models.scheduled_task import RecurrenceType, TaskType from app.repositories.scheduled_task import ScheduledTaskRepository -from app.models.scheduled_task import TaskType, RecurrenceType + async def create_test_task(): session_factory = get_session_factory() - + async with session_factory() as session: repo = ScheduledTaskRepository(session) - + task_data = { "name": "Live Test Task", "task_type": TaskType.PLAY_SOUND, @@ -23,9 +24,9 @@ async def create_test_task(): "user_id": 1, "recurrence_type": RecurrenceType.NONE, } - + task = await repo.create(task_data) print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}") if __name__ == "__main__": - asyncio.run(create_test_task()) \ No newline at end of file + asyncio.run(create_test_task()) diff --git a/tests/conftest.py b/tests/conftest.py index bb47bd4..6fb634a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -351,11 +351,11 @@ async def admin_cookies(admin_user: User) -> dict[str, str]: @pytest.fixture def test_user_id(test_user: User): - """Get test user ID.""" + """Get test user ID.""" return test_user.id -@pytest.fixture +@pytest.fixture def test_sound_id(): """Create a test sound ID.""" import uuid @@ -364,7 +364,7 @@ def test_sound_id(): @pytest.fixture def test_playlist_id(): - """Create a test playlist ID.""" + """Create a test playlist ID.""" import uuid return uuid.uuid4() diff --git a/tests/test_scheduled_task_model.py b/tests/test_scheduled_task_model.py index cf47e8a..b4dd16b 100644 --- a/tests/test_scheduled_task_model.py +++ b/tests/test_scheduled_task_model.py @@ -3,8 +3,6 @@ import uuid from datetime import datetime, timedelta -import pytest - from app.models.scheduled_task import ( RecurrenceType, ScheduledTask, @@ -217,4 +215,4 @@ class TestScheduledTaskModel: assert RecurrenceType.WEEKLY == "weekly" assert RecurrenceType.MONTHLY == "monthly" assert RecurrenceType.YEARLY == "yearly" - assert RecurrenceType.CRON == "cron" \ No newline at end of file + assert RecurrenceType.CRON == "cron" diff --git a/tests/test_scheduled_task_repository.py b/tests/test_scheduled_task_repository.py index 6308735..f37b8e5 100644 --- a/tests/test_scheduled_task_repository.py +++ b/tests/test_scheduled_task_repository.py @@ -2,7 +2,6 @@ import uuid from datetime import datetime, timedelta -from typing import List import pytest from sqlmodel.ext.asyncio.session import AsyncSession @@ -491,4 +490,4 @@ class TestScheduledTaskRepository: updated_task = await repository.get_by_id(sample_task.id) assert updated_task.status == TaskStatus.FAILED # Non-recurring tasks should be deactivated on failure - assert updated_task.is_active is False \ No newline at end of file + assert updated_task.is_active is False diff --git a/tests/test_scheduler_service.py b/tests/test_scheduler_service.py index 6f0d711..8475382 100644 --- a/tests/test_scheduler_service.py +++ b/tests/test_scheduler_service.py @@ -51,7 +51,7 @@ class TestSchedulerService: sample_task_data: dict, ): """Test creating a scheduled task.""" - with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule: + with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule: task = await scheduler_service.create_task(**sample_task_data) assert task.id is not None @@ -68,7 +68,7 @@ class TestSchedulerService: test_user_id: uuid.UUID, ): """Test creating a user task.""" - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task( user_id=test_user_id, **sample_task_data, @@ -83,7 +83,7 @@ class TestSchedulerService: sample_task_data: dict, ): """Test creating a system task.""" - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task(**sample_task_data) assert task.user_id is None @@ -95,7 +95,7 @@ class TestSchedulerService: sample_task_data: dict, ): """Test creating a recurring task.""" - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task( recurrence_type=RecurrenceType.DAILY, recurrence_count=5, @@ -114,11 +114,11 @@ class TestSchedulerService: """Test creating task with timezone conversion.""" # Use a specific datetime for testing ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY - + sample_task_data["scheduled_at"] = ny_time sample_task_data["timezone"] = "America/New_York" - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task(**sample_task_data) # The scheduled_at should be converted to UTC @@ -134,11 +134,11 @@ class TestSchedulerService: ): """Test cancelling a task.""" # Create a task first - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task(**sample_task_data) # Mock the scheduler remove_job method - with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove: + with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove: result = await scheduler_service.cancel_task(task.id) assert result is True @@ -167,7 +167,7 @@ class TestSchedulerService: test_user_id: uuid.UUID, ): """Test getting user tasks.""" - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): # Create user task await scheduler_service.create_task( user_id=test_user_id, @@ -188,12 +188,12 @@ class TestSchedulerService: ): """Test ensuring system tasks exist.""" # Mock the repository to return no existing tasks - with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get: - with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create: + with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get: + with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create: mock_get.return_value = [] - + await scheduler_service._ensure_system_tasks() - + # Should create daily credit recharge task mock_create.assert_called_once() created_task = mock_create.call_args[0][0] @@ -213,13 +213,13 @@ class TestSchedulerService: recurrence_type=RecurrenceType.DAILY, is_active=True, ) - - with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get: - with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create: + + with patch("app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks") as mock_get: + with patch("app.repositories.scheduled_task.ScheduledTaskRepository.create") as mock_create: mock_get.return_value = [existing_task] - + await scheduler_service._ensure_system_tasks() - + # Should not create new task mock_create.assert_not_called() @@ -294,7 +294,7 @@ class TestSchedulerService: ): """Test calculating next execution time.""" now = datetime.utcnow() - + # Test different recurrence types test_cases = [ (RecurrenceType.HOURLY, timedelta(hours=1)), @@ -312,7 +312,7 @@ class TestSchedulerService: recurrence_type=recurrence_type, ) - with patch('app.services.scheduler.datetime') as mock_datetime: + with patch("app.services.scheduler.datetime") as mock_datetime: mock_datetime.utcnow.return_value = now next_execution = scheduler_service._calculate_next_execution(task) @@ -335,7 +335,7 @@ class TestSchedulerService: next_execution = scheduler_service._calculate_next_execution(task) assert next_execution is None - @patch('app.services.task_handlers.TaskHandlerRegistry') + @patch("app.services.task_handlers.TaskHandlerRegistry") async def test_execute_task_success( self, mock_handler_class, @@ -344,7 +344,7 @@ class TestSchedulerService: ): """Test successful task execution.""" # Create task - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task(**sample_task_data) # Mock handler registry @@ -365,7 +365,7 @@ class TestSchedulerService: assert updated_task.status == TaskStatus.COMPLETED assert updated_task.executions_count == 1 - @patch('app.services.task_handlers.TaskHandlerRegistry') + @patch("app.services.task_handlers.TaskHandlerRegistry") async def test_execute_task_failure( self, mock_handler_class, @@ -374,7 +374,7 @@ class TestSchedulerService: ): """Test task execution failure.""" # Create task - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task(**sample_task_data) # Mock handler to raise exception @@ -409,8 +409,8 @@ class TestSchedulerService: """Test executing expired task.""" # Create expired task sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1) - - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task(**sample_task_data) # Execute task @@ -430,20 +430,20 @@ class TestSchedulerService: sample_task_data: dict, ): """Test prevention of concurrent task execution.""" - with patch.object(scheduler_service, '_schedule_apscheduler_job'): + with patch.object(scheduler_service, "_schedule_apscheduler_job"): task = await scheduler_service.create_task(**sample_task_data) # Add task to running set scheduler_service._running_tasks.add(str(task.id)) # Try to execute - should return without doing anything - with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class: + with patch("app.services.task_handlers.TaskHandlerRegistry") as mock_handler_class: await scheduler_service._execute_task(task.id) - + # Handler should not be called mock_handler_class.assert_not_called() - @patch('app.repositories.scheduled_task.ScheduledTaskRepository') + @patch("app.repositories.scheduled_task.ScheduledTaskRepository") async def test_maintenance_job_expired_tasks( self, mock_repo_class, @@ -453,22 +453,22 @@ class TestSchedulerService: # Mock expired task expired_task = MagicMock() expired_task.id = uuid.uuid4() - + mock_repo = AsyncMock() mock_repo.get_expired_tasks.return_value = [expired_task] mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [] mock_repo_class.return_value = mock_repo - with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove: + with patch.object(scheduler_service.scheduler, "remove_job") as mock_remove: await scheduler_service._maintenance_job() - + # Should mark as cancelled and remove from scheduler assert expired_task.status == TaskStatus.CANCELLED assert expired_task.is_active is False mock_repo.update.assert_called_with(expired_task) mock_remove.assert_called_once_with(str(expired_task.id)) - @patch('app.repositories.scheduled_task.ScheduledTaskRepository') + @patch("app.repositories.scheduled_task.ScheduledTaskRepository") async def test_maintenance_job_due_recurring_tasks( self, mock_repo_class, @@ -479,17 +479,17 @@ class TestSchedulerService: due_task = MagicMock() due_task.should_repeat.return_value = True due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5) - + mock_repo = AsyncMock() mock_repo.get_expired_tasks.return_value = [] mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task] mock_repo_class.return_value = mock_repo - with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule: + with patch.object(scheduler_service, "_schedule_apscheduler_job") as mock_schedule: await scheduler_service._maintenance_job() - + # Should reset to pending and reschedule assert due_task.status == TaskStatus.PENDING assert due_task.scheduled_at == due_task.next_execution_at mock_repo.update.assert_called_with(due_task) - mock_schedule.assert_called_once_with(due_task) \ No newline at end of file + mock_schedule.assert_called_once_with(due_task) diff --git a/tests/test_task_handlers.py b/tests/test_task_handlers.py index ecce493..5d50aee 100644 --- a/tests/test_task_handlers.py +++ b/tests/test_task_handlers.py @@ -133,8 +133,8 @@ class TestTaskHandlerRegistry: mock_sound.id = test_sound_id mock_sound.filename = "test_sound.mp3" - with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound): - with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class: + with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound): + with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class: mock_vlc_service = AsyncMock() mock_vlc_class.return_value = mock_vlc_service @@ -186,7 +186,7 @@ class TestTaskHandlerRegistry: parameters={"sound_id": str(test_sound_id)}, ) - with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None): + with patch.object(task_registry.sound_repository, "get_by_id", return_value=None): with pytest.raises(TaskExecutionError, match="Sound not found"): await task_registry.execute_task(task) @@ -206,8 +206,8 @@ class TestTaskHandlerRegistry: mock_sound = MagicMock() mock_sound.filename = "test_sound.mp3" - with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound): - with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class: + with patch.object(task_registry.sound_repository, "get_by_id", return_value=mock_sound): + with patch("app.services.vlc_player.VLCPlayerService") as mock_vlc_class: mock_vlc_service = AsyncMock() mock_vlc_class.return_value = mock_vlc_service @@ -238,7 +238,7 @@ class TestTaskHandlerRegistry: mock_playlist.id = test_playlist_id mock_playlist.name = "Test Playlist" - with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): + with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist): await task_registry.execute_task(task) task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id) @@ -264,7 +264,7 @@ class TestTaskHandlerRegistry: mock_playlist = MagicMock() mock_playlist.name = "Test Playlist" - with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): + with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist): await task_registry.execute_task(task) # Should use default values @@ -314,7 +314,7 @@ class TestTaskHandlerRegistry: parameters={"playlist_id": str(test_playlist_id)}, ) - with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None): + with patch.object(task_registry.playlist_repository, "get_by_id", return_value=None): with pytest.raises(TaskExecutionError, match="Playlist not found"): await task_registry.execute_task(task) @@ -327,7 +327,7 @@ class TestTaskHandlerRegistry: """Test play playlist task with various valid play modes.""" mock_playlist = MagicMock() mock_playlist.name = "Test Playlist" - + valid_modes = ["continuous", "loop", "loop_one", "random", "single"] for mode in valid_modes: @@ -341,7 +341,7 @@ class TestTaskHandlerRegistry: }, ) - with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): + with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist): await task_registry.execute_task(task) mock_player_service.set_mode.assert_called_with(mode) @@ -368,7 +368,7 @@ class TestTaskHandlerRegistry: mock_playlist = MagicMock() mock_playlist.name = "Test Playlist" - with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist): + with patch.object(task_registry.playlist_repository, "get_by_id", return_value=mock_playlist): await task_registry.execute_task(task) # Should not call set_mode for invalid mode @@ -421,4 +421,4 @@ class TestTaskHandlerRegistry: TaskType.PLAY_SOUND, TaskType.PLAY_PLAYLIST, } - assert set(registry._handlers.keys()) == expected_handlers \ No newline at end of file + assert set(registry._handlers.keys()) == expected_handlers