import type { DragMoveEvent, DragStartEvent } from '@dnd-kit/core' import React from 'react' import { useGridContext } from '../../../context/grid-context/hook' import type { ColumnState } from '../../../state' import { DRAG_TYPE_HEADER_GROUP_CELL } from '../constants' import type { ColumnCompare } from './utils' import { calculateColumnMoving, calculateGroupMoving } from './utils' import { buildHeaderHierarchy, getLowestLeafs, } from '../../../state/reducer/column-utils' export const useColumnMove = () => { const grid = useGridContext() const initialColumnIdsWhileMoving = React.useRef(null) const [columnsWhileMoving, setColumnsWhileMoving] = React.useState< ColumnState[] | null >(null) const [movedColumnIndex, setMovedColumnIndex] = React.useState< number | null >(null) const [moveState, setMoveState] = React.useState<{ id: string columnIds: string[] order: string[] initialX: number bounds: { initialLeft: number locations: ColumnCompare[] currentIndex: number maxIndex: number minIndex: number minLeft: number maxLeft: number } } | null>(null) const onMoveStart = React.useCallback( (id: string, e: DragStartEvent) => { const state = grid.getState() const columnIdsState = grid.selectors.selectColumnIdsWithHidden(state) const hiddenIds = grid.selectors.selectHiddenIds(state) const columns = columnIdsState.map( (id) => state.columns.entities[id] ) const getGroup = (id: string) => state.columns.entities[id] const isGroup = !!getGroup(id).isParent const lowestLeafs = isGroup ? getLowestLeafs(id, state.columns.entities, []) : [] const columnIds = isGroup ? lowestLeafs .map((c) => c.id) .sort( (a, b) => columnIdsState.indexOf(a) - columnIdsState.indexOf(b) ) : [id] const entities = grid.selectors.selectColumnEntities(state) const headerHierarchy = buildHeaderHierarchy(entities, []) setColumnsWhileMoving(columns) initialColumnIdsWhileMoving.current = columnIdsState const bounds = isGroup ? calculateGroupMoving( id, columns, headerHierarchy.groups, getGroup, hiddenIds ) : calculateColumnMoving( id, columns, headerHierarchy.groups, getGroup, hiddenIds ) setMoveState({ id, columnIds, initialX: e.active.rect.current.initial?.left ?? 0, order: columns.map((c) => c.id), bounds, }) }, [grid] ) const onMove = React.useCallback( (e: DragMoveEvent) => { if (!moveState) { return } const { bounds, columnIds } = moveState const { initialLeft, currentIndex, locations, maxIndex, minIndex } = bounds const diff = e.delta.x const newLeft = initialLeft + diff const leftMost = locations.find( ({ location, compare }) => compare === 'less' && newLeft < location ) const rightMost = [...locations] .reverse() .find( ({ location, compare }) => compare === 'more' && newLeft + (e.active.data.current?.type === DRAG_TYPE_HEADER_GROUP_CELL ? (e.active.rect.current.initial?.width ?? 0) : 0) * 0.5 > location ) let { newIndex } = leftMost ?? rightMost ?? { newIndex: currentIndex, } newIndex = Math.max(minIndex, Math.min(maxIndex, newIndex)) setMovedColumnIndex(newIndex) if (movedColumnIndex !== newIndex) { let ids = columnsWhileMoving!.map((c) => c.id) if (newIndex !== currentIndex) { ids = ids.filter((cid) => !columnIds.includes(cid)) const modifiedIndex = newIndex > currentIndex ? newIndex - (columnIds.length - 1) //Take into account multiple columns being moved : newIndex ids.splice(modifiedIndex, 0, ...columnIds) } grid.api.column.reorderColumns(ids) } }, [moveState, columnsWhileMoving, movedColumnIndex, grid] ) const onMoveEnd = React.useCallback(() => { setMoveState(null) setMovedColumnIndex(null) setColumnsWhileMoving(null) grid.events.emit('onColumnStateChange') initialColumnIdsWhileMoving.current = null }, [grid]) const onMoveCancel = React.useCallback(() => { grid.api.column.reorderColumns(initialColumnIdsWhileMoving.current!) onMoveEnd() }, [grid.api.column, onMoveEnd]) return { minLeft: moveState?.bounds.minLeft ?? 0, maxLeft: moveState?.bounds.maxLeft ?? Infinity, onMoveStart, onMove, onMoveEnd, onMoveCancel, } }