From 25232c51dec230129ee7c4cc2e06d8fd73ed770b Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 3 May 2024 17:12:32 +0000 Subject: [PATCH 1/3] Handle multiple inplace update input output aliasing --- torch_xla/csrc/aten_xla_type.cpp | 35 +++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 8464d1320c2..6a51bb48a53 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -29,6 +29,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/as_strided.h" #include "torch_xla/csrc/ops/as_strided_view_update.h" +#include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/diagonal_view_update.h" #include "torch_xla/csrc/ops/einsum_utilities.h" #include "torch_xla/csrc/ops/index_ops.h" @@ -2538,7 +2539,39 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, // 1) Aid XLA's InputOutputAlias. auto input_tensor = bridge::GetXlaTensor(input); auto output_tensor = bridge::GetXlaTensor(output); - output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + if (input_tensor->CurrentDataHandle() != nullptr || + (input_tensor->CurrentIrValue().node != nullptr && + torch_xla::DeviceData::Cast( + input_tensor->CurrentIrValue().node.get()))) { + /* + if input has a XLAData or holds a devicedata node, set alias_id to + tensor_id. Consider the case. + + // x.tensor_id = 1, x.alias_id = 1 + x = torch.randn(5,5).to(xla_device()) + // x.tensor_id = 2, x.alias_id should be 1 + x += 1 + xm.mark_step() + // x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2 + for + // this graph + x *= 1 of 1 + */ + output_tensor->data()->alias_id = input_tensor->GetUniqueId(); + } else { + /* + Consider the case + + // x.tensor_id = 1, x.alias_id = 1 + x = torch.randn(5,5).to(xla_device()) + // x.tensor_id = 2, x.alias_id should be 1 + x += 1 + // x.tensor_id = 3, x.alias_id should still be 1 + x * = 2 + xm.mark_step() + */ + output_tensor->data()->alias_id = input_tensor->data()->alias_id; + } // 2) Aid SPMD. XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec(); From 22d1fb4f321e4b3bcdb7df02d38d7ad7454b05a7 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 3 May 2024 17:29:06 +0000 Subject: [PATCH 2/3] add test for multiple in place --- test/test_input_output_aliases.py | 33 +++++++++++++++++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 3 +-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index c7c04f781c3..9cfafd9f48a 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -38,6 +38,39 @@ def test_aliasing_with_cloned(self): torch.allclose(t1 - 1, t1_cloned) self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + def test_aliasing_with_multiple_inplace_update(self): + BATCH_SIZE = 1 + SEQ_LEN = 128 + NUM_KV_HEADS = 16 + HEAD_SIZE = 256 + BLOCK_SIZE = 16 + DTYPE = torch.bfloat16 + num_blocks = 1024 + device = xm.xla_device() + key = torch.randn( + BATCH_SIZE * SEQ_LEN, + NUM_KV_HEADS, + HEAD_SIZE, + device=device, + dtype=DTYPE) + k_cache = torch.randn( + num_blocks * BLOCK_SIZE, + NUM_KV_HEADS, + HEAD_SIZE, + device=device, + dtype=DTYPE) + slot_mapping = torch.randint( + 0, num_blocks, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.int64) + # materalize k_cache to device data + xm.mark_step() + met.clear_all() + for _ in range(10): + k_cache.index_copy_(0, slot_mapping.flatten(), key) + xm.mark_step() + xm.wait_device_ops() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + torch.allclose(k_cache[slot_mapping[0][0]].cpu(), key[0].cpu()) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 6a51bb48a53..12a49a91ad9 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2553,8 +2553,7 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input, x += 1 xm.mark_step() // x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2 - for - // this graph + // for this graph x *= 1 of 1 */ output_tensor->data()->alias_id = input_tensor->GetUniqueId(); From 8c1e37a7ef16c223fa26db53cb18f8d38a066c9e Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Fri, 3 May 2024 17:35:35 +0000 Subject: [PATCH 3/3] add another test for aliasing across mark_step --- test/test_input_output_aliases.py | 11 +++++++++++ test/tpu/run_tests.sh | 1 + 2 files changed, 12 insertions(+) diff --git a/test/test_input_output_aliases.py b/test/test_input_output_aliases.py index 9cfafd9f48a..b2c5fc50b21 100644 --- a/test/test_input_output_aliases.py +++ b/test/test_input_output_aliases.py @@ -38,6 +38,17 @@ def test_aliasing_with_cloned(self): torch.allclose(t1 - 1, t1_cloned) self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + def test_aliasing_across_mark_step(self): + xla_device = xm.xla_device() + met.clear_all() + t1 = torch.randn(4, 5).to(xla_device) + t1 += 1 + xm.mark_step() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0) + t1 *= 100 + xm.mark_step() + self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0) + def test_aliasing_with_multiple_inplace_update(self): BATCH_SIZE = 1 SEQ_LEN = 128 diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index b2a8fff33dc..a0eddebc3d5 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -21,6 +21,7 @@ python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py python3 test/test_pallas.py +python3 test/test_input_output_aliases.py python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py