diff --git a/narwhals/pandas_like/group_by.py b/narwhals/pandas_like/group_by.py index a28649b7a..05d04c7af 100644 --- a/narwhals/pandas_like/group_by.py +++ b/narwhals/pandas_like/group_by.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import os import warnings from copy import copy from typing import TYPE_CHECKING @@ -9,9 +8,6 @@ from typing import Callable from typing import Iterable -from narwhals.pandas_like.utils import dataframe_from_dict -from narwhals.pandas_like.utils import evaluate_simple_aggregation -from narwhals.pandas_like.utils import horizontal_concat from narwhals.pandas_like.utils import is_simple_aggregation from narwhals.pandas_like.utils import item from narwhals.pandas_like.utils import parse_into_exprs @@ -43,7 +39,7 @@ def agg( grouped = df.groupby( list(self._keys), sort=False, - as_index=False, + as_index=True, ) implementation: str = self._df._implementation output_names: list[str] = copy(self._keys) @@ -57,23 +53,13 @@ def agg( raise ValueError(msg) output_names.extend(expr._output_names) - if implementation in ("pandas", "modin") and not os.environ.get( - "NARWHALS_FORCE_GENERIC" - ): - return agg_pandas( - grouped, - exprs, - self._keys, - output_names, - self._from_dataframe, - ) - return agg_generic( + return agg_pandas( grouped, exprs, self._keys, output_names, - implementation, self._from_dataframe, + implementation, ) def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame: @@ -85,12 +71,13 @@ def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame: ) -def agg_pandas( +def agg_pandas( # noqa: PLR0913,PLR0915 grouped: Any, exprs: list[PandasExpr], keys: list[str], output_names: list[str], from_dataframe: Callable[[Any], PandasDataFrame], + implementation: Any, ) -> PandasDataFrame: """ This should be the fastpath, but cuDF is too far behind to use it. @@ -100,6 +87,8 @@ def agg_pandas( """ import pandas as pd + from narwhals.pandas_like.namespace import PandasNamespace + simple_aggs = [] complex_aggs = [] for expr in exprs: @@ -113,8 +102,9 @@ def agg_pandas( # e.g. agg(pl.len()) assert expr._output_names is not None for output_name in expr._output_names: - simple_aggregations[output_name] = pd.NamedAgg( - column=keys[0], aggfunc=expr._function_name.replace("len", "size") + simple_aggregations[output_name] = ( + keys[0], + expr._function_name.replace("len", "size"), ) continue @@ -122,8 +112,21 @@ def agg_pandas( assert expr._output_names is not None for root_name, output_name in zip(expr._root_names, expr._output_names): name = remove_prefix(expr._function_name, "col->") - simple_aggregations[output_name] = pd.NamedAgg(column=root_name, aggfunc=name) - result_simple = grouped.agg(**simple_aggregations) if simple_aggregations else None + simple_aggregations[output_name] = (root_name, name) + + if simple_aggregations: + aggs = collections.defaultdict(list) + name_mapping = {} + for output_name, named_agg in simple_aggregations.items(): + aggs[named_agg[0]].append(named_agg[1]) + name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name + result_simple = grouped.agg(aggs) + result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns] + result_simple = result_simple.rename(columns=name_mapping).reset_index() + else: + result_simple = None + + plx = PandasNamespace(implementation=implementation) def func(df: Any) -> Any: out_group = [] @@ -133,7 +136,7 @@ def func(df: Any) -> Any: for result_keys in results_keys: out_group.append(item(result_keys._series)) out_names.append(result_keys.name) - return pd.Series(out_group, index=out_names) + return plx.make_native_series(name="", data=out_group, index=out_names) if complex_aggs: warnings.warn( @@ -143,53 +146,26 @@ def func(df: Any) -> Any: UserWarning, stacklevel=2, ) - if parse_version(pd.__version__) < parse_version("2.2.0"): - result_complex = grouped.apply(func) + if implementation == "pandas": + import pandas as pd + + if parse_version(pd.__version__) < parse_version("2.2.0"): + result_complex = grouped.apply(func) + else: + result_complex = grouped.apply(func, include_groups=False) else: - result_complex = grouped.apply(func, include_groups=False) + result_complex = grouped.apply(func) if result_simple is not None and not complex_aggs: result = result_simple elif result_simple is not None and complex_aggs: result = pd.concat( - [result_simple, result_complex.drop(columns=keys)], + [result_simple, result_complex.reset_index(drop=True)], axis=1, copy=False, ) elif complex_aggs: - result = result_complex + result = result_complex.reset_index() else: raise AssertionError("At least one aggregation should have been passed") return from_dataframe(result.loc[:, output_names]) - - -def agg_generic( # noqa: PLR0913 - grouped: Any, - exprs: list[PandasExpr], - group_by_keys: list[str], - output_names: list[str], - implementation: str, - from_dataframe: Callable[[Any], PandasDataFrame], -) -> PandasDataFrame: - dfs: list[Any] = [] - to_remove: list[int] = [] - for i, expr in enumerate(exprs): - if is_simple_aggregation(expr): - dfs.append(evaluate_simple_aggregation(expr, grouped, group_by_keys)) - to_remove.append(i) - exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove] - - out: dict[str, list[Any]] = collections.defaultdict(list) - for keys, df_keys in grouped: - for key, name in zip(keys, group_by_keys): - out[name].append(key) - for expr in exprs: - results_keys = expr._call(from_dataframe(df_keys)) - for result_keys in results_keys: - out[result_keys.name].append(result_keys.item()) - - results_keys = dataframe_from_dict(out, implementation=implementation) - results_df = horizontal_concat( - [results_keys, *dfs], implementation=implementation - ).loc[:, output_names] - return from_dataframe(results_df) diff --git a/narwhals/pandas_like/namespace.py b/narwhals/pandas_like/namespace.py index bc418f7bd..594592769 100644 --- a/narwhals/pandas_like/namespace.py +++ b/narwhals/pandas_like/namespace.py @@ -33,15 +33,19 @@ class PandasNamespace: Boolean = dtypes.Boolean String = dtypes.String - def Series(self, name: str, data: list[Any]) -> PandasSeries: # noqa: N802 - from narwhals.pandas_like.series import PandasSeries - + def make_native_series(self, name: str, data: list[Any], index: Any) -> Any: if self._implementation == "pandas": import pandas as pd - return PandasSeries( - pd.Series(name=name, data=data), implementation=self._implementation - ) + return pd.Series(name=name, data=data, index=index) + if self._implementation == "modin": + import modin.pandas as mpd + + return mpd.Series(name=name, data=data, index=index) + if self._implementation == "cudf": + import cudf + + return cudf.Series(name=name, data=data, index=index) raise NotImplementedError # --- not in spec --- diff --git a/tests/tpch_q1_test.py b/tests/tpch_q1_test.py index a5d421c67..c94126eab 100644 --- a/tests/tpch_q1_test.py +++ b/tests/tpch_q1_test.py @@ -76,6 +76,74 @@ def test_q1(library: str) -> None: compare_dicts(result, expected) +@pytest.mark.parametrize( + "library", + ["pandas", "polars"], +) +@pytest.mark.filterwarnings( + "ignore:.*Passing a BlockManager.*:DeprecationWarning", + "ignore:.*Complex.*:UserWarning", +) +def test_q1_w_generic_funcs(library: str) -> None: + if library == "pandas": + df_raw = pd.read_parquet("tests/data/lineitem.parquet") + df_raw["l_shipdate"] = pd.to_datetime(df_raw["l_shipdate"]) + elif library == "polars": + df_raw = pl.scan_parquet("tests/data/lineitem.parquet") + var_1 = datetime(1998, 9, 2) + df = nw.LazyFrame(df_raw) + query_result = ( + df.filter(nw.col("l_shipdate") <= var_1) + .with_columns( + charge=( + nw.col("l_extendedprice") + * (1.0 - nw.col("l_discount")) + * (1.0 + nw.col("l_tax")) + ), + ) + .group_by(["l_returnflag", "l_linestatus"]) + .agg( + [ + nw.sum("l_quantity").alias("sum_qty"), + nw.sum("l_extendedprice").alias("sum_base_price"), + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) + .sum() + .alias("sum_disc_price"), + nw.col("charge").sum().alias("sum_charge"), + nw.mean("l_quantity").alias("avg_qty"), + nw.mean("l_extendedprice").alias("avg_price"), + nw.mean("l_discount").alias("avg_disc"), + nw.len().alias("count_order"), + ], + ) + .sort(["l_returnflag", "l_linestatus"]) + ) + result = query_result.collect().to_dict(as_series=False) + expected = { + "l_returnflag": ["A", "N", "N", "R"], + "l_linestatus": ["F", "F", "O", "F"], + "sum_qty": [2109.0, 29.0, 3682.0, 1876.0], + "sum_base_price": [3114026.44, 39824.83, 5517101.99, 2947892.16], + "sum_disc_price": [2954950.8082, 39028.3334, 5205468.4852, 2816542.4816999994], + "sum_charge": [ + 3092840.4194289995, + 39808.900068, + 5406966.873419, + 2935797.8313019997, + ], + "avg_qty": [27.75, 29.0, 25.047619047619047, 26.422535211267604], + "avg_price": [ + 40974.032105263155, + 39824.83, + 37531.30605442177, + 41519.607887323946, + ], + "avg_disc": [0.05039473684210526, 0.02, 0.05537414965986395, 0.04507042253521127], + "count_order": [76, 1, 147, 71], + } + compare_dicts(result, expected) + + @mock.patch.dict(os.environ, {"NARWHALS_FORCE_GENERIC": "1"}) @pytest.mark.filterwarnings("ignore:.*Passing a BlockManager.*:DeprecationWarning") def test_q1_w_pandas_agg_generic_path() -> None: diff --git a/tpch/q3.py b/tpch/q3.py index 9f2591f98..684bc8c53 100644 --- a/tpch/q3.py +++ b/tpch/q3.py @@ -21,9 +21,9 @@ def q3( var_1 = var_2 = datetime(1995, 3, 15) var_3 = "BUILDING" - customer_ds = nw.LazyFrame(customer_ds_raw) - line_item_ds = nw.LazyFrame(line_item_ds_raw) - orders_ds = nw.LazyFrame(orders_ds_raw) + customer_ds = nw.from_native(customer_ds_raw) + line_item_ds = nw.from_native(line_item_ds_raw) + orders_ds = nw.from_native(orders_ds_raw) q_final = ( customer_ds.filter(nw.col("c_mktsegment") == var_3) @@ -48,7 +48,7 @@ def q3( .head(10) ) - return nw.to_native(q_final.collect()) + return nw.to_native(q_final) customer_ds = polars.scan_parquet("../tpch-data/s1/customer.parquet") @@ -66,5 +66,5 @@ def q3( customer_ds, lineitem_ds, orders_ds, - ) + ).collect() )