import asyncio import os from datetime import datetime from typing import List, Optional from contextlib import asynccontextmanager from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sqlalchemy.orm import Session from ignore_manager import IgnoreManager from models import ( File, Image, IgnoreRule, IgnoreType, Project, ScanJob, Vulnerability, VulnerabilitySeverity, FileImageUsage, create_tables, get_db, ) from scanner import DockerImageScanner from vulnerability_scanner import VulnerabilityScanner from websocket_manager import websocket_manager # Load environment variables from pathlib import Path env_file = Path(__file__).parent / ".env" if env_file.exists(): from dotenv import load_dotenv load_dotenv(env_file) @asynccontextmanager async def lifespan(app: FastAPI): # Startup create_tables() websocket_manager.setup_events() yield # Shutdown app = FastAPI( title="GitLab Docker Images Tracker", version="1.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Background task functions def background_project_scan( gitlab_token: str, gitlab_url: str, gitlab_groups: List[str], job_id: int ): """Background task for project scanning""" from models import SessionLocal db = SessionLocal() job = None try: # Get the job and update status with started_at timestamp job = db.query(ScanJob).filter(ScanJob.id == job_id).first() if job: job.status = "running" job.started_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan started websocket_manager.notify_scan_started_sync("discovery", job_id, "Project scan started") scanner = DockerImageScanner(gitlab_token, gitlab_url, gitlab_groups) scanner.scan_all_projects(db) # Mark as completed if job: job.status = "completed" job.completed_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan completed websocket_manager.notify_scan_completed_sync("discovery", job_id, "Project scan completed successfully") except Exception as e: if job: job.status = "failed" job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan failed websocket_manager.notify_scan_failed_sync("discovery", job_id, str(e)) print(f"Background project scan error: {str(e)}") finally: db.close() def background_vulnerability_scan(job_id: int): """Background task for vulnerability scanning""" from models import SessionLocal db = SessionLocal() job = None try: job = db.query(ScanJob).filter(ScanJob.id == job_id).first() if job: job.status = "running" job.started_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan started websocket_manager.notify_scan_started_sync("vulnerability", job_id, "Vulnerability scan started") vuln_scanner = VulnerabilityScanner() vuln_scanner.scan_all_images(db, job_id) if job: job.status = "completed" job.completed_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan completed websocket_manager.notify_scan_completed_sync("vulnerability", job_id, "Vulnerability scan completed successfully") except Exception as e: if job: job.status = "failed" job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() # Notify via WebSocket that scan failed websocket_manager.notify_scan_failed_sync("vulnerability", job_id, str(e)) print(f"Background vulnerability scan error: {str(e)}") finally: db.close() class ProjectResponse(BaseModel): id: int gitlab_id: int name: str path: str web_url: str last_scanned: Optional[datetime] = None is_active: bool created_at: datetime updated_at: datetime class FileResponse(BaseModel): id: int project_id: int file_path: str branch: str file_type: str last_scanned: Optional[datetime] = None is_active: bool created_at: datetime updated_at: datetime class ImageResponse(BaseModel): id: int image_name: str tag: Optional[str] = None registry: Optional[str] = None full_image_name: str last_seen: datetime is_active: bool created_at: datetime updated_at: datetime usage_count: Optional[int] = None class VulnerabilityResponse(BaseModel): id: int image_id: int scan_job_id: Optional[int] = None vulnerability_id: str severity: str title: Optional[str] = None description: Optional[str] = None cvss_score: Optional[str] = None published_date: Optional[datetime] = None fixed_version: Optional[str] = None scan_date: datetime is_active: bool class IgnoreRuleResponse(BaseModel): id: int project_id: Optional[int] = None ignore_type: str target: str reason: Optional[str] = None created_by: Optional[str] = None is_active: bool created_at: datetime updated_at: datetime class IgnoreRuleCreate(BaseModel): ignore_type: str target: str reason: Optional[str] = None created_by: Optional[str] = None project_id: Optional[int] = None class DashboardStats(BaseModel): total_projects: int active_projects: int total_images: int active_images: int total_vulnerabilities: int critical_vulnerabilities: int high_vulnerabilities: int medium_vulnerabilities: int low_vulnerabilities: int last_scan: Optional[datetime] = None @app.get("/") async def root(): return {"message": "GitLab Docker Images Tracker API"} @app.get("/dashboard", response_model=DashboardStats) async def get_dashboard_stats(db: Session = Depends(get_db)): total_projects = db.query(Project).count() active_projects = db.query(Project).filter(Project.is_active == True).count() total_images = db.query(Image).count() active_images = db.query(Image).filter(Image.is_active == True).count() total_vulnerabilities = db.query(Vulnerability).filter(Vulnerability.is_active == True).count() critical_vulnerabilities = db.query(Vulnerability).filter( Vulnerability.is_active == True, Vulnerability.severity == VulnerabilitySeverity.CRITICAL.value ).count() high_vulnerabilities = db.query(Vulnerability).filter( Vulnerability.is_active == True, Vulnerability.severity == VulnerabilitySeverity.HIGH.value ).count() medium_vulnerabilities = db.query(Vulnerability).filter( Vulnerability.is_active == True, Vulnerability.severity == VulnerabilitySeverity.MEDIUM.value ).count() low_vulnerabilities = db.query(Vulnerability).filter( Vulnerability.is_active == True, Vulnerability.severity == VulnerabilitySeverity.LOW.value ).count() last_scan_job = db.query(ScanJob).order_by(ScanJob.completed_at.desc()).first() last_scan = last_scan_job.completed_at if last_scan_job else None return DashboardStats( total_projects=total_projects, active_projects=active_projects, total_images=total_images, active_images=active_images, total_vulnerabilities=total_vulnerabilities, critical_vulnerabilities=critical_vulnerabilities, high_vulnerabilities=high_vulnerabilities, medium_vulnerabilities=medium_vulnerabilities, low_vulnerabilities=low_vulnerabilities, last_scan=last_scan, ) @app.get("/projects", response_model=List[ProjectResponse]) async def get_projects( skip: int = 0, limit: int = 100, active_only: bool = True, db: Session = Depends(get_db) ): query = db.query(Project) if active_only: query = query.filter(Project.is_active == True) projects = query.offset(skip).limit(limit).all() return projects @app.get("/projects/{project_id}", response_model=ProjectResponse) async def get_project(project_id: int, db: Session = Depends(get_db)): project = db.query(Project).filter(Project.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") return project @app.get("/projects/{project_id}/files", response_model=List[FileResponse]) async def get_project_files( project_id: int, skip: int = 0, limit: int = 100, db: Session = Depends(get_db) ): files = db.query(File).filter( File.project_id == project_id, File.is_active == True ).offset(skip).limit(limit).all() return files @app.get("/projects/{project_id}/images", response_model=List[ImageResponse]) async def get_project_images( project_id: int, skip: int = 0, limit: int = 100, db: Session = Depends(get_db) ): images = db.query(Image).join(FileImageUsage).join(File).filter( File.project_id == project_id, Image.is_active == True, FileImageUsage.is_active == True ).distinct().offset(skip).limit(limit).all() # Add usage count for each image result = [] for image in images: usage_count = db.query(FileImageUsage).filter( FileImageUsage.image_id == image.id, FileImageUsage.is_active == True ).count() image_dict = { "id": image.id, "image_name": image.image_name, "tag": image.tag, "registry": image.registry, "full_image_name": image.full_image_name, "last_seen": image.last_seen, "is_active": image.is_active, "created_at": image.created_at, "updated_at": image.updated_at, "usage_count": usage_count } result.append(image_dict) return result @app.get("/images", response_model=List[ImageResponse]) async def get_images( skip: int = 0, limit: int = 100, active_only: bool = True, db: Session = Depends(get_db) ): query = db.query(Image) if active_only: query = query.filter(Image.is_active == True) images = query.offset(skip).limit(limit).all() # Add usage count for each image result = [] for image in images: usage_count = db.query(FileImageUsage).filter( FileImageUsage.image_id == image.id, FileImageUsage.is_active == True ).count() image_dict = { "id": image.id, "image_name": image.image_name, "tag": image.tag, "registry": image.registry, "full_image_name": image.full_image_name, "last_seen": image.last_seen, "is_active": image.is_active, "created_at": image.created_at, "updated_at": image.updated_at, "usage_count": usage_count } result.append(image_dict) return result @app.get("/images/{image_id}", response_model=ImageResponse) async def get_image(image_id: int, db: Session = Depends(get_db)): image = db.query(Image).filter(Image.id == image_id).first() if not image: raise HTTPException(status_code=404, detail="Image not found") usage_count = db.query(FileImageUsage).filter( FileImageUsage.image_id == image.id, FileImageUsage.is_active == True ).count() return { "id": image.id, "image_name": image.image_name, "tag": image.tag, "registry": image.registry, "full_image_name": image.full_image_name, "last_seen": image.last_seen, "is_active": image.is_active, "created_at": image.created_at, "updated_at": image.updated_at, "usage_count": usage_count } @app.get("/images/{image_id}/vulnerabilities", response_model=List[VulnerabilityResponse]) async def get_image_vulnerabilities( image_id: int, skip: int = 0, limit: int = 100, db: Session = Depends(get_db) ): vulnerabilities = db.query(Vulnerability).filter( Vulnerability.image_id == image_id, Vulnerability.is_active == True ).offset(skip).limit(limit).all() return vulnerabilities @app.get("/vulnerabilities", response_model=List[VulnerabilityResponse]) async def get_vulnerabilities( skip: int = 0, limit: int = 100, severity: Optional[str] = None, db: Session = Depends(get_db) ): query = db.query(Vulnerability).filter(Vulnerability.is_active == True) if severity: query = query.filter(Vulnerability.severity == severity) vulnerabilities = query.offset(skip).limit(limit).all() return vulnerabilities @app.get("/ignore-rules", response_model=List[IgnoreRuleResponse]) async def get_ignore_rules( ignore_type: Optional[str] = None, project_id: Optional[int] = None, db: Session = Depends(get_db) ): ignore_manager = IgnoreManager(db) ignore_type_enum = None if ignore_type: try: ignore_type_enum = IgnoreType(ignore_type) except ValueError: raise HTTPException(status_code=400, detail="Invalid ignore type") rules = ignore_manager.get_ignore_rules(ignore_type_enum, project_id) return rules @app.post("/ignore-rules", response_model=IgnoreRuleResponse) async def create_ignore_rule(rule: IgnoreRuleCreate, db: Session = Depends(get_db)): ignore_manager = IgnoreManager(db) try: ignore_type_enum = IgnoreType(rule.ignore_type) except ValueError: raise HTTPException(status_code=400, detail="Invalid ignore type") created_rule = ignore_manager.add_ignore_rule( ignore_type=ignore_type_enum, target=rule.target, reason=rule.reason, created_by=rule.created_by, project_id=rule.project_id, ) return created_rule @app.delete("/ignore-rules/{rule_id}") async def delete_ignore_rule(rule_id: int, db: Session = Depends(get_db)): ignore_manager = IgnoreManager(db) if not ignore_manager.remove_ignore_rule(rule_id): raise HTTPException(status_code=404, detail="Ignore rule not found") return {"message": "Ignore rule deleted successfully"} @app.post("/scan/projects") async def scan_projects(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): gitlab_token = os.getenv("GITLAB_TOKEN") gitlab_url = os.getenv("GITLAB_URL", "https://gitlab.com") gitlab_groups_str = os.getenv("GITLAB_GROUPS", "") if not gitlab_token: raise HTTPException(status_code=500, detail="GitLab token not configured") # Check if there are any running or pending scans running_jobs = db.query(ScanJob).filter( ScanJob.status.in_(["pending", "running"]) ).all() if running_jobs: running_job_types = [job.job_type for job in running_jobs] raise HTTPException( status_code=409, detail=f"Cannot start scan: {', '.join(running_job_types)} scan(s) already running" ) # Parse groups from environment variable gitlab_groups = [] if gitlab_groups_str: gitlab_groups = [group.strip() for group in gitlab_groups_str.split(",") if group.strip()] try: # Create the scanner just to create the job scanner = DockerImageScanner(gitlab_token, gitlab_url, gitlab_groups) job = scanner.create_scan_job(db, "discovery") # Start background task background_tasks.add_task( background_project_scan, gitlab_token, gitlab_url, gitlab_groups, job.id ) group_info = f" (groups: {', '.join(gitlab_groups)})" if gitlab_groups else " (all projects)" return { "message": f"Project scan started{group_info}", "job_id": job.id, "status": "pending" } except Exception as e: print(f"Scan error: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to start scan: {str(e)}") @app.post("/scan/vulnerabilities") async def scan_vulnerabilities(background_tasks: BackgroundTasks, db: Session = Depends(get_db)): # Check if there are any running or pending scans running_jobs = db.query(ScanJob).filter( ScanJob.status.in_(["pending", "running"]) ).all() if running_jobs: running_job_types = [job.job_type for job in running_jobs] raise HTTPException( status_code=409, detail=f"Cannot start scan: {', '.join(running_job_types)} scan(s) already running" ) # Create job record job = ScanJob(job_type="vulnerability", status="pending") db.add(job) db.commit() db.refresh(job) try: # Start background task background_tasks.add_task(background_vulnerability_scan, job.id) return { "message": "Vulnerability scan started", "job_id": job.id, "status": "pending" } except Exception as e: job.status = "failed" job.error_message = str(e) job.completed_at = datetime.utcnow() db.commit() raise HTTPException(status_code=500, detail=f"Failed to start vulnerability scan: {str(e)}") @app.get("/scan/jobs", response_model=List[dict]) async def get_scan_jobs( skip: int = 0, limit: int = 50, db: Session = Depends(get_db) ): jobs = db.query(ScanJob).order_by(ScanJob.created_at.desc()).offset(skip).limit(limit).all() return [ { "id": job.id, "job_type": job.job_type, "status": job.status, "project_id": job.project_id, "started_at": job.started_at, "completed_at": job.completed_at, "error_message": job.error_message, "created_at": job.created_at, } for job in jobs ] @app.get("/scan/jobs/{job_id}") async def get_scan_job(job_id: int, db: Session = Depends(get_db)): job = db.query(ScanJob).filter(ScanJob.id == job_id).first() if not job: raise HTTPException(status_code=404, detail="Scan job not found") return { "id": job.id, "job_type": job.job_type, "status": job.status, "project_id": job.project_id, "started_at": job.started_at, "completed_at": job.completed_at, "error_message": job.error_message, "created_at": job.created_at, } @app.get("/scan/status") async def get_scan_status(db: Session = Depends(get_db)): """Check if there are any running or pending scans""" running_jobs = db.query(ScanJob).filter( ScanJob.status.in_(["pending", "running"]) ).all() return { "has_running_scans": len(running_jobs) > 0, "running_jobs": [ { "id": job.id, "job_type": job.job_type, "status": job.status, "started_at": job.started_at, "created_at": job.created_at, } for job in running_jobs ] } @app.get("/gitlab/groups") async def get_gitlab_groups(): gitlab_token = os.getenv("GITLAB_TOKEN") gitlab_url = os.getenv("GITLAB_URL", "https://gitlab.com") if not gitlab_token: raise HTTPException(status_code=500, detail="GitLab token not configured") try: import gitlab gl = gitlab.Gitlab(gitlab_url, private_token=gitlab_token) groups = gl.groups.list(all=True, simple=True) return [ { "id": group.id, "name": group.name, "path": group.path, "full_path": group.full_path, "description": getattr(group, 'description', ''), } for group in groups ] except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to fetch groups: {str(e)}") def main(): import uvicorn uvicorn.run("main:asgi_app", host="0.0.0.0", port=5000, reload=True) # Create the ASGI app at module level for uvicorn asgi_app = websocket_manager.get_asgi_app(app) if __name__ == "__main__": main()