diff --git a/main.py b/main.py index cbdb46c..5958272 100644 --- a/main.py +++ b/main.py @@ -122,7 +122,7 @@ def background_vulnerability_scan(job_id: int): websocket_manager.notify_scan_started_sync("vulnerability", job_id, "Vulnerability scan started") vuln_scanner = VulnerabilityScanner() - vuln_scanner.scan_all_images(db) + vuln_scanner.scan_all_images(db, job_id) if job: job.status = "completed" @@ -185,6 +185,7 @@ class ImageResponse(BaseModel): class VulnerabilityResponse(BaseModel): id: int image_id: int + scan_job_id: Optional[int] = None vulnerability_id: str severity: str title: Optional[str] = None diff --git a/models.py b/models.py index d7796bc..615d6f1 100644 --- a/models.py +++ b/models.py @@ -119,6 +119,7 @@ class Vulnerability(Base): id = Column(Integer, primary_key=True) image_id = Column(Integer, ForeignKey("images.id"), nullable=False) + scan_job_id = Column(Integer, ForeignKey("scan_jobs.id"), nullable=True) vulnerability_id = Column(String(100), nullable=False) severity = Column(String(20), nullable=False) title = Column(String(500), nullable=True) @@ -132,6 +133,7 @@ class Vulnerability(Base): updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) image = relationship("Image", back_populates="vulnerabilities") + scan_job = relationship("ScanJob", back_populates="vulnerabilities") __table_args__ = ( UniqueConstraint( @@ -170,6 +172,8 @@ class ScanJob(Base): created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + vulnerabilities = relationship("Vulnerability", back_populates="scan_job", cascade="all, delete-orphan") + DATABASE_URL = "sqlite:///./gitlab_docker_tracker.db" engine = create_engine(DATABASE_URL, echo=False) diff --git a/vulnerability_scanner.py b/vulnerability_scanner.py index 62f92c7..b2c83ff 100644 --- a/vulnerability_scanner.py +++ b/vulnerability_scanner.py @@ -23,18 +23,18 @@ class VulnerabilityScanner: return False def scan_image_vulnerabilities( - self, db: Session, image: Image + self, db: Session, image: Image, scan_job_id: Optional[int] = None ) -> List[Vulnerability]: vulnerabilities = [] if self.trivy_available: - vulnerabilities.extend(self._scan_with_trivy(db, image)) + vulnerabilities.extend(self._scan_with_trivy(db, image, scan_job_id)) else: - vulnerabilities.extend(self._scan_with_api(db, image)) + vulnerabilities.extend(self._scan_with_api(db, image, scan_job_id)) return vulnerabilities - def _scan_with_trivy(self, db: Session, image: Image) -> List[Vulnerability]: + def _scan_with_trivy(self, db: Session, image: Image, scan_job_id: Optional[int] = None) -> List[Vulnerability]: vulnerabilities = [] try: @@ -54,7 +54,7 @@ class VulnerabilityScanner: if result.returncode == 0: data = json.loads(result.stdout) - vulnerabilities = self._parse_trivy_results(db, image, data) + vulnerabilities = self._parse_trivy_results(db, image, data, scan_job_id) else: print(f"Trivy scan failed for {image.full_image_name}: {result.stderr}") @@ -64,7 +64,7 @@ class VulnerabilityScanner: return vulnerabilities def _parse_trivy_results( - self, db: Session, image: Image, data: Dict + self, db: Session, image: Image, data: Dict, scan_job_id: Optional[int] = None ) -> List[Vulnerability]: vulnerabilities = [] @@ -72,14 +72,14 @@ class VulnerabilityScanner: for result in results: vulns = result.get("Vulnerabilities", []) for vuln in vulns: - vulnerability = self._create_vulnerability_from_trivy(db, image, vuln) + vulnerability = self._create_vulnerability_from_trivy(db, image, vuln, scan_job_id) if vulnerability: vulnerabilities.append(vulnerability) return vulnerabilities def _create_vulnerability_from_trivy( - self, db: Session, image: Image, vuln_data: Dict + self, db: Session, image: Image, vuln_data: Dict, scan_job_id: Optional[int] = None ) -> Optional[Vulnerability]: vulnerability_id = vuln_data.get("VulnerabilityID") if not vulnerability_id: @@ -97,6 +97,8 @@ class VulnerabilityScanner: if existing: existing.is_active = True existing.scan_date = datetime.utcnow() + if scan_job_id: + existing.scan_job_id = scan_job_id db.commit() return existing @@ -128,6 +130,7 @@ class VulnerabilityScanner: vulnerability = Vulnerability( image_id=image.id, + scan_job_id=scan_job_id, vulnerability_id=vulnerability_id, severity=severity, title=title, @@ -144,7 +147,7 @@ class VulnerabilityScanner: return vulnerability - def _scan_with_api(self, db: Session, image: Image) -> List[Vulnerability]: + def _scan_with_api(self, db: Session, image: Image, scan_job_id: Optional[int] = None) -> List[Vulnerability]: vulnerabilities = [] try: @@ -159,7 +162,7 @@ class VulnerabilityScanner: if response.status_code == 200: data = response.json() - vulnerabilities = self._parse_api_results(db, image, data) + vulnerabilities = self._parse_api_results(db, image, data, scan_job_id) else: print( f"API scan failed for {image.full_image_name}: {response.status_code}" @@ -171,21 +174,21 @@ class VulnerabilityScanner: return vulnerabilities def _parse_api_results( - self, db: Session, image: Image, data: Dict + self, db: Session, image: Image, data: Dict, scan_job_id: Optional[int] = None ) -> List[Vulnerability]: vulnerabilities = [] cves = data.get("vulnerabilities", []) for cve in cves: cve_data = cve.get("cve", {}) - vulnerability = self._create_vulnerability_from_api(db, image, cve_data) + vulnerability = self._create_vulnerability_from_api(db, image, cve_data, scan_job_id) if vulnerability: vulnerabilities.append(vulnerability) return vulnerabilities def _create_vulnerability_from_api( - self, db: Session, image: Image, cve_data: Dict + self, db: Session, image: Image, cve_data: Dict, scan_job_id: Optional[int] = None ) -> Optional[Vulnerability]: vulnerability_id = cve_data.get("id") if not vulnerability_id: @@ -203,6 +206,8 @@ class VulnerabilityScanner: if existing: existing.is_active = True existing.scan_date = datetime.utcnow() + if scan_job_id: + existing.scan_job_id = scan_job_id db.commit() return existing @@ -243,6 +248,7 @@ class VulnerabilityScanner: vulnerability = Vulnerability( image_id=image.id, + scan_job_id=scan_job_id, vulnerability_id=vulnerability_id, severity=severity, title=title, @@ -280,13 +286,13 @@ class VulnerabilityScanner: else: return "unspecified" - def scan_all_images(self, db: Session) -> None: + def scan_all_images(self, db: Session, scan_job_id: Optional[int] = None) -> None: images = db.query(Image).filter(Image.is_active == True).all() for image in images: try: print(f"Scanning vulnerabilities for {image.full_image_name}") - vulnerabilities = self.scan_image_vulnerabilities(db, image) + vulnerabilities = self.scan_image_vulnerabilities(db, image, scan_job_id) print(f"Found {len(vulnerabilities)} vulnerabilities") except Exception as e: print(f"Error scanning {image.full_image_name}: {e}")