import { useCallback, useMemo } from 'react'; import useViewportHelper from './useViewportHelper'; import { useStoreApi } from '../hooks/useStore'; import type { ReactFlowInstance, Instance, NodeAddChange, EdgeAddChange, NodeResetChange, EdgeResetChange, NodeRemoveChange, EdgeRemoveChange, NodeChange, Node, Rect, } from '../types'; import { getConnectedEdges } from '../utils/graph'; import { getOverlappingArea, isRectObject, nodeToRect } from '../utils'; /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ export default function useReactFlow(): ReactFlowInstance { const viewportHelper = useViewportHelper(); const store = useStoreApi(); const getNodes = useCallback>(() => { const { nodeInternals } = store.getState(); const nodes = Array.from(nodeInternals.values()); return nodes.map((n) => ({ ...n })); }, []); const getNode = useCallback>((id) => { const { nodeInternals } = store.getState(); return nodeInternals.get(id); }, []); const getEdges = useCallback>(() => { const { edges = [] } = store.getState(); return edges.map((e) => ({ ...e })); }, []); const getEdge = useCallback>((id) => { const { edges = [] } = store.getState(); return edges.find((e) => e.id === id); }, []); const setNodes = useCallback>((payload) => { const { nodeInternals, setNodes, hasDefaultNodes, onNodesChange } = store.getState(); const nodes = Array.from(nodeInternals.values()); const nextNodes = typeof payload === 'function' ? payload(nodes) : payload; if (hasDefaultNodes) { setNodes(nextNodes); } else if (onNodesChange) { const changes = nextNodes.length === 0 ? nodes.map((node) => ({ type: 'remove', id: node.id } as NodeRemoveChange)) : nextNodes.map((node) => ({ item: node, type: 'reset' } as NodeResetChange)); onNodesChange(changes); } }, []); const setEdges = useCallback>((payload) => { const { edges = [], setEdges, hasDefaultEdges, onEdgesChange } = store.getState(); const nextEdges = typeof payload === 'function' ? payload(edges) : payload; if (hasDefaultEdges) { setEdges(nextEdges); } else if (onEdgesChange) { const changes = nextEdges.length === 0 ? edges.map((edge) => ({ type: 'remove', id: edge.id } as EdgeRemoveChange)) : nextEdges.map((edge) => ({ item: edge, type: 'reset' } as EdgeResetChange)); onEdgesChange(changes); } }, []); const addNodes = useCallback>((payload) => { const nodes = Array.isArray(payload) ? payload : [payload]; const { nodeInternals, setNodes, hasDefaultNodes, onNodesChange } = store.getState(); if (hasDefaultNodes) { const currentNodes = Array.from(nodeInternals.values()); const nextNodes = [...currentNodes, ...nodes]; setNodes(nextNodes); } else if (onNodesChange) { const changes = nodes.map((node) => ({ item: node, type: 'add' } as NodeAddChange)); onNodesChange(changes); } }, []); const addEdges = useCallback>((payload) => { const nextEdges = Array.isArray(payload) ? payload : [payload]; const { edges = [], setEdges, hasDefaultEdges, onEdgesChange } = store.getState(); if (hasDefaultEdges) { setEdges([...edges, ...nextEdges]); } else if (onEdgesChange) { const changes = nextEdges.map((edge) => ({ item: edge, type: 'add' } as EdgeAddChange)); onEdgesChange(changes); } }, []); const toObject = useCallback>(() => { const { nodeInternals, edges = [], transform } = store.getState(); const nodes = Array.from(nodeInternals.values()); const [x, y, zoom] = transform; return { nodes: nodes.map((n) => ({ ...n })), edges: edges.map((e) => ({ ...e })), viewport: { x, y, zoom, }, }; }, []); const deleteElements = useCallback(({ nodes: nodesDeleted, edges: edgesDeleted }) => { const { nodeInternals, edges, hasDefaultNodes, hasDefaultEdges, onNodesDelete, onEdgesDelete, onNodesChange, onEdgesChange, } = store.getState(); const nodes = Array.from(nodeInternals.values()); const nodeIds = (nodesDeleted || []).map((node) => node.id); const edgeIds = (edgesDeleted || []).map((edge) => edge.id); const nodesToRemove = nodes.reduce((res, node) => { const parentHit = !nodeIds.includes(node.id) && node.parentNode && res.find((n) => n.id === node.parentNode); const deletable = typeof node.deletable === 'boolean' ? node.deletable : true; if (deletable && (nodeIds.includes(node.id) || parentHit)) { res.push(node); } return res; }, []); const deletableEdges = edges.filter((e) => (typeof e.deletable === 'boolean' ? e.deletable : true)); const initialHitEdges = deletableEdges.filter((e) => edgeIds.includes(e.id)); if (nodesToRemove || initialHitEdges) { const connectedEdges = getConnectedEdges(nodesToRemove, deletableEdges); const edgesToRemove = [...initialHitEdges, ...connectedEdges]; const edgeIdsToRemove = edgesToRemove.reduce((res, edge) => { if (!res.includes(edge.id)) { res.push(edge.id); } return res; }, []); if (hasDefaultEdges || hasDefaultNodes) { if (hasDefaultEdges) { store.setState({ edges: edges.filter((e) => !edgeIdsToRemove.includes(e.id)), }); } if (hasDefaultNodes) { nodesToRemove.forEach((node) => { nodeInternals.delete(node.id); }); store.setState({ nodeInternals: new Map(nodeInternals), }); } } if (edgeIdsToRemove.length > 0) { onEdgesDelete?.(edgesToRemove); if (onEdgesChange) { onEdgesChange( edgeIdsToRemove.map((id) => ({ id, type: 'remove', })) ); } } if (nodesToRemove.length > 0) { onNodesDelete?.(nodesToRemove); if (onNodesChange) { const nodeChanges: NodeChange[] = nodesToRemove.map((n) => ({ id: n.id, type: 'remove' })); onNodesChange(nodeChanges); } } } }, []); const getNodeRect = useCallback( ( nodeOrRect: (Partial> & { id: Node['id'] }) | Rect ): [Rect | null, Node | null | undefined, boolean] => { const isRect = isRectObject(nodeOrRect); const node = isRect ? null : store.getState().nodeInternals.get(nodeOrRect.id); if (!isRect && !node) { [null, null, isRect]; } const nodeRect = isRect ? nodeOrRect : nodeToRect(node!); return [nodeRect, node, isRect]; }, [] ); const getIntersectingNodes = useCallback>( (nodeOrRect, partially = true, nodes) => { const [nodeRect, node, isRect] = getNodeRect(nodeOrRect); if (!nodeRect) { return []; } return (nodes || Array.from(store.getState().nodeInternals.values())).filter((n) => { if (!isRect && (n.id === node!.id || !n.positionAbsolute)) { return false; } const currNodeRect = nodeToRect(n); const overlappingArea = getOverlappingArea(currNodeRect, nodeRect); const partiallyVisible = partially && overlappingArea > 0; return partiallyVisible || overlappingArea >= nodeOrRect.width! * nodeOrRect.height!; }); }, [] ); const isNodeIntersecting = useCallback>( (nodeOrRect, area, partially = true) => { const [nodeRect] = getNodeRect(nodeOrRect); if (!nodeRect) { return false; } const overlappingArea = getOverlappingArea(nodeRect, area); const partiallyVisible = partially && overlappingArea > 0; return partiallyVisible || overlappingArea >= nodeOrRect.width! * nodeOrRect.height!; }, [] ); return useMemo(() => { return { ...viewportHelper, getNodes, getNode, getEdges, getEdge, setNodes, setEdges, addNodes, addEdges, toObject, deleteElements, getIntersectingNodes, isNodeIntersecting, }; }, [ viewportHelper, getNodes, getNode, getEdges, getEdge, setNodes, setEdges, addNodes, addEdges, toObject, deleteElements, getIntersectingNodes, isNodeIntersecting, ]); }