diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h index 1ed82cb6a..3f7f16f61 100644 --- a/mlx/distributed/distributed.h +++ b/mlx/distributed/distributed.h @@ -32,6 +32,8 @@ struct Group { */ Group split(int color, int key = -1); + void barrier(); + const std::shared_ptr& raw_group() { return group_; } diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 3223832e5..0116c71ca 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -71,6 +71,7 @@ struct MPIWrapper { LOAD_SYMBOL(MPI_Allgather, all_gather); LOAD_SYMBOL(MPI_Send, send); LOAD_SYMBOL(MPI_Recv, recv); + LOAD_SYMBOL(MPI_Barrier, barrier); LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); LOAD_SYMBOL(MPI_Op_create, mpi_op_create); @@ -195,6 +196,7 @@ struct MPIWrapper { int (*comm_free)(MPI_Comm*); int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm); int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*); + int (*barrier)(MPI_Comm); // Objects MPI_Comm comm_world_; @@ -263,6 +265,10 @@ struct MPIGroupImpl { return size_; } + void barrier() { + mpi().barrier(comm_); + } + private: MPI_Comm comm_; bool global_; @@ -298,6 +304,11 @@ Group Group::split(int color, int key) { return Group(std::make_shared(new_comm, false)); } +void Group::barrier() { + auto mpi_group = std::static_pointer_cast(group_); + mpi_group->barrier(); +} + bool is_available() { return mpi().is_available(); } diff --git a/mlx/distributed/no_distributed.cpp b/mlx/distributed/no_distributed.cpp index 009e3a715..ef4e472e7 100644 --- a/mlx/distributed/no_distributed.cpp +++ b/mlx/distributed/no_distributed.cpp @@ -17,6 +17,8 @@ Group Group::split(int color, int key) { throw std::runtime_error("Cannot split the distributed group further"); } +void Group::barrier() {} + bool is_available() { return false; }