Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle multiple inplace update input output aliasing #7023

Merged
merged 3 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,50 @@ def test_aliasing_with_cloned(self):
torch.allclose(t1 - 1, t1_cloned)
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)

def test_aliasing_across_mark_step(self):
xla_device = xm.xla_device()
met.clear_all()
t1 = torch.randn(4, 5).to(xla_device)
t1 += 1
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
t1 *= 100
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)

def test_aliasing_with_multiple_inplace_update(self):
BATCH_SIZE = 1
SEQ_LEN = 128
NUM_KV_HEADS = 16
HEAD_SIZE = 256
BLOCK_SIZE = 16
DTYPE = torch.bfloat16
num_blocks = 1024
device = xm.xla_device()
key = torch.randn(
BATCH_SIZE * SEQ_LEN,
NUM_KV_HEADS,
HEAD_SIZE,
device=device,
dtype=DTYPE)
k_cache = torch.randn(
num_blocks * BLOCK_SIZE,
NUM_KV_HEADS,
HEAD_SIZE,
device=device,
dtype=DTYPE)
slot_mapping = torch.randint(
0, num_blocks, (BATCH_SIZE, SEQ_LEN), device=device, dtype=torch.int64)
# materalize k_cache to device data
xm.mark_step()
met.clear_all()
for _ in range(10):
k_cache.index_copy_(0, slot_mapping.flatten(), key)
xm.mark_step()
xm.wait_device_ops()
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
torch.allclose(k_cache[slot_mapping[0][0]].cpu(), key[0].cpu())


if __name__ == '__main__':
test = unittest.main()
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python3 test/pjrt/test_dtypes.py
python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
python3 test/test_pallas.py
python3 test/test_input_output_aliases.py
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py
Expand Down
34 changes: 33 additions & 1 deletion torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/ops/as_strided.h"
#include "torch_xla/csrc/ops/as_strided_view_update.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal_view_update.h"
#include "torch_xla/csrc/ops/einsum_utilities.h"
#include "torch_xla/csrc/ops/index_ops.h"
Expand Down Expand Up @@ -2538,7 +2539,38 @@ void XLANativeFunctions::_propagate_xla_data(const at::Tensor& input,
// 1) Aid XLA's InputOutputAlias.
auto input_tensor = bridge::GetXlaTensor(input);
auto output_tensor = bridge::GetXlaTensor(output);
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
if (input_tensor->CurrentDataHandle() != nullptr ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can always use alias_id?

Copy link
Collaborator Author

@JackCaoG JackCaoG May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha that's what I thought but actually no. Look at my example down below

    // x.tensor_id = 1, x.alias_id = 1
    x = torch.randn(5,5).to(xla_device())
    // x.tensor_id = 2, x.alias_id should be 1
    x += 1
    xm.mark_step()
    // x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2
    // for this graph
    x *= 1 
    xm.mark_step()

if we always use alias_id, the alias_id of x in second would be 1, but we need it to be 2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the second execution, input tensor id is 2, we need the alias ID to always match the input tensor ID. In other world we should not carry alias_id across mark_step.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit tricky, even the underlying buffer is aliased, we still create a new PjrtBuffer object for x after the first mark_step. That DeviceData object(wrap about pjrtbuffer) will have data_info with tensor_id 2, since x's tensor id is 2 after the first mark_step.

Copy link
Collaborator

@alanwaketan alanwaketan May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess resetting alias_id after mark_step is probably very complicated. This is more like a simplified way to achieve that. Assuming IR/outputs becomes DeviceData/inputs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do that too(reset alias_id to tensor id after processed the input_output_alias info). That might make this code less confuse haha.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a good follow up, but feel free to skip it.

(input_tensor->CurrentIrValue().node != nullptr &&
torch_xla::DeviceData::Cast(
input_tensor->CurrentIrValue().node.get()))) {
/*
if input has a XLAData or holds a devicedata node, set alias_id to
tensor_id. Consider the case.

// x.tensor_id = 1, x.alias_id = 1
x = torch.randn(5,5).to(xla_device())
// x.tensor_id = 2, x.alias_id should be 1
x += 1
xm.mark_step()
// x.tensor_id =3, x.alias_id should be 2 since input tensor id will be 2
// for this graph
x *= 1 of 1
*/
output_tensor->data()->alias_id = input_tensor->GetUniqueId();
} else {
/*
Consider the case

// x.tensor_id = 1, x.alias_id = 1
x = torch.randn(5,5).to(xla_device())
// x.tensor_id = 2, x.alias_id should be 1
x += 1
// x.tensor_id = 3, x.alias_id should still be 1
x * = 2
xm.mark_step()
*/
output_tensor->data()->alias_id = input_tensor->data()->alias_id;
}

// 2) Aid SPMD.
XLATensor::ShardingSpecPtr sharding = input_tensor->sharding_spec();
Expand Down
Loading