Skip to content

Commit

Permalink
Fix reduce-scatter-coalesce to be compatible with openxla reduce-scat…
Browse files Browse the repository at this point in the history
…ter tuple change without token
  • Loading branch information
jeffhataws committed Nov 28, 2023
1 parent ec27f90 commit f9f3a71
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 42 deletions.
11 changes: 7 additions & 4 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from datetime import timedelta


def get_process_group_xla(rank, size):
pg_xla_creator = dist.Backend._plugins['XLA'].creator_fn
pg_xla = pg_xla_creator(
prefix_store=None, rank=rank, size=size, timeout=timedelta(minutes=1))
return pg_xla


def hlo_matches(hlo, expected_pattern, match_times=1):
matches = re.findall(expected_pattern, hlo)
assert len(list(matches)) == match_times, hlo
Expand Down Expand Up @@ -104,10 +106,10 @@ def test_allgather_coalesced(self):
output_tensors2 = [torch.zeros_like(tensor2)] * 8
# because we set os.environ[xenv.WORLD_SIZE] = '1', here the outputs'
# shapes will be same as the inputs' shapes.
# Ex: %all-gather.26 = (s64[2]{0}, s64[5]{0}) all-gather(s64[2]{0} %get-tuple-element.24, s64[5]{0} %get-tuple-element.25), replica_groups={}, dimensions={0}
all_gather_pattern = (
r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}\) '
r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+\)')
pg_xla.allgather_coalesced([output_tensors, output_tensors2],
[tensor, tensor2])
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
Expand All @@ -134,7 +136,8 @@ def test_reduce_scatter(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
hlo_matches(hlo, reduce_scatter_pattern)

@skipIf(xr.device_type() == 'CPU', "UNIMPLEMENTED: ReduceScatter is not implemented on CPU.")
@skipIf(xr.device_type() == 'CPU',
"UNIMPLEMENTED: ReduceScatter is not implemented on CPU.")
def test_reduce_scatter_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
Expand Down
16 changes: 9 additions & 7 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,8 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count,
groups or [], pin_layout)
result = torch_xla._XLAC._xla_all_gather(value, dim, shard_count, groups or
[], pin_layout)
return result

# Now the input should be a list of Tensors.
Expand All @@ -594,8 +594,8 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]


def all_to_all(value,
Expand Down Expand Up @@ -773,14 +773,14 @@ def reduce_scatter(reduce_type,
new_token = torch_xla._XLAC._xla_reduce_scatter_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
devctx.all_reduce_token = new_token
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
scale, scatter_dim,
shard_count, groups or [],
pin_layout)
devctx.all_reduce_token = result[1]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]

# Now the input should be a list of Tensors.
Expand All @@ -801,7 +801,9 @@ def reduce_scatter(reduce_type,
result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, output or [], input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
devctx.all_reduce_token = result[-1]
#torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
#return result[0]
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]


Expand Down
35 changes: 16 additions & 19 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,11 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

AllGatherResult BuildAllGather(
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token, int64_t dim,
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
AllGatherResult BuildAllGather(absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
TokenHandler token_handler(token);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
Expand All @@ -234,13 +235,12 @@ AllGatherResult BuildAllGather(
xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
dim, shard_count, cc_groups);
}
if (type_ctx.second.indices.size() > 1) {
if (type_ctx.second.indices.size() > 1) {
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
result[op_idx] = xla::GetTupleElement(all_gather_result, i);
}
}
else {
} else {
result[0] = all_gather_result;
}
}
Expand Down Expand Up @@ -285,22 +285,18 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape,
return {result, new_token};
}

std::vector<xla::XlaOp> BuildReduceScatter(
ReduceScatterResult BuildReduceScatter(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
TokenHandler token_handler(token);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
// switched to use the real XLA Token once support has been added to XLA
// ReduceScatter().
xla::XlaOp chained_token = token;
ReduceContext cc_ctx = GetReduceContext(inputs);
std::vector<xla::XlaOp> result(inputs.size());
for (auto& type_ctx : cc_ctx.contexts) {
xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first);
type_ctx.second.ops.push_back(token_op);
type_ctx.second.operand_shapes.push_back(
ShapeHelper::ShapeOfXlaOp(token_op));
xla::XlaOp reduce_result;
if (pin_layout) {
reduce_result = xla::ReduceScatter(
Expand All @@ -317,7 +313,12 @@ std::vector<xla::XlaOp> BuildReduceScatter(
}
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
xla::XlaOp gte = xla::GetTupleElement(reduce_result, i);
xla::XlaOp gte;
if (ShapeHelper::ShapeOfXlaOp(reduce_result).rank() == 0) {
gte = xla::GetTupleElement(reduce_result, i);
} else {
gte = reduce_result;
}
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, type_ctx.second.operand_shapes[i].element_type(),
Expand All @@ -326,12 +327,8 @@ std::vector<xla::XlaOp> BuildReduceScatter(
}
result[op_idx] = gte;
}
chained_token =
xla::GetTupleElement(reduce_result, type_ctx.second.indices.size());
}
result.push_back(
MaybeConvertTo(chained_token, XlaHelpers::TypeOfXlaOp(token)));
return result;
return {result, token_handler.GetNewToken(result[0])};
}

// moved from torch_xla/csrc/ops/all_reduce.cpp
Expand Down
11 changes: 5 additions & 6 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout);

AllGatherResult BuildAllGather(
absl::Span<const xla::XlaOp>, xla::XlaOp token, int64_t dim,
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
bool pin_layout);
AllGatherResult BuildAllGather(absl::Span<const xla::XlaOp>, xla::XlaOp token,
int64_t dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout);

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
Expand All @@ -76,8 +76,7 @@ SendResult BuildSendWithToken(xla::XlaOp input, xla::XlaOp token,
RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape,
int64_t channel_id);

//ReduceScatterResult BuildReduceScatter(
std::vector<xla::XlaOp> BuildReduceScatter(
ReduceScatterResult BuildReduceScatter(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout);
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/ops/all_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ xla::Shape NodeOutputShape(c10::ArrayRef<torch::lazy::Value> inputs,
}
input_shapes.emplace_back(GetXlaShape(token));
return InferOutputShape(input_shapes, shape_fn);

}

} // namespace
Expand Down
16 changes: 11 additions & 5 deletions torch_xla/csrc/ops/reduce_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ xla::Shape NodeOutputShape(AllReduceType reduce_type,
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
auto shape_fn = [&](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
std::vector<xla::XlaOp> result = BuildReduceScatter(
ReduceScatterResult result = BuildReduceScatter(
reduce_type, operands.subspan(0, operands.size() - 1), operands.back(),
scale, scatter_dim, shard_count, groups, pin_layout);
return xla::Tuple(operands[0].builder(), result);
std::vector<xla::XlaOp> outputs;
for (size_t i = 0; i < result.result.size(); ++i) {
outputs.emplace_back(result.result[i]);
}
outputs.emplace_back(result.token);
return xla::Tuple(operands[0].builder(), outputs);
};
std::vector<xla::Shape> input_shapes;
for (const auto& input : inputs) {
Expand Down Expand Up @@ -70,10 +75,11 @@ XlaOpVector ReduceScatter::Lower(LoweringContext* loctx) const {
inputs.push_back(loctx->GetOutputOp(operand_list[i]));
}
xla::XlaOp token = loctx->GetOutputOp(operand_list.back());
return ReturnOps(
ReduceScatterResult result =
BuildReduceScatter(reduce_type_, inputs, token, scale_, scatter_dim_,
shard_count_, groups_, pin_layout_),
loctx);
shard_count_, groups_, pin_layout_);
result.result.push_back(result.token);
return ReturnOps(result.result, loctx);
}

std::string ReduceScatter::ToString() const {
Expand Down

0 comments on commit f9f3a71

Please sign in to comment.