Skip to content

Commit

Permalink
fix: cast keys to list in dask joins, add dask to tpch benchmarks (#957)
Browse files Browse the repository at this point in the history
* replace polars lazy with eager in q19, add dask for some queries that allow it

* fix dask left join for string keys, replace some queries with polars eager

* fix semi-join with str keys

* add unique method to dask, todo: understand dtype warning

* remove unique for dask expr, modify q20 and q22 to use unique for dataframes, add semi-join test for str keys

* change test for left join, convert to lazy polars in tpch

* convert str to list[str] at the beginning of join for dask
  • Loading branch information
raisadz authored Sep 12, 2024
1 parent aed2d51 commit 8d4c658
Show file tree
Hide file tree
Showing 25 changed files with 138 additions and 6 deletions.
4 changes: 4 additions & 0 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def join(
right_on: str | list[str] | None,
suffix: str,
) -> Self:
if isinstance(left_on, str):
left_on = [left_on]
if isinstance(right_on, str):
right_on = [right_on]
if how == "cross":
key_token = generate_unique_token(
n_bytes=8, columns=[*self.columns, *other.columns]
Expand Down
5 changes: 5 additions & 0 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ def round(self, decimals: int) -> Self:
returns_scalar=False,
)

def unique(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.unique` is not supported for the Dask backend. Please use `LazyFrame.unique` instead."
raise NotImplementedError(msg)

def drop_nulls(self) -> NoReturn:
# We can't (yet?) allow methods which modify the index
msg = "`Expr.drop_nulls` is not supported for the Dask backend. Please use `LazyFrame.drop_nulls` instead."
Expand Down
26 changes: 24 additions & 2 deletions tests/frame/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def test_anti_join(
@pytest.mark.parametrize(
("join_key", "filter_expr", "expected"),
[
(
"antananarivo",
(nw.col("bob") > 5),
{"antananarivo": [2], "bob": [6], "zorro": [9]},
),
(
["antananarivo"],
(nw.col("bob") > 5),
Expand Down Expand Up @@ -222,10 +227,14 @@ def test_left_join(constructor: Any) -> None:
"bob": [4.0, 5, 6],
"index": [0.0, 1.0, 2.0],
}
data_right = {"antananarivo": [1.0, 2, 3], "c": [4.0, 5, 7], "index": [0.0, 1.0, 2.0]}
data_right = {
"antananarivo": [1.0, 2, 3],
"co": [4.0, 5, 7],
"index": [0.0, 1.0, 2.0],
}
df_left = nw.from_native(constructor(data_left))
df_right = nw.from_native(constructor(data_right))
result = df_left.join(df_right, left_on="bob", right_on="c", how="left").select( # type: ignore[arg-type]
result = df_left.join(df_right, left_on="bob", right_on="co", how="left").select( # type: ignore[arg-type]
nw.all().fill_null(float("nan"))
)
result = result.sort("index")
Expand All @@ -236,7 +245,20 @@ def test_left_join(constructor: Any) -> None:
"antananarivo_right": [1, 2, float("nan")],
"index": [0, 1, 2],
}
result_on_list = df_left.join(
df_right, # type: ignore[arg-type]
on=["antananarivo", "index"],
how="left",
).select(nw.all().fill_null(float("nan")))
result_on_list = result_on_list.sort("index")
expected_on_list = {
"antananarivo": [1, 2, 3],
"bob": [4, 5, 6],
"index": [0, 1, 2],
"co": [4, 5, 7],
}
compare_dicts(result, expected)
compare_dicts(result_on_list, expected_on_list)


@pytest.mark.filterwarnings("ignore: the default coalesce behavior")
Expand Down
4 changes: 4 additions & 0 deletions tpch/execute/q10.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q10.query(fn(customer), fn(nation), fn(lineitem), fn(orders)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q10.query(fn(customer), fn(nation), fn(lineitem), fn(orders)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q11.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q11.query(fn(nation), fn(partsupp), fn(supplier)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q11.query(fn(nation), fn(partsupp), fn(supplier)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q12.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q12.query(fn(line_item), fn(orders)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q12.query(fn(line_item), fn(orders)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q13.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q13.query(fn(customer), fn(orders)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q13.query(fn(customer), fn(orders)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q14.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q14.query(fn(line_item), fn(part)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q14.query(fn(line_item), fn(part)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q15.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q15.query(fn(lineitem), fn(supplier)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q15.query(fn(lineitem), fn(supplier)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q16.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q16.query(fn(part), fn(partsupp), fn(supplier)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q16.query(fn(part), fn(partsupp), fn(supplier)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q17.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q17.query(fn(lineitem), fn(part)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q17.query(fn(lineitem), fn(part)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q18.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q18.query(fn(customer), fn(lineitem), fn(orders)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q18.query(fn(customer), fn(lineitem), fn(orders)).compute())
3 changes: 3 additions & 0 deletions tpch/execute/q19.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@

fn = IO_FUNCS["pyarrow"]
print(q19.query(fn(lineitem), fn(part)))

fn = IO_FUNCS["dask"]
print(q19.query(fn(lineitem), fn(part)).compute())
3 changes: 3 additions & 0 deletions tpch/execute/q20.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@

fn = IO_FUNCS["pyarrow"]
print(q20.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier)))

fn = IO_FUNCS["dask"]
print(q20.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(supplier)).compute())
3 changes: 3 additions & 0 deletions tpch/execute/q21.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@

fn = IO_FUNCS["pyarrow"]
print(q21.query(fn(lineitem), fn(nation), fn(orders), fn(supplier)))

fn = IO_FUNCS["dask"]
print(q21.query(fn(lineitem), fn(nation), fn(orders), fn(supplier)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q22.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q22.query(fn(customer), fn(orders)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q22.query(fn(customer), fn(orders)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q3.query(fn(customer), fn(lineitem), fn(orders)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q3.query(fn(customer), fn(lineitem), fn(orders)).compute())
4 changes: 4 additions & 0 deletions tpch/execute/q4.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q4.query(fn(line_item), fn(orders)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q4.query(fn(line_item), fn(orders)).compute())
8 changes: 8 additions & 0 deletions tpch/execute/q5.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@
fn(region), fn(nation), fn(customer), fn(line_item), fn(orders), fn(supplier)
)
)

tool = "dask"
fn = IO_FUNCS[tool]
print(
q5.query(
fn(region), fn(nation), fn(customer), fn(line_item), fn(orders), fn(supplier)
).compute()
)
4 changes: 4 additions & 0 deletions tpch/execute/q6.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q6.query(fn(lineitem)))

tool = "dask"
fn = IO_FUNCS[tool]
print(q6.query(fn(lineitem)).compute())
6 changes: 6 additions & 0 deletions tpch/execute/q7.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,9 @@
tool = "pyarrow"
fn = IO_FUNCS[tool]
print(q7.query(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier)))

tool = "dask"
fn = IO_FUNCS[tool]
print(
q7.query(fn(nation), fn(customer), fn(lineitem), fn(orders), fn(supplier)).compute()
)
14 changes: 14 additions & 0 deletions tpch/execute/q8.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,17 @@
fn(region),
)
)

tool = "dask"
fn = IO_FUNCS[tool]
print(
q8.query(
fn(part),
fn(supplier),
fn(lineitem),
fn(orders),
fn(customer),
fn(nation),
fn(region),
).compute()
)
8 changes: 8 additions & 0 deletions tpch/execute/q9.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@
print(
q9.query(fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier))
)

tool = "dask"
fn = IO_FUNCS[tool]
print(
q9.query(
fn(part), fn(partsupp), fn(nation), fn(lineitem), fn(orders), fn(supplier)
).compute()
)
6 changes: 4 additions & 2 deletions tpch/queries/q20.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ def query(

return (
part_ds.filter(nw.col("p_name").str.starts_with(var4))
.select(nw.col("p_partkey").unique())
.select("p_partkey")
.unique("p_partkey")
.join(partsupp_ds, left_on="p_partkey", right_on="ps_partkey")
.join(
query1,
left_on=["ps_suppkey", "p_partkey"],
right_on=["l_suppkey", "l_partkey"],
)
.filter(nw.col("ps_availqty") > nw.col("sum_quantity"))
.select(nw.col("ps_suppkey").unique())
.select("ps_suppkey")
.unique("ps_suppkey")
.join(query3, left_on="ps_suppkey", right_on="s_suppkey")
.select("s_name", "s_address")
.sort("s_name")
Expand Down
6 changes: 4 additions & 2 deletions tpch/queries/q22.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ def query(customer_ds: FrameT, orders_ds: FrameT) -> FrameT:
nw.col("c_acctbal").mean().alias("avg_acctbal")
)

q3 = orders_ds.select(nw.col("o_custkey").unique()).with_columns(
nw.col("o_custkey").alias("c_custkey")
q3 = (
orders_ds.select("o_custkey")
.unique("o_custkey")
.with_columns(nw.col("o_custkey").alias("c_custkey"))
)

return (
Expand Down

0 comments on commit 8d4c658

Please sign in to comment.