Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data parallel helper #1407

Merged
merged 3 commits into from
Sep 17, 2024
Merged

Data parallel helper #1407

merged 3 commits into from
Sep 17, 2024

Conversation

angeloskath
Copy link
Member

This PR adds float16 and bfloat16 types to MPI as well as a helper to average gradients over a distributed group. It is a bit ugly but it helps nicely on ml-explore/mlx-examples#821 . So I am not sure if it belongs in nn.utils 🤔 .

Here is a simple scaling plot for a Llama 8B finetuning.

llama-8b-finetuning

@angeloskath angeloskath requested a review from awni September 12, 2024 23:04
return tree_map(_average, gradients)

else:
flat_grads = sorted(tree_flatten(gradients), key=lambda x: x[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of the sort here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this should be either removed or the code just above is wrong. Basically in order for this to work tree_map needs to have a deterministic ordering across machines (because all_sum needs to be called on the equivalent arrays).

So this comes down to iteration ordering of dicts in python. Looking into it a bit more since python 3.7 the iteration is guaranteed to be the order of insertion which is why the code in line 110 works as well. So the conclusion is I will remove the sort :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't thought of the ordering across machines. That makes sense!

@@ -68,3 +69,93 @@ def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)

return wrapped_checkpointed_fn


def average_gradients(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it matters, but this function has two different return modes which could lead to some confusion:

  • Return reference to input tree if it's a no-op
  • Return copied tree structure o/w

So if you do:

avg_grads = average_gradients(grads)
grads[0] = ...

The behavior would be different in the two cases. It seems like a really odd usage pattern.. but nevertheless it might be worth making a copy of the tree structure in all cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm hadn't thought of that. One reason to keep it as is would be that in the case of no-op we have absolutely no overhead (except a python function call which is nanoseconds). Otherwise doing a copy would add some overhead proportional to the number of parameters.

Let me know what you think. It could be fine as it is likely hidden behind computation anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I don't like making overhead for no reason and those trees can get pretty large. So probably let's just keep it the way it is and deal with it in the future if needed.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really nice. I think of all the places this could be nn.utils is probably the best. If you think there is a world in which we want to do more distributed stuff in nn.. we could make a new sub-package nn.distributed and put it there (which may also be a good home for the distributed layers in #1270).

@awni
Copy link
Member

awni commented Sep 16, 2024

The benchmark scaling is pretty remarkable btw.. so much potential for fast distributed fine-tuning!

@angeloskath angeloskath merged commit 914409f into main Sep 17, 2024
4 checks passed
@angeloskath angeloskath deleted the data-parallel-helper branch September 17, 2024 01:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants