import json import subprocess from datetime import datetime from typing import Dict, List, Optional import requests from models import Image, Vulnerability, VulnerabilitySeverity from sqlalchemy.orm import Session class VulnerabilityScanner: def __init__(self): self.trivy_available = self._check_trivy_available() self.cve_api_base = "https://services.nvd.nist.gov/rest/json/cves/2.0" def _check_trivy_available(self) -> bool: try: result = subprocess.run( ["trivy", "--version"], capture_output=True, text=True, timeout=10 ) return result.returncode == 0 except (subprocess.TimeoutExpired, FileNotFoundError): return False def scan_image_vulnerabilities( self, db: Session, image: Image, scan_job_id: Optional[int] = None ) -> List[Vulnerability]: vulnerabilities = [] if self.trivy_available: vulnerabilities.extend(self._scan_with_trivy(db, image, scan_job_id)) else: vulnerabilities.extend(self._scan_with_api(db, image, scan_job_id)) return vulnerabilities def _scan_with_trivy(self, db: Session, image: Image, scan_job_id: Optional[int] = None) -> List[Vulnerability]: vulnerabilities = [] try: cmd = [ "trivy", "image", "--format", "json", "--no-progress", "--quiet", "--scanners", "vuln", image.full_image_name, ] result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) if result.returncode == 0: data = json.loads(result.stdout) vulnerabilities = self._parse_trivy_results(db, image, data, scan_job_id) else: print(f"Trivy scan failed for {image.full_image_name}: {result.stderr}") except (subprocess.TimeoutExpired, json.JSONDecodeError, Exception) as e: print(f"Error scanning {image.full_image_name} with Trivy: {e}") return vulnerabilities def _parse_trivy_results( self, db: Session, image: Image, data: Dict, scan_job_id: Optional[int] = None ) -> List[Vulnerability]: vulnerabilities = [] results = data.get("Results", []) for result in results: vulns = result.get("Vulnerabilities", []) for vuln in vulns: vulnerability = self._create_vulnerability_from_trivy(db, image, vuln, scan_job_id) if vulnerability: vulnerabilities.append(vulnerability) return vulnerabilities def _create_vulnerability_from_trivy( self, db: Session, image: Image, vuln_data: Dict, scan_job_id: Optional[int] = None ) -> Optional[Vulnerability]: vulnerability_id = vuln_data.get("VulnerabilityID") if not vulnerability_id: return None existing = ( db.query(Vulnerability) .filter( Vulnerability.image_id == image.id, Vulnerability.vulnerability_id == vulnerability_id, ) .first() ) if existing: existing.is_active = True existing.scan_date = datetime.utcnow() if scan_job_id: existing.scan_job_id = scan_job_id db.commit() return existing severity = self._normalize_severity(vuln_data.get("Severity", "UNKNOWN")) title = vuln_data.get("Title", "") description = vuln_data.get("Description", "") cvss_score = None if "CVSS" in vuln_data: cvss_data = vuln_data["CVSS"] if isinstance(cvss_data, dict): for version, score_data in cvss_data.items(): if "V3Score" in score_data: cvss_score = str(score_data["V3Score"]) break elif "V2Score" in score_data: cvss_score = str(score_data["V2Score"]) published_date = None if "PublishedDate" in vuln_data: try: published_date = datetime.fromisoformat( vuln_data["PublishedDate"].replace("Z", "+00:00") ) except ValueError: pass fixed_version = vuln_data.get("FixedVersion") vulnerability = Vulnerability( image_id=image.id, scan_job_id=scan_job_id, vulnerability_id=vulnerability_id, severity=severity, title=title, description=description, cvss_score=cvss_score, published_date=published_date, fixed_version=fixed_version, scan_date=datetime.utcnow(), ) db.add(vulnerability) db.commit() db.refresh(vulnerability) return vulnerability def _scan_with_api(self, db: Session, image: Image, scan_job_id: Optional[int] = None) -> List[Vulnerability]: vulnerabilities = [] try: response = requests.get( f"{self.cve_api_base}", params={ "keywordSearch": image.image_name, "resultsPerPage": 50, }, timeout=30, ) if response.status_code == 200: data = response.json() vulnerabilities = self._parse_api_results(db, image, data, scan_job_id) else: print( f"API scan failed for {image.full_image_name}: {response.status_code}" ) except requests.RequestException as e: print(f"Error scanning {image.full_image_name} with API: {e}") return vulnerabilities def _parse_api_results( self, db: Session, image: Image, data: Dict, scan_job_id: Optional[int] = None ) -> List[Vulnerability]: vulnerabilities = [] cves = data.get("vulnerabilities", []) for cve in cves: cve_data = cve.get("cve", {}) vulnerability = self._create_vulnerability_from_api(db, image, cve_data, scan_job_id) if vulnerability: vulnerabilities.append(vulnerability) return vulnerabilities def _create_vulnerability_from_api( self, db: Session, image: Image, cve_data: Dict, scan_job_id: Optional[int] = None ) -> Optional[Vulnerability]: vulnerability_id = cve_data.get("id") if not vulnerability_id: return None existing = ( db.query(Vulnerability) .filter( Vulnerability.image_id == image.id, Vulnerability.vulnerability_id == vulnerability_id, ) .first() ) if existing: existing.is_active = True existing.scan_date = datetime.utcnow() if scan_job_id: existing.scan_job_id = scan_job_id db.commit() return existing descriptions = cve_data.get("descriptions", []) title = "" description = "" for desc in descriptions: if desc.get("lang") == "en": description = desc.get("value", "") title = ( description[:100] + "..." if len(description) > 100 else description ) break metrics = cve_data.get("metrics", {}) cvss_score = None severity = "unspecified" for metric_type in ["cvssMetricV31", "cvssMetricV30", "cvssMetricV2"]: if metric_type in metrics: metric_data = metrics[metric_type] if isinstance(metric_data, list) and len(metric_data) > 0: cvss_data = metric_data[0].get("cvssData", {}) if "baseScore" in cvss_data: cvss_score = str(cvss_data["baseScore"]) severity = self._cvss_score_to_severity(cvss_data["baseScore"]) break published_date = None if "published" in cve_data: try: published_date = datetime.fromisoformat( cve_data["published"].replace("Z", "+00:00") ) except ValueError: pass vulnerability = Vulnerability( image_id=image.id, scan_job_id=scan_job_id, vulnerability_id=vulnerability_id, severity=severity, title=title, description=description, cvss_score=cvss_score, published_date=published_date, scan_date=datetime.utcnow(), ) db.add(vulnerability) db.commit() db.refresh(vulnerability) return vulnerability def _normalize_severity(self, severity: str) -> str: severity_lower = severity.lower() if severity_lower in ["critical", "high", "medium", "low"]: return severity_lower elif severity_lower in ["unknown", "negligible", "unspecified"]: return "unspecified" else: return "unspecified" def _cvss_score_to_severity(self, score: float) -> str: if score >= 9.0: return "critical" elif score >= 7.0: return "high" elif score >= 4.0: return "medium" elif score >= 0.1: return "low" else: return "unspecified" def scan_all_images(self, db: Session, scan_job_id: Optional[int] = None) -> None: images = db.query(Image).filter(Image.is_active == True).all() for image in images: try: print(f"Scanning vulnerabilities for {image.full_image_name}") vulnerabilities = self.scan_image_vulnerabilities(db, image, scan_job_id) print(f"Found {len(vulnerabilities)} vulnerabilities") except Exception as e: print(f"Error scanning {image.full_image_name}: {e}") continue def cleanup_old_vulnerabilities(self, db: Session) -> None: db.query(Vulnerability).filter( Vulnerability.scan_date < datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0) ).update({Vulnerability.is_active: False}) db.commit()