Files
gdit-back/main.py

910 lines
29 KiB
Python

import asyncio
import os
from datetime import datetime
from typing import List, Optional
from contextlib import asynccontextmanager
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ignore_manager import IgnoreManager
from models import (
File,
Image,
IgnoreRule,
IgnoreType,
Project,
ScanJob,
Vulnerability,
VulnerabilitySeverity,
FileImageUsage,
create_tables,
get_db,
)
from scanner import DockerImageScanner
from vulnerability_scanner import VulnerabilityScanner
from websocket_manager import websocket_manager
# Load environment variables
from pathlib import Path
env_file = Path(__file__).parent / ".env"
if env_file.exists():
from dotenv import load_dotenv
load_dotenv(env_file)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
create_tables()
websocket_manager.setup_events()
yield
# Shutdown
app = FastAPI(
title="GitLab Docker Images Tracker",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Background task functions
def background_project_scan(
gitlab_token: str,
gitlab_url: str,
gitlab_groups: List[str],
job_id: int
):
"""Background task for project scanning"""
from models import SessionLocal
db = SessionLocal()
job = None
try:
# Get the job and update status with started_at timestamp
job = db.query(ScanJob).filter(ScanJob.id == job_id).first()
if job:
job.status = "running"
job.started_at = datetime.utcnow()
db.commit()
# Notify via WebSocket that scan started
websocket_manager.notify_scan_started_sync("discovery", job_id, "Project scan started")
scanner = DockerImageScanner(gitlab_token, gitlab_url, gitlab_groups)
scanner.scan_all_projects(db)
# Mark as completed
if job:
job.status = "completed"
job.completed_at = datetime.utcnow()
db.commit()
# Notify via WebSocket that scan completed
websocket_manager.notify_scan_completed_sync("discovery", job_id, "Project scan completed successfully")
except Exception as e:
if job:
job.status = "failed"
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
# Notify via WebSocket that scan failed
websocket_manager.notify_scan_failed_sync("discovery", job_id, str(e))
print(f"Background project scan error: {str(e)}")
finally:
db.close()
def background_vulnerability_scan(job_id: int):
"""Background task for vulnerability scanning"""
from models import SessionLocal
db = SessionLocal()
job = None
try:
job = db.query(ScanJob).filter(ScanJob.id == job_id).first()
if job:
job.status = "running"
job.started_at = datetime.utcnow()
db.commit()
# Notify via WebSocket that scan started
websocket_manager.notify_scan_started_sync("vulnerability", job_id, "Vulnerability scan started")
vuln_scanner = VulnerabilityScanner()
vuln_scanner.scan_all_images(db, job_id)
if job:
job.status = "completed"
job.completed_at = datetime.utcnow()
db.commit()
# Notify via WebSocket that scan completed
websocket_manager.notify_scan_completed_sync("vulnerability", job_id, "Vulnerability scan completed successfully")
except Exception as e:
if job:
job.status = "failed"
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
# Notify via WebSocket that scan failed
websocket_manager.notify_scan_failed_sync("vulnerability", job_id, str(e))
print(f"Background vulnerability scan error: {str(e)}")
finally:
db.close()
class ProjectResponse(BaseModel):
id: int
gitlab_id: int
name: str
path: str
web_url: str
last_scanned: Optional[datetime] = None
is_active: bool
created_at: datetime
updated_at: datetime
vulnerability_counts: Optional[dict] = None
class FileResponse(BaseModel):
id: int
project_id: int
file_path: str
branch: str
file_type: str
last_scanned: Optional[datetime] = None
is_active: bool
created_at: datetime
updated_at: datetime
image_count: Optional[int] = None
vulnerability_counts: Optional[dict] = None
class ImageResponse(BaseModel):
id: int
image_name: str
tag: Optional[str] = None
registry: Optional[str] = None
full_image_name: str
last_seen: datetime
is_active: bool
created_at: datetime
updated_at: datetime
usage_count: Optional[int] = None
vulnerability_counts: Optional[dict] = None
class VulnerabilityResponse(BaseModel):
id: int
image_id: int
scan_job_id: Optional[int] = None
vulnerability_id: str
severity: str
title: Optional[str] = None
description: Optional[str] = None
cvss_score: Optional[str] = None
published_date: Optional[datetime] = None
fixed_version: Optional[str] = None
scan_date: datetime
is_active: bool
class IgnoreRuleResponse(BaseModel):
id: int
project_id: Optional[int] = None
ignore_type: str
target: str
reason: Optional[str] = None
created_by: Optional[str] = None
is_active: bool
created_at: datetime
updated_at: datetime
class IgnoreRuleCreate(BaseModel):
ignore_type: str
target: str
reason: Optional[str] = None
created_by: Optional[str] = None
project_id: Optional[int] = None
class DashboardStats(BaseModel):
total_projects: int
active_projects: int
total_images: int
active_images: int
total_vulnerabilities: int
critical_vulnerabilities: int
high_vulnerabilities: int
medium_vulnerabilities: int
low_vulnerabilities: int
last_scan: Optional[datetime] = None
@app.get("/")
async def root():
return {"message": "GitLab Docker Images Tracker API"}
@app.get("/dashboard", response_model=DashboardStats)
async def get_dashboard_stats(db: Session = Depends(get_db)):
total_projects = db.query(Project).count()
active_projects = db.query(Project).filter(Project.is_active == True).count()
total_images = db.query(Image).count()
active_images = db.query(Image).filter(Image.is_active == True).count()
total_vulnerabilities = db.query(Vulnerability).filter(Vulnerability.is_active == True).count()
critical_vulnerabilities = db.query(Vulnerability).filter(
Vulnerability.is_active == True,
Vulnerability.severity == VulnerabilitySeverity.CRITICAL.value
).count()
high_vulnerabilities = db.query(Vulnerability).filter(
Vulnerability.is_active == True,
Vulnerability.severity == VulnerabilitySeverity.HIGH.value
).count()
medium_vulnerabilities = db.query(Vulnerability).filter(
Vulnerability.is_active == True,
Vulnerability.severity == VulnerabilitySeverity.MEDIUM.value
).count()
low_vulnerabilities = db.query(Vulnerability).filter(
Vulnerability.is_active == True,
Vulnerability.severity == VulnerabilitySeverity.LOW.value
).count()
last_scan_job = db.query(ScanJob).order_by(ScanJob.completed_at.desc()).first()
last_scan = last_scan_job.completed_at if last_scan_job else None
return DashboardStats(
total_projects=total_projects,
active_projects=active_projects,
total_images=total_images,
active_images=active_images,
total_vulnerabilities=total_vulnerabilities,
critical_vulnerabilities=critical_vulnerabilities,
high_vulnerabilities=high_vulnerabilities,
medium_vulnerabilities=medium_vulnerabilities,
low_vulnerabilities=low_vulnerabilities,
last_scan=last_scan,
)
@app.get("/projects", response_model=List[ProjectResponse])
async def get_projects(
skip: int = 0,
limit: int = 100,
active_only: bool = True,
include_vulnerability_counts: bool = False,
db: Session = Depends(get_db)
):
query = db.query(Project)
if active_only:
query = query.filter(Project.is_active == True)
projects = query.offset(skip).limit(limit).all()
if not include_vulnerability_counts:
return projects
# Add vulnerability counts for each project
result = []
for project in projects:
# Get all images for this project
project_images = db.query(Image).join(FileImageUsage).join(File).filter(
File.project_id == project.id,
Image.is_active == True,
FileImageUsage.is_active == True
).distinct().all()
# Count vulnerabilities by severity for all images in this project
vulnerability_counts = {
'critical': 0,
'high': 0,
'medium': 0,
'low': 0,
'unspecified': 0,
'total': 0
}
# Count vulnerabilities by severity for all images in this project using SQL COUNT queries
if project_images:
image_ids = [image.id for image in project_images]
vulnerability_counts = {
'critical': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'critical'
).count(),
'high': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'high'
).count(),
'medium': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'medium'
).count(),
'low': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'low'
).count(),
'unspecified': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'unspecified'
).count(),
}
vulnerability_counts['total'] = sum(vulnerability_counts.values())
project_dict = {
"id": project.id,
"gitlab_id": project.gitlab_id,
"name": project.name,
"path": project.path,
"web_url": project.web_url,
"last_scanned": project.last_scanned,
"is_active": project.is_active,
"created_at": project.created_at,
"updated_at": project.updated_at,
"vulnerability_counts": vulnerability_counts
}
result.append(project_dict)
return result
@app.get("/projects/{project_id}", response_model=ProjectResponse)
async def get_project(project_id: int, db: Session = Depends(get_db)):
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return project
@app.get("/projects/{project_id}/files", response_model=List[FileResponse])
async def get_project_files(
project_id: int,
skip: int = 0,
limit: int = 100,
include_image_counts: bool = False,
db: Session = Depends(get_db)
):
files = db.query(File).filter(
File.project_id == project_id,
File.is_active == True
).offset(skip).limit(limit).all()
if not include_image_counts:
return files
# Add image and vulnerability counts for each file
result = []
for file in files:
# Count distinct images in this file
distinct_images = db.query(Image).join(FileImageUsage).filter(
FileImageUsage.file_id == file.id,
FileImageUsage.is_active == True,
Image.is_active == True
).distinct().all()
# Count vulnerabilities by severity for all images in this file
vulnerability_counts = {
'critical': 0,
'high': 0,
'medium': 0,
'low': 0,
'unspecified': 0,
'total': 0
}
if distinct_images:
# Count vulnerabilities by severity for all images in this file using SQL COUNT queries
image_ids = [image.id for image in distinct_images]
vulnerability_counts = {
'critical': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'critical'
).count(),
'high': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'high'
).count(),
'medium': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'medium'
).count(),
'low': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'low'
).count(),
'unspecified': db.query(Vulnerability).filter(
Vulnerability.image_id.in_(image_ids),
Vulnerability.is_active == True,
Vulnerability.severity == 'unspecified'
).count(),
}
vulnerability_counts['total'] = sum(vulnerability_counts.values())
file_dict = {
"id": file.id,
"project_id": file.project_id,
"file_path": file.file_path,
"branch": file.branch,
"file_type": file.file_type,
"last_scanned": file.last_scanned,
"is_active": file.is_active,
"created_at": file.created_at,
"updated_at": file.updated_at,
"image_count": len(distinct_images),
"vulnerability_counts": vulnerability_counts
}
result.append(file_dict)
return result
@app.get("/projects/{project_id}/images", response_model=List[ImageResponse])
async def get_project_images(
project_id: int,
skip: int = 0,
limit: int = 100,
include_vulnerability_counts: bool = False,
db: Session = Depends(get_db)
):
images = db.query(Image).join(FileImageUsage).join(File).filter(
File.project_id == project_id,
Image.is_active == True,
FileImageUsage.is_active == True
).distinct().offset(skip).limit(limit).all()
# Add usage count and optionally vulnerability counts for each image
result = []
for image in images:
usage_count = db.query(FileImageUsage).filter(
FileImageUsage.image_id == image.id,
FileImageUsage.is_active == True
).count()
image_dict = {
"id": image.id,
"image_name": image.image_name,
"tag": image.tag,
"registry": image.registry,
"full_image_name": image.full_image_name,
"last_seen": image.last_seen,
"is_active": image.is_active,
"created_at": image.created_at,
"updated_at": image.updated_at,
"usage_count": usage_count
}
if include_vulnerability_counts:
# Count vulnerabilities by severity for this image using SQL COUNT queries
vulnerability_counts = {
'critical': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'critical'
).count(),
'high': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'high'
).count(),
'medium': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'medium'
).count(),
'low': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'low'
).count(),
'unspecified': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'unspecified'
).count(),
}
vulnerability_counts['total'] = sum(vulnerability_counts.values())
image_dict["vulnerability_counts"] = vulnerability_counts
result.append(image_dict)
return result
@app.get("/images", response_model=List[ImageResponse])
async def get_images(
skip: int = 0,
limit: int = 100,
active_only: bool = True,
db: Session = Depends(get_db)
):
query = db.query(Image)
if active_only:
query = query.filter(Image.is_active == True)
images = query.offset(skip).limit(limit).all()
# Add usage count for each image
result = []
for image in images:
usage_count = db.query(FileImageUsage).filter(
FileImageUsage.image_id == image.id,
FileImageUsage.is_active == True
).count()
image_dict = {
"id": image.id,
"image_name": image.image_name,
"tag": image.tag,
"registry": image.registry,
"full_image_name": image.full_image_name,
"last_seen": image.last_seen,
"is_active": image.is_active,
"created_at": image.created_at,
"updated_at": image.updated_at,
"usage_count": usage_count
}
result.append(image_dict)
return result
@app.get("/images/{image_id}", response_model=ImageResponse)
async def get_image(
image_id: int,
include_vulnerability_counts: bool = False,
db: Session = Depends(get_db)
):
image = db.query(Image).filter(Image.id == image_id).first()
if not image:
raise HTTPException(status_code=404, detail="Image not found")
usage_count = db.query(FileImageUsage).filter(
FileImageUsage.image_id == image.id,
FileImageUsage.is_active == True
).count()
result = {
"id": image.id,
"image_name": image.image_name,
"tag": image.tag,
"registry": image.registry,
"full_image_name": image.full_image_name,
"last_seen": image.last_seen,
"is_active": image.is_active,
"created_at": image.created_at,
"updated_at": image.updated_at,
"usage_count": usage_count
}
if include_vulnerability_counts:
# Count vulnerabilities by severity for this image using SQL COUNT queries
vulnerability_counts = {
'critical': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'critical'
).count(),
'high': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'high'
).count(),
'medium': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'medium'
).count(),
'low': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'low'
).count(),
'unspecified': db.query(Vulnerability).filter(
Vulnerability.image_id == image.id,
Vulnerability.is_active == True,
Vulnerability.severity == 'unspecified'
).count(),
}
vulnerability_counts['total'] = sum(vulnerability_counts.values())
result["vulnerability_counts"] = vulnerability_counts
return result
@app.get("/images/{image_id}/vulnerabilities", response_model=List[VulnerabilityResponse])
async def get_image_vulnerabilities(
image_id: int,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db)
):
vulnerabilities = db.query(Vulnerability).filter(
Vulnerability.image_id == image_id,
Vulnerability.is_active == True
).offset(skip).limit(limit).all()
return vulnerabilities
@app.get("/vulnerabilities", response_model=List[VulnerabilityResponse])
async def get_vulnerabilities(
skip: int = 0,
limit: int = 100,
severity: Optional[str] = None,
db: Session = Depends(get_db)
):
query = db.query(Vulnerability).filter(Vulnerability.is_active == True)
if severity:
query = query.filter(Vulnerability.severity == severity)
vulnerabilities = query.offset(skip).limit(limit).all()
return vulnerabilities
@app.get("/ignore-rules", response_model=List[IgnoreRuleResponse])
async def get_ignore_rules(
ignore_type: Optional[str] = None,
project_id: Optional[int] = None,
db: Session = Depends(get_db)
):
ignore_manager = IgnoreManager(db)
ignore_type_enum = None
if ignore_type:
try:
ignore_type_enum = IgnoreType(ignore_type)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid ignore type")
rules = ignore_manager.get_ignore_rules(ignore_type_enum, project_id)
return rules
@app.post("/ignore-rules", response_model=IgnoreRuleResponse)
async def create_ignore_rule(rule: IgnoreRuleCreate, db: Session = Depends(get_db)):
ignore_manager = IgnoreManager(db)
try:
ignore_type_enum = IgnoreType(rule.ignore_type)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid ignore type")
created_rule = ignore_manager.add_ignore_rule(
ignore_type=ignore_type_enum,
target=rule.target,
reason=rule.reason,
created_by=rule.created_by,
project_id=rule.project_id,
)
return created_rule
@app.delete("/ignore-rules/{rule_id}")
async def delete_ignore_rule(rule_id: int, db: Session = Depends(get_db)):
ignore_manager = IgnoreManager(db)
if not ignore_manager.remove_ignore_rule(rule_id):
raise HTTPException(status_code=404, detail="Ignore rule not found")
return {"message": "Ignore rule deleted successfully"}
@app.post("/scan/projects")
async def scan_projects(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
gitlab_token = os.getenv("GITLAB_TOKEN")
gitlab_url = os.getenv("GITLAB_URL", "https://gitlab.com")
gitlab_groups_str = os.getenv("GITLAB_GROUPS", "")
if not gitlab_token:
raise HTTPException(status_code=500, detail="GitLab token not configured")
# Check if there are any running or pending scans
running_jobs = db.query(ScanJob).filter(
ScanJob.status.in_(["pending", "running"])
).all()
if running_jobs:
running_job_types = [job.job_type for job in running_jobs]
raise HTTPException(
status_code=409,
detail=f"Cannot start scan: {', '.join(running_job_types)} scan(s) already running"
)
# Parse groups from environment variable
gitlab_groups = []
if gitlab_groups_str:
gitlab_groups = [group.strip() for group in gitlab_groups_str.split(",") if group.strip()]
try:
# Create the scanner just to create the job
scanner = DockerImageScanner(gitlab_token, gitlab_url, gitlab_groups)
job = scanner.create_scan_job(db, "discovery")
# Start background task
background_tasks.add_task(
background_project_scan,
gitlab_token,
gitlab_url,
gitlab_groups,
job.id
)
group_info = f" (groups: {', '.join(gitlab_groups)})" if gitlab_groups else " (all projects)"
return {
"message": f"Project scan started{group_info}",
"job_id": job.id,
"status": "pending"
}
except Exception as e:
print(f"Scan error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to start scan: {str(e)}")
@app.post("/scan/vulnerabilities")
async def scan_vulnerabilities(background_tasks: BackgroundTasks, db: Session = Depends(get_db)):
# Check if there are any running or pending scans
running_jobs = db.query(ScanJob).filter(
ScanJob.status.in_(["pending", "running"])
).all()
if running_jobs:
running_job_types = [job.job_type for job in running_jobs]
raise HTTPException(
status_code=409,
detail=f"Cannot start scan: {', '.join(running_job_types)} scan(s) already running"
)
# Create job record
job = ScanJob(job_type="vulnerability", status="pending")
db.add(job)
db.commit()
db.refresh(job)
try:
# Start background task
background_tasks.add_task(background_vulnerability_scan, job.id)
return {
"message": "Vulnerability scan started",
"job_id": job.id,
"status": "pending"
}
except Exception as e:
job.status = "failed"
job.error_message = str(e)
job.completed_at = datetime.utcnow()
db.commit()
raise HTTPException(status_code=500, detail=f"Failed to start vulnerability scan: {str(e)}")
@app.get("/scan/jobs", response_model=List[dict])
async def get_scan_jobs(
skip: int = 0,
limit: int = 50,
db: Session = Depends(get_db)
):
jobs = db.query(ScanJob).order_by(ScanJob.created_at.desc()).offset(skip).limit(limit).all()
return [
{
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"project_id": job.project_id,
"started_at": job.started_at,
"completed_at": job.completed_at,
"error_message": job.error_message,
"created_at": job.created_at,
}
for job in jobs
]
@app.get("/scan/jobs/{job_id}")
async def get_scan_job(job_id: int, db: Session = Depends(get_db)):
job = db.query(ScanJob).filter(ScanJob.id == job_id).first()
if not job:
raise HTTPException(status_code=404, detail="Scan job not found")
return {
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"project_id": job.project_id,
"started_at": job.started_at,
"completed_at": job.completed_at,
"error_message": job.error_message,
"created_at": job.created_at,
}
@app.get("/scan/status")
async def get_scan_status(db: Session = Depends(get_db)):
"""Check if there are any running or pending scans"""
running_jobs = db.query(ScanJob).filter(
ScanJob.status.in_(["pending", "running"])
).all()
return {
"has_running_scans": len(running_jobs) > 0,
"running_jobs": [
{
"id": job.id,
"job_type": job.job_type,
"status": job.status,
"started_at": job.started_at,
"created_at": job.created_at,
}
for job in running_jobs
]
}
@app.get("/gitlab/groups")
async def get_gitlab_groups():
gitlab_token = os.getenv("GITLAB_TOKEN")
gitlab_url = os.getenv("GITLAB_URL", "https://gitlab.com")
if not gitlab_token:
raise HTTPException(status_code=500, detail="GitLab token not configured")
try:
import gitlab
gl = gitlab.Gitlab(gitlab_url, private_token=gitlab_token)
groups = gl.groups.list(all=True, simple=True)
return [
{
"id": group.id,
"name": group.name,
"path": group.path,
"full_path": group.full_path,
"description": getattr(group, 'description', ''),
}
for group in groups
]
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to fetch groups: {str(e)}")
def main():
import uvicorn
uvicorn.run("main:asgi_app", host="0.0.0.0", port=5000, reload=True)
# Create the ASGI app at module level for uvicorn
asgi_app = websocket_manager.get_asgi_app(app)
if __name__ == "__main__":
main()