Skip to content

Commit

Permalink
patch: fix when-then double lit case (#1810)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jan 16, 2025
1 parent e4e881b commit 8564711
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 25 deletions.
19 changes: 8 additions & 11 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,19 +435,18 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable(
pa.repeat(pa.scalar(self._then_value), len(condition)),
name="literal",
backend_version=self._backend_version,
version=self._version,
# `self._then_value` is a scalar and can't be converted to an expression
value_series = plx._create_series_from_scalar(
self._then_value, reference_series=condition
)

value_series_native = value_series._native_series
condition_native = condition._native_series
condition_native, value_series_native = broadcast_series(
[condition, value_series]
)

if self._otherwise_value is None:
otherwise_native = pa.repeat(
Expand All @@ -472,9 +471,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
]
else:
otherwise_series = otherwise_expr(df)[0]
condition_native, otherwise_native = broadcast_series(
[condition, otherwise_series]
)
_, otherwise_native = broadcast_series([condition, otherwise_series])
return [
value_series._from_native_series(
pc.if_else(condition_native, value_series_native, otherwise_native)
Expand Down
17 changes: 14 additions & 3 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,24 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
condition = cast("dx.Series", condition)

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
then_expr = parse_into_expr(self._then_value, namespace=plx)
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
# `self._then_value` is a scalar and can't be converted to an expression
value_sequence: Sequence[Any] = [self._then_value]
is_scalar = True
else:
is_scalar = then_expr._returns_scalar # type: ignore[attr-defined]
value_sequence = then_expr(df)[0]

if is_scalar:
_df = condition.to_frame("a")
_df["tmp"] = self._then_value
_df["tmp"] = value_sequence[0]
value_series = _df["tmp"]
else:
value_series = value_sequence

value_series = cast("dx.Series", value_series)
validate_comparand(condition, value_series)

Expand Down
26 changes: 15 additions & 11 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,20 +461,17 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]

try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable(
[self._then_value] * len(condition),
name="literal",
index=condition._native_series.index,
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
# `self._then_value` is a scalar and can't be converted to an expression
value_series = plx._create_series_from_scalar(
self._then_value, reference_series=condition
)
value_series_native, condition_native = broadcast_align_and_extract_native(
value_series, condition

condition_native, value_series_native = broadcast_align_and_extract_native(
condition, value_series
)

if self._otherwise_value is None:
Expand All @@ -494,7 +491,14 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
]
else:
otherwise_series = otherwise_expr(df)[0]
return [value_series.zip_with(condition, otherwise_series)]
_, otherwise_native = broadcast_align_and_extract_native(
condition, otherwise_series
)
return [
value_series._from_native_series(
value_series_native.where(condition_native, otherwise_native)
)
]

def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen:
self._then_value = value
Expand Down
10 changes: 10 additions & 0 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,13 @@ def test_when_then_otherwise_lit_str(
result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z")))
expected = {"b": ["z", "b", "c"]}
assert_equal_data(result, expected)


def test_when_then_otherwise_both_lit(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(
x1=nw.when(nw.col("a") > 1).then(nw.lit(42)).otherwise(nw.lit(-1)),
x2=nw.when(nw.col("a") > 2).then(nw.lit(42)).otherwise(nw.lit(-1)),
)
expected = {"x1": [-1, 42, 42], "x2": [-1, -1, 42]}
assert_equal_data(result, expected)

0 comments on commit 8564711

Please sign in to comment.