#!/usr/bin/env python3
"""Run trigger evaluation for a skill description.

This is an advanced, Windows-safe adaptation of the upstream script. It keeps
the same core idea but adds:
- cross-platform subprocess streaming (no select.select on pipes)
- richer JSON output and per-run diagnostics
- support for either a raw JSON array or {"queries": [...]} eval files
- clear schema errors when someone accidentally passes task evals.json
"""

from __future__ import annotations

import argparse
import json
import os
import queue
import subprocess
import sys
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any

try:
    from scripts.utils import parse_skill_md, write_json
except ImportError:  # pragma: no cover - direct script execution fallback
    from utils import parse_skill_md, write_json


def find_project_root() -> Path:
    """Find the Claude project root by searching upward for .claude/."""
    current = Path.cwd()
    for parent in [current, *current.parents]:
        if (parent / ".claude").is_dir():
            return parent
    return current


def normalize_eval_items(data: Any) -> list[dict[str, Any]]:
    """Normalize supported trigger-eval JSON formats."""
    if isinstance(data, list):
        raw_items = data
    elif isinstance(data, dict) and isinstance(data.get("queries"), list):
        raw_items = data["queries"]
    elif isinstance(data, dict) and isinstance(data.get("evals"), list):
        if data["evals"] and not all("should_trigger" in item for item in data["evals"]):
            raise ValueError(
                "This looks like task evals (assets/evals/evals.json), not trigger evals. "
                "Use a JSON array or {'queries': [...]} with 'query' and 'should_trigger'."
            )
        raw_items = data["evals"]
    else:
        raise ValueError(
            "Trigger eval JSON must be either a list of items or an object with a 'queries' list."
        )

    normalized: list[dict[str, Any]] = []
    for index, item in enumerate(raw_items, start=1):
        if not isinstance(item, dict):
            raise ValueError(f"Eval item #{index} must be an object")

        query = str(item.get("query", "")).strip()
        if not query:
            raise ValueError(f"Eval item #{index} is missing a non-empty 'query'")

        if "should_trigger" not in item:
            raise ValueError(f"Eval item #{index} is missing 'should_trigger'")

        normalized.append(
            {
                "id": str(item.get("id") or index),
                "query": query,
                "should_trigger": bool(item["should_trigger"]),
                "notes": str(item.get("notes", "")).strip(),
                "tags": [str(tag) for tag in item.get("tags", [])] if isinstance(item.get("tags", []), list) else [],
            }
        )

    if not normalized:
        raise ValueError("Trigger eval set is empty")
    return normalized


def load_trigger_eval_set(path: Path) -> list[dict[str, Any]]:
    return normalize_eval_items(json.loads(path.read_text(encoding="utf-8")))


def _build_command_file(
    project_root: Path,
    skill_name: str,
    skill_description: str,
    unique_suffix: str,
) -> tuple[Path, str]:
    clean_name = f"{skill_name}-skill-{unique_suffix}"
    command_dir = project_root / ".claude" / "commands"
    command_dir.mkdir(parents=True, exist_ok=True)
    command_path = command_dir / f"{clean_name}.md"

    indented_description = "\n  ".join(skill_description.splitlines() or [""])
    command_content = (
        "---\n"
        "description: |\n"
        f"  {indented_description}\n"
        "---\n\n"
        f"# {skill_name}\n\n"
        "Temporary trigger-eval command generated by skill-creator-advanced.\n"
    )
    command_path.write_text(command_content, encoding="utf-8")
    return command_path, clean_name


def _reader_thread(stdout: Any, line_queue: "queue.Queue[str | None]") -> None:
    try:
        for line in iter(stdout.readline, ""):
            line_queue.put(line)
    finally:
        line_queue.put(None)


def _inspect_stream_event(
    event: dict[str, Any],
    clean_name: str,
    pending_tool_name: str | None,
    accumulated_json: str,
    triggered: bool,
) -> tuple[bool | None, str | None, str | None, str]:
    """Inspect one JSON event line and optionally decide the result."""
    event_type = event.get("type")

    if event_type == "stream_event":
        stream_event = event.get("event", {})
        stream_type = stream_event.get("type", "")

        if stream_type == "content_block_start":
            content_block = stream_event.get("content_block", {})
            if content_block.get("type") == "tool_use":
                tool_name = content_block.get("name", "")
                if tool_name in {"Skill", "Read"}:
                    return None, tool_name, "", accumulated_json
                return False, pending_tool_name, f"other_tool:{tool_name or 'unknown'}", accumulated_json

        elif stream_type == "content_block_delta" and pending_tool_name:
            delta = stream_event.get("delta", {})
            if delta.get("type") == "input_json_delta":
                accumulated_json += delta.get("partial_json", "")
                if clean_name in accumulated_json:
                    return True, pending_tool_name, "trigger_detected", accumulated_json
            return None, pending_tool_name, None, accumulated_json

        elif stream_type in {"content_block_stop", "message_stop"}:
            if pending_tool_name:
                if clean_name in accumulated_json:
                    return True, None, "tool_block_stop", accumulated_json
                return False, None, "tool_block_stop", accumulated_json
            if stream_type == "message_stop":
                return triggered, None, "message_stop", accumulated_json

    elif event_type == "assistant":
        message = event.get("message", {})
        for content_item in message.get("content", []):
            if content_item.get("type") != "tool_use":
                continue
            tool_name = content_item.get("name", "")
            tool_input = content_item.get("input", {}) or {}
            if tool_name == "Skill" and clean_name in str(tool_input.get("skill", "")):
                return True, pending_tool_name, "assistant_tool_use", accumulated_json
            if tool_name == "Read" and clean_name in str(tool_input.get("file_path", "")):
                return True, pending_tool_name, "assistant_tool_use", accumulated_json
            return False, pending_tool_name, f"assistant_other_tool:{tool_name or 'unknown'}", accumulated_json

    elif event_type == "result":
        return triggered, pending_tool_name, "result", accumulated_json

    return None, pending_tool_name, None, accumulated_json


def run_single_query(
    query: str,
    skill_name: str,
    skill_description: str,
    timeout: int,
    project_root: Path,
    model: str | None = None,
    keep_command_files: bool = False,
) -> dict[str, Any]:
    """Run one query and determine whether the description triggered."""
    unique_suffix = uuid.uuid4().hex[:8]
    command_path, clean_name = _build_command_file(project_root, skill_name, skill_description, unique_suffix)
    start_time = time.time()
    line_queue: "queue.Queue[str | None]" = queue.Queue()
    triggered = False
    pending_tool_name: str | None = None
    accumulated_json = ""
    termination = "process_exit"
    error_message = None
    parsed_events = 0
    raw_lines = 0

    cmd = [
        "claude",
        "-p",
        query,
        "--output-format",
        "stream-json",
        "--verbose",
        "--include-partial-messages",
    ]
    if model:
        cmd.extend(["--model", model])

    env = {key: value for key, value in os.environ.items() if key != "CLAUDECODE"}
    process: subprocess.Popen[str] | None = None

    try:
        try:
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                cwd=project_root,
                env=env,
                text=True,
                encoding="utf-8",
                errors="replace",
                bufsize=1,
            )
        except FileNotFoundError as exc:
            raise RuntimeError(
                "Cannot find 'claude' CLI. Install/configure Claude Code before running trigger evals."
            ) from exc

        assert process.stdout is not None
        reader = threading.Thread(target=_reader_thread, args=(process.stdout, line_queue), daemon=True)
        reader.start()

        saw_sentinel = False
        while True:
            elapsed = time.time() - start_time
            if elapsed >= timeout:
                termination = "timeout"
                break

            if process.poll() is not None and saw_sentinel and line_queue.empty():
                break

            try:
                line = line_queue.get(timeout=0.2)
            except queue.Empty:
                continue

            if line is None:
                saw_sentinel = True
                continue

            raw_lines += 1
            stripped = line.strip()
            if not stripped:
                continue

            try:
                event = json.loads(stripped)
            except json.JSONDecodeError:
                continue

            parsed_events += 1
            decision, pending_tool_name, reason, accumulated_json = _inspect_stream_event(
                event=event,
                clean_name=clean_name,
                pending_tool_name=pending_tool_name,
                accumulated_json=accumulated_json,
                triggered=triggered,
            )
            if decision is not None:
                triggered = bool(decision)
                termination = reason or termination
                break

        stderr_output = ""
        if process.stderr is not None:
            try:
                stderr_output = process.stderr.read()
            except Exception:
                stderr_output = ""

        return_code = process.poll()
        if return_code is None:
            process.kill()
            process.wait(timeout=5)
            return_code = process.returncode

        if return_code not in (0, None) and termination not in {"timeout"} and not triggered:
            error_message = stderr_output.strip() or f"claude exited with {return_code}"
            termination = "claude_error"

    except Exception as exc:  # pragma: no cover - subprocess behavior depends on local env
        error_message = str(exc)
        termination = "exception"
    finally:
        if process is not None:
            if process.poll() is None:
                process.kill()
                process.wait(timeout=5)
            if process.stdout is not None:
                process.stdout.close()
            if process.stderr is not None:
                process.stderr.close()
        if not keep_command_files and command_path.exists():
            command_path.unlink()

    duration_seconds = round(time.time() - start_time, 3)
    return {
        "triggered": triggered,
        "duration_seconds": duration_seconds,
        "termination": termination,
        "error": error_message,
        "parsed_events": parsed_events,
        "raw_lines": raw_lines,
    }


def build_summary(results: list[dict[str, Any]]) -> dict[str, Any]:
    passed = sum(1 for result in results if result["pass"])
    total = len(results)

    true_positive = sum(result["triggers"] for result in results if result["should_trigger"])
    false_negative = sum(
        result["runs"] - result["triggers"] for result in results if result["should_trigger"]
    )
    false_positive = sum(result["triggers"] for result in results if not result["should_trigger"])
    true_negative = sum(
        result["runs"] - result["triggers"] for result in results if not result["should_trigger"]
    )
    run_total = true_positive + true_negative + false_positive + false_negative

    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 1.0
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 1.0
    accuracy = (true_positive + true_negative) / run_total if run_total else 0.0

    return {
        "total": total,
        "passed": passed,
        "failed": total - passed,
        "run_level": {
            "true_positive": true_positive,
            "false_positive": false_positive,
            "true_negative": true_negative,
            "false_negative": false_negative,
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "accuracy": round(accuracy, 4),
            "total_runs": run_total,
        },
    }


def run_eval(
    eval_set: list[dict[str, Any]],
    skill_name: str,
    description: str,
    num_workers: int,
    timeout: int,
    project_root: Path,
    runs_per_query: int = 1,
    trigger_threshold: float = 0.5,
    model: str | None = None,
    keep_command_files: bool = False,
) -> dict[str, Any]:
    """Run a full trigger-eval set."""
    items_by_index = {index: item for index, item in enumerate(eval_set)}
    details_by_index: dict[int, list[dict[str, Any]]] = {index: [] for index in items_by_index}

    max_workers = max(1, min(num_workers, len(eval_set) * max(runs_per_query, 1)))
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_index = {}
        for index, item in items_by_index.items():
            for _ in range(runs_per_query):
                future = executor.submit(
                    run_single_query,
                    query=item["query"],
                    skill_name=skill_name,
                    skill_description=description,
                    timeout=timeout,
                    project_root=project_root,
                    model=model,
                    keep_command_files=keep_command_files,
                )
                future_to_index[future] = index

        for future in as_completed(future_to_index):
            index = future_to_index[future]
            try:
                details_by_index[index].append(future.result())
            except Exception as exc:  # pragma: no cover - executor safety net
                details_by_index[index].append(
                    {
                        "triggered": False,
                        "duration_seconds": 0.0,
                        "termination": "executor_exception",
                        "error": str(exc),
                        "parsed_events": 0,
                        "raw_lines": 0,
                    }
                )

    results: list[dict[str, Any]] = []
    for index in sorted(items_by_index):
        item = items_by_index[index]
        run_details = details_by_index[index]
        triggers = sum(1 for detail in run_details if detail["triggered"])
        runs = len(run_details)
        trigger_rate = triggers / runs if runs else 0.0
        should_trigger = item["should_trigger"]
        did_pass = trigger_rate >= trigger_threshold if should_trigger else trigger_rate < trigger_threshold

        results.append(
            {
                "id": item["id"],
                "query": item["query"],
                "should_trigger": should_trigger,
                "trigger_rate": round(trigger_rate, 4),
                "triggers": triggers,
                "runs": runs,
                "pass": did_pass,
                "avg_duration_seconds": round(
                    sum(detail["duration_seconds"] for detail in run_details) / runs, 3
                )
                if runs
                else 0.0,
                "run_details": run_details,
                "notes": item.get("notes", ""),
                "tags": item.get("tags", []),
            }
        )

    summary = build_summary(results)
    return {
        "skill_name": skill_name,
        "description": description,
        "results": results,
        "summary": summary,
        "metadata": {
            "project_root": str(project_root),
            "runs_per_query": runs_per_query,
            "trigger_threshold": trigger_threshold,
            "model": model,
            "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        },
    }


def main() -> int:
    parser = argparse.ArgumentParser(description="Run trigger evaluation for a skill description")
    parser.add_argument("--eval-set", required=True, help="Path to trigger eval JSON")
    parser.add_argument("--skill-path", required=True, help="Path to the skill directory")
    parser.add_argument("--description", default=None, help="Override description to test")
    parser.add_argument("--num-workers", type=int, default=6, help="Parallel worker count")
    parser.add_argument("--timeout", type=int, default=45, help="Per-query timeout in seconds")
    parser.add_argument("--runs-per-query", type=int, default=3, help="Repeated runs per query")
    parser.add_argument(
        "--trigger-threshold",
        type=float,
        default=0.5,
        help="Minimum trigger rate for should-trigger queries to count as pass",
    )
    parser.add_argument("--project-root", default=None, help="Override detected project root")
    parser.add_argument("--model", default=None, help="Model passed to claude -p")
    parser.add_argument("--output", default=None, help="Write JSON result to this path")
    parser.add_argument(
        "--keep-command-files",
        action="store_true",
        help="Keep temporary .claude/commands files for debugging",
    )
    parser.add_argument("--verbose", action="store_true", help="Print progress to stderr")
    args = parser.parse_args()

    skill_path = Path(args.skill_path).resolve()
    if not (skill_path / "SKILL.md").exists():
        print(f"Error: No SKILL.md found at {skill_path}", file=sys.stderr)
        return 1

    eval_path = Path(args.eval_set).resolve()
    try:
        eval_set = load_trigger_eval_set(eval_path)
    except Exception as exc:
        print(f"Error loading trigger eval set: {exc}", file=sys.stderr)
        return 1

    name, original_description, _ = parse_skill_md(skill_path)
    description = args.description or original_description
    project_root = Path(args.project_root).resolve() if args.project_root else find_project_root()

    if args.verbose:
        print(f"Skill: {name}", file=sys.stderr)
        print(f"Project root: {project_root}", file=sys.stderr)
        print(f"Queries: {len(eval_set)}", file=sys.stderr)
        print(f"Description chars: {len(description)}", file=sys.stderr)

    output = run_eval(
        eval_set=eval_set,
        skill_name=name,
        description=description,
        num_workers=args.num_workers,
        timeout=args.timeout,
        project_root=project_root,
        runs_per_query=args.runs_per_query,
        trigger_threshold=args.trigger_threshold,
        model=args.model,
        keep_command_files=args.keep_command_files,
    )

    if args.verbose:
        summary = output["summary"]
        print(
            f"Query pass rate: {summary['passed']}/{summary['total']} | "
            f"run accuracy: {summary['run_level']['accuracy']:.0%} | "
            f"precision: {summary['run_level']['precision']:.0%} | "
            f"recall: {summary['run_level']['recall']:.0%}",
            file=sys.stderr,
        )
        for result in output["results"]:
            status = "PASS" if result["pass"] else "FAIL"
            print(
                f"[{status}] rate={result['triggers']}/{result['runs']} "
                f"expected={result['should_trigger']} :: {result['query'][:100]}",
                file=sys.stderr,
            )

    rendered = json.dumps(output, indent=2, ensure_ascii=False)
    print(rendered)
    if args.output:
        write_json(Path(args.output), output)
        if args.verbose:
            print(f"Wrote: {Path(args.output).resolve()}", file=sys.stderr)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
