diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 54f676203..6c6f81868 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -40,7 +40,7 @@ int detail::InTracing::tracing_counter{0}; int detail::RetainGraph::tracing_counter{0}; array eval_impl(std::vector outputs, bool async) { - std::queue tape; + std::vector tape; // stream events to use for synchronization std::unordered_map events; @@ -64,7 +64,9 @@ array eval_impl(std::vector outputs, bool async) { events.emplace(stream.index, Event{stream}); { - std::unordered_set cache; + // Record the degree of each input + std::unordered_map cache; + std::stack, int>> dfs; dfs.emplace(synchronizer, 0); while (!dfs.empty()) { @@ -104,42 +106,75 @@ array eval_impl(std::vector outputs, bool async) { } } - if (cache.find(in.id()) == cache.end()) { + // All siblings have the same degree + auto cache_it = cache.find(in.id()); + if (cache_it == cache.end()) { dfs.emplace(in, 0); - cache.insert(in.id()); + cache.insert({in.id(), 1}); + for (auto& s : in.siblings()) { + cache.insert({s.id(), 1}); + } + } else { + cache_it->second++; for (auto& s : in.siblings()) { - cache.insert(s.id()); + cache[s.id()]++; } } continue; } - - // All inputs are done being processed, process this array if ((a.status() != array::Status::unscheduled) && !a.is_tracer() && a.has_primitive()) { // If the array is evaluated and is no longer a tracer, detach it a.detach(); - } else if (a.status() == array::Status::unscheduled) { - tape.push(a); - // Lookup corresponding event and increment counter - auto& stream = a.primitive().stream(); - auto e = events.find(stream.index); - if (e == events.end()) { - e = events.emplace(stream.index, Event{stream}).first; + } + dfs.pop(); + } + + // Build the tape in BFS order + tape.push_back(synchronizer); + for (int i = 0; !cache.empty() && i < tape.size(); ++i) { + auto& a = tape[i]; + for (auto& in : a.inputs()) { + if (in.status() != array::Status::unscheduled) { + continue; } - e->second.set_value(e->second.value() + 1); - a.attach_event(e->second); - for (auto& s : a.siblings()) { - s.attach_event(e->second); + auto it = cache.find(in.id()); + it->second -= 1; + + if (it->second != 0) { + for (auto& s : in.siblings()) { + cache[s.id()] -= 1; + } + continue; + } + + // Remove input and siblings from cache + cache.erase(it); + for (auto& s : in.siblings()) { + cache.erase(s.id()); } + + tape.push_back(in); } - dfs.pop(); } } while (!tape.empty()) { - auto arr = std::move(tape.front()); - tape.pop(); + auto arr = std::move(tape.back()); + tape.pop_back(); + + auto stream = arr.primitive().stream(); + + // Lookup corresponding event and increment counter + auto e = events.find(stream.index); + if (e == events.end()) { + e = events.emplace(stream.index, Event{stream}).first; + } + e->second.set_value(e->second.value() + 1); + arr.attach_event(e->second); + for (auto& s : arr.siblings()) { + s.attach_event(e->second); + } // Set the status of the array and siblings. arr.set_status(array::Status::scheduled); @@ -147,7 +182,6 @@ array eval_impl(std::vector outputs, bool async) { s.set_status(array::Status::scheduled); } - auto stream = arr.primitive().stream(); std::vector> arr_deps; bool signal = needs_signal.find(arr.id()) != needs_signal.end();