/** * Extension runner - executes extensions and manages their lifecycle. */ import type { AgentMessage } from "@oh-my-pi/pi-agent-core"; import type { CredentialDisabledEvent, ImageContent, Model, ProviderResponseMetadata } from "@oh-my-pi/pi-ai"; import type { KeyId } from "@oh-my-pi/pi-tui"; import { logger } from "@oh-my-pi/pi-utils"; import type { ModelRegistry } from "../../config/model-registry"; import { type Theme, theme } from "../../modes/theme/theme"; import type { SessionManager } from "../../session/session-manager"; import type { AfterProviderResponseEvent, BeforeAgentStartEvent, BeforeAgentStartEventResult, BeforeProviderRequestEvent, BeforeProviderRequestEventResult, CompactOptions, ContextEvent, ContextEventResult, ContextUsage, Extension, ExtensionActions, ExtensionCommandContext, ExtensionCommandContextActions, ExtensionContext, ExtensionContextActions, ExtensionError, ExtensionEvent, ExtensionFlag, ExtensionRuntime, ExtensionShortcut, ExtensionUIContext, InputEvent, InputEventResult, MessageRenderer, RegisteredCommand, RegisteredTool, ResourcesDiscoverEvent, ResourcesDiscoverResult, SessionBeforeBranchResult, SessionBeforeCompactResult, SessionBeforeSwitchResult, SessionBeforeTreeResult, SessionCompactingResult, ToolCallEvent, ToolCallEventResult, ToolResultEvent, ToolResultEventResult, UserBashEvent, UserBashEventResult, UserPythonEvent, UserPythonEventResult, } from "./types"; /** Combined result from all before_agent_start handlers */ interface BeforeAgentStartCombinedResult { messages?: NonNullable[]; systemPrompt?: string[]; } export type ExtensionErrorListener = (error: ExtensionError) => void; export const EXTENSION_HANDLER_TIMEOUT_MS = 30_000; let extensionHandlerTimeoutMs = EXTENSION_HANDLER_TIMEOUT_MS; export function testSetExtensionHandlerTimeoutMs(timeoutMs: number): void { extensionHandlerTimeoutMs = timeoutMs; } const EXTENSION_HANDLER_TIMEOUT = Symbol("extensionHandlerTimeout"); const MAX_PENDING_CREDENTIAL_DISABLED = 32; /** * Events handled by the generic emit() method. * Events with dedicated emitXxx() methods are excluded for stronger type safety. */ type RunnerEmitEvent = Exclude< ExtensionEvent, | ToolCallEvent | ToolResultEvent | UserBashEvent | ContextEvent | BeforeProviderRequestEvent | AfterProviderResponseEvent | BeforeAgentStartEvent | ResourcesDiscoverEvent | InputEvent >; type SessionBeforeEvent = Extract< RunnerEmitEvent, { type: "session_before_switch" | "session_before_branch" | "session_before_compact" | "session_before_tree" } >; type SessionBeforeEventResult = | SessionBeforeSwitchResult | SessionBeforeBranchResult | SessionBeforeCompactResult | SessionBeforeTreeResult; type RunnerEmitResult = TEvent extends { type: "session_before_switch" } ? SessionBeforeSwitchResult | undefined : TEvent extends { type: "session_before_branch" } ? SessionBeforeBranchResult | undefined : TEvent extends { type: "session_before_compact" } ? SessionBeforeCompactResult | undefined : TEvent extends { type: "session_before_tree" } ? SessionBeforeTreeResult | undefined : TEvent extends { type: "session.compacting" } ? SessionCompactingResult | undefined : undefined; export type NewSessionHandler = (options?: { parentSession?: string; setup?: (sessionManager: SessionManager) => Promise; }) => Promise<{ cancelled: boolean }>; export type BranchHandler = (entryId: string) => Promise<{ cancelled: boolean }>; export type NavigateTreeHandler = ( targetId: string, options?: { summarize?: boolean }, ) => Promise<{ cancelled: boolean }>; export type SwitchSessionHandler = (sessionPath: string) => Promise<{ cancelled: boolean }>; export type ShutdownHandler = () => void; /** * Helper function to emit session_shutdown event to extensions. * Returns true if the event was emitted, false if there were no handlers. */ export async function emitSessionShutdownEvent(extensionRunner: ExtensionRunner | undefined): Promise { if (extensionRunner?.hasHandlers("session_shutdown")) { await extensionRunner.emit({ type: "session_shutdown", }); return true; } return false; } const noOpUIContext: ExtensionUIContext = { select: async (_title, _options, _dialogOptions) => undefined, confirm: async (_title, _message, _dialogOptions) => false, input: async (_title, _placeholder, _dialogOptions) => undefined, notify: () => {}, onTerminalInput: () => () => {}, setStatus: () => {}, setWorkingMessage: () => {}, setWidget: () => {}, setFooter: () => {}, setHeader: () => {}, setTitle: () => {}, custom: async () => undefined as never, setEditorText: () => {}, pasteToEditor: () => {}, getEditorText: () => "", editor: async () => undefined, setEditorComponent: () => {}, get theme() { return theme; }, getAllThemes: () => Promise.resolve([]), getTheme: () => Promise.resolve(undefined), setTheme: (_theme: string | Theme) => Promise.resolve({ success: false, error: "UI not available" }), getToolsExpanded: () => false, setToolsExpanded: () => {}, }; export class ExtensionRunner { #uiContext: ExtensionUIContext; #errorListeners: Set = new Set(); #getModel: () => Model | undefined = () => undefined; #isIdleFn: () => boolean = () => true; #waitForIdleFn: () => Promise = async () => {}; #abortFn: () => void = () => {}; #hasPendingMessagesFn: () => boolean = () => false; #getContextUsageFn: () => ContextUsage | undefined = () => undefined; #compactFn: (instructionsOrOptions?: string | CompactOptions) => Promise = async () => {}; #getSystemPromptFn: () => string[] = () => []; #newSessionHandler: NewSessionHandler = async () => ({ cancelled: false }); #branchHandler: BranchHandler = async () => ({ cancelled: false }); #navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false }); #switchSessionHandler: SwitchSessionHandler = async () => ({ cancelled: false }); #reloadHandler: () => Promise = async () => {}; #shutdownHandler: ShutdownHandler = () => {}; #commandDiagnostics: Array<{ type: string; message: string; path: string }> = []; #initialized = false; /** * Buffer for `credential_disabled` events received via {@link emitCredentialDisabled} * before {@link initialize} has run. Drained through {@link emit} once initialize sets * up the runtime context, so extension handlers see a populated UI/runtime context * rather than the constructor's no-op default. Bounded at * {@link MAX_PENDING_CREDENTIAL_DISABLED}; oldest entries are dropped under pressure. */ #pendingCredentialDisabled: CredentialDisabledEvent[] = []; constructor( private readonly extensions: Extension[], private readonly runtime: ExtensionRuntime, private readonly cwd: string, private readonly sessionManager: SessionManager, private readonly modelRegistry: ModelRegistry, ) { this.#uiContext = noOpUIContext; } initialize( actions: ExtensionActions, contextActions: ExtensionContextActions, commandContextActions?: ExtensionCommandContextActions, uiContext?: ExtensionUIContext, ): void { // Copy actions into the shared runtime (all extension APIs reference this) this.runtime.sendMessage = actions.sendMessage; this.runtime.sendUserMessage = actions.sendUserMessage; this.runtime.appendEntry = actions.appendEntry; this.runtime.getActiveTools = actions.getActiveTools; this.runtime.getAllTools = actions.getAllTools; this.runtime.setActiveTools = actions.setActiveTools; this.runtime.getCommands = actions.getCommands; this.runtime.setModel = actions.setModel; this.runtime.getThinkingLevel = actions.getThinkingLevel; this.runtime.setThinkingLevel = actions.setThinkingLevel; this.runtime.getSessionName = actions.getSessionName; this.runtime.setSessionName = actions.setSessionName; // Context actions (required) this.#getModel = contextActions.getModel; this.#isIdleFn = contextActions.isIdle; this.#abortFn = contextActions.abort; this.#hasPendingMessagesFn = contextActions.hasPendingMessages; this.#shutdownHandler = contextActions.shutdown; this.#getSystemPromptFn = contextActions.getSystemPrompt; // Command context actions (optional, only for interactive mode) if (commandContextActions) { this.#waitForIdleFn = commandContextActions.waitForIdle; this.#newSessionHandler = commandContextActions.newSession; this.#branchHandler = commandContextActions.branch; this.#navigateTreeHandler = commandContextActions.navigateTree; this.#switchSessionHandler = commandContextActions.switchSession; this.#reloadHandler = commandContextActions.reload; this.#getContextUsageFn = commandContextActions.getContextUsage; this.#compactFn = commandContextActions.compact; } this.#uiContext = uiContext ?? noOpUIContext; this.#initialized = true; // Drain events buffered by emitCredentialDisabled() before initialize ran. The // spread adds the `type` discriminator — `event` is the pi-ai shape (no `type`). // Deferred by one microtask so callers that register an onError listener // synchronously after initialize() see handler errors routed through it. const pending = this.#pendingCredentialDisabled.splice(0); queueMicrotask(() => { for (const event of pending) { this.emit({ type: "credential_disabled", ...event }).catch((error: unknown) => { logger.warn("credential_disabled handler threw during initialize flush", { provider: event.provider, error: error instanceof Error ? error.message : String(error), }); }); } }); } /** * Forward a `credential_disabled` event from `AuthStorage` to extension handlers. * * If {@link initialize} has not yet run, the event is buffered and replayed once * initialize wires the runtime/UI context. This matters because mode controllers * (interactive, RPC, ACP, print, subagent) call `initialize()` AFTER `createAgentSession` * returns, but `AuthStorage` can fire `credential_disabled` during startup model probes * inside `createAgentSession()`. Without deferral, extension handlers would observe * `hasUI=false`, an unset model, and no-op runtime actions on exactly the headline * "OAuth invalid_grant during startup" path the event was designed to surface. * * Always returns; never throws. Errors from handlers are routed through * {@link onError} via {@link emit}'s normal isolation. */ async emitCredentialDisabled(event: CredentialDisabledEvent): Promise { if (!this.#initialized) { if (this.#pendingCredentialDisabled.length >= MAX_PENDING_CREDENTIAL_DISABLED) { this.#pendingCredentialDisabled.shift(); } this.#pendingCredentialDisabled.push(event); return; } await this.emit({ type: "credential_disabled", ...event }); } getUIContext(): ExtensionUIContext { return this.#uiContext; } hasUI(): boolean { return this.#uiContext !== noOpUIContext; } getExtensionPaths(): string[] { return this.extensions.map(e => e.path); } /** Get all registered tools from all extensions. */ getAllRegisteredTools(): RegisteredTool[] { const tools: RegisteredTool[] = []; for (const ext of this.extensions) { for (const tool of ext.tools.values()) { tools.push(tool); } } return tools; } getFlags(): Map { const allFlags = new Map(); for (const ext of this.extensions) { for (const [name, flag] of ext.flags) { allFlags.set(name, flag); } } return allFlags; } getFlagValues(): Map { return new Map(this.runtime.flagValues); } setFlagValue(name: string, value: boolean | string): void { this.runtime.flagValues.set(name, value); } static readonly #RESERVED_SHORTCUTS = new Set([ "ctrl+c", "ctrl+d", "ctrl+z", "ctrl+k", "ctrl+p", "ctrl+l", "ctrl+o", "ctrl+t", "ctrl+g", "shift+tab", "shift+ctrl+p", "alt+enter", "escape", "enter", ]); getShortcuts(): Map { const allShortcuts = new Map(); for (const ext of this.extensions) { for (const [key, shortcut] of ext.shortcuts) { const normalizedKey = key.toLowerCase() as KeyId; if (ExtensionRunner.#RESERVED_SHORTCUTS.has(normalizedKey)) { logger.warn("Extension shortcut conflicts with built-in shortcut", { key, extensionPath: shortcut.extensionPath, }); continue; } const existing = allShortcuts.get(normalizedKey); if (existing) { logger.warn("Extension shortcut conflict", { key, extensionPath: shortcut.extensionPath, existingExtensionPath: existing.extensionPath, }); } allShortcuts.set(normalizedKey, shortcut); } } return allShortcuts; } onError(listener: ExtensionErrorListener): () => void { this.#errorListeners.add(listener); return () => this.#errorListeners.delete(listener); } emitError(error: ExtensionError): void { for (const listener of this.#errorListeners) { listener(error); } } hasHandlers(eventType: string): boolean { for (const ext of this.extensions) { const handlers = ext.handlers.get(eventType); if (handlers && handlers.length > 0) { return true; } } return false; } getMessageRenderer(customType: string): MessageRenderer | undefined { for (const ext of this.extensions) { const renderer = ext.messageRenderers.get(customType); if (renderer) { return renderer; } } return undefined; } getRegisteredCommands(reserved?: Set): RegisteredCommand[] { this.#commandDiagnostics = []; const commands = new Map(); for (const ext of this.extensions) { for (const command of ext.commands.values()) { if (reserved?.has(command.name)) { const message = `Extension command '${command.name}' from ${ext.path} conflicts with built-in commands. Skipping.`; this.#commandDiagnostics.push({ type: "warning", message, path: ext.path }); if (!this.hasUI()) { logger.warn(message); } continue; } commands.set(command.name, command); } } return [...commands.values()]; } getCommandDiagnostics(): Array<{ type: string; message: string; path: string }> { return this.#commandDiagnostics; } getCommand(name: string): RegisteredCommand | undefined { for (let index = this.extensions.length - 1; index >= 0; index -= 1) { const command = this.extensions[index]?.commands.get(name); if (command) { return command; } } return undefined; } createContext(): ExtensionContext { const getModel = this.#getModel; return { ui: this.#uiContext, getContextUsage: () => this.#getContextUsageFn(), compact: instructionsOrOptions => this.#compactFn(instructionsOrOptions), hasUI: this.hasUI(), cwd: this.cwd, sessionManager: this.sessionManager, modelRegistry: this.modelRegistry, get model() { return getModel(); }, isIdle: () => this.#isIdleFn(), abort: () => this.#abortFn(), hasPendingMessages: () => this.#hasPendingMessagesFn(), shutdown: () => this.#shutdownHandler(), getSystemPrompt: () => this.#getSystemPromptFn(), hasQueuedMessages: () => this.#hasPendingMessagesFn(), // deprecated alias }; } /** * Request a graceful shutdown. Called by extension tools and event handlers. */ shutdown(): void { this.#shutdownHandler(); } createCommandContext(): ExtensionCommandContext { return { ...this.createContext(), getContextUsage: () => this.#getContextUsageFn(), waitForIdle: () => this.#waitForIdleFn(), newSession: options => this.#newSessionHandler(options), branch: entryId => this.#branchHandler(entryId), navigateTree: (targetId, options) => this.#navigateTreeHandler(targetId, options), switchSession: sessionPath => this.#switchSessionHandler(sessionPath), reload: () => this.#reloadHandler(), compact: instructionsOrOptions => this.#compactFn(instructionsOrOptions), }; } #isSessionBeforeEvent(event: RunnerEmitEvent): event is SessionBeforeEvent { return ( event.type === "session_before_switch" || event.type === "session_before_branch" || event.type === "session_before_compact" || event.type === "session_before_tree" ); } async #runHandlerWithTimeout( handler: (event: TEvent, ctx: ExtensionContext) => Promise | TResult | undefined, event: TEvent, ctx: ExtensionContext, ext: Extension, timeoutMs: number, ): Promise { try { const handlerResult = await Promise.race([ Promise.resolve(handler(event, ctx)), Bun.sleep(timeoutMs).then(() => EXTENSION_HANDLER_TIMEOUT), ]); if (handlerResult === EXTENSION_HANDLER_TIMEOUT) { const error = `handler timed out after ${timeoutMs}ms`; logger.warn("Extension handler timed out", { extensionPath: ext.path, event: event.type, timeoutMs, }); this.emitError({ extensionPath: ext.path, event: event.type, error, }); return undefined; } return handlerResult as TResult | undefined; } catch (err) { const message = err instanceof Error ? err.message : String(err); const stack = err instanceof Error ? err.stack : undefined; this.emitError({ extensionPath: ext.path, event: event.type, error: message, stack, }); return undefined; } } async emit(event: TEvent): Promise> { const ctx = this.createContext(); let result: SessionBeforeEventResult | SessionCompactingResult | undefined; for (const ext of this.extensions) { const handlers = ext.handlers.get(event.type); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const handlerResult = await this.#runHandlerWithTimeout( handler, event, ctx, ext, extensionHandlerTimeoutMs, ); if (this.#isSessionBeforeEvent(event) && handlerResult) { result = handlerResult as SessionBeforeEventResult; if (result.cancel) { return result as RunnerEmitResult; } } if (event.type === "session.compacting" && handlerResult) { result = handlerResult as SessionCompactingResult; } } } return result as RunnerEmitResult; } async emitToolResult(event: ToolResultEvent): Promise { const ctx = this.createContext(); const currentEvent: ToolResultEvent = { ...event }; let modified = false; for (const ext of this.extensions) { const handlers = ext.handlers.get("tool_result"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const handlerResult = (await this.#runHandlerWithTimeout( handler, currentEvent, ctx, ext, extensionHandlerTimeoutMs, )) as ToolResultEventResult | undefined; if (!handlerResult) continue; if (handlerResult.content !== undefined) { currentEvent.content = handlerResult.content; modified = true; } if (handlerResult.details !== undefined) { currentEvent.details = handlerResult.details; modified = true; } if (handlerResult.isError !== undefined) { currentEvent.isError = handlerResult.isError; modified = true; } } } if (!modified) return undefined; return { content: currentEvent.content, details: currentEvent.details, isError: currentEvent.isError, }; } async emitToolCall(event: ToolCallEvent): Promise { const ctx = this.createContext(); let result: ToolCallEventResult | undefined; for (const ext of this.extensions) { const handlers = ext.handlers.get("tool_call"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { try { const handlerResult = await handler(event, ctx); if (handlerResult) { result = handlerResult as ToolCallEventResult; if (result.block) { return result; } } } catch (err) { const message = err instanceof Error ? err.message : String(err); const stack = err instanceof Error ? err.stack : undefined; this.emitError({ extensionPath: ext.path, event: "tool_call", error: message, stack, }); return { block: true, reason: `Extension ${ext.path} failed: ${message}` }; } } } return result; } async emitUserBash(event: UserBashEvent): Promise { return this.emitUserEvent(event, "user_bash"); } async emitUserPython(event: UserPythonEvent): Promise { return this.emitUserEvent(event, "user_python"); } private async emitUserEvent( event: UserBashEvent | UserPythonEvent, eventName: "user_bash" | "user_python", ): Promise { const ctx = this.createContext(); for (const ext of this.extensions) { const handlers = ext.handlers.get(eventName); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const handlerResult = await this.#runHandlerWithTimeout( handler, event, ctx, ext, extensionHandlerTimeoutMs, ); if (handlerResult) { return handlerResult as R; } } } return undefined; } async emitResourcesDiscover( cwd: string, reason: ResourcesDiscoverEvent["reason"], ): Promise<{ skillPaths: Array<{ path: string; extensionPath: string }>; promptPaths: Array<{ path: string; extensionPath: string }>; themePaths: Array<{ path: string; extensionPath: string }>; }> { const ctx = this.createContext(); const skillPaths: Array<{ path: string; extensionPath: string }> = []; const promptPaths: Array<{ path: string; extensionPath: string }> = []; const themePaths: Array<{ path: string; extensionPath: string }> = []; for (const ext of this.extensions) { const handlers = ext.handlers.get("resources_discover"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const event: ResourcesDiscoverEvent = { type: "resources_discover", cwd, reason }; const handlerResult = await this.#runHandlerWithTimeout( handler, event, ctx, ext, extensionHandlerTimeoutMs, ); const result = handlerResult as ResourcesDiscoverResult | undefined; if (result?.skillPaths?.length) { skillPaths.push(...result.skillPaths.map(path => ({ path, extensionPath: ext.path }))); } if (result?.promptPaths?.length) { promptPaths.push(...result.promptPaths.map(path => ({ path, extensionPath: ext.path }))); } if (result?.themePaths?.length) { themePaths.push(...result.themePaths.map(path => ({ path, extensionPath: ext.path }))); } } } return { skillPaths, promptPaths, themePaths }; } /** Emit input event. Transforms chain, "handled" short-circuits. */ async emitInput( text: string, images: ImageContent[] | undefined, source: "interactive" | "rpc" | "extension", ): Promise { const ctx = this.createContext(); let currentText = text; let currentImages = images; for (const ext of this.extensions) { for (const handler of ext.handlers.get("input") ?? []) { const event: InputEvent = { type: "input", text: currentText, images: currentImages, source }; const result = (await this.#runHandlerWithTimeout(handler, event, ctx, ext, extensionHandlerTimeoutMs)) as | InputEventResult | undefined; if (result?.handled) return result; if (result?.text !== undefined) { currentText = result.text; currentImages = result.images ?? currentImages; } } } return currentText !== text || currentImages !== images ? { text: currentText, images: currentImages } : {}; } async emitContext(messages: AgentMessage[]): Promise { const ctx = this.createContext(); // Check if any extensions actually have context handlers before cloning let hasContextHandlers = false; for (const ext of this.extensions) { if (ext.handlers.get("context")?.length) { hasContextHandlers = true; break; } } if (!hasContextHandlers) return messages; let currentMessages: AgentMessage[]; try { currentMessages = structuredClone(messages); } catch { // Messages may contain non-cloneable objects (e.g. in ToolResultMessage.details // or ProviderPayload). Fall back to a shallow array clone — extensions should // return new message arrays rather than mutating in place. currentMessages = [...messages]; } for (const ext of this.extensions) { const handlers = ext.handlers.get("context"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const event: ContextEvent = { type: "context", messages: currentMessages }; const handlerResult = await this.#runHandlerWithTimeout( handler, event, ctx, ext, extensionHandlerTimeoutMs, ); if (handlerResult && (handlerResult as ContextEventResult).messages) { currentMessages = (handlerResult as ContextEventResult).messages!; } } } return currentMessages; } async emitBeforeProviderRequest(payload: unknown): Promise { const ctx = this.createContext(); let currentPayload = payload; for (const ext of this.extensions) { const handlers = ext.handlers.get("before_provider_request"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const event: BeforeProviderRequestEvent = { type: "before_provider_request", payload: currentPayload, }; const handlerResult = await this.#runHandlerWithTimeout( handler, event, ctx, ext, extensionHandlerTimeoutMs, ); if (handlerResult !== undefined) { currentPayload = handlerResult; } } } return currentPayload; } async emitAfterProviderResponse(response: ProviderResponseMetadata, _model?: Model): Promise { const ctx = this.createContext(); for (const ext of this.extensions) { const handlers = ext.handlers.get("after_provider_response"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const event: AfterProviderResponseEvent = { type: "after_provider_response", status: response.status, headers: response.headers, requestId: response.requestId, metadata: response.metadata, }; await this.#runHandlerWithTimeout(handler, event, ctx, ext, extensionHandlerTimeoutMs); } } } async emitBeforeAgentStart( prompt: string, images: ImageContent[] | undefined, systemPrompt: string[], ): Promise { const ctx = this.createContext(); const messages: NonNullable[] = []; let currentSystemPrompt = systemPrompt; let systemPromptModified = false; for (const ext of this.extensions) { const handlers = ext.handlers.get("before_agent_start"); if (!handlers || handlers.length === 0) continue; for (const handler of handlers) { const event: BeforeAgentStartEvent = { type: "before_agent_start", prompt, images, systemPrompt: currentSystemPrompt, }; const handlerResult = await this.#runHandlerWithTimeout( handler, event, ctx, ext, extensionHandlerTimeoutMs, ); if (handlerResult) { const result = handlerResult as BeforeAgentStartEventResult; if (result.message) { messages.push(result.message); } if (result.systemPrompt !== undefined) { currentSystemPrompt = result.systemPrompt; systemPromptModified = true; } } } } if (messages.length > 0 || systemPromptModified) { return { messages: messages.length > 0 ? messages : undefined, systemPrompt: systemPromptModified ? currentSystemPrompt : undefined, }; } return undefined; } }