-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data parallel helper #1407
Data parallel helper #1407
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
# Copyright © 2023-2024 Apple Inc. | ||
|
||
from functools import wraps | ||
from typing import Callable, Optional | ||
from functools import reduce, wraps | ||
from typing import Any, Callable, Optional | ||
|
||
import mlx.core as mx | ||
|
||
from ..utils import tree_flatten, tree_map, tree_unflatten | ||
from .layers.base import Module | ||
|
||
|
||
|
@@ -68,3 +69,93 @@ def wrapped_checkpointed_fn(*args, **kwargs): | |
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) | ||
|
||
return wrapped_checkpointed_fn | ||
|
||
|
||
def average_gradients( | ||
gradients: Any, | ||
group: Optional[mx.distributed.Group] = None, | ||
all_reduce_size: int = 32 * 1024**2, | ||
communication_type: Optional[mx.Dtype] = None, | ||
): | ||
"""Average the gradients across the distributed processes in the passed group. | ||
|
||
This helper enables concatenating several gradients of small arrays to one | ||
big all reduce call for better networking performance. | ||
|
||
Args: | ||
gradients (Any): The Python tree containing the gradients (it should | ||
have the same structure across processes) | ||
group (Optional[mlx.core.distributed.Group]): The group of processes to | ||
average the gradients. If set to ``None`` the global group is used. | ||
Default: ``None``. | ||
all_reduce_size (int): Group arrays until their size in bytes exceeds | ||
this number. Perform one communication step per group of arrays. If | ||
less or equal to 0 array grouping is disabled. Default: ``32MiB``. | ||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this | ||
type before performing the communication. Typically cast to a | ||
smaller float to reduce the communication size. Default: ``None``. | ||
""" | ||
group = group or mx.distributed.init() | ||
N = group.size() | ||
|
||
if N == 1: | ||
return gradients | ||
|
||
def _average(x): | ||
dt = x.dtype | ||
x = x.astype(communication_type) if communication_type is not None else x | ||
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N | ||
|
||
if all_reduce_size <= 0: | ||
return tree_map(_average, gradients) | ||
|
||
else: | ||
flat_grads = sorted(tree_flatten(gradients), key=lambda x: x[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of the sort here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this should be either removed or the code just above is wrong. Basically in order for this to work So this comes down to iteration ordering of dicts in python. Looking into it a bit more since python 3.7 the iteration is guaranteed to be the order of insertion which is why the code in line 110 works as well. So the conclusion is I will remove the sort :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hadn't thought of the ordering across machines. That makes sense! |
||
if len(flat_grads) == 0: | ||
return gradients | ||
|
||
# Extract some info for the gradient | ||
keys = [k for k, _ in flat_grads] | ||
shapes = [v.shape for _, v in flat_grads] | ||
sizes = [v.size for _, v in flat_grads] | ||
dtypes = [v.dtype for _, v in flat_grads] | ||
|
||
# We can't group them if they have mixed types | ||
if not all(dt == dtypes[0] for dt in dtypes): | ||
return average_gradients(gradients, group, 0, communication_type) | ||
itemsize = ( | ||
communication_type.size | ||
if communication_type is not None | ||
else dtypes[0].size | ||
) | ||
|
||
# Gather the gradients in groups that are just above or equal to all_reduce_size | ||
grad_groups = [] | ||
grad_group = [] | ||
grad_group_size = 0 | ||
for i in range(len(keys)): | ||
grad_group.append(i) | ||
grad_group_size += sizes[i] * itemsize | ||
if grad_group_size >= all_reduce_size: | ||
grad_groups.append(grad_group) | ||
grad_group = [] | ||
grad_group_size = 0 | ||
if grad_group: | ||
grad_groups.append(grad_group) | ||
grad_group = [] | ||
|
||
# Concatenate-reduce-split | ||
new_flat_grads = [] | ||
for grad_group in grad_groups: | ||
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0]) | ||
big_grad = mx.concatenate( | ||
[flat_grads[i][1].reshape(-1) for i in grad_group] | ||
) | ||
big_grad = _average(big_grad) | ||
big_grad = mx.split(big_grad, indices[1:-1]) | ||
new_flat_grads.extend( | ||
(keys[j], big_grad[i].reshape(shapes[j])) | ||
for i, j in enumerate(grad_group) | ||
) | ||
|
||
return tree_unflatten(new_flat_grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it matters, but this function has two different return modes which could lead to some confusion:
So if you do:
The behavior would be different in the two cases. It seems like a really odd usage pattern.. but nevertheless it might be worth making a copy of the tree structure in all cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm hadn't thought of that. One reason to keep it as is would be that in the case of no-op we have absolutely no overhead (except a python function call which is nanoseconds). Otherwise doing a copy would add some overhead proportional to the number of parameters.
Let me know what you think. It could be fine as it is likely hidden behind computation anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I don't like making overhead for no reason and those trees can get pretty large. So probably let's just keep it the way it is and deal with it in the future if needed.