- Implemented VulnerabilityScanner class to scan images for vulnerabilities using Trivy and NVD API. - Added methods to parse and store vulnerability data in the database. - Created WebSocketManager class to handle real-time notifications for scan status updates. - Integrated WebSocket notifications for scan start, completion, and failure events.
161 lines
6.1 KiB
Python
161 lines
6.1 KiB
Python
import asyncio
|
|
import json
|
|
from typing import Dict, Set
|
|
from datetime import datetime
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
import socketio
|
|
from fastapi import FastAPI
|
|
|
|
|
|
class WebSocketManager:
|
|
def __init__(self):
|
|
self.sio = socketio.AsyncServer(
|
|
async_mode="asgi",
|
|
cors_allowed_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
|
|
logger=True
|
|
)
|
|
self.connected_clients: Set[str] = set()
|
|
self._loop = None
|
|
self._executor = ThreadPoolExecutor(max_workers=2)
|
|
|
|
def get_asgi_app(self, fastapi_app: FastAPI):
|
|
return socketio.ASGIApp(self.sio, fastapi_app)
|
|
|
|
def setup_events(self):
|
|
@self.sio.event
|
|
async def connect(sid, environ, auth):
|
|
print(f"Client {sid} connected")
|
|
self.connected_clients.add(sid)
|
|
await self.sio.emit('connected', {'message': 'Connected to scan notifications'}, room=sid)
|
|
|
|
# Check for any running scans and notify the new client
|
|
await self._notify_existing_scans(sid)
|
|
|
|
# Store the event loop for later use
|
|
if not self._loop:
|
|
self._loop = asyncio.get_event_loop()
|
|
|
|
# Store the method reference for later use
|
|
self._notify_existing_scans = self._notify_existing_scans_impl
|
|
|
|
@self.sio.event
|
|
async def disconnect(sid):
|
|
print(f"Client {sid} disconnected")
|
|
self.connected_clients.discard(sid)
|
|
|
|
async def _notify_existing_scans_impl(self, sid):
|
|
"""Check for running scans and notify the newly connected client"""
|
|
try:
|
|
from models import SessionLocal, ScanJob
|
|
db = SessionLocal()
|
|
try:
|
|
running_jobs = db.query(ScanJob).filter(
|
|
ScanJob.status.in_(["pending", "running"])
|
|
).all()
|
|
|
|
for job in running_jobs:
|
|
# Send scan_started event to the specific client
|
|
message = {
|
|
'type': 'scan_started',
|
|
'timestamp': datetime.utcnow().isoformat(),
|
|
'data': {
|
|
'job_type': job.job_type,
|
|
'job_id': job.id,
|
|
'message': f"{job.job_type.title()} scan in progress",
|
|
'status': job.status
|
|
}
|
|
}
|
|
await self.sio.emit('scan_update', message, room=sid)
|
|
print(f"Notified new client {sid} about running {job.job_type} scan")
|
|
finally:
|
|
db.close()
|
|
except Exception as e:
|
|
print(f"Error notifying existing scans to new client: {e}")
|
|
|
|
async def broadcast_scan_status(self, event_type: str, data: Dict):
|
|
"""Broadcast scan status to all connected clients"""
|
|
message = {
|
|
'type': event_type,
|
|
'timestamp': datetime.utcnow().isoformat(),
|
|
'data': data
|
|
}
|
|
|
|
if self.connected_clients:
|
|
await self.sio.emit('scan_update', message)
|
|
print(f"Broadcasted {event_type} to {len(self.connected_clients)} clients")
|
|
|
|
def notify_scan_started_sync(self, job_type: str, job_id: int, message: str):
|
|
"""Thread-safe version to notify that a scan has started"""
|
|
self._schedule_notification('scan_started', {
|
|
'job_type': job_type,
|
|
'job_id': job_id,
|
|
'message': message,
|
|
'status': 'running'
|
|
})
|
|
|
|
def notify_scan_completed_sync(self, job_type: str, job_id: int, message: str):
|
|
"""Thread-safe version to notify that a scan has completed"""
|
|
self._schedule_notification('scan_completed', {
|
|
'job_type': job_type,
|
|
'job_id': job_id,
|
|
'message': message,
|
|
'status': 'completed'
|
|
})
|
|
|
|
def notify_scan_failed_sync(self, job_type: str, job_id: int, error_message: str):
|
|
"""Thread-safe version to notify that a scan has failed"""
|
|
self._schedule_notification('scan_failed', {
|
|
'job_type': job_type,
|
|
'job_id': job_id,
|
|
'message': f"Scan failed: {error_message}",
|
|
'status': 'failed'
|
|
})
|
|
|
|
def _schedule_notification(self, event_type: str, data: Dict):
|
|
"""Schedule a notification to be sent via WebSocket"""
|
|
if self._loop and not self._loop.is_closed():
|
|
try:
|
|
# Use call_soon_threadsafe to schedule the coroutine
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
self.broadcast_scan_status(event_type, data),
|
|
self._loop
|
|
)
|
|
# Don't wait for the result to avoid blocking the background task
|
|
except Exception as e:
|
|
print(f"Error scheduling WebSocket notification: {e}")
|
|
else:
|
|
print(f"No event loop available for WebSocket notification: {event_type}")
|
|
|
|
# Keep async versions for direct async usage
|
|
async def notify_scan_started(self, job_type: str, job_id: int, message: str):
|
|
"""Notify that a scan has started"""
|
|
await self.broadcast_scan_status('scan_started', {
|
|
'job_type': job_type,
|
|
'job_id': job_id,
|
|
'message': message,
|
|
'status': 'running'
|
|
})
|
|
|
|
async def notify_scan_completed(self, job_type: str, job_id: int, message: str):
|
|
"""Notify that a scan has completed"""
|
|
await self.broadcast_scan_status('scan_completed', {
|
|
'job_type': job_type,
|
|
'job_id': job_id,
|
|
'message': message,
|
|
'status': 'completed'
|
|
})
|
|
|
|
async def notify_scan_failed(self, job_type: str, job_id: int, error_message: str):
|
|
"""Notify that a scan has failed"""
|
|
await self.broadcast_scan_status('scan_failed', {
|
|
'job_type': job_type,
|
|
'job_id': job_id,
|
|
'message': f"Scan failed: {error_message}",
|
|
'status': 'failed'
|
|
})
|
|
|
|
|
|
# Global WebSocket manager instance
|
|
websocket_manager = WebSocketManager() |