/* Copyright 2026 Marimo. All rights reserved. */ import type * as Plotly from "plotly.js"; import { z } from "zod"; import type { IPlugin, IPluginProps, Setter } from "@/plugins/types"; import { Logger } from "@/utils/Logger"; import type { Figure } from "./Plot"; import "./plotly.css"; import "./mapbox.css"; import { set } from "lodash-es"; import { type JSX, lazy, memo, useMemo } from "react"; import useEvent from "react-use-event-hook"; import { useDeepCompareMemoize } from "@/hooks/useDeepCompareMemoize"; import { useScript } from "@/hooks/useScript"; import { Arrays } from "@/utils/arrays"; import { extractIndices, extractPoints, extractSunburstPoints, extractTreemapPoints, hasAreaTrace, hasPureLineTrace, lineSelectionButtons, type ModeBarButton, mergeModeBarButtonsToAdd, shouldHandleClickSelection, } from "./selection"; import { usePlotlyLayout } from "./usePlotlyLayout"; interface Data { figure: Figure; config: Partial; } type T = | { points?: Record[] | Plotly.PlotDatum[]; indices?: number[]; range?: { x?: number[]; y?: number[]; }; lasso?: { x?: unknown[]; y?: unknown[]; }; // These are kept in the state to persist selections across re-renders // on the frontend, but likely not used in the backend. selections?: unknown[]; dragmode?: Plotly.Layout["dragmode"]; xaxis?: Partial; yaxis?: Partial; } | undefined; export class PlotlyPlugin implements IPlugin { tagName = "marimo-plotly"; validator = z.object({ figure: z .object({}) .passthrough() .transform((spec) => spec as unknown as Figure), config: z.object({}).passthrough(), }); render(props: IPluginProps): JSX.Element { return ( ); } } interface PlotlyPluginProps extends Data { value: T; setValue: Setter; host: HTMLElement; } const LazyPlot = lazy(() => import("./Plot").then((mod) => ({ default: mod.Plot })), ); export const PlotlyComponent = memo( ({ figure: originalFigure, value, setValue, config }: PlotlyPluginProps) => { // Used for rendering LaTeX. TODO: Serve this library from Marimo const scriptStatus = useScript( "https://cdn.jsdelivr.net/npm/mathjax-full@3.2.2/es5/tex-mml-svg.min.js", ); const isScriptLoaded = scriptStatus === "ready"; const { figure, layout, setLayout, handleReset } = usePlotlyLayout({ originalFigure, initialValue: value, isScriptLoaded, }); const handleResetWithClear = useEvent(() => { handleReset(); setValue({}); }); const handleSetDragmode = useEvent( (dragmode: Plotly.Layout["dragmode"]) => { setLayout((prev) => ({ ...prev, dragmode })); setValue((prev) => ({ ...prev, dragmode })); }, ); const configMemo = useDeepCompareMemoize(config); const plotlyConfig = useMemo((): Partial => { const hasLineOrAreaTrace = hasPureLineTrace(figure.data) || hasAreaTrace(figure.data); const defaultButtons: ModeBarButton[] = [ // Custom button to reset the state { name: "reset", title: "Reset state", icon: { svg: ` `, }, click: handleResetWithClear, }, ]; if (hasLineOrAreaTrace) { defaultButtons.push(...lineSelectionButtons(handleSetDragmode)); } return { displaylogo: false, // Prioritize user's config ...configMemo, modeBarButtonsToAdd: mergeModeBarButtonsToAdd( defaultButtons, configMemo.modeBarButtonsToAdd as | readonly ModeBarButton[] | undefined, ), }; }, [handleResetWithClear, handleSetDragmode, configMemo, figure.data]); return ( { // Persist dragmode in the state to keep it across re-renders if ("dragmode" in layoutUpdate) { setValue((prev) => ({ ...prev, dragmode: layoutUpdate.dragmode })); } // Persist xaxis/yaxis changes in the state to keep it across re-renders if ( Object.keys(layoutUpdate).some( (key) => key.includes("xaxis") || key.includes("yaxis"), ) ) { // Axis changes are keypath updates, so need to use lodash.set // e.g. xaxis.range[0], xaxis.range[1], yaxis.range[0], yaxis.range[1] const obj: Partial = {}; Object.entries(layoutUpdate).forEach(([key, value]) => { set(obj, key, value); }); setValue((prev) => ({ ...prev, ...obj })); } }} onDeselect={useEvent(() => { setValue((prev) => { return { ...prev, selections: Arrays.EMPTY, points: Arrays.EMPTY, indices: Arrays.EMPTY, range: undefined, lasso: undefined, }; }); })} onTreemapClick={useEvent((evt: Readonly) => { if (!evt) { return; } setValue((prev) => ({ ...prev, points: extractTreemapPoints(evt.points), })); })} onSunburstClick={useEvent((evt: Readonly) => { if (!evt) { return; } setValue((prev) => ({ ...prev, points: extractSunburstPoints(evt.points), })); })} config={plotlyConfig} onClick={useEvent((evt: Readonly) => { if (!evt) { return; } // Handle clicks for chart types where box/lasso selection // is limited or unavailable (e.g. bar, heatmaps, histograms, pure line traces). if (!shouldHandleClickSelection(evt.points)) { return; } const extractedPoints = extractPoints(evt.points); const extractedIndices = extractIndices(evt.points); setValue((prev) => ({ ...prev, selections: Arrays.EMPTY, range: undefined, lasso: undefined, points: extractedPoints, indices: extractedIndices, })); })} onSelected={useEvent((evt: Readonly) => { if (!evt) { return; } setValue((prev) => ({ ...prev, selections: "selections" in evt ? (evt.selections as unknown[]) : [], points: extractPoints(evt.points), indices: extractIndices(evt.points), range: evt.range, lasso: "lassoPoints" in evt ? (evt.lassoPoints as { x?: unknown[]; y?: unknown[] }) : undefined, })); })} className="w-full" useResizeHandler={true} frames={figure.frames ?? undefined} onError={useEvent((err: Error) => { Logger.error("PlotlyPlugin: ", err); })} /> ); }, ); PlotlyComponent.displayName = "PlotlyComponent";