Skip to content

Commit

Permalink
chore: use __narwhals_namespace__ more (#1658)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Dec 24, 2024
1 parent a92cc22 commit ccf30e2
Show file tree
Hide file tree
Showing 9 changed files with 18 additions and 36 deletions.
2 changes: 1 addition & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def filter(self: Self, *predicates: IntoArrowExpr, **constraints: Any) -> Self:
predicates, (plx.col(name) == v for name, v in constraints.items())
)
)
# Safety: all_horizontal's expression only returns a single column.
# `[0]` is safe as all_horizontal's expression only returns a single column
mask = expr._call(self)[0]._native_series
return self._from_native_frame(self._native_frame.filter(mask))

Expand Down
4 changes: 1 addition & 3 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,9 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
import pyarrow as pa
import pyarrow.compute as pc

from narwhals._arrow.namespace import ArrowNamespace
from narwhals._expression_parsing import parse_into_expr

plx = ArrowNamespace(backend_version=self._backend_version, version=self._version)

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
Expand Down
6 changes: 2 additions & 4 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,11 @@ def filter(self, *predicates: DaskExpr, **constraints: Any) -> Self:
)
raise NotImplementedError(msg)

from narwhals._dask.namespace import DaskNamespace

plx = DaskNamespace(backend_version=self._backend_version, version=self._version)
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(
*chain(predicates, (plx.col(name) == v for name, v in constraints.items()))
)
# Safety: all_horizontal's expression only returns a single column.
# `[0]` is safe as all_horizontal's expression only returns a single column
mask = expr._call(self)[0]
return self._from_native_frame(self._native_frame.loc[mask])

Expand Down
4 changes: 1 addition & 3 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,9 @@ def __init__(
self._version = version

def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]:
from narwhals._dask.namespace import DaskNamespace
from narwhals._expression_parsing import parse_into_expr

plx = DaskNamespace(backend_version=self._backend_version, version=self._version)

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
condition = cast("dask_expr.Series", condition)
try:
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def filter(self, *predicates: IntoPandasLikeExpr, **constraints: Any) -> Self:
predicates, (plx.col(name) == v for name, v in constraints.items())
)
)
# Safety: all_horizontal's expression only returns a single column.
# `[0]` is safe as all_horizontal's expression only returns a single column
mask = expr._call(self)[0]
_mask = validate_dataframe_comparand(self._native_frame.index, mask)
return self._from_native_frame(self._native_frame.loc[_mask])
Expand Down Expand Up @@ -1006,7 +1006,6 @@ def explode(self: Self, columns: str | Sequence[str], *more_columns: str) -> Sel
]

plx = self.__native_namespace__()

return self._from_native_frame(
plx.concat([exploded_frame, *exploded_series], axis=1)[original_columns]
)
8 changes: 1 addition & 7 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,15 +503,9 @@ def __init__(

def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
from narwhals._expression_parsing import parse_into_expr
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.utils import broadcast_align_and_extract_native

plx = PandasLikeNamespace(
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)

plx = df.__narwhals_namespace__()
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
Expand Down
8 changes: 2 additions & 6 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,16 @@ def select(
return self._from_native_frame(self._native_frame.select(*new_columns_list))

def filter(self, *predicates: SparkLikeExpr) -> Self:
from narwhals._spark_like.namespace import SparkLikeNamespace

if (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
msg = "`LazyFrame.filter` is not supported for PySpark backend with boolean masks."
raise NotImplementedError(msg)
plx = SparkLikeNamespace(
backend_version=self._backend_version, version=self._version
)
plx = self.__narwhals_namespace__()
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
# `[0]` is safe as all_horizontal's expression only returns a single column
condition = expr._call(self)[0]
spark_df = self._native_frame.where(condition)
return self._from_native_frame(spark_df)
Expand Down
9 changes: 1 addition & 8 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_pandas_fixed_offset_1302() -> None:
def test_huge_int() -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
if POLARS_VERSION >= (1, 18): # pragma: no cover
result = nw.from_native(df).schema
result = nw.from_native(df.select(pl.col("a").cast(pl.Int128))).schema
assert result["a"] == nw.Int128
else: # pragma: no cover
# Int128 was not available yet
Expand All @@ -221,13 +221,6 @@ def test_huge_int() -> None:
result = nw.from_native(rel).schema
assert result["a"] == nw.UInt128

if POLARS_VERSION >= (1, 18): # pragma: no cover
result = nw.from_native(df).schema
assert result["a"] == nw.UInt128
else: # pragma: no cover
# UInt128 was not available yet
pass

# TODO(unassigned): once other libraries support Int128/UInt128,
# add tests for them too

Expand Down
10 changes: 8 additions & 2 deletions tests/frame/concat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,15 @@ def test_concat_vertical(constructor: Constructor) -> None:
with pytest.raises(ValueError, match="No items"):
nw.concat([], how="vertical")

with pytest.raises((Exception, TypeError), match="unable to vstack"):
with pytest.raises(
(Exception, TypeError),
match="unable to vstack|inputs should all have the same schema",
):
nw.concat([df_left, df_right.rename({"d": "i"})], how="vertical").collect()
with pytest.raises((Exception, TypeError), match="unable to vstack|unable to append"):
with pytest.raises(
(Exception, TypeError),
match="unable to vstack|unable to append|inputs should all have the same schema",
):
nw.concat([df_left, df_left.select("d")], how="vertical").collect()


Expand Down

0 comments on commit ccf30e2

Please sign in to comment.