import asyncio import os from datetime import datetime from typing import List, Optional from contextlib import asynccontextmanager from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sqlalchemy.orm import Session from ignore_manager import IgnoreManager from models import ( File, Image, IgnoreRule, IgnoreType, Project, ScanJob, Vulnerability, VulnerabilitySeverity, FileImageUsage, create_tables, get_db, ) from scanner import DockerImageScanner from vulnerability_scanner import VulnerabilityScanner from websocket_manager import websocket_manager # Load environment variables from pathlib import Path env_file = Path(__file__).parent / ".env" if env_file.exists(): from dotenv import load_dotenv load_dotenv(env_file) @asynccontextmanager async def lifespan(app: FastAPI): # Startup create_tables() websocket_manager.setup_events() yield # Shutdown app = FastAPI( title="GitLab Docker Images Tracker", version="1.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Background task functions def background_project_scan( gitlab_token: str, gitlab_url: str, gitlab_groups: List[str], job_id: int ): """Background task for project scanning""" from models import SessionLocal db = SessionLocal() job = None try: # Get the job and update status with started_at timestamp job = db.query(ScanJob).filter(ScanJob.id == job_id).first() if job: job.status = "running" job.started_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan started websocket_manager.notify_scan_started_sync("discovery", job_id, "Project scan started") scanner = DockerImageScanner(gitlab_token, gitlab_url, gitlab_groups) scanner.scan_all_projects(db) # Mark as completed if job: job.status = "completed" job.completed_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan completed websocket_manager.notify_scan_completed_sync("discovery", job_id, "Project scan completed successfully") except Exception as e: if job: job.status = "failed" job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan failed websocket_manager.notify_scan_failed_sync("discovery", job_id, str(e)) print(f"Background project scan error: {str(e)}") finally: db.close() def background_vulnerability_scan(job_id: int): """Background task for vulnerability scanning""" from models import SessionLocal db = SessionLocal() job = None try: job = db.query(ScanJob).filter(ScanJob.id == job_id).first() if job: job.status = "running" job.started_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan started websocket_manager.notify_scan_started_sync("vulnerability", job_id, "Vulnerability scan started") vuln_scanner = VulnerabilityScanner() vuln_scanner.scan_all_images(db, job_id) if job: job.status = "completed" job.completed_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan completed websocket_manager.notify_scan_completed_sync("vulnerability", job_id, "Vulnerability scan completed successfully") except Exception as e: if job: job.status = "failed" job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan failed websocket_manager.notify_scan_failed_sync("vulnerability", job_id, str(e)) print(f"Background vulnerability scan error: {str(e)}") finally: db.close() class ProjectResponse(BaseModel): id: int gitlab_id: int name: str path: str web_url: str last_scanned: Optional[datetime] = None is_active: bool created_at: datetime updated_at: datetime vulnerability_counts: Optional[dict] = None class FileResponse(BaseModel): id: int project_id: int file_path: str branch: str file_type: str last_scanned: Optional[datetime] = None is_active: bool created_at: datetime updated_at: datetime image_count: Optional[int] = None vulnerability_counts: Optional[dict] = None class ImageResponse(BaseModel): id: int image_name: str tag: Optional[str] = None registry: Optional[str] = None full_image_name: str last_seen: datetime is_active: bool created_at: datetime updated_at: datetime usage_count: Optional[int] = None vulnerability_counts: Optional[dict] = None class VulnerabilityResponse(BaseModel): id: int image_id: int scan_job_id: Optional[int] = None vulnerability_id: str severity: str title: Optional[str] = None description: Optional[str] = None cvss_score: Optional[str] = None published_date: Optional[datetime] = None fixed_version: Optional[str] = None scan_date: datetime is_active: bool class IgnoreRuleResponse(BaseModel): id: int project_id: Optional[int] = None ignore_type: str target: str reason: Optional[str] = None created_by: Optional[str] = None is_active: bool created_at: datetime updated_at: datetime class IgnoreRuleCreate(BaseModel): ignore_type: str target: str reason: Optional[str] = None created_by: Optional[str] = None project_id: Optional[int] = None class DashboardStats(BaseModel): total_projects: int active_projects: int total_images: int active_images: int total_vulnerabilities: int critical_vulnerabilities: int high_vulnerabilities: int medium_vulnerabilities: int low_vulnerabilities: int last_scan: Optional[datetime] = None @app.get("/") async def root(): return {"message": "GitLab Docker Images Tracker API"} @app.get("/dashboard", response_model=DashboardStats) async def get_dashboard_stats(db: Session = Depends(get_db)): total_projects = db.query(Project).count() active_projects = db.query(Project).filter(Project.is_active == True).count() total_images = db.query(Image).count() active_images = db.query(Image).filter(Image.is_active == True).count() total_vulnerabilities = db.query(Vulnerability).filter(Vulnerability.is_active == True).count() critical_vulnerabilities = db.query(Vulnerability).filter( Vulnerability.is_active == True, Vulnerability.severity == VulnerabilitySeverity.CRITICAL.value ).count() high_vulnerabilities = db.query(Vulnerability).filter( Vulnerability.is_active == True, Vulnerability.severity == VulnerabilitySeverity.HIGH.value ).count() medium_vulnerabilities = db.query(Vulnerability).filter( Vulnerability.is_active == True, Vulnerability.severity == VulnerabilitySeverity.MEDIUM.value ).count() low_vulnerabilities = db.query(Vulnerability).filter( Vulnerability.is_active == True, Vulnerability.severity == VulnerabilitySeverity.LOW.value ).count() last_scan_job = db.query(ScanJob).order_by(ScanJob.completed_at.desc()).first() last_scan = last_scan_job.completed_at if last_scan_job else None return DashboardStats( total_projects=total_projects, active_projects=active_projects, total_images=total_images, active_images=active_images, total_vulnerabilities=total_vulnerabilities, critical_vulnerabilities=critical_vulnerabilities, high_vulnerabilities=high_vulnerabilities, medium_vulnerabilities=medium_vulnerabilities, low_vulnerabilities=low_vulnerabilities, last_scan=last_scan, ) @app.get("/projects", response_model=List[ProjectResponse]) async def get_projects( skip: int = 0, limit: int = 100, active_only: bool = True, include_vulnerability_counts: bool = False, db: Session = Depends(get_db) ): query = db.query(Project) if active_only: query = query.filter(Project.is_active == True) projects = query.offset(skip).limit(limit).all() if not include_vulnerability_counts: return projects # Add vulnerability counts for each project result = [] for project in projects: # Get all images for this project project_images = db.query(Image).join(FileImageUsage).join(File).filter( File.project_id == project.id, Image.is_active == True, FileImageUsage.is_active == True ).distinct().all() # Count vulnerabilities by severity for all images in this project vulnerability_counts = { 'critical': 0, 'high': 0, 'medium': 0, 'low': 0, 'unspecified': 0, 'total': 0 } # Count vulnerabilities by severity for all images in this project using SQL COUNT queries if project_images: image_ids = [image.id for image in project_images] vulnerability_counts = { 'critical': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'critical' ).count(), 'high': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'high' ).count(), 'medium': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'medium' ).count(), 'low': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'low' ).count(), 'unspecified': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'unspecified' ).count(), } vulnerability_counts['total'] = sum(vulnerability_counts.values()) project_dict = { "id": project.id, "gitlab_id": project.gitlab_id, "name": project.name, "path": project.path, "web_url": project.web_url, "last_scanned": project.last_scanned, "is_active": project.is_active, "created_at": project.created_at, "updated_at": project.updated_at, "vulnerability_counts": vulnerability_counts } result.append(project_dict) return result @app.get("/projects/{project_id}", response_model=ProjectResponse) async def get_project(project_id: int, db: Session = Depends(get_db)): project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") return project @app.get("/projects/{project_id}/files", response_model=List[FileResponse]) async def get_project_files( project_id: int, skip: int = 0, limit: int = 100, include_image_counts: bool = False, db: Session = Depends(get_db) ): files = db.query(File).filter( File.project_id == project_id, File.is_active == True ).offset(skip).limit(limit).all() if not include_image_counts: return files # Add image and vulnerability counts for each file result = [] for file in files: # Count distinct images in this file distinct_images = db.query(Image).join(FileImageUsage).filter( FileImageUsage.file_id == file.id, FileImageUsage.is_active == True, Image.is_active == True ).distinct().all() # Count vulnerabilities by severity for all images in this file vulnerability_counts = { 'critical': 0, 'high': 0, 'medium': 0, 'low': 0, 'unspecified': 0, 'total': 0 } if distinct_images: # Count vulnerabilities by severity for all images in this file using SQL COUNT queries image_ids = [image.id for image in distinct_images] vulnerability_counts = { 'critical': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'critical' ).count(), 'high': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'high' ).count(), 'medium': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'medium' ).count(), 'low': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'low' ).count(), 'unspecified': db.query(Vulnerability).filter( Vulnerability.image_id.in_(image_ids), Vulnerability.is_active == True, Vulnerability.severity == 'unspecified' ).count(), } vulnerability_counts['total'] = sum(vulnerability_counts.values()) file_dict = { "id": file.id, "project_id": file.project_id, "file_path": file.file_path, "branch": file.branch, "file_type": file.file_type, "last_scanned": file.last_scanned, "is_active": file.is_active, "created_at": file.created_at, "updated_at": file.updated_at, "image_count": len(distinct_images), "vulnerability_counts": vulnerability_counts } result.append(file_dict) return result @app.get("/projects/{project_id}/images", response_model=List[ImageResponse]) async def get_project_images( project_id: int, skip: int = 0, limit: int = 100, include_vulnerability_counts: bool = False, db: Session = Depends(get_db) ): images = db.query(Image).join(FileImageUsage).join(File).filter( File.project_id == project_id, Image.is_active == True, FileImageUsage.is_active == True ).distinct().offset(skip).limit(limit).all() # Add usage count and optionally vulnerability counts for each image result = [] for image in images: usage_count = db.query(FileImageUsage).filter( FileImageUsage.image_id == image.id, FileImageUsage.is_active == True ).count() image_dict = { "id": image.id, "image_name": image.image_name, "tag": image.tag, "registry": image.registry, "full_image_name": image.full_image_name, "last_seen": image.last_seen, "is_active": image.is_active, "created_at": image.created_at, "updated_at": image.updated_at, "usage_count": usage_count } if include_vulnerability_counts: # Count vulnerabilities by severity for this image using SQL COUNT queries vulnerability_counts = { 'critical': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'critical' ).count(), 'high': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'high' ).count(), 'medium': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'medium' ).count(), 'low': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'low' ).count(), 'unspecified': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'unspecified' ).count(), } vulnerability_counts['total'] = sum(vulnerability_counts.values()) image_dict["vulnerability_counts"] = vulnerability_counts result.append(image_dict) return result @app.get("/images", response_model=List[ImageResponse]) async def get_images( skip: int = 0, limit: int = 100, active_only: bool = True, db: Session = Depends(get_db) ): query = db.query(Image) if active_only: query = query.filter(Image.is_active == True) images = query.offset(skip).limit(limit).all() # Add usage count for each image result = [] for image in images: usage_count = db.query(FileImageUsage).filter( FileImageUsage.image_id == image.id, FileImageUsage.is_active == True ).count() image_dict = { "id": image.id, "image_name": image.image_name, "tag": image.tag, "registry": image.registry, "full_image_name": image.full_image_name, "last_seen": image.last_seen, "is_active": image.is_active, "created_at": image.created_at, "updated_at": image.updated_at, "usage_count": usage_count } result.append(image_dict) return result @app.get("/images/{image_id}", response_model=ImageResponse) async def get_image( image_id: int, include_vulnerability_counts: bool = False, db: Session = Depends(get_db) ): image = db.query(Image).filter(Image.id == image_id).first() if not image: raise HTTPException(status_code=404, detail="Image not found") usage_count = db.query(FileImageUsage).filter( FileImageUsage.image_id == image.id, FileImageUsage.is_active == True ).count() result = { "id": image.id, "image_name": image.image_name, "tag": image.tag, "registry": image.registry, "full_image_name": image.full_image_name, "last_seen": image.last_seen, "is_active": image.is_active, "created_at": image.created_at, "updated_at": image.updated_at, "usage_count": usage_count } if include_vulnerability_counts: # Count vulnerabilities by severity for this image using SQL COUNT queries vulnerability_counts = { 'critical': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'critical' ).count(), 'high': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'high' ).count(), 'medium': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'medium' ).count(), 'low': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'low' ).count(), 'unspecified': db.query(Vulnerability).filter( Vulnerability.image_id == image.id, Vulnerability.is_active == True, Vulnerability.severity == 'unspecified' ).count(), } vulnerability_counts['total'] = sum(vulnerability_counts.values()) result["vulnerability_counts"] = vulnerability_counts return result @app.get("/images/{image_id}/vulnerabilities", response_model=List[VulnerabilityResponse]) async def get_image_vulnerabilities( image_id: int, skip: int = 0, limit: int = 100, db: Session = Depends(get_db) ): vulnerabilities = db.query(Vulnerability).filter( Vulnerability.image_id == image_id, Vulnerability.is_active == True ).offset(skip).limit(limit).all() return vulnerabilities @app.get("/vulnerabilities", response_model=List[VulnerabilityResponse]) async def get_vulnerabilities( skip: int = 0, limit: int = 100, severity: Optional[str] = None, db: Session = Depends(get_db) ): query = db.query(Vulnerability).filter(Vulnerability.is_active == True) if severity: query = query.filter(Vulnerability.severity == severity) vulnerabilities = query.offset(skip).limit(limit).all() return vulnerabilities @app.get("/ignore-rules", response_model=List[IgnoreRuleResponse]) async def get_ignore_rules( ignore_type: Optional[str] = None, project_id: Optional[int] = None, db: Session = Depends(get_db) ): ignore_manager = IgnoreManager(db) ignore_type_enum = None if ignore_type: try: ignore_type_enum = IgnoreType(ignore_type) except ValueError: raise HTTPException(status_code=400, detail="Invalid ignore type") rules = ignore_manager.get_ignore_rules(ignore_type_enum, project_id) return rules @app.post("/ignore-rules", response_model=IgnoreRuleResponse) async def create_ignore_rule(rule: IgnoreRuleCreate, db: Session = Depends(get_db)): ignore_manager = IgnoreManager(db) try: ignore_type_enum = IgnoreType(rule.ignore_type) except ValueError: raise HTTPException(status_code=400, detail="Invalid ignore type") created_rule = ignore_manager.add_ignore_rule( ignore_type=ignore_type_enum, target=rule.target, reason=rule.reason, created_by=rule.created_by, project_id=rule.project_id, ) return created_rule @app.delete("/ignore-rules/{rule_id}") async def delete_ignore_rule(rule_id: int, db: Session = Depends(get_db)): ignore_manager = IgnoreManager(db) if not ignore_manager.remove_ignore_rule(rule_id): raise HTTPException(status_code=404, detail="Ignore rule not found") return {"message": "Ignore rule deleted successfully"} @app.post("/scan/projects") async def scan_projects(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): gitlab_token = os.getenv("GITLAB_TOKEN") gitlab_url = os.getenv("GITLAB_URL", "https://gitlab.com") gitlab_groups_str = os.getenv("GITLAB_GROUPS", "") if not gitlab_token: raise HTTPException(status_code=500, detail="GitLab token not configured") # Check if there are any running or pending scans running_jobs = db.query(ScanJob).filter( ScanJob.status.in_(["pending", "running"]) ).all() if running_jobs: running_job_types = [job.job_type for job in running_jobs] raise HTTPException( status_code=409, detail=f"Cannot start scan: {', '.join(running_job_types)} scan(s) already running" ) # Parse groups from environment variable gitlab_groups = [] if gitlab_groups_str: gitlab_groups = [group.strip() for group in gitlab_groups_str.split(",") if group.strip()] try: # Create the scanner just to create the job scanner = DockerImageScanner(gitlab_token, gitlab_url, gitlab_groups) job = scanner.create_scan_job(db, "discovery") # Start background task background_tasks.add_task( background_project_scan, gitlab_token, gitlab_url, gitlab_groups, job.id ) group_info = f" (groups: {', '.join(gitlab_groups)})" if gitlab_groups else " (all projects)" return { "message": f"Project scan started{group_info}", "job_id": job.id, "status": "pending" } except Exception as e: print(f"Scan error: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to start scan: {str(e)}") @app.post("/scan/vulnerabilities") async def scan_vulnerabilities(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): # Check if there are any running or pending scans running_jobs = db.query(ScanJob).filter( ScanJob.status.in_(["pending", "running"]) ).all() if running_jobs: running_job_types = [job.job_type for job in running_jobs] raise HTTPException( status_code=409, detail=f"Cannot start scan: {', '.join(running_job_types)} scan(s) already running" ) # Create job record job = ScanJob(job_type="vulnerability", status="pending") db.add(job) db.commit() db.refresh(job) try: # Start background task background_tasks.add_task(background_vulnerability_scan, job.id) return { "message": "Vulnerability scan started", "job_id": job.id, "status": "pending" } except Exception as e: job.status = "failed" job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() raise HTTPException(status_code=500, detail=f"Failed to start vulnerability scan: {str(e)}") @app.get("/scan/jobs", response_model=List[dict]) async def get_scan_jobs( skip: int = 0, limit: int = 50, db: Session = Depends(get_db) ): jobs = db.query(ScanJob).order_by(ScanJob.created_at.desc()).offset(skip).limit(limit).all() return [ { "id": job.id, "job_type": job.job_type, "status": job.status, "project_id": job.project_id, "started_at": job.started_at, "completed_at": job.completed_at, "error_message": job.error_message, "created_at": job.created_at, } for job in jobs ] @app.get("/scan/jobs/{job_id}") async def get_scan_job(job_id: int, db: Session = Depends(get_db)): job = db.query(ScanJob).filter(ScanJob.id == job_id).first() if not job: raise HTTPException(status_code=404, detail="Scan job not found") return { "id": job.id, "job_type": job.job_type, "status": job.status, "project_id": job.project_id, "started_at": job.started_at, "completed_at": job.completed_at, "error_message": job.error_message, "created_at": job.created_at, } @app.get("/scan/status") async def get_scan_status(db: Session = Depends(get_db)): """Check if there are any running or pending scans""" running_jobs = db.query(ScanJob).filter( ScanJob.status.in_(["pending", "running"]) ).all() return { "has_running_scans": len(running_jobs) > 0, "running_jobs": [ { "id": job.id, "job_type": job.job_type, "status": job.status, "started_at": job.started_at, "created_at": job.created_at, } for job in running_jobs ] } @app.get("/gitlab/groups") async def get_gitlab_groups(): gitlab_token = os.getenv("GITLAB_TOKEN") gitlab_url = os.getenv("GITLAB_URL", "https://gitlab.com") if not gitlab_token: raise HTTPException(status_code=500, detail="GitLab token not configured") try: import gitlab gl = gitlab.Gitlab(gitlab_url, private_token=gitlab_token) groups = gl.groups.list(all=True, simple=True) return [ { "id": group.id, "name": group.name, "path": group.path, "full_path": group.full_path, "description": getattr(group, 'description', ''), } for group in groups ] except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch groups: {str(e)}") def main(): import uvicorn uvicorn.run("main:asgi_app", host="0.0.0.0", port=5000, reload=True) # Create the ASGI app at module level for uvicorn asgi_app = websocket_manager.get_asgi_app(app) if __name__ == "__main__": main()