import asyncio import json from typing import Dict, Set from datetime import datetime import threading from concurrent.futures import ThreadPoolExecutor import socketio from fastapi import FastAPI class WebSocketManager: def __init__(self): self.sio = socketio.AsyncServer( async_mode="asgi", cors_allowed_origins=["http://localhost:3000", "http://127.0.0.1:3000"], logger=True ) self.connected_clients: Set[str] = set() self._loop = None self._executor = ThreadPoolExecutor(max_workers=2) def get_asgi_app(self, fastapi_app: FastAPI): return socketio.ASGIApp(self.sio, fastapi_app) def setup_events(self): @self.sio.event async def connect(sid, environ, auth): print(f"Client {sid} connected") self.connected_clients.add(sid) await self.sio.emit('connected', {'message': 'Connected to scan notifications'}, room=sid) # Check for any running scans and notify the new client await self._notify_existing_scans(sid) # Store the event loop for later use if not self._loop: self._loop = asyncio.get_event_loop() # Store the method reference for later use self._notify_existing_scans = self._notify_existing_scans_impl @self.sio.event async def disconnect(sid): print(f"Client {sid} disconnected") self.connected_clients.discard(sid) async def _notify_existing_scans_impl(self, sid): """Check for running scans and notify the newly connected client""" try: from models import SessionLocal, ScanJob db = SessionLocal() try: running_jobs = db.query(ScanJob).filter( ScanJob.status.in_(["pending", "running"]) ).all() for job in running_jobs: # Send scan_started event to the specific client message = { 'type': 'scan_started', 'timestamp': datetime.utcnow().isoformat(), 'data': { 'job_type': job.job_type, 'job_id': job.id, 'message': f"{job.job_type.title()} scan in progress", 'status': job.status } } await self.sio.emit('scan_update', message, room=sid) print(f"Notified new client {sid} about running {job.job_type} scan") finally: db.close() except Exception as e: print(f"Error notifying existing scans to new client: {e}") async def broadcast_scan_status(self, event_type: str, data: Dict): """Broadcast scan status to all connected clients""" message = { 'type': event_type, 'timestamp': datetime.utcnow().isoformat(), 'data': data } if self.connected_clients: await self.sio.emit('scan_update', message) print(f"Broadcasted {event_type} to {len(self.connected_clients)} clients") def notify_scan_started_sync(self, job_type: str, job_id: int, message: str): """Thread-safe version to notify that a scan has started""" self._schedule_notification('scan_started', { 'job_type': job_type, 'job_id': job_id, 'message': message, 'status': 'running' }) def notify_scan_completed_sync(self, job_type: str, job_id: int, message: str): """Thread-safe version to notify that a scan has completed""" self._schedule_notification('scan_completed', { 'job_type': job_type, 'job_id': job_id, 'message': message, 'status': 'completed' }) def notify_scan_failed_sync(self, job_type: str, job_id: int, error_message: str): """Thread-safe version to notify that a scan has failed""" self._schedule_notification('scan_failed', { 'job_type': job_type, 'job_id': job_id, 'message': f"Scan failed: {error_message}", 'status': 'failed' }) def _schedule_notification(self, event_type: str, data: Dict): """Schedule a notification to be sent via WebSocket""" if self._loop and not self._loop.is_closed(): try: # Use call_soon_threadsafe to schedule the coroutine future = asyncio.run_coroutine_threadsafe( self.broadcast_scan_status(event_type, data), self._loop ) # Don't wait for the result to avoid blocking the background task except Exception as e: print(f"Error scheduling WebSocket notification: {e}") else: print(f"No event loop available for WebSocket notification: {event_type}") # Keep async versions for direct async usage async def notify_scan_started(self, job_type: str, job_id: int, message: str): """Notify that a scan has started""" await self.broadcast_scan_status('scan_started', { 'job_type': job_type, 'job_id': job_id, 'message': message, 'status': 'running' }) async def notify_scan_completed(self, job_type: str, job_id: int, message: str): """Notify that a scan has completed""" await self.broadcast_scan_status('scan_completed', { 'job_type': job_type, 'job_id': job_id, 'message': message, 'status': 'completed' }) async def notify_scan_failed(self, job_type: str, job_id: int, error_message: str): """Notify that a scan has failed""" await self.broadcast_scan_status('scan_failed', { 'job_type': job_type, 'job_id': job_id, 'message': f"Scan failed: {error_message}", 'status': 'failed' }) # Global WebSocket manager instance websocket_manager = WebSocketManager()