Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add DuckDB: nw.nth, nw.sum_horizontal, nw.concat_str, group_by with drop_null_keys #1832

Merged
merged 18 commits into from
Jan 20, 2025
Merged
14 changes: 7 additions & 7 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,19 @@ def concat_str(
dtypes = import_dtypes_module(self._version)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = (
s._native_series
for _expr in parsed_exprs
for s in _expr.cast(dtypes.String())(df)
)
compliant_series_list = [
s for _expr in parsed_exprs for s in _expr.cast(dtypes.String())(df)
]
null_handling = "skip" if ignore_nulls else "emit_null"
result_series = pc.binary_join_element_wise(
*series, separator, null_handling=null_handling
*(s._native_series for s in compliant_series_list),
separator,
null_handling=null_handling,
)
return [
ArrowSeries(
native_series=result_series,
name="",
name=compliant_series_list[0].name,
backend_version=self._backend_version,
version=self._version,
)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
init_value,
)

return [result]
return [result.rename(null_mask[0].name)]

return DaskExpr(
call=func,
Expand Down
4 changes: 0 additions & 4 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,6 @@ def _from_native_frame(self: Self, df: Any) -> Self:
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> DuckDBGroupBy:
from narwhals._duckdb.group_by import DuckDBGroupBy

if drop_null_keys:
msg = "todo"
raise NotImplementedError(msg)

return DuckDBGroupBy(
compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys
)
Expand Down
26 changes: 26 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,32 @@ def func(_: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={},
)

@classmethod
def from_column_indices(
cls: type[Self],
*column_indices: int,
backend_version: tuple[int, ...],
version: Version,
) -> Self:
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
from duckdb import ColumnExpression

columns = df.columns

return [ColumnExpression(columns[i]) for i in column_indices]

return cls(
func,
depth=0,
function_name="nth",
root_names=None,
output_names=None,
returns_scalar=False,
backend_version=backend_version,
version=version,
kwargs={},
)

def _from_call(
self,
call: Callable[..., duckdb.Expression],
Expand Down
7 changes: 5 additions & 2 deletions narwhals/_duckdb/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def __init__(
keys: list[str],
drop_null_keys: bool, # noqa: FBT001
) -> None:
self._compliant_frame = compliant_frame
if drop_null_keys:
self._compliant_frame = compliant_frame.drop_nulls(subset=None)
else:
self._compliant_frame = compliant_frame
self._keys = keys

def agg(
Expand Down Expand Up @@ -46,7 +49,7 @@ def agg(
try:
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.aggregate(
agg_columns, group_expr=",".join(self._keys)
agg_columns, group_expr=",".join(f'"{key}"' for key in self._keys)
)
)
except ValueError as exc: # pragma: no cover
Expand Down
111 changes: 111 additions & 0 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal
from typing import Sequence
from typing import cast
Expand Down Expand Up @@ -74,6 +75,83 @@ def concat(
)
return first._from_native_frame(res)

def concat_str(
self,
exprs: Iterable[IntoDuckDBExpr],
*more_exprs: IntoDuckDBExpr,
separator: str,
ignore_nulls: bool,
) -> DuckDBExpr:
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]
from duckdb import CaseExpression
from duckdb import ConstantExpression
from duckdb import FunctionExpression

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
cols = [s for _expr in parsed_exprs for s in _expr(df)]
null_mask = [s.isnull() for _expr in parsed_exprs for s in _expr(df)]
first_column_name = get_column_name(df, cols[0])

if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
cols_separated = [
y
for x in [
(col.cast("string"),)
if i == len(cols) - 1
else (col.cast("string"), ConstantExpression(separator))
for i, col in enumerate(cols)
]
for y in x
]
result = CaseExpression(
condition=~null_mask_result,
value=FunctionExpression("concat", *cols_separated),
)
else:
init_value, *values = [
CaseExpression(~nm, col.cast("string")).otherwise(
ConstantExpression("")
)
for col, nm in zip(cols, null_mask)
]
separators = (
CaseExpression(nm, ConstantExpression("")).otherwise(
ConstantExpression(separator)
)
for nm in null_mask[:-1]
)
result = reduce(
lambda x, y: FunctionExpression("concat", x, y),
(
FunctionExpression("concat", s, v)
for s, v in zip(separators, values)
),
init_value,
)

return [result.alias(first_column_name)]

return DuckDBExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={
"exprs": exprs,
"more_exprs": more_exprs,
"separator": separator,
"ignore_nulls": ignore_nulls,
},
)

def all_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

Expand Down Expand Up @@ -158,6 +236,34 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={"exprs": exprs},
)

def sum_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr:
from duckdb import CoalesceOperator
from duckdb import ConstantExpression

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [
reduce(
operator.add,
(CoalesceOperator(col, ConstantExpression(0)) for col in cols),
).alias(col_name)
]

return DuckDBExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="sum_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def when(
self,
*predicates: IntoDuckDBExpr,
Expand All @@ -173,6 +279,11 @@ def col(self, *column_names: str) -> DuckDBExpr:
*column_names, backend_version=self._backend_version, version=self._version
)

def nth(self, *column_indices: int) -> DuckDBExpr:
return DuckDBExpr.from_column_indices(
*column_indices, backend_version=self._backend_version, version=self._version
)

def lit(self, value: Any, dtype: DType | None) -> DuckDBExpr:
from duckdb import ConstantExpression

Expand Down
25 changes: 25 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,31 @@ def func(_: SparkLikeLazyFrame) -> list[Column]:
kwargs={},
)

@classmethod
def from_column_indices(
cls: type[Self],
*column_indices: int,
backend_version: tuple[int, ...],
version: Version,
) -> Self:
def func(df: SparkLikeLazyFrame) -> list[Column]:
from pyspark.sql import functions as F # noqa: N812

columns = df.columns
return [F.col(columns[i]) for i in column_indices]

return cls(
func,
depth=0,
function_name="nth",
root_names=None,
output_names=None,
returns_scalar=False,
backend_version=backend_version,
version=version,
kwargs={},
)

def _from_call(
self,
call: Callable[..., Column],
Expand Down
18 changes: 14 additions & 4 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def col(self, *column_names: str) -> SparkLikeExpr:
*column_names, backend_version=self._backend_version, version=self._version
)

def nth(self, *column_indices: int) -> SparkLikeExpr:
return SparkLikeExpr.from_column_indices(
*column_indices, backend_version=self._backend_version, version=self._version
)

def lit(self, value: object, dtype: DType | None) -> SparkLikeExpr:
if dtype is not None:
msg = "todo"
Expand Down Expand Up @@ -293,19 +298,24 @@ def concat_str(
]

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = (s.cast(StringType()) for _expr in parsed_exprs for s in _expr(df))
cols = [s for _expr in parsed_exprs for s in _expr(df)]
cols_casted = [s.cast(StringType()) for s in cols]
null_mask = [F.isnull(s) for _expr in parsed_exprs for s in _expr(df)]
first_column_name = get_column_name(df, cols[0])

if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
result = F.when(
~null_mask_result,
reduce(lambda x, y: F.format_string(f"%s{separator}%s", x, y), cols),
reduce(
lambda x, y: F.format_string(f"%s{separator}%s", x, y),
cols_casted,
),
).otherwise(F.lit(None))
else:
init_value, *values = [
F.when(~nm, col).otherwise(F.lit(""))
for col, nm in zip(cols, null_mask)
for col, nm in zip(cols_casted, null_mask)
]

separators = (
Expand All @@ -318,7 +328,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
init_value,
)

return [result]
return [result.alias(first_column_name)]

return SparkLikeExpr(
call=func,
Expand Down
2 changes: 0 additions & 2 deletions tests/expr_and_series/all_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def test_allh_nth(
) -> None:
if "polars" in str(constructor) and POLARS_VERSION < (1, 0):
request.applymarker(pytest.mark.xfail)
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand Down
16 changes: 9 additions & 7 deletions tests/expr_and_series/concat_str_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data

Expand All @@ -27,7 +28,8 @@ def test_concat_str(
expected: list[str],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
if "polars" in str(constructor) and POLARS_VERSION < (1, 0, 0):
# nth only available after 1.0
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = (
Expand All @@ -49,16 +51,16 @@ def test_concat_str(
assert_equal_data(result, {"full_sentence": expected})
result = (
df.select(
"a",
nw.col("a").alias("a_original"),
nw.concat_str(
nw.col("a") * 2,
nw.nth(0) * 2,
nw.col("b"),
nw.col("c"),
separator=" ",
ignore_nulls=ignore_nulls, # default behavior is False
).alias("full_sentence"),
),
)
.sort("a")
.select("full_sentence")
.sort("a_original")
.select("a")
)
assert_equal_data(result, {"full_sentence": expected})
assert_equal_data(result, {"a": expected})
1 change: 0 additions & 1 deletion tests/expr_and_series/dt/datetime_attributes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def test_to_date(request: pytest.FixtureRequest, constructor: Constructor) -> No
"pandas_nullable_constructor",
"cudf",
"modin_constructor",
"pyspark",
)
):
request.applymarker(pytest.mark.xfail)
Expand Down
2 changes: 0 additions & 2 deletions tests/expr_and_series/nth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def test_nth(
expected: dict[str, list[int]],
request: pytest.FixtureRequest,
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor) and POLARS_VERSION < (1, 0, 0):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
Expand Down
Loading
Loading