diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index c232c13e7..3223832e5 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -32,8 +32,29 @@ array ensure_row_contiguous(const array& arr) { } } +template +void simple_sum( + void* input, + void* accumulator, + int* len, + MPI_Datatype* datatype) { + T* in = (T*)input; + T* acc = (T*)accumulator; + int N = *len; + + while (N-- > 0) { + *acc += *in; + acc++; + in++; + } +} +template void simple_sum(void*, void*, int*, MPI_Datatype*); +template void simple_sum(void*, void*, int*, MPI_Datatype*); + struct MPIWrapper { MPIWrapper() { + initialized_ = false; + libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); if (libmpi_handle_ == nullptr) { return; @@ -50,6 +71,9 @@ struct MPIWrapper { LOAD_SYMBOL(MPI_Allgather, all_gather); LOAD_SYMBOL(MPI_Send, send); LOAD_SYMBOL(MPI_Recv, recv); + LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); + LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); + LOAD_SYMBOL(MPI_Op_create, mpi_op_create); // Objects LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); @@ -79,7 +103,24 @@ struct MPIWrapper { if (!is_available()) { return false; } - return init(nullptr, nullptr) == MPI_SUCCESS; + bool success = init(nullptr, nullptr) == MPI_SUCCESS; + + // Initialize custom types and ops + if (success && !initialized_) { + // Custom float16 dtypes + mpi_type_contiguous(2, mpi_uint8_, &mpi_float16_); + mpi_type_commit(&mpi_float16_); + mpi_type_contiguous(2, mpi_uint8_, &mpi_bfloat16_); + mpi_type_commit(&mpi_bfloat16_); + + // Custom sum ops + mpi_op_create(&simple_sum, 1, &op_sum_f16_); + mpi_op_create(&simple_sum, 1, &op_sum_bf16_); + + initialized_ = true; + } + + return success; } void finalize_safe() { @@ -117,13 +158,21 @@ struct MPIWrapper { case complex64: return mpi_complex_; case float16: + return mpi_float16_; case bfloat16: - throw std::runtime_error("MPI doesn't support 16-bit floats"); + return mpi_bfloat16_; } } - MPI_Op op_sum() { - return op_sum_; + MPI_Op op_sum(const array& arr) { + switch (arr.dtype()) { + case float16: + return op_sum_f16_; + case bfloat16: + return op_sum_bf16_; + default: + return op_sum_; + } } void* libmpi_handle_; @@ -152,6 +201,8 @@ struct MPIWrapper { // Ops MPI_Op op_sum_; + MPI_Op op_sum_f16_; + MPI_Op op_sum_bf16_; // Datatypes MPI_Datatype mpi_bool_; @@ -165,6 +216,16 @@ struct MPIWrapper { MPI_Datatype mpi_uint64_; MPI_Datatype mpi_float_; MPI_Datatype mpi_complex_; + MPI_Datatype mpi_float16_; + MPI_Datatype mpi_bfloat16_; + + private: + bool initialized_; + + // Private API + int (*mpi_type_contiguous)(int, MPI_Datatype, MPI_Datatype*); + int (*mpi_type_commit)(MPI_Datatype*); + int (*mpi_op_create)(MPI_User_function*, int, MPI_Op*); }; MPIWrapper& mpi() { @@ -276,7 +337,7 @@ void all_sum(Group group, const array& input_, array& output) { output.data(), input.size(), mpi().datatype(input), - mpi().op_sum(), + mpi().op_sum(input), to_comm(group)); } diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py index 8c5e4d462..6cc799a7c 100644 --- a/python/mlx/nn/utils.py +++ b/python/mlx/nn/utils.py @@ -1,10 +1,11 @@ # Copyright © 2023-2024 Apple Inc. -from functools import wraps -from typing import Callable, Optional +from functools import reduce, wraps +from typing import Any, Callable, Optional import mlx.core as mx +from ..utils import tree_flatten, tree_map, tree_unflatten from .layers.base import Module @@ -68,3 +69,93 @@ def wrapped_checkpointed_fn(*args, **kwargs): return checkpointed_fn(module.trainable_parameters(), *args, **kwargs) return wrapped_checkpointed_fn + + +def average_gradients( + gradients: Any, + group: Optional[mx.distributed.Group] = None, + all_reduce_size: int = 32 * 1024**2, + communication_type: Optional[mx.Dtype] = None, +): + """Average the gradients across the distributed processes in the passed group. + + This helper enables concatenating several gradients of small arrays to one + big all reduce call for better networking performance. + + Args: + gradients (Any): The Python tree containing the gradients (it should + have the same structure across processes) + group (Optional[mlx.core.distributed.Group]): The group of processes to + average the gradients. If set to ``None`` the global group is used. + Default: ``None``. + all_reduce_size (int): Group arrays until their size in bytes exceeds + this number. Perform one communication step per group of arrays. If + less or equal to 0 array grouping is disabled. Default: ``32MiB``. + communication_type (Optional[mlx.core.Dtype]): If provided cast to this + type before performing the communication. Typically cast to a + smaller float to reduce the communication size. Default: ``None``. + """ + group = group or mx.distributed.init() + N = group.size() + + if N == 1: + return gradients + + def _average(x): + dt = x.dtype + x = x.astype(communication_type) if communication_type is not None else x + return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N + + if all_reduce_size <= 0: + return tree_map(_average, gradients) + + else: + flat_grads = tree_flatten(gradients) + if len(flat_grads) == 0: + return gradients + + # Extract some info for the gradient + keys = [k for k, _ in flat_grads] + shapes = [v.shape for _, v in flat_grads] + sizes = [v.size for _, v in flat_grads] + dtypes = [v.dtype for _, v in flat_grads] + + # We can't group them if they have mixed types + if not all(dt == dtypes[0] for dt in dtypes): + return average_gradients(gradients, group, 0, communication_type) + itemsize = ( + communication_type.size + if communication_type is not None + else dtypes[0].size + ) + + # Gather the gradients in groups that are just above or equal to all_reduce_size + grad_groups = [] + grad_group = [] + grad_group_size = 0 + for i in range(len(keys)): + grad_group.append(i) + grad_group_size += sizes[i] * itemsize + if grad_group_size >= all_reduce_size: + grad_groups.append(grad_group) + grad_group = [] + grad_group_size = 0 + if grad_group: + grad_groups.append(grad_group) + grad_group = [] + + # Concatenate-reduce-split + new_flat_grads = [] + for grad_group in grad_groups: + indices = reduce(lambda x, y: x + [x[-1] + sizes[y]], grad_group, [0]) + big_grad = mx.concatenate( + [flat_grads[i][1].reshape(-1) for i in grad_group] + ) + big_grad = _average(big_grad) + big_grad = mx.split(big_grad, indices[1:-1]) + new_flat_grads.extend( + (keys[j], big_grad[i].reshape(shapes[j])) + for i, j in enumerate(grad_group) + ) + + return tree_unflatten(new_flat_grads) diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py index 44f3fd4ce..aa261a6e5 100644 --- a/python/tests/mpi_test_distributed.py +++ b/python/tests/mpi_test_distributed.py @@ -4,6 +4,7 @@ import mlx.core as mx import mlx_tests +from mlx.nn.utils import average_gradients class TestDistributed(mlx_tests.MLXTestCase): @@ -110,6 +111,59 @@ def test_send_recv(self): self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512))) + def test_average_gradients(self): + original_all_sum = mx.distributed.all_sum + n_calls = 0 + xtype = None + + def new_all_sum(x, **kwargs): + nonlocal n_calls + nonlocal xtype + + n_calls += 1 + if xtype is not None: + self.assertEqual(xtype, x.dtype) + + return original_all_sum(x, **kwargs) + + mx.distributed.all_sum = new_all_sum + + try: + grads = [mx.ones(10) for i in range(10)] + new_grads = average_gradients(grads) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 1) + + n_calls = 0 + new_grads = average_gradients(grads, all_reduce_size=4 * 50) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 2) + + n_calls = 0 + new_grads = average_gradients(grads, all_reduce_size=0) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 10) + + n_calls = 0 + xtype = mx.float16 + new_grads = average_gradients( + grads, all_reduce_size=2 * 50, communication_type=mx.float16 + ) + mx.eval(new_grads) + self.assertEqual(len(new_grads), 10) + self.assertTrue(all(g.dtype == mx.float32 for g in new_grads)) + self.assertTrue(all(mx.all(g == 1) for g in new_grads)) + self.assertEqual(n_calls, 2) + + finally: + mx.distributed.all_sum = original_all_sum + if __name__ == "__main__": unittest.main()