Skip to content

Commit

Permalink
Ensure comparisons with pyints and integer series always succeed
Browse files Browse the repository at this point in the history
When Python integers are compared to a series of integers, the result
can always be correctly defined no matter the values of the Python
integer.

This was always a very mild issue.  But with NumPy 2 behavior not
upcasting the computation result type based on the value anymore,
even things like:
```
cudf.Series([1, 2, 3], dtype="int8") < 1000
```
would fail.

N.B. NumPy/pandas also support exact comparisons when mixing e.g.
uint64 and int64.  This is another rare exception that cudf currently
does not support.
  • Loading branch information
seberg committed Aug 12, 2024
1 parent 45b20d1 commit 029dcd4
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 14 deletions.
54 changes: 40 additions & 14 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,53 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
np.bool_: np.float32,
}

out_dtype = None
if op in {"__truediv__", "__rtruediv__"}:
# Division with integer types results in a suitable float.
if truediv_type := int_float_dtype_mapping.get(self.dtype.type):
return self.astype(truediv_type)._binaryop(other, op)
elif op in {
"__lt__",
"__gt__",
"__le__",
"__ge__",
"__eq__",
"__ne__"
}:
out_dtype = "bool"

# If `other` is a Python integer and it is out-of-bounds
# promotion could fail but we can trivially define the result
# in terms of `notnull` or `NULL_NOT_EQUALS`.
if type(other) is int and self.dtype.kind in "iu": # noqa: E721
truthiness = None
iinfo = np.iinfo(self.dtype)
if iinfo.min > other:
truthiness = op in {"__ne__", "__gt__", "__ge__"}
elif iinfo.max < other:
truthiness = op in {"__ne__", "__lt__", "__le__"}

# Compare with minimum value so that the result is true/false
if truthiness is True:
other = iinfo.min
op = "__ge__"
elif truthiness is False:
other = iinfo.min
op = "__lt__"

elif op in {"NULL_EQUALS", "NULL_NOT_EQUALS"}:
out_dtype = "bool"

reflect, op = self._check_reflected_op(op)
if (other := self._wrap_binop_normalization(other)) is NotImplemented:
return NotImplemented
out_dtype = self.dtype
if other is not None:

if out_dtype is not None:
pass # out_dtype was already set to bool
if other is None:
# not a binary operator, so no need to promote
out_dtype = self.dtype
elif out_dtype is None:
out_dtype = np.result_type(self.dtype, other.dtype)
if op in {"__mod__", "__floordiv__"}:
tmp = self if reflect else other
Expand All @@ -225,17 +262,6 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
out_dtype = cudf.dtype("float64")
elif is_scalar(tmp) and tmp == 0:
out_dtype = cudf.dtype("float64")
if op in {
"__lt__",
"__gt__",
"__le__",
"__ge__",
"__eq__",
"__ne__",
"NULL_EQUALS",
"NULL_NOT_EQUALS",
}:
out_dtype = "bool"

if op in {"__and__", "__or__", "__xor__"}:
if self.dtype.kind == "f" or other.dtype.kind == "f":
Expand All @@ -247,7 +273,7 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
if self.dtype.kind == "b" or other.dtype.kind == "b":
out_dtype = "bool"

if (
elif (
op == "__pow__"
and self.dtype.kind in "iu"
and (is_integer(other) or other.dtype.kind in "iu")
Expand Down
41 changes: 41 additions & 0 deletions python/cudf/cudf/tests/test_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,47 @@ def test_series_compare(cmpop, obj_class, dtype):
np.testing.assert_equal(result3.to_numpy(), cmpop(arr1, arr2))


@pytest.mark.parametrize(
"dtype,val",
[("int8", 200), ("int32", 2**32), ("uint8", -128), ("uint64", -1)],
)
@pytest.mark.parametrize(
"op",
[
operator.eq,
operator.ne,
operator.lt,
operator.le,
operator.gt,
operator.ge,
],
)
@pytest.mark.parametrize("reverse", [False, True])
def test_series_compare_integer(dtype, val, op, reverse):
# Tests that these actually work, even though they are out of bound.
force_cast_val = np.array(val).astype(dtype)
sr = Series(
[np.iinfo(dtype).min, np.iinfo(dtype).max, force_cast_val, None],
dtype=dtype,
)

if reverse:
_op = op

def op(x, y):
return _op(y, x)

# We expect the same result as comparing to a value within range (e.g. 0)
# except that a NULL value evaluates to False
if op(0, val):
expected = Series([True, True, True, None])
else:
expected = Series([False, False, False, None])

res = op(sr, val)
assert_eq(res, expected)


def _series_compare_nulls_typegen():
return [
*combinations_with_replacement(DATETIME_TYPES, 2),
Expand Down

0 comments on commit 029dcd4

Please sign in to comment.