From beafaea473c6a17a8a3e66675d8ebac3ac3ebf6d Mon Sep 17 00:00:00 2001
From: Marija Najdova <mnajdova@gmail.com>
Date: Mon, 2 Oct 2023 13:46:12 +0200
Subject: [PATCH] [system] Add support for `variants` in the styled() util
 (#39073)

---
 .../mui-material/src/styles/styled.spec.tsx   |  30 +++
 packages/mui-styled-engine-sc/src/index.d.ts  |  11 +-
 packages/mui-styled-engine/src/index.d.ts     |  11 +-
 packages/mui-system/src/createStyled.js       | 154 +++++++++++----
 packages/mui-system/src/createStyled.test.js  | 177 ++++++++++++++++++
 5 files changed, 343 insertions(+), 40 deletions(-)

diff --git a/packages/mui-material/src/styles/styled.spec.tsx b/packages/mui-material/src/styles/styled.spec.tsx
index e76d9757df168b..8053428c818c6a 100644
--- a/packages/mui-material/src/styles/styled.spec.tsx
+++ b/packages/mui-material/src/styles/styled.spec.tsx
@@ -155,3 +155,33 @@ function Button({
     Hello
   </Button>
 </ThemeProvider>;
+
+function variantsAPI() {
+  const ObjectSyntax = styled('div')<{ foo?: string; bar?: number }>({
+    variants: [
+      {
+        props: { foo: 'a' },
+        style: { color: 'blue' },
+      },
+    ],
+  });
+
+  const FunctionSyntax = styled('div')<{ foo?: string; bar?: number }>(() => ({
+    variants: [
+      {
+        props: { foo: 'a' },
+        style: { color: 'blue' },
+      },
+    ],
+  }));
+
+  // @ts-expect-error the API is not valid for CSS properties
+  const WrongUsage = styled('div')<{ foo?: string; bar?: number }>({
+    color: [
+      {
+        props: { foo: 'a' },
+        style: { color: 'blue' },
+      },
+    ],
+  });
+}
diff --git a/packages/mui-styled-engine-sc/src/index.d.ts b/packages/mui-styled-engine-sc/src/index.d.ts
index 59f8108776256a..d338703d00e04d 100644
--- a/packages/mui-styled-engine-sc/src/index.d.ts
+++ b/packages/mui-styled-engine-sc/src/index.d.ts
@@ -50,11 +50,20 @@ export interface CSSOthersObjectForCSSObject {
   [propertiesName: string]: CSSInterpolation;
 }
 
-export interface CSSObject extends CSSPropertiesWithMultiValues, CSSPseudos, CSSOthersObject {}
+// Omit variants as a key, because we have a special handling for it
+export interface CSSObject
+  extends CSSPropertiesWithMultiValues,
+    CSSPseudos,
+    Omit<CSSOthersObject, 'variants'> {}
+
+interface CSSObjectWithVariants<Props> extends Omit<CSSObject, 'variants'> {
+  variants: Array<{ props: Props; variants: CSSObject }>;
+}
 
 export type FalseyValue = undefined | null | false;
 export type Interpolation<P> =
   | InterpolationValue
+  | CSSObjectWithVariants<P>
   | InterpolationFunction<P>
   | FlattenInterpolation<P>;
 // cannot be made a self-referential interface, breaks WithPropNested
diff --git a/packages/mui-styled-engine/src/index.d.ts b/packages/mui-styled-engine/src/index.d.ts
index 1130164844f7b0..68c4776444e8ff 100644
--- a/packages/mui-styled-engine/src/index.d.ts
+++ b/packages/mui-styled-engine/src/index.d.ts
@@ -49,7 +49,15 @@ export interface CSSOthersObjectForCSSObject {
   [propertiesName: string]: CSSInterpolation;
 }
 
-export interface CSSObject extends CSSPropertiesWithMultiValues, CSSPseudos, CSSOthersObject {}
+// Omit variants as a key, because we have a special handling for it
+export interface CSSObject
+  extends CSSPropertiesWithMultiValues,
+    CSSPseudos,
+    Omit<CSSOthersObject, 'variants'> {}
+
+interface CSSObjectWithVariants<Props> extends Omit<CSSObject, 'variants'> {
+  variants: Array<{ props: Props; variants: CSSObject }>;
+}
 
 export interface ComponentSelector {
   __emotion_styles: any;
@@ -85,6 +93,7 @@ export interface ArrayInterpolation<Props> extends Array<Interpolation<Props>> {
 
 export type Interpolation<Props> =
   | InterpolationPrimitive
+  | CSSObjectWithVariants<Props>
   | ArrayInterpolation<Props>
   | FunctionInterpolation<Props>;
 
diff --git a/packages/mui-system/src/createStyled.js b/packages/mui-system/src/createStyled.js
index 0fc9686fb80530..768acec584a320 100644
--- a/packages/mui-system/src/createStyled.js
+++ b/packages/mui-system/src/createStyled.js
@@ -1,6 +1,11 @@
 /* eslint-disable no-underscore-dangle */
 import styledEngineStyled, { internal_processStyles as processStyles } from '@mui/styled-engine';
-import { getDisplayName, unstable_capitalize as capitalize } from '@mui/utils';
+import {
+  getDisplayName,
+  unstable_capitalize as capitalize,
+  isPlainObject,
+  deepmerge,
+} from '@mui/utils';
 import createTheme from './createTheme';
 import propsToClassKey from './propsToClassKey';
 import styleFunctionSx from './styleFunctionSx';
@@ -28,36 +33,41 @@ const getStyleOverrides = (name, theme) => {
   return null;
 };
 
+const transformVariants = (variants) => {
+  const variantsStyles = {};
+
+  if (variants) {
+    variants.forEach((definition) => {
+      const key = propsToClassKey(definition.props);
+      variantsStyles[key] = definition.style;
+    });
+  }
+
+  return variantsStyles;
+};
 const getVariantStyles = (name, theme) => {
   let variants = [];
   if (theme && theme.components && theme.components[name] && theme.components[name].variants) {
     variants = theme.components[name].variants;
   }
 
-  const variantsStyles = {};
-
-  variants.forEach((definition) => {
-    const key = propsToClassKey(definition.props);
-    variantsStyles[key] = definition.style;
-  });
-
-  return variantsStyles;
+  return transformVariants(variants);
 };
 
-const variantsResolver = (props, styles, theme, name) => {
+const variantsResolver = (props, styles, variants) => {
   const { ownerState = {} } = props;
   const variantsStyles = [];
-  const themeVariants = theme?.components?.[name]?.variants;
-  if (themeVariants) {
-    themeVariants.forEach((themeVariant) => {
+
+  if (variants) {
+    variants.forEach((variant) => {
       let isMatch = true;
-      Object.keys(themeVariant.props).forEach((key) => {
-        if (ownerState[key] !== themeVariant.props[key] && props[key] !== themeVariant.props[key]) {
+      Object.keys(variant.props).forEach((key) => {
+        if (ownerState[key] !== variant.props[key] && props[key] !== variant.props[key]) {
           isMatch = false;
         }
       });
       if (isMatch) {
-        variantsStyles.push(styles[propsToClassKey(themeVariant.props)]);
+        variantsStyles.push(styles[propsToClassKey(variant.props)]);
       }
     });
   }
@@ -65,6 +75,11 @@ const variantsResolver = (props, styles, theme, name) => {
   return variantsStyles;
 };
 
+const themeVariantsResolver = (props, styles, theme, name) => {
+  const themeVariants = theme?.components?.[name]?.variants;
+  return variantsResolver(props, styles, themeVariants);
+};
+
 // Update /system/styled/#api in case if this changes
 export function shouldForwardProp(prop) {
   return prop !== 'ownerState' && prop !== 'theme' && prop !== 'sx' && prop !== 'as';
@@ -90,6 +105,30 @@ function defaultOverridesResolver(slot) {
   return (props, styles) => styles[slot];
 }
 
+const muiStyledFunctionResolver = ({ styledArg, props, defaultTheme, themeId }) => {
+  const resolvedStyles = styledArg({
+    ...props,
+    theme: resolveTheme({ ...props, defaultTheme, themeId }),
+  });
+
+  let optionalVariants;
+  if (resolvedStyles && resolvedStyles.variants) {
+    optionalVariants = resolvedStyles.variants;
+    delete resolvedStyles.variants;
+  }
+  if (optionalVariants) {
+    const variantsStyles = variantsResolver(
+      props,
+      transformVariants(optionalVariants),
+      optionalVariants,
+    );
+
+    return [resolvedStyles, ...variantsStyles];
+  }
+
+  return resolvedStyles;
+};
+
 export default function createStyled(input = {}) {
   const {
     themeId,
@@ -163,19 +202,72 @@ export default function createStyled(input = {}) {
             // On the server Emotion doesn't use React.forwardRef for creating components, so the created
             // component stays as a function. This condition makes sure that we do not interpolate functions
             // which are basically components used as a selectors.
-            return typeof stylesArg === 'function' && stylesArg.__emotion_real !== stylesArg
-              ? (props) => {
-                  return stylesArg({
-                    ...props,
-                    theme: resolveTheme({ ...props, defaultTheme, themeId }),
+            if (typeof stylesArg === 'function' && stylesArg.__emotion_real !== stylesArg) {
+              return (props) =>
+                muiStyledFunctionResolver({ styledArg: stylesArg, props, defaultTheme, themeId });
+            }
+            if (isPlainObject(stylesArg)) {
+              let transformedStylesArg = stylesArg;
+              let styledArgVariants;
+
+              if (stylesArg && stylesArg.variants) {
+                styledArgVariants = stylesArg.variants;
+                delete transformedStylesArg.variants;
+
+                transformedStylesArg = (props) => {
+                  let result = stylesArg;
+                  const variantStyles = variantsResolver(
+                    props,
+                    transformVariants(styledArgVariants),
+                    styledArgVariants,
+                  );
+                  variantStyles.forEach((variantStyle) => {
+                    result = deepmerge(result, variantStyle);
                   });
-                }
-              : stylesArg;
+
+                  return result;
+                };
+              }
+              return transformedStylesArg;
+            }
+            return stylesArg;
           })
         : [];
 
       let transformedStyleArg = styleArg;
 
+      if (isPlainObject(styleArg)) {
+        let styledArgVariants;
+        if (styleArg && styleArg.variants) {
+          styledArgVariants = styleArg.variants;
+          delete transformedStyleArg.variants;
+
+          transformedStyleArg = (props) => {
+            let result = styleArg;
+            const variantStyles = variantsResolver(
+              props,
+              transformVariants(styledArgVariants),
+              styledArgVariants,
+            );
+            variantStyles.forEach((variantStyle) => {
+              result = deepmerge(result, variantStyle);
+            });
+
+            return result;
+          };
+        }
+      } else if (
+        typeof styleArg === 'function' &&
+        // On the server Emotion doesn't use React.forwardRef for creating components, so the created
+        // component stays as a function. This condition makes sure that we do not interpolate functions
+        // which are basically components used as a selectors.
+        styleArg.__emotion_real !== styleArg
+      ) {
+        // If the type is function, we need to define the default theme.
+        transformedStyleArg = (props) =>
+          muiStyledFunctionResolver({ styledArg: styleArg, props, defaultTheme, themeId });
+      }
+
       if (componentName && overridesResolver) {
         expressionsWithDefaultTheme.push((props) => {
           const theme = resolveTheme({ ...props, defaultTheme, themeId });
@@ -197,7 +289,7 @@ export default function createStyled(input = {}) {
       if (componentName && !skipVariantsResolver) {
         expressionsWithDefaultTheme.push((props) => {
           const theme = resolveTheme({ ...props, defaultTheme, themeId });
-          return variantsResolver(
+          return themeVariantsResolver(
             props,
             getVariantStyles(componentName, theme),
             theme,
@@ -217,21 +309,7 @@ export default function createStyled(input = {}) {
         // If the type is array, than we need to add placeholders in the template for the overrides, variants and the sx styles.
         transformedStyleArg = [...styleArg, ...placeholders];
         transformedStyleArg.raw = [...styleArg.raw, ...placeholders];
-      } else if (
-        typeof styleArg === 'function' &&
-        // On the server Emotion doesn't use React.forwardRef for creating components, so the created
-        // component stays as a function. This condition makes sure that we do not interpolate functions
-        // which are basically components used as a selectors.
-        styleArg.__emotion_real !== styleArg
-      ) {
-        // If the type is function, we need to define the default theme.
-        transformedStyleArg = (props) =>
-          styleArg({
-            ...props,
-            theme: resolveTheme({ ...props, defaultTheme, themeId }),
-          });
       }
-
       const Component = defaultStyledResolver(transformedStyleArg, ...expressionsWithDefaultTheme);
 
       if (process.env.NODE_ENV !== 'production') {
diff --git a/packages/mui-system/src/createStyled.test.js b/packages/mui-system/src/createStyled.test.js
index c03c503374db88..35d36dc0e3f18e 100644
--- a/packages/mui-system/src/createStyled.test.js
+++ b/packages/mui-system/src/createStyled.test.js
@@ -411,4 +411,181 @@ describe('createStyled', () => {
       expect(container.firstChild).to.have.tagName('span');
     });
   });
+
+  describe('variants key', () => {
+    it('should accept variants in object style arg', () => {
+      const styled = createStyled({});
+
+      const Test = styled('div')({
+        variants: [
+          {
+            props: { color: 'blue', variant: 'filled' },
+            style: {
+              backgroundColor: 'rgb(0,0,255)',
+            },
+          },
+          {
+            props: { color: 'blue', variant: 'text' },
+            style: {
+              color: 'rgb(0,0,255)',
+            },
+          },
+        ],
+      });
+
+      const { getByTestId } = render(
+        <React.Fragment>
+          <Test data-testid="filled" color="blue" variant="filled">
+            Filled
+          </Test>
+          <Test data-testid="text" color="blue" variant="text">
+            Filled
+          </Test>
+        </React.Fragment>,
+      );
+      expect(getByTestId('filled')).toHaveComputedStyle({ backgroundColor: 'rgb(0, 0, 255)' });
+      expect(getByTestId('text')).toHaveComputedStyle({ color: 'rgb(0, 0, 255)' });
+    });
+
+    it('should accept variants in function style arg', () => {
+      const styled = createStyled({ defaultTheme: { colors: { blue: 'rgb(0, 0, 255)' } } });
+
+      const Test = styled('div')(({ theme }) => ({
+        variants: [
+          {
+            props: { color: 'blue', variant: 'filled' },
+            style: {
+              backgroundColor: theme.colors.blue,
+            },
+          },
+          {
+            props: { color: 'blue', variant: 'text' },
+            style: {
+              color: theme.colors.blue,
+            },
+          },
+        ],
+      }));
+
+      const { getByTestId } = render(
+        <React.Fragment>
+          <Test data-testid="filled" color="blue" variant="filled">
+            Filled
+          </Test>
+          <Test data-testid="text" color="blue" variant="text">
+            Filled
+          </Test>
+        </React.Fragment>,
+      );
+      expect(getByTestId('filled')).toHaveComputedStyle({ backgroundColor: 'rgb(0, 0, 255)' });
+      expect(getByTestId('text')).toHaveComputedStyle({ color: 'rgb(0, 0, 255)' });
+    });
+
+    it('should accept variants in arrays', () => {
+      const styled = createStyled({ defaultTheme: { colors: { blue: 'rgb(0, 0, 255)' } } });
+
+      const Test = styled('div')(
+        ({ theme }) => ({
+          variants: [
+            {
+              props: { color: 'blue', variant: 'filled' },
+              style: {
+                backgroundColor: theme.colors.blue,
+              },
+            },
+            {
+              props: { color: 'blue', variant: 'text' },
+              style: {
+                color: theme.colors.blue,
+              },
+            },
+          ],
+        }),
+        {
+          variants: [
+            {
+              props: { color: 'blue', variant: 'outlined' },
+              style: {
+                borderTopColor: 'rgb(0,0,255)',
+              },
+            },
+            // This is overriding the previous definition
+            {
+              props: { color: 'blue', variant: 'text' },
+              style: {
+                color: 'rgb(0,0,220)',
+              },
+            },
+          ],
+        },
+      );
+
+      const { getByTestId } = render(
+        <React.Fragment>
+          <Test data-testid="filled" color="blue" variant="filled">
+            Filled
+          </Test>
+          <Test data-testid="text" color="blue" variant="text">
+            Filled
+          </Test>
+          <Test data-testid="outlined" color="blue" variant="outlined">
+            Outlined
+          </Test>
+        </React.Fragment>,
+      );
+      expect(getByTestId('filled')).toHaveComputedStyle({ backgroundColor: 'rgb(0, 0, 255)' });
+      expect(getByTestId('text')).toHaveComputedStyle({ color: 'rgb(0, 0, 220)' });
+      expect(getByTestId('outlined')).toHaveComputedStyle({ borderTopColor: 'rgb(0, 0, 255)' });
+    });
+
+    it('theme variants should override styled variants', () => {
+      const styled = createStyled({});
+
+      const Test = styled('div', { name: 'Test' })({
+        variants: [
+          {
+            props: { color: 'blue', variant: 'filled' },
+            style: {
+              backgroundColor: 'rgb(0,0,255)',
+            },
+          },
+          // This is overriding the previous definition
+          {
+            props: { color: 'blue', variant: 'text' },
+            style: {
+              color: 'rgb(0,0,255)',
+            },
+          },
+        ],
+      });
+
+      const { getByTestId } = render(
+        <ThemeProvider
+          theme={{
+            components: {
+              Test: {
+                variants: [
+                  {
+                    props: { variant: 'text', color: 'blue' },
+                    style: {
+                      color: 'rgb(0,0,220)',
+                    },
+                  },
+                ],
+              },
+            },
+          }}
+        >
+          <Test data-testid="filled" color="blue" variant="filled">
+            Filled
+          </Test>
+          <Test data-testid="text" color="blue" variant="text">
+            Filled
+          </Test>
+        </ThemeProvider>,
+      );
+      expect(getByTestId('filled')).toHaveComputedStyle({ backgroundColor: 'rgb(0, 0, 255)' });
+      expect(getByTestId('text')).toHaveComputedStyle({ color: 'rgb(0, 0, 220)' });
+    });
+  });
 });