#!/usr/bin/env python3
"""
trace_referrers.py — 反向引用追溯（自实现的 path-to-GC-root）。

从某个"嫌疑类"的所有实例出发，逐跳向上寻找 referrer（谁引用了它们）：
实例字段引用 / 数组元素 / **静态字段**。一直追到 GC root 或聚集容器，
从而在字段级别定位"哪个集合/字段持有了泄漏对象、且没有释放"。

这是直方图回答不了的关键问题：直方图告诉你"什么对象多"，
本脚本告诉你"谁持有它们、通过哪个字段"。等价于 Eclipse MAT 的
"Path to GC Roots / merge shortest paths"，但纯标准库、不依赖 MAT。

用法:
    python3 trace_referrers.py <dump.hprof> <class-name> [--hops N]

参数:
    class-name   嫌疑类全名，点或斜杠分隔均可
                 例: com.corundumstudio.socketio.handler.ClientHead
    --hops N     向上追溯的最大跳数（默认 6）

每一跳输出:
    - [referrer 类 -> referrer 对象数]      谁引用了当前层对象
    - [referrer 字段 -> 指向目标的边数]      通过哪个字段名/数组下标
    - [★静态字段持有(GC-root 级)]           被某类的 static 字段直接持有（强信号）
    - [★referrer 自身就是 GC root]          referrer 本身是 GC root（线程/JNI/sticky class 等）

阅读技巧:
    - 注意对象图常有环（如 ClientHead.clientsBox -> ClientsBox -> map -> ClientHead），
      反向 BFS 的 next 集合到后期会膨胀，这是正常现象。
    - 关注每一跳"收敛"出的单一容器（如某个 ConcurrentHashMap 的 table 数组、
      某个单例 Holder 对象），以及 ★ 标记的静态字段/GC root 锚点 —— 那就是泄漏的持有链。
"""
import struct
import argparse
from collections import defaultdict


def main():
    ap = argparse.ArgumentParser(description="HPROF 反向引用追溯")
    ap.add_argument("hprof")
    ap.add_argument("class_name", help="嫌疑类全名（点或斜杠分隔）")
    ap.add_argument("--hops", type=int, default=6, help="最大追溯跳数")
    args = ap.parse_args()

    path = args.hprof
    target_name = args.class_name.replace(".", "/")
    MAX_HOPS = args.hops

    def read_header(f):
        buf = bytearray()
        while True:
            c = f.read(1)
            if c in (b"\x00", b""):
                break
            buf += c
        if not bytes(buf).startswith(b"JAVA PROFILE"):
            raise SystemExit(f"[err] {path} 不是 HPROF 堆转储（缺少 'JAVA PROFILE' 魔数）")
        head = f.read(4)
        if len(head) < 4:
            raise SystemExit(f"[err] {path} 文件过短或损坏，无法读取 id_size")
        ids = struct.unpack(">I", head)[0]
        f.read(8)
        return ids

    with open(path, "rb") as f0:
        id_size = read_header(f0)
    ID = ">Q" if id_size == 8 else ">I"
    TS = {2: id_size, 4: 1, 5: 2, 6: 4, 7: 8, 8: 1, 9: 2, 10: 4, 11: 8}

    def rid(mv, off):
        return struct.unpack_from(ID, mv, off)[0], off + id_size

    strings = {}
    loadclass = {}
    class_super = {}
    class_ifields = {}
    roots = {}

    def parse_class_dump(mv, off, collect):
        cls_id, off = rid(mv, off)
        off += 4
        super_id, off = rid(mv, off)
        off += id_size * 5
        off += 4
        cp = struct.unpack_from(">H", mv, off)[0]; off += 2
        for _ in range(cp):
            off += 2
            t = mv[off]; off += 1
            off += TS[t]
        sf = struct.unpack_from(">H", mv, off)[0]; off += 2
        srefs = []
        for _ in range(sf):
            nid, off = rid(mv, off)
            t = mv[off]; off += 1
            if t == 2:
                val, off = rid(mv, off)
                srefs.append((nid, val))
            else:
                off += TS[t]
        iff = struct.unpack_from(">H", mv, off)[0]; off += 2
        ifs = []
        for _ in range(iff):
            nid, off = rid(mv, off)
            t = mv[off]; off += 1
            ifs.append((t, nid))
        if collect:
            class_super[cls_id] = super_id
            class_ifields[cls_id] = ifs
        return off, cls_id, srefs

    ROOT_FIX = {0xFF: 0, 0x05: 0, 0x07: 0, 0x02: 8, 0x03: 8, 0x08: 8, 0x04: 4, 0x06: 4}

    def skip_root(mv, off, sub):
        _, off = rid(mv, off)
        if sub == 0x01:
            _, off = rid(mv, off)
        else:
            off += ROOT_FIX[sub]
        return off

    # ---------- Pass 0: 类元数据 + 目标实例 ids + GC roots ----------
    def build():
        target_cls_id = None
        S0 = set()
        f = open(path, "rb"); read_header(f); read = f.read
        while True:
            hdr = read(9)
            if len(hdr) < 9:
                break
            tag = hdr[0]
            length = struct.unpack_from(">I", hdr, 5)[0]
            if tag == 0x01:
                body = read(length)
                sid = struct.unpack_from(ID, body, 0)[0]
                strings[sid] = bytes(body[id_size:]).decode("utf-8", "replace")
            elif tag == 0x02:
                body = read(length)
                off = 4
                cls_id, off = rid(body, off)
                off += 4
                nid, off = rid(body, off)
                loadclass[cls_id] = nid
            elif tag in (0x0C, 0x1C):
                body = read(length); mv = memoryview(body); n = len(mv); off = 0
                if target_cls_id is None:
                    for cid, nid in loadclass.items():
                        if strings.get(nid) == target_name:
                            target_cls_id = cid
                            break
                while off < n:
                    sub = mv[off]; off += 1
                    if sub == 0x21:
                        oid, off = rid(mv, off); off += 4
                        cls_id, off = rid(mv, off)
                        nb = struct.unpack_from(">I", mv, off)[0]; off += 4
                        off += nb
                        if cls_id == target_cls_id:
                            S0.add(oid)
                    elif sub == 0x20:
                        off, _, _ = parse_class_dump(mv, off, True)
                    elif sub == 0x22:
                        _, off = rid(mv, off); off += 4
                        num = struct.unpack_from(">I", mv, off)[0]; off += 4
                        _, off = rid(mv, off)
                        off += num * id_size
                    elif sub == 0x23:
                        _, off = rid(mv, off); off += 4
                        num = struct.unpack_from(">I", mv, off)[0]; off += 4
                        et = mv[off]; off += 1
                        off += num * TS[et]
                    else:
                        oid = struct.unpack_from(ID, mv, off)[0]
                        roots[oid] = sub
                        off = skip_root(mv, off, sub)
            else:
                f.seek(length, 1)
        f.close()
        return target_cls_id, S0

    refoffs_cache = {}

    def get_refoffs(cls_id):
        r = refoffs_cache.get(cls_id)
        if r is not None:
            return r
        res = []
        off = 0
        cid = cls_id
        while cid in class_ifields:
            for (t, nid) in class_ifields[cid]:
                if t == 2:
                    res.append((off, nid))
                    off += id_size
                else:
                    off += TS[t]
            cid = class_super.get(cid, 0)
        refoffs_cache[cls_id] = res
        return res

    def cname(cid):
        nid = loadclass.get(cid)
        return strings.get(nid, f"<cls@{cid}>").replace("/", ".") if nid else f"<cls@{cid}>"

    def fname(nid):
        return strings.get(nid, f"<f@{nid}>")

    def scan_referrers(S):
        next_set = set()
        field_hits = defaultdict(int)
        cls_hits = defaultdict(int)
        static_hits = defaultdict(int)
        root_ref = defaultdict(int)
        Sdisj = S.isdisjoint
        f = open(path, "rb"); read_header(f); read = f.read
        while True:
            hdr = read(9)
            if len(hdr) < 9:
                break
            tag = hdr[0]
            length = struct.unpack_from(">I", hdr, 5)[0]
            if tag in (0x0C, 0x1C):
                body = read(length); mv = memoryview(body); n = len(mv); off = 0
                while off < n:
                    sub = mv[off]; off += 1
                    if sub == 0x21:
                        oid, off = rid(mv, off); off += 4
                        cls_id, off = rid(mv, off)
                        nb = struct.unpack_from(">I", mv, off)[0]; off += 4
                        d = off; off += nb
                        hit = False
                        for (fo, nid) in get_refoffs(cls_id):
                            rv = struct.unpack_from(ID, mv, d + fo)[0]
                            if rv in S:
                                hit = True
                                field_hits[(cname(cls_id), fname(nid))] += 1
                        if hit:
                            next_set.add(oid)
                            cls_hits[cname(cls_id)] += 1
                            if oid in roots:
                                root_ref[cname(cls_id)] += 1
                    elif sub == 0x20:
                        off, cls_id, srefs = parse_class_dump(mv, off, False)
                        for (nid, val) in srefs:
                            if val in S:
                                static_hits[(cname(cls_id), fname(nid))] += 1
                    elif sub == 0x22:
                        oid, off = rid(mv, off); off += 4
                        num = struct.unpack_from(">I", mv, off)[0]; off += 4
                        arr_cls, off = rid(mv, off)
                        if num:
                            elems = struct.unpack_from(">%d%s" % (num, ID[1]), mv, off)
                            off += num * id_size
                            if not Sdisj(elems):
                                next_set.add(oid)
                                cls_hits[cname(arr_cls)] += 1
                                for e in elems:
                                    if e in S:
                                        field_hits[(cname(arr_cls), "[]")] += 1
                                if oid in roots:
                                    root_ref[cname(arr_cls)] += 1
                    elif sub == 0x23:
                        _, off = rid(mv, off); off += 4
                        num = struct.unpack_from(">I", mv, off)[0]; off += 4
                        et = mv[off]; off += 1
                        off += num * TS[et]
                    else:
                        off = skip_root(mv, off, sub)
            else:
                f.seek(length, 1)
        f.close()
        return next_set, field_hits, cls_hits, static_hits, root_ref

    def show(title, d, topn=15):
        print(title)
        for k, v in sorted(d.items(), key=lambda x: x[1], reverse=True)[:topn]:
            print(f"    {v:>12,}   {k}")

    print(f"[build] id_size={id_size} target={target_name}", flush=True)
    target_cls_id, S = build()
    if target_cls_id is None:
        raise SystemExit(f"[err] 未在 dump 中找到类 {target_name}（检查类名拼写/包路径）")
    print(f"[build] target_cls_id={target_cls_id}  instances={len(S):,}  "
          f"roots={len(roots):,}  classes={len(class_ifields):,}", flush=True)

    for hop in range(1, MAX_HOPS + 1):
        print("\n" + "#" * 90, flush=True)
        print(f"# HOP {hop}: 寻找引用 {len(S):,} 个对象的 referrer", flush=True)
        print("#" * 90, flush=True)
        nxt, field_hits, cls_hits, static_hits, root_ref = scan_referrers(S)
        show("[referrer 类 -> referrer 对象数]", cls_hits)
        show("[referrer 字段 -> 指向目标的边数]", field_hits)
        if static_hits:
            show("[★静态字段持有(GC-root 级) -> 边数]", static_hits)
        if root_ref:
            show("[★referrer 自身就是 GC root -> 个数]", root_ref)
        print(f"[next] referrer 对象总数 = {len(nxt):,}", flush=True)
        if not nxt:
            print("[stop] 无更多 referrer（已达根）", flush=True)
            break
        S = nxt
    print("\n[done]", flush=True)


if __name__ == "__main__":
    main()
