Add comprehensive tests for scheduled task repository, scheduler service, and task handlers
- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates. - Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs. - Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback. - Ensured proper error handling and edge cases in task execution scenarios. - Added fixtures and mocks to facilitate isolated testing of services and repositories.
This commit is contained in:
232
SCHEDULER_EXAMPLE.md
Normal file
232
SCHEDULER_EXAMPLE.md
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
# Enhanced Scheduler System - Usage Examples
|
||||||
|
|
||||||
|
This document demonstrates how to use the new comprehensive scheduled task system.
|
||||||
|
|
||||||
|
## Features Overview
|
||||||
|
|
||||||
|
### ✨ **Task Types**
|
||||||
|
- **Credit Recharge**: Automatic or scheduled credit replenishment
|
||||||
|
- **Play Sound**: Schedule individual sound playback
|
||||||
|
- **Play Playlist**: Schedule playlist playback with modes
|
||||||
|
|
||||||
|
### 🌍 **Timezone Support**
|
||||||
|
- Full timezone support with automatic UTC conversion
|
||||||
|
- Specify any IANA timezone (e.g., "America/New_York", "Europe/Paris")
|
||||||
|
|
||||||
|
### 🔄 **Scheduling Options**
|
||||||
|
- **One-shot**: Execute once at specific date/time
|
||||||
|
- **Recurring**: Hourly, daily, weekly, monthly, yearly intervals
|
||||||
|
- **Cron**: Custom cron expressions for complex scheduling
|
||||||
|
|
||||||
|
## API Usage Examples
|
||||||
|
|
||||||
|
### Create a One-Shot Task
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Schedule a sound to play in 2 hours
|
||||||
|
curl -X POST "http://localhost:8000/api/v1/scheduler/tasks" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Cookie: access_token=YOUR_JWT_TOKEN" \
|
||||||
|
-d '{
|
||||||
|
"name": "Play Morning Alarm",
|
||||||
|
"task_type": "play_sound",
|
||||||
|
"scheduled_at": "2024-01-01T10:00:00",
|
||||||
|
"timezone": "America/New_York",
|
||||||
|
"parameters": {
|
||||||
|
"sound_id": "sound-uuid-here"
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create a Recurring Task
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Daily credit recharge at midnight UTC
|
||||||
|
curl -X POST "http://localhost:8000/api/v1/scheduler/admin/system-tasks" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Cookie: access_token=ADMIN_JWT_TOKEN" \
|
||||||
|
-d '{
|
||||||
|
"name": "Daily Credit Recharge",
|
||||||
|
"task_type": "credit_recharge",
|
||||||
|
"scheduled_at": "2024-01-01T00:00:00",
|
||||||
|
"timezone": "UTC",
|
||||||
|
"recurrence_type": "daily",
|
||||||
|
"parameters": {}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create a Cron-Based Task
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Play playlist every weekday at 9 AM
|
||||||
|
curl -X POST "http://localhost:8000/api/v1/scheduler/tasks" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Cookie: access_token=YOUR_JWT_TOKEN" \
|
||||||
|
-d '{
|
||||||
|
"name": "Workday Playlist",
|
||||||
|
"task_type": "play_playlist",
|
||||||
|
"scheduled_at": "2024-01-01T09:00:00",
|
||||||
|
"timezone": "America/New_York",
|
||||||
|
"recurrence_type": "cron",
|
||||||
|
"cron_expression": "0 9 * * 1-5",
|
||||||
|
"parameters": {
|
||||||
|
"playlist_id": "playlist-uuid-here",
|
||||||
|
"play_mode": "loop",
|
||||||
|
"shuffle": true
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Python Service Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from app.services.scheduler import SchedulerService
|
||||||
|
from app.models.scheduled_task import TaskType, RecurrenceType
|
||||||
|
|
||||||
|
# Initialize scheduler service
|
||||||
|
scheduler_service = SchedulerService(db_session_factory, player_service)
|
||||||
|
|
||||||
|
# Create a one-shot task
|
||||||
|
task = await scheduler_service.create_task(
|
||||||
|
name="Test Sound",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=2),
|
||||||
|
timezone="America/New_York",
|
||||||
|
parameters={"sound_id": "sound-uuid-here"},
|
||||||
|
user_id=user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a recurring task
|
||||||
|
recurring_task = await scheduler_service.create_task(
|
||||||
|
name="Weekly Playlist",
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(days=1),
|
||||||
|
recurrence_type=RecurrenceType.WEEKLY,
|
||||||
|
recurrence_count=10, # Run 10 times then stop
|
||||||
|
parameters={
|
||||||
|
"playlist_id": "playlist-uuid",
|
||||||
|
"play_mode": "continuous",
|
||||||
|
"shuffle": False
|
||||||
|
},
|
||||||
|
user_id=user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cancel a task
|
||||||
|
success = await scheduler_service.cancel_task(task.id)
|
||||||
|
|
||||||
|
# Get user's tasks
|
||||||
|
user_tasks = await scheduler_service.get_user_tasks(
|
||||||
|
user_id=user.id,
|
||||||
|
status=TaskStatus.PENDING,
|
||||||
|
limit=20
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Task Parameters
|
||||||
|
|
||||||
|
### Credit Recharge Parameters
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"user_id": "uuid-string-or-null" // null for all users (system task)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Play Sound Parameters
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sound_id": "uuid-string" // Required: sound to play
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Play Playlist Parameters
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"playlist_id": "uuid-string", // Required: playlist to play
|
||||||
|
"play_mode": "continuous", // Optional: continuous, loop, loop_one, random, single
|
||||||
|
"shuffle": false // Optional: shuffle playlist
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Recurrence Types
|
||||||
|
|
||||||
|
| Type | Description | Example |
|
||||||
|
|------|-------------|---------|
|
||||||
|
| `none` | One-shot execution | Single alarm |
|
||||||
|
| `hourly` | Every hour | Hourly reminders |
|
||||||
|
| `daily` | Every day | Daily credit recharge |
|
||||||
|
| `weekly` | Every week | Weekly reports |
|
||||||
|
| `monthly` | Every month | Monthly maintenance |
|
||||||
|
| `yearly` | Every year | Annual renewals |
|
||||||
|
| `cron` | Custom cron expression | Complex schedules |
|
||||||
|
|
||||||
|
## Cron Expression Examples
|
||||||
|
|
||||||
|
| Expression | Description |
|
||||||
|
|------------|-------------|
|
||||||
|
| `0 9 * * *` | Daily at 9 AM |
|
||||||
|
| `0 9 * * 1-5` | Weekdays at 9 AM |
|
||||||
|
| `30 14 1 * *` | 1st of month at 2:30 PM |
|
||||||
|
| `0 0 * * 0` | Every Sunday at midnight |
|
||||||
|
| `*/15 * * * *` | Every 15 minutes |
|
||||||
|
|
||||||
|
## System Tasks vs User Tasks
|
||||||
|
|
||||||
|
### System Tasks
|
||||||
|
- Created by administrators
|
||||||
|
- No user association (`user_id` is null)
|
||||||
|
- Typically for maintenance operations
|
||||||
|
- Accessible via admin endpoints
|
||||||
|
|
||||||
|
### User Tasks
|
||||||
|
- Created by regular users
|
||||||
|
- Associated with specific user
|
||||||
|
- User can only manage their own tasks
|
||||||
|
- Accessible via regular user endpoints
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The system provides comprehensive error handling:
|
||||||
|
|
||||||
|
- **Invalid Parameters**: Validation errors for missing or invalid task parameters
|
||||||
|
- **Scheduling Conflicts**: Prevention of resource conflicts
|
||||||
|
- **Timezone Errors**: Invalid timezone specifications handled gracefully
|
||||||
|
- **Execution Failures**: Failed tasks marked with error messages and retry logic
|
||||||
|
- **Expired Tasks**: Automatic cleanup of expired tasks
|
||||||
|
|
||||||
|
## Monitoring and Management
|
||||||
|
|
||||||
|
### Get Task Status
|
||||||
|
```bash
|
||||||
|
curl "http://localhost:8000/api/v1/scheduler/tasks/{task-id}" \
|
||||||
|
-H "Cookie: access_token=YOUR_JWT_TOKEN"
|
||||||
|
```
|
||||||
|
|
||||||
|
### List User Tasks
|
||||||
|
```bash
|
||||||
|
curl "http://localhost:8000/api/v1/scheduler/tasks?status=pending&limit=10" \
|
||||||
|
-H "Cookie: access_token=YOUR_JWT_TOKEN"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Admin: View All Tasks
|
||||||
|
```bash
|
||||||
|
curl "http://localhost:8000/api/v1/scheduler/admin/tasks?limit=50" \
|
||||||
|
-H "Cookie: access_token=ADMIN_JWT_TOKEN"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cancel Task
|
||||||
|
```bash
|
||||||
|
curl -X DELETE "http://localhost:8000/api/v1/scheduler/tasks/{task-id}" \
|
||||||
|
-H "Cookie: access_token=YOUR_JWT_TOKEN"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration from Old Scheduler
|
||||||
|
|
||||||
|
The new system automatically:
|
||||||
|
|
||||||
|
1. **Creates system tasks**: Daily credit recharge task created on startup
|
||||||
|
2. **Maintains compatibility**: Existing credit recharge functionality preserved
|
||||||
|
3. **Enhances functionality**: Adds user tasks and new task types
|
||||||
|
4. **Improves reliability**: Better error handling and timezone support
|
||||||
|
|
||||||
|
The old scheduler is completely replaced - no migration needed for existing functionality.
|
||||||
@@ -12,6 +12,7 @@ from app.api.v1 import (
|
|||||||
main,
|
main,
|
||||||
player,
|
player,
|
||||||
playlists,
|
playlists,
|
||||||
|
scheduler,
|
||||||
socket,
|
socket,
|
||||||
sounds,
|
sounds,
|
||||||
)
|
)
|
||||||
@@ -28,6 +29,7 @@ api_router.include_router(files.router, tags=["files"])
|
|||||||
api_router.include_router(main.router, tags=["main"])
|
api_router.include_router(main.router, tags=["main"])
|
||||||
api_router.include_router(player.router, tags=["player"])
|
api_router.include_router(player.router, tags=["player"])
|
||||||
api_router.include_router(playlists.router, tags=["playlists"])
|
api_router.include_router(playlists.router, tags=["playlists"])
|
||||||
|
api_router.include_router(scheduler.router, tags=["scheduler"])
|
||||||
api_router.include_router(socket.router, tags=["socket"])
|
api_router.include_router(socket.router, tags=["socket"])
|
||||||
api_router.include_router(sounds.router, tags=["sounds"])
|
api_router.include_router(sounds.router, tags=["sounds"])
|
||||||
api_router.include_router(admin.router)
|
api_router.include_router(admin.router)
|
||||||
|
|||||||
228
app/api/v1/scheduler.py
Normal file
228
app/api/v1/scheduler.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""API endpoints for scheduled task management."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.database import get_db
|
||||||
|
from app.core.dependencies import (
|
||||||
|
get_admin_user,
|
||||||
|
get_current_active_user,
|
||||||
|
)
|
||||||
|
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||||
|
from app.models.user import User
|
||||||
|
from app.schemas.scheduler import (
|
||||||
|
ScheduledTaskCreate,
|
||||||
|
ScheduledTaskResponse,
|
||||||
|
ScheduledTaskUpdate,
|
||||||
|
)
|
||||||
|
from app.services.scheduler import SchedulerService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/scheduler")
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler_service() -> SchedulerService:
|
||||||
|
"""Get the global scheduler service instance."""
|
||||||
|
from app.main import get_global_scheduler_service
|
||||||
|
return get_global_scheduler_service()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tasks", response_model=ScheduledTaskResponse)
|
||||||
|
async def create_task(
|
||||||
|
task_data: ScheduledTaskCreate,
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||||
|
) -> ScheduledTask:
|
||||||
|
"""Create a new scheduled task."""
|
||||||
|
try:
|
||||||
|
task = await scheduler_service.create_task(
|
||||||
|
name=task_data.name,
|
||||||
|
task_type=task_data.task_type,
|
||||||
|
scheduled_at=task_data.scheduled_at,
|
||||||
|
parameters=task_data.parameters,
|
||||||
|
user_id=current_user.id,
|
||||||
|
timezone=task_data.timezone,
|
||||||
|
recurrence_type=task_data.recurrence_type,
|
||||||
|
cron_expression=task_data.cron_expression,
|
||||||
|
recurrence_count=task_data.recurrence_count,
|
||||||
|
expires_at=task_data.expires_at,
|
||||||
|
)
|
||||||
|
return task
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tasks", response_model=List[ScheduledTaskResponse])
|
||||||
|
async def get_user_tasks(
|
||||||
|
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||||
|
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||||
|
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"),
|
||||||
|
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||||
|
) -> List[ScheduledTask]:
|
||||||
|
"""Get user's scheduled tasks."""
|
||||||
|
return await scheduler_service.get_user_tasks(
|
||||||
|
user_id=current_user.id,
|
||||||
|
status=status,
|
||||||
|
task_type=task_type,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tasks/{task_id}", response_model=ScheduledTaskResponse)
|
||||||
|
async def get_task(
|
||||||
|
task_id: int,
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
db_session: AsyncSession = Depends(get_db),
|
||||||
|
) -> ScheduledTask:
|
||||||
|
"""Get a specific scheduled task."""
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
|
repo = ScheduledTaskRepository(db_session)
|
||||||
|
task = await repo.get_by_id(task_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
# Check if user owns the task or is admin
|
||||||
|
if task.user_id != current_user.id and not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/tasks/{task_id}", response_model=ScheduledTaskResponse)
|
||||||
|
async def update_task(
|
||||||
|
task_id: int,
|
||||||
|
task_update: ScheduledTaskUpdate,
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
db_session: AsyncSession = Depends(get_db),
|
||||||
|
) -> ScheduledTask:
|
||||||
|
"""Update a scheduled task."""
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
|
repo = ScheduledTaskRepository(db_session)
|
||||||
|
task = await repo.get_by_id(task_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
# Check if user owns the task or is admin
|
||||||
|
if task.user_id != current_user.id and not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
# Update task fields
|
||||||
|
update_data = task_update.model_dump(exclude_unset=True)
|
||||||
|
for field, value in update_data.items():
|
||||||
|
setattr(task, field, value)
|
||||||
|
|
||||||
|
updated_task = await repo.update(task)
|
||||||
|
return updated_task
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/tasks/{task_id}")
|
||||||
|
async def cancel_task(
|
||||||
|
task_id: int,
|
||||||
|
current_user: User = Depends(get_current_active_user),
|
||||||
|
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||||
|
db_session: AsyncSession = Depends(get_db),
|
||||||
|
) -> dict:
|
||||||
|
"""Cancel a scheduled task."""
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
|
repo = ScheduledTaskRepository(db_session)
|
||||||
|
task = await repo.get_by_id(task_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(status_code=404, detail="Task not found")
|
||||||
|
|
||||||
|
# Check if user owns the task or is admin
|
||||||
|
if task.user_id != current_user.id and not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
success = await scheduler_service.cancel_task(task_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to cancel task")
|
||||||
|
|
||||||
|
return {"message": "Task cancelled successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
# Admin-only endpoints
|
||||||
|
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse])
|
||||||
|
async def get_all_tasks(
|
||||||
|
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||||
|
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||||
|
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"),
|
||||||
|
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
||||||
|
current_user: User = Depends(get_admin_user),
|
||||||
|
db_session: AsyncSession = Depends(get_db),
|
||||||
|
) -> List[ScheduledTask]:
|
||||||
|
"""Get all scheduled tasks (admin only)."""
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
|
repo = ScheduledTaskRepository(db_session)
|
||||||
|
|
||||||
|
# Get all tasks with pagination and filtering
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
statement = select(ScheduledTask)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
statement = statement.where(ScheduledTask.status == status)
|
||||||
|
|
||||||
|
if task_type:
|
||||||
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||||
|
|
||||||
|
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||||
|
|
||||||
|
if offset:
|
||||||
|
statement = statement.offset(offset)
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
statement = statement.limit(limit)
|
||||||
|
|
||||||
|
result = await db_session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse])
|
||||||
|
async def get_system_tasks(
|
||||||
|
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||||
|
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||||
|
current_user: User = Depends(get_admin_user),
|
||||||
|
db_session: AsyncSession = Depends(get_db),
|
||||||
|
) -> List[ScheduledTask]:
|
||||||
|
"""Get system tasks (admin only)."""
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
|
repo = ScheduledTaskRepository(db_session)
|
||||||
|
return await repo.get_system_tasks(status=status, task_type=task_type)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/admin/system-tasks", response_model=ScheduledTaskResponse)
|
||||||
|
async def create_system_task(
|
||||||
|
task_data: ScheduledTaskCreate,
|
||||||
|
current_user: User = Depends(get_admin_user),
|
||||||
|
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||||
|
) -> ScheduledTask:
|
||||||
|
"""Create a system task (admin only)."""
|
||||||
|
try:
|
||||||
|
task = await scheduler_service.create_task(
|
||||||
|
name=task_data.name,
|
||||||
|
task_type=task_data.task_type,
|
||||||
|
scheduled_at=task_data.scheduled_at,
|
||||||
|
parameters=task_data.parameters,
|
||||||
|
user_id=None, # System task
|
||||||
|
timezone=task_data.timezone,
|
||||||
|
recurrence_type=task_data.recurrence_type,
|
||||||
|
cron_expression=task_data.cron_expression,
|
||||||
|
recurrence_count=task_data.recurrence_count,
|
||||||
|
expires_at=task_data.expires_at,
|
||||||
|
)
|
||||||
|
return task
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
@@ -7,17 +7,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
from app.core.seeds import seed_all_data
|
from app.core.seeds import seed_all_data
|
||||||
from app.models import ( # noqa: F401
|
# Import all models to ensure SQLModel metadata discovery
|
||||||
extraction,
|
import app.models # noqa: F401
|
||||||
favorite,
|
|
||||||
plan,
|
|
||||||
playlist,
|
|
||||||
playlist_sound,
|
|
||||||
sound,
|
|
||||||
sound_played,
|
|
||||||
user,
|
|
||||||
user_oauth,
|
|
||||||
)
|
|
||||||
|
|
||||||
engine: AsyncEngine = create_async_engine(
|
engine: AsyncEngine = create_async_engine(
|
||||||
settings.DATABASE_URL,
|
settings.DATABASE_URL,
|
||||||
|
|||||||
31
app/main.py
31
app/main.py
@@ -11,14 +11,27 @@ from app.core.database import get_session_factory, init_db
|
|||||||
from app.core.logging import get_logger, setup_logging
|
from app.core.logging import get_logger, setup_logging
|
||||||
from app.middleware.logging import LoggingMiddleware
|
from app.middleware.logging import LoggingMiddleware
|
||||||
from app.services.extraction_processor import extraction_processor
|
from app.services.extraction_processor import extraction_processor
|
||||||
from app.services.player import initialize_player_service, shutdown_player_service
|
from app.services.player import initialize_player_service, shutdown_player_service, get_player_service
|
||||||
from app.services.scheduler import SchedulerService
|
from app.services.scheduler import SchedulerService
|
||||||
from app.services.socket import socket_manager
|
from app.services.socket import socket_manager
|
||||||
|
|
||||||
|
|
||||||
|
scheduler_service = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_scheduler_service() -> SchedulerService:
|
||||||
|
"""Get the global scheduler service instance."""
|
||||||
|
global scheduler_service
|
||||||
|
if scheduler_service is None:
|
||||||
|
raise RuntimeError("Scheduler service not initialized")
|
||||||
|
return scheduler_service
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
||||||
"""Application lifespan context manager for setup and teardown."""
|
"""Application lifespan context manager for setup and teardown."""
|
||||||
|
global scheduler_service
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
logger.info("Starting application")
|
logger.info("Starting application")
|
||||||
@@ -35,17 +48,23 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None]:
|
|||||||
logger.info("Player service started")
|
logger.info("Player service started")
|
||||||
|
|
||||||
# Start the scheduler service
|
# Start the scheduler service
|
||||||
scheduler_service = SchedulerService(get_session_factory())
|
try:
|
||||||
await scheduler_service.start()
|
player_service = get_player_service() # Get the initialized player service
|
||||||
logger.info("Scheduler service started")
|
scheduler_service = SchedulerService(get_session_factory(), player_service)
|
||||||
|
await scheduler_service.start()
|
||||||
|
logger.info("Enhanced scheduler service started")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to start scheduler service - continuing without it")
|
||||||
|
scheduler_service = None
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
logger.info("Shutting down application")
|
logger.info("Shutting down application")
|
||||||
|
|
||||||
# Stop the scheduler service
|
# Stop the scheduler service
|
||||||
await scheduler_service.stop()
|
if scheduler_service:
|
||||||
logger.info("Scheduler service stopped")
|
await scheduler_service.stop()
|
||||||
|
logger.info("Scheduler service stopped")
|
||||||
|
|
||||||
# Stop the player service
|
# Stop the player service
|
||||||
await shutdown_player_service()
|
await shutdown_player_service()
|
||||||
|
|||||||
@@ -1 +1,32 @@
|
|||||||
"""Models package."""
|
"""Models package."""
|
||||||
|
|
||||||
|
# Import all models for SQLAlchemy metadata discovery
|
||||||
|
from .base import BaseModel
|
||||||
|
from .credit_action import CreditAction
|
||||||
|
from .credit_transaction import CreditTransaction
|
||||||
|
from .extraction import Extraction
|
||||||
|
from .favorite import Favorite
|
||||||
|
from .plan import Plan
|
||||||
|
from .playlist import Playlist
|
||||||
|
from .playlist_sound import PlaylistSound
|
||||||
|
from .scheduled_task import ScheduledTask
|
||||||
|
from .sound import Sound
|
||||||
|
from .sound_played import SoundPlayed
|
||||||
|
from .user import User
|
||||||
|
from .user_oauth import UserOauth
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseModel",
|
||||||
|
"CreditAction",
|
||||||
|
"CreditTransaction",
|
||||||
|
"Extraction",
|
||||||
|
"Favorite",
|
||||||
|
"Plan",
|
||||||
|
"Playlist",
|
||||||
|
"PlaylistSound",
|
||||||
|
"ScheduledTask",
|
||||||
|
"Sound",
|
||||||
|
"SoundPlayed",
|
||||||
|
"User",
|
||||||
|
"UserOauth",
|
||||||
|
]
|
||||||
|
|||||||
125
app/models/scheduled_task.py
Normal file
125
app/models/scheduled_task.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Scheduled task model for flexible task scheduling with timezone support."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlmodel import JSON, Column, Field, SQLModel
|
||||||
|
|
||||||
|
from app.models.base import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TaskType(str, Enum):
|
||||||
|
"""Available task types."""
|
||||||
|
|
||||||
|
CREDIT_RECHARGE = "credit_recharge"
|
||||||
|
PLAY_SOUND = "play_sound"
|
||||||
|
PLAY_PLAYLIST = "play_playlist"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
"""Task execution status."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrenceType(str, Enum):
|
||||||
|
"""Recurrence patterns."""
|
||||||
|
|
||||||
|
NONE = "none" # One-shot task
|
||||||
|
HOURLY = "hourly"
|
||||||
|
DAILY = "daily"
|
||||||
|
WEEKLY = "weekly"
|
||||||
|
MONTHLY = "monthly"
|
||||||
|
YEARLY = "yearly"
|
||||||
|
CRON = "cron" # Custom cron expression
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledTask(BaseModel, table=True):
|
||||||
|
"""Model for scheduled tasks with timezone support."""
|
||||||
|
|
||||||
|
__tablename__ = "scheduled_tasks"
|
||||||
|
|
||||||
|
id: int | None = Field(primary_key=True, default=None)
|
||||||
|
name: str = Field(max_length=255, description="Human-readable task name")
|
||||||
|
task_type: TaskType = Field(description="Type of task to execute")
|
||||||
|
status: TaskStatus = Field(default=TaskStatus.PENDING)
|
||||||
|
|
||||||
|
# Scheduling fields with timezone support
|
||||||
|
scheduled_at: datetime = Field(description="When the task should be executed (UTC)")
|
||||||
|
timezone: str = Field(
|
||||||
|
default="UTC",
|
||||||
|
description="Timezone for scheduling (e.g., 'America/New_York', 'Europe/Paris')",
|
||||||
|
)
|
||||||
|
recurrence_type: RecurrenceType = Field(default=RecurrenceType.NONE)
|
||||||
|
cron_expression: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Cron expression for custom recurrence (when recurrence_type is CRON)",
|
||||||
|
)
|
||||||
|
recurrence_count: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of times to repeat (None for infinite)",
|
||||||
|
)
|
||||||
|
executions_count: int = Field(default=0, description="Number of times executed")
|
||||||
|
|
||||||
|
# Task parameters
|
||||||
|
parameters: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
sa_column=Column(JSON),
|
||||||
|
description="Task-specific parameters",
|
||||||
|
)
|
||||||
|
|
||||||
|
# User association (None for system tasks)
|
||||||
|
user_id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
foreign_key="user.id",
|
||||||
|
description="User who created the task (None for system tasks)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execution tracking
|
||||||
|
last_executed_at: Optional[datetime] = Field(
|
||||||
|
default=None,
|
||||||
|
description="When the task was last executed (UTC)",
|
||||||
|
)
|
||||||
|
next_execution_at: Optional[datetime] = Field(
|
||||||
|
default=None,
|
||||||
|
description="When the task should be executed next (UTC, for recurring tasks)",
|
||||||
|
)
|
||||||
|
error_message: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Error message if execution failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Task lifecycle
|
||||||
|
is_active: bool = Field(default=True, description="Whether the task is active")
|
||||||
|
expires_at: Optional[datetime] = Field(
|
||||||
|
default=None,
|
||||||
|
description="When the task expires (UTC, optional)",
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if the task has expired."""
|
||||||
|
if self.expires_at is None:
|
||||||
|
return False
|
||||||
|
return datetime.utcnow() > self.expires_at
|
||||||
|
|
||||||
|
def is_recurring(self) -> bool:
|
||||||
|
"""Check if the task is recurring."""
|
||||||
|
return self.recurrence_type != RecurrenceType.NONE
|
||||||
|
|
||||||
|
def should_repeat(self) -> bool:
|
||||||
|
"""Check if the task should be repeated."""
|
||||||
|
if not self.is_recurring():
|
||||||
|
return False
|
||||||
|
if self.recurrence_count is None:
|
||||||
|
return True
|
||||||
|
return self.executions_count < self.recurrence_count
|
||||||
|
|
||||||
|
def is_system_task(self) -> bool:
|
||||||
|
"""Check if this is a system task (no user association)."""
|
||||||
|
return self.user_id is None
|
||||||
177
app/repositories/scheduled_task.py
Normal file
177
app/repositories/scheduled_task.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""Repository for scheduled task operations."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||||
|
from app.repositories.base import BaseRepository
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledTaskRepository(BaseRepository[ScheduledTask]):
|
||||||
|
"""Repository for scheduled task database operations."""
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
"""Initialize the repository."""
|
||||||
|
super().__init__(ScheduledTask, session)
|
||||||
|
|
||||||
|
async def get_pending_tasks(self) -> List[ScheduledTask]:
|
||||||
|
"""Get all pending tasks that are ready to be executed."""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
statement = select(ScheduledTask).where(
|
||||||
|
ScheduledTask.status == TaskStatus.PENDING,
|
||||||
|
ScheduledTask.is_active.is_(True),
|
||||||
|
ScheduledTask.scheduled_at <= now,
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_active_tasks(self) -> List[ScheduledTask]:
|
||||||
|
"""Get all active tasks."""
|
||||||
|
statement = select(ScheduledTask).where(
|
||||||
|
ScheduledTask.is_active.is_(True),
|
||||||
|
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_user_tasks(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
status: Optional[TaskStatus] = None,
|
||||||
|
task_type: Optional[TaskType] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: Optional[int] = None,
|
||||||
|
) -> List[ScheduledTask]:
|
||||||
|
"""Get tasks for a specific user."""
|
||||||
|
statement = select(ScheduledTask).where(ScheduledTask.user_id == user_id)
|
||||||
|
|
||||||
|
if status:
|
||||||
|
statement = statement.where(ScheduledTask.status == status)
|
||||||
|
|
||||||
|
if task_type:
|
||||||
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||||
|
|
||||||
|
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||||
|
|
||||||
|
if offset:
|
||||||
|
statement = statement.offset(offset)
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
statement = statement.limit(limit)
|
||||||
|
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_system_tasks(
|
||||||
|
self,
|
||||||
|
status: Optional[TaskStatus] = None,
|
||||||
|
task_type: Optional[TaskType] = None,
|
||||||
|
) -> List[ScheduledTask]:
|
||||||
|
"""Get system tasks (tasks with no user association)."""
|
||||||
|
statement = select(ScheduledTask).where(ScheduledTask.user_id.is_(None))
|
||||||
|
|
||||||
|
if status:
|
||||||
|
statement = statement.where(ScheduledTask.status == status)
|
||||||
|
|
||||||
|
if task_type:
|
||||||
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||||
|
|
||||||
|
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||||
|
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_recurring_tasks_due_for_next_execution(self) -> List[ScheduledTask]:
|
||||||
|
"""Get recurring tasks that need their next execution scheduled."""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
statement = select(ScheduledTask).where(
|
||||||
|
ScheduledTask.recurrence_type != RecurrenceType.NONE,
|
||||||
|
ScheduledTask.is_active.is_(True),
|
||||||
|
ScheduledTask.status == TaskStatus.COMPLETED,
|
||||||
|
ScheduledTask.next_execution_at <= now,
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def get_expired_tasks(self) -> List[ScheduledTask]:
|
||||||
|
"""Get expired tasks that should be cleaned up."""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
statement = select(ScheduledTask).where(
|
||||||
|
ScheduledTask.expires_at.is_not(None),
|
||||||
|
ScheduledTask.expires_at <= now,
|
||||||
|
ScheduledTask.is_active.is_(True),
|
||||||
|
)
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
return list(result.all())
|
||||||
|
|
||||||
|
async def cancel_user_tasks(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
task_type: Optional[TaskType] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Cancel all pending/running tasks for a user."""
|
||||||
|
statement = select(ScheduledTask).where(
|
||||||
|
ScheduledTask.user_id == user_id,
|
||||||
|
ScheduledTask.status.in_([TaskStatus.PENDING, TaskStatus.RUNNING]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if task_type:
|
||||||
|
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||||
|
|
||||||
|
result = await self.session.exec(statement)
|
||||||
|
tasks = list(result.all())
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for task in tasks:
|
||||||
|
task.status = TaskStatus.CANCELLED
|
||||||
|
task.is_active = False
|
||||||
|
self.session.add(task)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
await self.session.commit()
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def mark_as_running(self, task: ScheduledTask) -> None:
|
||||||
|
"""Mark a task as running."""
|
||||||
|
task.status = TaskStatus.RUNNING
|
||||||
|
self.session.add(task)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(task)
|
||||||
|
|
||||||
|
async def mark_as_completed(
|
||||||
|
self,
|
||||||
|
task: ScheduledTask,
|
||||||
|
next_execution_at: Optional[datetime] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Mark a task as completed and set next execution if recurring."""
|
||||||
|
task.status = TaskStatus.COMPLETED
|
||||||
|
task.last_executed_at = datetime.utcnow()
|
||||||
|
task.executions_count += 1
|
||||||
|
task.error_message = None
|
||||||
|
|
||||||
|
if next_execution_at and task.should_repeat():
|
||||||
|
task.next_execution_at = next_execution_at
|
||||||
|
task.status = TaskStatus.PENDING
|
||||||
|
elif not task.should_repeat():
|
||||||
|
task.is_active = False
|
||||||
|
|
||||||
|
self.session.add(task)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(task)
|
||||||
|
|
||||||
|
async def mark_as_failed(self, task: ScheduledTask, error_message: str) -> None:
|
||||||
|
"""Mark a task as failed with error message."""
|
||||||
|
task.status = TaskStatus.FAILED
|
||||||
|
task.error_message = error_message
|
||||||
|
task.last_executed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
# For non-recurring tasks, deactivate on failure
|
||||||
|
if not task.is_recurring():
|
||||||
|
task.is_active = False
|
||||||
|
|
||||||
|
self.session.add(task)
|
||||||
|
await self.session.commit()
|
||||||
|
await self.session.refresh(task)
|
||||||
189
app/schemas/scheduler.py
Normal file
189
app/schemas/scheduler.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""Schemas for scheduled task API."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.scheduled_task import RecurrenceType, TaskStatus, TaskType
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledTaskBase(BaseModel):
|
||||||
|
"""Base schema for scheduled tasks."""
|
||||||
|
|
||||||
|
name: str = Field(description="Human-readable task name")
|
||||||
|
task_type: TaskType = Field(description="Type of task to execute")
|
||||||
|
scheduled_at: datetime = Field(description="When the task should be executed")
|
||||||
|
timezone: str = Field(default="UTC", description="Timezone for scheduling")
|
||||||
|
parameters: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Task-specific parameters",
|
||||||
|
)
|
||||||
|
recurrence_type: RecurrenceType = Field(
|
||||||
|
default=RecurrenceType.NONE,
|
||||||
|
description="Recurrence pattern",
|
||||||
|
)
|
||||||
|
cron_expression: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Cron expression for custom recurrence",
|
||||||
|
)
|
||||||
|
recurrence_count: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of times to repeat (None for infinite)",
|
||||||
|
)
|
||||||
|
expires_at: Optional[datetime] = Field(
|
||||||
|
default=None,
|
||||||
|
description="When the task expires (optional)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledTaskCreate(ScheduledTaskBase):
|
||||||
|
"""Schema for creating a scheduled task."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledTaskUpdate(BaseModel):
|
||||||
|
"""Schema for updating a scheduled task."""
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
scheduled_at: Optional[datetime] = None
|
||||||
|
timezone: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, Any]] = None
|
||||||
|
is_active: Optional[bool] = None
|
||||||
|
expires_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduledTaskResponse(ScheduledTaskBase):
|
||||||
|
"""Schema for scheduled task responses."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
status: TaskStatus
|
||||||
|
user_id: Optional[int] = None
|
||||||
|
executions_count: int
|
||||||
|
last_executed_at: Optional[datetime] = None
|
||||||
|
next_execution_at: Optional[datetime] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
is_active: bool
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic configuration."""
|
||||||
|
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
# Task-specific parameter schemas
|
||||||
|
class CreditRechargeParameters(BaseModel):
|
||||||
|
"""Parameters for credit recharge tasks."""
|
||||||
|
|
||||||
|
user_id: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Specific user ID to recharge (None for all users)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PlaySoundParameters(BaseModel):
|
||||||
|
"""Parameters for play sound tasks."""
|
||||||
|
|
||||||
|
sound_id: int = Field(description="ID of the sound to play")
|
||||||
|
|
||||||
|
|
||||||
|
class PlayPlaylistParameters(BaseModel):
|
||||||
|
"""Parameters for play playlist tasks."""
|
||||||
|
|
||||||
|
playlist_id: int = Field(description="ID of the playlist to play")
|
||||||
|
play_mode: str = Field(
|
||||||
|
default="continuous",
|
||||||
|
description="Play mode (continuous, loop, loop_one, random, single)",
|
||||||
|
)
|
||||||
|
shuffle: bool = Field(default=False, description="Whether to shuffle the playlist")
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience schemas for creating specific task types
|
||||||
|
class CreateCreditRechargeTask(BaseModel):
|
||||||
|
"""Schema for creating credit recharge tasks."""
|
||||||
|
|
||||||
|
name: str = "Credit Recharge"
|
||||||
|
scheduled_at: datetime
|
||||||
|
timezone: str = "UTC"
|
||||||
|
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||||
|
cron_expression: Optional[str] = None
|
||||||
|
recurrence_count: Optional[int] = None
|
||||||
|
expires_at: Optional[datetime] = None
|
||||||
|
user_id: Optional[int] = None
|
||||||
|
|
||||||
|
def to_task_create(self) -> ScheduledTaskCreate:
|
||||||
|
"""Convert to generic task creation schema."""
|
||||||
|
return ScheduledTaskCreate(
|
||||||
|
name=self.name,
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=self.scheduled_at,
|
||||||
|
timezone=self.timezone,
|
||||||
|
parameters={"user_id": self.user_id},
|
||||||
|
recurrence_type=self.recurrence_type,
|
||||||
|
cron_expression=self.cron_expression,
|
||||||
|
recurrence_count=self.recurrence_count,
|
||||||
|
expires_at=self.expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreatePlaySoundTask(BaseModel):
|
||||||
|
"""Schema for creating play sound tasks."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
scheduled_at: datetime
|
||||||
|
sound_id: int
|
||||||
|
timezone: str = "UTC"
|
||||||
|
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||||
|
cron_expression: Optional[str] = None
|
||||||
|
recurrence_count: Optional[int] = None
|
||||||
|
expires_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
def to_task_create(self) -> ScheduledTaskCreate:
|
||||||
|
"""Convert to generic task creation schema."""
|
||||||
|
return ScheduledTaskCreate(
|
||||||
|
name=self.name,
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=self.scheduled_at,
|
||||||
|
timezone=self.timezone,
|
||||||
|
parameters={"sound_id": self.sound_id},
|
||||||
|
recurrence_type=self.recurrence_type,
|
||||||
|
cron_expression=self.cron_expression,
|
||||||
|
recurrence_count=self.recurrence_count,
|
||||||
|
expires_at=self.expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreatePlayPlaylistTask(BaseModel):
|
||||||
|
"""Schema for creating play playlist tasks."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
scheduled_at: datetime
|
||||||
|
playlist_id: int
|
||||||
|
play_mode: str = "continuous"
|
||||||
|
shuffle: bool = False
|
||||||
|
timezone: str = "UTC"
|
||||||
|
recurrence_type: RecurrenceType = RecurrenceType.NONE
|
||||||
|
cron_expression: Optional[str] = None
|
||||||
|
recurrence_count: Optional[int] = None
|
||||||
|
expires_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
def to_task_create(self) -> ScheduledTaskCreate:
|
||||||
|
"""Convert to generic task creation schema."""
|
||||||
|
return ScheduledTaskCreate(
|
||||||
|
name=self.name,
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=self.scheduled_at,
|
||||||
|
timezone=self.timezone,
|
||||||
|
parameters={
|
||||||
|
"playlist_id": self.playlist_id,
|
||||||
|
"play_mode": self.play_mode,
|
||||||
|
"shuffle": self.shuffle,
|
||||||
|
},
|
||||||
|
recurrence_type=self.recurrence_type,
|
||||||
|
cron_expression=self.cron_expression,
|
||||||
|
recurrence_count=self.recurrence_count,
|
||||||
|
expires_at=self.expires_at,
|
||||||
|
)
|
||||||
@@ -1,63 +1,413 @@
|
|||||||
"""Scheduler service for periodic tasks."""
|
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import pytz
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
from apscheduler.triggers.date import DateTrigger
|
||||||
|
from apscheduler.triggers.interval import IntervalTrigger
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from app.core.logging import get_logger
|
from app.core.logging import get_logger
|
||||||
|
from app.models.scheduled_task import (
|
||||||
|
RecurrenceType,
|
||||||
|
ScheduledTask,
|
||||||
|
TaskStatus,
|
||||||
|
TaskType,
|
||||||
|
)
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
from app.services.credit import CreditService
|
from app.services.credit import CreditService
|
||||||
|
from app.services.player import PlayerService
|
||||||
|
from app.services.task_handlers import TaskHandlerRegistry
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SchedulerService:
|
class SchedulerService:
|
||||||
"""Service for managing scheduled tasks."""
|
"""Enhanced service for managing scheduled tasks with timezone support."""
|
||||||
|
|
||||||
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_session_factory: Callable[[], AsyncSession],
|
||||||
|
player_service: PlayerService,
|
||||||
|
) -> None:
|
||||||
"""Initialize the scheduler service.
|
"""Initialize the scheduler service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_session_factory: Factory function to create database sessions
|
db_session_factory: Factory function to create database sessions
|
||||||
|
player_service: Player service for audio playback tasks
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.db_session_factory = db_session_factory
|
self.db_session_factory = db_session_factory
|
||||||
self.scheduler = AsyncIOScheduler()
|
self.scheduler = AsyncIOScheduler(timezone=pytz.UTC)
|
||||||
self.credit_service = CreditService(db_session_factory)
|
self.credit_service = CreditService(db_session_factory)
|
||||||
|
self.player_service = player_service
|
||||||
|
self._running_tasks: set[str] = set()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the scheduler and register all tasks."""
|
"""Start the scheduler and load all active tasks."""
|
||||||
logger.info("Starting scheduler service...")
|
logger.info("Starting enhanced scheduler service...")
|
||||||
|
|
||||||
# Add daily credit recharge job (runs at midnight UTC)
|
self.scheduler.start()
|
||||||
|
|
||||||
|
# Schedule system tasks initialization for after startup
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self._daily_credit_recharge,
|
self._initialize_system_tasks,
|
||||||
"cron",
|
"date",
|
||||||
hour=0,
|
run_date=datetime.utcnow() + timedelta(seconds=2),
|
||||||
minute=0,
|
id="initialize_system_tasks",
|
||||||
id="daily_credit_recharge",
|
name="Initialize System Tasks",
|
||||||
name="Daily Credit Recharge",
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Schedule periodic cleanup and maintenance
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self._maintenance_job,
|
||||||
|
"interval",
|
||||||
|
minutes=5,
|
||||||
|
id="scheduler_maintenance",
|
||||||
|
name="Scheduler Maintenance",
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scheduler.start()
|
logger.info("Enhanced scheduler service started successfully")
|
||||||
logger.info("Scheduler service started successfully")
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the scheduler."""
|
"""Stop the scheduler."""
|
||||||
logger.info("Stopping scheduler service...")
|
logger.info("Stopping scheduler service...")
|
||||||
self.scheduler.shutdown()
|
self.scheduler.shutdown(wait=True)
|
||||||
logger.info("Scheduler service stopped")
|
logger.info("Scheduler service stopped")
|
||||||
|
|
||||||
async def _daily_credit_recharge(self) -> None:
|
async def create_task(
|
||||||
"""Execute daily credit recharge for all users."""
|
self,
|
||||||
logger.info("Starting daily credit recharge task...")
|
name: str,
|
||||||
|
task_type: TaskType,
|
||||||
|
scheduled_at: datetime,
|
||||||
|
parameters: Optional[Dict[str, Any]] = None,
|
||||||
|
user_id: Optional[int] = None,
|
||||||
|
timezone: str = "UTC",
|
||||||
|
recurrence_type: RecurrenceType = RecurrenceType.NONE,
|
||||||
|
cron_expression: Optional[str] = None,
|
||||||
|
recurrence_count: Optional[int] = None,
|
||||||
|
expires_at: Optional[datetime] = None,
|
||||||
|
) -> ScheduledTask:
|
||||||
|
"""Create a new scheduled task."""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
|
# Convert scheduled_at to UTC if it's in a different timezone
|
||||||
|
if timezone != "UTC":
|
||||||
|
tz = pytz.timezone(timezone)
|
||||||
|
if scheduled_at.tzinfo is None:
|
||||||
|
# Assume the datetime is in the specified timezone
|
||||||
|
scheduled_at = tz.localize(scheduled_at)
|
||||||
|
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
|
task_data = {
|
||||||
|
"name": name,
|
||||||
|
"task_type": task_type,
|
||||||
|
"scheduled_at": scheduled_at,
|
||||||
|
"timezone": timezone,
|
||||||
|
"parameters": parameters or {},
|
||||||
|
"user_id": user_id,
|
||||||
|
"recurrence_type": recurrence_type,
|
||||||
|
"cron_expression": cron_expression,
|
||||||
|
"recurrence_count": recurrence_count,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
created_task = await repo.create(task_data)
|
||||||
|
await self._schedule_apscheduler_job(created_task)
|
||||||
|
|
||||||
|
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
|
||||||
|
return created_task
|
||||||
|
|
||||||
|
async def cancel_task(self, task_id: int) -> bool:
|
||||||
|
"""Cancel a scheduled task."""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
|
task = await repo.get_by_id(task_id)
|
||||||
|
if not task:
|
||||||
|
return False
|
||||||
|
|
||||||
|
task.status = TaskStatus.CANCELLED
|
||||||
|
task.is_active = False
|
||||||
|
await repo.update(task)
|
||||||
|
|
||||||
|
# Remove from APScheduler
|
||||||
|
try:
|
||||||
|
self.scheduler.remove_job(str(task_id))
|
||||||
|
except Exception:
|
||||||
|
pass # Job might not exist in scheduler
|
||||||
|
|
||||||
|
logger.info(f"Cancelled task: {task.name} ({task_id})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_user_tasks(
|
||||||
|
self,
|
||||||
|
user_id: int,
|
||||||
|
status: Optional[TaskStatus] = None,
|
||||||
|
task_type: Optional[TaskType] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
offset: Optional[int] = None,
|
||||||
|
) -> List[ScheduledTask]:
|
||||||
|
"""Get tasks for a specific user."""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
|
||||||
|
|
||||||
|
async def _initialize_system_tasks(self) -> None:
|
||||||
|
"""Initialize system tasks and load active tasks from database."""
|
||||||
|
logger.info("Initializing system tasks...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stats = await self.credit_service.recharge_all_users_credits()
|
# Create system tasks if they don't exist
|
||||||
logger.info(
|
await self._ensure_system_tasks()
|
||||||
"Daily credit recharge completed successfully: %s",
|
|
||||||
stats,
|
# Load all active tasks from database
|
||||||
)
|
await self._load_active_tasks()
|
||||||
|
|
||||||
|
logger.info("System tasks initialized successfully")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Daily credit recharge task failed")
|
logger.exception("Failed to initialize system tasks")
|
||||||
|
|
||||||
|
async def _ensure_system_tasks(self) -> None:
|
||||||
|
"""Ensure required system tasks exist."""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
|
# Check if daily credit recharge task exists
|
||||||
|
system_tasks = await repo.get_system_tasks(
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE
|
||||||
|
)
|
||||||
|
|
||||||
|
daily_recharge_exists = any(
|
||||||
|
task.recurrence_type == RecurrenceType.DAILY
|
||||||
|
and task.is_active
|
||||||
|
for task in system_tasks
|
||||||
|
)
|
||||||
|
|
||||||
|
if not daily_recharge_exists:
|
||||||
|
# Create daily credit recharge task
|
||||||
|
tomorrow_midnight = datetime.utcnow().replace(
|
||||||
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
) + timedelta(days=1)
|
||||||
|
|
||||||
|
task_data = {
|
||||||
|
"name": "Daily Credit Recharge",
|
||||||
|
"task_type": TaskType.CREDIT_RECHARGE,
|
||||||
|
"scheduled_at": tomorrow_midnight,
|
||||||
|
"recurrence_type": RecurrenceType.DAILY,
|
||||||
|
"parameters": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
await repo.create(task_data)
|
||||||
|
logger.info("Created system daily credit recharge task")
|
||||||
|
|
||||||
|
async def _load_active_tasks(self) -> None:
|
||||||
|
"""Load all active tasks from database into scheduler."""
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
active_tasks = await repo.get_active_tasks()
|
||||||
|
|
||||||
|
for task in active_tasks:
|
||||||
|
await self._schedule_apscheduler_job(task)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
|
||||||
|
|
||||||
|
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
|
||||||
|
"""Schedule a task in APScheduler."""
|
||||||
|
job_id = str(task.id)
|
||||||
|
|
||||||
|
# Remove existing job if it exists
|
||||||
|
try:
|
||||||
|
self.scheduler.remove_job(job_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Don't schedule if task is not active or already completed/failed
|
||||||
|
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create trigger based on recurrence type
|
||||||
|
trigger = self._create_trigger(task)
|
||||||
|
if not trigger:
|
||||||
|
logger.warning(f"Could not create trigger for task {task.id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Schedule the job
|
||||||
|
self.scheduler.add_job(
|
||||||
|
self._execute_task,
|
||||||
|
trigger=trigger,
|
||||||
|
args=[task.id],
|
||||||
|
id=job_id,
|
||||||
|
name=task.name,
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Scheduled APScheduler job for task {task.id}")
|
||||||
|
|
||||||
|
def _create_trigger(self, task: ScheduledTask):
|
||||||
|
"""Create APScheduler trigger based on task configuration."""
|
||||||
|
tz = pytz.timezone(task.timezone)
|
||||||
|
|
||||||
|
if task.recurrence_type == RecurrenceType.NONE:
|
||||||
|
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
|
||||||
|
|
||||||
|
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
|
||||||
|
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
|
||||||
|
|
||||||
|
elif task.recurrence_type == RecurrenceType.HOURLY:
|
||||||
|
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
|
||||||
|
|
||||||
|
elif task.recurrence_type == RecurrenceType.DAILY:
|
||||||
|
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
|
||||||
|
|
||||||
|
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
||||||
|
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
|
||||||
|
|
||||||
|
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
||||||
|
# Use cron trigger for monthly (more reliable than interval)
|
||||||
|
scheduled_time = task.scheduled_at
|
||||||
|
return CronTrigger(
|
||||||
|
day=scheduled_time.day,
|
||||||
|
hour=scheduled_time.hour,
|
||||||
|
minute=scheduled_time.minute,
|
||||||
|
timezone=tz
|
||||||
|
)
|
||||||
|
|
||||||
|
elif task.recurrence_type == RecurrenceType.YEARLY:
|
||||||
|
scheduled_time = task.scheduled_at
|
||||||
|
return CronTrigger(
|
||||||
|
month=scheduled_time.month,
|
||||||
|
day=scheduled_time.day,
|
||||||
|
hour=scheduled_time.hour,
|
||||||
|
minute=scheduled_time.minute,
|
||||||
|
timezone=tz
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _execute_task(self, task_id: int) -> None:
|
||||||
|
"""Execute a scheduled task."""
|
||||||
|
task_id_str = str(task_id)
|
||||||
|
|
||||||
|
# Prevent concurrent execution of the same task
|
||||||
|
if task_id_str in self._running_tasks:
|
||||||
|
logger.warning(f"Task {task_id} is already running, skipping execution")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running_tasks.add(task_id_str)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
|
# Get fresh task data
|
||||||
|
task = await repo.get_by_id(task_id)
|
||||||
|
if not task:
|
||||||
|
logger.warning(f"Task {task_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if task is still active and pending
|
||||||
|
if not task.is_active or task.status != TaskStatus.PENDING:
|
||||||
|
logger.info(f"Task {task_id} is not active or not pending, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if task has expired
|
||||||
|
if task.is_expired():
|
||||||
|
logger.info(f"Task {task_id} has expired, marking as cancelled")
|
||||||
|
task.status = TaskStatus.CANCELLED
|
||||||
|
task.is_active = False
|
||||||
|
await repo.update(task)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Mark task as running
|
||||||
|
await repo.mark_as_running(task)
|
||||||
|
|
||||||
|
# Execute the task
|
||||||
|
try:
|
||||||
|
handler_registry = TaskHandlerRegistry(
|
||||||
|
session, self.credit_service, self.player_service
|
||||||
|
)
|
||||||
|
await handler_registry.execute_task(task)
|
||||||
|
|
||||||
|
# Calculate next execution time for recurring tasks
|
||||||
|
next_execution_at = None
|
||||||
|
if task.should_repeat():
|
||||||
|
next_execution_at = self._calculate_next_execution(task)
|
||||||
|
|
||||||
|
# Mark as completed
|
||||||
|
await repo.mark_as_completed(task, next_execution_at)
|
||||||
|
|
||||||
|
# Reschedule if recurring
|
||||||
|
if next_execution_at and task.should_repeat():
|
||||||
|
# Refresh task to get updated data
|
||||||
|
await session.refresh(task)
|
||||||
|
await self._schedule_apscheduler_job(task)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await repo.mark_as_failed(task, str(e))
|
||||||
|
logger.exception(f"Task {task_id} execution failed: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._running_tasks.discard(task_id_str)
|
||||||
|
|
||||||
|
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
|
||||||
|
"""Calculate the next execution time for a recurring task."""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
if task.recurrence_type == RecurrenceType.HOURLY:
|
||||||
|
return now + timedelta(hours=1)
|
||||||
|
elif task.recurrence_type == RecurrenceType.DAILY:
|
||||||
|
return now + timedelta(days=1)
|
||||||
|
elif task.recurrence_type == RecurrenceType.WEEKLY:
|
||||||
|
return now + timedelta(weeks=1)
|
||||||
|
elif task.recurrence_type == RecurrenceType.MONTHLY:
|
||||||
|
# Add approximately one month
|
||||||
|
return now + timedelta(days=30)
|
||||||
|
elif task.recurrence_type == RecurrenceType.YEARLY:
|
||||||
|
return now + timedelta(days=365)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _maintenance_job(self) -> None:
|
||||||
|
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
|
||||||
|
try:
|
||||||
|
async with self.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
|
# Handle expired tasks
|
||||||
|
expired_tasks = await repo.get_expired_tasks()
|
||||||
|
for task in expired_tasks:
|
||||||
|
task.status = TaskStatus.CANCELLED
|
||||||
|
task.is_active = False
|
||||||
|
await repo.update(task)
|
||||||
|
|
||||||
|
# Remove from scheduler
|
||||||
|
try:
|
||||||
|
self.scheduler.remove_job(str(task.id))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if expired_tasks:
|
||||||
|
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
|
||||||
|
|
||||||
|
# Handle any missed recurring tasks
|
||||||
|
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
|
||||||
|
for task in due_recurring:
|
||||||
|
if task.should_repeat():
|
||||||
|
task.status = TaskStatus.PENDING
|
||||||
|
task.scheduled_at = task.next_execution_at or datetime.utcnow()
|
||||||
|
await repo.update(task)
|
||||||
|
await self._schedule_apscheduler_job(task)
|
||||||
|
|
||||||
|
if due_recurring:
|
||||||
|
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Maintenance job failed")
|
||||||
|
|||||||
137
app/services/task_handlers.py
Normal file
137
app/services/task_handlers.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Task execution handlers for different task types."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.core.logging import get_logger
|
||||||
|
from app.models.scheduled_task import ScheduledTask, TaskType
|
||||||
|
from app.repositories.playlist import PlaylistRepository
|
||||||
|
from app.repositories.sound import SoundRepository
|
||||||
|
from app.services.credit import CreditService
|
||||||
|
from app.services.player import PlayerService
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutionError(Exception):
|
||||||
|
"""Exception raised when task execution fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TaskHandlerRegistry:
|
||||||
|
"""Registry for task execution handlers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
credit_service: CreditService,
|
||||||
|
player_service: PlayerService,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the task handler registry."""
|
||||||
|
self.db_session = db_session
|
||||||
|
self.credit_service = credit_service
|
||||||
|
self.player_service = player_service
|
||||||
|
self.sound_repository = SoundRepository(db_session)
|
||||||
|
self.playlist_repository = PlaylistRepository(db_session)
|
||||||
|
|
||||||
|
# Register handlers
|
||||||
|
self._handlers = {
|
||||||
|
TaskType.CREDIT_RECHARGE: self._handle_credit_recharge,
|
||||||
|
TaskType.PLAY_SOUND: self._handle_play_sound,
|
||||||
|
TaskType.PLAY_PLAYLIST: self._handle_play_playlist,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute_task(self, task: ScheduledTask) -> None:
|
||||||
|
"""Execute a task based on its type."""
|
||||||
|
handler = self._handlers.get(task.task_type)
|
||||||
|
if not handler:
|
||||||
|
raise TaskExecutionError(f"No handler registered for task type: {task.task_type}")
|
||||||
|
|
||||||
|
logger.info(f"Executing task {task.id} ({task.task_type.value}): {task.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await handler(task)
|
||||||
|
logger.info(f"Task {task.id} executed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Task {task.id} execution failed: {str(e)}")
|
||||||
|
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
|
||||||
|
|
||||||
|
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
|
||||||
|
"""Handle credit recharge task."""
|
||||||
|
parameters = task.parameters
|
||||||
|
user_id = parameters.get("user_id")
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
# Recharge specific user
|
||||||
|
user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
|
||||||
|
stats = await self.credit_service.recharge_user_credits(user_uuid)
|
||||||
|
logger.info(f"Recharged credits for user {user_id}: {stats}")
|
||||||
|
else:
|
||||||
|
# Recharge all users (system task)
|
||||||
|
stats = await self.credit_service.recharge_all_users_credits()
|
||||||
|
logger.info(f"Recharged credits for all users: {stats}")
|
||||||
|
|
||||||
|
async def _handle_play_sound(self, task: ScheduledTask) -> None:
|
||||||
|
"""Handle play sound task."""
|
||||||
|
parameters = task.parameters
|
||||||
|
sound_id = parameters.get("sound_id")
|
||||||
|
|
||||||
|
if not sound_id:
|
||||||
|
raise TaskExecutionError("sound_id parameter is required for PLAY_SOUND tasks")
|
||||||
|
|
||||||
|
try:
|
||||||
|
sound_uuid = uuid.UUID(sound_id) if isinstance(sound_id, str) else sound_id
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
raise TaskExecutionError(f"Invalid sound_id format: {sound_id}") from e
|
||||||
|
|
||||||
|
# Get the sound from database
|
||||||
|
sound = await self.sound_repository.get_by_id(sound_uuid)
|
||||||
|
if not sound:
|
||||||
|
raise TaskExecutionError(f"Sound not found: {sound_id}")
|
||||||
|
|
||||||
|
# Play the sound through VLC
|
||||||
|
from app.services.vlc_player import VLCPlayerService
|
||||||
|
|
||||||
|
vlc_service = VLCPlayerService(lambda: self.db_session)
|
||||||
|
await vlc_service.play_sound(sound)
|
||||||
|
|
||||||
|
logger.info(f"Played sound {sound.filename} via scheduled task")
|
||||||
|
|
||||||
|
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
|
||||||
|
"""Handle play playlist task."""
|
||||||
|
parameters = task.parameters
|
||||||
|
playlist_id = parameters.get("playlist_id")
|
||||||
|
play_mode = parameters.get("play_mode", "continuous")
|
||||||
|
shuffle = parameters.get("shuffle", False)
|
||||||
|
|
||||||
|
if not playlist_id:
|
||||||
|
raise TaskExecutionError("playlist_id parameter is required for PLAY_PLAYLIST tasks")
|
||||||
|
|
||||||
|
try:
|
||||||
|
playlist_uuid = uuid.UUID(playlist_id) if isinstance(playlist_id, str) else playlist_id
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
raise TaskExecutionError(f"Invalid playlist_id format: {playlist_id}") from e
|
||||||
|
|
||||||
|
# Get the playlist from database
|
||||||
|
playlist = await self.playlist_repository.get_by_id(playlist_uuid)
|
||||||
|
if not playlist:
|
||||||
|
raise TaskExecutionError(f"Playlist not found: {playlist_id}")
|
||||||
|
|
||||||
|
# Load playlist in player
|
||||||
|
await self.player_service.load_playlist(playlist_uuid)
|
||||||
|
|
||||||
|
# Set play mode if specified
|
||||||
|
if play_mode in ["continuous", "loop", "loop_one", "random", "single"]:
|
||||||
|
self.player_service.set_mode(play_mode)
|
||||||
|
|
||||||
|
# Enable shuffle if requested
|
||||||
|
if shuffle:
|
||||||
|
self.player_service.set_shuffle(True)
|
||||||
|
|
||||||
|
# Start playing
|
||||||
|
self.player_service.play()
|
||||||
|
|
||||||
|
logger.info(f"Started playing playlist {playlist.name} via scheduled task")
|
||||||
@@ -238,75 +238,76 @@ class VLCPlayerService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Recording play count for sound %s", sound_id)
|
logger.info("Recording play count for sound %s", sound_id)
|
||||||
session = self.db_session_factory()
|
|
||||||
|
# Initialize variables for WebSocket event
|
||||||
|
old_count = 0
|
||||||
|
sound = None
|
||||||
|
admin_user_id = None
|
||||||
|
admin_user_name = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sound_repo = SoundRepository(session)
|
async with self.db_session_factory() as session:
|
||||||
user_repo = UserRepository(session)
|
sound_repo = SoundRepository(session)
|
||||||
|
user_repo = UserRepository(session)
|
||||||
|
|
||||||
# Update sound play count
|
# Update sound play count
|
||||||
sound = await sound_repo.get_by_id(sound_id)
|
sound = await sound_repo.get_by_id(sound_id)
|
||||||
old_count = 0
|
if sound:
|
||||||
if sound:
|
old_count = sound.play_count
|
||||||
old_count = sound.play_count
|
# Update the sound's play count using direct attribute modification
|
||||||
await sound_repo.update(
|
sound.play_count = sound.play_count + 1
|
||||||
sound,
|
session.add(sound)
|
||||||
{"play_count": sound.play_count + 1},
|
await session.commit()
|
||||||
|
await session.refresh(sound)
|
||||||
|
logger.info(
|
||||||
|
"Updated sound %s play_count: %s -> %s",
|
||||||
|
sound_id,
|
||||||
|
old_count,
|
||||||
|
old_count + 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("Sound %s not found for play count update", sound_id)
|
||||||
|
|
||||||
|
# Record play history for admin user (ID 1) as placeholder
|
||||||
|
# This could be refined to track per-user play history
|
||||||
|
admin_user = await user_repo.get_by_id(1)
|
||||||
|
if admin_user:
|
||||||
|
admin_user_id = admin_user.id
|
||||||
|
admin_user_name = admin_user.name
|
||||||
|
|
||||||
|
# Always create a new SoundPlayed record for each play event
|
||||||
|
sound_played = SoundPlayed(
|
||||||
|
user_id=admin_user_id, # Can be None for player-based plays
|
||||||
|
sound_id=sound_id,
|
||||||
)
|
)
|
||||||
|
session.add(sound_played)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Updated sound %s play_count: %s -> %s",
|
"Created SoundPlayed record for user %s, sound %s",
|
||||||
sound_id,
|
admin_user_id,
|
||||||
old_count,
|
|
||||||
old_count + 1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning("Sound %s not found for play count update", sound_id)
|
|
||||||
|
|
||||||
# Record play history for admin user (ID 1) as placeholder
|
|
||||||
# This could be refined to track per-user play history
|
|
||||||
admin_user = await user_repo.get_by_id(1)
|
|
||||||
admin_user_id = None
|
|
||||||
admin_user_name = None
|
|
||||||
if admin_user:
|
|
||||||
admin_user_id = admin_user.id
|
|
||||||
admin_user_name = admin_user.name
|
|
||||||
|
|
||||||
# Always create a new SoundPlayed record for each play event
|
|
||||||
sound_played = SoundPlayed(
|
|
||||||
user_id=admin_user_id, # Can be None for player-based plays
|
|
||||||
sound_id=sound_id,
|
|
||||||
)
|
|
||||||
session.add(sound_played)
|
|
||||||
logger.info(
|
|
||||||
"Created SoundPlayed record for user %s, sound %s",
|
|
||||||
admin_user_id,
|
|
||||||
sound_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
logger.info("Successfully recorded play count for sound %s", sound_id)
|
|
||||||
|
|
||||||
# Emit sound_played event via WebSocket
|
|
||||||
try:
|
|
||||||
event_data = {
|
|
||||||
"sound_id": sound_id,
|
|
||||||
"sound_name": sound_name,
|
|
||||||
"user_id": admin_user_id,
|
|
||||||
"user_name": admin_user_name,
|
|
||||||
"play_count": (old_count + 1) if sound else None,
|
|
||||||
}
|
|
||||||
await socket_manager.broadcast_to_all("sound_played", event_data)
|
|
||||||
logger.info("Broadcasted sound_played event for sound %s", sound_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to broadcast sound_played event for sound %s",
|
|
||||||
sound_id,
|
sound_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
logger.info("Successfully recorded play count for sound %s", sound_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error recording play count for sound %s", sound_id)
|
logger.exception("Error recording play count for sound %s", sound_id)
|
||||||
await session.rollback()
|
|
||||||
finally:
|
# Emit sound_played event via WebSocket (outside session context)
|
||||||
await session.close()
|
try:
|
||||||
|
event_data = {
|
||||||
|
"sound_id": sound_id,
|
||||||
|
"sound_name": sound_name,
|
||||||
|
"user_id": admin_user_id,
|
||||||
|
"user_name": admin_user_name,
|
||||||
|
"play_count": (old_count + 1) if sound else None,
|
||||||
|
}
|
||||||
|
await socket_manager.broadcast_to_all("sound_played", event_data)
|
||||||
|
logger.info("Broadcasted sound_played event for sound %s", sound_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to broadcast sound_played event for sound %s",
|
||||||
|
sound_id,
|
||||||
|
)
|
||||||
|
|
||||||
async def play_sound_with_credits(
|
async def play_sound_with_credits(
|
||||||
self,
|
self,
|
||||||
|
|||||||
51
check_tasks.py
Normal file
51
check_tasks.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Check current tasks in the database."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.core.database import get_session_factory
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
|
async def check_tasks():
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
|
# Get all tasks
|
||||||
|
all_tasks = await repo.get_all(limit=20)
|
||||||
|
|
||||||
|
print("All tasks in database:")
|
||||||
|
print("=" * 80)
|
||||||
|
for task in all_tasks:
|
||||||
|
print(f"ID: {task.id}")
|
||||||
|
print(f"Name: {task.name}")
|
||||||
|
print(f"Type: {task.task_type}")
|
||||||
|
print(f"Status: {task.status}")
|
||||||
|
print(f"Scheduled: {task.scheduled_at}")
|
||||||
|
print(f"Timezone: {task.timezone}")
|
||||||
|
print(f"Active: {task.is_active}")
|
||||||
|
print(f"User ID: {task.user_id}")
|
||||||
|
print(f"Executions: {task.executions_count}")
|
||||||
|
print(f"Last executed: {task.last_executed_at}")
|
||||||
|
print(f"Error: {task.error_message}")
|
||||||
|
print(f"Parameters: {task.parameters}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
# Check specifically for pending tasks
|
||||||
|
print(f"\nCurrent time: {datetime.utcnow()}")
|
||||||
|
print("\nPending tasks:")
|
||||||
|
from app.models.scheduled_task import TaskStatus
|
||||||
|
pending_tasks = await repo.get_all(limit=10)
|
||||||
|
pending_tasks = [t for t in pending_tasks if t.status == TaskStatus.PENDING and t.is_active]
|
||||||
|
|
||||||
|
if not pending_tasks:
|
||||||
|
print("No pending tasks found")
|
||||||
|
else:
|
||||||
|
for task in pending_tasks:
|
||||||
|
time_diff = task.scheduled_at - datetime.utcnow()
|
||||||
|
print(f"- {task.name} (ID: {task.id}): scheduled for {task.scheduled_at} (in {time_diff})")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(check_tasks())
|
||||||
36
create_future_task.py
Normal file
36
create_future_task.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Create a test task with a future execution time."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from app.core.database import get_session_factory
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
from app.models.scheduled_task import TaskType, RecurrenceType
|
||||||
|
|
||||||
|
async def create_future_task():
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
|
||||||
|
# Create a task for 2 minutes from now
|
||||||
|
future_time = datetime.utcnow() + timedelta(minutes=2)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
|
task_data = {
|
||||||
|
"name": f"Future Task {future_time.strftime('%H:%M:%S')}",
|
||||||
|
"task_type": TaskType.PLAY_SOUND,
|
||||||
|
"scheduled_at": future_time,
|
||||||
|
"timezone": "UTC",
|
||||||
|
"parameters": {"sound_id": 1},
|
||||||
|
"user_id": 1,
|
||||||
|
"recurrence_type": RecurrenceType.NONE,
|
||||||
|
}
|
||||||
|
|
||||||
|
task = await repo.create(task_data)
|
||||||
|
print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}")
|
||||||
|
print(f"Current time: {datetime.utcnow()}")
|
||||||
|
print(f"Task will execute in: {future_time - datetime.utcnow()}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(create_future_task())
|
||||||
@@ -15,6 +15,7 @@ dependencies = [
|
|||||||
"pydantic-settings==2.10.1",
|
"pydantic-settings==2.10.1",
|
||||||
"pyjwt==2.10.1",
|
"pyjwt==2.10.1",
|
||||||
"python-socketio==5.13.0",
|
"python-socketio==5.13.0",
|
||||||
|
"pytz==2024.1",
|
||||||
"python-vlc==3.0.21203",
|
"python-vlc==3.0.21203",
|
||||||
"sqlmodel==0.0.24",
|
"sqlmodel==0.0.24",
|
||||||
"uvicorn[standard]==0.35.0",
|
"uvicorn[standard]==0.35.0",
|
||||||
|
|||||||
40
test_api_task.py
Normal file
40
test_api_task.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Test creating a task via the scheduler service to simulate API call."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from app.core.database import get_session_factory
|
||||||
|
from app.main import get_global_scheduler_service
|
||||||
|
from app.models.scheduled_task import TaskType, RecurrenceType
|
||||||
|
|
||||||
|
async def test_api_task_creation():
|
||||||
|
"""Test creating a task through the scheduler service (simulates API call)."""
|
||||||
|
try:
|
||||||
|
scheduler_service = get_global_scheduler_service()
|
||||||
|
|
||||||
|
# Create a task for 2 minutes from now
|
||||||
|
future_time = datetime.utcnow() + timedelta(minutes=2)
|
||||||
|
|
||||||
|
print(f"Creating task scheduled for: {future_time}")
|
||||||
|
print(f"Current time: {datetime.utcnow()}")
|
||||||
|
|
||||||
|
task = await scheduler_service.create_task(
|
||||||
|
name=f"API Test Task {future_time.strftime('%H:%M:%S')}",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=future_time,
|
||||||
|
parameters={"sound_id": 1},
|
||||||
|
user_id=1,
|
||||||
|
timezone="UTC",
|
||||||
|
recurrence_type=RecurrenceType.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Created task: {task.name} (ID: {task.id})")
|
||||||
|
print(f"Task will execute in: {future_time - datetime.utcnow()}")
|
||||||
|
print("Task should be automatically scheduled in APScheduler!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_api_task_creation())
|
||||||
31
test_task.py
Normal file
31
test_task.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Create a test task for scheduler testing."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.core.database import get_session_factory
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
from app.models.scheduled_task import TaskType, RecurrenceType
|
||||||
|
|
||||||
|
async def create_test_task():
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
|
||||||
|
task_data = {
|
||||||
|
"name": "Live Test Task",
|
||||||
|
"task_type": TaskType.PLAY_SOUND,
|
||||||
|
"scheduled_at": datetime(2025, 8, 28, 15, 21, 0), # 15:21:00 UTC
|
||||||
|
"timezone": "UTC",
|
||||||
|
"parameters": {"sound_id": 1},
|
||||||
|
"user_id": 1,
|
||||||
|
"recurrence_type": RecurrenceType.NONE,
|
||||||
|
}
|
||||||
|
|
||||||
|
task = await repo.create(task_data)
|
||||||
|
print(f"Created task: {task.name} (ID: {task.id}) scheduled for {task.scheduled_at}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(create_test_task())
|
||||||
@@ -25,6 +25,7 @@ from app.models.favorite import Favorite # noqa: F401
|
|||||||
from app.models.plan import Plan
|
from app.models.plan import Plan
|
||||||
from app.models.playlist import Playlist # noqa: F401
|
from app.models.playlist import Playlist # noqa: F401
|
||||||
from app.models.playlist_sound import PlaylistSound # noqa: F401
|
from app.models.playlist_sound import PlaylistSound # noqa: F401
|
||||||
|
from app.models.scheduled_task import ScheduledTask # noqa: F401
|
||||||
from app.models.sound import Sound # noqa: F401
|
from app.models.sound import Sound # noqa: F401
|
||||||
from app.models.sound_played import SoundPlayed # noqa: F401
|
from app.models.sound_played import SoundPlayed # noqa: F401
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
@@ -346,3 +347,29 @@ async def admin_cookies(admin_user: User) -> dict[str, str]:
|
|||||||
access_token = JWTUtils.create_access_token(token_data)
|
access_token = JWTUtils.create_access_token(token_data)
|
||||||
|
|
||||||
return {"access_token": access_token}
|
return {"access_token": access_token}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_user_id(test_user: User):
|
||||||
|
"""Get test user ID."""
|
||||||
|
return test_user.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_sound_id():
|
||||||
|
"""Create a test sound ID."""
|
||||||
|
import uuid
|
||||||
|
return uuid.uuid4()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_playlist_id():
|
||||||
|
"""Create a test playlist ID."""
|
||||||
|
import uuid
|
||||||
|
return uuid.uuid4()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_session(test_session: AsyncSession) -> AsyncSession:
|
||||||
|
"""Alias for test_session to match test expectations."""
|
||||||
|
return test_session
|
||||||
|
|||||||
220
tests/test_scheduled_task_model.py
Normal file
220
tests/test_scheduled_task_model.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""Tests for scheduled task model."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.scheduled_task import (
|
||||||
|
RecurrenceType,
|
||||||
|
ScheduledTask,
|
||||||
|
TaskStatus,
|
||||||
|
TaskType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScheduledTaskModel:
|
||||||
|
"""Test cases for scheduled task model."""
|
||||||
|
|
||||||
|
def test_task_creation(self):
|
||||||
|
"""Test basic task creation."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Test Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.name == "Test Task"
|
||||||
|
assert task.task_type == TaskType.CREDIT_RECHARGE
|
||||||
|
assert task.status == TaskStatus.PENDING
|
||||||
|
assert task.timezone == "UTC"
|
||||||
|
assert task.recurrence_type == RecurrenceType.NONE
|
||||||
|
assert task.parameters == {}
|
||||||
|
assert task.user_id is None
|
||||||
|
assert task.executions_count == 0
|
||||||
|
assert task.is_active is True
|
||||||
|
|
||||||
|
def test_task_with_user(self):
|
||||||
|
"""Test task creation with user association."""
|
||||||
|
user_id = uuid.uuid4()
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="User Task",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.user_id == user_id
|
||||||
|
assert not task.is_system_task()
|
||||||
|
|
||||||
|
def test_system_task(self):
|
||||||
|
"""Test system task (no user association)."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="System Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.user_id is None
|
||||||
|
assert task.is_system_task()
|
||||||
|
|
||||||
|
def test_recurring_task(self):
|
||||||
|
"""Test recurring task properties."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Recurring Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
recurrence_count=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.is_recurring()
|
||||||
|
assert task.should_repeat()
|
||||||
|
|
||||||
|
def test_non_recurring_task(self):
|
||||||
|
"""Test non-recurring task properties."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="One-shot Task",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not task.is_recurring()
|
||||||
|
assert not task.should_repeat()
|
||||||
|
|
||||||
|
def test_infinite_recurring_task(self):
|
||||||
|
"""Test infinitely recurring task."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Infinite Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
recurrence_count=None, # Infinite
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.is_recurring()
|
||||||
|
assert task.should_repeat()
|
||||||
|
|
||||||
|
# Even after many executions
|
||||||
|
task.executions_count = 100
|
||||||
|
assert task.should_repeat()
|
||||||
|
|
||||||
|
def test_recurring_task_execution_limit(self):
|
||||||
|
"""Test recurring task with execution limit."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Limited Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
recurrence_count=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.should_repeat()
|
||||||
|
|
||||||
|
# After 3 executions, should not repeat
|
||||||
|
task.executions_count = 3
|
||||||
|
assert not task.should_repeat()
|
||||||
|
|
||||||
|
# After more than limit, still should not repeat
|
||||||
|
task.executions_count = 5
|
||||||
|
assert not task.should_repeat()
|
||||||
|
|
||||||
|
def test_task_expiration(self):
|
||||||
|
"""Test task expiration."""
|
||||||
|
# Non-expired task
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Valid Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
expires_at=datetime.utcnow() + timedelta(hours=2),
|
||||||
|
)
|
||||||
|
assert not task.is_expired()
|
||||||
|
|
||||||
|
# Expired task
|
||||||
|
expired_task = ScheduledTask(
|
||||||
|
name="Expired Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
)
|
||||||
|
assert expired_task.is_expired()
|
||||||
|
|
||||||
|
# Task with no expiration
|
||||||
|
no_expiry_task = ScheduledTask(
|
||||||
|
name="No Expiry Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
assert not no_expiry_task.is_expired()
|
||||||
|
|
||||||
|
def test_task_with_parameters(self):
|
||||||
|
"""Test task with custom parameters."""
|
||||||
|
parameters = {
|
||||||
|
"sound_id": str(uuid.uuid4()),
|
||||||
|
"volume": 80,
|
||||||
|
"repeat": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Parametrized Task",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
parameters=parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.parameters == parameters
|
||||||
|
assert task.parameters["sound_id"] == parameters["sound_id"]
|
||||||
|
assert task.parameters["volume"] == 80
|
||||||
|
assert task.parameters["repeat"] is True
|
||||||
|
|
||||||
|
def test_task_with_timezone(self):
|
||||||
|
"""Test task with custom timezone."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="NY Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
timezone="America/New_York",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.timezone == "America/New_York"
|
||||||
|
|
||||||
|
def test_task_with_cron_expression(self):
|
||||||
|
"""Test task with cron expression."""
|
||||||
|
cron_expr = "0 9 * * 1-5" # 9 AM on weekdays
|
||||||
|
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Cron Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.CRON,
|
||||||
|
cron_expression=cron_expr,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.recurrence_type == RecurrenceType.CRON
|
||||||
|
assert task.cron_expression == cron_expr
|
||||||
|
assert task.is_recurring()
|
||||||
|
|
||||||
|
def test_task_status_enum_values(self):
|
||||||
|
"""Test all task status enum values."""
|
||||||
|
assert TaskStatus.PENDING == "pending"
|
||||||
|
assert TaskStatus.RUNNING == "running"
|
||||||
|
assert TaskStatus.COMPLETED == "completed"
|
||||||
|
assert TaskStatus.FAILED == "failed"
|
||||||
|
assert TaskStatus.CANCELLED == "cancelled"
|
||||||
|
|
||||||
|
def test_task_type_enum_values(self):
|
||||||
|
"""Test all task type enum values."""
|
||||||
|
assert TaskType.CREDIT_RECHARGE == "credit_recharge"
|
||||||
|
assert TaskType.PLAY_SOUND == "play_sound"
|
||||||
|
assert TaskType.PLAY_PLAYLIST == "play_playlist"
|
||||||
|
|
||||||
|
def test_recurrence_type_enum_values(self):
|
||||||
|
"""Test all recurrence type enum values."""
|
||||||
|
assert RecurrenceType.NONE == "none"
|
||||||
|
assert RecurrenceType.HOURLY == "hourly"
|
||||||
|
assert RecurrenceType.DAILY == "daily"
|
||||||
|
assert RecurrenceType.WEEKLY == "weekly"
|
||||||
|
assert RecurrenceType.MONTHLY == "monthly"
|
||||||
|
assert RecurrenceType.YEARLY == "yearly"
|
||||||
|
assert RecurrenceType.CRON == "cron"
|
||||||
494
tests/test_scheduled_task_repository.py
Normal file
494
tests/test_scheduled_task_repository.py
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
"""Tests for scheduled task repository."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.scheduled_task import (
|
||||||
|
RecurrenceType,
|
||||||
|
ScheduledTask,
|
||||||
|
TaskStatus,
|
||||||
|
TaskType,
|
||||||
|
)
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
|
||||||
|
|
||||||
|
class TestScheduledTaskRepository:
|
||||||
|
"""Test cases for scheduled task repository."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def repository(self, db_session: AsyncSession) -> ScheduledTaskRepository:
|
||||||
|
"""Create repository fixture."""
|
||||||
|
return ScheduledTaskRepository(db_session)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def sample_task(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
) -> ScheduledTask:
|
||||||
|
"""Create a sample scheduled task."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Test Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
parameters={"test": "value"},
|
||||||
|
)
|
||||||
|
return await repository.create(task)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def user_task(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
) -> ScheduledTask:
|
||||||
|
"""Create a user task."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="User Task",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=2),
|
||||||
|
user_id=test_user_id,
|
||||||
|
parameters={"sound_id": str(uuid.uuid4())},
|
||||||
|
)
|
||||||
|
return await repository.create(task)
|
||||||
|
|
||||||
|
async def test_create_task(self, repository: ScheduledTaskRepository):
|
||||||
|
"""Test creating a scheduled task."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Test Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
timezone="America/New_York",
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
parameters={"test": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
created_task = await repository.create(task)
|
||||||
|
|
||||||
|
assert created_task.id is not None
|
||||||
|
assert created_task.name == "Test Task"
|
||||||
|
assert created_task.task_type == TaskType.CREDIT_RECHARGE
|
||||||
|
assert created_task.status == TaskStatus.PENDING
|
||||||
|
assert created_task.timezone == "America/New_York"
|
||||||
|
assert created_task.recurrence_type == RecurrenceType.DAILY
|
||||||
|
assert created_task.parameters == {"test": "value"}
|
||||||
|
assert created_task.is_active is True
|
||||||
|
assert created_task.executions_count == 0
|
||||||
|
|
||||||
|
async def test_get_pending_tasks(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
):
|
||||||
|
"""Test getting pending tasks."""
|
||||||
|
# Create tasks with different statuses and times
|
||||||
|
past_pending = ScheduledTask(
|
||||||
|
name="Past Pending",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
status=TaskStatus.PENDING,
|
||||||
|
)
|
||||||
|
await repository.create(past_pending)
|
||||||
|
|
||||||
|
future_pending = ScheduledTask(
|
||||||
|
name="Future Pending",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
status=TaskStatus.PENDING,
|
||||||
|
)
|
||||||
|
await repository.create(future_pending)
|
||||||
|
|
||||||
|
completed_task = ScheduledTask(
|
||||||
|
name="Completed",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
)
|
||||||
|
await repository.create(completed_task)
|
||||||
|
|
||||||
|
inactive_task = ScheduledTask(
|
||||||
|
name="Inactive",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
status=TaskStatus.PENDING,
|
||||||
|
is_active=False,
|
||||||
|
)
|
||||||
|
await repository.create(inactive_task)
|
||||||
|
|
||||||
|
pending_tasks = await repository.get_pending_tasks()
|
||||||
|
task_names = [task.name for task in pending_tasks]
|
||||||
|
|
||||||
|
# Only the past pending task should be returned
|
||||||
|
assert len(pending_tasks) == 1
|
||||||
|
assert "Past Pending" in task_names
|
||||||
|
|
||||||
|
async def test_get_user_tasks(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
user_task: ScheduledTask,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test getting tasks for a specific user."""
|
||||||
|
# Create another user's task
|
||||||
|
other_user_id = uuid.uuid4()
|
||||||
|
other_task = ScheduledTask(
|
||||||
|
name="Other User Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
user_id=other_user_id,
|
||||||
|
)
|
||||||
|
await repository.create(other_task)
|
||||||
|
|
||||||
|
# Create system task (no user)
|
||||||
|
system_task = ScheduledTask(
|
||||||
|
name="System Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
await repository.create(system_task)
|
||||||
|
|
||||||
|
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||||
|
|
||||||
|
assert len(user_tasks) == 1
|
||||||
|
assert user_tasks[0].name == "User Task"
|
||||||
|
assert user_tasks[0].user_id == test_user_id
|
||||||
|
|
||||||
|
async def test_get_user_tasks_with_filters(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test getting user tasks with status and type filters."""
|
||||||
|
# Create tasks with different statuses and types
|
||||||
|
tasks_data = [
|
||||||
|
("Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
|
||||||
|
("Task 2", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
|
||||||
|
("Task 3", TaskStatus.PENDING, TaskType.PLAY_SOUND),
|
||||||
|
("Task 4", TaskStatus.FAILED, TaskType.PLAY_PLAYLIST),
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, status, task_type in tasks_data:
|
||||||
|
task = ScheduledTask(
|
||||||
|
name=name,
|
||||||
|
task_type=task_type,
|
||||||
|
status=status,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
await repository.create(task)
|
||||||
|
|
||||||
|
# Test status filter
|
||||||
|
pending_tasks = await repository.get_user_tasks(
|
||||||
|
test_user_id,
|
||||||
|
status=TaskStatus.PENDING,
|
||||||
|
)
|
||||||
|
assert len(pending_tasks) == 2
|
||||||
|
assert all(task.status == TaskStatus.PENDING for task in pending_tasks)
|
||||||
|
|
||||||
|
# Test type filter
|
||||||
|
credit_tasks = await repository.get_user_tasks(
|
||||||
|
test_user_id,
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
)
|
||||||
|
assert len(credit_tasks) == 2
|
||||||
|
assert all(task.task_type == TaskType.CREDIT_RECHARGE for task in credit_tasks)
|
||||||
|
|
||||||
|
# Test combined filters
|
||||||
|
pending_credit_tasks = await repository.get_user_tasks(
|
||||||
|
test_user_id,
|
||||||
|
status=TaskStatus.PENDING,
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
)
|
||||||
|
assert len(pending_credit_tasks) == 1
|
||||||
|
assert pending_credit_tasks[0].name == "Task 1"
|
||||||
|
|
||||||
|
async def test_get_system_tasks(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
sample_task: ScheduledTask,
|
||||||
|
user_task: ScheduledTask,
|
||||||
|
):
|
||||||
|
"""Test getting system tasks."""
|
||||||
|
system_tasks = await repository.get_system_tasks()
|
||||||
|
|
||||||
|
assert len(system_tasks) == 1
|
||||||
|
assert system_tasks[0].name == "Test Task"
|
||||||
|
assert system_tasks[0].user_id is None
|
||||||
|
|
||||||
|
async def test_get_recurring_tasks_due_for_next_execution(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
):
|
||||||
|
"""Test getting recurring tasks due for next execution."""
|
||||||
|
# Create completed recurring task that should be re-executed
|
||||||
|
due_task = ScheduledTask(
|
||||||
|
name="Due Recurring",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
next_execution_at=datetime.utcnow() - timedelta(minutes=5),
|
||||||
|
)
|
||||||
|
await repository.create(due_task)
|
||||||
|
|
||||||
|
# Create completed recurring task not due yet
|
||||||
|
not_due_task = ScheduledTask(
|
||||||
|
name="Not Due Recurring",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
next_execution_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
await repository.create(not_due_task)
|
||||||
|
|
||||||
|
# Create non-recurring completed task
|
||||||
|
non_recurring = ScheduledTask(
|
||||||
|
name="Non-recurring",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.NONE,
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
)
|
||||||
|
await repository.create(non_recurring)
|
||||||
|
|
||||||
|
due_tasks = await repository.get_recurring_tasks_due_for_next_execution()
|
||||||
|
|
||||||
|
assert len(due_tasks) == 1
|
||||||
|
assert due_tasks[0].name == "Due Recurring"
|
||||||
|
|
||||||
|
async def test_get_expired_tasks(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
):
|
||||||
|
"""Test getting expired tasks."""
|
||||||
|
# Create expired task
|
||||||
|
expired_task = ScheduledTask(
|
||||||
|
name="Expired Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
)
|
||||||
|
await repository.create(expired_task)
|
||||||
|
|
||||||
|
# Create non-expired task
|
||||||
|
valid_task = ScheduledTask(
|
||||||
|
name="Valid Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
expires_at=datetime.utcnow() + timedelta(hours=2),
|
||||||
|
)
|
||||||
|
await repository.create(valid_task)
|
||||||
|
|
||||||
|
# Create task with no expiry
|
||||||
|
no_expiry_task = ScheduledTask(
|
||||||
|
name="No Expiry",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
await repository.create(no_expiry_task)
|
||||||
|
|
||||||
|
expired_tasks = await repository.get_expired_tasks()
|
||||||
|
|
||||||
|
assert len(expired_tasks) == 1
|
||||||
|
assert expired_tasks[0].name == "Expired Task"
|
||||||
|
|
||||||
|
async def test_cancel_user_tasks(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test cancelling user tasks."""
|
||||||
|
# Create multiple user tasks
|
||||||
|
tasks_data = [
|
||||||
|
("Pending Task 1", TaskStatus.PENDING, TaskType.CREDIT_RECHARGE),
|
||||||
|
("Running Task", TaskStatus.RUNNING, TaskType.PLAY_SOUND),
|
||||||
|
("Completed Task", TaskStatus.COMPLETED, TaskType.CREDIT_RECHARGE),
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, status, task_type in tasks_data:
|
||||||
|
task = ScheduledTask(
|
||||||
|
name=name,
|
||||||
|
task_type=task_type,
|
||||||
|
status=status,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
await repository.create(task)
|
||||||
|
|
||||||
|
# Cancel all user tasks
|
||||||
|
cancelled_count = await repository.cancel_user_tasks(test_user_id)
|
||||||
|
|
||||||
|
assert cancelled_count == 2 # Only pending and running tasks
|
||||||
|
|
||||||
|
# Verify tasks are cancelled
|
||||||
|
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||||
|
pending_or_running = [
|
||||||
|
task for task in user_tasks
|
||||||
|
if task.status in [TaskStatus.PENDING, TaskStatus.RUNNING]
|
||||||
|
]
|
||||||
|
cancelled_tasks = [
|
||||||
|
task for task in user_tasks
|
||||||
|
if task.status == TaskStatus.CANCELLED
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(pending_or_running) == 0
|
||||||
|
assert len(cancelled_tasks) == 2
|
||||||
|
|
||||||
|
async def test_cancel_user_tasks_by_type(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test cancelling user tasks by type."""
|
||||||
|
# Create tasks of different types
|
||||||
|
credit_task = ScheduledTask(
|
||||||
|
name="Credit Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
await repository.create(credit_task)
|
||||||
|
|
||||||
|
sound_task = ScheduledTask(
|
||||||
|
name="Sound Task",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
await repository.create(sound_task)
|
||||||
|
|
||||||
|
# Cancel only credit tasks
|
||||||
|
cancelled_count = await repository.cancel_user_tasks(
|
||||||
|
test_user_id,
|
||||||
|
TaskType.CREDIT_RECHARGE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cancelled_count == 1
|
||||||
|
|
||||||
|
# Verify only credit task is cancelled
|
||||||
|
user_tasks = await repository.get_user_tasks(test_user_id)
|
||||||
|
credit_tasks = [
|
||||||
|
task for task in user_tasks
|
||||||
|
if task.task_type == TaskType.CREDIT_RECHARGE
|
||||||
|
]
|
||||||
|
sound_tasks = [
|
||||||
|
task for task in user_tasks
|
||||||
|
if task.task_type == TaskType.PLAY_SOUND
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(credit_tasks) == 1
|
||||||
|
assert credit_tasks[0].status == TaskStatus.CANCELLED
|
||||||
|
assert len(sound_tasks) == 1
|
||||||
|
assert sound_tasks[0].status == TaskStatus.PENDING
|
||||||
|
|
||||||
|
async def test_mark_as_running(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
sample_task: ScheduledTask,
|
||||||
|
):
|
||||||
|
"""Test marking task as running."""
|
||||||
|
await repository.mark_as_running(sample_task)
|
||||||
|
|
||||||
|
updated_task = await repository.get_by_id(sample_task.id)
|
||||||
|
assert updated_task.status == TaskStatus.RUNNING
|
||||||
|
|
||||||
|
async def test_mark_as_completed(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
sample_task: ScheduledTask,
|
||||||
|
):
|
||||||
|
"""Test marking task as completed."""
|
||||||
|
initial_count = sample_task.executions_count
|
||||||
|
next_execution = datetime.utcnow() + timedelta(days=1)
|
||||||
|
|
||||||
|
await repository.mark_as_completed(sample_task, next_execution)
|
||||||
|
|
||||||
|
updated_task = await repository.get_by_id(sample_task.id)
|
||||||
|
assert updated_task.status == TaskStatus.COMPLETED
|
||||||
|
assert updated_task.executions_count == initial_count + 1
|
||||||
|
assert updated_task.last_executed_at is not None
|
||||||
|
assert updated_task.error_message is None
|
||||||
|
|
||||||
|
async def test_mark_as_completed_recurring_task(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
):
|
||||||
|
"""Test marking recurring task as completed."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Recurring Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
)
|
||||||
|
created_task = await repository.create(task)
|
||||||
|
|
||||||
|
next_execution = datetime.utcnow() + timedelta(days=1)
|
||||||
|
await repository.mark_as_completed(created_task, next_execution)
|
||||||
|
|
||||||
|
updated_task = await repository.get_by_id(created_task.id)
|
||||||
|
# Should be set back to pending for next execution
|
||||||
|
assert updated_task.status == TaskStatus.PENDING
|
||||||
|
assert updated_task.next_execution_at == next_execution
|
||||||
|
assert updated_task.is_active is True
|
||||||
|
|
||||||
|
async def test_mark_as_completed_non_recurring_task(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
sample_task: ScheduledTask,
|
||||||
|
):
|
||||||
|
"""Test marking non-recurring task as completed."""
|
||||||
|
await repository.mark_as_completed(sample_task, None)
|
||||||
|
|
||||||
|
updated_task = await repository.get_by_id(sample_task.id)
|
||||||
|
assert updated_task.status == TaskStatus.COMPLETED
|
||||||
|
assert updated_task.is_active is False
|
||||||
|
|
||||||
|
async def test_mark_as_failed(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
sample_task: ScheduledTask,
|
||||||
|
):
|
||||||
|
"""Test marking task as failed."""
|
||||||
|
error_message = "Task execution failed"
|
||||||
|
|
||||||
|
await repository.mark_as_failed(sample_task, error_message)
|
||||||
|
|
||||||
|
updated_task = await repository.get_by_id(sample_task.id)
|
||||||
|
assert updated_task.status == TaskStatus.FAILED
|
||||||
|
assert updated_task.error_message == error_message
|
||||||
|
assert updated_task.last_executed_at is not None
|
||||||
|
|
||||||
|
async def test_mark_as_failed_recurring_task(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
):
|
||||||
|
"""Test marking recurring task as failed."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Recurring Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
)
|
||||||
|
created_task = await repository.create(task)
|
||||||
|
|
||||||
|
await repository.mark_as_failed(created_task, "Failed")
|
||||||
|
|
||||||
|
updated_task = await repository.get_by_id(created_task.id)
|
||||||
|
assert updated_task.status == TaskStatus.FAILED
|
||||||
|
# Recurring tasks should remain active even after failure
|
||||||
|
assert updated_task.is_active is True
|
||||||
|
|
||||||
|
async def test_mark_as_failed_non_recurring_task(
|
||||||
|
self,
|
||||||
|
repository: ScheduledTaskRepository,
|
||||||
|
sample_task: ScheduledTask,
|
||||||
|
):
|
||||||
|
"""Test marking non-recurring task as failed."""
|
||||||
|
await repository.mark_as_failed(sample_task, "Failed")
|
||||||
|
|
||||||
|
updated_task = await repository.get_by_id(sample_task.id)
|
||||||
|
assert updated_task.status == TaskStatus.FAILED
|
||||||
|
# Non-recurring tasks should be deactivated on failure
|
||||||
|
assert updated_task.is_active is False
|
||||||
495
tests/test_scheduler_service.py
Normal file
495
tests/test_scheduler_service.py
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
"""Tests for scheduler service."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.scheduled_task import (
|
||||||
|
RecurrenceType,
|
||||||
|
ScheduledTask,
|
||||||
|
TaskStatus,
|
||||||
|
TaskType,
|
||||||
|
)
|
||||||
|
from app.services.scheduler import SchedulerService
|
||||||
|
|
||||||
|
|
||||||
|
class TestSchedulerService:
|
||||||
|
"""Test cases for scheduler service."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_player_service(self):
|
||||||
|
"""Create mock player service."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def scheduler_service(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_player_service,
|
||||||
|
) -> SchedulerService:
|
||||||
|
"""Create scheduler service fixture."""
|
||||||
|
session_factory = lambda: db_session
|
||||||
|
return SchedulerService(session_factory, mock_player_service)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_task_data(self) -> dict:
|
||||||
|
"""Sample task data for testing."""
|
||||||
|
return {
|
||||||
|
"name": "Test Task",
|
||||||
|
"task_type": TaskType.CREDIT_RECHARGE,
|
||||||
|
"scheduled_at": datetime.utcnow() + timedelta(hours=1),
|
||||||
|
"parameters": {"test": "value"},
|
||||||
|
"timezone": "UTC",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def test_create_task(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test creating a scheduled task."""
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
||||||
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
assert task.id is not None
|
||||||
|
assert task.name == sample_task_data["name"]
|
||||||
|
assert task.task_type == sample_task_data["task_type"]
|
||||||
|
assert task.status == TaskStatus.PENDING
|
||||||
|
assert task.parameters == sample_task_data["parameters"]
|
||||||
|
mock_schedule.assert_called_once_with(task)
|
||||||
|
|
||||||
|
async def test_create_user_task(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test creating a user task."""
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(
|
||||||
|
user_id=test_user_id,
|
||||||
|
**sample_task_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.user_id == test_user_id
|
||||||
|
assert not task.is_system_task()
|
||||||
|
|
||||||
|
async def test_create_system_task(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test creating a system task."""
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
assert task.user_id is None
|
||||||
|
assert task.is_system_task()
|
||||||
|
|
||||||
|
async def test_create_recurring_task(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test creating a recurring task."""
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
recurrence_count=5,
|
||||||
|
**sample_task_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.recurrence_type == RecurrenceType.DAILY
|
||||||
|
assert task.recurrence_count == 5
|
||||||
|
assert task.is_recurring()
|
||||||
|
|
||||||
|
async def test_create_task_with_timezone_conversion(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test creating task with timezone conversion."""
|
||||||
|
# Use a specific datetime for testing
|
||||||
|
ny_time = datetime(2024, 1, 1, 12, 0, 0) # Noon in NY
|
||||||
|
|
||||||
|
sample_task_data["scheduled_at"] = ny_time
|
||||||
|
sample_task_data["timezone"] = "America/New_York"
|
||||||
|
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
# The scheduled_at should be converted to UTC
|
||||||
|
assert task.timezone == "America/New_York"
|
||||||
|
# In winter, EST is UTC-5, so noon EST becomes 5 PM UTC
|
||||||
|
# Note: This test might need adjustment based on DST
|
||||||
|
assert task.scheduled_at.hour in [16, 17] # Account for DST
|
||||||
|
|
||||||
|
async def test_cancel_task(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test cancelling a task."""
|
||||||
|
# Create a task first
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
# Mock the scheduler remove_job method
|
||||||
|
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
||||||
|
result = await scheduler_service.cancel_task(task.id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_remove.assert_called_once_with(str(task.id))
|
||||||
|
|
||||||
|
# Check task is cancelled in database
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
async with scheduler_service.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
updated_task = await repo.get_by_id(task.id)
|
||||||
|
assert updated_task.status == TaskStatus.CANCELLED
|
||||||
|
assert updated_task.is_active is False
|
||||||
|
|
||||||
|
async def test_cancel_nonexistent_task(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test cancelling a non-existent task."""
|
||||||
|
result = await scheduler_service.cancel_task(uuid.uuid4())
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
async def test_get_user_tasks(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test getting user tasks."""
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
# Create user task
|
||||||
|
await scheduler_service.create_task(
|
||||||
|
user_id=test_user_id,
|
||||||
|
**sample_task_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create system task
|
||||||
|
await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
user_tasks = await scheduler_service.get_user_tasks(test_user_id)
|
||||||
|
|
||||||
|
assert len(user_tasks) == 1
|
||||||
|
assert user_tasks[0].user_id == test_user_id
|
||||||
|
|
||||||
|
async def test_ensure_system_tasks(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test ensuring system tasks exist."""
|
||||||
|
# Mock the repository to return no existing tasks
|
||||||
|
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
||||||
|
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
||||||
|
mock_get.return_value = []
|
||||||
|
|
||||||
|
await scheduler_service._ensure_system_tasks()
|
||||||
|
|
||||||
|
# Should create daily credit recharge task
|
||||||
|
mock_create.assert_called_once()
|
||||||
|
created_task = mock_create.call_args[0][0]
|
||||||
|
assert created_task.name == "Daily Credit Recharge"
|
||||||
|
assert created_task.task_type == TaskType.CREDIT_RECHARGE
|
||||||
|
assert created_task.recurrence_type == RecurrenceType.DAILY
|
||||||
|
|
||||||
|
async def test_ensure_system_tasks_already_exist(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test ensuring system tasks when they already exist."""
|
||||||
|
existing_task = ScheduledTask(
|
||||||
|
name="Existing Daily Credit Recharge",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
is_active=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.get_system_tasks') as mock_get:
|
||||||
|
with patch('app.repositories.scheduled_task.ScheduledTaskRepository.create') as mock_create:
|
||||||
|
mock_get.return_value = [existing_task]
|
||||||
|
|
||||||
|
await scheduler_service._ensure_system_tasks()
|
||||||
|
|
||||||
|
# Should not create new task
|
||||||
|
mock_create.assert_not_called()
|
||||||
|
|
||||||
|
def test_create_trigger_one_shot(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test creating one-shot trigger."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="One Shot",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
trigger = scheduler_service._create_trigger(task)
|
||||||
|
assert trigger is not None
|
||||||
|
assert trigger.__class__.__name__ == "DateTrigger"
|
||||||
|
|
||||||
|
def test_create_trigger_daily(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test creating daily interval trigger."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Daily",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.DAILY,
|
||||||
|
)
|
||||||
|
|
||||||
|
trigger = scheduler_service._create_trigger(task)
|
||||||
|
assert trigger is not None
|
||||||
|
assert trigger.__class__.__name__ == "IntervalTrigger"
|
||||||
|
|
||||||
|
def test_create_trigger_cron(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test creating cron trigger."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Cron",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
recurrence_type=RecurrenceType.CRON,
|
||||||
|
cron_expression="0 9 * * *", # 9 AM daily
|
||||||
|
)
|
||||||
|
|
||||||
|
trigger = scheduler_service._create_trigger(task)
|
||||||
|
assert trigger is not None
|
||||||
|
assert trigger.__class__.__name__ == "CronTrigger"
|
||||||
|
|
||||||
|
def test_create_trigger_monthly(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test creating monthly cron trigger."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Monthly",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime(2024, 1, 15, 10, 30, 0), # 15th at 10:30 AM
|
||||||
|
recurrence_type=RecurrenceType.MONTHLY,
|
||||||
|
)
|
||||||
|
|
||||||
|
trigger = scheduler_service._create_trigger(task)
|
||||||
|
assert trigger is not None
|
||||||
|
assert trigger.__class__.__name__ == "CronTrigger"
|
||||||
|
|
||||||
|
def test_calculate_next_execution(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test calculating next execution time."""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
|
||||||
|
# Test different recurrence types
|
||||||
|
test_cases = [
|
||||||
|
(RecurrenceType.HOURLY, timedelta(hours=1)),
|
||||||
|
(RecurrenceType.DAILY, timedelta(days=1)),
|
||||||
|
(RecurrenceType.WEEKLY, timedelta(weeks=1)),
|
||||||
|
(RecurrenceType.MONTHLY, timedelta(days=30)),
|
||||||
|
(RecurrenceType.YEARLY, timedelta(days=365)),
|
||||||
|
]
|
||||||
|
|
||||||
|
for recurrence_type, expected_delta in test_cases:
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Test",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=now,
|
||||||
|
recurrence_type=recurrence_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch('app.services.scheduler.datetime') as mock_datetime:
|
||||||
|
mock_datetime.utcnow.return_value = now
|
||||||
|
next_execution = scheduler_service._calculate_next_execution(task)
|
||||||
|
|
||||||
|
assert next_execution is not None
|
||||||
|
# Allow some tolerance for execution time
|
||||||
|
assert abs((next_execution - now) - expected_delta) < timedelta(seconds=1)
|
||||||
|
|
||||||
|
def test_calculate_next_execution_none_recurrence(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test calculating next execution for non-recurring task."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="One Shot",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
recurrence_type=RecurrenceType.NONE,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_execution = scheduler_service._calculate_next_execution(task)
|
||||||
|
assert next_execution is None
|
||||||
|
|
||||||
|
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
||||||
|
async def test_execute_task_success(
|
||||||
|
self,
|
||||||
|
mock_handler_class,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test successful task execution."""
|
||||||
|
# Create task
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
# Mock handler registry
|
||||||
|
mock_handler = AsyncMock()
|
||||||
|
mock_handler_class.return_value = mock_handler
|
||||||
|
|
||||||
|
# Execute task
|
||||||
|
await scheduler_service._execute_task(task.id)
|
||||||
|
|
||||||
|
# Verify handler was called
|
||||||
|
mock_handler.execute_task.assert_called_once()
|
||||||
|
|
||||||
|
# Check task is marked as completed
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
async with scheduler_service.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
updated_task = await repo.get_by_id(task.id)
|
||||||
|
assert updated_task.status == TaskStatus.COMPLETED
|
||||||
|
assert updated_task.executions_count == 1
|
||||||
|
|
||||||
|
@patch('app.services.task_handlers.TaskHandlerRegistry')
|
||||||
|
async def test_execute_task_failure(
|
||||||
|
self,
|
||||||
|
mock_handler_class,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test task execution failure."""
|
||||||
|
# Create task
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
# Mock handler to raise exception
|
||||||
|
mock_handler = AsyncMock()
|
||||||
|
mock_handler.execute_task.side_effect = Exception("Task failed")
|
||||||
|
mock_handler_class.return_value = mock_handler
|
||||||
|
|
||||||
|
# Execute task
|
||||||
|
await scheduler_service._execute_task(task.id)
|
||||||
|
|
||||||
|
# Check task is marked as failed
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
async with scheduler_service.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
updated_task = await repo.get_by_id(task.id)
|
||||||
|
assert updated_task.status == TaskStatus.FAILED
|
||||||
|
assert "Task failed" in updated_task.error_message
|
||||||
|
|
||||||
|
async def test_execute_nonexistent_task(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test executing non-existent task."""
|
||||||
|
# Should handle gracefully
|
||||||
|
await scheduler_service._execute_task(uuid.uuid4())
|
||||||
|
|
||||||
|
async def test_execute_expired_task(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test executing expired task."""
|
||||||
|
# Create expired task
|
||||||
|
sample_task_data["expires_at"] = datetime.utcnow() - timedelta(hours=1)
|
||||||
|
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
# Execute task
|
||||||
|
await scheduler_service._execute_task(task.id)
|
||||||
|
|
||||||
|
# Check task is cancelled
|
||||||
|
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||||
|
async with scheduler_service.db_session_factory() as session:
|
||||||
|
repo = ScheduledTaskRepository(session)
|
||||||
|
updated_task = await repo.get_by_id(task.id)
|
||||||
|
assert updated_task.status == TaskStatus.CANCELLED
|
||||||
|
assert updated_task.is_active is False
|
||||||
|
|
||||||
|
async def test_concurrent_task_execution_prevention(
|
||||||
|
self,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
sample_task_data: dict,
|
||||||
|
):
|
||||||
|
"""Test prevention of concurrent task execution."""
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job'):
|
||||||
|
task = await scheduler_service.create_task(**sample_task_data)
|
||||||
|
|
||||||
|
# Add task to running set
|
||||||
|
scheduler_service._running_tasks.add(str(task.id))
|
||||||
|
|
||||||
|
# Try to execute - should return without doing anything
|
||||||
|
with patch('app.services.task_handlers.TaskHandlerRegistry') as mock_handler_class:
|
||||||
|
await scheduler_service._execute_task(task.id)
|
||||||
|
|
||||||
|
# Handler should not be called
|
||||||
|
mock_handler_class.assert_not_called()
|
||||||
|
|
||||||
|
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
||||||
|
async def test_maintenance_job_expired_tasks(
|
||||||
|
self,
|
||||||
|
mock_repo_class,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test maintenance job handling expired tasks."""
|
||||||
|
# Mock expired task
|
||||||
|
expired_task = MagicMock()
|
||||||
|
expired_task.id = uuid.uuid4()
|
||||||
|
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo.get_expired_tasks.return_value = [expired_task]
|
||||||
|
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = []
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
|
||||||
|
with patch.object(scheduler_service.scheduler, 'remove_job') as mock_remove:
|
||||||
|
await scheduler_service._maintenance_job()
|
||||||
|
|
||||||
|
# Should mark as cancelled and remove from scheduler
|
||||||
|
assert expired_task.status == TaskStatus.CANCELLED
|
||||||
|
assert expired_task.is_active is False
|
||||||
|
mock_repo.update.assert_called_with(expired_task)
|
||||||
|
mock_remove.assert_called_once_with(str(expired_task.id))
|
||||||
|
|
||||||
|
@patch('app.repositories.scheduled_task.ScheduledTaskRepository')
|
||||||
|
async def test_maintenance_job_due_recurring_tasks(
|
||||||
|
self,
|
||||||
|
mock_repo_class,
|
||||||
|
scheduler_service: SchedulerService,
|
||||||
|
):
|
||||||
|
"""Test maintenance job handling due recurring tasks."""
|
||||||
|
# Mock due recurring task
|
||||||
|
due_task = MagicMock()
|
||||||
|
due_task.should_repeat.return_value = True
|
||||||
|
due_task.next_execution_at = datetime.utcnow() - timedelta(minutes=5)
|
||||||
|
|
||||||
|
mock_repo = AsyncMock()
|
||||||
|
mock_repo.get_expired_tasks.return_value = []
|
||||||
|
mock_repo.get_recurring_tasks_due_for_next_execution.return_value = [due_task]
|
||||||
|
mock_repo_class.return_value = mock_repo
|
||||||
|
|
||||||
|
with patch.object(scheduler_service, '_schedule_apscheduler_job') as mock_schedule:
|
||||||
|
await scheduler_service._maintenance_job()
|
||||||
|
|
||||||
|
# Should reset to pending and reschedule
|
||||||
|
assert due_task.status == TaskStatus.PENDING
|
||||||
|
assert due_task.scheduled_at == due_task.next_execution_at
|
||||||
|
mock_repo.update.assert_called_with(due_task)
|
||||||
|
mock_schedule.assert_called_once_with(due_task)
|
||||||
424
tests/test_task_handlers.py
Normal file
424
tests/test_task_handlers.py
Normal file
@@ -0,0 +1,424 @@
|
|||||||
|
"""Tests for task handlers."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from app.models.scheduled_task import ScheduledTask, TaskType
|
||||||
|
from app.services.task_handlers import TaskExecutionError, TaskHandlerRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskHandlerRegistry:
|
||||||
|
"""Test cases for task handler registry."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_credit_service(self):
|
||||||
|
"""Create mock credit service."""
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_player_service(self):
|
||||||
|
"""Create mock player service."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task_registry(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_credit_service,
|
||||||
|
mock_player_service,
|
||||||
|
) -> TaskHandlerRegistry:
|
||||||
|
"""Create task handler registry fixture."""
|
||||||
|
return TaskHandlerRegistry(
|
||||||
|
db_session,
|
||||||
|
mock_credit_service,
|
||||||
|
mock_player_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_execute_task_unknown_type(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
):
|
||||||
|
"""Test executing task with unknown type."""
|
||||||
|
# Create task with invalid type
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Unknown Task",
|
||||||
|
task_type="UNKNOWN_TYPE", # Invalid type
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TaskExecutionError, match="No handler registered"):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
async def test_handle_credit_recharge_all_users(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
mock_credit_service,
|
||||||
|
):
|
||||||
|
"""Test handling credit recharge for all users."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Daily Credit Recharge",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_credit_service.recharge_all_users_credits.return_value = {
|
||||||
|
"users_recharged": 10,
|
||||||
|
"total_credits": 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
mock_credit_service.recharge_all_users_credits.assert_called_once()
|
||||||
|
|
||||||
|
async def test_handle_credit_recharge_specific_user(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
mock_credit_service,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test handling credit recharge for specific user."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="User Credit Recharge",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"user_id": str(test_user_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_credit_service.recharge_user_credits.return_value = {
|
||||||
|
"user_id": str(test_user_id),
|
||||||
|
"credits_added": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
mock_credit_service.recharge_user_credits.assert_called_once_with(test_user_id)
|
||||||
|
|
||||||
|
async def test_handle_credit_recharge_uuid_user_id(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
mock_credit_service,
|
||||||
|
test_user_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test handling credit recharge with UUID user_id parameter."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="User Credit Recharge",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"user_id": test_user_id}, # UUID object instead of string
|
||||||
|
)
|
||||||
|
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
mock_credit_service.recharge_user_credits.assert_called_once_with(test_user_id)
|
||||||
|
|
||||||
|
async def test_handle_play_sound_success(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
test_sound_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test successful play sound task handling."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Sound",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"sound_id": str(test_sound_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock sound repository
|
||||||
|
mock_sound = MagicMock()
|
||||||
|
mock_sound.id = test_sound_id
|
||||||
|
mock_sound.filename = "test_sound.mp3"
|
||||||
|
|
||||||
|
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
||||||
|
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
||||||
|
mock_vlc_service = AsyncMock()
|
||||||
|
mock_vlc_class.return_value = mock_vlc_service
|
||||||
|
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
task_registry.sound_repository.get_by_id.assert_called_once_with(test_sound_id)
|
||||||
|
mock_vlc_service.play_sound.assert_called_once_with(mock_sound)
|
||||||
|
|
||||||
|
async def test_handle_play_sound_missing_sound_id(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
):
|
||||||
|
"""Test play sound task with missing sound_id parameter."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Sound",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={}, # Missing sound_id
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TaskExecutionError, match="sound_id parameter is required"):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
async def test_handle_play_sound_invalid_sound_id(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
):
|
||||||
|
"""Test play sound task with invalid sound_id parameter."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Sound",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"sound_id": "invalid-uuid"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TaskExecutionError, match="Invalid sound_id format"):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
async def test_handle_play_sound_not_found(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
test_sound_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test play sound task with non-existent sound."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Sound",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"sound_id": str(test_sound_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=None):
|
||||||
|
with pytest.raises(TaskExecutionError, match="Sound not found"):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
async def test_handle_play_sound_uuid_parameter(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
test_sound_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test play sound task with UUID parameter (not string)."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Sound",
|
||||||
|
task_type=TaskType.PLAY_SOUND,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"sound_id": test_sound_id}, # UUID object
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_sound = MagicMock()
|
||||||
|
mock_sound.filename = "test_sound.mp3"
|
||||||
|
|
||||||
|
with patch.object(task_registry.sound_repository, 'get_by_id', return_value=mock_sound):
|
||||||
|
with patch('app.services.vlc_player.VLCPlayerService') as mock_vlc_class:
|
||||||
|
mock_vlc_service = AsyncMock()
|
||||||
|
mock_vlc_class.return_value = mock_vlc_service
|
||||||
|
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
task_registry.sound_repository.get_by_id.assert_called_once_with(test_sound_id)
|
||||||
|
|
||||||
|
async def test_handle_play_playlist_success(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
mock_player_service,
|
||||||
|
test_playlist_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test successful play playlist task handling."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Playlist",
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={
|
||||||
|
"playlist_id": str(test_playlist_id),
|
||||||
|
"play_mode": "loop",
|
||||||
|
"shuffle": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock playlist repository
|
||||||
|
mock_playlist = MagicMock()
|
||||||
|
mock_playlist.id = test_playlist_id
|
||||||
|
mock_playlist.name = "Test Playlist"
|
||||||
|
|
||||||
|
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
task_registry.playlist_repository.get_by_id.assert_called_once_with(test_playlist_id)
|
||||||
|
mock_player_service.load_playlist.assert_called_once_with(test_playlist_id)
|
||||||
|
mock_player_service.set_mode.assert_called_once_with("loop")
|
||||||
|
mock_player_service.set_shuffle.assert_called_once_with(True)
|
||||||
|
mock_player_service.play.assert_called_once()
|
||||||
|
|
||||||
|
async def test_handle_play_playlist_minimal_parameters(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
mock_player_service,
|
||||||
|
test_playlist_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test play playlist task with minimal parameters."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Playlist",
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"playlist_id": str(test_playlist_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_playlist = MagicMock()
|
||||||
|
mock_playlist.name = "Test Playlist"
|
||||||
|
|
||||||
|
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
# Should use default values
|
||||||
|
mock_player_service.set_mode.assert_called_once_with("continuous")
|
||||||
|
mock_player_service.set_shuffle.assert_called_once_with(False)
|
||||||
|
|
||||||
|
async def test_handle_play_playlist_missing_playlist_id(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
):
|
||||||
|
"""Test play playlist task with missing playlist_id parameter."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Playlist",
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={}, # Missing playlist_id
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TaskExecutionError, match="playlist_id parameter is required"):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
async def test_handle_play_playlist_invalid_playlist_id(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
):
|
||||||
|
"""Test play playlist task with invalid playlist_id parameter."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Playlist",
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"playlist_id": "invalid-uuid"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TaskExecutionError, match="Invalid playlist_id format"):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
async def test_handle_play_playlist_not_found(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
test_playlist_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test play playlist task with non-existent playlist."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Playlist",
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={"playlist_id": str(test_playlist_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=None):
|
||||||
|
with pytest.raises(TaskExecutionError, match="Playlist not found"):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
async def test_handle_play_playlist_valid_play_modes(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
mock_player_service,
|
||||||
|
test_playlist_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test play playlist task with various valid play modes."""
|
||||||
|
mock_playlist = MagicMock()
|
||||||
|
mock_playlist.name = "Test Playlist"
|
||||||
|
|
||||||
|
valid_modes = ["continuous", "loop", "loop_one", "random", "single"]
|
||||||
|
|
||||||
|
for mode in valid_modes:
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Playlist",
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={
|
||||||
|
"playlist_id": str(test_playlist_id),
|
||||||
|
"play_mode": mode,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
mock_player_service.set_mode.assert_called_with(mode)
|
||||||
|
|
||||||
|
# Reset mock for next iteration
|
||||||
|
mock_player_service.reset_mock()
|
||||||
|
|
||||||
|
async def test_handle_play_playlist_invalid_play_mode(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
mock_player_service,
|
||||||
|
test_playlist_id: uuid.UUID,
|
||||||
|
):
|
||||||
|
"""Test play playlist task with invalid play mode."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Play Playlist",
|
||||||
|
task_type=TaskType.PLAY_PLAYLIST,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={
|
||||||
|
"playlist_id": str(test_playlist_id),
|
||||||
|
"play_mode": "invalid_mode",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_playlist = MagicMock()
|
||||||
|
mock_playlist.name = "Test Playlist"
|
||||||
|
|
||||||
|
with patch.object(task_registry.playlist_repository, 'get_by_id', return_value=mock_playlist):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
# Should not call set_mode for invalid mode
|
||||||
|
mock_player_service.set_mode.assert_not_called()
|
||||||
|
# But should still load playlist and play
|
||||||
|
mock_player_service.load_playlist.assert_called_once()
|
||||||
|
mock_player_service.play.assert_called_once()
|
||||||
|
|
||||||
|
async def test_task_execution_exception_handling(
|
||||||
|
self,
|
||||||
|
task_registry: TaskHandlerRegistry,
|
||||||
|
mock_credit_service,
|
||||||
|
):
|
||||||
|
"""Test exception handling during task execution."""
|
||||||
|
task = ScheduledTask(
|
||||||
|
name="Failing Task",
|
||||||
|
task_type=TaskType.CREDIT_RECHARGE,
|
||||||
|
scheduled_at=datetime.utcnow(),
|
||||||
|
parameters={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make credit service raise an exception
|
||||||
|
mock_credit_service.recharge_all_users_credits.side_effect = Exception("Service error")
|
||||||
|
|
||||||
|
with pytest.raises(TaskExecutionError, match="Task execution failed: Service error"):
|
||||||
|
await task_registry.execute_task(task)
|
||||||
|
|
||||||
|
async def test_task_registry_initialization(
|
||||||
|
self,
|
||||||
|
db_session: AsyncSession,
|
||||||
|
mock_credit_service,
|
||||||
|
mock_player_service,
|
||||||
|
):
|
||||||
|
"""Test task registry initialization."""
|
||||||
|
registry = TaskHandlerRegistry(
|
||||||
|
db_session,
|
||||||
|
mock_credit_service,
|
||||||
|
mock_player_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert registry.db_session == db_session
|
||||||
|
assert registry.credit_service == mock_credit_service
|
||||||
|
assert registry.player_service == mock_player_service
|
||||||
|
assert registry.sound_repository is not None
|
||||||
|
assert registry.playlist_repository is not None
|
||||||
|
|
||||||
|
# Check all handlers are registered
|
||||||
|
expected_handlers = {
|
||||||
|
TaskType.CREDIT_RECHARGE,
|
||||||
|
TaskType.PLAY_SOUND,
|
||||||
|
TaskType.PLAY_PLAYLIST,
|
||||||
|
}
|
||||||
|
assert set(registry._handlers.keys()) == expected_handlers
|
||||||
11
uv.lock
generated
11
uv.lock
generated
@@ -742,6 +742,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/5b/ee/7d76eb3b50ccb1397621f32ede0fb4d17aa55a9aa2251bc34e6b9929fdce/python_vlc-3.0.21203-py3-none-any.whl", hash = "sha256:1613451a31b692ec276296ceeae0c0ba82bfc2d094dabf9aceb70f58944a6320", size = 87651 },
|
{ url = "https://files.pythonhosted.org/packages/5b/ee/7d76eb3b50ccb1397621f32ede0fb4d17aa55a9aa2251bc34e6b9929fdce/python_vlc-3.0.21203-py3-none-any.whl", hash = "sha256:1613451a31b692ec276296ceeae0c0ba82bfc2d094dabf9aceb70f58944a6320", size = 87651 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytz"
|
||||||
|
version = "2024.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/90/26/9f1f00a5d021fff16dee3de13d43e5e978f3d58928e129c3a62cf7eb9738/pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812", size = 316214 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/9c/3d/a121f284241f08268b21359bd425f7d4825cffc5ac5cd0e1b3d82ffd2b10/pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319", size = 505474 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyyaml"
|
name = "pyyaml"
|
||||||
version = "6.0.2"
|
version = "6.0.2"
|
||||||
@@ -883,6 +892,7 @@ dependencies = [
|
|||||||
{ name = "pyjwt" },
|
{ name = "pyjwt" },
|
||||||
{ name = "python-socketio" },
|
{ name = "python-socketio" },
|
||||||
{ name = "python-vlc" },
|
{ name = "python-vlc" },
|
||||||
|
{ name = "pytz" },
|
||||||
{ name = "sqlmodel" },
|
{ name = "sqlmodel" },
|
||||||
{ name = "uvicorn", extra = ["standard"] },
|
{ name = "uvicorn", extra = ["standard"] },
|
||||||
{ name = "yt-dlp" },
|
{ name = "yt-dlp" },
|
||||||
@@ -912,6 +922,7 @@ requires-dist = [
|
|||||||
{ name = "pyjwt", specifier = "==2.10.1" },
|
{ name = "pyjwt", specifier = "==2.10.1" },
|
||||||
{ name = "python-socketio", specifier = "==5.13.0" },
|
{ name = "python-socketio", specifier = "==5.13.0" },
|
||||||
{ name = "python-vlc", specifier = "==3.0.21203" },
|
{ name = "python-vlc", specifier = "==3.0.21203" },
|
||||||
|
{ name = "pytz", specifier = "==2024.1" },
|
||||||
{ name = "sqlmodel", specifier = "==0.0.24" },
|
{ name = "sqlmodel", specifier = "==0.0.24" },
|
||||||
{ name = "uvicorn", extras = ["standard"], specifier = "==0.35.0" },
|
{ name = "uvicorn", extras = ["standard"], specifier = "==0.35.0" },
|
||||||
{ name = "yt-dlp", specifier = "==2025.8.20" },
|
{ name = "yt-dlp", specifier = "==2025.8.20" },
|
||||||
|
|||||||
Reference in New Issue
Block a user