-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data parallel helper #1407
Data parallel helper #1407
Conversation
python/mlx/nn/utils.py
Outdated
return tree_map(_average, gradients) | ||
|
||
else: | ||
flat_grads = sorted(tree_flatten(gradients), key=lambda x: x[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of the sort here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this should be either removed or the code just above is wrong. Basically in order for this to work tree_map
needs to have a deterministic ordering across machines (because all_sum
needs to be called on the equivalent arrays).
So this comes down to iteration ordering of dicts in python. Looking into it a bit more since python 3.7 the iteration is guaranteed to be the order of insertion which is why the code in line 110 works as well. So the conclusion is I will remove the sort :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hadn't thought of the ordering across machines. That makes sense!
@@ -68,3 +69,93 @@ def wrapped_checkpointed_fn(*args, **kwargs): | |||
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) | |||
|
|||
return wrapped_checkpointed_fn | |||
|
|||
|
|||
def average_gradients( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it matters, but this function has two different return modes which could lead to some confusion:
- Return reference to input tree if it's a no-op
- Return copied tree structure o/w
So if you do:
avg_grads = average_gradients(grads)
grads[0] = ...
The behavior would be different in the two cases. It seems like a really odd usage pattern.. but nevertheless it might be worth making a copy of the tree structure in all cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm hadn't thought of that. One reason to keep it as is would be that in the case of no-op we have absolutely no overhead (except a python function call which is nanoseconds). Otherwise doing a copy would add some overhead proportional to the number of parameters.
Let me know what you think. It could be fine as it is likely hidden behind computation anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I don't like making overhead for no reason and those trees can get pretty large. So probably let's just keep it the way it is and deal with it in the future if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really nice. I think of all the places this could be nn.utils
is probably the best. If you think there is a world in which we want to do more distributed stuff in nn
.. we could make a new sub-package nn.distributed
and put it there (which may also be a good home for the distributed layers in #1270).
The benchmark scaling is pretty remarkable btw.. so much potential for fast distributed fine-tuning! |
This PR adds float16 and bfloat16 types to MPI as well as a helper to average gradients over a distributed group. It is a bit ugly but it helps nicely on ml-explore/mlx-examples#821 . So I am not sure if it belongs in
nn.utils
🤔 .Here is a simple scaling plot for a Llama 8B finetuning.