/* https://github.com/mui/base-ui/blob/f2d7a90e3a20dee84955beb5be0d59e50f45ae7e/packages/utils/src/useMergedRefs.ts */ import React from "react"; import { useRefWithInit } from "./useRefWithInit"; type RefCallback = (instance: T | null) => void; type Empty = null | undefined; type InputRef = React.Ref | Empty; type Result = RefCallback | null; type Cleanup = () => void; type ForkRef = { callback: RefCallback | null; cleanup: Cleanup | null; refs: InputRef[]; }; /** * Merges refs into a single memoized callback ref or `null`. * This makes sure multiple refs are updated together and have the same value. * * This function accepts up to four refs. If you need to merge more, or have an unspecified number of refs to merge, * use `useMergeRefsN` instead. */ export function useMergeRefs(a: InputRef, b: InputRef): Result; export function useMergeRefs( a: InputRef, b: InputRef, c: InputRef, ): Result; export function useMergeRefs( a: InputRef, b: InputRef, c: InputRef, d: InputRef, ): Result; export function useMergeRefs( a: InputRef, b: InputRef, c?: InputRef, d?: InputRef, ): Result { const forkRef = useRefWithInit(createForkRef).current!; if (didChange(forkRef, a, b, c, d)) { update(forkRef, [a, b, c, d]); } return forkRef.callback; } /** * Merges an array of refs into a single memoized callback ref or `null`. * * If you need to merge a fixed number (up to four) of refs, use `useMergeRefs` instead for better performance. */ export function useMergeRefsN(refs: InputRef[]): Result { const forkRef = useRefWithInit(createForkRef).current!; if (didChangeN(forkRef, refs)) { update(forkRef, refs); } return forkRef.callback; } function createForkRef(): ForkRef { return { callback: null, cleanup: null as Cleanup | null, refs: [], }; } function didChange( forkRef: ForkRef, a: InputRef, b: InputRef, c: InputRef, d: InputRef, ) { return ( forkRef.refs[0] !== a || forkRef.refs[1] !== b || forkRef.refs[2] !== c || forkRef.refs[3] !== d ); } function didChangeN(forkRef: ForkRef, newRefs: InputRef[]) { return ( forkRef.refs.length !== newRefs.length || forkRef.refs.some((ref, index) => ref !== newRefs[index]) ); } function update(forkRef: ForkRef, refs: InputRef[]) { forkRef.refs = refs; if (refs.every((ref) => ref == null)) { forkRef.callback = null; return; } forkRef.callback = (instance) => { if (forkRef.cleanup) { forkRef.cleanup(); forkRef.cleanup = null; } if (instance != null) { const cleanupCallbacks = Array(refs.length).fill( null, ) as (Cleanup | null)[]; for (let i = 0; i < refs.length; i += 1) { const ref = refs[i]; if (ref == null) { continue; } switch (typeof ref) { case "function": { const refCleanup = ref(instance); if (typeof refCleanup === "function") { cleanupCallbacks[i] = refCleanup; } break; } case "object": { (ref as React.MutableRefObject).current = instance; break; } default: } } forkRef.cleanup = () => { for (let i = 0; i < refs.length; i += 1) { const ref = refs[i]; if (ref == null) { continue; } switch (typeof ref) { case "function": { const cleanupCallback = cleanupCallbacks[i]; if (typeof cleanupCallback === "function") { cleanupCallback(); } else { ref(null); } break; } case "object": { (ref as React.MutableRefObject).current = null; break; } default: } } }; } }; }