From 67698aafff90aa96621ab15337a4360065ebd392 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 20 Mar 2024 15:01:55 +0000 Subject: [PATCH] fixup q3 --- narwhals/dataframe.py | 3 +++ tests/test_common.py | 8 ++++++++ tpch/q2.py | 33 ++++++++++++++++----------------- tpch/q3.py | 28 ++++++++++++++-------------- 4 files changed, 41 insertions(+), 31 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 87ec52ef5..45bd37e05 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -104,6 +104,9 @@ def select( def rename(self, mapping: dict[str, str]) -> Self: return self._from_dataframe(self._dataframe.rename(mapping)) + def head(self, n: int) -> Self: + return self._from_dataframe(self._dataframe.head(n)) + def drop(self, *columns: str | Iterable[str]) -> Self: return self._from_dataframe(self._dataframe.drop(*columns)) diff --git a/tests/test_common.py b/tests/test_common.py index 040cb6c49..011afaf11 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -341,3 +341,11 @@ def test_expr_na(df_raw: Any) -> None: ) expected = {"a": [2], "b": [6], "z": [9]} compare_dicts(result_nna, expected) + + +@pytest.mark.parametrize("df_raw", [df_pandas, df_lazy]) +def test_head(df_raw: Any) -> None: + df = nw.LazyFrame(df_raw) + result = nw.to_native(df.head(2)) + expected = {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]} + compare_dicts(result, expected) diff --git a/tpch/q2.py b/tpch/q2.py index ffa0a9e8b..933f96391 100644 --- a/tpch/q2.py +++ b/tpch/q2.py @@ -4,7 +4,7 @@ import polars import pandas as pd -from narwhals import translate_frame, get_namespace, to_native +import narwhals as nw polars.Config.set_tbl_cols(10) pd.set_option("display.max_columns", 10) @@ -21,21 +21,20 @@ def q2( var_2 = "BRASS" var_3 = "EUROPE" - region_ds = translate_frame(region_ds_raw, is_lazy=True) - nation_ds = translate_frame(nation_ds_raw, is_lazy=True) - supplier_ds = translate_frame(supplier_ds_raw, is_lazy=True) - part_ds = translate_frame(part_ds_raw, is_lazy=True) - part_supp_ds = translate_frame(part_supp_ds_raw, is_lazy=True) - pl = get_namespace(region_ds) + region_ds = nw.LazyFrame(region_ds_raw) + nation_ds = nw.LazyFrame(nation_ds_raw) + supplier_ds = nw.LazyFrame(supplier_ds_raw) + part_ds = nw.LazyFrame(part_ds_raw) + part_supp_ds = nw.LazyFrame(part_supp_ds_raw) result_q2 = ( part_ds.join(part_supp_ds, left_on="p_partkey", right_on="ps_partkey") .join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey") .join(nation_ds, left_on="s_nationkey", right_on="n_nationkey") .join(region_ds, left_on="n_regionkey", right_on="r_regionkey") - .filter(pl.col("p_size") == var_1) - .filter(pl.col("p_type").str.ends_with(var_2)) - .filter(pl.col("r_name") == var_3) + .filter(nw.col("p_size") == var_1) + .filter(nw.col("p_type").str.ends_with(var_2)) + .filter(nw.col("r_name") == var_3) ).cache() final_cols = [ @@ -51,7 +50,7 @@ def q2( q_final = ( result_q2.group_by("p_partkey") - .agg(pl.min("ps_supplycost").alias("ps_supplycost")) + .agg(nw.min("ps_supplycost").alias("ps_supplycost")) .join( result_q2, left_on=["p_partkey", "ps_supplycost"], @@ -65,14 +64,14 @@ def q2( .head(100) ) - return to_native(q_final.collect()) + return nw.to_native(q_final.collect()) -region_ds = polars.scan_parquet("../tpch-data/region.parquet") -ration_ds = polars.scan_parquet("../tpch-data/nation.parquet") -supplier_ds = polars.scan_parquet("../tpch-data/supplier.parquet") -part_ds = polars.scan_parquet("../tpch-data/part.parquet") -part_supp_ds = polars.scan_parquet("../tpch-data/partsupp.parquet") +region_ds = polars.scan_parquet("../tpch-data/s1/region.parquet") +ration_ds = polars.scan_parquet("../tpch-data/s1/nation.parquet") +supplier_ds = polars.scan_parquet("../tpch-data/s1/supplier.parquet") +part_ds = polars.scan_parquet("../tpch-data/s1/part.parquet") +part_supp_ds = polars.scan_parquet("../tpch-data/s1/partsupp.parquet") print( q2( region_ds.collect().to_pandas(), diff --git a/tpch/q3.py b/tpch/q3.py index 689ba55ab..9f2591f98 100644 --- a/tpch/q3.py +++ b/tpch/q3.py @@ -6,7 +6,7 @@ import polars import pandas as pd -from narwhals import translate_frame +import narwhals as nw import polars polars.Config.set_tbl_cols(10) @@ -21,24 +21,24 @@ def q3( var_1 = var_2 = datetime(1995, 3, 15) var_3 = "BUILDING" - customer_ds, pl = translate_frame(customer_ds_raw, is_lazy=True) - line_item_ds, _ = translate_frame(line_item_ds_raw, is_lazy=True) - orders_ds, _ = translate_frame(orders_ds_raw, is_lazy=True) + customer_ds = nw.LazyFrame(customer_ds_raw) + line_item_ds = nw.LazyFrame(line_item_ds_raw) + orders_ds = nw.LazyFrame(orders_ds_raw) q_final = ( - customer_ds.filter(pl.col("c_mktsegment") == var_3) + customer_ds.filter(nw.col("c_mktsegment") == var_3) .join(orders_ds, left_on="c_custkey", right_on="o_custkey") .join(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") - .filter(pl.col("o_orderdate") < var_2) - .filter(pl.col("l_shipdate") > var_1) + .filter(nw.col("o_orderdate") < var_2) + .filter(nw.col("l_shipdate") > var_1) .with_columns( - (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("revenue") + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue") ) .group_by(["o_orderkey", "o_orderdate", "o_shippriority"]) - .agg([pl.sum("revenue")]) + .agg([nw.sum("revenue")]) .select( [ - pl.col("o_orderkey").alias("l_orderkey"), + nw.col("o_orderkey").alias("l_orderkey"), "revenue", "o_orderdate", "o_shippriority", @@ -48,12 +48,12 @@ def q3( .head(10) ) - return q_final.collect().to_native() + return nw.to_native(q_final.collect()) -customer_ds = polars.scan_parquet("../tpch-data/customer.parquet") -lineitem_ds = polars.scan_parquet("../tpch-data/lineitem.parquet") -orders_ds = polars.scan_parquet("../tpch-data/orders.parquet") +customer_ds = polars.scan_parquet("../tpch-data/s1/customer.parquet") +lineitem_ds = polars.scan_parquet("../tpch-data/s1/lineitem.parquet") +orders_ds = polars.scan_parquet("../tpch-data/s1/orders.parquet") print( q3( customer_ds.collect().to_pandas(),