"use client" import { Check } from "lucide-react" import * as React from "react" import { Slot } from "@radix-ui/react-slot" import { cn } from "../../../lib/utils" const ROOT_NAME = "Stepper" const LIST_NAME = "StepperList" const ITEM_NAME = "StepperItem" const TRIGGER_NAME = "StepperTrigger" const INDICATOR_NAME = "StepperIndicator" const SEPARATOR_NAME = "StepperSeparator" const TITLE_NAME = "StepperTitle" const DESCRIPTION_NAME = "StepperDescription" const CONTENT_NAME = "StepperContent" const PREV_NAME = "StepperPrev" const NEXT_NAME = "StepperNext" const ENTRY_FOCUS = "stepperFocusGroup.onEntryFocus" const EVENT_OPTIONS = { bubbles: false, cancelable: true } const ARROW_KEYS = ["ArrowUp", "ArrowDown", "ArrowLeft", "ArrowRight"] type Direction = "ltr" | "rtl" type Orientation = "horizontal" | "vertical" type NavigationDirection = "next" | "prev" type ActivationMode = "automatic" | "manual" type DataState = "inactive" | "active" | "completed" interface DivProps extends React.ComponentProps<"div"> { asChild?: boolean } interface ButtonProps extends React.ComponentProps<"button"> { asChild?: boolean } type ListElement = HTMLDivElement type TriggerElement = HTMLButtonElement function getId( id: string, variant: "trigger" | "content" | "title" | "description", value: string, ) { return `${id}-${variant}-${value}` } type FocusIntent = "first" | "last" | "prev" | "next" const MAP_KEY_TO_FOCUS_INTENT: Record = { ArrowLeft: "prev", ArrowUp: "prev", ArrowRight: "next", ArrowDown: "next", PageUp: "first", Home: "first", PageDown: "last", End: "last", } function getDirectionAwareKey(key: string, dir?: Direction) { if (dir !== "rtl") return key return key === "ArrowLeft" ? "ArrowRight" : key === "ArrowRight" ? "ArrowLeft" : key } function getFocusIntent( event: React.KeyboardEvent, dir?: Direction, orientation?: Orientation, ) { const key = getDirectionAwareKey(event.key, dir) if (orientation === "horizontal" && ["ArrowUp", "ArrowDown"].includes(key)) return undefined if (orientation === "vertical" && ["ArrowLeft", "ArrowRight"].includes(key)) return undefined return MAP_KEY_TO_FOCUS_INTENT[key] } function focusFirst( candidates: React.RefObject[], preventScroll = false, ) { const PREVIOUSLY_FOCUSED_ELEMENT = document.activeElement for (const candidateRef of candidates) { const candidate = candidateRef.current if (!candidate) continue if (candidate === PREVIOUSLY_FOCUSED_ELEMENT) return candidate.focus({ preventScroll }) if (document.activeElement !== PREVIOUSLY_FOCUSED_ELEMENT) return } } function wrapArray(array: T[], startIndex: number) { return array.map( (_, index) => array[(startIndex + index) % array.length] as T, ) } function getDataState( value: string | undefined, itemValue: string, stepState: StepState | undefined, steps: Map, variant: "item" | "separator" = "item", ): DataState { const stepKeys = Array.from(steps.keys()) const currentIndex = stepKeys.indexOf(itemValue) if (stepState?.completed) return "completed" if (value === itemValue) { return variant === "separator" ? "inactive" : "active" } if (value) { const activeIndex = stepKeys.indexOf(value) if (activeIndex > currentIndex) return "completed" } return "inactive" } interface StepState { value: string completed: boolean disabled: boolean } interface StoreState { steps: Map value: string } interface Store { subscribe: (callback: () => void) => () => void getState: () => StoreState setState: (key: K, value: StoreState[K]) => void setStateWithValidation: ( value: string, direction: NavigationDirection, ) => Promise hasValidation: () => boolean notify: () => void addStep: (value: string, completed: boolean, disabled: boolean) => void removeStep: (value: string) => void setStep: (value: string, completed: boolean, disabled: boolean) => void } const StoreContext = React.createContext(null) function useStoreContext(consumerName: string) { const context = React.useContext(StoreContext) if (!context) { throw new Error(`\`${consumerName}\` must be used within \`${ROOT_NAME}\``) } return context } function useStore(selector: (state: StoreState) => T): T { const store = useStoreContext("useStore") const getSnapshot = React.useCallback( () => selector(store.getState()), [store, selector], ) return React.useSyncExternalStore(store.subscribe, getSnapshot, getSnapshot) } interface ItemData { id: string ref: React.RefObject value: string active: boolean disabled: boolean } interface StepperContextValue { rootId: string dir: Direction orientation: Orientation activationMode: ActivationMode disabled: boolean nonInteractive: boolean loop: boolean } const StepperContext = React.createContext(null) function useStepperContext(consumerName: string) { const context = React.useContext(StepperContext) if (!context) { throw new Error(`\`${consumerName}\` must be used within \`${ROOT_NAME}\``) } return context } interface StepperProps extends DivProps { value?: string defaultValue?: string onValueChange?: (value: string) => void onValueComplete?: (value: string, completed: boolean) => void onValueAdd?: (value: string) => void onValueRemove?: (value: string) => void onValidate?: ( value: string, direction: NavigationDirection, ) => boolean | Promise activationMode?: ActivationMode dir?: Direction orientation?: Orientation disabled?: boolean loop?: boolean nonInteractive?: boolean } function Stepper(props: StepperProps) { const { value, defaultValue, onValueChange, onValueComplete, onValueAdd, onValueRemove, onValidate, dir: dirProp, orientation = "horizontal", activationMode = "automatic", asChild, disabled = false, nonInteractive = false, loop = false, className, id, ...rootProps } = props const listenersRef = React.useRef void>>(new Set()) const stateRef = React.useRef({ steps: new Map(), value: value ?? defaultValue ?? "", }) const propsRef = React.useRef({ onValueChange, onValueComplete, onValueAdd, onValueRemove, onValidate, }) React.useEffect(() => { propsRef.current = { onValueChange, onValueComplete, onValueAdd, onValueRemove, onValidate, } }, [onValueChange, onValueComplete, onValueAdd, onValueRemove, onValidate]) const store = React.useMemo(() => { return { subscribe: (cb) => { listenersRef.current.add(cb) return () => listenersRef.current.delete(cb) }, getState: () => stateRef.current, setState: (key, value) => { if (Object.is(stateRef.current[key], value)) return if (key === "value" && typeof value === "string") { stateRef.current.value = value propsRef.current.onValueChange?.(value) } else { stateRef.current[key] = value } store.notify() }, setStateWithValidation: async (value, direction) => { if (!propsRef.current.onValidate) { store.setState("value", value) return true } try { const isValid = await propsRef.current.onValidate(value, direction) if (isValid) { store.setState("value", value) } return isValid } catch { return false } }, hasValidation: () => !!propsRef.current.onValidate, addStep: (value, completed, disabled) => { const newStep: StepState = { value, completed, disabled } stateRef.current.steps.set(value, newStep) propsRef.current.onValueAdd?.(value) store.notify() }, removeStep: (value) => { stateRef.current.steps.delete(value) propsRef.current.onValueRemove?.(value) store.notify() }, setStep: (value, completed, disabled) => { const step = stateRef.current.steps.get(value) if (step) { const updatedStep: StepState = { ...step, completed, disabled } stateRef.current.steps.set(value, updatedStep) if (completed !== step.completed) { propsRef.current.onValueComplete?.(value, completed) } store.notify() } }, notify: () => { for (const cb of listenersRef.current) { cb() } }, } }, []) React.useEffect(() => { if (value !== undefined) { store.setState("value", value) } }, [value, store]) const dir: Direction = dirProp ?? "ltr" const instanceId = React.useId() const rootId = id ?? instanceId const contextValue = React.useMemo( () => ({ rootId, dir, orientation, activationMode, disabled, nonInteractive, loop, }), [rootId, dir, orientation, activationMode, disabled, nonInteractive, loop], ) const RootPrimitive = asChild ? Slot : "div" return ( ) } interface FocusContextValue { tabStopId: string | null onItemFocus: (tabStopId: string) => void onItemShiftTab: () => void onFocusableItemAdd: () => void onFocusableItemRemove: () => void onItemRegister: (item: ItemData) => void onItemUnregister: (id: string) => void getItems: () => ItemData[] } const FocusContext = React.createContext(null) function useFocusContext(consumerName: string) { const context = React.useContext(FocusContext) if (!context) { throw new Error( `\`${consumerName}\` must be used within \`FocusProvider\``, ) } return context } function StepperList(props: DivProps) { const { asChild, onBlur: onBlurProp, onFocus: onFocusProp, onMouseDown: onMouseDownProp, className, children, ref, ...listProps } = props const context = useStepperContext(LIST_NAME) const orientation = context.orientation const currentValue = useStore((state) => state.value) const propsRef = React.useRef({ onBlur: onBlurProp, onFocus: onFocusProp, onMouseDown: onMouseDownProp, }) React.useEffect(() => { propsRef.current = { onBlur: onBlurProp, onFocus: onFocusProp, onMouseDown: onMouseDownProp } }, [onBlurProp, onFocusProp, onMouseDownProp]) const [tabStopId, setTabStopId] = React.useState(null) const [isTabbingBackOut, setIsTabbingBackOut] = React.useState(false) const [focusableItemCount, setFocusableItemCount] = React.useState(0) const isClickFocusRef = React.useRef(false) const itemsRef = React.useRef>(new Map()) const listRef = React.useRef(null) const composedRef = React.useCallback( (node: ListElement | null) => { listRef.current = node if (typeof ref === "function") { ref(node) } else if (ref) { ref.current = node } }, [ref], ) const onItemFocus = React.useCallback((tabStopId: string) => { setTabStopId(tabStopId) }, []) const onItemShiftTab = React.useCallback(() => { setIsTabbingBackOut(true) }, []) const onFocusableItemAdd = React.useCallback(() => { setFocusableItemCount((prevCount) => prevCount + 1) }, []) const onFocusableItemRemove = React.useCallback(() => { setFocusableItemCount((prevCount) => prevCount - 1) }, []) const onItemRegister = React.useCallback((item: ItemData) => { itemsRef.current.set(item.id, item) }, []) const onItemUnregister = React.useCallback((id: string) => { itemsRef.current.delete(id) }, []) const getItems = React.useCallback(() => { return Array.from(itemsRef.current.values()) .filter((item) => item.ref.current) .sort((a, b) => { const elementA = a.ref.current const elementB = b.ref.current if (!elementA || !elementB) return 0 const position = elementA.compareDocumentPosition(elementB) if (position & Node.DOCUMENT_POSITION_FOLLOWING) { return -1 } if (position & Node.DOCUMENT_POSITION_PRECEDING) { return 1 } return 0 }) }, []) const onBlur = React.useCallback( (event: React.FocusEvent) => { propsRef.current.onBlur?.(event) if (event.defaultPrevented) return setIsTabbingBackOut(false) }, [], ) const onFocus = React.useCallback( (event: React.FocusEvent) => { propsRef.current.onFocus?.(event) if (event.defaultPrevented) return const isKeyboardFocus = !isClickFocusRef.current if ( event.target === event.currentTarget && isKeyboardFocus && !isTabbingBackOut ) { const entryFocusEvent = new CustomEvent(ENTRY_FOCUS, EVENT_OPTIONS) event.currentTarget.dispatchEvent(entryFocusEvent) if (!entryFocusEvent.defaultPrevented) { const items = Array.from(itemsRef.current.values()).filter( (item) => !item.disabled, ) const selectedItem = currentValue ? items.find((item) => item.value === currentValue) : undefined const activeItem = items.find((item) => item.active) const currentItem = items.find((item) => item.id === tabStopId) const candidateItems = [ selectedItem, activeItem, currentItem, ...items, ].filter(Boolean) as ItemData[] const candidateRefs = candidateItems.map((item) => item.ref) focusFirst(candidateRefs, false) } } isClickFocusRef.current = false }, [isTabbingBackOut, currentValue, tabStopId], ) const onMouseDown = React.useCallback( (event: React.MouseEvent) => { propsRef.current.onMouseDown?.(event) if (event.defaultPrevented) return isClickFocusRef.current = true }, [], ) const focusContextValue = React.useMemo( () => ({ tabStopId, onItemFocus, onItemShiftTab, onFocusableItemAdd, onFocusableItemRemove, onItemRegister, onItemUnregister, getItems, }), [ tabStopId, onItemFocus, onItemShiftTab, onFocusableItemAdd, onFocusableItemRemove, onItemRegister, onItemUnregister, getItems, ], ) const ListPrimitive = asChild ? Slot : "div" return ( {children} ) } interface StepperItemContextValue { value: string stepState: StepState | undefined } const StepperItemContext = React.createContext( null, ) function useStepperItemContext(consumerName: string) { const context = React.useContext(StepperItemContext) if (!context) { throw new Error(`\`${consumerName}\` must be used within \`${ITEM_NAME}\``) } return context } interface StepperItemProps extends DivProps { value: string completed?: boolean disabled?: boolean } function StepperItem(props: StepperItemProps) { const { value: itemValue, completed = false, disabled = false, asChild, className, children, ref, ...itemProps } = props const context = useStepperContext(ITEM_NAME) const store = useStoreContext(ITEM_NAME) const orientation = context.orientation const value = useStore((state) => state.value) React.useEffect(() => { store.addStep(itemValue, completed, disabled) return () => { store.removeStep(itemValue) } }, [itemValue, completed, disabled, store]) React.useEffect(() => { store.setStep(itemValue, completed, disabled) }, [itemValue, completed, disabled, store]) const stepState = useStore((state) => state.steps.get(itemValue)) const steps = useStore((state) => state.steps) const dataState = getDataState(value, itemValue, stepState, steps) const itemContextValue = React.useMemo( () => ({ value: itemValue, stepState, }), [itemValue, stepState], ) const ItemPrimitive = asChild ? Slot : "div" return ( {children} ) } function StepperTrigger(props: ButtonProps) { const { asChild, onClick: onClickProp, onFocus: onFocusProp, onKeyDown: onKeyDownProp, onMouseDown: onMouseDownProp, disabled, className, ref, ...triggerProps } = props const context = useStepperContext(TRIGGER_NAME) const itemContext = useStepperItemContext(TRIGGER_NAME) const itemValue = itemContext.value const store = useStoreContext(TRIGGER_NAME) const focusContext = useFocusContext(TRIGGER_NAME) const value = useStore((state) => state.value) const steps = useStore((state) => state.steps) const stepState = useStore((state) => state.steps.get(itemValue)) const propsRef = React.useRef({ onClick: onClickProp, onFocus: onFocusProp, onKeyDown: onKeyDownProp, onMouseDown: onMouseDownProp, }) React.useEffect(() => { propsRef.current = { onClick: onClickProp, onFocus: onFocusProp, onKeyDown: onKeyDownProp, onMouseDown: onMouseDownProp, } }, [onClickProp, onFocusProp, onKeyDownProp, onMouseDownProp]) const activationMode = context.activationMode const orientation = context.orientation const loop = context.loop const stepIndex = Array.from(steps.keys()).indexOf(itemValue) const stepPosition = stepIndex + 1 const stepCount = steps.size const triggerId = getId(context.rootId, "trigger", itemValue) const contentId = getId(context.rootId, "content", itemValue) const titleId = getId(context.rootId, "title", itemValue) const descriptionId = getId(context.rootId, "description", itemValue) const isDisabled = disabled || stepState?.disabled || context.disabled const isActive = value === itemValue const isTabStop = focusContext.tabStopId === triggerId const dataState = getDataState(value, itemValue, stepState, steps) const triggerRef = React.useRef(null) const composedRef = React.useCallback( (node: TriggerElement | null) => { triggerRef.current = node if (typeof ref === "function") { ref(node) } else if (ref) { ref.current = node } }, [ref], ) const isArrowKeyPressedRef = React.useRef(false) const isMouseClickRef = React.useRef(false) React.useEffect(() => { function onKeyDown(event: KeyboardEvent) { if (ARROW_KEYS.includes(event.key)) { isArrowKeyPressedRef.current = true } } function onKeyUp() { isArrowKeyPressedRef.current = false } document.addEventListener("keydown", onKeyDown) document.addEventListener("keyup", onKeyUp) return () => { document.removeEventListener("keydown", onKeyDown) document.removeEventListener("keyup", onKeyUp) } }, []) React.useEffect(() => { focusContext.onItemRegister({ id: triggerId, ref: triggerRef, value: itemValue, active: isTabStop, disabled: !!isDisabled, }) if (!isDisabled) { focusContext.onFocusableItemAdd() } return () => { focusContext.onItemUnregister(triggerId) if (!isDisabled) { focusContext.onFocusableItemRemove() } } }, [focusContext, triggerId, itemValue, isTabStop, isDisabled]) const onClick = React.useCallback( async (event: React.MouseEvent) => { propsRef.current.onClick?.(event) if (event.defaultPrevented) return if (!isDisabled && !context.nonInteractive) { const currentStepIndex = Array.from(steps.keys()).indexOf(value ?? "") const targetStepIndex = Array.from(steps.keys()).indexOf(itemValue) const direction = targetStepIndex > currentStepIndex ? "next" : "prev" await store.setStateWithValidation(itemValue, direction) } }, [isDisabled, context.nonInteractive, store, itemValue, value, steps], ) const onFocus = React.useCallback( async (event: React.FocusEvent) => { propsRef.current.onFocus?.(event) if (event.defaultPrevented) return focusContext.onItemFocus(triggerId) const isKeyboardFocus = !isMouseClickRef.current if ( !isActive && !isDisabled && activationMode !== "manual" && !context.nonInteractive && isKeyboardFocus ) { const currentStepIndex = Array.from(steps.keys()).indexOf(value || "") const targetStepIndex = Array.from(steps.keys()).indexOf(itemValue) const direction = targetStepIndex > currentStepIndex ? "next" : "prev" await store.setStateWithValidation(itemValue, direction) } isMouseClickRef.current = false }, [ focusContext, triggerId, activationMode, isActive, isDisabled, context.nonInteractive, store, itemValue, value, steps, ], ) const onKeyDown = React.useCallback( async (event: React.KeyboardEvent) => { propsRef.current.onKeyDown?.(event) if (event.defaultPrevented) return if (event.key === "Enter" && context.nonInteractive) { event.preventDefault() return } if ( (event.key === "Enter" || event.key === " ") && activationMode === "manual" && !context.nonInteractive ) { event.preventDefault() if (!isDisabled && triggerRef.current) { triggerRef.current.click() } return } if (event.key === "Tab" && event.shiftKey) { focusContext.onItemShiftTab() return } if (event.target !== event.currentTarget) return const focusIntent = getFocusIntent(event, context.dir, orientation) if (focusIntent !== undefined) { if (event.metaKey || event.ctrlKey || event.altKey || event.shiftKey) return event.preventDefault() const items = focusContext.getItems().filter((item) => !item.disabled) let candidateRefs = items.map((item) => item.ref) if (focusIntent === "last") { candidateRefs.reverse() } else if (focusIntent === "prev" || focusIntent === "next") { if (focusIntent === "prev") candidateRefs.reverse() const currentIndex = candidateRefs.findIndex( (ref) => ref.current === event.currentTarget, ) candidateRefs = loop ? wrapArray(candidateRefs, currentIndex + 1) : candidateRefs.slice(currentIndex + 1) } if (store.hasValidation() && candidateRefs.length > 0) { const nextRef = candidateRefs[0] const nextElement = nextRef?.current const nextItem = items.find( (item) => item.ref.current === nextElement, ) if (nextItem && nextItem.value !== itemValue) { const currentStepIndex = Array.from(steps.keys()).indexOf( value || "", ) const targetStepIndex = Array.from(steps.keys()).indexOf( nextItem.value, ) const direction: NavigationDirection = targetStepIndex > currentStepIndex ? "next" : "prev" if (direction === "next") { const isValid = await store.setStateWithValidation( nextItem.value, direction, ) if (!isValid) return } else { store.setState("value", nextItem.value) } queueMicrotask(() => nextElement?.focus()) return } } queueMicrotask(() => focusFirst(candidateRefs)) } }, [ focusContext, context.nonInteractive, context.dir, activationMode, orientation, loop, isDisabled, store, itemValue, value, steps, ], ) const onMouseDown = React.useCallback( (event: React.MouseEvent) => { propsRef.current.onMouseDown?.(event) if (event.defaultPrevented) return isMouseClickRef.current = true if (isDisabled) { event.preventDefault() } else { focusContext.onItemFocus(triggerId) } }, [focusContext, triggerId, isDisabled], ) const TriggerPrimitive = asChild ? Slot : "button" return ( ) } interface StepperIndicatorProps extends Omit { children?: React.ReactNode | ((dataState: DataState) => React.ReactNode) } function StepperIndicator(props: StepperIndicatorProps) { const { className, children, asChild, ref, ...indicatorProps } = props const context = useStepperContext(INDICATOR_NAME) const itemContext = useStepperItemContext(INDICATOR_NAME) const value = useStore((state) => state.value) const itemValue = itemContext.value const stepState = useStore((state) => state.steps.get(itemValue)) const steps = useStore((state) => state.steps) const stepPosition = Array.from(steps.keys()).indexOf(itemValue) + 1 const dataState = getDataState(value, itemValue, stepState, steps) const IndicatorPrimitive = asChild ? Slot : "div" return ( {typeof children === "function" ? ( children(dataState) ) : children ? ( children ) : dataState === "completed" ? ( ) : ( stepPosition )} ) } interface StepperSeparatorProps extends DivProps { forceMount?: boolean } function StepperSeparator(props: StepperSeparatorProps) { const { className, asChild, forceMount = false, ref, ...separatorProps } = props const context = useStepperContext(SEPARATOR_NAME) const itemContext = useStepperItemContext(SEPARATOR_NAME) const value = useStore((state) => state.value) const steps = useStore((state) => state.steps) const orientation = context.orientation const stepIndex = Array.from(steps.keys()).indexOf(itemContext.value) const isLastStep = stepIndex === steps.size - 1 if (isLastStep && !forceMount) return null const dataState = getDataState( value, itemContext.value, itemContext.stepState, steps, "separator", ) const SeparatorPrimitive = asChild ? Slot : "div" return (