diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 25858d7be..0dd033018 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -165,24 +165,24 @@ array eval_impl(std::vector outputs, bool async) { auto stream = arr.primitive().stream(); - // Lookup corresponding event and increment counter + Event e; if (stream.threads > 1) { - auto e = Event(stream); - e.set_value(e.value() + 1); - arr.attach_event(e); - for (auto& s : arr.siblings()) { - s.attach_event(e); - } + // Use unique events for multi-threaded streams + e = Event(stream); + e.set_value(1); } else { - auto e = events.find(stream.index); - if (e == events.end()) { - e = events.emplace(stream.index, Event{stream}).first; - } - e->second.set_value(e->second.value() + 1); - arr.attach_event(e->second); - for (auto& s : arr.siblings()) { - s.attach_event(e->second); + // Share events for single-threaded streams + auto e_it = events.find(stream.index); + if (e_it == events.end()) { + e_it = events.emplace(stream.index, Event{stream}).first; } + e_it->second.set_value(e_it->second.value() + 1); + e = e_it->second; + } + // Increment event counter and attach to the array and siblings + arr.attach_event(e); + for (auto& s : arr.siblings()) { + s.attach_event(e); } // Set the status of the array and siblings. @@ -192,7 +192,8 @@ array eval_impl(std::vector outputs, bool async) { } std::vector> arr_deps; - bool signal = needs_signal.find(arr.id()) != needs_signal.end(); + bool signal = + stream.threads > 1 || needs_signal.find(arr.id()) != needs_signal.end(); if (arr.primitive().device() == Device::gpu) { if (!metal::is_available()) { @@ -202,7 +203,9 @@ array eval_impl(std::vector outputs, bool async) { } else { auto task = [arr = std::move(arr), stream, signal]() mutable { for (auto& input : arr.inputs()) { - if (input.event().valid()) { + if (input.event().valid() && + (stream.threads > 1 || + input.event().stream() != arr.primitive().stream())) { input.event().wait(); } } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index e3c2c72bd..4a436f41a 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -141,7 +141,7 @@ std::ostream& operator<<(std::ostream& os, const Device& d) { std::ostream& operator<<(std::ostream& os, const Stream& s) { os << "Stream("; os << s.device; - os << ", " << s.index << ")"; + os << ", index=" << s.index << ", threads=" << s.threads << ")"; return os; } diff --git a/python/src/array.cpp b/python/src/array.cpp index c5a3c0cdd..e518f2765 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -184,10 +184,6 @@ void init_array(nb::module_& m) { R"pbdoc( A helper object to apply updates at specific indices. )pbdoc") - .def( - nb::init(), - "x"_a, - nb::sig("def __init__(self, x: array)")) .def("__getitem__", &ArrayAt::set_indices, "indices"_a.none()) .def("add", &ArrayAt::add, "value"_a) .def("subtract", &ArrayAt::subtract, "value"_a) @@ -202,10 +198,6 @@ void init_array(nb::module_& m) { R"pbdoc( A helper object to iterate over the 1st dimension of an array. )pbdoc") - .def( - nb::init(), - "x"_a, - nb::sig("def __init__(self, x: array)")) .def("__next__", &ArrayPythonIterator::next) .def("__iter__", [](const ArrayPythonIterator& it) { return it; }); diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 6a76934d4..7e9c3d107 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -47,6 +47,8 @@ void init_stream(nb::module_& m) { "Stream", R"pbdoc( A stream for running operations on a given device. + + Use :func:`new_stream` to create new streams. )pbdoc") .def(nb::init(), "index"_a, "device"_a) .def_ro("device", &Stream::device) @@ -79,7 +81,7 @@ void init_stream(nb::module_& m) { streams device. It will not change the default device. Args: - stream (stream): Stream to make the default. + stream (Stream): Stream to make the default. )pbdoc"); m.def( "new_stream", diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py index 5856d84fd..c1c32b754 100644 --- a/python/tests/test_eval.py +++ b/python/tests/test_eval.py @@ -137,6 +137,18 @@ def test_async_eval_with_multiple_streams(self): mx.async_eval(x) mx.eval(a + b) + def test_multithreaded_stream(self): + arrays = [mx.random.uniform(shape=(4, 4)) for _ in range(8)] + mx.eval(arrays) + s = mx.new_stream(mx.cpu, threads=2) + with mx.stream(s): + outputs = [mx.exp(-mx.abs(x)) for x in arrays] + out_multi = sum(outputs) + + outputs = [mx.exp(-mx.abs(x)) for x in arrays] + out = sum(outputs) + self.assertTrue(mx.allclose(out, out_multi)) + if __name__ == "__main__": unittest.main()