Refactor scheduled task repository and schemas for improved type hints and consistency
- Updated type hints from List/Optional to list/None for better readability and consistency across the codebase. - Refactored import statements for better organization and clarity. - Enhanced the ScheduledTaskBase schema to use modern type hints. - Cleaned up unnecessary comments and whitespace in various files. - Improved error handling and logging in task execution handlers. - Updated test cases to reflect changes in type hints and ensure compatibility with the new structure.
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
"""API endpoints for scheduled task management."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -11,7 +9,7 @@ from app.core.dependencies import (
|
||||
get_admin_user,
|
||||
get_current_active_user,
|
||||
)
|
||||
from app.models.scheduled_task import RecurrenceType, ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.scheduled_task import ScheduledTask, TaskStatus, TaskType
|
||||
from app.models.user import User
|
||||
from app.schemas.scheduler import (
|
||||
ScheduledTaskCreate,
|
||||
@@ -54,15 +52,15 @@ async def create_task(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=List[ScheduledTaskResponse])
|
||||
@router.get("/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_user_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
limit: Optional[int] = Query(50, description="Maximum number of tasks to return"),
|
||||
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
limit: int | None = Query(50, description="Maximum number of tasks to return"),
|
||||
offset: int | None = Query(0, description="Number of tasks to skip"),
|
||||
current_user: User = Depends(get_current_active_user),
|
||||
scheduler_service: SchedulerService = Depends(get_scheduler_service),
|
||||
) -> List[ScheduledTask]:
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get user's scheduled tasks."""
|
||||
return await scheduler_service.get_user_tasks(
|
||||
user_id=current_user.id,
|
||||
@@ -81,17 +79,17 @@ async def get_task(
|
||||
) -> ScheduledTask:
|
||||
"""Get a specific scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
|
||||
return task
|
||||
|
||||
|
||||
@@ -104,22 +102,22 @@ async def update_task(
|
||||
) -> ScheduledTask:
|
||||
"""Update a scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
|
||||
# Update task fields
|
||||
update_data = task_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
|
||||
|
||||
updated_task = await repo.update(task)
|
||||
return updated_task
|
||||
|
||||
@@ -133,72 +131,72 @@ async def cancel_task(
|
||||
) -> dict:
|
||||
"""Cancel a scheduled task."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
task = await repo.get_by_id(task_id)
|
||||
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
|
||||
# Check if user owns the task or is admin
|
||||
if task.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
|
||||
success = await scheduler_service.cancel_task(task_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="Failed to cancel task")
|
||||
|
||||
|
||||
return {"message": "Task cancelled successfully"}
|
||||
|
||||
|
||||
# Admin-only endpoints
|
||||
@router.get("/admin/tasks", response_model=List[ScheduledTaskResponse])
|
||||
@router.get("/admin/tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_all_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of tasks to return"),
|
||||
offset: Optional[int] = Query(0, description="Number of tasks to skip"),
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
limit: int | None = Query(100, description="Maximum number of tasks to return"),
|
||||
offset: int | None = Query(0, description="Number of tasks to skip"),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> List[ScheduledTask]:
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get all scheduled tasks (admin only)."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
|
||||
|
||||
# Get all tasks with pagination and filtering
|
||||
from sqlmodel import select
|
||||
|
||||
|
||||
statement = select(ScheduledTask)
|
||||
|
||||
|
||||
if status:
|
||||
statement = statement.where(ScheduledTask.status == status)
|
||||
|
||||
|
||||
if task_type:
|
||||
statement = statement.where(ScheduledTask.task_type == task_type)
|
||||
|
||||
|
||||
statement = statement.order_by(ScheduledTask.scheduled_at.desc())
|
||||
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
|
||||
result = await db_session.exec(statement)
|
||||
return list(result.all())
|
||||
|
||||
|
||||
@router.get("/admin/system-tasks", response_model=List[ScheduledTaskResponse])
|
||||
@router.get("/admin/system-tasks", response_model=list[ScheduledTaskResponse])
|
||||
async def get_system_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by task status"),
|
||||
task_type: Optional[TaskType] = Query(None, description="Filter by task type"),
|
||||
status: TaskStatus | None = Query(None, description="Filter by task status"),
|
||||
task_type: TaskType | None = Query(None, description="Filter by task type"),
|
||||
current_user: User = Depends(get_admin_user),
|
||||
db_session: AsyncSession = Depends(get_db),
|
||||
) -> List[ScheduledTask]:
|
||||
) -> list[ScheduledTask]:
|
||||
"""Get system tasks (admin only)."""
|
||||
from app.repositories.scheduled_task import ScheduledTaskRepository
|
||||
|
||||
|
||||
repo = ScheduledTaskRepository(db_session)
|
||||
return await repo.get_system_tasks(status=status, task_type=task_type)
|
||||
|
||||
@@ -225,4 +223,4 @@ async def create_system_task(
|
||||
)
|
||||
return task
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
Reference in New Issue
Block a user