Skip to content

Commit

Permalink
fixup q3
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 20, 2024
1 parent 85bfc03 commit 67698aa
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 31 deletions.
3 changes: 3 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def select(
def rename(self, mapping: dict[str, str]) -> Self:
return self._from_dataframe(self._dataframe.rename(mapping))

def head(self, n: int) -> Self:
return self._from_dataframe(self._dataframe.head(n))

def drop(self, *columns: str | Iterable[str]) -> Self:
return self._from_dataframe(self._dataframe.drop(*columns))

Expand Down
8 changes: 8 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,11 @@ def test_expr_na(df_raw: Any) -> None:
)
expected = {"a": [2], "b": [6], "z": [9]}
compare_dicts(result_nna, expected)


@pytest.mark.parametrize("df_raw", [df_pandas, df_lazy])
def test_head(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
result = nw.to_native(df.head(2))
expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]}
compare_dicts(result, expected)
33 changes: 16 additions & 17 deletions tpch/q2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import polars
import pandas as pd

from narwhals import translate_frame, get_namespace, to_native
import narwhals as nw

polars.Config.set_tbl_cols(10)
pd.set_option("display.max_columns", 10)
Expand All @@ -21,21 +21,20 @@ def q2(
var_2 = "BRASS"
var_3 = "EUROPE"

region_ds = translate_frame(region_ds_raw, is_lazy=True)
nation_ds = translate_frame(nation_ds_raw, is_lazy=True)
supplier_ds = translate_frame(supplier_ds_raw, is_lazy=True)
part_ds = translate_frame(part_ds_raw, is_lazy=True)
part_supp_ds = translate_frame(part_supp_ds_raw, is_lazy=True)
pl = get_namespace(region_ds)
region_ds = nw.LazyFrame(region_ds_raw)
nation_ds = nw.LazyFrame(nation_ds_raw)
supplier_ds = nw.LazyFrame(supplier_ds_raw)
part_ds = nw.LazyFrame(part_ds_raw)
part_supp_ds = nw.LazyFrame(part_supp_ds_raw)

result_q2 = (
part_ds.join(part_supp_ds, left_on="p_partkey", right_on="ps_partkey")
.join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey")
.join(nation_ds, left_on="s_nationkey", right_on="n_nationkey")
.join(region_ds, left_on="n_regionkey", right_on="r_regionkey")
.filter(pl.col("p_size") == var_1)
.filter(pl.col("p_type").str.ends_with(var_2))
.filter(pl.col("r_name") == var_3)
.filter(nw.col("p_size") == var_1)
.filter(nw.col("p_type").str.ends_with(var_2))
.filter(nw.col("r_name") == var_3)
).cache()

final_cols = [
Expand All @@ -51,7 +50,7 @@ def q2(

q_final = (
result_q2.group_by("p_partkey")
.agg(pl.min("ps_supplycost").alias("ps_supplycost"))
.agg(nw.min("ps_supplycost").alias("ps_supplycost"))
.join(
result_q2,
left_on=["p_partkey", "ps_supplycost"],
Expand All @@ -65,14 +64,14 @@ def q2(
.head(100)
)

return to_native(q_final.collect())
return nw.to_native(q_final.collect())


region_ds = polars.scan_parquet("../tpch-data/region.parquet")
ration_ds = polars.scan_parquet("../tpch-data/nation.parquet")
supplier_ds = polars.scan_parquet("../tpch-data/supplier.parquet")
part_ds = polars.scan_parquet("../tpch-data/part.parquet")
part_supp_ds = polars.scan_parquet("../tpch-data/partsupp.parquet")
region_ds = polars.scan_parquet("../tpch-data/s1/region.parquet")
ration_ds = polars.scan_parquet("../tpch-data/s1/nation.parquet")
supplier_ds = polars.scan_parquet("../tpch-data/s1/supplier.parquet")
part_ds = polars.scan_parquet("../tpch-data/s1/part.parquet")
part_supp_ds = polars.scan_parquet("../tpch-data/s1/partsupp.parquet")
print(
q2(
region_ds.collect().to_pandas(),
Expand Down
28 changes: 14 additions & 14 deletions tpch/q3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import polars
import pandas as pd

from narwhals import translate_frame
import narwhals as nw
import polars

polars.Config.set_tbl_cols(10)
Expand All @@ -21,24 +21,24 @@ def q3(
var_1 = var_2 = datetime(1995, 3, 15)
var_3 = "BUILDING"

customer_ds, pl = translate_frame(customer_ds_raw, is_lazy=True)
line_item_ds, _ = translate_frame(line_item_ds_raw, is_lazy=True)
orders_ds, _ = translate_frame(orders_ds_raw, is_lazy=True)
customer_ds = nw.LazyFrame(customer_ds_raw)
line_item_ds = nw.LazyFrame(line_item_ds_raw)
orders_ds = nw.LazyFrame(orders_ds_raw)

q_final = (
customer_ds.filter(pl.col("c_mktsegment") == var_3)
customer_ds.filter(nw.col("c_mktsegment") == var_3)
.join(orders_ds, left_on="c_custkey", right_on="o_custkey")
.join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey")
.filter(pl.col("o_orderdate") < var_2)
.filter(pl.col("l_shipdate") > var_1)
.filter(nw.col("o_orderdate") < var_2)
.filter(nw.col("l_shipdate") > var_1)
.with_columns(
(pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("revenue")
(nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue")
)
.group_by(["o_orderkey", "o_orderdate", "o_shippriority"])
.agg([pl.sum("revenue")])
.agg([nw.sum("revenue")])
.select(
[
pl.col("o_orderkey").alias("l_orderkey"),
nw.col("o_orderkey").alias("l_orderkey"),
"revenue",
"o_orderdate",
"o_shippriority",
Expand All @@ -48,12 +48,12 @@ def q3(
.head(10)
)

return q_final.collect().to_native()
return nw.to_native(q_final.collect())


customer_ds = polars.scan_parquet("../tpch-data/customer.parquet")
lineitem_ds = polars.scan_parquet("../tpch-data/lineitem.parquet")
orders_ds = polars.scan_parquet("../tpch-data/orders.parquet")
customer_ds = polars.scan_parquet("../tpch-data/s1/customer.parquet")
lineitem_ds = polars.scan_parquet("../tpch-data/s1/lineitem.parquet")
orders_ds = polars.scan_parquet("../tpch-data/s1/orders.parquet")
print(
q3(
customer_ds.collect().to_pandas(),
Expand Down

0 comments on commit 67698aa

Please sign in to comment.