import * as React from "react";
import { classList, nodeListToArray, findNextFocusableElement, focusLastActive, ContainerProps } from "../../util";
import { addRegion, FocusTrapProvider, removeRegion, useFocusTrapDispatch, useFocusTrapState } from "./context";
import { useId } from "../../../hooks/useId";
export interface FocusTrapProps extends ContainerProps {
onEscape: () => void;
className?: string;
arrowKeyNavigation?: boolean;
includeOutsideTabOrder?: boolean;
dontStealFocus?: boolean;
dontRestoreFocus?: boolean;
dontTrapFocus?: boolean;
focusFirstItem?: boolean;
tagName?: keyof JSX.IntrinsicElements;
ariaLabelledby?: string;
}
export const FocusTrap = (props: FocusTrapProps) => {
return (
);
}
const FocusTrapInner = (props: FocusTrapProps) => {
const {
children,
id,
className,
onEscape,
arrowKeyNavigation,
dontStealFocus,
includeOutsideTabOrder,
dontRestoreFocus,
dontTrapFocus,
focusFirstItem,
tagName,
role,
ariaLabelledby,
ariaLabel,
ariaHidden
} = props;
const containerRef = React.useRef(null);
const previouslyFocused = React.useRef(document.activeElement);
const [stoleFocus, setStoleFocus] = React.useState(false);
const lastValidTabElement = React.useRef(null);
const { regions } = useFocusTrapState();
React.useEffect(() => {
return () => {
if (!dontRestoreFocus && previouslyFocused.current) {
focusLastActive(previouslyFocused.current as HTMLElement)
}
}
}, [])
const getElements = React.useCallback(() => {
let all = nodeListToArray(
includeOutsideTabOrder ? containerRef.current?.querySelectorAll(`[tabindex]`) :
containerRef.current?.querySelectorAll(`[tabindex]:not([tabindex="-1"])`)
);
if (regions.length) {
const regionElements: pxt.Map = {};
for (const region of regions) {
const el = containerRef.current?.querySelector(`[data-focus-trap-region="${region.id}"]`);
if (el) {
regionElements[region.id] = el;
}
}
for (const region of regions) {
const regionElement = regionElements[region.id];
if (!region.enabled && regionElement) {
all = all.filter(el => !regionElement.contains(el));
}
}
const initialOrder = all.slice();
all.sort((a, b) => {
const aRegion = regions.find(r => r.enabled && regionElements[r.id]?.contains(a));
const bRegion = regions.find(r => r.enabled && regionElements[r.id]?.contains(b));
if (aRegion?.order === bRegion?.order) {
const aIndex = initialOrder.indexOf(a);
const bIndex = initialOrder.indexOf(b);
return aIndex - bIndex;
}
else if (!aRegion) {
return 1;
}
else if (!bRegion) {
return -1;
}
else {
return aRegion.order - bRegion.order;
}
});
}
return all as HTMLElement[];
}, [regions, includeOutsideTabOrder]);
const handleRef = React.useCallback((ref: HTMLDivElement) => {
if (!ref) return;
containerRef.current = ref;
const elements = getElements();
if (!dontStealFocus && !stoleFocus && !ref.contains(document.activeElement) && elements.length) {
containerRef.current.focus();
if (focusFirstItem) {
findNextFocusableElement(elements, -1, 0, true).focus();
}
// Only steal focus once
setStoleFocus(true);
}
}, [getElements, dontStealFocus, stoleFocus, focusFirstItem]);
const onKeyDown = React.useCallback((e: React.KeyboardEvent) => {
if (!containerRef.current) return;
const moveFocus = (forward: boolean, goToEnd: boolean) => {
const focusable = getElements();
if (!focusable.length) return;
let index = focusable.indexOf(e.target as HTMLElement);
if (index < 0) {
// If we have arrived at a non-indexed focusable, it's probably
// been triggered by a calling focus() on an element with
// tabindex=-1, from the last focusable element, so try to use
// that.
index = focusable.indexOf(lastValidTabElement.current);
}
let nextFocusableElement;
if (forward) {
if (goToEnd) {
nextFocusableElement = findNextFocusableElement(focusable, index, focusable.length - 1, forward);
}
else if (index === focusable.length - 1) {
nextFocusableElement = findNextFocusableElement(focusable, index, 0, forward);
}
else {
nextFocusableElement = findNextFocusableElement(focusable, index, index + 1, forward);
}
}
else {
if (goToEnd) {
nextFocusableElement = findNextFocusableElement(focusable, index, 0, forward);
}
else if (index === 0) {
nextFocusableElement = findNextFocusableElement(focusable, index, focusable.length - 1, forward);
}
else {
nextFocusableElement = findNextFocusableElement(focusable, index, Math.max(index - 1, 0), forward);
}
}
lastValidTabElement.current = nextFocusableElement;
nextFocusableElement.focus();
e.preventDefault();
e.stopPropagation();
}
if (e.key === "Escape") {
let foundHandler = false;
if (regions.length) {
for (const region of regions) {
if (!region.onEscape) continue;
const regionElement = containerRef.current?.querySelector(`[data-focus-trap-region="${region.id}"]`);
if (regionElement?.contains(document.activeElement)) {
foundHandler = true;
region.onEscape();
break;
}
}
}
if (!foundHandler) {
onEscape();
}
e.preventDefault();
e.stopPropagation();
}
else if (e.key === "Tab") {
if (dontTrapFocus) {
onEscape();
}
else if (e.shiftKey) moveFocus(false, false);
else moveFocus(true, false);
}
else if (arrowKeyNavigation) {
if (e.key === "ArrowDown") {
moveFocus(true, false);
}
else if (e.key === "ArrowUp") {
moveFocus(false, false);
}
else if (e.key === "Home") {
moveFocus(false, true);
}
else if (e.key === "End") {
moveFocus(true, true);
}
}
}, [getElements, onEscape, arrowKeyNavigation, regions, dontTrapFocus])
return React.createElement(
tagName || "div",
{
id,
className: classList("common-focus-trap", className),
ref: handleRef,
onKeyDown,
role,
tabIndex: -1,
"aria-labelledby": ariaLabelledby,
"aria-label": ariaLabel,
"aria-hidden": ariaHidden,
},
children
);
}
interface FocusTrapRegionProps extends React.PropsWithChildren<{}> {
enabled: boolean;
order?: number;
onEscape?: () => void;
id?: string;
className?: string;
divRef?: (ref: HTMLDivElement) => void;
}
export const FocusTrapRegion = (props: FocusTrapRegionProps) => {
const {
className,
id,
onEscape,
order,
enabled,
children,
divRef
} = props;
const regionId = useId();
const dispatch = useFocusTrapDispatch();
React.useEffect(() => {
dispatch(addRegion(regionId, order, enabled, onEscape));
return () => dispatch(removeRegion(regionId));
}, [regionId, enabled, order])
return (
{children}
)
}