diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index 11d4b2f71a55..ba82c5c8e42e 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -53,10 +53,9 @@ class AtenSource : public TensorSource { at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type()); if (target_torch_type != tensor.type().scalarType()) { TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); - tensor_ = std::move(tensor.to(target_torch_type).contiguous()); - } else { - tensor_ = std::move(tensor.contiguous()); } + tensor_ = std::move(tensor.to(target_torch_type, /*non_blocking=*/false, + /*copy=*/true, at::MemoryFormat::Contiguous)); } const void* data() const override { return tensor_.const_data_ptr(); }