Skip to content

Commit

Permalink
Fix argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 24, 2024
1 parent ed532b5 commit a09cac9
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 37 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
Changelog
=========

0.9.3 (unreleased)
------------------

- Reduced the number of unnecessary casts in :func:`ndonnx.argmax` and :func:`ndonnx.argmin`.


0.9.2 (2024-10-03)
------------------

Expand Down
Binary file modified docs/_static/classify_iris.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 30 additions & 34 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,43 +355,39 @@ def matrix_transpose(self, x) -> ndx.Array:

@validate_core
def argmax(self, x, axis=None, keepdims=False):
if axis is None:
reshaped_x = ndx.reshape(x, [-1])._core()
if keepdims:
return from_corearray(
opx.reshape(
opx.arg_max(reshaped_x, axis=0, keepdims=False),
opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64),
)
)
else:
return from_corearray(
opx.reshape(
opx.arg_max(reshaped_x, axis=0, keepdims=False),
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims), [x])
out = via_upcast(
lambda x: opx.arg_max(
x,
axis=axis or 0,
keepdims=int(keepdims),
),
[ndx.reshape(x, [-1]) if axis is None else x],
cast_return=False,
int_dtype=ndx.int32,
float_dtype=ndx.float64,
)

while keepdims and out.ndim < x.ndim:
out = ndx.expand_dims(out, axis=0)
return out

@validate_core
def argmin(self, x, axis=None, keepdims=False):
if axis is None:
reshaped_x = ndx.reshape(x, [-1])._core()
if keepdims:
return from_corearray(
opx.reshape(
opx.arg_min(reshaped_x, axis=0, keepdims=False),
opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64),
)
)
else:
return from_corearray(
opx.reshape(
opx.arg_min(reshaped_x, axis=0, keepdims=False),
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), [x])
out = via_upcast(
lambda x: opx.arg_min(
x,
axis=axis or 0,
keepdims=int(keepdims),
),
[ndx.reshape(x, [-1]) if axis is None else x],
cast_return=False,
int_dtype=ndx.int32,
float_dtype=ndx.float64,
)

while keepdims and out.ndim < x.ndim:
out = ndx.expand_dims(out, axis=0)
return out

@validate_core
def nonzero(self, x) -> tuple[Array, ...]:
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def argmax(x, axis=None, keepdims=False):

def argmin(x, axis=None, keepdims=False):
if (
out := x.dtype._ops.argmax(x, axis=axis, keepdims=keepdims)
out := x.dtype._ops.argmin(x, axis=axis, keepdims=keepdims)
) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for argmin: '{x.dtype}'")
Expand Down
46 changes: 46 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import re
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -992,3 +993,48 @@ def test_no_unsafe_cumulative_sum_cast():
):
a = ndx.asarray([1, 2, 3], ndx.int32)
ndx.cumulative_sum(a, dtype=ndx.uint64)


@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)),
(np.argmax, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)),
],
)
def test_argmaxmin(func, x, keepdims):
np_result = func(x, keepdims=keepdims)
ndx_result = getattr(ndx, func.__name__)(
ndx.asarray(x), keepdims=keepdims
).to_numpy()
assert_array_equal(np_result, ndx_result)


# Pending ORT 1.19 conda-forge release before this becomes supported:
# https://github.com/conda-forge/onnxruntime-feedstock/pull/128
@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
],
)
def test_argmaxmin_unsupported_kernels(func, x):
import onnxruntime as ort

if ort.__version__.startswith("19"):
warnings.warn(
"Please remove this test and update `argmax` and `argmin` to reflect expanded kernel support.",
Warning,
)

with pytest.raises(TypeError):
getattr(ndx, func.__name__)(ndx.asarray(x))
2 changes: 0 additions & 2 deletions xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt
array_api_tests/test_operators_and_elementwise_functions.py::test_tan
array_api_tests/test_searching_functions.py::test_argmax
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_searching_functions.py::test_where
Expand Down

0 comments on commit a09cac9

Please sign in to comment.