Skip to content

Commit

Permalink
Minor updates to address a few issues (#537)
Browse files Browse the repository at this point in the history
* docs on arg indices return type

* arange with nan

* undo isort
  • Loading branch information
awni authored Jan 24, 2024
1 parent 4fe2fa2 commit f30e633
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
9 changes: 8 additions & 1 deletion mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ array arange(
msg << bool_ << " not supported for arange.";
throw std::invalid_argument(msg.str());
}
int size = std::max(static_cast<int>(std::ceil((stop - start) / step)), 0);
if (std::isnan(start) || std::isnan(step) || std::isnan(stop)) {
throw std::invalid_argument("[arange] Cannot compute length.");
}
double real_size = std::ceil((stop - start) / step);
if (std::isnan(real_size)) {
throw std::invalid_argument("[arange] Cannot compute length.");
}
int size = std::max(static_cast<int>(real_size), 0);
return array(
{size},
dtype,
Expand Down
8 changes: 4 additions & 4 deletions python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2254,7 +2254,7 @@ void init_ops(py::module_& m) {
singleton dimensions, defaults to `False`.
Returns:
array: The output array with the indices of the minimum values.
array: The ``uint32`` array with the indices of the minimum values.
)pbdoc");
m.def(
"argmax",
Expand Down Expand Up @@ -2287,7 +2287,7 @@ void init_ops(py::module_& m) {
singleton dimensions, defaults to `False`.
Returns:
array: The output array with the indices of the maximum values.
array: The ``uint32`` array with the indices of the maximum values.
)pbdoc");
m.def(
"sort",
Expand Down Expand Up @@ -2343,7 +2343,7 @@ void init_ops(py::module_& m) {
If unspecified, it defaults to -1 (sorting over the last axis).
Returns:
array: The indices that sort the input array.
array: The ``uint32`` array containing indices that sort the input.
)pbdoc");
m.def(
"partition",
Expand Down Expand Up @@ -2416,7 +2416,7 @@ void init_ops(py::module_& m) {
If unspecified, it defaults to ``-1``.
Returns:
array: The indices that partition the input array.
array: The `uint32`` array containing indices that partition the input.
)pbdoc");
m.def(
"topk",
Expand Down
11 changes: 11 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,17 @@ def test_split(self):
self.assertEqual(z.tolist(), [5, 6, 7])

def test_arange_overload_dispatch(self):
with self.assertRaises(ValueError):
a = mx.arange(float("nan"), 1, 5)
with self.assertRaises(ValueError):
a = mx.arange(0, float("nan"), 5)
with self.assertRaises(ValueError):
a = mx.arange(0, 2, float("nan"))
with self.assertRaises(ValueError):
a = mx.arange(0, float("inf"), float("inf"))
with self.assertRaises(ValueError):
a = mx.arange(float("inf"), 1, float("inf"))

a = mx.arange(5)
expected = [0, 1, 2, 3, 4]
self.assertListEqual(a.tolist(), expected)
Expand Down

0 comments on commit f30e633

Please sign in to comment.