Skip to content

Commit

Permalink
Some cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws committed Nov 29, 2023
1 parent 5843b43 commit 8e6715e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ AllGatherResult BuildAllGather(absl::Span<const xla::XlaOp> inputs,
xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
dim, shard_count, cc_groups);
}
if (type_ctx.second.indices.size() > 1) {
if (ShapeHelper::ShapeOfXlaOp(all_gather_result).rank() == 0) {
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
result[op_idx] = xla::GetTupleElement(all_gather_result, i);
Expand Down
8 changes: 2 additions & 6 deletions torch_xla/csrc/ops/all_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ xla::Shape NodeOutputShape(c10::ArrayRef<torch::lazy::Value> inputs,
AllGatherResult result =
BuildAllGather(operands.subspan(0, operands.size() - 1),
operands.back(), dim, shard_count, groups, pin_layout);
std::vector<xla::XlaOp> outputs;
for (size_t i = 0; i < result.result.size(); ++i) {
outputs.emplace_back(result.result[i]);
}
outputs.emplace_back(result.token);
return xla::Tuple(operands[0].builder(), outputs);
result.result.emplace_back(result.token);
return xla::Tuple(operands[0].builder(), result.result);
};
std::vector<xla::Shape> input_shapes;
for (const auto& input : inputs) {
Expand Down

0 comments on commit 8e6715e

Please sign in to comment.