307 lines
10 KiB
Python
307 lines
10 KiB
Python
import json
|
|
import subprocess
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional
|
|
|
|
import requests
|
|
from models import Image, Vulnerability, VulnerabilitySeverity
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
class VulnerabilityScanner:
|
|
def __init__(self):
|
|
self.trivy_available = self._check_trivy_available()
|
|
self.cve_api_base = "https://services.nvd.nist.gov/rest/json/cves/2.0"
|
|
|
|
def _check_trivy_available(self) -> bool:
|
|
try:
|
|
result = subprocess.run(
|
|
["trivy", "--version"], capture_output=True, text=True, timeout=10
|
|
)
|
|
return result.returncode == 0
|
|
except (subprocess.TimeoutExpired, FileNotFoundError):
|
|
return False
|
|
|
|
def scan_image_vulnerabilities(
|
|
self, db: Session, image: Image, scan_job_id: Optional[int] = None
|
|
) -> List[Vulnerability]:
|
|
vulnerabilities = []
|
|
|
|
if self.trivy_available:
|
|
vulnerabilities.extend(self._scan_with_trivy(db, image, scan_job_id))
|
|
else:
|
|
vulnerabilities.extend(self._scan_with_api(db, image, scan_job_id))
|
|
|
|
return vulnerabilities
|
|
|
|
def _scan_with_trivy(self, db: Session, image: Image, scan_job_id: Optional[int] = None) -> List[Vulnerability]:
|
|
vulnerabilities = []
|
|
|
|
try:
|
|
cmd = [
|
|
"trivy",
|
|
"image",
|
|
"--format",
|
|
"json",
|
|
"--no-progress",
|
|
"--quiet",
|
|
"--scanners",
|
|
"vuln",
|
|
image.full_image_name,
|
|
]
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
|
|
|
if result.returncode == 0:
|
|
data = json.loads(result.stdout)
|
|
vulnerabilities = self._parse_trivy_results(db, image, data, scan_job_id)
|
|
else:
|
|
print(f"Trivy scan failed for {image.full_image_name}: {result.stderr}")
|
|
|
|
except (subprocess.TimeoutExpired, json.JSONDecodeError, Exception) as e:
|
|
print(f"Error scanning {image.full_image_name} with Trivy: {e}")
|
|
|
|
return vulnerabilities
|
|
|
|
def _parse_trivy_results(
|
|
self, db: Session, image: Image, data: Dict, scan_job_id: Optional[int] = None
|
|
) -> List[Vulnerability]:
|
|
vulnerabilities = []
|
|
|
|
results = data.get("Results", [])
|
|
for result in results:
|
|
vulns = result.get("Vulnerabilities", [])
|
|
for vuln in vulns:
|
|
vulnerability = self._create_vulnerability_from_trivy(db, image, vuln, scan_job_id)
|
|
if vulnerability:
|
|
vulnerabilities.append(vulnerability)
|
|
|
|
return vulnerabilities
|
|
|
|
def _create_vulnerability_from_trivy(
|
|
self, db: Session, image: Image, vuln_data: Dict, scan_job_id: Optional[int] = None
|
|
) -> Optional[Vulnerability]:
|
|
vulnerability_id = vuln_data.get("VulnerabilityID")
|
|
if not vulnerability_id:
|
|
return None
|
|
|
|
existing = (
|
|
db.query(Vulnerability)
|
|
.filter(
|
|
Vulnerability.image_id == image.id,
|
|
Vulnerability.vulnerability_id == vulnerability_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if existing:
|
|
existing.is_active = True
|
|
existing.scan_date = datetime.utcnow()
|
|
if scan_job_id:
|
|
existing.scan_job_id = scan_job_id
|
|
db.commit()
|
|
return existing
|
|
|
|
severity = self._normalize_severity(vuln_data.get("Severity", "UNKNOWN"))
|
|
title = vuln_data.get("Title", "")
|
|
description = vuln_data.get("Description", "")
|
|
|
|
cvss_score = None
|
|
if "CVSS" in vuln_data:
|
|
cvss_data = vuln_data["CVSS"]
|
|
if isinstance(cvss_data, dict):
|
|
for version, score_data in cvss_data.items():
|
|
if "V3Score" in score_data:
|
|
cvss_score = str(score_data["V3Score"])
|
|
break
|
|
elif "V2Score" in score_data:
|
|
cvss_score = str(score_data["V2Score"])
|
|
|
|
published_date = None
|
|
if "PublishedDate" in vuln_data:
|
|
try:
|
|
published_date = datetime.fromisoformat(
|
|
vuln_data["PublishedDate"].replace("Z", "+00:00")
|
|
)
|
|
except ValueError:
|
|
pass
|
|
|
|
fixed_version = vuln_data.get("FixedVersion")
|
|
|
|
vulnerability = Vulnerability(
|
|
image_id=image.id,
|
|
scan_job_id=scan_job_id,
|
|
vulnerability_id=vulnerability_id,
|
|
severity=severity,
|
|
title=title,
|
|
description=description,
|
|
cvss_score=cvss_score,
|
|
published_date=published_date,
|
|
fixed_version=fixed_version,
|
|
scan_date=datetime.utcnow(),
|
|
)
|
|
|
|
db.add(vulnerability)
|
|
db.commit()
|
|
db.refresh(vulnerability)
|
|
|
|
return vulnerability
|
|
|
|
def _scan_with_api(self, db: Session, image: Image, scan_job_id: Optional[int] = None) -> List[Vulnerability]:
|
|
vulnerabilities = []
|
|
|
|
try:
|
|
response = requests.get(
|
|
f"{self.cve_api_base}",
|
|
params={
|
|
"keywordSearch": image.image_name,
|
|
"resultsPerPage": 50,
|
|
},
|
|
timeout=30,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
vulnerabilities = self._parse_api_results(db, image, data, scan_job_id)
|
|
else:
|
|
print(
|
|
f"API scan failed for {image.full_image_name}: {response.status_code}"
|
|
)
|
|
|
|
except requests.RequestException as e:
|
|
print(f"Error scanning {image.full_image_name} with API: {e}")
|
|
|
|
return vulnerabilities
|
|
|
|
def _parse_api_results(
|
|
self, db: Session, image: Image, data: Dict, scan_job_id: Optional[int] = None
|
|
) -> List[Vulnerability]:
|
|
vulnerabilities = []
|
|
|
|
cves = data.get("vulnerabilities", [])
|
|
for cve in cves:
|
|
cve_data = cve.get("cve", {})
|
|
vulnerability = self._create_vulnerability_from_api(db, image, cve_data, scan_job_id)
|
|
if vulnerability:
|
|
vulnerabilities.append(vulnerability)
|
|
|
|
return vulnerabilities
|
|
|
|
def _create_vulnerability_from_api(
|
|
self, db: Session, image: Image, cve_data: Dict, scan_job_id: Optional[int] = None
|
|
) -> Optional[Vulnerability]:
|
|
vulnerability_id = cve_data.get("id")
|
|
if not vulnerability_id:
|
|
return None
|
|
|
|
existing = (
|
|
db.query(Vulnerability)
|
|
.filter(
|
|
Vulnerability.image_id == image.id,
|
|
Vulnerability.vulnerability_id == vulnerability_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if existing:
|
|
existing.is_active = True
|
|
existing.scan_date = datetime.utcnow()
|
|
if scan_job_id:
|
|
existing.scan_job_id = scan_job_id
|
|
db.commit()
|
|
return existing
|
|
|
|
descriptions = cve_data.get("descriptions", [])
|
|
title = ""
|
|
description = ""
|
|
|
|
for desc in descriptions:
|
|
if desc.get("lang") == "en":
|
|
description = desc.get("value", "")
|
|
title = (
|
|
description[:100] + "..." if len(description) > 100 else description
|
|
)
|
|
break
|
|
|
|
metrics = cve_data.get("metrics", {})
|
|
cvss_score = None
|
|
severity = "unspecified"
|
|
|
|
for metric_type in ["cvssMetricV31", "cvssMetricV30", "cvssMetricV2"]:
|
|
if metric_type in metrics:
|
|
metric_data = metrics[metric_type]
|
|
if isinstance(metric_data, list) and len(metric_data) > 0:
|
|
cvss_data = metric_data[0].get("cvssData", {})
|
|
if "baseScore" in cvss_data:
|
|
cvss_score = str(cvss_data["baseScore"])
|
|
severity = self._cvss_score_to_severity(cvss_data["baseScore"])
|
|
break
|
|
|
|
published_date = None
|
|
if "published" in cve_data:
|
|
try:
|
|
published_date = datetime.fromisoformat(
|
|
cve_data["published"].replace("Z", "+00:00")
|
|
)
|
|
except ValueError:
|
|
pass
|
|
|
|
vulnerability = Vulnerability(
|
|
image_id=image.id,
|
|
scan_job_id=scan_job_id,
|
|
vulnerability_id=vulnerability_id,
|
|
severity=severity,
|
|
title=title,
|
|
description=description,
|
|
cvss_score=cvss_score,
|
|
published_date=published_date,
|
|
scan_date=datetime.utcnow(),
|
|
)
|
|
|
|
db.add(vulnerability)
|
|
db.commit()
|
|
db.refresh(vulnerability)
|
|
|
|
return vulnerability
|
|
|
|
def _normalize_severity(self, severity: str) -> str:
|
|
severity_lower = severity.lower()
|
|
|
|
if severity_lower in ["critical", "high", "medium", "low"]:
|
|
return severity_lower
|
|
elif severity_lower in ["unknown", "negligible", "unspecified"]:
|
|
return "unspecified"
|
|
else:
|
|
return "unspecified"
|
|
|
|
def _cvss_score_to_severity(self, score: float) -> str:
|
|
if score >= 9.0:
|
|
return "critical"
|
|
elif score >= 7.0:
|
|
return "high"
|
|
elif score >= 4.0:
|
|
return "medium"
|
|
elif score >= 0.1:
|
|
return "low"
|
|
else:
|
|
return "unspecified"
|
|
|
|
def scan_all_images(self, db: Session, scan_job_id: Optional[int] = None) -> None:
|
|
images = db.query(Image).filter(Image.is_active == True).all()
|
|
|
|
for image in images:
|
|
try:
|
|
print(f"Scanning vulnerabilities for {image.full_image_name}")
|
|
vulnerabilities = self.scan_image_vulnerabilities(db, image, scan_job_id)
|
|
print(f"Found {len(vulnerabilities)} vulnerabilities")
|
|
except Exception as e:
|
|
print(f"Error scanning {image.full_image_name}: {e}")
|
|
continue
|
|
|
|
def cleanup_old_vulnerabilities(self, db: Session) -> None:
|
|
db.query(Vulnerability).filter(
|
|
Vulnerability.scan_date
|
|
< datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
|
).update({Vulnerability.is_active: False})
|
|
db.commit()
|