Skip to content

Commit

Permalink
Transfer data directly to the device (#5772)
Browse files Browse the repository at this point in the history
* Transfer data directly to the device (#5752)

* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.

* Use `at::Tensor`'s layout for byte strides

* Downcast at::Tensor if required

* formatting

* Simplify AtenSource

* fix build

* formatting

* fix typo that makes us ignore input type

* Revert "Simplify AtenSource"

This reverts commit 4225deb.

* Skip hanging test

* fix gil deadlock

* formatting
  • Loading branch information
will-cromar authored Nov 13, 2023
1 parent e60428d commit 05a3cdd
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 151 deletions.
64 changes: 42 additions & 22 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
load(
"//bazel:rules_def.bzl",
"ptxla_cc_test",
)

load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
Expand Down Expand Up @@ -46,6 +51,7 @@ cc_library(
":metrics_reader",
":metrics",
":sys_util",
":tensor_source",
":types",
":util",
":xla_coordinator",
Expand Down Expand Up @@ -78,6 +84,7 @@ cc_library(
":env_vars",
":multi_wait",
":stablehlo_helper",
":tensor_source",
":tf_logging",
":thread_pool",
":xla_coordinator",
Expand Down Expand Up @@ -264,6 +271,17 @@ cc_library(
],
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
deps = [
":debug_macros",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@torch//:headers",
]
)

cc_library(
name = "types",
hdrs = ["types.h"],
Expand Down Expand Up @@ -339,25 +357,27 @@ ptxla_cc_test(
],
)

# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream.
# xla_cc_test(
# name = "pjrt_computation_client_test",
# srcs = ["pjrt_computation_client_test.cc"],
# deps = [
# ":computation_client",
# "@xla//xla:literal",
# "@xla//xla:literal_util",
# "@xla//xla:shape_util",
# "@xla//xla:status",
# "@xla//xla:statusor",
# "@xla//xla/client:xla_builder",
# "@xla//xla/client:xla_computation",
# "@xla//xla/tests:literal_test_util",
# "@xla//xla/tools:hlo_module_loader",
# "@org_tensorflow//tensorflow/core/platform:logging",
# "@tsl//tsl/lib/core:status_test_util",
# "@tsl//tsl/platform:env",
# "@tsl//tsl/platform:test",
# "@tsl//tsl/platform:test_main",
# ],
# )
ptxla_cc_test(
name = "pjrt_computation_client_test",
srcs = ["pjrt_computation_client_test.cc"],
deps = [
":computation_client",
":pjrt_computation_client",
":tensor_source",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)
30 changes: 8 additions & 22 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef XLA_CLIENT_COMPUTATION_CLIENT_H_
#define XLA_CLIENT_COMPUTATION_CLIENT_H_

#include <ATen/Tensor.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/hash.h>
Expand All @@ -20,6 +21,7 @@
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/types.h"
#include "torch_xla/csrc/runtime/util.h"
#include "xla/client/xla_computation.h"
Expand Down Expand Up @@ -192,25 +194,6 @@ class ComputationClient {

using ComputationPtr = std::shared_ptr<Computation>;

// The TensorSource provides a way for a client to populate a buffer allocated
// by the core computation client code.
struct TensorSource {
// The PopulateFn accepts a dense buffer is standard array layout
// (dim0-major) and deposits the source tensor data directly over the
// provided buffer.
using PopulateFn = std::function<void(const TensorSource&, void*, size_t)>;

TensorSource() = default;
TensorSource(xla::Shape shape, std::string device, PopulateFn populate_fn)
: shape(std::move(shape)),
device(std::move(device)),
populate_fn(std::move(populate_fn)) {}

xla::Shape shape;
std::string device;
PopulateFn populate_fn;
};

// TODO(wcromar): Should CompileInstance still exist? Should it be a subclass
// of torch::lazy::Computation?
struct CompileInstance {
Expand Down Expand Up @@ -275,19 +258,22 @@ class ComputationClient {

// Transfers local tensor values to the TPU devices and fetches the handles.
virtual std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) = 0;
absl::Span<const std::shared_ptr<const TensorSource>> tensors) = 0;

// Transfers local sharded tensor values to the TPU devices and returns a
// `PjRtShardedData`.
virtual DataPtr TransferShardsToServer(
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) = 0;
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) = 0;

// Copies `data->buffer` to `dst` device buffer.
virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0;

// Reads the tensor literal values stored at TPU server sites, behind the
// supplied handles.
// Note: `TransferFromServer` call will block until the `DataPtrs` are ready
// if they were created by `TransferToServer` or `Execute*`. Calling this from
// python while holding the GIL can cause deadlocks!
virtual std::vector<xla::Literal> TransferFromServer(
absl::Span<const DataPtr> handles) = 0;

Expand Down
44 changes: 18 additions & 26 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
Expand Down Expand Up @@ -303,39 +304,30 @@ std::optional<xla::OpSharding> PjRtComputationClient::GetDataSharding(
}

std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
absl::Span<const TensorSource> tensors) {
absl::Span<const std::shared_ptr<const TensorSource>> tensors) {
metrics::TimedSection timed(TransferToServerMetric());
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToServer",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::DataPtr> datas;
datas.reserve(tensors.size());
int64_t total_size = 0;
for (auto& tensor : tensors) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device);

auto literal = std::make_shared<xla::Literal>(tensor.shape);
tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes());
std::vector<int64_t> byte_strides(literal->shape().dimensions_size());
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(),
absl::MakeSpan(byte_strides)));
total_size += literal->size_bytes();

// Avoid use-after-free on `literal` due to unsequenced move and use.
xla::Literal* literal_pointer = literal.get();
std::shared_ptr<xla::PjRtBuffer> buffer = std::move(
client_
->BufferFromHostBuffer(
literal_pointer->untyped_data(),
literal_pointer->shape().element_type(),
literal_pointer->shape().dimensions(), byte_strides,
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[literal{std::move(literal)}]() { /* frees literal */ },
pjrt_device)
.value());
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device());

total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape());

std::shared_ptr<xla::PjRtBuffer> buffer =
std::move(client_
->BufferFromHostBuffer(
tensor->data(), tensor->primitive_type(),
tensor->dimensions(), tensor->byte_strides(),
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[tensor]() { /* frees tensor */ }, pjrt_device)
.value());

ComputationClient::DataPtr data =
std::make_shared<PjRtData>(tensor.device, tensor.shape, buffer);
std::make_shared<PjRtData>(tensor->device(), tensor->shape(), buffer);
datas.push_back(data);
}
OutboundDataMetric()->AddSample(total_size);
Expand All @@ -345,8 +337,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
}

ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer(
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) {
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) {
tsl::profiler::TraceMe activity(
"PjRtComputationClient::TransferShardsToServer",
tsl::profiler::TraceMeLevel::kInfo);
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ class PjRtComputationClient : public ComputationClient {
std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) override;

std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) override;
absl::Span<const std::shared_ptr<const TensorSource>> tensors) override;

// Use XLA replication to re-assemble the sharded data.
DataPtr ReplicateShardedData(const DataPtr& handle);

std::vector<xla::Literal> TransferFromServer(
absl::Span<const DataPtr> handles) override;

DataPtr TransferShardsToServer(absl::Span<const TensorSource> tensor_shards,
std::string device, xla::Shape shape,
xla::OpSharding sharding) override;
DataPtr TransferShardsToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) override;

DataPtr CopyToDevice(DataPtr data, std::string dst) override;

Expand Down
19 changes: 5 additions & 14 deletions torch_xla/csrc/runtime/pjrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <vector>

#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/logging.h"
Expand All @@ -32,17 +34,6 @@ tsl::StatusOr<xla::XlaComputation> MakeComputation() {
return builder.Build();
}

ComputationClient::TensorSource TensorSourceFromLiteral(
const std::string& device, const xla::Literal& literal) {
auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor,
void* dest_buffer, size_t dest_buffer_size) {
std::memcpy(dest_buffer, literal.data<float>().data(),
dest_buffer_size * sizeof(literal.data<float>().data()));
};
return ComputationClient::TensorSource(literal.shape(), device,
std::move(populate_fn));
}

TEST(PjRtComputationClientTest, Init) {
// Get a CPU client.
tsl::setenv("PJRT_DEVICE", "CPU", true);
Expand All @@ -69,9 +60,9 @@ TEST(PjRtComputationClientTest, Init) {

// Copy inputs to device.
ComputationClient::ExecuteComputationOptions options{};
std::vector<ComputationClient::TensorSource> args = {
TensorSourceFromLiteral(device, literal_x),
TensorSourceFromLiteral(device, literal_y)};
std::vector<std::shared_ptr<const TensorSource>> args = {
std::make_shared<LiteralSource>(std::move(literal_x), device),
std::make_shared<LiteralSource>(std::move(literal_y), device)};

// Execute the graph.
std::vector<ComputationClient::DataPtr> results = client->ExecuteComputation(
Expand Down
100 changes: 100 additions & 0 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#ifndef XLA_CLIENT_TENSOR_SOURCE_H_
#define XLA_CLIENT_TENSOR_SOURCE_H_

#include <ATen/Tensor.h>
#include <torch/csrc/lazy/core/metrics.h>

#include <vector>

#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "xla/literal.h"
#include "xla/shape.h"
#include "xla/shape_util.h"

namespace torch_xla {
namespace runtime {

// Owns a contiguous block of data with the shape and layout matching `shape()`.
class TensorSource {
public:
TensorSource(std::string device) : device_(std::move(device)){};

virtual const void* data() const = 0;

virtual const xla::Shape& shape() const = 0;

const std::string& device() const { return device_; }

virtual std::vector<int64_t> byte_strides() const {
std::vector<int64_t> byte_strides(shape().dimensions_size());
XLA_CHECK_OK(
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
return byte_strides;
}

virtual std::vector<int64_t> dimensions() const {
auto dimensions = shape().dimensions();
return {dimensions.begin(), dimensions.end()};
}

virtual xla::PrimitiveType primitive_type() const {
return shape().element_type();
}

private:
std::string device_;
};

class AtenSource : public TensorSource {
public:
AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device)
: TensorSource(std::move(device)), shape_(std::move(shape)) {
at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type());
if (target_torch_type != tensor.type().scalarType()) {
TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1);
tensor_ = std::move(tensor.to(target_torch_type).contiguous());
} else {
tensor_ = std::move(tensor.contiguous());
}
}

const void* data() const override { return tensor_.const_data_ptr(); }

const xla::Shape& shape() const override { return shape_; }

std::vector<int64_t> byte_strides() const override {
std::vector<int64_t> strides;
for (auto& stride : tensor_.strides()) {
strides.push_back(stride * tensor_.itemsize());
}
return strides;
}

std::vector<int64_t> dimensions() const override {
auto sizes = tensor_.sizes();
return {sizes.begin(), sizes.end()};
}

private:
at::Tensor tensor_;
xla::Shape shape_;
};

class LiteralSource : public TensorSource {
public:
LiteralSource(xla::Literal literal, std::string device)
: TensorSource(std::move(device)), literal_(std::move(literal)) {}

const void* data() const override { return literal_.untyped_data(); }

const xla::Shape& shape() const override { return literal_.shape(); }

private:
xla::Literal literal_;
};

} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_COMPUTATION_CLIENT_H_
Loading

0 comments on commit 05a3cdd

Please sign in to comment.