Skip to content

Commit

Permalink
generalize dask case
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi committed Jan 15, 2025
1 parent e344296 commit 129d734
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
if len(value_series) == 1: # literal case
if len(value_series) == 1: # literal or reduction case
value_series = condition.__class__._from_iterable(
pa.repeat(pa.scalar(value_series[0]), len(condition)),
name="literal",
Expand Down
5 changes: 3 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,9 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
condition = cast("dx.Series", condition)

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
if getattr(self._then_value, "_function_name", None) == "lit":
then_expr = parse_into_expr(self._then_value, namespace=plx)
value_series = then_expr(df)[0]
if getattr(then_expr, "_returns_scalar", False): # literal or reduction case
_df = condition.to_frame("a")
_df["tmp"] = value_series[0]
value_series = _df["tmp"]
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
if len(value_series) == 1: # literal case
if len(value_series) == 1: # literal or reduction case
value_series = condition.__class__._from_iterable(
[value_series[0]] * len(condition),
name="literal",
Expand Down

0 comments on commit 129d734

Please sign in to comment.