#!/usr/bin/env python3
"""Description optimization loop with train/test split.

Generates trigger eval queries, tests the current description,
and suggests improvements based on accuracy metrics.

Usage:
    python3 run_loop.py --skill-dir .claude/skills/my-skill
    python3 run_loop.py --skill-dir .claude/skills/my-skill --use-claude
    python3 run_loop.py --queries queries.json --description "Skill desc"

Input:
    Either a --skill-dir (reads SKILL.md frontmatter) or explicit --description.
    Optionally --queries pointing to a JSON file with pre-defined trigger queries.

Output:
    Prints accuracy metrics and saves results to description_eval.json in skill dir.
"""

import argparse
import json
import random
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent))
from utils import load_json, save_json, now_iso
from run_eval import evaluate_trigger, evaluate_trigger_with_claude


def parse_frontmatter(skill_path):
    """Extract name and description from SKILL.md frontmatter."""
    path = Path(skill_path) / "SKILL.md"
    if not path.exists():
        print(f"Error: {path} not found", file=sys.stderr)
        sys.exit(1)

    content = path.read_text()
    if not content.startswith("---"):
        print("Error: No frontmatter found in SKILL.md", file=sys.stderr)
        sys.exit(1)

    end = content.index("---", 3)
    frontmatter = content[3:end]

    name = ""
    description = ""
    for line in frontmatter.strip().split("\n"):
        if line.startswith("name:"):
            name = line.split(":", 1)[1].strip().strip('"').strip("'")
        elif line.startswith("description:"):
            description = line.split(":", 1)[1].strip().strip('"').strip("'")

    return name, description


def load_queries(queries_path):
    """Load trigger queries from a JSON file."""
    data = load_json(queries_path)
    return data.get("should_trigger", []), data.get("should_not_trigger", [])


def train_test_split(items, train_ratio=0.7):
    """Split items into train and test sets."""
    shuffled = items[:]
    random.shuffle(shuffled)
    split_idx = max(1, int(len(shuffled) * train_ratio))
    return shuffled[:split_idx], shuffled[split_idx:]


def evaluate_description(description, should_trigger, should_not_trigger, use_claude=False):
    """Evaluate a description against trigger queries."""
    eval_fn = evaluate_trigger_with_claude if use_claude else evaluate_trigger

    results = {"true_positive": 0, "false_positive": 0,
               "true_negative": 0, "false_negative": 0,
               "details": []}

    for query in should_trigger:
        result = eval_fn(description, query)
        triggered = result["would_trigger"]
        if triggered:
            results["true_positive"] += 1
        else:
            results["false_negative"] += 1
        results["details"].append({
            "query": query, "expected": True, "actual": triggered,
            "correct": triggered, "confidence": result.get("confidence", 0)
        })

    for query in should_not_trigger:
        result = eval_fn(description, query)
        triggered = result["would_trigger"]
        if not triggered:
            results["true_negative"] += 1
        else:
            results["false_positive"] += 1
        results["details"].append({
            "query": query, "expected": False, "actual": triggered,
            "correct": not triggered, "confidence": result.get("confidence", 0)
        })

    total = len(should_trigger) + len(should_not_trigger)
    correct = results["true_positive"] + results["true_negative"]
    results["accuracy"] = correct / total if total > 0 else 0
    results["precision"] = (
        results["true_positive"] / (results["true_positive"] + results["false_positive"])
        if (results["true_positive"] + results["false_positive"]) > 0 else 0
    )
    results["recall"] = (
        results["true_positive"] / (results["true_positive"] + results["false_negative"])
        if (results["true_positive"] + results["false_negative"]) > 0 else 0
    )

    return results


def main():
    parser = argparse.ArgumentParser(description="Description optimization loop")
    parser.add_argument("--skill-dir", help="Path to skill directory")
    parser.add_argument("--description", help="Explicit description to test")
    parser.add_argument("--queries", help="Path to queries JSON file")
    parser.add_argument("--use-claude", action="store_true",
                        help="Use claude CLI for evaluation")
    parser.add_argument("--train-only", action="store_true",
                        help="Only evaluate against training set")
    args = parser.parse_args()

    # Get description
    if args.description:
        description = args.description
        skill_name = "unknown"
    elif args.skill_dir:
        skill_name, description = parse_frontmatter(args.skill_dir)
    else:
        print("Error: Provide --skill-dir or --description", file=sys.stderr)
        sys.exit(1)

    if not description:
        print("Error: No description found", file=sys.stderr)
        sys.exit(1)

    # Get queries
    if args.queries:
        should_trigger, should_not_trigger = load_queries(args.queries)
    else:
        print("Error: Provide --queries with trigger eval queries", file=sys.stderr)
        sys.exit(1)

    # Split into train/test
    train_yes, test_yes = train_test_split(should_trigger)
    train_no, test_no = train_test_split(should_not_trigger)

    print(f"Skill: {skill_name}")
    print(f"Description: {description[:80]}...")
    print(f"Queries: {len(should_trigger)} should-trigger, {len(should_not_trigger)} should-not-trigger")
    print(f"Train/Test split: {len(train_yes)+len(train_no)} / {len(test_yes)+len(test_no)}")
    print()

    # Evaluate on training set
    print("=== Training Set ===")
    train_results = evaluate_description(description, train_yes, train_no, args.use_claude)
    print(f"Accuracy:  {train_results['accuracy']:.1%}")
    print(f"Precision: {train_results['precision']:.1%}")
    print(f"Recall:    {train_results['recall']:.1%}")

    # Show misclassifications
    misclassified = [d for d in train_results["details"] if not d["correct"]]
    if misclassified:
        print(f"\nMisclassified ({len(misclassified)}):")
        for m in misclassified:
            label = "FN" if m["expected"] else "FP"
            print(f"  [{label}] {m['query']}")

    # Evaluate on test set (unless train-only)
    if not args.train_only and (test_yes or test_no):
        print("\n=== Test Set ===")
        test_results = evaluate_description(description, test_yes, test_no, args.use_claude)
        print(f"Accuracy:  {test_results['accuracy']:.1%}")
        print(f"Precision: {test_results['precision']:.1%}")
        print(f"Recall:    {test_results['recall']:.1%}")
    else:
        test_results = None

    # Save results
    output = {
        "timestamp": now_iso(),
        "skill_name": skill_name,
        "description": description,
        "train_results": {
            "accuracy": train_results["accuracy"],
            "precision": train_results["precision"],
            "recall": train_results["recall"],
            "misclassified": misclassified
        }
    }
    if test_results:
        output["test_results"] = {
            "accuracy": test_results["accuracy"],
            "precision": test_results["precision"],
            "recall": test_results["recall"],
        }

    if args.skill_dir:
        save_json(Path(args.skill_dir) / "description_eval.json", output)
    else:
        print(json.dumps(output, indent=2))


if __name__ == "__main__":
    main()
