/* Copyright 2026 Marimo. All rights reserved. */ import type { Atom } from "jotai"; import { type Edge, MarkerType, type Node, type NodeProps } from "reactflow"; import { getNotebook } from "@/core/cells/cells"; import type { CellId } from "@/core/cells/ids"; import type { CellData } from "@/core/cells/types"; import { store } from "@/core/state/jotai"; import type { Variables } from "@/core/variables/types"; import { Arrays } from "@/utils/arrays"; export interface NodeData { atom: Atom; forceWidth?: number; } export type CustomNodeProps = NodeProps; export function getNodeHeight(linesOfCode: number) { const LINE_HEIGHT = 11; // matches TinyCode.css return Math.min(linesOfCode * LINE_HEIGHT + 35, 200); } // The nodes must have the same handle IDs to ensure edges connect correctly export const OUTPUTS_HANDLE_ID = "outputs"; export const INPUTS_HANDLE_ID = "inputs"; interface ElementsBuilder { createElements: ( cellIds: CellId[], cellAtoms: Atom[], variables: Variables, hidePureMarkdown: boolean, hideReusableFunctions: boolean, ) => { nodes: Node[]; edges: Edge[] }; } export class VerticalElementsBuilder implements ElementsBuilder { private createEdge(source: CellId, target: CellId, direction: string): Edge { return { type: "smoothstep", pathOptions: { offset: 20, borderRadius: 100, }, data: { direction: direction, }, markerEnd: { type: MarkerType.Arrow, }, id: `${source}-${target}-${direction}`, source: source, sourceHandle: direction, targetHandle: direction, target: target, }; } private createNode( id: string, atom: Atom, prevY: number, ): Node { const linesOfCode = store.get(atom).code.trim().split("\n").length; const height = getNodeHeight(linesOfCode); return { id: id, data: { atom }, width: 250, type: "custom", height: height, position: { x: 0, y: prevY + 20 }, }; } createElements( cellIds: CellId[], cellAtoms: Atom[], variables: Variables, _hidePureMarkdown: boolean, _hideReusableFunctions: boolean, ) { let prevY = 0; const nodes: Node[] = []; const edges: Edge[] = []; for (const [cellId, cellAtom] of Arrays.zip(cellIds, cellAtoms)) { const node = this.createNode(cellId, cellAtom, prevY); nodes.push(node); prevY = node.position.y + (node.height || 0); } const visited = new Set(); for (const variable of Object.values(variables)) { const { declaredBy, usedBy } = variable; for (const fromId of declaredBy) { for (const toId of usedBy) { const key = `${fromId}-${toId}`; if (visited.has(key)) { continue; } visited.add(key); edges.push( this.createEdge(fromId, toId, INPUTS_HANDLE_ID), this.createEdge(fromId, toId, OUTPUTS_HANDLE_ID), ); } } } return { nodes, edges }; } } export class TreeElementsBuilder implements ElementsBuilder { private createEdge(source: CellId, target: CellId): Edge { return { animated: true, markerEnd: { type: MarkerType.ArrowClosed, }, id: `${source}-${target}`, // Make thicker style: { strokeWidth: 2 }, source: source, // Use the same handle ids as the custom node sourceHandle: OUTPUTS_HANDLE_ID, targetHandle: INPUTS_HANDLE_ID, target: target, }; } private createNode(id: string, atom: Atom): Node { const linesOfCode = store.get(atom).code.trim().split("\n").length; const height = getNodeHeight(linesOfCode); return { id: id, data: { atom, forceWidth: 300 }, width: 300, type: "custom", height: height, position: { x: 0, y: 0 }, }; } createElements( cellIds: CellId[], cellAtoms: Atom[], variables: Variables, hidePureMarkdown: boolean, hideReusableFunctions: boolean, ) { const nodes: Node[] = []; const edges: Edge[] = []; const nodesWithEdges = new Set(); const visited = new Set(); for (const variable of Object.values(variables)) { // Skip marimo, since likely every cell uses it if (variable.value === "marimo" && variable.name === "mo") { continue; } const { declaredBy, usedBy } = variable; for (const fromId of declaredBy) { for (const toId of usedBy) { const key = `${fromId}-${toId}`; if (visited.has(key)) { continue; } visited.add(key); nodesWithEdges.add(fromId); nodesWithEdges.add(toId); edges.push(this.createEdge(fromId, toId)); } } } const cellRuntime = getNotebook().cellRuntime; for (const [cellId, cellAtom] of Arrays.zip(cellIds, cellAtoms)) { const code = store.get(cellAtom).code.trim(); const hasEdge = nodesWithEdges.has(cellId); const isMarkdown = code.startsWith("mo.md"); const runtime = cellRuntime[cellId]; const isReusable = runtime?.serialization?.toLowerCase() === "valid"; // Apply filters if (hidePureMarkdown && isMarkdown && !hasEdge) { continue; } if (hideReusableFunctions && isReusable && !hasEdge) { continue; } // Show every cell that wasn't filtered out nodes.push(this.createNode(cellId, cellAtom)); } return { nodes, edges }; } }