diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 9b52baa91..1f9556346 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -211,6 +211,8 @@ std::uintptr_t get_function_address(const std::function& fun) { class CompilerCache { public: struct CacheEntry { + CacheEntry(Stream stream) : stream(stream) {}; + Stream stream; std::vector inputs; std::vector outputs; std::vector tape; @@ -227,6 +229,7 @@ class CompilerCache { const std::vector& constants) { // Find the cache entries for |fun_id|. std::vector& entries = cache_[fun_id]; + // Compare if 2 arrays have same shape and dtype. auto has_same_shape_and_dtype = [shapeless]( const std::vector& in1, @@ -247,11 +250,16 @@ class CompilerCache { } return true; }; - // Loop over entries and check inputs match i.e. shapes and types must be - // equal. Note this could get really slow if one compiles the same - // function with many different shapes. May want to store entries in a - // more easily searchable structure. + // Loop over entries and check: + // - Default stream and device match the entry's default stream + // - Inputs match i.e. shapes and types must be equal. + auto stream = default_stream(default_device()); for (CacheEntry& entry : entries) { + // Check that the default stream and device match + if (entry.stream != stream) { + continue; + } + // Check the inputs match and return if so if (has_same_shape_and_dtype(inputs, entry.inputs) && constants == entry.constants) { @@ -259,7 +267,7 @@ class CompilerCache { } } // Otherwise append a new cache entry - entries.push_back(CacheEntry{}); + entries.push_back(CacheEntry{stream}); return entries.back(); } diff --git a/python/src/array.cpp b/python/src/array.cpp index c5a3c0cdd..e518f2765 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -184,10 +184,6 @@ void init_array(nb::module_& m) { R"pbdoc( A helper object to apply updates at specific indices. )pbdoc") - .def( - nb::init(), - "x"_a, - nb::sig("def __init__(self, x: array)")) .def("__getitem__", &ArrayAt::set_indices, "indices"_a.none()) .def("add", &ArrayAt::add, "value"_a) .def("subtract", &ArrayAt::subtract, "value"_a) @@ -202,10 +198,6 @@ void init_array(nb::module_& m) { R"pbdoc( A helper object to iterate over the 1st dimension of an array. )pbdoc") - .def( - nb::init(), - "x"_a, - nb::sig("def __init__(self, x: array)")) .def("__next__", &ArrayPythonIterator::next) .def("__iter__", [](const ArrayPythonIterator& it) { return it; }); diff --git a/python/src/stream.cpp b/python/src/stream.cpp index 81260abb7..e1eb6b953 100644 --- a/python/src/stream.cpp +++ b/python/src/stream.cpp @@ -48,7 +48,6 @@ void init_stream(nb::module_& m) { R"pbdoc( A stream for running operations on a given device. )pbdoc") - .def(nb::init(), "index"_a, "device"_a) .def_ro("device", &Stream::device) .def( "__repr__", diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index d1ca59c99..a1559b7d3 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -719,3 +719,14 @@ TEST_CASE("test compile strides") { CHECK_EQ(out.strides().size(), 3); } } + +TEST_CASE("test compile change streams") { + auto cfun = compile(simple_fun); + auto out = cfun({array(1.0f), array(2.0f)})[0]; + CHECK_EQ(out.primitive().stream(), default_stream(default_device())); + + auto s = new_stream(default_device()); + StreamContext sctx(s); + out = cfun({array(1.0f), array(2.0f)})[0]; + CHECK_EQ(out.primitive().stream(), s); +}