diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 8464d1320c2..6a51bb48a53 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -29,6 +29,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/as_strided.h" #include "torch_xla/csrc/ops/as_strided_view_update.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal_view_update.h" #include "torch_xla/csrc/ops/einsum_utilities.h" #include "torch_xla/csrc/ops/index_ops.h" @@ -2538,7 +2539,39 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, // 1) Aid XLA's InputOutputAlias. auto input_tensor = bridge::GetXlaTensor(input); auto output_tensor = bridge::GetXlaTensor(output); - output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + if (input_tensor->CurrentDataHandle() != nullptr || + (input_tensor->CurrentIrValue().node != nullptr && + torch_xla::DeviceData::Cast( + input_tensor->CurrentIrValue().node.get()))) { + /* + if input has a XLAData or holds a devicedata node, set alias_id to + tensor_id. Consider the case. + + // x.tensor_id = 1, x.alias_id = 1 + x = torch.randn(5,5).to(xla_device()) + // x.tensor_id = 2, x.alias_id should be 1 + x += 1 + xm.mark_step() + // x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2 + for + // this graph + x *= 1 of 1 + */ + output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + } else { + /* + Consider the case + + // x.tensor_id = 1, x.alias_id = 1 + x = torch.randn(5,5).to(xla_device()) + // x.tensor_id = 2, x.alias_id should be 1 + x += 1 + // x.tensor_id = 3, x.alias_id should still be 1 + x * = 2 + xm.mark_step() + */ + output_tensor->data()->alias_id = input_tensor->data()->alias_id; + } // 2) Aid SPMD. XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();