#!/usr/bin/env python3
"""
inspect_objects.py — 精确实测某个类的实例字段 / 静态字段值，以及它持有的 Map 的真实条目数。

用于内存泄漏分析的"钉死根因"阶段：当反向追溯/MAT 指出某个 holder 的某个集合字段
是泄漏聚集点时，用本脚本读出确凿数字 —— 例如：
  - 某 ConcurrentHashMap 字段实际装了多少条目（容量 / 非空桶）
  - 某配置对象的关键字段实际值（如心跳/超时配置是否被改坏）
  - 某工具类的 *静态* 缓存 Map 实际有多大（静态字段是 GC-root 级持有，最常见的泄漏点之一）

四种模式:
  1) 列字段（默认）: 只给 --class，打印该类的全部实例字段与静态字段（名/类型），
                     帮助你决定接下来读哪个字段。
  2) 读实例字段 --fields a,b,c:
                     打印该类每个实例上这些字段的实际值（int/long/bool/ref…）。
  3) 读静态字段 --static-fields a,b,c:
                     打印该类这些 *静态* 字段的值；若静态字段指向 HashMap/ConcurrentHashMap，
                     一并实测其条目数。适合排查静态缓存/注册表泄漏。
  4) 测实例 Map --map-fields x,y:
                     把这些实例字段当 java.util.HashMap / ConcurrentHashMap，
                     输出 size 字段值、table 容量、非空桶数（条目数可靠下界）。

用法:
    python3 inspect_objects.py <dump.hprof> --class <holder-class>
    python3 inspect_objects.py <dump.hprof> --class <holder-class> --fields f1,f2
    python3 inspect_objects.py <dump.hprof> --class <holder-class> --static-fields s1,s2
    python3 inspect_objects.py <dump.hprof> --class <holder-class> --map-fields m1,m2
    [--limit N]  最多打印前 N 个实例（默认 20）

注意:
    ConcurrentHashMap 的 baseCount/size 在高并发下可能偏小（计数分散到 counterCells），
    本脚本同时给出"非空桶数"，它是条目数更可靠的下界；当 size 字段与非空桶量级差 >10×
    时会主动追加 ⚠️ 提示。结合 table 容量（2 的幂）即可判断真实规模量级。
"""
import sys
import struct
import argparse


def main():
    ap = argparse.ArgumentParser(description="HPROF 对象字段 / 静态字段 / Map 条目实测")
    ap.add_argument("hprof")
    ap.add_argument("--class", dest="cls", required=True, help="holder 类全名（点或斜杠）")
    ap.add_argument("--fields", default="", help="逗号分隔的实例字段名，读其实际值")
    ap.add_argument("--static-fields", dest="static_fields", default="",
                    help="逗号分隔的静态字段名，读其值（指向 Map 则一并测条目数）")
    ap.add_argument("--map-fields", dest="map_fields", default="",
                    help="逗号分隔的实例字段名，按 Map 实测条目数")
    ap.add_argument("--limit", type=int, default=20, help="最多打印实例数")
    args = ap.parse_args()

    path = args.hprof
    holder_name = args.cls.replace(".", "/")
    want_fields = [s.strip() for s in args.fields.split(",") if s.strip()]
    want_static = [s.strip() for s in args.static_fields.split(",") if s.strip()]
    want_maps = [s.strip() for s in args.map_fields.split(",") if s.strip()]

    def read_header(f):
        b = bytearray()
        while True:
            c = f.read(1)
            if c in (b"\x00", b""):
                break
            b += c
        if not bytes(b).startswith(b"JAVA PROFILE"):
            raise SystemExit(f"[err] {path} 不是 HPROF 堆转储（缺少 'JAVA PROFILE' 魔数）")
        head = f.read(4)
        if len(head) < 4:
            raise SystemExit(f"[err] {path} 文件过短或损坏，无法读取 id_size")
        ids = struct.unpack(">I", head)[0]
        f.read(8)
        return ids

    with open(path, "rb") as f0:
        id_size = read_header(f0)
    ID = ">Q" if id_size == 8 else ">I"
    TS = {2: id_size, 4: 1, 5: 2, 6: 4, 7: 8, 8: 1, 9: 2, 10: 4, 11: 8}
    TNAME = {2: "ref", 4: "boolean", 5: "char", 6: "float", 7: "double",
             8: "byte", 9: "short", 10: "int", 11: "long"}

    def rid(mv, off):
        return struct.unpack_from(ID, mv, off)[0], off + id_size

    def read_val(mv, base, off, t):
        """读 base+off 处一个 type=t 的值。"""
        p = base + off
        if t == 2:
            return struct.unpack_from(ID, mv, p)[0]
        if t == 10:
            return struct.unpack_from(">i", mv, p)[0]
        if t == 11:
            return struct.unpack_from(">q", mv, p)[0]
        if t == 4:
            return bool(mv[p])
        if t == 9:
            return struct.unpack_from(">h", mv, p)[0]
        if t == 5:
            return struct.unpack_from(">H", mv, p)[0]
        if t == 8:
            return struct.unpack_from(">b", mv, p)[0]
        if t == 6:
            return struct.unpack_from(">f", mv, p)[0]
        if t == 7:
            return struct.unpack_from(">d", mv, p)[0]
        return "?"

    strings = {}
    loadclass = {}
    class_super = {}
    class_ifields = {}
    class_statics = {}        # cid -> [(name_id, type, value), ...]

    def parse_cd(mv, off, collect):
        cid, off = rid(mv, off); off += 4
        sup, off = rid(mv, off); off += id_size * 5; off += 4
        cp = struct.unpack_from(">H", mv, off)[0]; off += 2
        for _ in range(cp):
            off += 2; t = mv[off]; off += 1; off += TS[t]
        sf = struct.unpack_from(">H", mv, off)[0]; off += 2
        srefs = []
        for _ in range(sf):
            nid, off = rid(mv, off); t = mv[off]; off += 1
            v = read_val(mv, 0, off, t)
            off += id_size if t == 2 else TS[t]
            srefs.append((nid, t, v))
        iff = struct.unpack_from(">H", mv, off)[0]; off += 2
        ifs = []
        for _ in range(iff):
            nid, off = rid(mv, off); t = mv[off]; off += 1
            ifs.append((t, nid))
        if collect:
            class_super[cid] = sup
            class_ifields[cid] = ifs
            if srefs:
                class_statics[cid] = srefs
        return off, cid

    ROOT_FIX = {0xFF: 0, 0x05: 0, 0x07: 0, 0x02: 8, 0x03: 8, 0x08: 8, 0x04: 4, 0x06: 4}

    def skip_root(mv, off, sub):
        _, off = rid(mv, off)
        if sub == 0x01:
            _, off = rid(mv, off)
        else:
            off += ROOT_FIX[sub]
        return off

    def walk(collect_classes, instance_cb, array_cb, label=""):
        f = open(path, "rb"); read_header(f); read = f.read
        seg = 0
        while True:
            h = read(9)
            if len(h) < 9:
                break
            tag = h[0]; L = struct.unpack_from(">I", h, 5)[0]
            if tag == 0x01:
                b = read(L); sid = struct.unpack_from(ID, b, 0)[0]
                strings[sid] = bytes(b[id_size:]).decode("utf-8", "replace")
            elif tag == 0x02:
                b = read(L); off = 4; cid, off = rid(b, off); off += 4
                nid, off = rid(b, off); loadclass[cid] = nid
            elif tag in (0x0C, 0x1C):
                b = read(L); mv = memoryview(b); n = len(mv); off = 0
                while off < n:
                    s = mv[off]; off += 1
                    if s == 0x21:
                        oid, off = rid(mv, off); off += 4
                        ccid, off = rid(mv, off)
                        nb = struct.unpack_from(">I", mv, off)[0]; off += 4
                        d = off; off += nb
                        if instance_cb:
                            instance_cb(oid, ccid, mv, d)
                    elif s == 0x20:
                        off, _ = parse_cd(mv, off, collect_classes)
                    elif s == 0x22:
                        oid, off = rid(mv, off); off += 4
                        num = struct.unpack_from(">I", mv, off)[0]; off += 4
                        acid, off = rid(mv, off)
                        if array_cb:
                            array_cb(oid, acid, num, mv, off)
                        off += num * id_size
                    elif s == 0x23:
                        _, off = rid(mv, off); off += 4
                        num = struct.unpack_from(">I", mv, off)[0]; off += 4
                        et = mv[off]; off += 1; off += num * TS[et]
                    else:
                        off = skip_root(mv, off, s)
                seg += 1
                if label:
                    sys.stderr.write(f"[{label}] heap segment #{seg} scanned\n")
                    sys.stderr.flush()
            else:
                f.seek(L, 1)
        f.close()

    # ---------- Pass A: 类布局 + 静态字段 ----------
    sys.stderr.write("[passA] 扫描类元数据...\n"); sys.stderr.flush()
    walk(True, None, None)

    def cid_by_name(name):
        name = name.replace(".", "/")
        for cid, nid in loadclass.items():
            if strings.get(nid) == name:
                return cid
        return None

    def fname(nid):
        return strings.get(nid, f"<f@{nid}>")

    def layout(cls_id):
        res = {}
        off = 0
        cid = cls_id
        while cid in class_ifields:
            for (t, nid) in class_ifields[cid]:
                nm = strings.get(nid, "?")
                if nm not in res:
                    res[nm] = (off, t)
                off += id_size if t == 2 else TS[t]
            cid = class_super.get(cid, 0)
        return res

    HOLDER = cid_by_name(holder_name)
    if HOLDER is None:
        raise SystemExit(f"[err] 未找到类 {holder_name}")
    hl = layout(HOLDER)
    statics = {fname(nid): (t, v) for (nid, t, v) in class_statics.get(HOLDER, [])}

    # 模式 1：仅列字段（实例 + 静态）
    if not want_fields and not want_static and not want_maps:
        print(f"[class] {holder_name}")
        print("  实例字段 (name : type @offset):")
        for nm, (o, t) in sorted(hl.items(), key=lambda x: x[1][0]):
            print(f"    {nm:30s} : {TNAME.get(t, t):8s} @{o}")
        if statics:
            print("  静态字段 (name : type):")
            for nm, (t, v) in statics.items():
                vs = hex(v) if t == 2 and isinstance(v, int) else v
                print(f"    {nm:30s} : {TNAME.get(t, t):8s} = {vs}")
        print("\n用 --fields 读实例字段，--static-fields 读静态字段，--map-fields 测实例 Map。")
        return

    # Map 类布局
    CHM = cid_by_name("java/util/concurrent/ConcurrentHashMap")
    HM = cid_by_name("java/util/HashMap")
    chm_l = layout(CHM) if CHM else {}
    hm_l = layout(HM) if HM else {}
    NODE_CHM = cid_by_name("[Ljava/util/concurrent/ConcurrentHashMap$Node;")
    NODE_HM = cid_by_name("[Ljava/util/HashMap$Node;")

    holder_rows = []          # (oid, {field: (value, tname)}, {mapfield: map_id})
    map_info = {}             # map_id -> (size_or_baseCount, table_id)
    arr_info = {}             # nodearr_id -> (capacity, nonnull)
    need_passB = bool(want_fields or want_maps) or any(
        statics.get(s, (None, None))[0] == 2 for s in want_static)

    def inst_cb(oid, ccid, mv, d):
        if ccid == HOLDER and (want_fields or want_maps):
            fv = {}
            for fn in want_fields:
                if fn in hl:
                    o, t = hl[fn]
                    fv[fn] = (read_val(mv, d, o, t), TNAME.get(t, t))
                else:
                    fv[fn] = ("<无此字段>", "")
            mids = {}
            for fn in want_maps:
                mids[fn] = struct.unpack_from(ID, mv, d + hl[fn][0])[0] if fn in hl else 0
            holder_rows.append((oid, fv, mids))
        elif ccid == CHM:
            bc = struct.unpack_from(">q", mv, d + chm_l["baseCount"][0])[0] if "baseCount" in chm_l else -1
            tid = struct.unpack_from(ID, mv, d + chm_l["table"][0])[0] if "table" in chm_l else 0
            map_info[oid] = (bc, tid)
        elif ccid == HM:
            sz = struct.unpack_from(">i", mv, d + hm_l["size"][0])[0] if "size" in hm_l else -1
            tid = struct.unpack_from(ID, mv, d + hm_l["table"][0])[0] if "table" in hm_l else 0
            map_info[oid] = (sz, tid)

    def arr_cb(oid, acid, num, mv, off):
        if acid in (NODE_CHM, NODE_HM):
            if num:
                elems = struct.unpack_from(">%d%s" % (num, ID[1]), mv, off)
                nn = sum(1 for e in elems if e != 0)
            else:
                nn = 0
            arr_info[oid] = (num, nn)

    if need_passB:
        sys.stderr.write("[passB] 扫描实例 / Map / 数组...\n"); sys.stderr.flush()
        walk(False, inst_cb, arr_cb)

    def report_map(label, mid):
        if not mid:
            print(f"    {label} = <null>"); return
        info = map_info.get(mid)
        if not info:
            print(f"    {label} = {hex(mid)} (非 HashMap/ConcurrentHashMap 或未解析)"); return
        szc, tid = info
        cap, nn = arr_info.get(tid, ("?", "?"))
        warn = ""
        if isinstance(nn, int) and szc >= 0 and nn > szc * 10:
            warn = "  ⚠️ size 字段疑似失真(并发计数分散到 counterCells)，以非空桶/容量为准"
        print(f"    {label}: size字段={szc:,}  table容量={cap}  非空桶={nn}{warn}")

    # 模式 3：静态字段
    if want_static:
        print(f"[class] {holder_name}  静态字段:")
        for sn in want_static:
            if sn not in statics:
                print(f"    {sn} = <无此静态字段>"); continue
            t, v = statics[sn]
            if t == 2:
                report_map(sn, v)
            else:
                print(f"    {sn} = {v}  ({TNAME.get(t, t)})")
        print()

    # 模式 2/4：实例字段 / 实例 Map
    if want_fields or want_maps:
        print(f"[class] {holder_name}  实例数 = {len(holder_rows)}\n")
        for i, (oid, fv, mids) in enumerate(holder_rows[:args.limit]):
            print(f"实例#{i} @ {hex(oid)}")
            for fn, (v, t) in fv.items():
                vs = hex(v) if t == "ref" and isinstance(v, int) else v
                print(f"    [field] {fn} = {vs}  ({t})")
            for fn, mid in mids.items():
                report_map(f"[map]   {fn}", mid)
            print()
        if len(holder_rows) > args.limit:
            print(f"... 其余 {len(holder_rows) - args.limit} 个实例已省略（--limit 调整）")
    print("[done]")


if __name__ == "__main__":
    main()
