diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index e6f064158ab..4feb4154166 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -302,6 +302,7 @@ cc_library( name = "tensor_source", hdrs = ["tensor_source.h"], deps = [ + ":debug_macros", "@xla//xla:literal", "@xla//xla:shape_util", "@torch//:headers", diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index de158582615..cbe1937df00 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -271,16 +271,13 @@ std::vector PjRtComputationClient::TransferToDevice( for (auto& tensor : tensors) { xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); - std::vector byte_strides(tensor->shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(tensor->shape(), - absl::MakeSpan(byte_strides))); total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); std::shared_ptr buffer = std::move(client_ ->BufferFromHostBuffer( tensor->data(), tensor->shape().element_type(), - tensor->shape().dimensions(), byte_strides, + tensor->shape().dimensions(), tensor->byte_strides(), xla::PjRtClient::HostBufferSemantics:: kImmutableUntilTransferCompletes, [tensor]() { /* frees tensor */ }, pjrt_device) diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index 14772289d36..d3798231777 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -1,6 +1,7 @@ #ifndef XLA_CLIENT_TENSOR_SOURCE_H_ #define XLA_CLIENT_TENSOR_SOURCE_H_ +#include "torch_xla/csrc/runtime/debug_macros.h" #include "xla/literal.h" #include "xla/shape.h" @@ -18,6 +19,13 @@ class TensorSource { const std::string& device() const { return device_; } + std::vector byte_strides() const { + std::vector byte_strides(shape().dimensions_size()); + XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(shape(), + absl::MakeSpan(byte_strides))); + return byte_strides; + } + private: std::string device_; };