Skip to content

Commit

Permalink
Fix cumsum with dtype specified
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 29, 2024
1 parent 13fb397 commit 7b6c8aa
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 20 deletions.
30 changes: 16 additions & 14 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,18 +568,17 @@ def cumulative_sum(
else:
raise ValueError("axis must be specified for multi-dimensional arrays")

if dtype is None:
if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
if ndx.iinfo(x.dtype).bits < 64:
out = x.astype(dtypes.int64)
else:
raise ndx.UnsupportedOperationError(
f"Cannot perform `cumulative_sum` using {x.dtype}"
)
if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)):
if ndx.iinfo(x.dtype).bits < 64:
out = x.astype(dtypes.int64)
else:
out = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.int64))
return NotImplemented
elif dtype == dtypes.uint64 or dtype == dtypes.nuint64:
raise ndx.UnsupportedOperationError(
f"Unsupported dtype parameter for cumulative_sum {dtype} due to missing kernel support"
)
else:
out = out.astype(dtype)
out = x.astype(_determine_reduce_op_dtype(x, None, dtypes.int64))

out = from_corearray(
opx.cumsum(
Expand All @@ -589,10 +588,13 @@ def cumulative_sum(
)
)

if isinstance(x.dtype, dtypes.Unsigned):
out = out.astype(ndx.uint64)
elif isinstance(x.dtype, dtypes.NullableUnsigned):
out = out.astype(ndx.nuint64)
if dtype is None:
if isinstance(x.dtype, dtypes.Unsigned):
out = out.astype(ndx.uint64)
elif isinstance(x.dtype, dtypes.NullableUnsigned):
out = out.astype(ndx.nuint64)
else:
out = out.astype(dtype)

# Exclude axis and create zeros of that shape
if include_initial:
Expand Down
6 changes: 5 additions & 1 deletion ndonnx/_data_types/coretype.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def _parse_input(self, data: np.ndarray) -> dict[str, np.ndarray]:

def _cast_from(self, array: Array) -> Array:
if isinstance(array.dtype, CoreType):
return ndx.Array._from_fields(self, data=opx.cast(array._core(), to=self))
return (
ndx.Array._from_fields(self, data=opx.cast(array._core(), to=self))
if array.dtype != self
else array.copy()
)
else:
raise CastError(f"Cannot cast from {array.dtype} to {self}")

Expand Down
25 changes: 20 additions & 5 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ def test_dynamic_reshape_has_no_static_shape(x, shape):
)
@pytest.mark.parametrize("include_initial", [True, False])
@pytest.mark.parametrize(
"dtype",
"array_dtype",
[ndx.int32, ndx.int64, ndx.float32, ndx.float64, ndx.uint8, ndx.uint16, ndx.uint32],
)
@pytest.mark.parametrize(
Expand All @@ -948,21 +948,36 @@ def test_dynamic_reshape_has_no_static_shape(x, shape):
([[[[1]]], [[[3]]]], 1),
],
)
def test_cumulative_sum(array, axis, include_initial, dtype):
a = ndx.asarray(array, dtype=dtype)
@pytest.mark.parametrize(
"cumsum_dtype",
[None, ndx.int32, ndx.float32, ndx.float64, ndx.uint8],
)
def test_cumulative_sum(array, axis, include_initial, array_dtype, cumsum_dtype):
a = ndx.asarray(array, dtype=array_dtype)
assert_array_equal(
ndx.cumulative_sum(a, include_initial=include_initial, axis=axis).to_numpy(),
ndx.cumulative_sum(
a, include_initial=include_initial, axis=axis, dtype=cumsum_dtype
).to_numpy(),
np.cumulative_sum(
np.asarray(array, a.dtype.to_numpy_dtype()),
include_initial=include_initial,
axis=axis,
dtype=cumsum_dtype.to_numpy_dtype() if cumsum_dtype is not None else None,
),
)


def test_no_unsafe_cumulative_sum_cast():
with pytest.raises(
ndx.UnsupportedOperationError, match="Cannot perform `cumulative_sum`"
ndx.UnsupportedOperationError,
match="Unsupported operand type for cumulative_sum",
):
a = ndx.asarray([1, 2, 3], ndx.uint64)
ndx.cumulative_sum(a)

with pytest.raises(
ndx.UnsupportedOperationError,
match="Unsupported dtype parameter for cumulative_sum",
):
a = ndx.asarray([1, 2, 3], ndx.int32)
ndx.cumulative_sum(a, dtype=ndx.uint64)

0 comments on commit 7b6c8aa

Please sign in to comment.