/** * Extension loader - loads TypeScript extension modules using native Bun import. */ import type * as fs1 from "node:fs"; import * as fs from "node:fs/promises"; import * as path from "node:path"; import type { ThinkingLevel } from "@oh-my-pi/pi-agent-core"; import type { ImageContent, Model, TextContent } from "@oh-my-pi/pi-ai"; import type { KeyId } from "@oh-my-pi/pi-tui"; import { hasFsCode, isEacces, isEnoent, logger } from "@oh-my-pi/pi-utils"; import * as Zod from "zod/v4"; import { type ExtensionModule, extensionModuleCapability } from "../../capability/extension-module"; import { loadCapability } from "../../discovery"; import { getExtensionNameFromPath } from "../../discovery/helpers"; import type { ExecOptions } from "../../exec/exec"; import { execCommand } from "../../exec/exec"; import type { CustomMessage } from "../../session/messages"; import { EventBus } from "../../utils/event-bus"; import { installLegacyPiSpecifierShim, loadLegacyPiModule } from "../plugins/legacy-pi-compat"; import { getAllPluginExtensionPaths } from "../plugins/loader"; import * as TypeBox from "../typebox"; import { resolvePath } from "../utils"; import type { Extension, ExtensionAPI, ExtensionContext, ExtensionFactory, ExtensionRuntime as IExtensionRuntime, LoadExtensionsResult, MessageRenderer, RegisteredCommand, ToolDefinition, } from "./types"; installLegacyPiSpecifierShim(); type HandlerFn = (...args: unknown[]) => Promise; type LoadedExtensionModule = ExtensionFactory | { default?: ExtensionFactory }; function getExtensionFactory(module: LoadedExtensionModule): ExtensionFactory | null { const candidate = typeof module === "function" ? module : module.default; return typeof candidate === "function" ? candidate : null; } export class ExtensionRuntimeNotInitializedError extends Error { constructor() { super("Extension runtime not initialized. Action methods cannot be called during extension loading."); } } /** * Extension runtime with throwing stubs for action methods. * These are replaced with real implementations during initialization. */ export class ExtensionRuntime implements IExtensionRuntime { flagValues = new Map(); pendingProviderRegistrations: Array<{ name: string; config: import("./types").ProviderConfig; sourceId: string }> = []; sendMessage(): void { throw new ExtensionRuntimeNotInitializedError(); } sendUserMessage(): void { throw new ExtensionRuntimeNotInitializedError(); } appendEntry(): void { throw new ExtensionRuntimeNotInitializedError(); } setLabel(): void { throw new ExtensionRuntimeNotInitializedError(); } getActiveTools(): string[] { throw new ExtensionRuntimeNotInitializedError(); } getAllTools(): string[] { throw new ExtensionRuntimeNotInitializedError(); } setActiveTools(): Promise { throw new ExtensionRuntimeNotInitializedError(); } getCommands(): never { throw new ExtensionRuntimeNotInitializedError(); } setModel(): Promise { throw new ExtensionRuntimeNotInitializedError(); } getThinkingLevel(): ThinkingLevel { throw new ExtensionRuntimeNotInitializedError(); } setThinkingLevel(): void { throw new ExtensionRuntimeNotInitializedError(); } getSessionName(): string | undefined { throw new ExtensionRuntimeNotInitializedError(); } setSessionName(): Promise { throw new ExtensionRuntimeNotInitializedError(); } } /** * ExtensionAPI implementation for an extension. * Registration methods write to the extension object. * Action methods delegate to the shared runtime. */ class ConcreteExtensionAPI implements ExtensionAPI, IExtensionRuntime { readonly logger = logger; readonly typebox = TypeBox; readonly zod = Zod; readonly flagValues = new Map(); readonly pendingProviderRegistrations: Array<{ name: string; config: import("./types").ProviderConfig; sourceId: string; }> = []; constructor( public readonly pi: typeof import("@oh-my-pi/pi-coding-agent"), private readonly extension: Extension, private readonly runtime: IExtensionRuntime, private readonly cwd: string, public readonly events: EventBus, ) {} on(event: string, handler: F): void { const list = this.extension.handlers.get(event) ?? []; list.push(handler); this.extension.handlers.set(event, list); } registerTool< TParams extends import("@oh-my-pi/pi-ai").TSchema = import("@oh-my-pi/pi-ai").TSchema, TDetails = unknown, >(tool: ToolDefinition): void { this.extension.tools.set(tool.name, { definition: tool, extensionPath: this.extension.path, }); } registerCommand( name: string, options: { description?: string; getArgumentCompletions?: RegisteredCommand["getArgumentCompletions"]; handler: RegisteredCommand["handler"]; }, ): void { this.extension.commands.set(name, { name, ...options }); } setLabel(label: string): void { this.extension.label = label; } registerShortcut( shortcut: KeyId, options: { description?: string; handler: (ctx: ExtensionContext) => Promise | void; }, ): void { this.extension.shortcuts.set(shortcut, { shortcut, extensionPath: this.extension.path, ...options }); } registerFlag( name: string, options: { description?: string; type: "boolean" | "string"; default?: boolean | string }, ): void { this.extension.flags.set(name, { name, extensionPath: this.extension.path, ...options }); if (options.default !== undefined) { this.runtime.flagValues.set(name, options.default); } } registerMessageRenderer(customType: string, renderer: MessageRenderer): void { this.extension.messageRenderers.set(customType, renderer as MessageRenderer); } getFlag(name: string): boolean | string | undefined { if (!this.extension.flags.has(name)) return undefined; return this.runtime.flagValues.get(name); } sendMessage( message: Pick, "customType" | "content" | "display" | "details" | "attribution">, options?: { triggerTurn?: boolean; deliverAs?: "steer" | "followUp" | "nextTurn" }, ): void { this.runtime.sendMessage(message, options); } sendUserMessage( content: string | (TextContent | ImageContent)[], options?: { deliverAs?: "steer" | "followUp" }, ): void { this.runtime.sendUserMessage(content, options); } appendEntry(customType: string, data?: unknown): void { this.runtime.appendEntry(customType, data); } exec(command: string, args: string[], options?: ExecOptions) { return execCommand(command, args, options?.cwd ?? this.cwd, options); } getActiveTools(): string[] { return this.runtime.getActiveTools(); } getAllTools(): string[] { return this.runtime.getAllTools(); } setActiveTools(toolNames: string[]): Promise { return this.runtime.setActiveTools(toolNames); } getCommands() { return this.runtime.getCommands(); } setModel(model: Model): Promise { return this.runtime.setModel(model); } getThinkingLevel(): ThinkingLevel | undefined { return this.runtime.getThinkingLevel(); } setThinkingLevel(level: ThinkingLevel, persist?: boolean): void { this.runtime.setThinkingLevel(level, persist); } getSessionName(): string | undefined { return this.runtime.getSessionName(); } setSessionName(name: string): Promise { return this.runtime.setSessionName(name); } registerProvider(name: string, config: import("./types").ProviderConfig): void { this.runtime.pendingProviderRegistrations.push({ name, config, sourceId: this.extension.path }); } } /** * Create an Extension object with empty collections. */ function createExtension(extensionPath: string, resolvedPath: string): Extension { return { path: extensionPath, resolvedPath, handlers: new Map(), tools: new Map(), messageRenderers: new Map(), commands: new Map(), flags: new Map(), shortcuts: new Map(), }; } async function loadExtension( extensionPath: string, cwd: string, eventBus: EventBus, runtime: IExtensionRuntime, ): Promise<{ extension: Extension | null; error: string | null }> { const resolvedPath = resolvePath(extensionPath, cwd); try { const module = (await loadLegacyPiModule(resolvedPath)) as LoadedExtensionModule; const factory = getExtensionFactory(module); if (typeof factory !== "function") { return { extension: null, error: `Extension does not export a valid factory function: ${extensionPath}`, }; } const extension = createExtension(extensionPath, resolvedPath); const api = new ConcreteExtensionAPI( await import("@oh-my-pi/pi-coding-agent"), extension, runtime, cwd, eventBus, ); await factory(api); return { extension, error: null }; } catch (err) { const message = err instanceof Error ? err.message : String(err); return { extension: null, error: `Failed to load extension: ${message}` }; } } /** * Create an Extension from an inline factory function. */ export async function loadExtensionFromFactory( factory: ExtensionFactory, cwd: string, eventBus: EventBus, runtime: IExtensionRuntime, name = "", ): Promise { const extension = createExtension(name, name); const api = new ConcreteExtensionAPI(await import("@oh-my-pi/pi-coding-agent"), extension, runtime, cwd, eventBus); await factory(api); return extension; } /** * Load extensions from paths. */ export async function loadExtensions(paths: string[], cwd: string, eventBus?: EventBus): Promise { const extensions: Extension[] = []; const errors: Array<{ path: string; error: string }> = []; const resolvedEventBus = eventBus ?? new EventBus(); const runtime = new ExtensionRuntime(); for (const extPath of paths) { const { extension, error } = await loadExtension(extPath, cwd, resolvedEventBus, runtime); if (error) { errors.push({ path: extPath, error }); continue; } if (extension) { extensions.push(extension); } } return { extensions, errors, runtime, }; } interface ExtensionManifest { extensions?: string[]; themes?: string[]; skills?: string[]; } async function readExtensionManifest(packageJsonPath: string): Promise { try { const pkg = (await Bun.file(packageJsonPath).json()) as { omp?: ExtensionManifest; pi?: ExtensionManifest }; const manifest = pkg.omp ?? pkg.pi; if (manifest && typeof manifest === "object") { return manifest; } return null; } catch (error) { if (isEnoent(error) || isEacces(error) || hasFsCode(error, "EPERM")) { return null; } logger.warn("Failed to read extension manifest", { path: packageJsonPath, error: String(error) }); return null; } } function isExtensionFile(name: string): boolean { return name.endsWith(".ts") || name.endsWith(".js"); } /** * Resolve extension entry points from a directory. */ async function resolveExtensionEntries(dir: string): Promise { const packageJsonPath = path.join(dir, "package.json"); const manifest = await readExtensionManifest(packageJsonPath); if (manifest?.extensions?.length) { const entries: string[] = []; for (const extPath of manifest.extensions) { const resolvedExtPath = path.resolve(dir, extPath); try { await fs.stat(resolvedExtPath); entries.push(resolvedExtPath); } catch (err) { if (isEnoent(err) || isEacces(err) || hasFsCode(err, "EPERM")) continue; throw err; } } if (entries.length > 0) { return entries; } } const indexTs = path.join(dir, "index.ts"); const indexJs = path.join(dir, "index.js"); try { await fs.stat(indexTs); return [indexTs]; } catch (err) { if (isEnoent(err) || isEacces(err) || hasFsCode(err, "EPERM")) { // Ignore } else { throw err; } } try { await fs.stat(indexJs); return [indexJs]; } catch (err) { if (isEnoent(err) || isEacces(err) || hasFsCode(err, "EPERM")) { // Ignore } else { throw err; } } return null; } /** * Discover extensions in a directory. * * Discovery rules: * 1. Direct files: `extensions/*.ts` or `*.js` → load * 2. Subdirectory with index: `extensions//index.ts` or `index.js` → load * 3. Subdirectory with package.json: `extensions//package.json` with "omp"/"pi" field → load declared paths * * No recursion beyond one level. Complex packages must use package.json manifest. */ async function discoverExtensionsInDir(dir: string): Promise { const discovered: string[] = []; // First check if this directory itself has explicit extension entries (package.json or index) const rootEntries = await resolveExtensionEntries(dir); if (rootEntries) { return rootEntries; } // Otherwise, discover extensions from directory contents let entries: fs1.Dirent[]; try { entries = await fs.readdir(dir, { withFileTypes: true }); } catch (err) { if (isEnoent(err)) return []; logger.warn("Failed to discover extensions in directory", { path: dir, error: String(err) }); return []; } for (const entry of entries) { const entryPath = path.join(dir, entry.name); if ((entry.isFile() || entry.isSymbolicLink()) && isExtensionFile(entry.name)) { discovered.push(entryPath); continue; } if (entry.isDirectory() || entry.isSymbolicLink()) { const resolved = await resolveExtensionEntries(entryPath); if (resolved) { discovered.push(...resolved); } } } return discovered; } /** * Discover and load extensions from standard locations. */ export async function discoverAndLoadExtensions( configuredPaths: string[], cwd: string, eventBus?: EventBus, disabledExtensionIds: string[] = [], ): Promise { const allPaths: string[] = []; const seen = new Set(); const disabled = new Set(disabledExtensionIds); const isDisabledName = (name: string): boolean => disabled.has(`extension-module:${name}`); const addPath = (extPath: string): void => { const resolved = path.resolve(extPath); if (!seen.has(resolved)) { seen.add(resolved); allPaths.push(extPath); } }; const addPaths = (paths: string[]) => { for (const extPath of paths) { if (isDisabledName(getExtensionNameFromPath(extPath))) continue; addPath(extPath); } }; // 1. Discover extension modules via capability API (native .omp/.pi only) const discovered = await loadCapability(extensionModuleCapability.id, { cwd }); for (const ext of discovered.items) { if (ext._source.provider !== "native") continue; if (isDisabledName(ext.name)) continue; addPath(ext.path); } // 2. Discover extension entry points from installed plugins addPaths(await getAllPluginExtensionPaths(cwd)); // 3. Explicitly configured paths for (const configuredPath of configuredPaths) { const resolved = resolvePath(configuredPath, cwd); let stat: fs1.Stats | null = null; try { stat = await fs.stat(resolved); } catch (err) { if (!isEnoent(err)) throw err; } if (stat?.isDirectory()) { const entries = await resolveExtensionEntries(resolved); if (entries) { addPaths(entries); continue; } const discovered = await discoverExtensionsInDir(resolved); if (discovered.length > 0) { addPaths(discovered); } continue; } addPath(resolved); } return loadExtensions(allPaths, cwd, eventBus); }