Skip to content

Commit

Permalink
Separate out the reduce-scatter-coalesce changes into a separate PR
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws committed Nov 29, 2023
1 parent f9f3a71 commit 5843b43
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 374 deletions.
21 changes: 0 additions & 21 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,6 @@ def test_reduce_scatter(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
hlo_matches(hlo, reduce_scatter_pattern)

@skipIf(xr.device_type() == 'CPU',
"UNIMPLEMENTED: ReduceScatter is not implemented on CPU.")
def test_reduce_scatter_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
input_tensors_list = [[tensor, tensor], [tensor2, tensor2]]
output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)]
pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0]))
opts = dist.ReduceScatterOptions()
opts.reduceOp = dist.ReduceOp.SUM
reduce_scatter_pattern = (
r'%reduce\-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'reduce\-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list)
hlo_matches(hlo, reduce_scatter_pattern)
# purge all computations attached the device.
xm.mark_step()

@patch_world(0, 6)
def test_send(self):
device = xm.xla_device()
Expand Down
3 changes: 0 additions & 3 deletions test/test_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

class XlaZeRO1Test(TestCase):

@unittest.skipIf(xr.device_type() == 'CPU', "Crash on CPU")
@unittest.skipIf(xr.device_type() == 'TPU', "Crash on TPU")
@unittest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"TODO(alanwaketan): Fix it for the token change.")
Expand All @@ -27,13 +26,11 @@ def test_zero1(self):
y.backward()

opt1 = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# pin_layout=False to workaround "Check failed: has_layout() element_type"
opt2 = ZeroRedundancyOptimizer(
model.parameters(),
torch.optim.SGD,
lr=0.01,
momentum=0.9,
pin_layout=False,
grad_clipping=False)

opt1.step()
Expand Down
63 changes: 17 additions & 46 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,6 @@ def _all_gather_using_all_reduce(value, dim=0, groups=None, pin_layout=True):
Args:
value (torch.Tensor): The input tensor.
value (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
dim (int): The gather dimension.
Default: 0
groups (list, optional): A list of list, representing the replica groups for
Expand Down Expand Up @@ -741,19 +739,16 @@ def reduce_scatter(reduce_type,
reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
``xm.REDUCE_MAX``.
input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
input: A single `torch.Tensor` all reduce + scatter op to.
scale (float): A default scaling value to be applied after the reduce.
scatter_dim (int): Dimension number to which apply scatter operation.
shard_count (int): The number of ways to split up the scatter_dim in.
groups (list): A list of list, representing the replica groups for
the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output: Optional output tensor
output: Optional output tensor if `input` is a torch.Tensor or a list of
torch.Tensor if `input` is a list of torch.Tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
Expand All @@ -766,45 +761,21 @@ def reduce_scatter(reduce_type,
the same as the input.
"""
token, devctx = _get_all_reduce_token()

if isinstance(input, torch.Tensor):
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
scale, scatter_dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]

# Now the input should be a list of Tensors.
if not isinstance(input, list) or any(
not isinstance(v, torch.Tensor) for v in input):
raise TypeError("`input` needs to be a Tensor or a list of Tensors, but "
f"given {type(input)}.")
if output != None:
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(input):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")

result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, output or [], input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
#torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
#return result[0]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(reduce_type, output,
input, token, scale,
scatter_dim,
shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale,
scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]


def add_step_closure(closure, args=(), run_async=False):
Expand Down
69 changes: 29 additions & 40 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,49 +286,38 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape,
}

ReduceScatterResult BuildReduceScatter(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
TokenHandler token_handler(token);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
// switched to use the real XLA Token once support has been added to XLA
// ReduceScatter().
ReduceContext cc_ctx = GetReduceContext(inputs);
std::vector<xla::XlaOp> result(inputs.size());
for (auto& type_ctx : cc_ctx.contexts) {
xla::XlaOp reduce_result;
if (pin_layout) {
reduce_result = xla::ReduceScatter(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
shard_count, cc_groups, /*channel_id=*/absl::nullopt,
/*layout=*/
MakeReduceShape(type_ctx.second.operand_shapes).layout());
} else {
reduce_result = xla::ReduceScatter(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
shard_count, cc_groups);
}
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
xla::XlaOp gte;
if (ShapeHelper::ShapeOfXlaOp(reduce_result).rank() == 0) {
gte = xla::GetTupleElement(reduce_result, i);
} else {
gte = reduce_result;
}
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, type_ctx.second.operand_shapes[i].element_type(),
gte.builder());
gte = gte * scaling_value;
}
result[op_idx] = gte;
}
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp reduce_result;
if (pin_layout) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
input_shape.dimensions(), input_shape.dynamic_dimensions(),
input_shape.element_type(),
static_cast<XlaDeviceType>(xla_device.type()));
reduce_result = xla::ReduceScatter(
token_handler.GetInput(input, &input_shape),
GetReduceComutation(reduce_type, input_shape.element_type()),
scatter_dim, shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
/*layout=*/reduce_shape.layout());
} else {
reduce_result = xla::ReduceScatter(
token_handler.GetInput(input, &input_shape),
GetReduceComutation(reduce_type, input_shape.element_type()),
scatter_dim, shard_count, reduce_groups);
}
return {result, token_handler.GetNewToken(result[0])};

if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, input_shape.element_type(), input.builder());
reduce_result = reduce_result * scaling_value;
}

return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

// moved from torch_xla/csrc/ops/all_reduce.cpp
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct RecvResult {
};

struct ReduceScatterResult {
std::vector<xla::XlaOp> result;
xla::XlaOp result;
xla::XlaOp token;
};

Expand Down Expand Up @@ -77,8 +77,8 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape,
int64_t channel_id);

ReduceScatterResult BuildReduceScatter(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);

std::vector<torch::lazy::Value> GetOperandList(
Expand Down
47 changes: 0 additions & 47 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,29 +214,6 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> ReduceScatter(
std::make_shared<torch::lazy::Value>(new_token));
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
ReduceScatterCoalesced(const std::string& reduce_type,
const std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token,
double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced(
xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale,
scatter_dim, shard_count, replica_groups, pin_layout);
std::vector<at::Tensor> aten_result;
for (auto& xt : result) {
aten_result.emplace_back(bridge::AtenFromXlaTensor(std::move(xt)));
}
return {aten_result, std::make_shared<torch::lazy::Value>(new_token)};
}

std::shared_ptr<torch::lazy::Value> ReduceScatterOut(
const std::string& reduce_type, at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
Expand Down Expand Up @@ -1286,30 +1263,6 @@ void InitXlaModuleBindings(py::module m) {
result_tuple[1] = new_token;
return result_tuple;
});
m.def("_xla_reduce_scatter_coalesced",
[](const std::string& reduce_type, std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, double scale,
int64_t scatter_dim, int64_t shard_count, const py::list& groups,
bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::vector<at::Tensor> result;
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
std::tie(result, new_token) = ReduceScatterCoalesced(
reduce_type, outputs, inputs, token, scale, scatter_dim,
shard_count, replica_groups, pin_layout);
}
auto result_list = py::list(result.size() + 1);
for (int i = 0; i < result.size(); ++i) {
result_list[i] = torch::autograd::make_variable(
result[i], /*requires_grad=*/result[i].requires_grad());
}
result_list[result.size()] = new_token;
return result_list;
});
m.def("_xla_reduce_scatter_out",
[](const std::string& reduce_type, at::Tensor& output,
const at::Tensor& input,
Expand Down
51 changes: 18 additions & 33 deletions torch_xla/csrc/ops/reduce_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,37 @@ namespace torch_xla {
namespace {

xla::Shape NodeOutputShape(AllReduceType reduce_type,
c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value input,
const torch::lazy::Value& token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
ReduceScatterResult result = BuildReduceScatter(
reduce_type, operands.subspan(0, operands.size() - 1), operands.back(),
scale, scatter_dim, shard_count, groups, pin_layout);
std::vector<xla::XlaOp> outputs;
for (size_t i = 0; i < result.result.size(); ++i) {
outputs.emplace_back(result.result[i]);
}
outputs.emplace_back(result.token);
return xla::Tuple(operands[0].builder(), outputs);
xla::XlaOp inputOp = operands[0];
xla::XlaOp tokenOp = operands[1];
ReduceScatterResult result =
BuildReduceScatter(reduce_type, inputOp, tokenOp, scale, scatter_dim,
shard_count, groups, pin_layout);
return xla::Tuple(operands[0].builder(), {result.result, result.token});
};
std::vector<xla::Shape> input_shapes;
for (const auto& input : inputs) {
input_shapes.emplace_back(GetXlaShape(input));
}
input_shapes.emplace_back(GetXlaShape(token));
return InferOutputShape(input_shapes, shape_fn);
return InferOutputShape({GetXlaShape(input), GetXlaShape(token)}, shape_fn);
}

} // namespace

ReduceScatter::ReduceScatter(AllReduceType reduce_type,
c10::ArrayRef<torch::lazy::Value> inputs,
const torch::lazy::Value& input,
const torch::lazy::Value& token, double scale,
int64_t scatter_dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
bool pin_layout)
: XlaNode(xla_reduce_scatter, GetOperandList(inputs, token),
: XlaNode(xla_reduce_scatter, {input, token},
[&]() {
return NodeOutputShape(reduce_type, inputs, token, scale,
return NodeOutputShape(reduce_type, input, token, scale,
scatter_dim, shard_count, groups,
pin_layout);
},
/*num_outputs=*/inputs.size() + 1,
/*num_outputs=*/2,
torch::lazy::MHash(torch::lazy::GetEnumValue(reduce_type), scale,
scatter_dim, shard_count, groups, pin_layout)),
reduce_type_(reduce_type),
Expand All @@ -61,25 +53,18 @@ ReduceScatter::ReduceScatter(AllReduceType reduce_type,
pin_layout_(pin_layout) {}

torch::lazy::NodePtr ReduceScatter::Clone(torch::lazy::OpList operands) const {
std::vector<torch::lazy::Value> inputs(operands.begin(), operands.end() - 1);
return torch::lazy::MakeNode<ReduceScatter>(
reduce_type_, inputs, operands.back(), scale_, scatter_dim_, shard_count_,
groups_, pin_layout_);
reduce_type_, operands.at(0), operands.at(1), scale_, scatter_dim_,
shard_count_, groups_, pin_layout_);
}

XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const {
auto& operand_list = operands();
std::vector<xla::XlaOp> inputs;
inputs.reserve(operand_list.size());
for (size_t i = 0; i + 1 < operand_list.size(); ++i) {
inputs.push_back(loctx->GetOutputOp(operand_list[i]));
}
xla::XlaOp token = loctx->GetOutputOp(operand_list.back());
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp token = loctx->GetOutputOp(operand(1));
ReduceScatterResult result =
BuildReduceScatter(reduce_type_, inputs, token, scale_, scatter_dim_,
BuildReduceScatter(reduce_type_, input, token, scale_, scatter_dim_,
shard_count_, groups_, pin_layout_);
result.result.push_back(result.token);
return ReturnOps(result.result, loctx);
return ReturnOps({result.result, result.token}, loctx);
}

std::string ReduceScatter::ToString() const {
Expand Down
Loading

0 comments on commit 5843b43

Please sign in to comment.