Skip to content

Commit

Permalink
Fix argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 24, 2024
1 parent ed532b5 commit 49e8ee9
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 35 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
Changelog
=========

0.9.3 (unreleased)
------------------

- Reduced the number of unnecessary casts in :func:`ndonnx.argmax` and :func:`ndonnx.argmin`.


0.9.2 (2024-10-03)
------------------

Expand Down
Binary file modified docs/_static/classify_iris.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
51 changes: 19 additions & 32 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,42 +356,29 @@ def matrix_transpose(self, x) -> ndx.Array:
@validate_core
def argmax(self, x, axis=None, keepdims=False):
if axis is None:
reshaped_x = ndx.reshape(x, [-1])._core()
if keepdims:
return from_corearray(
opx.reshape(
opx.arg_max(reshaped_x, axis=0, keepdims=False),
opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64),
)
)
else:
return from_corearray(
opx.reshape(
opx.arg_max(reshaped_x, axis=0, keepdims=False),
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims), [x])
x = ndx.reshape(x, [-1])
axis = 0
return via_upcast(
lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims),
[x],
cast_return=False,
int_dtype=ndx.int32,
float_dtype=ndx.float64,
)

@validate_core
def argmin(self, x, axis=None, keepdims=False):
if axis is None:
reshaped_x = ndx.reshape(x, [-1])._core()
if keepdims:
return from_corearray(
opx.reshape(
opx.arg_min(reshaped_x, axis=0, keepdims=False),
opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64),
)
)
else:
return from_corearray(
opx.reshape(
opx.arg_min(reshaped_x, axis=0, keepdims=False),
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), [x])
x = ndx.reshape(x, [-1])
axis = 0

return via_upcast(
lambda x: opx.arg_min(x, axis=axis, keepdims=int(keepdims)),
[x],
cast_return=False,
int_dtype=ndx.int32,
float_dtype=ndx.float64,
)

@validate_core
def nonzero(self, x) -> tuple[Array, ...]:
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def argmax(x, axis=None, keepdims=False):

def argmin(x, axis=None, keepdims=False):
if (
out := x.dtype._ops.argmax(x, axis=axis, keepdims=keepdims)
out := x.dtype._ops.argmin(x, axis=axis, keepdims=keepdims)
) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for argmin: '{x.dtype}'")
Expand Down
41 changes: 41 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import re
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -992,3 +993,43 @@ def test_no_unsafe_cumulative_sum_cast():
):
a = ndx.asarray([1, 2, 3], ndx.int32)
ndx.cumulative_sum(a, dtype=ndx.uint64)


@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)),
],
)
def test_argmaxmin(func, x):
np_result = func(x)
ndx_result = getattr(ndx, func.__name__)(ndx.asarray(x)).to_numpy()
assert_array_equal(np_result, ndx_result)


# Pending ORT 1.19 conda-forge release before this becomes supported:
# https://github.com/conda-forge/onnxruntime-feedstock/pull/128
@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
],
)
def test_argmaxmin_unsupported_kernels(func, x):
import onnxruntime as ort

if ort.__version__.startswith("19"):
warnings.warn(
"Please remove this test and update `argmax` and `argmin` to reflect expanded kernel support.",
DeprecationWarning,
)

with pytest.raises(TypeError):
getattr(ndx, func.__name__)(ndx.asarray(x))
2 changes: 0 additions & 2 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt
array_api_tests/test_operators_and_elementwise_functions.py::test_tan
array_api_tests/test_searching_functions.py::test_argmax
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_searching_functions.py::test_where
Expand Down

0 comments on commit 49e8ee9

Please sign in to comment.