From 0891c5dec7fd8ce0f2e0233fe1c637e49a53f86e Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 12 Jun 2024 17:50:52 +0100 Subject: [PATCH] Add tests covering magic methods of Expr objects (#15996) repr is not stable for now because the pylibcudf datatype repr is not stable (it includes the address). Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/15996 --- python/cudf_polars/cudf_polars/dsl/expr.py | 12 ++-- python/cudf_polars/tests/dsl/__init__.py | 6 ++ python/cudf_polars/tests/dsl/test_expr.py | 76 ++++++++++++++++++++++ 3 files changed, 89 insertions(+), 5 deletions(-) create mode 100644 python/cudf_polars/tests/dsl/__init__.py create mode 100644 python/cudf_polars/tests/dsl/test_expr.py diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index a81cdcbf0c3..13e496136b5 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -134,14 +134,14 @@ def is_equal(self, other: Any) -> bool: True if the two expressions are equal, false otherwise. """ if type(self) is not type(other): - return False + return False # pragma: no cover; __eq__ trips first return self._ctor_arguments(self.children) == other._ctor_arguments( other.children ) def __eq__(self, other: Any) -> bool: """Equality of expressions.""" - if type(self) != type(other) or hash(self) != hash(other): + if type(self) is not type(other) or hash(self) != hash(other): return False else: return self.is_equal(other) @@ -196,7 +196,9 @@ def do_evaluate( are returned during translation to the IR, but for now we are not perfect. """ - raise NotImplementedError(f"Evaluation of {type(self).__name__}") + raise NotImplementedError( + f"Evaluation of expression {type(self).__name__}" + ) # pragma: no cover; translation of unimplemented nodes trips first def evaluate( self, @@ -266,7 +268,7 @@ def collect_agg(self, *, depth: int) -> AggInfo: """ raise NotImplementedError( f"Collecting aggregation info for {type(self).__name__}" - ) + ) # pragma: no cover; check_agg trips first class NamedExpr: @@ -287,7 +289,7 @@ def __hash__(self) -> int: def __repr__(self) -> str: """Repr of the expression.""" - return f"NamedExpr({self.name}, {self.value}" + return f"NamedExpr({self.name}, {self.value})" def __eq__(self, other: Any) -> bool: """Equality of two expressions.""" diff --git a/python/cudf_polars/tests/dsl/__init__.py b/python/cudf_polars/tests/dsl/__init__.py new file mode 100644 index 00000000000..4611d642f14 --- /dev/null +++ b/python/cudf_polars/tests/dsl/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +__all__: list[str] = [] diff --git a/python/cudf_polars/tests/dsl/test_expr.py b/python/cudf_polars/tests/dsl/test_expr.py new file mode 100644 index 00000000000..ddc3ca66d86 --- /dev/null +++ b/python/cudf_polars/tests/dsl/test_expr.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import cudf._lib.pylibcudf as plc + +from cudf_polars.dsl import expr + + +def test_expression_equality_not_expression(): + col = expr.Col(plc.DataType(plc.TypeId.INT8), "a") + assert not (col == "a") # noqa: SIM201 + assert col != "a" + + +@pytest.mark.parametrize("dtype", [plc.TypeId.INT8, plc.TypeId.INT16]) +def test_column_ne_dtypes_differ(dtype): + a = expr.Col(plc.DataType(dtype), "a") + b = expr.Col(plc.DataType(plc.TypeId.FLOAT32), "a") + assert a != b + + +@pytest.mark.parametrize("dtype", [plc.TypeId.INT8, plc.TypeId.INT16]) +def test_column_ne_names_differ(dtype): + a = expr.Col(plc.DataType(dtype), "a") + b = expr.Col(plc.DataType(dtype), "b") + assert a != b + + +@pytest.mark.parametrize("dtype", [plc.TypeId.INT8, plc.TypeId.INT16]) +def test_column_eq_names_eq(dtype): + a = expr.Col(plc.DataType(dtype), "a") + b = expr.Col(plc.DataType(dtype), "a") + assert a == b + + +def test_expr_hashable(): + a = expr.Col(plc.DataType(plc.TypeId.INT8), "a") + b = expr.Col(plc.DataType(plc.TypeId.INT8), "b") + c = expr.Col(plc.DataType(plc.TypeId.FLOAT32), "c") + + collection = {a, b, c} + assert len(collection) == 3 + assert a in collection + assert b in collection + assert c in collection + + +def test_namedexpr_hashable(): + b = expr.NamedExpr("b", expr.Col(plc.DataType(plc.TypeId.INT8), "a")) + c = expr.NamedExpr("c", expr.Col(plc.DataType(plc.TypeId.INT8), "a")) + + collection = {b, c} + + assert len(collection) == 2 + + assert b in collection + assert c in collection + + +def test_namedexpr_ne_values(): + b1 = expr.NamedExpr("b1", expr.Col(plc.DataType(plc.TypeId.INT8), "a")) + b2 = expr.NamedExpr("b2", expr.Col(plc.DataType(plc.TypeId.INT16), "a")) + + assert b1 != b2 + + +@pytest.mark.xfail(reason="pylibcudf datatype repr not stable") +def test_namedexpr_repr_stable(): + b1 = expr.NamedExpr("b1", expr.Col(plc.DataType(plc.TypeId.INT8), "a")) + b2 = expr.NamedExpr("b1", expr.Col(plc.DataType(plc.TypeId.INT8), "a")) + + assert repr(b1) == repr(b2)