Skip to content

Commit

Permalink
Fix split layer grid tag
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Feb 7, 2024
1 parent ff0afa0 commit d968f30
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,10 +1382,13 @@ void model::add_split_layers(std::unordered_set<std::string>& layer_names)
split->set_name(name);
layer_names.insert(name);

// Copy parallel strategy from parent.
// Copy parallel strategy and grid tag from parent.
ParallelStrategy& ps = split->get_parallel_strategy();
ParallelStrategy& orig_ps = l.get_parallel_strategy();
ps = orig_ps;
if (l.grid_tag() >= 0) {
split->grid_tag(l.grid_tag());
}

// Setup relationships between split layer and child layers
for (int j = 0; j < l.get_num_children(); ++j) {
Expand Down Expand Up @@ -1674,8 +1677,9 @@ void model::backward_prop(bool compute_weight_grads_only, bool skip_callbacks)

// Based on gradient/optimizer requirements
if (compute_weight_grads_only && m_needed_for_backprop.size() > 0 &&
m_needed_for_backprop.find(&l) == m_needed_for_backprop.end())
m_needed_for_backprop.find(&l) == m_needed_for_backprop.end()) {
enable_layer = false;
}
}

// Check if all children skip gradient backpropagation
Expand Down

0 comments on commit d968f30

Please sign in to comment.