Add comprehensive tests for scheduled task repository, scheduler service, and task handlers

- Implemented tests for ScheduledTaskRepository covering task creation, retrieval, filtering, and status updates.
- Developed tests for SchedulerService including task creation, cancellation, user task retrieval, and maintenance jobs.
- Created tests for TaskHandlerRegistry to validate task execution for various types, including credit recharge and sound playback.
- Ensured proper error handling and edge cases in task execution scenarios.
- Added fixtures and mocks to facilitate isolated testing of services and repositories.
This commit is contained in:
JSC
2025-08-28 22:37:43 +02:00
parent 7dee6e320e
commit 03abed6d39
23 changed files with 3415 additions and 103 deletions

View File

@@ -1,63 +1,413 @@
"""Scheduler service for periodic tasks."""
"""Enhanced scheduler service for flexible task scheduling with timezone support."""
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
import pytz
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.date import DateTrigger
from apscheduler.triggers.interval import IntervalTrigger
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.scheduled_task import (
RecurrenceType,
ScheduledTask,
TaskStatus,
TaskType,
)
from app.repositories.scheduled_task import ScheduledTaskRepository
from app.services.credit import CreditService
from app.services.player import PlayerService
from app.services.task_handlers import TaskHandlerRegistry
logger = get_logger(__name__)
class SchedulerService:
"""Service for managing scheduled tasks."""
"""Enhanced service for managing scheduled tasks with timezone support."""
def __init__(self, db_session_factory: Callable[[], AsyncSession]) -> None:
def __init__(
self,
db_session_factory: Callable[[], AsyncSession],
player_service: PlayerService,
) -> None:
"""Initialize the scheduler service.
Args:
db_session_factory: Factory function to create database sessions
player_service: Player service for audio playback tasks
"""
self.db_session_factory = db_session_factory
self.scheduler = AsyncIOScheduler()
self.scheduler = AsyncIOScheduler(timezone=pytz.UTC)
self.credit_service = CreditService(db_session_factory)
self.player_service = player_service
self._running_tasks: set[str] = set()
async def start(self) -> None:
"""Start the scheduler and register all tasks."""
logger.info("Starting scheduler service...")
"""Start the scheduler and load all active tasks."""
logger.info("Starting enhanced scheduler service...")
# Add daily credit recharge job (runs at midnight UTC)
self.scheduler.start()
# Schedule system tasks initialization for after startup
self.scheduler.add_job(
self._daily_credit_recharge,
"cron",
hour=0,
minute=0,
id="daily_credit_recharge",
name="Daily Credit Recharge",
self._initialize_system_tasks,
"date",
run_date=datetime.utcnow() + timedelta(seconds=2),
id="initialize_system_tasks",
name="Initialize System Tasks",
replace_existing=True,
)
# Schedule periodic cleanup and maintenance
self.scheduler.add_job(
self._maintenance_job,
"interval",
minutes=5,
id="scheduler_maintenance",
name="Scheduler Maintenance",
replace_existing=True,
)
self.scheduler.start()
logger.info("Scheduler service started successfully")
logger.info("Enhanced scheduler service started successfully")
async def stop(self) -> None:
"""Stop the scheduler."""
logger.info("Stopping scheduler service...")
self.scheduler.shutdown()
self.scheduler.shutdown(wait=True)
logger.info("Scheduler service stopped")
async def _daily_credit_recharge(self) -> None:
"""Execute daily credit recharge for all users."""
logger.info("Starting daily credit recharge task...")
async def create_task(
self,
name: str,
task_type: TaskType,
scheduled_at: datetime,
parameters: Optional[Dict[str, Any]] = None,
user_id: Optional[int] = None,
timezone: str = "UTC",
recurrence_type: RecurrenceType = RecurrenceType.NONE,
cron_expression: Optional[str] = None,
recurrence_count: Optional[int] = None,
expires_at: Optional[datetime] = None,
) -> ScheduledTask:
"""Create a new scheduled task."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Convert scheduled_at to UTC if it's in a different timezone
if timezone != "UTC":
tz = pytz.timezone(timezone)
if scheduled_at.tzinfo is None:
# Assume the datetime is in the specified timezone
scheduled_at = tz.localize(scheduled_at)
scheduled_at = scheduled_at.astimezone(pytz.UTC).replace(tzinfo=None)
task_data = {
"name": name,
"task_type": task_type,
"scheduled_at": scheduled_at,
"timezone": timezone,
"parameters": parameters or {},
"user_id": user_id,
"recurrence_type": recurrence_type,
"cron_expression": cron_expression,
"recurrence_count": recurrence_count,
"expires_at": expires_at,
}
created_task = await repo.create(task_data)
await self._schedule_apscheduler_job(created_task)
logger.info(f"Created scheduled task: {created_task.name} ({created_task.id})")
return created_task
async def cancel_task(self, task_id: int) -> bool:
"""Cancel a scheduled task."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
task = await repo.get_by_id(task_id)
if not task:
return False
task.status = TaskStatus.CANCELLED
task.is_active = False
await repo.update(task)
# Remove from APScheduler
try:
self.scheduler.remove_job(str(task_id))
except Exception:
pass # Job might not exist in scheduler
logger.info(f"Cancelled task: {task.name} ({task_id})")
return True
async def get_user_tasks(
self,
user_id: int,
status: Optional[TaskStatus] = None,
task_type: Optional[TaskType] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[ScheduledTask]:
"""Get tasks for a specific user."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
return await repo.get_user_tasks(user_id, status, task_type, limit, offset)
async def _initialize_system_tasks(self) -> None:
"""Initialize system tasks and load active tasks from database."""
logger.info("Initializing system tasks...")
try:
stats = await self.credit_service.recharge_all_users_credits()
logger.info(
"Daily credit recharge completed successfully: %s",
stats,
)
# Create system tasks if they don't exist
await self._ensure_system_tasks()
# Load all active tasks from database
await self._load_active_tasks()
logger.info("System tasks initialized successfully")
except Exception:
logger.exception("Daily credit recharge task failed")
logger.exception("Failed to initialize system tasks")
async def _ensure_system_tasks(self) -> None:
"""Ensure required system tasks exist."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Check if daily credit recharge task exists
system_tasks = await repo.get_system_tasks(
task_type=TaskType.CREDIT_RECHARGE
)
daily_recharge_exists = any(
task.recurrence_type == RecurrenceType.DAILY
and task.is_active
for task in system_tasks
)
if not daily_recharge_exists:
# Create daily credit recharge task
tomorrow_midnight = datetime.utcnow().replace(
hour=0, minute=0, second=0, microsecond=0
) + timedelta(days=1)
task_data = {
"name": "Daily Credit Recharge",
"task_type": TaskType.CREDIT_RECHARGE,
"scheduled_at": tomorrow_midnight,
"recurrence_type": RecurrenceType.DAILY,
"parameters": {},
}
await repo.create(task_data)
logger.info("Created system daily credit recharge task")
async def _load_active_tasks(self) -> None:
"""Load all active tasks from database into scheduler."""
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
active_tasks = await repo.get_active_tasks()
for task in active_tasks:
await self._schedule_apscheduler_job(task)
logger.info(f"Loaded {len(active_tasks)} active tasks into scheduler")
async def _schedule_apscheduler_job(self, task: ScheduledTask) -> None:
"""Schedule a task in APScheduler."""
job_id = str(task.id)
# Remove existing job if it exists
try:
self.scheduler.remove_job(job_id)
except Exception:
pass
# Don't schedule if task is not active or already completed/failed
if not task.is_active or task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
return
# Create trigger based on recurrence type
trigger = self._create_trigger(task)
if not trigger:
logger.warning(f"Could not create trigger for task {task.id}")
return
# Schedule the job
self.scheduler.add_job(
self._execute_task,
trigger=trigger,
args=[task.id],
id=job_id,
name=task.name,
replace_existing=True,
)
logger.debug(f"Scheduled APScheduler job for task {task.id}")
def _create_trigger(self, task: ScheduledTask):
"""Create APScheduler trigger based on task configuration."""
tz = pytz.timezone(task.timezone)
if task.recurrence_type == RecurrenceType.NONE:
return DateTrigger(run_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.CRON and task.cron_expression:
return CronTrigger.from_crontab(task.cron_expression, timezone=tz)
elif task.recurrence_type == RecurrenceType.HOURLY:
return IntervalTrigger(hours=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.DAILY:
return IntervalTrigger(days=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.WEEKLY:
return IntervalTrigger(weeks=1, start_date=task.scheduled_at, timezone=tz)
elif task.recurrence_type == RecurrenceType.MONTHLY:
# Use cron trigger for monthly (more reliable than interval)
scheduled_time = task.scheduled_at
return CronTrigger(
day=scheduled_time.day,
hour=scheduled_time.hour,
minute=scheduled_time.minute,
timezone=tz
)
elif task.recurrence_type == RecurrenceType.YEARLY:
scheduled_time = task.scheduled_at
return CronTrigger(
month=scheduled_time.month,
day=scheduled_time.day,
hour=scheduled_time.hour,
minute=scheduled_time.minute,
timezone=tz
)
return None
async def _execute_task(self, task_id: int) -> None:
"""Execute a scheduled task."""
task_id_str = str(task_id)
# Prevent concurrent execution of the same task
if task_id_str in self._running_tasks:
logger.warning(f"Task {task_id} is already running, skipping execution")
return
self._running_tasks.add(task_id_str)
try:
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Get fresh task data
task = await repo.get_by_id(task_id)
if not task:
logger.warning(f"Task {task_id} not found")
return
# Check if task is still active and pending
if not task.is_active or task.status != TaskStatus.PENDING:
logger.info(f"Task {task_id} is not active or not pending, skipping")
return
# Check if task has expired
if task.is_expired():
logger.info(f"Task {task_id} has expired, marking as cancelled")
task.status = TaskStatus.CANCELLED
task.is_active = False
await repo.update(task)
return
# Mark task as running
await repo.mark_as_running(task)
# Execute the task
try:
handler_registry = TaskHandlerRegistry(
session, self.credit_service, self.player_service
)
await handler_registry.execute_task(task)
# Calculate next execution time for recurring tasks
next_execution_at = None
if task.should_repeat():
next_execution_at = self._calculate_next_execution(task)
# Mark as completed
await repo.mark_as_completed(task, next_execution_at)
# Reschedule if recurring
if next_execution_at and task.should_repeat():
# Refresh task to get updated data
await session.refresh(task)
await self._schedule_apscheduler_job(task)
except Exception as e:
await repo.mark_as_failed(task, str(e))
logger.exception(f"Task {task_id} execution failed: {str(e)}")
finally:
self._running_tasks.discard(task_id_str)
def _calculate_next_execution(self, task: ScheduledTask) -> Optional[datetime]:
"""Calculate the next execution time for a recurring task."""
now = datetime.utcnow()
if task.recurrence_type == RecurrenceType.HOURLY:
return now + timedelta(hours=1)
elif task.recurrence_type == RecurrenceType.DAILY:
return now + timedelta(days=1)
elif task.recurrence_type == RecurrenceType.WEEKLY:
return now + timedelta(weeks=1)
elif task.recurrence_type == RecurrenceType.MONTHLY:
# Add approximately one month
return now + timedelta(days=30)
elif task.recurrence_type == RecurrenceType.YEARLY:
return now + timedelta(days=365)
return None
async def _maintenance_job(self) -> None:
"""Periodic maintenance job to clean up expired tasks and handle scheduling issues."""
try:
async with self.db_session_factory() as session:
repo = ScheduledTaskRepository(session)
# Handle expired tasks
expired_tasks = await repo.get_expired_tasks()
for task in expired_tasks:
task.status = TaskStatus.CANCELLED
task.is_active = False
await repo.update(task)
# Remove from scheduler
try:
self.scheduler.remove_job(str(task.id))
except Exception:
pass
if expired_tasks:
logger.info(f"Cleaned up {len(expired_tasks)} expired tasks")
# Handle any missed recurring tasks
due_recurring = await repo.get_recurring_tasks_due_for_next_execution()
for task in due_recurring:
if task.should_repeat():
task.status = TaskStatus.PENDING
task.scheduled_at = task.next_execution_at or datetime.utcnow()
await repo.update(task)
await self._schedule_apscheduler_job(task)
if due_recurring:
logger.info(f"Rescheduled {len(due_recurring)} recurring tasks")
except Exception:
logger.exception("Maintenance job failed")

View File

@@ -0,0 +1,137 @@
"""Task execution handlers for different task types."""
import uuid
from typing import Any, Dict, Optional
from sqlmodel.ext.asyncio.session import AsyncSession
from app.core.logging import get_logger
from app.models.scheduled_task import ScheduledTask, TaskType
from app.repositories.playlist import PlaylistRepository
from app.repositories.sound import SoundRepository
from app.services.credit import CreditService
from app.services.player import PlayerService
logger = get_logger(__name__)
class TaskExecutionError(Exception):
"""Exception raised when task execution fails."""
pass
class TaskHandlerRegistry:
"""Registry for task execution handlers."""
def __init__(
self,
db_session: AsyncSession,
credit_service: CreditService,
player_service: PlayerService,
) -> None:
"""Initialize the task handler registry."""
self.db_session = db_session
self.credit_service = credit_service
self.player_service = player_service
self.sound_repository = SoundRepository(db_session)
self.playlist_repository = PlaylistRepository(db_session)
# Register handlers
self._handlers = {
TaskType.CREDIT_RECHARGE: self._handle_credit_recharge,
TaskType.PLAY_SOUND: self._handle_play_sound,
TaskType.PLAY_PLAYLIST: self._handle_play_playlist,
}
async def execute_task(self, task: ScheduledTask) -> None:
"""Execute a task based on its type."""
handler = self._handlers.get(task.task_type)
if not handler:
raise TaskExecutionError(f"No handler registered for task type: {task.task_type}")
logger.info(f"Executing task {task.id} ({task.task_type.value}): {task.name}")
try:
await handler(task)
logger.info(f"Task {task.id} executed successfully")
except Exception as e:
logger.exception(f"Task {task.id} execution failed: {str(e)}")
raise TaskExecutionError(f"Task execution failed: {str(e)}") from e
async def _handle_credit_recharge(self, task: ScheduledTask) -> None:
"""Handle credit recharge task."""
parameters = task.parameters
user_id = parameters.get("user_id")
if user_id:
# Recharge specific user
user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
stats = await self.credit_service.recharge_user_credits(user_uuid)
logger.info(f"Recharged credits for user {user_id}: {stats}")
else:
# Recharge all users (system task)
stats = await self.credit_service.recharge_all_users_credits()
logger.info(f"Recharged credits for all users: {stats}")
async def _handle_play_sound(self, task: ScheduledTask) -> None:
"""Handle play sound task."""
parameters = task.parameters
sound_id = parameters.get("sound_id")
if not sound_id:
raise TaskExecutionError("sound_id parameter is required for PLAY_SOUND tasks")
try:
sound_uuid = uuid.UUID(sound_id) if isinstance(sound_id, str) else sound_id
except (ValueError, TypeError) as e:
raise TaskExecutionError(f"Invalid sound_id format: {sound_id}") from e
# Get the sound from database
sound = await self.sound_repository.get_by_id(sound_uuid)
if not sound:
raise TaskExecutionError(f"Sound not found: {sound_id}")
# Play the sound through VLC
from app.services.vlc_player import VLCPlayerService
vlc_service = VLCPlayerService(lambda: self.db_session)
await vlc_service.play_sound(sound)
logger.info(f"Played sound {sound.filename} via scheduled task")
async def _handle_play_playlist(self, task: ScheduledTask) -> None:
"""Handle play playlist task."""
parameters = task.parameters
playlist_id = parameters.get("playlist_id")
play_mode = parameters.get("play_mode", "continuous")
shuffle = parameters.get("shuffle", False)
if not playlist_id:
raise TaskExecutionError("playlist_id parameter is required for PLAY_PLAYLIST tasks")
try:
playlist_uuid = uuid.UUID(playlist_id) if isinstance(playlist_id, str) else playlist_id
except (ValueError, TypeError) as e:
raise TaskExecutionError(f"Invalid playlist_id format: {playlist_id}") from e
# Get the playlist from database
playlist = await self.playlist_repository.get_by_id(playlist_uuid)
if not playlist:
raise TaskExecutionError(f"Playlist not found: {playlist_id}")
# Load playlist in player
await self.player_service.load_playlist(playlist_uuid)
# Set play mode if specified
if play_mode in ["continuous", "loop", "loop_one", "random", "single"]:
self.player_service.set_mode(play_mode)
# Enable shuffle if requested
if shuffle:
self.player_service.set_shuffle(True)
# Start playing
self.player_service.play()
logger.info(f"Started playing playlist {playlist.name} via scheduled task")

View File

@@ -238,75 +238,76 @@ class VLCPlayerService:
return
logger.info("Recording play count for sound %s", sound_id)
session = self.db_session_factory()
# Initialize variables for WebSocket event
old_count = 0
sound = None
admin_user_id = None
admin_user_name = None
try:
sound_repo = SoundRepository(session)
user_repo = UserRepository(session)
async with self.db_session_factory() as session:
sound_repo = SoundRepository(session)
user_repo = UserRepository(session)
# Update sound play count
sound = await sound_repo.get_by_id(sound_id)
old_count = 0
if sound:
old_count = sound.play_count
await sound_repo.update(
sound,
{"play_count": sound.play_count + 1},
# Update sound play count
sound = await sound_repo.get_by_id(sound_id)
if sound:
old_count = sound.play_count
# Update the sound's play count using direct attribute modification
sound.play_count = sound.play_count + 1
session.add(sound)
await session.commit()
await session.refresh(sound)
logger.info(
"Updated sound %s play_count: %s -> %s",
sound_id,
old_count,
old_count + 1,
)
else:
logger.warning("Sound %s not found for play count update", sound_id)
# Record play history for admin user (ID 1) as placeholder
# This could be refined to track per-user play history
admin_user = await user_repo.get_by_id(1)
if admin_user:
admin_user_id = admin_user.id
admin_user_name = admin_user.name
# Always create a new SoundPlayed record for each play event
sound_played = SoundPlayed(
user_id=admin_user_id, # Can be None for player-based plays
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Updated sound %s play_count: %s -> %s",
sound_id,
old_count,
old_count + 1,
)
else:
logger.warning("Sound %s not found for play count update", sound_id)
# Record play history for admin user (ID 1) as placeholder
# This could be refined to track per-user play history
admin_user = await user_repo.get_by_id(1)
admin_user_id = None
admin_user_name = None
if admin_user:
admin_user_id = admin_user.id
admin_user_name = admin_user.name
# Always create a new SoundPlayed record for each play event
sound_played = SoundPlayed(
user_id=admin_user_id, # Can be None for player-based plays
sound_id=sound_id,
)
session.add(sound_played)
logger.info(
"Created SoundPlayed record for user %s, sound %s",
admin_user_id,
sound_id,
)
await session.commit()
logger.info("Successfully recorded play count for sound %s", sound_id)
# Emit sound_played event via WebSocket
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": admin_user_id,
"user_name": admin_user_name,
"play_count": (old_count + 1) if sound else None,
}
await socket_manager.broadcast_to_all("sound_played", event_data)
logger.info("Broadcasted sound_played event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_played event for sound %s",
"Created SoundPlayed record for user %s, sound %s",
admin_user_id,
sound_id,
)
await session.commit()
logger.info("Successfully recorded play count for sound %s", sound_id)
except Exception:
logger.exception("Error recording play count for sound %s", sound_id)
await session.rollback()
finally:
await session.close()
# Emit sound_played event via WebSocket (outside session context)
try:
event_data = {
"sound_id": sound_id,
"sound_name": sound_name,
"user_id": admin_user_id,
"user_name": admin_user_name,
"play_count": (old_count + 1) if sound else None,
}
await socket_manager.broadcast_to_all("sound_played", event_data)
logger.info("Broadcasted sound_played event for sound %s", sound_id)
except Exception:
logger.exception(
"Failed to broadcast sound_played event for sound %s",
sound_id,
)
async def play_sound_with_credits(
self,