From e70f70018664d5cc326680968804ef678ad87c58 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 21 Mar 2024 12:44:42 +0000 Subject: [PATCH] remove packaging --- narwhals/pandas_like/group_by.py | 5 ++--- narwhals/utils.py | 9 +++++++++ pyproject.toml | 3 --- tpch/q5.py | 24 ++++++++++++------------ 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/narwhals/pandas_like/group_by.py b/narwhals/pandas_like/group_by.py index 9253683e2..a28649b7a 100644 --- a/narwhals/pandas_like/group_by.py +++ b/narwhals/pandas_like/group_by.py @@ -9,14 +9,13 @@ from typing import Callable from typing import Iterable -from packaging.version import parse - from narwhals.pandas_like.utils import dataframe_from_dict from narwhals.pandas_like.utils import evaluate_simple_aggregation from narwhals.pandas_like.utils import horizontal_concat from narwhals.pandas_like.utils import is_simple_aggregation from narwhals.pandas_like.utils import item from narwhals.pandas_like.utils import parse_into_exprs +from narwhals.utils import parse_version from narwhals.utils import remove_prefix if TYPE_CHECKING: @@ -144,7 +143,7 @@ def func(df: Any) -> Any: UserWarning, stacklevel=2, ) - if parse(pd.__version__) < parse("2.2.0"): + if parse_version(pd.__version__) < parse_version("2.2.0"): result_complex = grouped.apply(func) else: result_complex = grouped.apply(func, include_groups=False) diff --git a/narwhals/utils.py b/narwhals/utils.py index 911559ad7..906498c5c 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -1,7 +1,9 @@ from __future__ import annotations +import re from typing import Any from typing import Iterable +from typing import Sequence def remove_prefix(text: str, prefix: str) -> str: @@ -50,3 +52,10 @@ def flatten_bool(*args: bool | Iterable[bool]) -> list[bool]: raise TypeError(msg) out.append(item) return out + + +def parse_version(version: Sequence[str | int]) -> tuple[int, ...]: + """Simple version parser; split into a tuple of ints for comparison.""" + if isinstance(version, str): + version = version.split(".") + return tuple(int(re.sub(r"\D", "", str(v))) for v in version) diff --git a/pyproject.toml b/pyproject.toml index a35ccf1e8..150789b99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,9 +16,6 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = [ - "packaging; python_version < '3.9'", -] [project.urls] "Homepage" = "https://github.com/MarcoGorelli/narwhals" diff --git a/tpch/q5.py b/tpch/q5.py index ca75b3a20..cb8231484 100644 --- a/tpch/q5.py +++ b/tpch/q5.py @@ -4,7 +4,7 @@ import polars -from narwhals import translate_frame +import narwhals as nw def q5( @@ -19,12 +19,12 @@ def q5( var_2 = datetime(1994, 1, 1) var_3 = datetime(1995, 1, 1) - region_ds, pl = translate_frame(region_ds_raw, is_lazy=True) - nation_ds, _ = translate_frame(nation_ds_raw, is_lazy=True) - customer_ds, _ = translate_frame(customer_ds_raw, is_lazy=True) - line_item_ds, _ = translate_frame(lineitem_ds_raw, is_lazy=True) - orders_ds, _ = translate_frame(orders_ds_raw, is_lazy=True) - supplier_ds, _ = translate_frame(supplier_ds_raw, is_lazy=True) + region_ds = nw.LazyFrame(region_ds_raw) + nation_ds = nw.LazyFrame(nation_ds_raw) + customer_ds = nw.LazyFrame(customer_ds_raw) + line_item_ds = nw.LazyFrame(lineitem_ds_raw) + orders_ds = nw.LazyFrame(orders_ds_raw) + supplier_ds = nw.LazyFrame(supplier_ds_raw) result = ( region_ds.join(nation_ds, left_on="r_regionkey", right_on="n_regionkey") @@ -36,17 +36,17 @@ def q5( left_on=["l_suppkey", "n_nationkey"], right_on=["s_suppkey", "s_nationkey"], ) - .filter(pl.col("r_name") == var_1) - .filter(pl.col("o_orderdate").is_between(var_2, var_3, closed="left")) + .filter(nw.col("r_name") == var_1) + .filter(nw.col("o_orderdate").is_between(var_2, var_3, closed="left")) .with_columns( - (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("revenue") + (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue") ) .group_by("n_name") - .agg([pl.sum("revenue")]) + .agg([nw.sum("revenue")]) .sort(by="revenue", descending=True) ) - return result.collect().to_native() + return nw.to_native(result.collect()) region_ds = polars.scan_parquet("../tpch-data/region.parquet")