import { ROW_FOCUS_ID } from '../constants' import type { HeaderHierarchy, HeaderHierarchyGroup } from '../state' import type { GridFocusState } from '../state/reducer/focus' import type { GridRowId } from '../types' export const getColumnIndex = ( columnIds: string[], columnId: string, subFocus: number | 'first' | 'last' ): number => { if (columnId === ROW_FOCUS_ID) { return subFocus === 'last' ? -1 : columnIds.length } return columnIds.indexOf(columnId) } export const getHeaderRowIndex = ( headerRowIds: GridRowId[], rowId: GridRowId ): number => headerRowIds.indexOf(rowId) const getFocusedGroup = ( { rowId, columnId }: GridFocusState['focus'], { levels, groups }: HeaderHierarchy, headerRowIds: GridRowId[] ): HeaderHierarchyGroup | undefined => { const rowIndex = getHeaderRowIndex(headerRowIds, rowId) if (levels) { const currLevel = levels - rowIndex const groupsFound = groups .filter( (g) => g.columnIds.includes(columnId) && g.level <= currLevel ) .sort((g1, g2) => g2.level - g1.level) return groupsFound[0] } return undefined } export const findNextHeaderCell = ( direction: 'up' | 'down' | 'left' | 'right', currentFocus: GridFocusState['focus'], columnIds: string[], headerHierarchy: HeaderHierarchy, headerRowIds: GridRowId[] ): { rowIndex: number; columnIndex: number } => { const columnIndex = getColumnIndex( columnIds, currentFocus.columnId, currentFocus.subFocus ) const groupWithFocus = getFocusedGroup( currentFocus, headerHierarchy, headerRowIds ) let nextColumnId = currentFocus.columnId let nextRowIndex = getHeaderRowIndex(headerRowIds, currentFocus.rowId) if (direction === 'up') { if (groupWithFocus) { nextRowIndex = groupWithFocus.parentIds.length ? Math.max(nextRowIndex - 1, 0) : nextRowIndex } else { const anyParentsAbove = headerHierarchy.groups .filter((g) => g.columnIds.includes(currentFocus.columnId)) .sort((g1, g2) => g1.level - g2.level) const parentAbove = anyParentsAbove[0] nextRowIndex = parentAbove ? headerHierarchy.levels - parentAbove.level : nextRowIndex } } else if (direction === 'down') { if (groupWithFocus) { const anyParentBelow = !!headerHierarchy.groups.find( (g) => g.columnIds.includes(currentFocus.columnId) && g.level < groupWithFocus.level ) nextRowIndex = Math.min( nextRowIndex + (anyParentBelow ? 1 : 2), headerHierarchy.levels ) } else { nextRowIndex = -1 } } else if (direction === 'left') { if (groupWithFocus) { const remainingIds = columnIds.slice(0, columnIndex).reverse() nextColumnId = remainingIds.find( (id) => !groupWithFocus.columnIds.includes(id) ) ?? nextColumnId } else { nextColumnId = columnIds[columnIndex - 1] } } else if (direction === 'right') { if (groupWithFocus) { const remainingIds = columnIds.slice(columnIndex) nextColumnId = remainingIds.find( (id) => !groupWithFocus.columnIds.includes(id) ) ?? nextColumnId } else { nextColumnId = columnIds[columnIndex + 1] } } return { rowIndex: nextRowIndex, columnIndex: getColumnIndex( columnIds, nextColumnId, currentFocus.subFocus ), } }