Skip to content

Commit

Permalink
Data parallel helper (#1407)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Sep 17, 2024
1 parent 8d68a3e commit 914409f
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 7 deletions.
71 changes: 66 additions & 5 deletions mlx/distributed/mpi/mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,29 @@ array ensure_row_contiguous(const array& arr) {
}
}

template <typename T>
void simple_sum(
void* input,
void* accumulator,
int* len,
MPI_Datatype* datatype) {
T* in = (T*)input;
T* acc = (T*)accumulator;
int N = *len;

while (N-- > 0) {
*acc += *in;
acc++;
in++;
}
}
template void simple_sum<float16_t>(void*, void*, int*, MPI_Datatype*);
template void simple_sum<bfloat16_t>(void*, void*, int*, MPI_Datatype*);

struct MPIWrapper {
MPIWrapper() {
initialized_ = false;

libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
Expand All @@ -50,6 +71,9 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous);
LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit);
LOAD_SYMBOL(MPI_Op_create, mpi_op_create);

// Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
Expand Down Expand Up @@ -79,7 +103,24 @@ struct MPIWrapper {
if (!is_available()) {
return false;
}
return init(nullptr, nullptr) == MPI_SUCCESS;
bool success = init(nullptr, nullptr) == MPI_SUCCESS;

// Initialize custom types and ops
if (success && !initialized_) {
// Custom float16 dtypes
mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_);
mpi_type_commit(&mpi_float16_);
mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_);
mpi_type_commit(&mpi_bfloat16_);

// Custom sum ops
mpi_op_create(&simple_sum<float16_t>, 1, &op_sum_f16_);
mpi_op_create(&simple_sum<bfloat16_t>, 1, &op_sum_bf16_);

initialized_ = true;
}

return success;
}

void finalize_safe() {
Expand Down Expand Up @@ -117,13 +158,21 @@ struct MPIWrapper {
case complex64:
return mpi_complex_;
case float16:
return mpi_float16_;
case bfloat16:
throw std::runtime_error("MPI doesn't support 16-bit floats");
return mpi_bfloat16_;
}
}

MPI_Op op_sum() {
return op_sum_;
MPI_Op op_sum(const array& arr) {
switch (arr.dtype()) {
case float16:
return op_sum_f16_;
case bfloat16:
return op_sum_bf16_;
default:
return op_sum_;
}
}

void* libmpi_handle_;
Expand Down Expand Up @@ -152,6 +201,8 @@ struct MPIWrapper {

// Ops
MPI_Op op_sum_;
MPI_Op op_sum_f16_;
MPI_Op op_sum_bf16_;

// Datatypes
MPI_Datatype mpi_bool_;
Expand All @@ -165,6 +216,16 @@ struct MPIWrapper {
MPI_Datatype mpi_uint64_;
MPI_Datatype mpi_float_;
MPI_Datatype mpi_complex_;
MPI_Datatype mpi_float16_;
MPI_Datatype mpi_bfloat16_;

private:
bool initialized_;

// Private API
int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*);
int (*mpi_type_commit)(MPI_Datatype*);
int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*);
};

MPIWrapper& mpi() {
Expand Down Expand Up @@ -276,7 +337,7 @@ void all_sum(Group group, const array& input_, array& output) {
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(),
mpi().op_sum(input),
to_comm(group));
}

Expand Down
95 changes: 93 additions & 2 deletions python/mlx/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.

from functools import wraps
from typing import Callable, Optional
from functools import reduce, wraps
from typing import Any, Callable, Optional

import mlx.core as mx

from ..utils import tree_flatten, tree_map, tree_unflatten
from .layers.base import Module


Expand Down Expand Up @@ -68,3 +69,93 @@ def wrapped_checkpointed_fn(*args, **kwargs):
return checkpointed_fn(module.trainable_parameters(), *args, **kwargs)

return wrapped_checkpointed_fn


def average_gradients(
gradients: Any,
group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None,
):
"""Average the gradients across the distributed processes in the passed group.
This helper enables concatenating several gradients of small arrays to one
big all reduce call for better networking performance.
Args:
gradients (Any): The Python tree containing the gradients (it should
have the same structure across processes)
group (Optional[mlx.core.distributed.Group]): The group of processes to
average the gradients. If set to ``None`` the global group is used.
Default: ``None``.
all_reduce_size (int): Group arrays until their size in bytes exceeds
this number. Perform one communication step per group of arrays. If
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
"""
group = group or mx.distributed.init()
N = group.size()

if N == 1:
return gradients

def _average(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N

if all_reduce_size <= 0:
return tree_map(_average, gradients)

else:
flat_grads = tree_flatten(gradients)
if len(flat_grads) == 0:
return gradients

# Extract some info for the gradient
keys = [k for k, _ in flat_grads]
shapes = [v.shape for _, v in flat_grads]
sizes = [v.size for _, v in flat_grads]
dtypes = [v.dtype for _, v in flat_grads]

# We can't group them if they have mixed types
if not all(dt == dtypes[0] for dt in dtypes):
return average_gradients(gradients, group, 0, communication_type)
itemsize = (
communication_type.size
if communication_type is not None
else dtypes[0].size
)

# Gather the gradients in groups that are just above or equal to all_reduce_size
grad_groups = []
grad_group = []
grad_group_size = 0
for i in range(len(keys)):
grad_group.append(i)
grad_group_size += sizes[i] * itemsize
if grad_group_size >= all_reduce_size:
grad_groups.append(grad_group)
grad_group = []
grad_group_size = 0
if grad_group:
grad_groups.append(grad_group)
grad_group = []

# Concatenate-reduce-split
new_flat_grads = []
for grad_group in grad_groups:
indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0])
big_grad = mx.concatenate(
[flat_grads[i][1].reshape(-1) for i in grad_group]
)
big_grad = _average(big_grad)
big_grad = mx.split(big_grad, indices[1:-1])
new_flat_grads.extend(
(keys[j], big_grad[i].reshape(shapes[j]))
for i, j in enumerate(grad_group)
)

return tree_unflatten(new_flat_grads)
54 changes: 54 additions & 0 deletions python/tests/mpi_test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import mlx.core as mx
import mlx_tests
from mlx.nn.utils import average_gradients


class TestDistributed(mlx_tests.MLXTestCase):
Expand Down Expand Up @@ -110,6 +111,59 @@ def test_send_recv(self):

self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))

def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0
xtype = None

def new_all_sum(x, **kwargs):
nonlocal n_calls
nonlocal xtype

n_calls += 1
if xtype is not None:
self.assertEqual(xtype, x.dtype)

return original_all_sum(x, **kwargs)

mx.distributed.all_sum = new_all_sum

try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)

n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)

n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 10)

n_calls = 0
xtype = mx.float16
new_grads = average_gradients(
grads, all_reduce_size=2 * 50, communication_type=mx.float16
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)

finally:
mx.distributed.all_sum = original_all_sum


if __name__ == "__main__":
unittest.main()

0 comments on commit 914409f

Please sign in to comment.