#!/usr/bin/env python3
"""
[HolySheep fork v2.1.29 / hs22] Hermes ACP PTY wrapper.

Why: bun 1.3.9's subprocess.spawn with stdio:['pipe','pipe','pipe']
passes sockets to the child. Python's asyncio.connect_write_pipe()
silently drops the 2nd+ message written under bun's event loop routing.
Reproduced: bun delivers id=1 response but not id=2+; node delivers both.

Fix: allocate a PTY master/slave pair. Hand the SLAVE to hermes as its
stdin+stdout, keep stderr on the wrapper's fd 2 so hs web log still
captures Python tracebacks separately. Master side pumps bytes between
our (bun-spawned) stdin/stdout and the PTY, so bun reads through kernel
PTY buffer which hermes's asyncio transport drives correctly.

CRITICAL: disable ICANON (canonical/line-buffered mode) on the slave.
Default macOS/Linux PTY slaves have ICANON on, which means:
  1. Input is buffered line-by-line and each line is capped at MAX_CANON
     (1024B on macOS, ~4096B on Linux). Longer lines get truncated.
  2. Special chars (^C, ^D, ^Z, erase/kill) are interpreted instead of
     passed through — breaks JSON containing random bytes.

ACP session/prompt JSON can exceed 4KB easily (system prompt + skill
list + assistant rules). Before hs22 we only disabled ECHO + ONLCR, so
id=1 (initialize, ~150B) and id=2/3 (newSession, setSessionMode, <400B)
went through, but id=4 (prompt, ~4-8KB) was truncated by MAX_CANON —
hermes got invalid JSON, returned an SDK error with empty data, our
AcpAgentManager 60s fallback masked it. User saw "hermes thinks forever".

Also disable ICRNL and IXON so we don't have \\r<->\\n rewriting or
flow-control pauses on large bursts.

Usage: python3 pty-hermes-wrapper.py
HERMES_BIN is auto-resolved from $PATH; override via $HOLYSHEEP_HERMES_BIN.
"""
import os, sys, pty, termios, select, signal


def _resolve_hermes_bin() -> str:
    override = os.environ.get('HOLYSHEEP_HERMES_BIN')
    if override and os.access(override, os.X_OK):
        return override
    for d in os.environ.get('PATH', '').split(os.pathsep):
        candidate = os.path.join(d, 'hermes')
        if os.access(candidate, os.X_OK):
            return candidate
    sys.stderr.write('[pty-hermes-wrapper] hermes binary not found in $PATH\n')
    sys.exit(127)


def main() -> None:
    hermes_bin = _resolve_hermes_bin()
    master_fd, slave_fd = pty.openpty()
    attrs = termios.tcgetattr(slave_fd)
    # attrs = [iflag, oflag, cflag, lflag, ispeed, ospeed, cc]
    # lflag (attrs[3]): disable canonical-mode line buffering, ECHO*, signal chars
    attrs[3] &= ~(termios.ICANON | termios.ECHO | termios.ECHOE | termios.ECHOK
                  | termios.ECHONL | termios.ISIG)
    # iflag (attrs[0]): disable \r<->\n translation, XON/XOFF flow control,
    # and parity checking (we carry raw bytes only).
    attrs[0] &= ~(termios.ICRNL | termios.INLCR | termios.IGNCR | termios.IXON
                  | termios.IXOFF | termios.ISTRIP | termios.IGNBRK | termios.BRKINT
                  | termios.INPCK | termios.PARMRK)
    # oflag (attrs[1]): disable post-processing (ONLCR \n -> \r\n, tabs, etc.)
    attrs[1] &= ~termios.OPOST
    # Non-blocking reads: return as soon as 1 byte is available.
    attrs[6][termios.VMIN] = 1
    attrs[6][termios.VTIME] = 0
    termios.tcsetattr(slave_fd, termios.TCSANOW, attrs)
    # Best-effort: print the applied flags once so future regressions are
    # easy to diagnose from hs web logs.
    try:
        sys.stderr.write(
            f'[pty-hermes-wrapper] slave configured raw: ICANON off, ECHO off, OPOST off\n'
        )
        sys.stderr.flush()
    except OSError:
        pass

    pid = os.fork()
    if pid == 0:
        os.setsid()
        os.dup2(slave_fd, 0)
        os.dup2(slave_fd, 1)
        os.close(master_fd)
        os.close(slave_fd)
        os.execvp(hermes_bin, [hermes_bin, 'acp'])
        return

    os.close(slave_fd)

    def _forward(signum, _frame):
        try:
            os.kill(pid, signum)
        except ProcessLookupError:
            pass

    signal.signal(signal.SIGTERM, _forward)
    signal.signal(signal.SIGINT, _forward)

    try:
        while True:
            try:
                wpid, _ = os.waitpid(pid, os.WNOHANG)
                if wpid != 0:
                    break
            except ChildProcessError:
                break
            try:
                r, _, _ = select.select([master_fd, 0], [], [], 0.5)
            except (InterruptedError, OSError):
                continue
            if master_fd in r:
                try:
                    data = os.read(master_fd, 65536)
                except OSError:
                    break
                if not data:
                    break
                try:
                    sys.stdout.buffer.write(data)
                    sys.stdout.buffer.flush()
                except (BrokenPipeError, OSError):
                    try:
                        os.kill(pid, signal.SIGTERM)
                    except ProcessLookupError:
                        pass
                    break
            if 0 in r:
                try:
                    data = os.read(0, 65536)
                except OSError:
                    break
                if not data:
                    try:
                        os.kill(pid, signal.SIGTERM)
                    except ProcessLookupError:
                        pass
                    break
                # [HolySheep fork v2.1.29 / hs22] Diagnostic trace when enabled —
                # emits one line per stdin chunk so we can see exactly what bun
                # pushed to us and verify hermes received the full payload. Only
                # active when HOLYSHEEP_PTY_TRACE=1 is set to avoid log flood.
                if os.environ.get('HOLYSHEEP_PTY_TRACE') == '1':
                    try:
                        sys.stderr.write(f'[pty-hermes-wrapper] stdin chunk len={len(data)}\n')
                        sys.stderr.flush()
                    except OSError:
                        pass
                # Write in full — os.write may short-write on large buffers; loop.
                total = 0
                while total < len(data):
                    try:
                        n = os.write(master_fd, data[total:])
                    except OSError:
                        total = -1
                        break
                    if n <= 0:
                        break
                    total += n
                if total < 0:
                    break
    finally:
        try:
            os.close(master_fd)
        except OSError:
            pass
        try:
            os.waitpid(pid, 0)
        except OSError:
            pass


if __name__ == '__main__':
    main()
