/* Copyright 2026 Marimo. All rights reserved. */ /* oxlint-disable typescript/no-explicit-any */ import { useCallback, useEffect, useRef } from "react"; import { z } from "zod"; import { useEventListener } from "@/hooks/useEventListener"; import { createPlugin } from "@/plugins/core/builder"; import { isTrustedVirtualFileUrl } from "@/plugins/core/trusted-url"; import { MODEL_MANAGER, type Model } from "@/plugins/impl/anywidget/model"; import type { ModelState, WidgetModelId } from "@/plugins/impl/anywidget/types"; import type { IPluginProps } from "@/plugins/types"; import { downloadBlob } from "@/utils/download"; import { Logger } from "@/utils/Logger"; import { MplCommWebSocket } from "./mpl-websocket-shim"; import { Functions } from "@/utils/functions"; const MPL_SCOPE_CLASS = "mpl-interactive-figure"; interface Data { mplJsUrl: string; cssUrl: string; toolbarImages: Record; width: number; height: number; } interface ModelIdRef { model_id: WidgetModelId; } declare global { interface Window { mpl: { figure: new ( id: string, ws: MplCommWebSocket, ondownload: (figure: MplFigure, format: string) => void, element: HTMLElement, ) => MplFigure; toolbar_items: [ string | null, string | null, string | null, string | null, ][]; }; } } interface MplFigure { id: string; ws: MplCommWebSocket; root: HTMLElement; send_message: (type: string, properties: Record) => void; } export const MplInteractivePlugin = createPlugin( "marimo-mpl-interactive", ) .withData( z.object({ mplJsUrl: z.string(), cssUrl: z.string(), toolbarImages: z.record(z.string(), z.string()), width: z.number(), height: z.number(), }), ) .withFunctions({}) .renderer((props) => ); let mplJsLoading: Promise | null = null; async function ensureMplJs(jsUrl: string): Promise { if (window.mpl) { return; } if (!isTrustedVirtualFileUrl(jsUrl)) { throw new Error( `Refusing to load mpl.js from untrusted URL: ${String(jsUrl)}`, ); } if (mplJsLoading) { return mplJsLoading; } mplJsLoading = new Promise((resolve, reject) => { const script = document.createElement("script"); script.src = jsUrl; script.onload = () => resolve(); script.onerror = () => { mplJsLoading = null; reject(new Error("Failed to load mpl.js")); }; document.head.append(script); }); return mplJsLoading; } /** * Patch mpl.js toolbar image references to use inline data URIs. * * mpl.js sets `icon_img.src = '_images/' + image + '.png'` and * `icon_img.srcset = '_images/' + image + '_large.png 2x'`. * * We observe the container for new elements and rewrite their * src/srcset to the inlined base64 data URIs. */ function patchToolbarImages( container: HTMLElement, toolbarImages: Record, ): () => void { const patchImg = (img: HTMLImageElement) => { const src = img.getAttribute("src") || ""; const match = src.match(/_images\/(.+)\.png$/); if (match) { const name = match[1]; const dataUri = toolbarImages[name]; if (dataUri) { img.src = dataUri; } } const srcset = img.getAttribute("srcset") || ""; const srcsetMatch = srcset.match(/_images\/(.+)\.png\s+2x$/); if (srcsetMatch) { const name = srcsetMatch[1]; const dataUri = toolbarImages[name]; if (dataUri) { img.srcset = `${dataUri} 2x`; } } }; // Patch any existing images for (const img of container.querySelectorAll("img")) { patchImg(img); } // Observe for new images added by mpl.js const observer = new MutationObserver((mutations) => { for (const mutation of mutations) { for (const node of mutation.addedNodes) { if (node instanceof HTMLImageElement) { patchImg(node); } else if (node instanceof HTMLElement) { for (const img of node.querySelectorAll("img")) { patchImg(img); } } } } }); observer.observe(container, { childList: true, subtree: true }); return () => observer.disconnect(); } function injectCss(container: HTMLElement, cssUrl: string): () => void { if (!isTrustedVirtualFileUrl(cssUrl)) { Logger.error( `Refusing to load mpl CSS from untrusted URL: ${String(cssUrl)}`, ); return Functions.NOOP; } const link = document.createElement("link"); link.rel = "stylesheet"; link.href = cssUrl; container.append(link); return () => link.remove(); } const MplInteractiveSlot = (props: IPluginProps) => { const { mplJsUrl, cssUrl, toolbarImages, width, height } = props.data; const { model_id: modelId } = props.value; const containerRef = useRef(null); const figureRef = useRef(null); const wsRef = useRef(null); const setupFigure = useCallback( async (container: HTMLElement) => { // Load mpl.js globally (only once, via