#!/usr/bin/env python3
"""Aggregate grading.json and timing.json files into a benchmark.json summary.

Usage:
    python3 aggregate_benchmark.py <iteration-dir>
    python3 aggregate_benchmark.py <workspace-dir> --multi-run

Single iteration mode:
    Reads all grading.json and timing.json from an iteration directory,
    produces benchmark.json with pass rates, tokens, and timing per config.

Multi-run mode (--multi-run):
    Reads benchmark.json from multiple iteration directories,
    computes mean ± stddev across runs.
"""

import argparse
import sys
from pathlib import Path

# Add script directory to path for utils import
sys.path.insert(0, str(Path(__file__).parent))
from utils import (
    load_json, save_json, collect_grading_files, collect_timing_files,
    mean, stddev, now_iso
)


def aggregate_iteration(iteration_dir):
    """Aggregate results from a single iteration into benchmark.json."""
    iteration_dir = Path(iteration_dir)
    gradings = collect_grading_files(iteration_dir)
    timings = collect_timing_files(iteration_dir)

    if not gradings:
        print(f"No grading.json files found in {iteration_dir}", file=sys.stderr)
        sys.exit(1)

    # Group by config
    configs = {}
    for g in gradings:
        config = g.get("config", "unknown")
        if config not in configs:
            configs[config] = {"gradings": [], "timings": []}
        configs[config]["gradings"].append(g)

    for t in timings:
        config = t.get("config", "unknown")
        if config in configs:
            configs[config]["timings"].append(t)

    # Build benchmark
    benchmark = {
        "iteration": int(iteration_dir.name.split("-")[1]) if "-" in iteration_dir.name else 1,
        "timestamp": now_iso(),
        "configs": {},
        "comparison": {}
    }

    for config_name, data in configs.items():
        pass_rates = [g["pass_rate"] for g in data["gradings"]]
        tokens = [t["total_tokens"] for t in data["timings"]]
        durations = [t["duration_ms"] for t in data["timings"]]

        evals = []
        for g in data["gradings"]:
            eval_entry = {
                "eval_id": g["eval_id"],
                "pass_rate": g["pass_rate"],
                "pass_count": g["pass_count"],
                "fail_count": g["fail_count"],
            }
            # Find matching timing
            matching_timing = [t for t in data["timings"] if t["eval_id"] == g["eval_id"]]
            if matching_timing:
                eval_entry["tokens"] = matching_timing[0]["total_tokens"]
                eval_entry["duration_ms"] = matching_timing[0]["duration_ms"]
            evals.append(eval_entry)

        benchmark["configs"][config_name] = {
            "overall_pass_rate": mean(pass_rates),
            "total_tokens_mean": mean(tokens) if tokens else 0,
            "total_tokens_stddev": stddev(tokens) if tokens else 0,
            "duration_ms_mean": mean(durations) if durations else 0,
            "duration_ms_stddev": stddev(durations) if durations else 0,
            "evals": evals
        }

    # Compute comparison if both configs exist
    if "with_skill" in benchmark["configs"] and "without_skill" in benchmark["configs"]:
        ws = benchmark["configs"]["with_skill"]
        wo = benchmark["configs"]["without_skill"]
        benchmark["comparison"] = {
            "pass_rate_delta": ws["overall_pass_rate"] - wo["overall_pass_rate"],
            "token_overhead_percent": (
                ((ws["total_tokens_mean"] - wo["total_tokens_mean"]) / wo["total_tokens_mean"] * 100)
                if wo["total_tokens_mean"] > 0 else 0
            ),
            "time_overhead_percent": (
                ((ws["duration_ms_mean"] - wo["duration_ms_mean"]) / wo["duration_ms_mean"] * 100)
                if wo["duration_ms_mean"] > 0 else 0
            ),
        }

        # Find non-discriminating assertions
        non_disc = []
        ws_evals = {e["eval_id"]: e for e in ws["evals"]}
        wo_evals = {e["eval_id"]: e for e in wo["evals"]}
        for eval_id in ws_evals:
            if eval_id in wo_evals:
                if ws_evals[eval_id]["pass_rate"] == 1.0 and wo_evals[eval_id]["pass_rate"] == 1.0:
                    non_disc.append(eval_id)
        benchmark["comparison"]["non_discriminating_evals"] = non_disc

    output_path = iteration_dir / "benchmark.json"
    save_json(output_path, benchmark)
    return benchmark


def aggregate_multi_run(workspace_dir):
    """Aggregate benchmark.json files from multiple iterations for statistical analysis."""
    workspace_dir = Path(workspace_dir)
    benchmarks = []

    for d in sorted(workspace_dir.iterdir()):
        if d.is_dir() and d.name.startswith("iteration-"):
            bm_path = d / "benchmark.json"
            if bm_path.exists():
                benchmarks.append(load_json(bm_path))

    if len(benchmarks) < 2:
        print(f"Need at least 2 iteration benchmarks for multi-run analysis, found {len(benchmarks)}", file=sys.stderr)
        sys.exit(1)

    result = {
        "timestamp": now_iso(),
        "num_runs": len(benchmarks),
        "configs": {}
    }

    # Aggregate per config
    for config_name in ["with_skill", "without_skill"]:
        pass_rates = [b["configs"][config_name]["overall_pass_rate"]
                      for b in benchmarks if config_name in b.get("configs", {})]
        tokens = [b["configs"][config_name]["total_tokens_mean"]
                  for b in benchmarks if config_name in b.get("configs", {})]
        durations = [b["configs"][config_name]["duration_ms_mean"]
                     for b in benchmarks if config_name in b.get("configs", {})]

        if pass_rates:
            result["configs"][config_name] = {
                "pass_rate": {"mean": mean(pass_rates), "stddev": stddev(pass_rates)},
                "tokens": {"mean": mean(tokens), "stddev": stddev(tokens)},
                "duration_ms": {"mean": mean(durations), "stddev": stddev(durations)},
            }

    output_path = workspace_dir / "benchmark.json"
    save_json(output_path, result)
    return result


def main():
    parser = argparse.ArgumentParser(description="Aggregate eval results into benchmark summary")
    parser.add_argument("path", help="Iteration directory or workspace directory (with --multi-run)")
    parser.add_argument("--multi-run", action="store_true", help="Aggregate across multiple iterations")
    args = parser.parse_args()

    if args.multi_run:
        result = aggregate_multi_run(args.path)
        print(f"\nMulti-run benchmark ({result['num_runs']} runs):")
        for config, stats in result["configs"].items():
            pr = stats["pass_rate"]
            print(f"  {config}: pass_rate={pr['mean']:.2f} ± {pr['stddev']:.2f}")
    else:
        result = aggregate_iteration(args.path)
        print(f"\nIteration {result['iteration']} benchmark:")
        for config, stats in result["configs"].items():
            print(f"  {config}: pass_rate={stats['overall_pass_rate']:.2f}, "
                  f"tokens={stats['total_tokens_mean']:.0f}, "
                  f"time={stats['duration_ms_mean']:.0f}ms")
        if result.get("comparison"):
            c = result["comparison"]
            print(f"\n  Comparison:")
            print(f"    Pass rate delta: +{c['pass_rate_delta']:.2f}")
            print(f"    Token overhead: {c['token_overhead_percent']:.1f}%")
            print(f"    Time overhead: {c['time_overhead_percent']:.1f}%")


if __name__ == "__main__":
    main()
