Skip to content

Commit

Permalink
Add tests covering magic methods of Expr objects (#15996)
Browse files Browse the repository at this point in the history
repr is not stable for now because the pylibcudf datatype repr is not stable (it includes the address).

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #15996
  • Loading branch information
wence- authored Jun 12, 2024
1 parent c0c2ad3 commit 0891c5d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 5 deletions.
12 changes: 7 additions & 5 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ def is_equal(self, other: Any) -> bool:
True if the two expressions are equal, false otherwise.
"""
if type(self) is not type(other):
return False
return False # pragma: no cover; __eq__ trips first
return self._ctor_arguments(self.children) == other._ctor_arguments(
other.children
)

def __eq__(self, other: Any) -> bool:
"""Equality of expressions."""
if type(self) != type(other) or hash(self) != hash(other):
if type(self) is not type(other) or hash(self) != hash(other):
return False
else:
return self.is_equal(other)
Expand Down Expand Up @@ -196,7 +196,9 @@ def do_evaluate(
are returned during translation to the IR, but for now we
are not perfect.
"""
raise NotImplementedError(f"Evaluation of {type(self).__name__}")
raise NotImplementedError(
f"Evaluation of expression {type(self).__name__}"
) # pragma: no cover; translation of unimplemented nodes trips first

def evaluate(
self,
Expand Down Expand Up @@ -266,7 +268,7 @@ def collect_agg(self, *, depth: int) -> AggInfo:
"""
raise NotImplementedError(
f"Collecting aggregation info for {type(self).__name__}"
)
) # pragma: no cover; check_agg trips first


class NamedExpr:
Expand All @@ -287,7 +289,7 @@ def __hash__(self) -> int:

def __repr__(self) -> str:
"""Repr of the expression."""
return f"NamedExpr({self.name}, {self.value}"
return f"NamedExpr({self.name}, {self.value})"

def __eq__(self, other: Any) -> bool:
"""Equality of two expressions."""
Expand Down
6 changes: 6 additions & 0 deletions python/cudf_polars/tests/dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

__all__: list[str] = []
76 changes: 76 additions & 0 deletions python/cudf_polars/tests/dsl/test_expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import cudf._lib.pylibcudf as plc

from cudf_polars.dsl import expr


def test_expression_equality_not_expression():
col = expr.Col(plc.DataType(plc.TypeId.INT8), "a")
assert not (col == "a") # noqa: SIM201
assert col != "a"


@pytest.mark.parametrize("dtype", [plc.TypeId.INT8, plc.TypeId.INT16])
def test_column_ne_dtypes_differ(dtype):
a = expr.Col(plc.DataType(dtype), "a")
b = expr.Col(plc.DataType(plc.TypeId.FLOAT32), "a")
assert a != b


@pytest.mark.parametrize("dtype", [plc.TypeId.INT8, plc.TypeId.INT16])
def test_column_ne_names_differ(dtype):
a = expr.Col(plc.DataType(dtype), "a")
b = expr.Col(plc.DataType(dtype), "b")
assert a != b


@pytest.mark.parametrize("dtype", [plc.TypeId.INT8, plc.TypeId.INT16])
def test_column_eq_names_eq(dtype):
a = expr.Col(plc.DataType(dtype), "a")
b = expr.Col(plc.DataType(dtype), "a")
assert a == b


def test_expr_hashable():
a = expr.Col(plc.DataType(plc.TypeId.INT8), "a")
b = expr.Col(plc.DataType(plc.TypeId.INT8), "b")
c = expr.Col(plc.DataType(plc.TypeId.FLOAT32), "c")

collection = {a, b, c}
assert len(collection) == 3
assert a in collection
assert b in collection
assert c in collection


def test_namedexpr_hashable():
b = expr.NamedExpr("b", expr.Col(plc.DataType(plc.TypeId.INT8), "a"))
c = expr.NamedExpr("c", expr.Col(plc.DataType(plc.TypeId.INT8), "a"))

collection = {b, c}

assert len(collection) == 2

assert b in collection
assert c in collection


def test_namedexpr_ne_values():
b1 = expr.NamedExpr("b1", expr.Col(plc.DataType(plc.TypeId.INT8), "a"))
b2 = expr.NamedExpr("b2", expr.Col(plc.DataType(plc.TypeId.INT16), "a"))

assert b1 != b2


@pytest.mark.xfail(reason="pylibcudf datatype repr not stable")
def test_namedexpr_repr_stable():
b1 = expr.NamedExpr("b1", expr.Col(plc.DataType(plc.TypeId.INT8), "a"))
b2 = expr.NamedExpr("b1", expr.Col(plc.DataType(plc.TypeId.INT8), "a"))

assert repr(b1) == repr(b2)

0 comments on commit 0891c5d

Please sign in to comment.