#!/usr/bin/env python3
"""
hprof_histogram.py — 流式 HPROF 类直方图（实例数 / 浅大小）。

纯标准库、单遍流式解析，内存占用与对象总数无关（只缓存类元数据），
可处理数 GB、数千万对象的 dump。用于内存泄漏分析第一步：
快速看出"哪些类实例最多 / 占内存最多"，尤其是异常增长的业务/第三方类。

用法:
    python3 hprof_histogram.py <dump.hprof> [--top N] [--biz-only] [--no-jdk]

参数:
    --top N      每个榜单显示前 N 行（默认 50）
    --biz-only   只显示业务/第三方类（排除 java./javax./jdk./sun./com.sun./jakarta.）

输出: 四个榜单 —— 全部按浅大小 / 全部按实例数 / 业务类按浅大小 / 业务类按实例数。

注意:
    - 浅大小是估算（对象头按 16 字节近似），用于排名而非精确字节会计。
    - 真正定位"谁持有这些对象"需要配合 trace_referrers.py（反向引用追溯）。
"""
import sys
import struct
import argparse
from collections import defaultdict


def main():
    ap = argparse.ArgumentParser(description="HPROF 类直方图")
    ap.add_argument("hprof", help="heap dump (.hprof) 路径")
    ap.add_argument("--top", type=int, default=50, help="每个榜单显示前 N 行")
    ap.add_argument("--biz-only", action="store_true", help="只显示业务/第三方类")
    args = ap.parse_args()

    f = open(args.hprof, "rb")

    # ---- header: 格式串(null 结尾) + id_size(u4) + timestamp(u8) ----
    buf = bytearray()
    while True:
        c = f.read(1)
        if c in (b"\x00", b""):
            break
        buf += c
    fmt = bytes(buf).decode("ascii", "replace")
    if not fmt.startswith("JAVA PROFILE"):
        raise SystemExit(f"[err] {args.hprof} 不是 HPROF 堆转储（缺少 'JAVA PROFILE' 魔数）")
    head = f.read(4)
    if len(head) < 4:
        raise SystemExit(f"[err] {args.hprof} 文件过短或损坏，无法读取 id_size")
    id_size = struct.unpack(">I", head)[0]
    f.read(8)
    sys.stderr.write(f"[info] format={fmt!r} id_size={id_size}\n")
    sys.stderr.flush()

    ID = ">Q" if id_size == 8 else ">I"
    TS = {2: id_size, 4: 1, 5: 2, 6: 4, 7: 8, 8: 1, 9: 2, 10: 4, 11: 8}
    PRIM = {4: "boolean", 5: "char", 6: "float", 7: "double",
            8: "byte", 9: "short", 10: "int", 11: "long"}

    def rid(mv, off):
        return struct.unpack_from(ID, mv, off)[0], off + id_size

    strings = {}
    loadclass = {}
    inst_count = defaultdict(int)
    inst_bytes = defaultdict(int)
    objarr_count = defaultdict(int)
    objarr_bytes = defaultdict(int)
    primarr_count = defaultdict(int)
    primarr_bytes = defaultdict(int)
    total_objs = 0

    def parse_class_dump(mv, off):
        _, off = rid(mv, off)        # class object id
        off += 4                     # stack trace serial
        off += id_size * 6           # super, loader, signers, protdomain, res1, res2
        off += 4                     # instance size
        cp = struct.unpack_from(">H", mv, off)[0]; off += 2
        for _ in range(cp):
            off += 2
            t = mv[off]; off += 1
            off += TS[t]
        sf = struct.unpack_from(">H", mv, off)[0]; off += 2
        for _ in range(sf):
            off += id_size
            t = mv[off]; off += 1
            off += TS[t]
        iff = struct.unpack_from(">H", mv, off)[0]; off += 2
        off += iff * (id_size + 1)
        return off

    def parse_heap(body):
        nonlocal total_objs
        mv = memoryview(body)
        n = len(mv)
        off = 0
        su = struct.unpack_from
        while off < n:
            sub = mv[off]; off += 1
            if sub == 0x21:                       # INSTANCE DUMP
                _, off = rid(mv, off); off += 4
                cls_id, off = rid(mv, off)
                nb = su(">I", mv, off)[0]; off += 4
                off += nb
                inst_count[cls_id] += 1
                inst_bytes[cls_id] += 16 + nb
                total_objs += 1
            elif sub == 0x20:                     # CLASS DUMP
                off = parse_class_dump(mv, off)
            elif sub == 0x22:                     # OBJECT ARRAY DUMP
                _, off = rid(mv, off); off += 4
                num = su(">I", mv, off)[0]; off += 4
                arr_cls, off = rid(mv, off)
                off += num * id_size
                objarr_count[arr_cls] += 1
                objarr_bytes[arr_cls] += 16 + num * id_size
                total_objs += 1
            elif sub == 0x23:                     # PRIMITIVE ARRAY DUMP
                _, off = rid(mv, off); off += 4
                num = su(">I", mv, off)[0]; off += 4
                et = mv[off]; off += 1
                off += num * TS[et]
                primarr_count[et] += 1
                primarr_bytes[et] += 16 + num * TS[et]
                total_objs += 1
            elif sub == 0xFF: _, off = rid(mv, off)
            elif sub == 0x01: _, off = rid(mv, off); _, off = rid(mv, off)
            elif sub == 0x02: _, off = rid(mv, off); off += 8
            elif sub == 0x03: _, off = rid(mv, off); off += 8
            elif sub == 0x04: _, off = rid(mv, off); off += 4
            elif sub == 0x05: _, off = rid(mv, off)
            elif sub == 0x06: _, off = rid(mv, off); off += 4
            elif sub == 0x07: _, off = rid(mv, off)
            elif sub == 0x08: _, off = rid(mv, off); off += 8
            else:
                raise SystemExit(f"[err] unknown heap subrecord 0x{sub:02x} at off {off-1}")

    read = f.read
    seg = 0
    while True:
        hdr = read(9)
        if len(hdr) < 9:
            break
        tag = hdr[0]
        length = struct.unpack_from(">I", hdr, 5)[0]
        if tag == 0x01:                           # STRING
            body = read(length)
            sid = struct.unpack_from(ID, body, 0)[0]
            strings[sid] = bytes(body[id_size:]).decode("utf-8", "replace")
        elif tag == 0x02:                         # LOAD CLASS
            body = read(length)
            off = 4
            cls_id, off = rid(body, off)
            off += 4
            name_id, off = rid(body, off)
            loadclass[cls_id] = name_id
        elif tag in (0x0C, 0x1C):                 # HEAP DUMP / SEGMENT
            body = read(length)
            parse_heap(body)
            seg += 1
            sys.stderr.write(f"[info] heap segment #{seg} parsed, total_objs={total_objs:,}\n")
            sys.stderr.flush()
        else:
            f.seek(length, 1)
    f.close()

    def cname(cls_id):
        nid = loadclass.get(cls_id)
        if nid is None:
            return f"<class@{cls_id}>"
        return strings.get(nid, f"<str@{nid}>").replace("/", ".")

    rows = []
    for cid, c in inst_count.items():
        rows.append((cname(cid), c, inst_bytes[cid]))
    for cid, c in objarr_count.items():
        rows.append((cname(cid), c, objarr_bytes[cid]))
    for et, c in primarr_count.items():
        rows.append((f"{PRIM[et]}[]", c, primarr_bytes[et]))

    total_bytes = sum(r[2] for r in rows)
    out = []
    out.append(f"format={fmt!r} id_size={id_size}")
    out.append(f"total_objects={total_objs:,}  total_shallow_bytes={total_bytes:,} "
               f"({total_bytes/1024/1024/1024:.2f} GiB)")

    def is_biz(name):
        base = name.replace("[]", "").lstrip("[L").rstrip(";")
        if base in ("int", "long", "short", "byte", "char",
                    "boolean", "float", "double"):
            return False                       # 基本类型数组归 JDK，不算业务类
        return not base.startswith(("java.", "javax.", "jdk.", "sun.",
                                    "com.sun.", "jakarta.", "<"))

    def section(title, key, biz):
        out.append("\n" + "=" * 100)
        out.append(title)
        out.append("=" * 100)
        out.append(f"{'shallow_MiB':>12} {'count':>14}   class")
        sel = [r for r in rows if (not biz or is_biz(r[0]))]
        for name, c, b in sorted(sel, key=key, reverse=True)[:args.top]:
            out.append(f"{b/1024/1024:12.1f} {c:14,}   {name}")

    if not args.biz_only:
        section("TOP by SHALLOW SIZE (全部)", lambda r: r[2], False)
        section("TOP by INSTANCE COUNT (全部)", lambda r: r[1], False)
    section("TOP by SHALLOW SIZE (业务/第三方类，排除 JDK)", lambda r: r[2], True)
    section("TOP by INSTANCE COUNT (业务/第三方类，排除 JDK)", lambda r: r[1], True)

    print("\n".join(out))
    print("\n[done]", flush=True)


if __name__ == "__main__":
    main()
