From d39e51b8c06f88a8b623e83189ddbeac621e6039 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 27 Nov 2023 12:33:42 -0800 Subject: [PATCH] Copy input tensors before async transfer (#5830) * Copy input tensors before transfer * clone tensors before passing them to test case * formatting * Update test_utils.py --- torch_xla/csrc/runtime/tensor_source.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index 11d4b2f71a5..ba82c5c8e42 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -53,10 +53,9 @@ class AtenSource : public TensorSource { at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type()); if (target_torch_type != tensor.type().scalarType()) { TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); - tensor_ = std::move(tensor.to(target_torch_type).contiguous()); - } else { - tensor_ = std::move(tensor.contiguous()); } + tensor_ = std::move(tensor.to(target_torch_type, /*non_blocking=*/false, + /*copy=*/true, at::MemoryFormat::Contiguous)); } const void* data() const override { return tensor_.const_data_ptr(); }