Skip to content

Commit

Permalink
Handle multiple inplace update input output aliasing
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed May 3, 2024
1 parent d123585 commit 25232c5
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/as_strided.h"
#include "torch_xla/csrc/ops/as_strided_view_update.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal_view_update.h"
#include "torch_xla/csrc/ops/einsum_utilities.h"
#include "torch_xla/csrc/ops/index_ops.h"
Expand Down Expand Up @@ -2538,7 +2539,39 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
// 1) Aid XLA's InputOutputAlias.
auto input_tensor = bridge::GetXlaTensor(input);
auto output_tensor = bridge::GetXlaTensor(output);
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
if (input_tensor->CurrentDataHandle() != nullptr ||
(input_tensor->CurrentIrValue().node != nullptr &&
torch_xla::DeviceData::Cast(
input_tensor->CurrentIrValue().node.get()))) {
/*
if input has a XLAData or holds a devicedata node, set alias_id to
tensor_id. Consider the case.
// x.tensor_id = 1, x.alias_id = 1
x = torch.randn(5,5).to(xla_device())
// x.tensor_id = 2, x.alias_id should be 1
x += 1
xm.mark_step()
// x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2
for
// this graph
x *= 1 of 1
*/
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
} else {
/*
Consider the case
// x.tensor_id = 1, x.alias_id = 1
x = torch.randn(5,5).to(xla_device())
// x.tensor_id = 2, x.alias_id should be 1
x += 1
// x.tensor_id = 3, x.alias_id should still be 1
x * = 2
xm.mark_step()
*/
output_tensor->data()->alias_id = input_tensor->data()->alias_id;
}

// 2) Aid SPMD.
XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();
Expand Down

0 comments on commit 25232c5

Please sign in to comment.