diff --git a/test/test_mp_reduce_scatter.py b/test/test_mp_reduce_scatter.py index 9e4a87363f1..531c3e5650a 100644 --- a/test/test_mp_reduce_scatter.py +++ b/test/test_mp_reduce_scatter.py @@ -10,6 +10,7 @@ def _mp_fn(index): scale = 1 / world_size scatter_dim = 1 shard_size = 2 + input_list_size = 5 if xm.xla_device_hw(device) == 'TPU': rand = torch.rand((32, shard_size * world_size, 32)) @@ -25,8 +26,35 @@ def _mp_fn(index): expected = expected_world.cpu().index_select(scatter_dim, slice_idx) assert res.cpu().allclose(expected) - xm.rendezvous('test_reduce_scatter') + + # Testing reduce-scatter with list input + rand_list = [ + torch.rand((32, shard_size * world_size, 32)) + for _ in range(input_list_size) + ] + xrand_list = [rand.to(device) for rand in rand_list] + + # TODO: fix the broken case with pin_layout=True + res_list = xm.reduce_scatter( + xm.REDUCE_SUM, + xrand_list, + scale, + scatter_dim, + world_size, + pin_layout=False) + + for i, res in enumerate(res_list): + expected_world = xm.all_reduce(xm.REDUCE_SUM, xrand_list[i], scale) + xm.mark_step() + + slice_idx = torch.tensor( + list(range(index * shard_size, (index + 1) * shard_size))) + expected = expected_world.cpu().index_select(scatter_dim, slice_idx) + assert res.cpu().allclose(expected) + + xm.rendezvous('test_reduce_scatter_list_input') + else: print( 'Default device {} is not a TPU device'.format(device), file=sys.stderr) diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 991e4669d1d..8a818b4211a 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -134,6 +134,27 @@ def test_reduce_scatter(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([output]) hlo_matches(hlo, reduce_scatter_pattern) + @skipIf(xr.device_type() == 'CPU', + "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.") + def test_reduce_scatter_coalesced(self): + device = xm.xla_device() + tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() + tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() + input_tensors_list = [[tensor, tensor], [tensor2, tensor2]] + output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)] + pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0])) + opts = dist.ReduceScatterOptions() + opts.reduceOp = dist.ReduceOp.SUM + reduce_scatter_pattern = ( + r'%reduce\-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) ' + r'reduce\-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, ' + r's64\[] %.+\.\d+\)') + pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts) + hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list) + hlo_matches(hlo, reduce_scatter_pattern) + # purge all computations attached the device. + xm.mark_step() + @patch_world(0, 6) def test_send(self): device = xm.xla_device() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 7174ec9514d..4063a279016 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -749,16 +749,18 @@ def reduce_scatter(reduce_type, reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``, ``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and ``xm.REDUCE_MAX``. - input: A single `torch.Tensor` all reduce + scatter op to. + input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then + it will also be the output. scale (float): A default scaling value to be applied after the reduce. scatter_dim (int): Dimension number to which apply scatter operation. shard_count (int): The number of ways to split up the scatter_dim in. groups (list): A list of list, representing the replica groups for - the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]` + the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]` defines two groups, one with the `[0, 1, 2, 3]` replicas and one with the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. - output: Optional output tensor + output: Optional output tensor if `input` is a torch.Tensor, or a list of + torch.Tensor if `input` is a list of torch.Tensor. pin_layout (bool, optional): whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might @@ -771,21 +773,39 @@ def reduce_scatter(reduce_type, the same as the input. """ token, devctx = _get_all_reduce_token() - if output != None: - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_reduce_scatter_out(reduce_type, output, - input, token, scale, - scatter_dim, - shard_count, groups or - [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output - - result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale, - scatter_dim, shard_count, - groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) - return result[0] + + if isinstance(input, torch.Tensor): + if output != None: + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_reduce_scatter_out( + reduce_type, output, input, token, scale, scatter_dim, shard_count, + groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + + result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, + scale, scatter_dim, + shard_count, groups or [], + pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) + return result[0] + + # Now the input should be a list of Tensors. + elif isinstance(input, list) and all( + isinstance(v, torch.Tensor) for v in input): + if output != None: + raise RuntimeError( + "For xm.reduce_scatter with list of tensors input, output != None is not yet supported." + ) + + result = torch_xla._XLAC._xla_reduce_scatter_coalesced( + reduce_type, output or [], input, token, scale, scatter_dim, + shard_count, groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) + return result[:-1] + else: + raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " + f"given {type(input)}.") def add_step_closure(closure, args=(), run_async=False): diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index a06d916d1e1..c98c603b3d3 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -348,6 +348,54 @@ ReduceScatterResult BuildReduceScatter( return {reduce_result, token_handler.GetNewToken(reduce_result)}; } +ReduceScatterResultCoalesced BuildReduceScatterCoalesced( + AllReduceType reduce_type, absl::Span inputs, + xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, + const std::vector>& groups, bool pin_layout) { + std::vector cc_groups = CreateReduceGroups(groups); + TokenHandler token_handler(token); + // TODO: We use pseudo-tokens ATM, which are real values. This need to be + // switched to use the real XLA Token once support has been added to XLA + // ReduceScatter(). + ReduceContext cc_ctx = GetReduceContext(inputs); + std::vector result(inputs.size()); + for (auto& type_ctx : cc_ctx.contexts) { + xla::XlaOp reduce_result; + type_ctx.second.ops[0] = token_handler.GetInput( + type_ctx.second.ops[0], &type_ctx.second.operand_shapes[0]); + if (pin_layout) { + reduce_result = xla::ReduceScatter( + xla::Tuple(inputs[0].builder(), type_ctx.second.ops), + GetReduceComutation(reduce_type, type_ctx.first), scatter_dim, + shard_count, cc_groups, /*channel_id=*/absl::nullopt, + /*layout=*/ + MakeReduceShape(type_ctx.second.operand_shapes).layout()); + } else { + reduce_result = xla::ReduceScatter( + xla::Tuple(inputs[0].builder(), type_ctx.second.ops), + GetReduceComutation(reduce_type, type_ctx.first), scatter_dim, + shard_count, cc_groups); + } + for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { + size_t op_idx = type_ctx.second.indices[i]; + xla::XlaOp gte; + if (ShapeHelper::ShapeOfXlaOp(reduce_result).rank() == 0) { + gte = xla::GetTupleElement(reduce_result, i); + } else { + gte = reduce_result; + } + if (scale != 1.0) { + xla::XlaOp scaling_value = XlaHelpers::ScalarValue( + scale, type_ctx.second.operand_shapes[i].element_type(), + gte.builder()); + gte = gte * scaling_value; + } + result[op_idx] = gte; + } + } + return {result, token_handler.GetNewToken(result[0])}; +} + std::vector GetOperandListWithToken( c10::ArrayRef operands, const torch::lazy::Value& token) { diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index 715363ea7a5..ade1a0fa00e 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -55,6 +55,11 @@ struct ReduceScatterResult { xla::XlaOp token; }; +struct ReduceScatterResultCoalesced { + std::vector result; + xla::XlaOp token; +}; + std::vector BuildAllReduce( AllReduceType reduce_type, absl::Span operands, xla::XlaOp token, double scale, @@ -91,6 +96,11 @@ ReduceScatterResult BuildReduceScatter( int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout); +ReduceScatterResultCoalesced BuildReduceScatterCoalesced( + AllReduceType reduce_type, absl::Span inputs, + xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, + const std::vector>& groups, bool pin_layout); + std::vector GetOperandListWithToken( c10::ArrayRef operands, const torch::lazy::Value& token); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 2c421d2dbfc..75a0ab29cf5 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -211,6 +211,29 @@ std::pair> ReduceScatter( std::make_shared(new_token)); } +std::pair, std::shared_ptr> +ReduceScatterCoalesced(const std::string& reduce_type, + const std::vector& outputs, + const std::vector& inputs, + const std::shared_ptr& token, + double scale, int64_t scatter_dim, int64_t shard_count, + const std::vector>& replica_groups, + bool pin_layout) { + std::vector xtensors_out = + GetXlaTensors(outputs, /*want_all=*/true); + std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); + std::vector result; + torch::lazy::Value new_token; + std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced( + xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale, + scatter_dim, shard_count, replica_groups, pin_layout); + std::vector aten_result; + for (auto& xt : result) { + aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt))); + } + return {aten_result, std::make_shared(new_token)}; +} + std::shared_ptr ReduceScatterOut( const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, const std::shared_ptr& token, double scale, @@ -1287,6 +1310,30 @@ void InitXlaModuleBindings(py::module m) { result_tuple[1] = new_token; return result_tuple; }); + m.def("_xla_reduce_scatter_coalesced", + [](const std::string& reduce_type, std::vector& outputs, + const std::vector& inputs, + const std::shared_ptr& token, double scale, + int64_t scatter_dim, int64_t shard_count, const py::list& groups, + bool pin_layout) { + std::vector> replica_groups = + CreateReduceGroups(groups); + std::vector result; + std::shared_ptr new_token; + { + NoGilSection nogil; + std::tie(result, new_token) = ReduceScatterCoalesced( + reduce_type, outputs, inputs, token, scale, scatter_dim, + shard_count, replica_groups, pin_layout); + } + auto result_list = py::list(result.size() + 1); + for (int i = 0; i < result.size(); ++i) { + result_list[i] = torch::autograd::make_variable( + result[i], /*requires_grad=*/result[i].requires_grad()); + } + result_list[result.size()] = new_token; + return result_list; + }); m.def("_xla_reduce_scatter_out", [](const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, diff --git a/torch_xla/csrc/ops/reduce_scatter.cpp b/torch_xla/csrc/ops/reduce_scatter.cpp index 91c0f5d66e2..f4e82696cde 100644 --- a/torch_xla/csrc/ops/reduce_scatter.cpp +++ b/torch_xla/csrc/ops/reduce_scatter.cpp @@ -28,6 +28,26 @@ xla::Shape NodeOutputShape(AllReduceType reduce_type, return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn); } +xla::Shape NodeOutputShapeCoalesced( + AllReduceType reduce_type, c10::ArrayRef inputs, + const torch::lazy::Value& token, double scale, int64_t scatter_dim, + int64_t shard_count, const std::vector>& groups, + bool pin_layout) { + auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { + ReduceScatterResultCoalesced result = BuildReduceScatterCoalesced( + reduce_type, operands.subspan(0, operands.size() - 1), operands.back(), + scale, scatter_dim, shard_count, groups, pin_layout); + result.result.emplace_back(result.token); + return xla::Tuple(operands[0].builder(), result.result); + }; + std::vector input_shapes; + for (const auto& input : inputs) { + input_shapes.emplace_back(GetXlaShape(input)); + } + input_shapes.emplace_back(GetXlaShape(token)); + return InferOutputShape(input_shapes, shape_fn); +} + } // namespace ReduceScatter::ReduceScatter(AllReduceType reduce_type, @@ -52,12 +72,41 @@ ReduceScatter::ReduceScatter(AllReduceType reduce_type, groups_(std::move(groups)), pin_layout_(pin_layout) {} +ReduceScatterCoalesced::ReduceScatterCoalesced( + AllReduceType reduce_type, c10::ArrayRef inputs, + const torch::lazy::Value& token, double scale, int64_t scatter_dim, + int64_t shard_count, std::vector> groups, + bool pin_layout) + : XlaNode(xla_reduce_scatter, GetOperandListWithToken(inputs, token), + [&]() { + return NodeOutputShapeCoalesced(reduce_type, inputs, token, + scale, scatter_dim, shard_count, + groups, pin_layout); + }, + /*num_outputs=*/inputs.size() + 1, + torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale, + scatter_dim, shard_count, groups, pin_layout)), + reduce_type_(reduce_type), + scale_(scale), + scatter_dim_(scatter_dim), + shard_count_(shard_count), + groups_(std::move(groups)), + pin_layout_(pin_layout) {} + torch::lazy::NodePtr ReduceScatter::Clone(torch::lazy::OpList operands) const { return torch::lazy::MakeNode( reduce_type_, operands.at(0), operands.at(1), scale_, scatter_dim_, shard_count_, groups_, pin_layout_); } +torch::lazy::NodePtr ReduceScatterCoalesced::Clone( + torch::lazy::OpList operands) const { + std::vector inputs(operands.begin(), operands.end() - 1); + return torch::lazy::MakeNode( + reduce_type_, inputs, operands.back(), scale_, scatter_dim_, shard_count_, + groups_, pin_layout_); +} + XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const { xla::XlaOp input = loctx->GetOutputOp(operand(0)); xla::XlaOp token = loctx->GetOutputOp(operand(1)); @@ -67,6 +116,21 @@ XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const { return ReturnOps({result.result, result.token}, loctx); } +XlaOpVector ReduceScatterCoalesced::Lower(LoweringContext* loctx) const { + auto& operand_list = operands(); + std::vector inputs; + inputs.reserve(operand_list.size()); + for (size_t i = 0; i + 1 < operand_list.size(); ++i) { + inputs.push_back(loctx->GetOutputOp(operand_list[i])); + } + xla::XlaOp token = loctx->GetOutputOp(operand_list.back()); + ReduceScatterResultCoalesced result = BuildReduceScatterCoalesced( + reduce_type_, inputs, token, scale_, scatter_dim_, shard_count_, groups_, + pin_layout_); + result.result.push_back(result.token); + return ReturnOps(result.result, loctx); +} + std::string ReduceScatter::ToString() const { std::stringstream ss; ss << XlaNode::ToString() @@ -82,4 +146,18 @@ std::string ReduceScatter::ToString() const { return ss.str(); } +std::string ReduceScatterCoalesced::ToString() const { + std::stringstream ss; + ss << XlaNode::ToString() + << ", reduce_type=" << torch::lazy::GetEnumValue(reduce_type_) + << ", scale=" << scale_ << ", scatter_dim=" << scatter_dim_ + << ", shard_count=" << shard_count_ << ", pin_layout=" << pin_layout_ + << ", groups=("; + for (size_t i = 0; i < groups_.size(); ++i) { + ss << (i == 0 ? "(" : ",("); + ss << absl::StrJoin(groups_[i], ", ") << ")"; + } + ss << ")"; + return ss.str(); +} } // namespace torch_xla diff --git a/torch_xla/csrc/ops/reduce_scatter.h b/torch_xla/csrc/ops/reduce_scatter.h index 0c888ce0fde..2a752788fc4 100644 --- a/torch_xla/csrc/ops/reduce_scatter.h +++ b/torch_xla/csrc/ops/reduce_scatter.h @@ -36,6 +36,38 @@ class ReduceScatter : public XlaNode { bool pin_layout_; }; +class ReduceScatterCoalesced : public XlaNode { + public: + ReduceScatterCoalesced(AllReduceType reduce_type, + c10::ArrayRef inputs, + const torch::lazy::Value& token, double scale, + int64_t scatter_dim, int64_t shard_count, + std::vector> groups, + bool pin_layout); + + std::string ToString() const override; + + torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; + + XlaOpVector Lower(LoweringContext* loctx) const override; + + AllReduceType reduce_type() const { return reduce_type_; } + + double scale() const { return scale_; } + + const std::vector>& groups() const { return groups_; } + + bool pin_layout() const { return pin_layout_; } + + private: + AllReduceType reduce_type_; + double scale_; + int64_t scatter_dim_; + int64_t shard_count_; + std::vector> groups_; + bool pin_layout_; +}; + } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_OPS_REDUCE_SCATTER_H_ diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 54231614db2..fefa770d432 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -391,6 +391,33 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, return torch::lazy::Value(node, 1); } +std::pair, torch::lazy::Value> +reduce_scatter_coalesced(const std::vector& outputs, + const std::vector& inputs, + const torch::lazy::Value& token, + AllReduceType reduce_type, double scale, + int64_t scatter_dim, int64_t shard_count, + std::vector> groups, + bool pin_layout) { + XLA_CHECK(outputs.empty() || outputs.size() == inputs.size()); + std::vector input_values; + input_values.reserve(inputs.size()); + for (auto& input : inputs) { + input_values.push_back(input->GetIrValue()); + } + torch::lazy::NodePtr node = torch::lazy::MakeNode( + reduce_type, input_values, token, scale, scatter_dim, shard_count, + std::move(groups), pin_layout); + std::vector result; + for (size_t i = 0; i < inputs.size(); ++i) { + result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); + if (!outputs.empty()) { + outputs[i]->SetIrValue(torch::lazy::Value(node, i)); + } + } + return {result, torch::lazy::Value(node, inputs.size())}; +} + std::pair all_to_all( const XLATensorPtr& input, const torch::lazy::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 47ba3d36799..aeabc96604f 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -33,6 +33,15 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, std::vector> groups, bool pin_layout); +std::pair, torch::lazy::Value> +reduce_scatter_coalesced(const std::vector& outputs, + const std::vector& inputs, + const torch::lazy::Value& token, + AllReduceType reduce_type, double scale, + int64_t scatter_dim, int64_t shard_count, + std::vector> groups, + bool pin_layout); + std::pair all_to_all( const XLATensorPtr& input, const torch::lazy::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 75b909cb848..b4a1ff6e142 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -122,6 +122,33 @@ def reduce_scatter(self, output_tensors, input_tensors_list, opts): return _ret_work(output_tensors) + def reduce_scatter_coalesced(self, output_tensors, input_tensors_list, opts): + input_tensor_list = [] + for input_tensors in input_tensors_list: + # Ensure all inputs have the same shape. + first_shape = input_tensors[0].shape + for i, t in enumerate(input_tensors[1:]): + if first_shape != t.shape: + raise ValueError(f"Input {i+1}'s shape is different from input 0: " + f"{t.shape} vs {first_shape}") + input_tensor = torch.cat(input_tensors) + input_tensor_list.append(input_tensor) + + reduce_type = self._get_reduce_type(opts.reduceOp) + groups = self._mesh + shard_count = len(groups[0]) if groups else self.size() + xm.reduce_scatter( + reduce_type, + input_tensor_list, + scatter_dim=0, + shard_count=shard_count, + scale=1, + groups=groups, + output=output_tensors, + pin_layout=False) + + return _ret_work(output_tensors) + # Call site: # https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L2683 def barrier(self, opts):