diff --git a/packages/mui-base/src/ClickAwayListener/ClickAwayListener.tsx b/packages/mui-base/src/ClickAwayListener/ClickAwayListener.tsx index ee00d7d9fc8d2b..a3270f05b34c5c 100644 --- a/packages/mui-base/src/ClickAwayListener/ClickAwayListener.tsx +++ b/packages/mui-base/src/ClickAwayListener/ClickAwayListener.tsx @@ -7,6 +7,7 @@ import { unstable_ownerDocument as ownerDocument, unstable_useForkRef as useForkRef, unstable_useEventCallback as useEventCallback, + unstable_getReactElementRef as getReactElementRef, } from '@mui/utils'; // TODO: return `EventHandlerName extends `on${infer EventName}` ? Lowercase : never` once generatePropTypes runs with TS 4.1 @@ -94,11 +95,7 @@ function ClickAwayListener(props: ClickAwayListenerProps): React.JSX.Element { }; }, []); - const handleRef = useForkRef( - // @ts-expect-error TODO upstream fix - children.ref, - nodeRef, - ); + const handleRef = useForkRef(getReactElementRef(children), nodeRef); // The handler doesn't take event.defaultPrevented into account: // diff --git a/packages/mui-base/src/FocusTrap/FocusTrap.tsx b/packages/mui-base/src/FocusTrap/FocusTrap.tsx index 3d00b5fc6f7bdc..5888a8031c634f 100644 --- a/packages/mui-base/src/FocusTrap/FocusTrap.tsx +++ b/packages/mui-base/src/FocusTrap/FocusTrap.tsx @@ -7,6 +7,7 @@ import { elementAcceptingRef, unstable_useForkRef as useForkRef, unstable_ownerDocument as ownerDocument, + unstable_getReactElementRef as getReactElementRef, } from '@mui/utils'; import { FocusTrapProps } from './FocusTrap.types'; @@ -152,8 +153,7 @@ function FocusTrap(props: FocusTrapProps): React.JSX.Element { const activated = React.useRef(false); const rootRef = React.useRef(null); - // @ts-expect-error TODO upstream fix - const handleRef = useForkRef(children.ref, rootRef); + const handleRef = useForkRef(getReactElementRef(children), rootRef); const lastKeydown = React.useRef(null); React.useEffect(() => { diff --git a/packages/mui-base/src/Portal/Portal.tsx b/packages/mui-base/src/Portal/Portal.tsx index c7b5c403decba3..830b229d8eb175 100644 --- a/packages/mui-base/src/Portal/Portal.tsx +++ b/packages/mui-base/src/Portal/Portal.tsx @@ -2,6 +2,7 @@ import * as React from 'react'; import * as ReactDOM from 'react-dom'; import PropTypes from 'prop-types'; +import getReactElementRef from '@mui/utils/getReactElementRef'; import { exactProp, HTMLElementType, @@ -33,8 +34,11 @@ const Portal = React.forwardRef(function Portal( ) { const { children, container, disablePortal = false } = props; const [mountNode, setMountNode] = React.useState>(null); - // @ts-expect-error TODO upstream fix - const handleRef = useForkRef(React.isValidElement(children) ? children.ref : null, forwardedRef); + + const handleRef = useForkRef( + React.isValidElement(children) ? getReactElementRef(children) : null, + forwardedRef, + ); useEnhancedEffect(() => { if (!disablePortal) { diff --git a/packages/mui-joy/src/Tooltip/Tooltip.tsx b/packages/mui-joy/src/Tooltip/Tooltip.tsx index 531d3a8b288c1b..7f48d74360228b 100644 --- a/packages/mui-joy/src/Tooltip/Tooltip.tsx +++ b/packages/mui-joy/src/Tooltip/Tooltip.tsx @@ -11,6 +11,7 @@ import { unstable_useId as useId, unstable_useTimeout as useTimeout, unstable_Timeout as Timeout, + unstable_getReactElementRef as getReactElementRef, } from '@mui/utils'; import { Popper, unstable_composeClasses as composeClasses } from '@mui/base'; import { OverridableComponent } from '@mui/types'; @@ -424,10 +425,7 @@ const Tooltip = React.forwardRef(function Tooltip(inProps, ref) { const handleUseRef = useForkRef(setChildNode, ref); const handleFocusRef = useForkRef(focusVisibleRef, handleUseRef); - const handleRef = useForkRef( - (children as unknown as { ref: React.Ref }).ref, - handleFocusRef, - ); + const handleRef = useForkRef(getReactElementRef(children), handleFocusRef); // There is no point in displaying an empty tooltip. if (typeof title !== 'number' && !title) { diff --git a/packages/mui-material/src/ClickAwayListener/ClickAwayListener.tsx b/packages/mui-material/src/ClickAwayListener/ClickAwayListener.tsx index 65333a0180f1e3..5c6ee0f77cc3ee 100644 --- a/packages/mui-material/src/ClickAwayListener/ClickAwayListener.tsx +++ b/packages/mui-material/src/ClickAwayListener/ClickAwayListener.tsx @@ -8,6 +8,7 @@ import { unstable_useForkRef as useForkRef, unstable_useEventCallback as useEventCallback, } from '@mui/utils'; +import getReactElementRef from '@mui/utils/getReactElementRef'; // TODO: return `EventHandlerName extends `on${infer EventName}` ? Lowercase : never` once generatePropTypes runs with TS 4.1 function mapEventPropToEvent( @@ -95,11 +96,7 @@ function ClickAwayListener(props: ClickAwayListenerProps): JSX.Element { }; }, []); - const handleRef = useForkRef( - // @ts-expect-error TODO upstream fix - children.ref, - nodeRef, - ); + const handleRef = useForkRef(getReactElementRef(children), nodeRef); // The handler doesn't take event.defaultPrevented into account: // diff --git a/packages/mui-material/src/Fade/Fade.js b/packages/mui-material/src/Fade/Fade.js index 0258a3adae8723..6521768d87c407 100644 --- a/packages/mui-material/src/Fade/Fade.js +++ b/packages/mui-material/src/Fade/Fade.js @@ -3,6 +3,7 @@ import * as React from 'react'; import PropTypes from 'prop-types'; import { Transition } from 'react-transition-group'; import elementAcceptingRef from '@mui/utils/elementAcceptingRef'; +import getReactElementRef from '@mui/utils/getReactElementRef'; import useTheme from '../styles/useTheme'; import { reflow, getTransitionProps } from '../transitions/utils'; import useForkRef from '../utils/useForkRef'; @@ -48,7 +49,7 @@ const Fade = React.forwardRef(function Fade(props, ref) { const enableStrictModeCompat = true; const nodeRef = React.useRef(null); - const handleRef = useForkRef(nodeRef, children.ref, ref); + const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref); const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => { if (callback) { diff --git a/packages/mui-material/src/Grow/Grow.js b/packages/mui-material/src/Grow/Grow.js index 77866f1f107765..0488f1f5810e5c 100644 --- a/packages/mui-material/src/Grow/Grow.js +++ b/packages/mui-material/src/Grow/Grow.js @@ -3,6 +3,7 @@ import * as React from 'react'; import PropTypes from 'prop-types'; import useTimeout from '@mui/utils/useTimeout'; import elementAcceptingRef from '@mui/utils/elementAcceptingRef'; +import getReactElementRef from '@mui/utils/getReactElementRef'; import { Transition } from 'react-transition-group'; import useTheme from '../styles/useTheme'; import { getTransitionProps, reflow } from '../transitions/utils'; @@ -61,7 +62,7 @@ const Grow = React.forwardRef(function Grow(props, ref) { const theme = useTheme(); const nodeRef = React.useRef(null); - const handleRef = useForkRef(nodeRef, children.ref, ref); + const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref); const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => { if (callback) { diff --git a/packages/mui-material/src/Portal/Portal.tsx b/packages/mui-material/src/Portal/Portal.tsx index e2eab377a2d40c..8bd141d8a5f858 100644 --- a/packages/mui-material/src/Portal/Portal.tsx +++ b/packages/mui-material/src/Portal/Portal.tsx @@ -8,6 +8,7 @@ import { unstable_useEnhancedEffect as useEnhancedEffect, unstable_useForkRef as useForkRef, unstable_setRef as setRef, + unstable_getReactElementRef as getReactElementRef, } from '@mui/utils'; import { PortalProps } from './Portal.types'; @@ -33,8 +34,11 @@ const Portal = React.forwardRef(function Portal( ) { const { children, container, disablePortal = false } = props; const [mountNode, setMountNode] = React.useState>(null); - // @ts-expect-error TODO upstream fix - const handleRef = useForkRef(React.isValidElement(children) ? children.ref : null, forwardedRef); + + const handleRef = useForkRef( + React.isValidElement(children) ? getReactElementRef(children) : null, + forwardedRef, + ); useEnhancedEffect(() => { if (!disablePortal) { diff --git a/packages/mui-material/src/Select/Select.js b/packages/mui-material/src/Select/Select.js index 7c8c7d0836bf73..8bc83230120001 100644 --- a/packages/mui-material/src/Select/Select.js +++ b/packages/mui-material/src/Select/Select.js @@ -3,6 +3,7 @@ import * as React from 'react'; import PropTypes from 'prop-types'; import clsx from 'clsx'; import deepmerge from '@mui/utils/deepmerge'; +import getReactElementRef from '@mui/utils/getReactElementRef'; import SelectInput from './SelectInput'; import formControlState from '../FormControl/formControlState'; import useFormControl from '../FormControl/useFormControl'; @@ -84,7 +85,7 @@ const Select = React.forwardRef(function Select(inProps, ref) { filled: , }[variant]; - const inputComponentRef = useForkRef(ref, InputComponent.ref); + const inputComponentRef = useForkRef(ref, getReactElementRef(InputComponent)); return ( diff --git a/packages/mui-material/src/Slide/Slide.js b/packages/mui-material/src/Slide/Slide.js index f2d30a7a97d588..3bcdb311e85eab 100644 --- a/packages/mui-material/src/Slide/Slide.js +++ b/packages/mui-material/src/Slide/Slide.js @@ -5,6 +5,7 @@ import { Transition } from 'react-transition-group'; import chainPropTypes from '@mui/utils/chainPropTypes'; import HTMLElementType from '@mui/utils/HTMLElementType'; import elementAcceptingRef from '@mui/utils/elementAcceptingRef'; +import getReactElementRef from '@mui/utils/getReactElementRef'; import debounce from '../utils/debounce'; import useForkRef from '../utils/useForkRef'; import useTheme from '../styles/useTheme'; @@ -119,7 +120,7 @@ const Slide = React.forwardRef(function Slide(props, ref) { } = props; const childrenRef = React.useRef(null); - const handleRef = useForkRef(children.ref, childrenRef, ref); + const handleRef = useForkRef(getReactElementRef(children), childrenRef, ref); const normalizedTransitionCallback = (callback) => (isAppearing) => { if (callback) { diff --git a/packages/mui-material/src/Tooltip/Tooltip.js b/packages/mui-material/src/Tooltip/Tooltip.js index a0dd1879ffa26c..d72c599e6ba863 100644 --- a/packages/mui-material/src/Tooltip/Tooltip.js +++ b/packages/mui-material/src/Tooltip/Tooltip.js @@ -8,6 +8,7 @@ import composeClasses from '@mui/utils/composeClasses'; import { alpha } from '@mui/system/colorManipulator'; import { useRtl } from '@mui/system/RtlProvider'; import appendOwnerState from '@mui/utils/appendOwnerState'; +import getReactElementRef from '@mui/utils/getReactElementRef'; import { styled, useTheme } from '../styles'; import { useDefaultProps } from '../DefaultPropsProvider'; import capitalize from '../utils/capitalize'; @@ -485,7 +486,7 @@ const Tooltip = React.forwardRef(function Tooltip(inProps, ref) { }; }, [handleClose, open]); - const handleRef = useForkRef(children.ref, focusVisibleRef, setChildNode, ref); + const handleRef = useForkRef(getReactElementRef(children), focusVisibleRef, setChildNode, ref); // There is no point in displaying an empty tooltip. // So we exclude all falsy values, except 0, which is valid. diff --git a/packages/mui-material/src/Unstable_TrapFocus/FocusTrap.tsx b/packages/mui-material/src/Unstable_TrapFocus/FocusTrap.tsx index c936f55b0bf006..df0e9476238f71 100644 --- a/packages/mui-material/src/Unstable_TrapFocus/FocusTrap.tsx +++ b/packages/mui-material/src/Unstable_TrapFocus/FocusTrap.tsx @@ -7,6 +7,7 @@ import { elementAcceptingRef, unstable_useForkRef as useForkRef, unstable_ownerDocument as ownerDocument, + unstable_getReactElementRef as getReactElementRef, } from '@mui/utils'; import { FocusTrapProps } from './FocusTrap.types'; @@ -144,8 +145,7 @@ function FocusTrap(props: FocusTrapProps): JSX.Element { const activated = React.useRef(false); const rootRef = React.useRef(null); - // @ts-expect-error TODO upstream fix - const handleRef = useForkRef(children.ref, rootRef); + const handleRef = useForkRef(getReactElementRef(children), rootRef); const lastKeydown = React.useRef(null); React.useEffect(() => { diff --git a/packages/mui-material/src/Zoom/Zoom.js b/packages/mui-material/src/Zoom/Zoom.js index 5f0ebfd5d9f781..62c721b2eec490 100644 --- a/packages/mui-material/src/Zoom/Zoom.js +++ b/packages/mui-material/src/Zoom/Zoom.js @@ -3,6 +3,7 @@ import * as React from 'react'; import PropTypes from 'prop-types'; import { Transition } from 'react-transition-group'; import elementAcceptingRef from '@mui/utils/elementAcceptingRef'; +import getReactElementRef from '@mui/utils/getReactElementRef'; import useTheme from '../styles/useTheme'; import { reflow, getTransitionProps } from '../transitions/utils'; import useForkRef from '../utils/useForkRef'; @@ -48,7 +49,7 @@ const Zoom = React.forwardRef(function Zoom(props, ref) { } = props; const nodeRef = React.useRef(null); - const handleRef = useForkRef(nodeRef, children.ref, ref); + const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref); const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => { if (callback) { diff --git a/packages/mui-utils/src/getReactElementRef/getReactElementRef.spec.tsx b/packages/mui-utils/src/getReactElementRef/getReactElementRef.spec.tsx new file mode 100644 index 00000000000000..434cf59fba109a --- /dev/null +++ b/packages/mui-utils/src/getReactElementRef/getReactElementRef.spec.tsx @@ -0,0 +1,19 @@ +import getReactElementRef from '@mui/utils/getReactElementRef'; +import * as React from 'react'; + +// @ts-expect-error +getReactElementRef(false); + +// @ts-expect-error +getReactElementRef(null); + +// @ts-expect-error +getReactElementRef(undefined); + +// @ts-expect-error +getReactElementRef(1); + +// @ts-expect-error +getReactElementRef([
,
]); + +getReactElementRef(
); diff --git a/packages/mui-utils/src/getReactElementRef/getReactElementRef.test.tsx b/packages/mui-utils/src/getReactElementRef/getReactElementRef.test.tsx new file mode 100644 index 00000000000000..b7bfbd3da5b67b --- /dev/null +++ b/packages/mui-utils/src/getReactElementRef/getReactElementRef.test.tsx @@ -0,0 +1,39 @@ +import { expect } from 'chai'; +import getReactElementRef from '@mui/utils/getReactElementRef'; +import * as React from 'react'; + +describe('getReactElementRef', () => { + it('should return undefined when not used correctly', () => { + // @ts-expect-error + expect(getReactElementRef(false)).to.equal(null); + // @ts-expect-error + expect(getReactElementRef()).to.equal(null); + // @ts-expect-error + expect(getReactElementRef(1)).to.equal(null); + + const children = [
,
]; + // @ts-expect-error + expect(getReactElementRef(children)).to.equal(null); + }); + + it('should return the ref of a React element', () => { + const ref = React.createRef(); + const element =
; + expect(getReactElementRef(element)).to.equal(ref); + }); + + it('should return null for a fragment', () => { + const element = ( + +

Hello

+

Hello

+
+ ); + expect(getReactElementRef(element)).to.equal(null); + }); + + it('should return null for element with no ref', () => { + const element =
; + expect(getReactElementRef(element)).to.equal(null); + }); +}); diff --git a/packages/mui-utils/src/getReactElementRef/getReactElementRef.ts b/packages/mui-utils/src/getReactElementRef/getReactElementRef.ts new file mode 100644 index 00000000000000..75394cbfa0ae8e --- /dev/null +++ b/packages/mui-utils/src/getReactElementRef/getReactElementRef.ts @@ -0,0 +1,18 @@ +import * as React from 'react'; + +/** + * Returns the ref of a React element handling differences between React 19 and older versions. + * It will throw runtime error if the element is not a valid React element. + * + * @param element React.ReactElement + * @returns React.Ref | null + */ +export default function getReactElementRef(element: React.ReactElement): React.Ref | null { + // 'ref' is passed as prop in React 19, whereas 'ref' is directly attached to children in older versions + if (parseInt(React.version, 10) >= 19) { + return (element?.props as any)?.ref || null; + } + // @ts-expect-error element.ref is not included in the ReactElement type + // https://github.com/DefinitelyTyped/DefinitelyTyped/discussions/70189 + return element?.ref || null; +} diff --git a/packages/mui-utils/src/getReactElementRef/index.ts b/packages/mui-utils/src/getReactElementRef/index.ts new file mode 100644 index 00000000000000..e71d03be6f7988 --- /dev/null +++ b/packages/mui-utils/src/getReactElementRef/index.ts @@ -0,0 +1 @@ +export { default } from './getReactElementRef'; diff --git a/packages/mui-utils/src/index.ts b/packages/mui-utils/src/index.ts index d4b0d45770f7f6..3a1aaabf05b8a8 100644 --- a/packages/mui-utils/src/index.ts +++ b/packages/mui-utils/src/index.ts @@ -49,4 +49,5 @@ export { default as unstable_useSlotProps } from './useSlotProps'; export type { UseSlotPropsParameters, UseSlotPropsResult } from './useSlotProps'; export { default as unstable_resolveComponentProps } from './resolveComponentProps'; export { default as unstable_extractEventHandlers } from './extractEventHandlers'; +export { default as unstable_getReactElementRef } from './getReactElementRef'; export * from './types'; diff --git a/packages/mui-utils/src/useForkRef/useForkRef.test.js b/packages/mui-utils/src/useForkRef/useForkRef.test.js index 57a7290a635b48..4b1a688b1abb0e 100644 --- a/packages/mui-utils/src/useForkRef/useForkRef.test.js +++ b/packages/mui-utils/src/useForkRef/useForkRef.test.js @@ -2,6 +2,7 @@ import * as React from 'react'; import { expect } from 'chai'; import { createRenderer, screen } from '@mui-internal/test-utils'; import useForkRef from './useForkRef'; +import getReactElementRef from '../getReactElementRef'; describe('useForkRef', () => { const { render } = createRenderer(); @@ -47,7 +48,7 @@ describe('useForkRef', () => { it('does nothing if none of the forked branches requires a ref', () => { const Outer = React.forwardRef(function Outer(props, ref) { const { children } = props; - const handleRef = useForkRef(children.ref, ref); + const handleRef = useForkRef(getReactElementRef(children), ref); return React.cloneElement(children, { ref: handleRef }); });