Skip to content

Commit

Permalink
Fix true division integer casting. (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Jul 11, 2024
1 parent b824fec commit f63dd6e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 6 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@
Changelog
=========

0.6.1 (unreleased)
------------------

**Bug fixes**

- Division now complies more strictly with the Array API standard by returning a floating-point result regardless of input data types.

0.6.0 (2024-07-11)
------------------

**Other changes**

- :func:`ndonnx.promote_nullable` is now publicly exported
- ``ndonnx.promote_nullable`` is now publicly exported.

0.5.0 (2024-07-01)
------------------
Expand Down
30 changes: 28 additions & 2 deletions ndonnx/_core/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,20 @@ def cosh(self, x):
return _unary_op(x, opx.cosh, dtypes.float32)

def divide(self, x, y):
return _variadic_op([x, y], opx.div, via_dtype=dtypes.float64)
x, y = promote(x, y)
if not isinstance(x.dtype, (dtypes.Numerical, dtypes.NullableNumerical)):
raise TypeError(f"Unsupported dtype for divide: {x.dtype}")
bits = (
ndx.iinfo(x.dtype).bits
if isinstance(x.dtype, (dtypes.Integral, dtypes.NullableIntegral))
else ndx.finfo(x.dtype).bits
)
via_dtype = (
dtypes.float64
if bits > 32 or x.dtype in (dtypes.nuint32, dtypes.uint32)
else dtypes.float32
)
return _variadic_op([x, y], opx.div, via_dtype=via_dtype, cast_return=False)

def equal(self, x, y) -> Array:
x, y = promote(x, y)
Expand All @@ -134,7 +147,20 @@ def floor(self, x):
return x

def floor_divide(self, x, y):
return self.floor(self.divide(x, y))
x, y = promote(x, y)
dtype = x.dtype
out = self.floor(self.divide(x, y))
if isinstance(
dtype,
(
dtypes.Integral,
dtypes.NullableIntegral,
dtypes.Unsigned,
dtypes.NullableUnsigned,
),
):
out = out.astype(dtype)
return out

def greater(self, x, y):
return _via_i64_f64(opx.greater, [x, y], cast_return=False)
Expand Down
6 changes: 3 additions & 3 deletions skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-cross]
array_api_tests/test_has_names.py::test_has_names[linalg-cholesky]
array_api_tests/test_linalg.py
array_api_tests/test_operators_and_elementwise_functions.py::test_atan2
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__itruediv__(x1, x2)]
array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide
array_api_tests/test_operators_and_elementwise_functions.py::test_sign
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_set_functions.py::test_unique_all
Expand Down Expand Up @@ -68,6 +65,8 @@ array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -0 and x2_i is
array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -infinity and x2_i is +infinity) -> roughly -pi/4]
array_api_tests/test_special_cases.py::test_binary[atan2(x1_i is -infinity and x2_i is -infinity) -> roughly -3pi/4]
array_api_tests/test_special_cases.py::test_binary[floor_divide(copysign(1, x1_i) != copysign(1, x2_i) and isfinite(x1_i) and x1_i != 0 and isfinite(x2_i) and x2_i != 0) -> negative sign]
array_api_tests/test_special_cases.py::test_binary[__floordiv__(copysign(1, x1_i) != copysign(1, x2_i) and isfinite(x1_i) and x1_i != 0 and isfinite(x2_i) and x2_i != 0) -> negative sign]
array_api_tests/test_special_cases.py::test_binary[__floordiv__((x1_i is +infinity or x1_i == -infinity) and (x2_i is +infinity or x2_i == -infinity)) -> NaN]
array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i]
array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i]
array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0]
Expand All @@ -77,6 +76,7 @@ array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i
array_api_tests/test_special_cases.py::test_iop[__imod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i]
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0]
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0]
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(copysign(1, x1_i) != copysign(1, x2_i) and isfinite(x1_i) and x1_i != 0 and isfinite(x2_i) and x2_i != 0) -> negative sign]
array_api_tests/test_special_cases.py::test_nan_propagation[prod]
array_api_tests/test_special_cases.py::test_unary[acos(x_i < -1) -> NaN]
array_api_tests/test_special_cases.py::test_unary[acos(x_i > 1) -> NaN]
Expand Down
8 changes: 8 additions & 0 deletions tests/ndonnx/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,11 @@ def test_searchsorted_raises():
b = ndx.array(shape=(3,), dtype=ndx.int64)

ndx.searchsorted(a, b, side="middle") # type: ignore[arg-type]


def test_truediv():
x = ndx.asarray([1, 2, 3], dtype=ndx.int64)
y = ndx.asarray([2, 3, 3], dtype=ndx.int64)
z = x / y
assert isinstance(z.dtype, ndx.Floating)
np.testing.assert_array_equal(z.to_numpy(), np.array([0.5, 2 / 3, 1.0]))

0 comments on commit f63dd6e

Please sign in to comment.