#!/usr/bin/env bash
set -euo pipefail

# Scans lockfiles for known vulnerabilities using Google's OSV database.
# Replaces pnpm audit (broken: npm retired legacy audit endpoints) and
# bundle-audit / ruby-audit with a single multi-ecosystem scanner.
#
# Usage:
#   bin/osv-audit [--level critical|high|medium|low] <lockfile> [lockfile...]
#
# Examples:
#   bin/osv-audit --level critical apps/website/pnpm-lock.yaml apps/website/Gemfile.lock
#   bin/osv-audit packages/ui/pnpm-lock.yaml

OSV_SCANNER_VERSION="2.3.5"

# Parse arguments
LEVEL="critical"
LOCKFILES=()

while [[ $# -gt 0 ]]; do
  case "$1" in
    --level)
      if [[ $# -lt 2 ]]; then
        echo "Error: --level requires a value (critical, high, medium, or low)"
        exit 1
      fi
      LEVEL="$2"
      shift 2
      ;;
    *)
      LOCKFILES+=("$1")
      shift
      ;;
  esac
done

if [[ ${#LOCKFILES[@]} -eq 0 ]]; then
  echo "Usage: bin/osv-audit [--level critical|high|medium|low] <lockfile> [lockfile...]"
  exit 1
fi

# Map level to minimum CVSS score
case "$LEVEL" in
  critical) MIN_CVSS=9.0 ;;
  high)     MIN_CVSS=7.0 ;;
  medium)   MIN_CVSS=4.0 ;;
  low)      MIN_CVSS=0.1 ;;
  *)
    echo "Unknown level: $LEVEL (expected critical, high, medium, or low)"
    exit 1
    ;;
esac

install_osv_scanner() {
  local install_dir="${OSV_SCANNER_DIR:-/tmp/osv-scanner}"
  local binary="$install_dir/osv-scanner"

  if [[ -x "$binary" ]] && "$binary" --version 2>/dev/null | grep -q "$OSV_SCANNER_VERSION"; then
    echo "$binary"
    return
  fi

  mkdir -p "$install_dir"

  local os arch platform
  os=$(uname -s | tr '[:upper:]' '[:lower:]')
  arch=$(uname -m)
  case "$arch" in
    x86_64)  arch="amd64" ;;
    aarch64|arm64) arch="arm64" ;;
  esac
  platform="${os}_${arch}"

  local base_url="https://github.com/google/osv-scanner/releases/download/v${OSV_SCANNER_VERSION}"
  local binary_name="osv-scanner_${platform}"

  echo "Installing osv-scanner v${OSV_SCANNER_VERSION}..." >&2
  curl -sSL "${base_url}/${binary_name}" -o "$binary"
  curl -sSL "${base_url}/osv-scanner_SHA256SUMS" -o "${install_dir}/SHA256SUMS"

  local expected_checksum
  expected_checksum=$(grep "$binary_name" "${install_dir}/SHA256SUMS" | awk '{print $1}')
  if [[ -z "$expected_checksum" ]]; then
    echo "Error: no checksum found for $binary_name in release SHA256SUMS" >&2
    rm -f "$binary"
    exit 1
  fi

  local actual_checksum
  if command -v sha256sum &>/dev/null; then
    actual_checksum=$(sha256sum "$binary" | awk '{print $1}')
  else
    actual_checksum=$(shasum -a 256 "$binary" | awk '{print $1}')
  fi
  if [[ "$actual_checksum" != "$expected_checksum" ]]; then
    echo "Error: checksum mismatch for osv-scanner binary" >&2
    echo "  expected: $expected_checksum" >&2
    echo "  got:      $actual_checksum" >&2
    rm -f "$binary"
    exit 1
  fi

  chmod +x "$binary"
  echo "$binary"
}

OSV_SCANNER=$(install_osv_scanner)

# Build lockfile args
LOCKFILE_ARGS=()
for lf in "${LOCKFILES[@]}"; do
  LOCKFILE_ARGS+=("-L" "$lf")
done

# Run a single JSON scan, then derive both the summary and severity check from it
echo "Scanning: ${LOCKFILES[*]}"
echo ""
JSON_OUTPUT=$("$OSV_SCANNER" scan source -f json "${LOCKFILE_ARGS[@]}") || true

# Print summary and count failing groups — a jq failure here must not silently pass
RESULT=$(echo "$JSON_OUTPUT" | jq --argjson min "$MIN_CVSS" '
  [.results[]?.packages[]?.groups[]?.max_severity | select(. != null) | tonumber] as $scores |
  {
    summary: {
      critical: [$scores[] | select(. >= 9.0)] | length,
      high:     [$scores[] | select(. >= 7.0 and . < 9.0)] | length,
      medium:   [$scores[] | select(. >= 4.0 and . < 7.0)] | length,
      low:      [$scores[] | select(. > 0 and . < 4.0)] | length
    },
    failing: [$scores[] | select(. >= $min)] | length
  }
') || {
  echo "Error: failed to parse osv-scanner output (is jq installed?)"
  exit 1
}

CRITICAL=$(echo "$RESULT" | jq '.summary.critical')
HIGH=$(echo "$RESULT" | jq '.summary.high')
MEDIUM=$(echo "$RESULT" | jq '.summary.medium')
LOW=$(echo "$RESULT" | jq '.summary.low')
FAILING_COUNT=$(echo "$RESULT" | jq '.failing')
TOTAL=$((CRITICAL + HIGH + MEDIUM + LOW))

echo "Found $TOTAL vulnerability group(s): $CRITICAL critical, $HIGH high, $MEDIUM medium, $LOW low"
echo ""

if [[ "$FAILING_COUNT" -gt 0 ]]; then
  # Print the table for detail on what's failing
  "$OSV_SCANNER" scan source "${LOCKFILE_ARGS[@]}" || true
  echo ""
  echo "FAILED: $FAILING_COUNT vulnerability group(s) at '$LEVEL' level (CVSS >= $MIN_CVSS) or above."
  exit 1
fi

echo "PASSED: No vulnerabilities at '$LEVEL' level (CVSS >= $MIN_CVSS) or above."
