diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 00b4584b2..f47d7aefd 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -437,7 +437,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: condition = parse_into_expr(self._condition, namespace=plx)(df)[0] try: value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] - if len(value_series) == 1: # literal case + if len(value_series) == 1: # literal or reduction case value_series = condition.__class__._from_iterable( pa.repeat(pa.scalar(value_series[0]), len(condition)), name="literal", diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index a36f636d7..5284f46be 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -400,8 +400,9 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: condition = cast("dx.Series", condition) try: - value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] - if getattr(self._then_value, "_function_name", None) == "lit": + then_expr = parse_into_expr(self._then_value, namespace=plx) + value_series = then_expr(df)[0] + if getattr(then_expr, "_returns_scalar", False): # literal or reduction case _df = condition.to_frame("a") _df["tmp"] = value_series[0] value_series = _df["tmp"] diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index f87c3417a..854229431 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -463,7 +463,7 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: condition = parse_into_expr(self._condition, namespace=plx)(df)[0] try: value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] - if len(value_series) == 1: # literal case + if len(value_series) == 1: # literal or reduction case value_series = condition.__class__._from_iterable( [value_series[0]] * len(condition), name="literal",