diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 47d0ae72139..cefcd687d10 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -136,27 +136,6 @@ def test_reduce_scatter(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([output]) hlo_matches(hlo, reduce_scatter_pattern) - @skipIf(xr.device_type() == 'CPU', - "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.") - def test_reduce_scatter_coalesced(self): - device = xm.xla_device() - tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() - tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank() - input_tensors_list = [[tensor, tensor], [tensor2, tensor2]] - output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)] - pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0])) - opts = dist.ReduceScatterOptions() - opts.reduceOp = dist.ReduceOp.SUM - reduce_scatter_pattern = ( - r'%reduce\-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) ' - r'reduce\-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, ' - r's64\[] %.+\.\d+\)') - pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts) - hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list) - hlo_matches(hlo, reduce_scatter_pattern) - # purge all computations attached the device. - xm.mark_step() - @patch_world(0, 6) def test_send(self): device = xm.xla_device() diff --git a/test/test_zero1.py b/test/test_zero1.py index a4bdb55876c..e9c3a3eeee6 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -12,7 +12,6 @@ class XlaZeRO1Test(TestCase): - @unittest.skipIf(xr.device_type() == 'CPU', "Crash on CPU") @unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU") @unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'), "TODO(alanwaketan): Fix it for the token change.") @@ -27,13 +26,11 @@ def test_zero1(self): y.backward() opt1 = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - # pin_layout=False to workaround "Check failed: has_layout() element_type" opt2 = ZeroRedundancyOptimizer( model.parameters(), torch.optim.SGD, lr=0.01, momentum=0.9, - pin_layout=False, grad_clipping=False) opt1.step() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index f66c99d1c43..63186bcb06f 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -493,8 +493,6 @@ def _all_gather_using_all_reduce(value, dim=0, groups=None, pin_layout=True): Args: value (torch.Tensor): The input tensor. - value (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then - it will also be the output. dim (int): The gather dimension. Default: 0 groups (list, optional): A list of list, representing the replica groups for @@ -741,19 +739,16 @@ def reduce_scatter(reduce_type, reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``, ``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and ``xm.REDUCE_MAX``. - input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then - it will also be the output. + input: A single `torch.Tensor` all reduce + scatter op to. scale (float): A default scaling value to be applied after the reduce. scatter_dim (int): Dimension number to which apply scatter operation. shard_count (int): The number of ways to split up the scatter_dim in. groups (list): A list of list, representing the replica groups for - the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]` + the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]` defines two groups, one with the `[0, 1, 2, 3]` replicas and one with the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with all the replicas in it. output: Optional output tensor - output: Optional output tensor if `input` is a torch.Tensor or a list of - torch.Tensor if `input` is a list of torch.Tensor. pin_layout (bool, optional): whether to pin the layout for this communication op. Layout pining can prevent potential data corruption when each process that participate in the communication has slightly different program, but it might @@ -766,45 +761,21 @@ def reduce_scatter(reduce_type, the same as the input. """ token, devctx = _get_all_reduce_token() - - if isinstance(input, torch.Tensor): - if output != None: - # Call the out of place version of the reduce_scatter - new_token = torch_xla._XLAC._xla_reduce_scatter_out( - reduce_type, output, input, token, scale, scatter_dim, shard_count, - groups or [], pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) - return output - - result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, - scale, scatter_dim, - shard_count, groups or [], - pin_layout) - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) - return result[0] - - # Now the input should be a list of Tensors. - if not isinstance(input, list) or any( - not isinstance(v, torch.Tensor) for v in input): - raise TypeError("`input` needs to be a Tensor or a list of Tensors, but " - f"given {type(input)}.") - if output != None: - if not isinstance(output, list) or any( - not isinstance(v, torch.Tensor) for v in output): - raise TypeError( - f"`output` needs to be a list of Tensors, but given {type(output)}." - ) - if len(output) != len(input): - raise ValueError("`output` length doesn't match `input` length: " - f"{len(output)} vs {len(input)}.") - - result = torch_xla._XLAC._xla_reduce_scatter_coalesced( - reduce_type, output or [], input, token, scale, scatter_dim, shard_count, - groups or [], pin_layout) - #torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) - #return result[0] - torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) - return result[:-1] + if output != None: + # Call the out of place version of the reduce_scatter + new_token = torch_xla._XLAC._xla_reduce_scatter_out(reduce_type, output, + input, token, scale, + scatter_dim, + shard_count, groups or + [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) + return output + + result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale, + scatter_dim, shard_count, + groups or [], pin_layout) + torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1]) + return result[0] def add_step_closure(closure, args=(), run_async=False): diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index f70c8dbccf6..9512cf1512a 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -286,49 +286,38 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape, } ReduceScatterResult BuildReduceScatter( - AllReduceType reduce_type, absl::Span inputs, - xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, + AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale, + int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout) { - std::vector cc_groups = CreateReduceGroups(groups); + std::vector reduce_groups = CreateReduceGroups(groups); TokenHandler token_handler(token); - // TODO: We use pseudo-tokens ATM, which are real values. This need to be - // switched to use the real XLA Token once support has been added to XLA - // ReduceScatter(). - ReduceContext cc_ctx = GetReduceContext(inputs); - std::vector result(inputs.size()); - for (auto& type_ctx : cc_ctx.contexts) { - xla::XlaOp reduce_result; - if (pin_layout) { - reduce_result = xla::ReduceScatter( - xla::Tuple(inputs[0].builder(), type_ctx.second.ops), - GetReduceComutation(reduce_type, type_ctx.first), scatter_dim, - shard_count, cc_groups, /*channel_id=*/absl::nullopt, - /*layout=*/ - MakeReduceShape(type_ctx.second.operand_shapes).layout()); - } else { - reduce_result = xla::ReduceScatter( - xla::Tuple(inputs[0].builder(), type_ctx.second.ops), - GetReduceComutation(reduce_type, type_ctx.first), scatter_dim, - shard_count, cc_groups); - } - for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) { - size_t op_idx = type_ctx.second.indices[i]; - xla::XlaOp gte; - if (ShapeHelper::ShapeOfXlaOp(reduce_result).rank() == 0) { - gte = xla::GetTupleElement(reduce_result, i); - } else { - gte = reduce_result; - } - if (scale != 1.0) { - xla::XlaOp scaling_value = XlaHelpers::ScalarValue( - scale, type_ctx.second.operand_shapes[i].element_type(), - gte.builder()); - gte = gte * scaling_value; - } - result[op_idx] = gte; - } + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + xla::XlaOp reduce_result; + if (pin_layout) { + torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); + xla::Shape reduce_shape = MakeArrayShapeFromDimensions( + input_shape.dimensions(), input_shape.dynamic_dimensions(), + input_shape.element_type(), + static_cast(xla_device.type())); + reduce_result = xla::ReduceScatter( + token_handler.GetInput(input, &input_shape), + GetReduceComutation(reduce_type, input_shape.element_type()), + scatter_dim, shard_count, reduce_groups, /*channel_id=*/absl::nullopt, + /*layout=*/reduce_shape.layout()); + } else { + reduce_result = xla::ReduceScatter( + token_handler.GetInput(input, &input_shape), + GetReduceComutation(reduce_type, input_shape.element_type()), + scatter_dim, shard_count, reduce_groups); } - return {result, token_handler.GetNewToken(result[0])}; + + if (scale != 1.0) { + xla::XlaOp scaling_value = XlaHelpers::ScalarValue( + scale, input_shape.element_type(), input.builder()); + reduce_result = reduce_result * scaling_value; + } + + return {reduce_result, token_handler.GetNewToken(reduce_result)}; } // moved from torch_xla/csrc/ops/all_reduce.cpp diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index d0692080b35..0445892e457 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -46,7 +46,7 @@ struct RecvResult { }; struct ReduceScatterResult { - std::vector result; + xla::XlaOp result; xla::XlaOp token; }; @@ -77,8 +77,8 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape, int64_t channel_id); ReduceScatterResult BuildReduceScatter( - AllReduceType reduce_type, absl::Span inputs, - xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count, + AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale, + int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout); std::vector GetOperandList( diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 5bb773f2618..df4fd679746 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -214,29 +214,6 @@ std::pair> ReduceScatter( std::make_shared(new_token)); } -std::pair, std::shared_ptr> -ReduceScatterCoalesced(const std::string& reduce_type, - const std::vector& outputs, - const std::vector& inputs, - const std::shared_ptr& token, - double scale, int64_t scatter_dim, int64_t shard_count, - const std::vector>& replica_groups, - bool pin_layout) { - std::vector xtensors_out = - GetXlaTensors(outputs, /*want_all=*/true); - std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); - std::vector result; - torch::lazy::Value new_token; - std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced( - xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale, - scatter_dim, shard_count, replica_groups, pin_layout); - std::vector aten_result; - for (auto& xt : result) { - aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt))); - } - return {aten_result, std::make_shared(new_token)}; -} - std::shared_ptr ReduceScatterOut( const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, const std::shared_ptr& token, double scale, @@ -1286,30 +1263,6 @@ void InitXlaModuleBindings(py::module m) { result_tuple[1] = new_token; return result_tuple; }); - m.def("_xla_reduce_scatter_coalesced", - [](const std::string& reduce_type, std::vector& outputs, - const std::vector& inputs, - const std::shared_ptr& token, double scale, - int64_t scatter_dim, int64_t shard_count, const py::list& groups, - bool pin_layout) { - std::vector> replica_groups = - CreateReduceGroups(groups); - std::vector result; - std::shared_ptr new_token; - { - NoGilSection nogil; - std::tie(result, new_token) = ReduceScatterCoalesced( - reduce_type, outputs, inputs, token, scale, scatter_dim, - shard_count, replica_groups, pin_layout); - } - auto result_list = py::list(result.size() + 1); - for (int i = 0; i < result.size(); ++i) { - result_list[i] = torch::autograd::make_variable( - result[i], /*requires_grad=*/result[i].requires_grad()); - } - result_list[result.size()] = new_token; - return result_list; - }); m.def("_xla_reduce_scatter_out", [](const std::string& reduce_type, at::Tensor& output, const at::Tensor& input, diff --git a/torch_xla/csrc/ops/reduce_scatter.cpp b/torch_xla/csrc/ops/reduce_scatter.cpp index a1f2c912457..91c0f5d66e2 100644 --- a/torch_xla/csrc/ops/reduce_scatter.cpp +++ b/torch_xla/csrc/ops/reduce_scatter.cpp @@ -12,45 +12,37 @@ namespace torch_xla { namespace { xla::Shape NodeOutputShape(AllReduceType reduce_type, - c10::ArrayRef inputs, + const torch::lazy::Value input, const torch::lazy::Value& token, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& groups, bool pin_layout) { auto shape_fn = [&](absl::Span operands) -> xla::XlaOp { - ReduceScatterResult result = BuildReduceScatter( - reduce_type, operands.subspan(0, operands.size() - 1), operands.back(), - scale, scatter_dim, shard_count, groups, pin_layout); - std::vector outputs; - for (size_t i = 0; i < result.result.size(); ++i) { - outputs.emplace_back(result.result[i]); - } - outputs.emplace_back(result.token); - return xla::Tuple(operands[0].builder(), outputs); + xla::XlaOp inputOp = operands[0]; + xla::XlaOp tokenOp = operands[1]; + ReduceScatterResult result = + BuildReduceScatter(reduce_type, inputOp, tokenOp, scale, scatter_dim, + shard_count, groups, pin_layout); + return xla::Tuple(operands[0].builder(), {result.result, result.token}); }; - std::vector input_shapes; - for (const auto& input : inputs) { - input_shapes.emplace_back(GetXlaShape(input)); - } - input_shapes.emplace_back(GetXlaShape(token)); - return InferOutputShape(input_shapes, shape_fn); + return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn); } } // namespace ReduceScatter::ReduceScatter(AllReduceType reduce_type, - c10::ArrayRef inputs, + const torch::lazy::Value& input, const torch::lazy::Value& token, double scale, int64_t scatter_dim, int64_t shard_count, std::vector> groups, bool pin_layout) - : XlaNode(xla_reduce_scatter, GetOperandList(inputs, token), + : XlaNode(xla_reduce_scatter, {input, token}, [&]() { - return NodeOutputShape(reduce_type, inputs, token, scale, + return NodeOutputShape(reduce_type, input, token, scale, scatter_dim, shard_count, groups, pin_layout); }, - /*num_outputs=*/inputs.size() + 1, + /*num_outputs=*/2, torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale, scatter_dim, shard_count, groups, pin_layout)), reduce_type_(reduce_type), @@ -61,25 +53,18 @@ ReduceScatter::ReduceScatter(AllReduceType reduce_type, pin_layout_(pin_layout) {} torch::lazy::NodePtr ReduceScatter::Clone(torch::lazy::OpList operands) const { - std::vector inputs(operands.begin(), operands.end() - 1); return torch::lazy::MakeNode( - reduce_type_, inputs, operands.back(), scale_, scatter_dim_, shard_count_, - groups_, pin_layout_); + reduce_type_, operands.at(0), operands.at(1), scale_, scatter_dim_, + shard_count_, groups_, pin_layout_); } XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const { - auto& operand_list = operands(); - std::vector inputs; - inputs.reserve(operand_list.size()); - for (size_t i = 0; i + 1 < operand_list.size(); ++i) { - inputs.push_back(loctx->GetOutputOp(operand_list[i])); - } - xla::XlaOp token = loctx->GetOutputOp(operand_list.back()); + xla::XlaOp input = loctx->GetOutputOp(operand(0)); + xla::XlaOp token = loctx->GetOutputOp(operand(1)); ReduceScatterResult result = - BuildReduceScatter(reduce_type_, inputs, token, scale_, scatter_dim_, + BuildReduceScatter(reduce_type_, input, token, scale_, scatter_dim_, shard_count_, groups_, pin_layout_); - result.result.push_back(result.token); - return ReturnOps(result.result, loctx); + return ReturnOps({result.result, result.token}, loctx); } std::string ReduceScatter::ToString() const { diff --git a/torch_xla/csrc/ops/reduce_scatter.h b/torch_xla/csrc/ops/reduce_scatter.h index 8e4a9e97275..0c888ce0fde 100644 --- a/torch_xla/csrc/ops/reduce_scatter.h +++ b/torch_xla/csrc/ops/reduce_scatter.h @@ -8,8 +8,7 @@ namespace torch_xla { class ReduceScatter : public XlaNode { public: - ReduceScatter(AllReduceType reduce_type, - c10::ArrayRef inputs, + ReduceScatter(AllReduceType reduce_type, const torch::lazy::Value& input, const torch::lazy::Value& token, double scale, int64_t scatter_dim, int64_t shard_count, std::vector> groups, bool pin_layout); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 3f233d1c2fa..6522dad9024 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -390,33 +390,6 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, return torch::lazy::Value(node, 1); } -std::pair, torch::lazy::Value> -reduce_scatter_coalesced(const std::vector& outputs, - const std::vector& inputs, - const torch::lazy::Value& token, - AllReduceType reduce_type, double scale, - int64_t scatter_dim, int64_t shard_count, - std::vector> groups, - bool pin_layout) { - XLA_CHECK(outputs.empty() || outputs.size() == inputs.size()); - std::vector input_values; - input_values.reserve(inputs.size()); - for (auto& input : inputs) { - input_values.push_back(input->GetIrValue()); - } - torch::lazy::NodePtr node = torch::lazy::MakeNode( - reduce_type, input_values, token, scale, scatter_dim, shard_count, - std::move(groups), pin_layout); - std::vector result; - for (size_t i = 0; i < inputs.size(); ++i) { - result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i))); - if (!outputs.empty()) { - outputs[i]->SetIrValue(torch::lazy::Value(node, i)); - } - } - return {result, torch::lazy::Value(node, inputs.size())}; -} - std::pair all_to_all( const XLATensorPtr& input, const torch::lazy::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 14150bb91d2..800320bec78 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -33,15 +33,6 @@ torch::lazy::Value reduce_scatter_out(XLATensorPtr& output, std::vector> groups, bool pin_layout); -std::pair, torch::lazy::Value> -reduce_scatter_coalesced(const std::vector& outputs, - const std::vector& inputs, - const torch::lazy::Value& token, - AllReduceType reduce_type, double scale, - int64_t scatter_dim, int64_t shard_count, - std::vector> groups, - bool pin_layout); - std::pair all_to_all( const XLATensorPtr& input, const torch::lazy::Value& token, int64_t split_dimension, int64_t concat_dimension, int64_t split_count, diff --git a/torch_xla/distributed/fsdp/utils.py b/torch_xla/distributed/fsdp/utils.py index 2a80e39fd39..ee79bc4a966 100644 --- a/torch_xla/distributed/fsdp/utils.py +++ b/torch_xla/distributed/fsdp/utils.py @@ -59,66 +59,21 @@ def dummy_all_reduce(reduce_type, inputs, scale=1.0, groups=None): return [t.mul_(scale) for t in inputs] -class DummyReduceScatter: +def dummy_reduce_scatter(reduce_type, + input, + scale, + scatter_dim, + shard_count, + groups=None): """A dummy op for debugging with the same output shape as reduce_scatter""" - - def __init__(self, shard_count): - assert shard_count == xm.xrt_world_size() - self.scale = 1.0 - - def __call__(self, input, callback): - full_size = input.size(0) - shard_size = full_size // xm.xrt_world_size() - begin = shard_size * xm.get_ordinal() - end = begin + shard_size - slices = [None] * input.dim() - slices[0] = slice(begin, end) - callback(input[tuple(slices)]) - - def flush(self): - pass - - -class BucketizedReduceScatter: - """A reduce_scatter op that group input tensors before reduce-scattering them.""" - - def __init__(self, bucket_size_mb, shard_count, groups, pin_layout) -> None: - self.bucket_size_bytes = bucket_size_mb * 1024 * 1024 - self.shard_count = shard_count - self.groups = groups - self.pin_layout = pin_layout - self.scale = 1.0 - - self.callbacks = [] - self.bucket = [] - self.bucket_watermark = 0 - - def __call__(self, input, callback): - input_byte_size = input.element_size() * input.numel() - self.bucket.append(input) - self.callbacks.append(callback) - self.bucket_watermark += input_byte_size - if self.bucket_watermark > self.bucket_size_bytes: - self.flush() - - def flush(self): - if not self.bucket: - return - - results = xm.reduce_scatter( - xm.REDUCE_SUM, - self.bucket, - scale=self.scale, - scatter_dim=0, - shard_count=self.shard_count, - groups=self.groups, - pin_layout=self.pin_layout) - for cb, result in zip(self.callbacks, results): - cb(result) - - self.bucket.clear() - self.callbacks.clear() - self.bucket_watermark = 0 + assert shard_count == xm.xrt_world_size() + full_size = input.size(scatter_dim) + shard_size = full_size // xm.xrt_world_size() + begin = shard_size * xm.get_ordinal() + end = begin + shard_size + slices = [None] * input.dim() + slices[scatter_dim] = slice(begin, end) + return input[tuple(slices)] * scale class XLAPatchedLinear(torch.autograd.Function): diff --git a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py index 506804316cd..dae259b6fb7 100644 --- a/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py +++ b/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py @@ -35,14 +35,7 @@ import torch_xla.core.xla_model as xm from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper -from .utils import ( - BucketizedReduceScatter, - DummyReduceScatter, - dummy_all_gather, - dummy_all_reduce, - apply_xla_patch_to_nn_linear, -) - +from .utils import dummy_all_gather, dummy_all_reduce, dummy_reduce_scatter, apply_xla_patch_to_nn_linear from .wrap import recursive_wrap from ._init_utils import _materialize_module @@ -303,7 +296,6 @@ def __init__( shard_param_on_dim_0: bool = False, pin_layout_in_collective_ops: bool = True, coalesce_all_gather_ops: bool = False, - reduce_scatter_bucket_size_mb: Optional[int] = 0, auto_wrap_policy: Optional[Callable] = None, auto_wrapper_callable: Optional[Callable] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, @@ -421,13 +413,10 @@ def __init__( self.all_reduce_op = functools.partial( xm.all_reduce, pin_layout=pin_layout_in_collective_ops) if _debug_dummy_reduce_scatter_op: - self.reduce_scatter_op = DummyReduceScatter(shard_count=self.world_size) + self.reduce_scatter_op = dummy_reduce_scatter else: - self.reduce_scatter_op = BucketizedReduceScatter( - reduce_scatter_bucket_size_mb, - shard_count=self.world_size, - groups=self.sharding_groups, - pin_layout=pin_layout_in_collective_ops) + self.reduce_scatter_op = functools.partial( + xm.reduce_scatter, pin_layout=pin_layout_in_collective_ops) if _debug_dummy_optimization_barrier_op: self.optimization_barrier_op = lambda *args: None else: @@ -565,10 +554,6 @@ def set_gradient_divide_factors(self, pre: float, post: float, module.set_gradient_divide_factors(pre, post, False) self.gradient_predivide_factor = pre self.gradient_postdivide_factor = post - if (pre, post) == (1, 1): - self.reduce_scatter_op.scale = 1.0 / self.world_size - else: - self.reduce_scatter_op.scale = 1.0 @property def module(self) -> XlaFlattenParamsWrapper: @@ -1159,7 +1144,6 @@ def _register_post_backward_hooks(self) -> None: """ if not torch.is_grad_enabled(): return # don't register grad hooks if grad isn't enabled - self._post_backward_hooks_to_call = 0 for p in self.full_params: if p.requires_grad: if hasattr(p, "_shard_bwd_hook"): @@ -1173,7 +1157,6 @@ def _register_post_backward_hooks(self) -> None: handle = grad_acc.register_hook( functools.partial(self._post_backward_hook, p)) p._shard_bwd_hook = (grad_acc, handle) - self._post_backward_hooks_to_call += 1 @torch.no_grad() def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: @@ -1200,10 +1183,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # then subsequent hook callbacks will see POST state. self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) self.training_state = TrainingState.BACKWARD_POST - self._post_backward_hooks_to_call -= 1 if param.grad is None: - if self._post_backward_hooks_to_call == 0: - self.reduce_scatter_op.flush() return assert param.grad is not None, param.shape @@ -1224,8 +1204,6 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: apply_opt_barrier=self.optimization_barrier_in_backward) if not self._require_backward_grad_sync: - if self._post_backward_hooks_to_call == 0: - self.reduce_scatter_op.flush() return if self.gradient_predivide_factor > 1: @@ -1241,37 +1219,38 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: self.optimization_barrier_op([grad_flat]) if grad_flat.dtype != torch.float32 and self.fp32_reduce_scatter: grad_flat = grad_flat.to(torch.float32) - - def reduce_scatter_done(reduced_grad): - if reduced_grad.dtype != torch.float32: - reduced_grad = reduced_grad.to(torch.float32) - if self.optimization_barrier_in_backward: - self.optimization_barrier_op([reduced_grad]) - if self.gradient_postdivide_factor > 1: - # Average grad by world_size for consistency with PyTorch DDP. - reduced_grad.data.div_(self.gradient_postdivide_factor) - - grad._has_full_param = True - grad_flat._has_full_param = True - self._free_full_params( - [grad, grad_flat], - dependency_tensors=[reduced_grad], - apply_opt_barrier=self.optimization_barrier_in_backward) - self._try_adding_to_backward_opt_barrier_lists(reduced_grad) - - # Accumulate into the gradient shard. - assert hasattr(param, "_sharded_param") - p_shard = param._sharded_param - if p_shard.grad is None: - p_shard.grad = reduced_grad.data - else: - assert p_shard.grad.shape == reduced_grad.shape - assert p_shard.grad.device == reduced_grad.device - p_shard.grad.data += reduced_grad.data - - self.reduce_scatter_op(grad_flat.detach(), reduce_scatter_done) - if self._post_backward_hooks_to_call == 0: - self.reduce_scatter_op.flush() + reduced_grad = self.reduce_scatter_op( + xm.REDUCE_SUM, + grad_flat.detach(), + scale=1.0, + scatter_dim=0, + shard_count=self.world_size, + groups=self.sharding_groups) + if reduced_grad.dtype != torch.float32: + reduced_grad = reduced_grad.to(torch.float32) + if self.optimization_barrier_in_backward: + self.optimization_barrier_op([reduced_grad]) + if self.gradient_postdivide_factor > 1: + # Average grad by world_size for consistency with PyTorch DDP. + reduced_grad.div_(self.gradient_postdivide_factor) + + grad._has_full_param = True + grad_flat._has_full_param = True + self._free_full_params( + [grad, grad_flat], + dependency_tensors=[reduced_grad], + apply_opt_barrier=self.optimization_barrier_in_backward) + self._try_adding_to_backward_opt_barrier_lists(reduced_grad) + + # Accumulate into the gradient shard. + assert hasattr(param, "_sharded_param") + p_shard = param._sharded_param + if p_shard.grad is None: + p_shard.grad = reduced_grad + else: + assert p_shard.grad.shape == reduced_grad.shape + assert p_shard.grad.device == reduced_grad.device + p_shard.grad += reduced_grad def _queue_wait_for_post_backward(self) -> None: """ diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 9e980997f8d..0f5d6b1164d 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -125,33 +125,6 @@ def reduce_scatter(self, output_tensors, input_tensors_list, opts): return _ret_work(output_tensors) - def reduce_scatter_coalesced(self, output_tensors, input_tensors_list, opts): - input_tensor_list = [] - for input_tensors in input_tensors_list: - # Ensure all inputs have the same shape. - first_shape = input_tensors[0].shape - for i, t in enumerate(input_tensors[1:]): - if first_shape != t.shape: - raise ValueError(f"Input {i+1}'s shape is different from input 0: " - f"{t.shape} vs {first_shape}") - input_tensor = torch.cat(input_tensors) - input_tensor_list.append(input_tensor) - - reduce_type = self._get_reduce_type(opts.reduceOp) - groups = self._mesh - shard_count = len(groups[0]) if groups else self.size() - xm.reduce_scatter( - reduce_type, - input_tensor_list, - scatter_dim=0, - shard_count=shard_count, - scale=1, - groups=groups, - output=output_tensors, - pin_layout=False) - - return _ret_work(output_tensors) - # Call site: # https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L2683 def barrier(self, opts):