Add vulnerability scanner and WebSocket manager for scan notifications
- Implemented VulnerabilityScanner class to scan images for vulnerabilities using Trivy and NVD API. - Added methods to parse and store vulnerability data in the database. - Created WebSocketManager class to handle real-time notifications for scan status updates. - Integrated WebSocket notifications for scan start, completion, and failure events.
This commit is contained in:
300
vulnerability_scanner.py
Normal file
300
vulnerability_scanner.py
Normal file
@@ -0,0 +1,300 @@
|
||||
import json
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from models import Image, Vulnerability, VulnerabilitySeverity
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class VulnerabilityScanner:
|
||||
def __init__(self):
|
||||
self.trivy_available = self._check_trivy_available()
|
||||
self.cve_api_base = "https://services.nvd.nist.gov/rest/json/cves/2.0"
|
||||
|
||||
def _check_trivy_available(self) -> bool:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["trivy", "--version"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def scan_image_vulnerabilities(
|
||||
self, db: Session, image: Image
|
||||
) -> List[Vulnerability]:
|
||||
vulnerabilities = []
|
||||
|
||||
if self.trivy_available:
|
||||
vulnerabilities.extend(self._scan_with_trivy(db, image))
|
||||
else:
|
||||
vulnerabilities.extend(self._scan_with_api(db, image))
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _scan_with_trivy(self, db: Session, image: Image) -> List[Vulnerability]:
|
||||
vulnerabilities = []
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
"trivy",
|
||||
"image",
|
||||
"--format",
|
||||
"json",
|
||||
"--no-progress",
|
||||
"--quiet",
|
||||
"--scanners",
|
||||
"vuln",
|
||||
image.full_image_name,
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
|
||||
if result.returncode == 0:
|
||||
data = json.loads(result.stdout)
|
||||
vulnerabilities = self._parse_trivy_results(db, image, data)
|
||||
else:
|
||||
print(f"Trivy scan failed for {image.full_image_name}: {result.stderr}")
|
||||
|
||||
except (subprocess.TimeoutExpired, json.JSONDecodeError, Exception) as e:
|
||||
print(f"Error scanning {image.full_image_name} with Trivy: {e}")
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _parse_trivy_results(
|
||||
self, db: Session, image: Image, data: Dict
|
||||
) -> List[Vulnerability]:
|
||||
vulnerabilities = []
|
||||
|
||||
results = data.get("Results", [])
|
||||
for result in results:
|
||||
vulns = result.get("Vulnerabilities", [])
|
||||
for vuln in vulns:
|
||||
vulnerability = self._create_vulnerability_from_trivy(db, image, vuln)
|
||||
if vulnerability:
|
||||
vulnerabilities.append(vulnerability)
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _create_vulnerability_from_trivy(
|
||||
self, db: Session, image: Image, vuln_data: Dict
|
||||
) -> Optional[Vulnerability]:
|
||||
vulnerability_id = vuln_data.get("VulnerabilityID")
|
||||
if not vulnerability_id:
|
||||
return None
|
||||
|
||||
existing = (
|
||||
db.query(Vulnerability)
|
||||
.filter(
|
||||
Vulnerability.image_id == image.id,
|
||||
Vulnerability.vulnerability_id == vulnerability_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.is_active = True
|
||||
existing.scan_date = datetime.utcnow()
|
||||
db.commit()
|
||||
return existing
|
||||
|
||||
severity = self._normalize_severity(vuln_data.get("Severity", "UNKNOWN"))
|
||||
title = vuln_data.get("Title", "")
|
||||
description = vuln_data.get("Description", "")
|
||||
|
||||
cvss_score = None
|
||||
if "CVSS" in vuln_data:
|
||||
cvss_data = vuln_data["CVSS"]
|
||||
if isinstance(cvss_data, dict):
|
||||
for version, score_data in cvss_data.items():
|
||||
if "V3Score" in score_data:
|
||||
cvss_score = str(score_data["V3Score"])
|
||||
break
|
||||
elif "V2Score" in score_data:
|
||||
cvss_score = str(score_data["V2Score"])
|
||||
|
||||
published_date = None
|
||||
if "PublishedDate" in vuln_data:
|
||||
try:
|
||||
published_date = datetime.fromisoformat(
|
||||
vuln_data["PublishedDate"].replace("Z", "+00:00")
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
fixed_version = vuln_data.get("FixedVersion")
|
||||
|
||||
vulnerability = Vulnerability(
|
||||
image_id=image.id,
|
||||
vulnerability_id=vulnerability_id,
|
||||
severity=severity,
|
||||
title=title,
|
||||
description=description,
|
||||
cvss_score=cvss_score,
|
||||
published_date=published_date,
|
||||
fixed_version=fixed_version,
|
||||
scan_date=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(vulnerability)
|
||||
db.commit()
|
||||
db.refresh(vulnerability)
|
||||
|
||||
return vulnerability
|
||||
|
||||
def _scan_with_api(self, db: Session, image: Image) -> List[Vulnerability]:
|
||||
vulnerabilities = []
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.cve_api_base}",
|
||||
params={
|
||||
"keywordSearch": image.image_name,
|
||||
"resultsPerPage": 50,
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
vulnerabilities = self._parse_api_results(db, image, data)
|
||||
else:
|
||||
print(
|
||||
f"API scan failed for {image.full_image_name}: {response.status_code}"
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f"Error scanning {image.full_image_name} with API: {e}")
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _parse_api_results(
|
||||
self, db: Session, image: Image, data: Dict
|
||||
) -> List[Vulnerability]:
|
||||
vulnerabilities = []
|
||||
|
||||
cves = data.get("vulnerabilities", [])
|
||||
for cve in cves:
|
||||
cve_data = cve.get("cve", {})
|
||||
vulnerability = self._create_vulnerability_from_api(db, image, cve_data)
|
||||
if vulnerability:
|
||||
vulnerabilities.append(vulnerability)
|
||||
|
||||
return vulnerabilities
|
||||
|
||||
def _create_vulnerability_from_api(
|
||||
self, db: Session, image: Image, cve_data: Dict
|
||||
) -> Optional[Vulnerability]:
|
||||
vulnerability_id = cve_data.get("id")
|
||||
if not vulnerability_id:
|
||||
return None
|
||||
|
||||
existing = (
|
||||
db.query(Vulnerability)
|
||||
.filter(
|
||||
Vulnerability.image_id == image.id,
|
||||
Vulnerability.vulnerability_id == vulnerability_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.is_active = True
|
||||
existing.scan_date = datetime.utcnow()
|
||||
db.commit()
|
||||
return existing
|
||||
|
||||
descriptions = cve_data.get("descriptions", [])
|
||||
title = ""
|
||||
description = ""
|
||||
|
||||
for desc in descriptions:
|
||||
if desc.get("lang") == "en":
|
||||
description = desc.get("value", "")
|
||||
title = (
|
||||
description[:100] + "..." if len(description) > 100 else description
|
||||
)
|
||||
break
|
||||
|
||||
metrics = cve_data.get("metrics", {})
|
||||
cvss_score = None
|
||||
severity = "unspecified"
|
||||
|
||||
for metric_type in ["cvssMetricV31", "cvssMetricV30", "cvssMetricV2"]:
|
||||
if metric_type in metrics:
|
||||
metric_data = metrics[metric_type]
|
||||
if isinstance(metric_data, list) and len(metric_data) > 0:
|
||||
cvss_data = metric_data[0].get("cvssData", {})
|
||||
if "baseScore" in cvss_data:
|
||||
cvss_score = str(cvss_data["baseScore"])
|
||||
severity = self._cvss_score_to_severity(cvss_data["baseScore"])
|
||||
break
|
||||
|
||||
published_date = None
|
||||
if "published" in cve_data:
|
||||
try:
|
||||
published_date = datetime.fromisoformat(
|
||||
cve_data["published"].replace("Z", "+00:00")
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
vulnerability = Vulnerability(
|
||||
image_id=image.id,
|
||||
vulnerability_id=vulnerability_id,
|
||||
severity=severity,
|
||||
title=title,
|
||||
description=description,
|
||||
cvss_score=cvss_score,
|
||||
published_date=published_date,
|
||||
scan_date=datetime.utcnow(),
|
||||
)
|
||||
|
||||
db.add(vulnerability)
|
||||
db.commit()
|
||||
db.refresh(vulnerability)
|
||||
|
||||
return vulnerability
|
||||
|
||||
def _normalize_severity(self, severity: str) -> str:
|
||||
severity_lower = severity.lower()
|
||||
|
||||
if severity_lower in ["critical", "high", "medium", "low"]:
|
||||
return severity_lower
|
||||
elif severity_lower in ["unknown", "negligible", "unspecified"]:
|
||||
return "unspecified"
|
||||
else:
|
||||
return "unspecified"
|
||||
|
||||
def _cvss_score_to_severity(self, score: float) -> str:
|
||||
if score >= 9.0:
|
||||
return "critical"
|
||||
elif score >= 7.0:
|
||||
return "high"
|
||||
elif score >= 4.0:
|
||||
return "medium"
|
||||
elif score >= 0.1:
|
||||
return "low"
|
||||
else:
|
||||
return "unspecified"
|
||||
|
||||
def scan_all_images(self, db: Session) -> None:
|
||||
images = db.query(Image).filter(Image.is_active == True).all()
|
||||
|
||||
for image in images:
|
||||
try:
|
||||
print(f"Scanning vulnerabilities for {image.full_image_name}")
|
||||
vulnerabilities = self.scan_image_vulnerabilities(db, image)
|
||||
print(f"Found {len(vulnerabilities)} vulnerabilities")
|
||||
except Exception as e:
|
||||
print(f"Error scanning {image.full_image_name}: {e}")
|
||||
continue
|
||||
|
||||
def cleanup_old_vulnerabilities(self, db: Session) -> None:
|
||||
db.query(Vulnerability).filter(
|
||||
Vulnerability.scan_date
|
||||
< datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
).update({Vulnerability.is_active: False})
|
||||
db.commit()
|
||||
Reference in New Issue
Block a user