import type { AgentTool, AgentToolResult, AgentToolUpdateCallback } from "@oh-my-pi/pi-agent-core"; import type { Static, TSchema } from "@oh-my-pi/pi-ai"; import { Snowflake } from "@oh-my-pi/pi-utils"; import { applyToolProxy } from "../../extensibility/tool-proxy"; import type { Theme } from "../../modes/theme/theme"; import type { RpcHostToolCallRequest, RpcHostToolCancelRequest, RpcHostToolDefinition, RpcHostToolResult, RpcHostToolUpdate, } from "./rpc-types"; type RpcHostToolOutput = (frame: RpcHostToolCallRequest | RpcHostToolCancelRequest) => void; type PendingHostToolCall = { resolve: (result: AgentToolResult) => void; reject: (error: Error) => void; onUpdate?: AgentToolUpdateCallback; }; function isAgentToolResult(value: unknown): value is AgentToolResult { if (!value || typeof value !== "object") return false; const content = (value as { content?: unknown }).content; return Array.isArray(content); } export function isRpcHostToolResult(value: unknown): value is RpcHostToolResult { if (!value || typeof value !== "object") return false; const frame = value as { type?: unknown; id?: unknown; result?: unknown }; return frame.type === "host_tool_result" && typeof frame.id === "string" && isAgentToolResult(frame.result); } export function isRpcHostToolUpdate(value: unknown): value is RpcHostToolUpdate { if (!value || typeof value !== "object") return false; const frame = value as { type?: unknown; id?: unknown; partialResult?: unknown }; return frame.type === "host_tool_update" && typeof frame.id === "string" && isAgentToolResult(frame.partialResult); } class RpcHostToolAdapter implements AgentTool { declare name: string; declare label: string; declare description: string; declare parameters: TParams; readonly strict = true; concurrency: "shared" | "exclusive" = "shared"; #bridge: RpcHostToolBridge; #definition: RpcHostToolDefinition; constructor(definition: RpcHostToolDefinition, bridge: RpcHostToolBridge) { this.#definition = definition; this.#bridge = bridge; applyToolProxy(definition, this); } execute( toolCallId: string, params: Static, signal?: AbortSignal, onUpdate?: AgentToolUpdateCallback, ): Promise> { return this.#bridge.requestExecution( this.#definition, toolCallId, params as Record, signal, onUpdate, ); } } export class RpcHostToolBridge { #output: RpcHostToolOutput; #definitions = new Map(); #pendingCalls = new Map(); constructor(output: RpcHostToolOutput) { this.#output = output; } getToolNames(): string[] { return Array.from(this.#definitions.keys()); } setTools(tools: RpcHostToolDefinition[]): AgentTool[] { this.#definitions = new Map(tools.map(tool => [tool.name, tool])); return tools.map(tool => new RpcHostToolAdapter(tool, this)); } handleResult(frame: RpcHostToolResult): boolean { const pending = this.#pendingCalls.get(frame.id); if (!pending) return false; this.#pendingCalls.delete(frame.id); if (frame.isError) { const text = frame.result.content .filter( (item): item is { type: "text"; text: string } => item.type === "text" && typeof item.text === "string", ) .map(item => item.text) .join("\n") .trim(); pending.reject(new Error(text || "Host tool execution failed")); return true; } pending.resolve(frame.result); return true; } handleUpdate(frame: RpcHostToolUpdate): boolean { const pending = this.#pendingCalls.get(frame.id); if (!pending) return false; pending.onUpdate?.(frame.partialResult); return true; } requestExecution( definition: RpcHostToolDefinition, toolCallId: string, args: Record, signal?: AbortSignal, onUpdate?: AgentToolUpdateCallback, ): Promise> { if (signal?.aborted) { return Promise.reject(new Error(`Host tool "${definition.name}" was aborted`)); } const id = Snowflake.next() as string; const { promise, resolve, reject } = Promise.withResolvers>(); let settled = false; const cleanup = () => { signal?.removeEventListener("abort", onAbort); this.#pendingCalls.delete(id); }; const onAbort = () => { if (settled) return; settled = true; cleanup(); this.#output({ type: "host_tool_cancel", id: Snowflake.next() as string, targetId: id, }); reject(new Error(`Host tool "${definition.name}" was aborted`)); }; signal?.addEventListener("abort", onAbort, { once: true }); this.#pendingCalls.set(id, { resolve: result => { if (settled) return; settled = true; cleanup(); resolve(result); }, reject: error => { if (settled) return; settled = true; cleanup(); reject(error); }, onUpdate, }); this.#output({ type: "host_tool_call", id, toolCallId, toolName: definition.name, arguments: args, }); return promise; } rejectAllPending(message: string): void { const error = new Error(message); const pendingCalls = Array.from(this.#pendingCalls.values()); this.#pendingCalls.clear(); for (const pending of pendingCalls) { pending.reject(error); } } }