diff --git a/packages/material-react-table/src/components/head/MRT_TableHeadCell.tsx b/packages/material-react-table/src/components/head/MRT_TableHeadCell.tsx index ae37780ab..880826e83 100644 --- a/packages/material-react-table/src/components/head/MRT_TableHeadCell.tsx +++ b/packages/material-react-table/src/components/head/MRT_TableHeadCell.tsx @@ -1,4 +1,4 @@ -import { type DragEvent, useMemo } from 'react'; +import { type DragEvent, useMemo, useCallback } from 'react'; import Box from '@mui/material/Box'; import TableCell, { type TableCellProps } from '@mui/material/TableCell'; import { useTheme } from '@mui/material/styles'; @@ -159,6 +159,20 @@ export const MRT_TableHeadCell = ({ }); }; + const handleRef = useCallback( + (node: HTMLTableCellElement) => { + if (node) { + if (tableHeadCellRefs.current) { + tableHeadCellRefs.current[column.id] = node; + } + if (columnDefType !== 'group') { + columnVirtualizer?.measureElement?.(node); + } + } + }, + [column.id, columnDefType, columnVirtualizer, tableHeadCellRefs], + ); + const HeaderElement = parseFromValuesOrFunc(columnDef.Header, { column, @@ -189,14 +203,7 @@ export const MRT_TableHeadCell = ({ data-sort={column.getIsSorted() || undefined} onDragEnter={handleDragEnter} onDragOver={handleDragOver} - ref={(node: HTMLTableCellElement) => { - if (node) { - tableHeadCellRefs.current![column.id] = node; - if (columnDefType !== 'group') { - columnVirtualizer?.measureElement?.(node); - } - } - }} + ref={handleRef} tabIndex={enableKeyboardShortcuts ? 0 : undefined} {...tableCellProps} onKeyDown={handleKeyDown}