Skip to content

Commit

Permalink
remove packaging
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 21, 2024
1 parent a21638c commit e70f700
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
5 changes: 2 additions & 3 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from typing import Callable
from typing import Iterable

from packaging.version import parse

from narwhals.pandas_like.utils import dataframe_from_dict
from narwhals.pandas_like.utils import evaluate_simple_aggregation
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import is_simple_aggregation
from narwhals.pandas_like.utils import item
from narwhals.pandas_like.utils import parse_into_exprs
from narwhals.utils import parse_version
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
Expand Down Expand Up @@ -144,7 +143,7 @@ def func(df: Any) -> Any:
UserWarning,
stacklevel=2,
)
if parse(pd.__version__) < parse("2.2.0"):
if parse_version(pd.__version__) < parse_version("2.2.0"):
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)
Expand Down
9 changes: 9 additions & 0 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import re
from typing import Any
from typing import Iterable
from typing import Sequence


def remove_prefix(text: str, prefix: str) -> str:
Expand Down Expand Up @@ -50,3 +52,10 @@ def flatten_bool(*args: bool | Iterable[bool]) -> list[bool]:
raise TypeError(msg)
out.append(item)
return out


def parse_version(version: Sequence[str | int]) -> tuple[int, ...]:
"""Simple version parser; split into a tuple of ints for comparison."""
if isinstance(version, str):
version = version.split(".")
return tuple(int(re.sub(r"\D", "", str(v))) for v in version)
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"packaging; python_version < '3.9'",
]

[project.urls]
"Homepage" = "https://github.com/MarcoGorelli/narwhals"
Expand Down
24 changes: 12 additions & 12 deletions tpch/q5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import polars

from narwhals import translate_frame
import narwhals as nw


def q5(
Expand All @@ -19,12 +19,12 @@ def q5(
var_2 = datetime(1994, 1, 1)
var_3 = datetime(1995, 1, 1)

region_ds, pl = translate_frame(region_ds_raw, is_lazy=True)
nation_ds, _ = translate_frame(nation_ds_raw, is_lazy=True)
customer_ds, _ = translate_frame(customer_ds_raw, is_lazy=True)
line_item_ds, _ = translate_frame(lineitem_ds_raw, is_lazy=True)
orders_ds, _ = translate_frame(orders_ds_raw, is_lazy=True)
supplier_ds, _ = translate_frame(supplier_ds_raw, is_lazy=True)
region_ds = nw.LazyFrame(region_ds_raw)
nation_ds = nw.LazyFrame(nation_ds_raw)
customer_ds = nw.LazyFrame(customer_ds_raw)
line_item_ds = nw.LazyFrame(lineitem_ds_raw)
orders_ds = nw.LazyFrame(orders_ds_raw)
supplier_ds = nw.LazyFrame(supplier_ds_raw)

result = (
region_ds.join(nation_ds, left_on="r_regionkey", right_on="n_regionkey")
Expand All @@ -36,17 +36,17 @@ def q5(
left_on=["l_suppkey", "n_nationkey"],
right_on=["s_suppkey", "s_nationkey"],
)
.filter(pl.col("r_name") == var_1)
.filter(pl.col("o_orderdate").is_between(var_2, var_3, closed="left"))
.filter(nw.col("r_name") == var_1)
.filter(nw.col("o_orderdate").is_between(var_2, var_3, closed="left"))
.with_columns(
(pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("revenue")
(nw.col("l_extendedprice") * (1 - nw.col("l_discount"))).alias("revenue")
)
.group_by("n_name")
.agg([pl.sum("revenue")])
.agg([nw.sum("revenue")])
.sort(by="revenue", descending=True)
)

return result.collect().to_native()
return nw.to_native(result.collect())


region_ds = polars.scan_parquet("../tpch-data/region.parquet")
Expand Down

0 comments on commit e70f700

Please sign in to comment.