Skip to content

Commit

Permalink
Handle multiple inplace update input output aliasing (#7023)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored and jeffhataws committed May 31, 2024
1 parent cc2f89b commit 2fc7912
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
44 changes: 44 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,50 @@ def test_non_view(self):

self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 4.0)

def test_aliasing_across_mark_step(self):
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.randn(4, 5).to(xla_device)
t1 += 1
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
t1 *= 100
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)

def test_aliasing_with_multiple_inplace_update(self):
BATCH_SIZE = 1
SEQ_LEN = 128
NUM_KV_HEADS = 16
HEAD_SIZE = 256
BLOCK_SIZE = 16
DTYPE = torch.bfloat16
num_blocks = 1024
device = xm.xla_device()
key = torch.randn(
BATCH_SIZE * SEQ_LEN,
NUM_KV_HEADS,
HEAD_SIZE,
device=device,
dtype=DTYPE)
k_cache = torch.randn(
num_blocks * BLOCK_SIZE,
NUM_KV_HEADS,
HEAD_SIZE,
device=device,
dtype=DTYPE)
slot_mapping = torch.randint(
0, num_blocks, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.int64)
# materalize k_cache to device data
xm.mark_step()
met.clear_all()
for _ in range(10):
k_cache.index_copy_(0, slot_mapping.flatten(), key)
xm.mark_step()
xm.wait_device_ops()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
torch.allclose(k_cache[slot_mapping[0][0]].cpu(), key[0].cpu())


if __name__ == '__main__':
test = unittest.main()
Expand Down
28 changes: 28 additions & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash
set -xue

# TODO: merge with other run_tests
python3 test/test_operations.py -v
python3 test/pjrt/test_runtime_tpu.py
python3 test/pjrt/test_collective_ops_tpu.py
python3 test/spmd/test_xla_sharding.py
python3 test/spmd/test_xla_virtual_device.py
python3 test/spmd/test_xla_distributed_checkpoint.py
python3 test/spmd/test_train_spmd_linear_model.py
python3 test/spmd/test_xla_spmd_python_api_interaction.py
python3 test/spmd/test_xla_auto_sharding.py
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
python3 test/test_grad_checkpoint.py
python3 test/dynamo/test_dynamo.py
python3 test/spmd/test_spmd_debugging.py
python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
python3 test/test_pallas.py
python3 test/test_input_output_aliases.py
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py
34 changes: 33 additions & 1 deletion torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/as_strided.h"
#include "torch_xla/csrc/ops/as_strided_view_update.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal_view_update.h"
#include "torch_xla/csrc/ops/einsum_utilities.h"
#include "torch_xla/csrc/ops/index_ops.h"
Expand Down Expand Up @@ -2405,7 +2406,38 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
// 1) Aid XLA's InputOutputAlias.
auto input_tensor = bridge::GetXlaTensor(input);
auto output_tensor = bridge::GetXlaTensor(output);
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
if (input_tensor->CurrentDataHandle() != nullptr ||
(input_tensor->CurrentIrValue().node != nullptr &&
torch_xla::DeviceData::Cast(
input_tensor->CurrentIrValue().node.get()))) {
/*
if input has a XLAData or holds a devicedata node, set alias_id to
tensor_id. Consider the case.
// x.tensor_id = 1, x.alias_id = 1
x = torch.randn(5,5).to(xla_device())
// x.tensor_id = 2, x.alias_id should be 1
x += 1
xm.mark_step()
// x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2
// for this graph
x *= 1 of 1
*/
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
} else {
/*
Consider the case
// x.tensor_id = 1, x.alias_id = 1
x = torch.randn(5,5).to(xla_device())
// x.tensor_id = 2, x.alias_id should be 1
x += 1
// x.tensor_id = 3, x.alias_id should still be 1
x * = 2
xm.mark_step()
*/
output_tensor->data()->alias_id = input_tensor->data()->alias_id;
}

// 2) Aid SPMD.
if (input_tensor->sharding_spec()) {
Expand Down

0 comments on commit 2fc7912

Please sign in to comment.