From f30e63353a3319f22be54e2315d04bc0cbbef1a7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 23 Jan 2024 22:24:41 -0800 Subject: [PATCH] Minor updates to address a few issues (#537) * docs on arg indices return type * arange with nan * undo isort --- mlx/ops.cpp | 9 ++++++++- python/src/ops.cpp | 8 ++++---- python/tests/test_ops.py | 11 +++++++++++ 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 54ba8bec9..df4b0495e 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -79,7 +79,14 @@ array arange( msg << bool_ << " not supported for arange."; throw std::invalid_argument(msg.str()); } - int size = std::max(static_cast(std::ceil((stop - start) / step)), 0); + if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) { + throw std::invalid_argument("[arange] Cannot compute length."); + } + double real_size = std::ceil((stop - start) / step); + if (std::isnan(real_size)) { + throw std::invalid_argument("[arange] Cannot compute length."); + } + int size = std::max(static_cast(real_size), 0); return array( {size}, dtype, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 95a05436c..7f2ce27ee 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2254,7 +2254,7 @@ void init_ops(py::module_& m) { singleton dimensions, defaults to `False`. Returns: - array: The output array with the indices of the minimum values. + array: The ``uint32`` array with the indices of the minimum values. )pbdoc"); m.def( "argmax", @@ -2287,7 +2287,7 @@ void init_ops(py::module_& m) { singleton dimensions, defaults to `False`. Returns: - array: The output array with the indices of the maximum values. + array: The ``uint32`` array with the indices of the maximum values. )pbdoc"); m.def( "sort", @@ -2343,7 +2343,7 @@ void init_ops(py::module_& m) { If unspecified, it defaults to -1 (sorting over the last axis). Returns: - array: The indices that sort the input array. + array: The ``uint32`` array containing indices that sort the input. )pbdoc"); m.def( "partition", @@ -2416,7 +2416,7 @@ void init_ops(py::module_& m) { If unspecified, it defaults to ``-1``. Returns: - array: The indices that partition the input array. + array: The `uint32`` array containing indices that partition the input. )pbdoc"); m.def( "topk", diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index f4e31df80..7284085d0 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -980,6 +980,17 @@ def test_split(self): self.assertEqual(z.tolist(), [5, 6, 7]) def test_arange_overload_dispatch(self): + with self.assertRaises(ValueError): + a = mx.arange(float("nan"), 1, 5) + with self.assertRaises(ValueError): + a = mx.arange(0, float("nan"), 5) + with self.assertRaises(ValueError): + a = mx.arange(0, 2, float("nan")) + with self.assertRaises(ValueError): + a = mx.arange(0, float("inf"), float("inf")) + with self.assertRaises(ValueError): + a = mx.arange(float("inf"), 1, float("inf")) + a = mx.arange(5) expected = [0, 1, 2, 3, 4] self.assertListEqual(a.tolist(), expected)