diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index a3c92e7..f779bd4 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -371,7 +371,11 @@ def argmax(self, x, axis=None, keepdims=False): opx.const([], dtype=dtypes.int64), ) ) - return _via_i64_f64(lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims), [x]) + return _via_i64_f64( + lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims), + [x], + cast_return=False, + ) @validate_core def argmin(self, x, axis=None, keepdims=False): @@ -391,7 +395,11 @@ def argmin(self, x, axis=None, keepdims=False): opx.const([], dtype=dtypes.int64), ) ) - return _via_i64_f64(lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), [x]) + return _via_i64_f64( + lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), + [x], + cast_return=False, + ) @validate_core def nonzero(self, x) -> tuple[Array, ...]: diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index d15dd16..3021b2d 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -670,7 +670,7 @@ def argmax(x, axis=None, keepdims=False): def argmin(x, axis=None, keepdims=False): if ( - out := x.dtype._ops.argmax(x, axis=axis, keepdims=keepdims) + out := x.dtype._ops.argmin(x, axis=axis, keepdims=keepdims) ) is not NotImplemented: return out raise UnsupportedOperationError(f"Unsupported operand type for argmin: '{x.dtype}'") diff --git a/tests/test_core.py b/tests/test_core.py index 5063954..2bdcab5 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -992,3 +992,18 @@ def test_no_unsafe_cumulative_sum_cast(): ): a = ndx.asarray([1, 2, 3], ndx.int32) ndx.cumulative_sum(a, dtype=ndx.uint64) + + +@pytest.mark.parametrize( + "func, x", + [ + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)), + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)), + (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)), + (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)), + ], +) +def test_argmaxmin(func, x): + np_result = func(x) + ndx_result = getattr(ndx, func.__name__)(ndx.asarray(x)).to_numpy() + assert_array_equal(np_result, ndx_result) diff --git a/xfails.txt b/xfails.txt index 69431ee..9665406 100644 --- a/xfails.txt +++ b/xfails.txt @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit array_api_tests/test_operators_and_elementwise_functions.py::test_sinh array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_tan -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_searching_functions.py::test_where