Skip to content

Commit

Permalink
When modifying IR node, make sure to not lose the read_only bit (#8505)…
Browse files Browse the repository at this point in the history
… (#8508)
  • Loading branch information
mcuiaws authored Dec 20, 2024
1 parent ef85771 commit 38ed80e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
26 changes: 26 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,32 @@ def test_xm_save_no_aliasing(self):

self.assertEqual(t2.item(), 3)

def test_device_data_cache_no_aliasing(self):
"""
Test that device data in DataCache are not aliased.
"""
xla_device = xm.xla_device()

t0 = torch.tensor(42, device=xla_device)
# drops the read-only bit on t0's device_data
xm.mark_step()

# cached value of 42 is donated
t0.add_(1)
xm.mark_step()

# t1 get the cached device_data, which was donated
t1 = torch.tensor(42, device=xla_device)
xm.mark_step()

t1.add_(1)
# XLA crashes here because parameter is donated buffer...
xm.mark_step()

# ...if it doesn't crash, the value here would be 44.
self.assertEqual(t1.item(), 43)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
6 changes: 5 additions & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,11 @@ torch::lazy::Value XLATensor::GetIrValue() const {
// will still collapse them all into a single XLA parameter op). So call
// which wants the XLA data will still find it, w/out having to fetch it
// via a computation client from-server call.
AssignIrValue(CreateTensorNode(handle, /*read_only=*/false));
auto* data_info =
static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
handle->info());
bool read_only = data_info != nullptr && data_info->read_only;
AssignIrValue(CreateTensorNode(handle, read_only));
// CreateTensorNode will set the data info of the tensor to the current
// unique_id. Here the alias id needs to be updated so that input output
// alias can correctly work on the xla's custom inplace operation.
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,9 +659,13 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
// XlaData from the DeviceData Node and reset the IR. We also want
// to update XlaData's tensorID to make it match with the current
// XLATensor.
auto* data_info =
static_cast<torch::lazy::LazyGraphExecutor::DeviceDataInfo*>(
device_data->data()->info());
bool read_only = data_info != nullptr && data_info->read_only;
tensors[i]->GetXlaData()->SetInfo(
std::make_shared<LazyGraphExecutor::DeviceDataInfo>(
tensors[i]->GetUniqueId(), /*=read_only=*/false));
tensors[i]->GetUniqueId(), read_only));
} else {
// Add only tensors which need to be synced.
coll.hash = torch::lazy::HashCombine(coll.hash, ir_value.hash());
Expand Down

0 comments on commit 38ed80e

Please sign in to comment.