Skip to content

Commit

Permalink
Fix argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 24, 2024
1 parent ed532b5 commit 8d770ff
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def argmax(x, axis=None, keepdims=False):

def argmin(x, axis=None, keepdims=False):
if (
out := x.dtype._ops.argmax(x, axis=axis, keepdims=keepdims)
out := x.dtype._ops.argmin(x, axis=axis, keepdims=keepdims)
) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for argmin: '{x.dtype}'")
Expand Down
15 changes: 15 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,3 +992,18 @@ def test_no_unsafe_cumulative_sum_cast():
):
a = ndx.asarray([1, 2, 3], ndx.int32)
ndx.cumulative_sum(a, dtype=ndx.uint64)


@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
],
)
def test_argmaxmin(func, x):
np_result = func(x)
ndx_result = getattr(ndx, func.__name__)(ndx.asarray(x)).to_numpy()
assert_array_equal(np_result, ndx_result)
2 changes: 0 additions & 2 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt
array_api_tests/test_operators_and_elementwise_functions.py::test_tan
array_api_tests/test_searching_functions.py::test_argmax
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_searching_functions.py::test_where
Expand Down

0 comments on commit 8d770ff

Please sign in to comment.