Skip to content

Commit

Permalink
implement byte_strides in TensorSource
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 1, 2023
1 parent ff2ffb9 commit ab88d50
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
deps = [
":debug_macros",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@torch//:headers",
Expand Down
5 changes: 1 addition & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,13 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToDevice(
for (auto& tensor : tensors) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device());

std::vector<int64_t> byte_strides(tensor->shape().dimensions_size());
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(tensor->shape(),
absl::MakeSpan(byte_strides)));
total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape());

std::shared_ptr<xla::PjRtBuffer> buffer =
std::move(client_
->BufferFromHostBuffer(
tensor->data(), tensor->shape().element_type(),
tensor->shape().dimensions(), byte_strides,
tensor->shape().dimensions(), tensor->byte_strides(),
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[tensor]() { /* frees tensor */ }, pjrt_device)
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef XLA_CLIENT_TENSOR_SOURCE_H_
#define XLA_CLIENT_TENSOR_SOURCE_H_

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "xla/literal.h"
#include "xla/shape.h"

Expand All @@ -18,6 +19,13 @@ class TensorSource {

const std::string& device() const { return device_; }

std::vector<int64_t> byte_strides() const {
std::vector<int64_t> byte_strides(shape().dimensions_size());
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(shape(),
absl::MakeSpan(byte_strides)));
return byte_strides;
}

private:
std::string device_;
};
Expand Down

0 comments on commit ab88d50

Please sign in to comment.