Skip to content

Commit

Permalink
BFS graph evaluation order (#1525)
Browse files Browse the repository at this point in the history
* bfs order

* try fix event issue
  • Loading branch information
awni authored Oct 25, 2024
1 parent 0eb56d5 commit 8e88e30
Showing 1 changed file with 56 additions and 22 deletions.
78 changes: 56 additions & 22 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ int detail::InTracing::tracing_counter{0};
int detail::RetainGraph::tracing_counter{0};

array eval_impl(std::vector<array> outputs, bool async) {
std::queue<array> tape;
std::vector<array> tape;

// stream events to use for synchronization
std::unordered_map<uint32_t, Event> events;
Expand All @@ -64,7 +64,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
events.emplace(stream.index, Event{stream});

{
std::unordered_set<std::uintptr_t> cache;
// Record the degree of each input
std::unordered_map<std::uintptr_t, int> cache;

std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
dfs.emplace(synchronizer, 0);
while (!dfs.empty()) {
Expand Down Expand Up @@ -104,50 +106,82 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
}

if (cache.find(in.id()) == cache.end()) {
// All siblings have the same degree
auto cache_it = cache.find(in.id());
if (cache_it == cache.end()) {
dfs.emplace(in, 0);
cache.insert(in.id());
cache.insert({in.id(), 1});
for (auto& s : in.siblings()) {
cache.insert({s.id(), 1});
}
} else {
cache_it->second++;
for (auto& s : in.siblings()) {
cache.insert(s.id());
cache[s.id()]++;
}
}
continue;
}

// All inputs are done being processed, process this array
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
a.has_primitive()) {
// If the array is evaluated and is no longer a tracer, detach it
a.detach();
} else if (a.status() == array::Status::unscheduled) {
tape.push(a);
// Lookup corresponding event and increment counter
auto& stream = a.primitive().stream();
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
dfs.pop();
}

// Build the tape in BFS order
tape.push_back(synchronizer);
for (int i = 0; !cache.empty() && i < tape.size(); ++i) {
auto& a = tape[i];
for (auto& in : a.inputs()) {
if (in.status() != array::Status::unscheduled) {
continue;
}
e->second.set_value(e->second.value() + 1);
a.attach_event(e->second);
for (auto& s : a.siblings()) {
s.attach_event(e->second);
auto it = cache.find(in.id());
it->second -= 1;

if (it->second != 0) {
for (auto& s : in.siblings()) {
cache[s.id()] -= 1;
}
continue;
}

// Remove input and siblings from cache
cache.erase(it);
for (auto& s : in.siblings()) {
cache.erase(s.id());
}

tape.push_back(in);
}
dfs.pop();
}
}

while (!tape.empty()) {
auto arr = std::move(tape.front());
tape.pop();
auto arr = std::move(tape.back());
tape.pop_back();

auto stream = arr.primitive().stream();

// Lookup corresponding event and increment counter
auto e = events.find(stream.index);
if (e == events.end()) {
e = events.emplace(stream.index, Event{stream}).first;
}
e->second.set_value(e->second.value() + 1);
arr.attach_event(e->second);
for (auto& s : arr.siblings()) {
s.attach_event(e->second);
}

// Set the status of the array and siblings.
arr.set_status(array::Status::scheduled);
for (auto& s : arr.siblings()) {
s.set_status(array::Status::scheduled);
}

auto stream = arr.primitive().stream();
std::vector<std::shared_future<void>> arr_deps;
bool signal = needs_signal.find(arr.id()) != needs_signal.end();

Expand Down

0 comments on commit 8e88e30

Please sign in to comment.