Skip to content

Commit

Permalink
Add out-of-place all-gather coalesced (pytorch#6059)
Browse files Browse the repository at this point in the history
Co-authored-by: Arjun Balasubramanian <[email protected]>
  • Loading branch information
2 people authored and chunnienc committed Dec 14, 2023
1 parent 1ddef24 commit c41eccf
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 32 deletions.
42 changes: 41 additions & 1 deletion test/test_mp_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def all_gather(tensor, dim):
def _mp_fn(index):
device = xm.xla_device()
world_size = xm.xrt_world_size()
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM'):
input_list_size = 5
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'):
# Testing with a single replica group
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
result = xm.all_gather(ordinal_tensor, dim=0)
Expand Down Expand Up @@ -57,6 +58,45 @@ def _mp_fn(index):
f'Failed to create two replica groups with {world_size} replicas',
file=sys.stderr)

# Testing with a single replica group and tensor list as input
ordinal_tensors = [
torch.tensor([i * 1000 + index], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output=None
result_list = xm.all_gather(ordinal_tensors, dim=0, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)

# Testing with a single replica group and tensor list as input and output!=None (out-of-place)
# Reuse ordinal_tensors from previous test
output_tensors = [
torch.zeros([world_size], dtype=torch.float).to(device)
for i in range(input_list_size)
]
# TODO: add support for list input with pin_layout=True and output!=None
result_list = xm.all_gather(
ordinal_tensors, dim=0, output=output_tensors, pin_layout=False)

for i, result in enumerate(result_list):
cpu_result = result.cpu()
expected = i * 1000 + torch.arange(world_size, dtype=torch.float)
if not cpu_result.allclose(expected):
print(
'xm.all_gather() produced wrong reductions for item {i} in result list',
file=sys.stderr)
print(f'[{index}] {cpu_result}', file=sys.stderr)
sys.exit(1)
# TODO: add test for torch.compile when support for list input is ready

else:
print(f'{device} is not a TPU or GPU device', file=sys.stderr)

Expand Down
22 changes: 21 additions & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,8 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
A tensor which has, in the ``dim`` dimension, all the values from the
participating replicas.
"""
if pin_layout and (output == None or xla_device_hw(value.device) == 'NEURON'):
# _all_gather_using_all_reduce does not support list of tensors as input
if pin_layout and output == None and isinstance(value, torch.Tensor):
# There is not an easy way to pin the all_gather layout on TPU, GPU and NEURON,
# use all_reduce based all_gather for this purpose.
return _all_gather_using_all_reduce(
Expand Down Expand Up @@ -587,6 +588,25 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
# Now the input should be a list of Tensors.
elif isinstance(value, list) and all(
isinstance(v, torch.Tensor) for v in value):
if pin_layout:
raise RuntimeError(
"For xm.all_gather with list of tensors input, pin_layout=True is not yet supported."
)
if output != None:
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(value):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_all_gather_coalesced_out(
output, value, token, dim, shard_count, groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
shard_count, groups or
[], pin_layout)
Expand Down
45 changes: 38 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,19 @@ at::Tensor AllGather(const at::Tensor& input, int64_t dim, int64_t shard_count,
return bridge::AtenFromXlaTensor(std::move(result));
}

std::shared_ptr<torch::lazy::Value> AllGatherOut(
at::Tensor& output, const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
XLATensorPtr out = bridge::GetXlaTensor(output);
torch::lazy::Value new_token;
new_token = tensor_methods::all_gather_out(out, bridge::GetXlaTensor(input),
*token, dim, shard_count,
replica_groups, pin_layout);
return std::make_shared<torch::lazy::Value>(new_token);
}

std::pair<std::vector<at::Tensor>, std::shared_ptr<torch::lazy::Value>>
AllGatherCoalesced(const std::vector<at::Tensor>& tensors,
const std::shared_ptr<torch::lazy::Value>& token,
Expand All @@ -304,7 +317,7 @@ AllGatherCoalesced(const std::vector<at::Tensor>& tensors,
GetXlaTensors(tensors, /*want_all=*/true);
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::all_gather(
std::tie(result, new_token) = tensor_methods::all_gather_coalesced(
xtensors, *token, dim, shard_count, replica_groups, pin_layout);
std::vector<at::Tensor> aten_result;
for (auto& xt : result) {
Expand All @@ -313,16 +326,18 @@ AllGatherCoalesced(const std::vector<at::Tensor>& tensors,
return {aten_result, std::make_shared<torch::lazy::Value>(new_token)};
}

std::shared_ptr<torch::lazy::Value> AllGatherOut(
at::Tensor& output, const at::Tensor& input,
std::shared_ptr<torch::lazy::Value> AllGatherCoalescedOut(
std::vector<at::Tensor>& outputs, const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
XLATensorPtr out = bridge::GetXlaTensor(output);
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
torch::lazy::Value new_token;
new_token = tensor_methods::all_gather_out(out, bridge::GetXlaTensor(input),
*token, dim, shard_count,
replica_groups, pin_layout);
new_token = tensor_methods::all_gather_coalesced_out(
xtensors_out, xtensors, *token, dim, shard_count, replica_groups,
pin_layout);
return std::make_shared<torch::lazy::Value>(new_token);
}

Expand Down Expand Up @@ -1288,6 +1303,22 @@ void InitXlaModuleBindings(py::module m) {
result_list[results.size()] = new_token;
return result_list;
});
m.def("_xla_all_gather_coalesced_out",
[](std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs,
const std::shared_ptr<torch::lazy::Value>& token, int64_t dim,
int64_t shard_count, const py::list& groups, bool pin_layout) {
std::vector<std::vector<int64_t>> replica_groups =
CreateReduceGroups(groups);
std::shared_ptr<torch::lazy::Value> new_token;
{
NoGilSection nogil;
new_token =
AllGatherCoalescedOut(outputs, inputs, token, dim, shard_count,
replica_groups, pin_layout);
}
return new_token;
});
m.def("_xla_collective_permute",
[](const at::Tensor& input,
const std::shared_ptr<torch::lazy::Value>& token,
Expand Down
53 changes: 35 additions & 18 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,24 +445,6 @@ std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
torch::lazy::Value(node, 1)};
}

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGatherCoalesced>(
input_values, token, dim, shard_count, std::move(groups), pin_layout);
std::vector<XLATensorPtr> result;
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
}
return {result, torch::lazy::Value(node, inputs.size())};
}

XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim,
int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
Expand All @@ -488,6 +470,41 @@ torch::lazy::Value all_gather_out(XLATensorPtr& output,
return torch::lazy::Value(node, 1);
}

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather_coalesced(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGatherCoalesced>(
input_values, token, dim, shard_count, std::move(groups), pin_layout);
std::vector<XLATensorPtr> result;
for (size_t i = 0; i < inputs.size(); ++i) {
result.emplace_back(inputs[i]->CreateFrom(torch::lazy::Value(node, i)));
}
return {result, torch::lazy::Value(node, inputs.size())};
}

torch::lazy::Value all_gather_coalesced_out(
std::vector<XLATensorPtr>& outputs, const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token, int64_t dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout) {
std::vector<torch::lazy::Value> input_values;
input_values.reserve(inputs.size());
for (auto& input : inputs) {
input_values.push_back(input->GetIrValue());
}
torch::lazy::NodePtr node = torch::lazy::MakeNode<AllGatherCoalesced>(
input_values, token, dim, shard_count, std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
outputs[i]->SetIrValue(torch::lazy::Value(node, i));
}
return torch::lazy::Value(node, inputs.size());
}

std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs) {
Expand Down
15 changes: 10 additions & 5 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ std::pair<XLATensorPtr, torch::lazy::Value> all_to_all(
int64_t split_dimension, int64_t concat_dimension, int64_t split_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout);

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout);

XLATensorPtr all_gather(const XLATensorPtr& input, int64_t dim,
int64_t shard_count,
std::vector<std::vector<int64_t>> groups,
Expand All @@ -70,6 +65,16 @@ torch::lazy::Value all_gather_out(XLATensorPtr& output,
std::vector<std::vector<int64_t>> groups,
bool pin_layout);

std::pair<std::vector<XLATensorPtr>, torch::lazy::Value> all_gather_coalesced(
const std::vector<XLATensorPtr>& inputs, const torch::lazy::Value& token,
int64_t dim, int64_t shard_count, std::vector<std::vector<int64_t>> groups,
bool pin_layout);

torch::lazy::Value all_gather_coalesced_out(
std::vector<XLATensorPtr>& outputs, const std::vector<XLATensorPtr>& inputs,
const torch::lazy::Value& token, int64_t dim, int64_t shard_count,
std::vector<std::vector<int64_t>> groups, bool pin_layout);

std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
const XLATensorPtr& input, const torch::lazy::Value& token,
std::vector<std::pair<int64_t, int64_t>> source_target_pairs);
Expand Down

0 comments on commit c41eccf

Please sign in to comment.