Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data parallel helper #1407

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions mlx/distributed/mpi/mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,29 @@ array ensure_row_contiguous(const array& arr) {
}
}

template <typename T>
void simple_sum(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;

while (N-- > 0) {
*acc += *in;
acc++;
in++;
}
}
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);

struct MPIWrapper {
MPIWrapper() {
initialized_ = false;

libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
Expand All @@ -50,6 +71,9 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);

// Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
Expand Down Expand Up @@ -79,7 +103,24 @@ struct MPIWrapper {
if (!is_available()) {
return false;
}
return init(nullptr, nullptr) == MPI_SUCCESS;
bool success = init(nullptr, nullptr) == MPI_SUCCESS;

// Initialize custom types and ops
if (success && !initialized_) {
// Custom float16 dtypes
mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);
mpi_type_commit(&mpi_float16_);
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
mpi_type_commit(&mpi_bfloat16_);

// Custom sum ops
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);

initialized_ = true;
}

return success;
}

void finalize_safe() {
Expand Down Expand Up @@ -117,13 +158,21 @@ struct MPIWrapper {
case complex64:
return mpi_complex_;
case float16:
return mpi_float16_;
case bfloat16:
throw std::runtime_error("MPI doesn't support 16-bit floats");
return mpi_bfloat16_;
}
}

MPI_Op op_sum() {
return op_sum_;
MPI_Op op_sum(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_sum_f16_;
case bfloat16:
return op_sum_bf16_;
default:
return op_sum_;
}
}

void* libmpi_handle_;
Expand Down Expand Up @@ -152,6 +201,8 @@ struct MPIWrapper {

// Ops
MPI_Op op_sum_;
MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_;

// Datatypes
MPI_Datatype mpi_bool_;
Expand All @@ -165,6 +216,16 @@ struct MPIWrapper {
MPI_Datatype mpi_uint64_;
MPI_Datatype mpi_float_;
MPI_Datatype mpi_complex_;
MPI_Datatype mpi_float16_;
MPI_Datatype mpi_bfloat16_;

private:
bool initialized_;

// Private API
int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);
int (*mpi_type_commit)(MPI_Datatype*);
int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);
};

MPIWrapper& mpi() {
Expand Down Expand Up @@ -273,7 +334,7 @@ void all_sum(Group group, const array& input_, array& output) {
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(),
mpi().op_sum(input),
to_comm(group));
}

Expand Down
95 changes: 93 additions & 2 deletions python/mlx/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.

from functools import wraps
from typing import Callable, Optional
from functools import reduce, wraps
from typing import Any, Callable, Optional

import mlx.core as mx

from ..utils import tree_flatten, tree_map, tree_unflatten
from .layers.base import Module


Expand Down Expand Up @@ -68,3 +69,93 @@ def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)

return wrapped_checkpointed_fn


def average_gradients(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it matters, but this function has two different return modes which could lead to some confusion:

  • Return reference to input tree if it's a no-op
  • Return copied tree structure o/w

So if you do:

avg_grads = average_gradients(grads)
grads[0] = ...

The behavior would be different in the two cases. It seems like a really odd usage pattern.. but nevertheless it might be worth making a copy of the tree structure in all cases.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm hadn't thought of that. One reason to keep it as is would be that in the case of no-op we have absolutely no overhead (except a python function call which is nanoseconds). Otherwise doing a copy would add some overhead proportional to the number of parameters.

Let me know what you think. It could be fine as it is likely hidden behind computation anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I don't like making overhead for no reason and those trees can get pretty large. So probably let's just keep it the way it is and deal with it in the future if needed.

gradients: Any,
group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None,
):
"""Average the gradients across the distributed processes in the passed group.

This helper enables concatenating several gradients of small arrays to one
big all reduce call for better networking performance.

Args:
gradients (Any): The Python tree containing the gradients (it should
have the same structure across processes)
group (Optional[mlx.core.distributed.Group]): The group of processes to
average the gradients. If set to ``None`` the global group is used.
Default: ``None``.
all_reduce_size (int): Group arrays until their size in bytes exceeds
this number. Perform one communication step per group of arrays. If
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
"""
group = group or mx.distributed.init()
N = group.size()

if N == 1:
return gradients

def _average(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N

if all_reduce_size <= 0:
return tree_map(_average, gradients)

else:
flat_grads = sorted(tree_flatten(gradients), key=lambda x: x[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of the sort here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this should be either removed or the code just above is wrong. Basically in order for this to work tree_map needs to have a deterministic ordering across machines (because all_sum needs to be called on the equivalent arrays).

So this comes down to iteration ordering of dicts in python. Looking into it a bit more since python 3.7 the iteration is guaranteed to be the order of insertion which is why the code in line 110 works as well. So the conclusion is I will remove the sort :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't thought of the ordering across machines. That makes sense!

if len(flat_grads) == 0:
return gradients

# Extract some info for the gradient
keys = [k for k, _ in flat_grads]
shapes = [v.shape for _, v in flat_grads]
sizes = [v.size for _, v in flat_grads]
dtypes = [v.dtype for _, v in flat_grads]

# We can't group them if they have mixed types
if not all(dt == dtypes[0] for dt in dtypes):
return average_gradients(gradients, group, 0, communication_type)
itemsize = (
communication_type.size
if communication_type is not None
else dtypes[0].size
)

# Gather the gradients in groups that are just above or equal to all_reduce_size
grad_groups = []
grad_group = []
grad_group_size = 0
for i in range(len(keys)):
grad_group.append(i)
grad_group_size += sizes[i] * itemsize
if grad_group_size >= all_reduce_size:
grad_groups.append(grad_group)
grad_group = []
grad_group_size = 0
if grad_group:
grad_groups.append(grad_group)
grad_group = []

# Concatenate-reduce-split
new_flat_grads = []
for grad_group in grad_groups:
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
big_grad = mx.concatenate(
[flat_grads[i][1].reshape(-1) for i in grad_group]
)
big_grad = _average(big_grad)
big_grad = mx.split(big_grad, indices[1:-1])
new_flat_grads.extend(
(keys[j], big_grad[i].reshape(shapes[j]))
for i, j in enumerate(grad_group)
)

return tree_unflatten(new_flat_grads)
54 changes: 54 additions & 0 deletions python/tests/mpi_test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mlx.core as mx
import mlx_tests
from mlx.nn.utils import average_gradients


class TestDistributed(mlx_tests.MLXTestCase):
Expand Down Expand Up @@ -110,6 +111,59 @@ def test_send_recv(self):

self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))

def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0
xtype = None

def new_all_sum(x, **kwargs):
nonlocal n_calls
nonlocal xtype

n_calls += 1
if xtype is not None:
self.assertEqual(xtype, x.dtype)

return original_all_sum(x, **kwargs)

mx.distributed.all_sum = new_all_sum

try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)

n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)

n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 10)

n_calls = 0
xtype = mx.float16
new_grads = average_gradients(
grads, all_reduce_size=2 * 50, communication_type=mx.float16
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)

finally:
mx.distributed.all_sum = original_all_sum


if __name__ == "__main__":
unittest.main()