Skip to content

Commit

Permalink
Transfer data directly to the device (#5752)
Browse files Browse the repository at this point in the history
* Remove `populate_fn` from `TensorSource`

* Make TensorSource an interface

* Re-enable pjrt_computation_client_test

* server -> device

* add comment

* fix outbound data metric

* formatting

* implement byte_strides in TensorSource

* more formatting

* remove extra deps

* add missing deps

* Revert "server -> device"

This reverts commit 6384516.
  • Loading branch information
will-cromar authored Nov 2, 2023
1 parent b20a082 commit 5ca36cf
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 128 deletions.
64 changes: 42 additions & 22 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
load(
"//bazel:rules_def.bzl",
"ptxla_cc_test",
)

load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
Expand Down Expand Up @@ -61,6 +66,7 @@ cc_library(
":metrics_reader",
":metrics",
":sys_util",
":tensor_source",
":types",
":util",
":xla_coordinator",
Expand Down Expand Up @@ -92,6 +98,7 @@ cc_library(
":env_vars",
":multi_wait",
":stablehlo_helper",
":tensor_source",
":tf_logging",
":thread_pool",
":xla_coordinator",
Expand Down Expand Up @@ -292,6 +299,17 @@ cc_library(
],
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
deps = [
":debug_macros",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@torch//:headers",
]
)

cc_library(
name = "types",
hdrs = ["types.h"],
Expand Down Expand Up @@ -376,25 +394,27 @@ cc_test(
],
)

# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream.
# xla_cc_test(
# name = "pjrt_computation_client_test",
# srcs = ["pjrt_computation_client_test.cc"],
# deps = [
# ":computation_client",
# "@xla//xla:literal",
# "@xla//xla:literal_util",
# "@xla//xla:shape_util",
# "@xla//xla:status",
# "@xla//xla:statusor",
# "@xla//xla/client:xla_builder",
# "@xla//xla/client:xla_computation",
# "@xla//xla/tests:literal_test_util",
# "@xla//xla/tools:hlo_module_loader",
# "@org_tensorflow//tensorflow/core/platform:logging",
# "@tsl//tsl/lib/core:status_test_util",
# "@tsl//tsl/platform:env",
# "@tsl//tsl/platform:test",
# "@tsl//tsl/platform:test_main",
# ],
# )
ptxla_cc_test(
name = "pjrt_computation_client_test",
srcs = ["pjrt_computation_client_test.cc"],
deps = [
":computation_client",
":pjrt_computation_client",
":tensor_source",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)
27 changes: 5 additions & 22 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef XLA_CLIENT_COMPUTATION_CLIENT_H_
#define XLA_CLIENT_COMPUTATION_CLIENT_H_

#include <ATen/Tensor.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/hash.h>
Expand All @@ -20,6 +21,7 @@
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/types.h"
#include "torch_xla/csrc/runtime/util.h"
#include "xla/client/xla_computation.h"
Expand Down Expand Up @@ -192,25 +194,6 @@ class ComputationClient {

using ComputationPtr = std::shared_ptr<Computation>;

// The TensorSource provides a way for a client to populate a buffer allocated
// by the core computation client code.
struct TensorSource {
// The PopulateFn accepts a dense buffer is standard array layout
// (dim0-major) and deposits the source tensor data directly over the
// provided buffer.
using PopulateFn = std::function<void(const TensorSource&, void*, size_t)>;

TensorSource() = default;
TensorSource(xla::Shape shape, std::string device, PopulateFn populate_fn)
: shape(std::move(shape)),
device(std::move(device)),
populate_fn(std::move(populate_fn)) {}

xla::Shape shape;
std::string device;
PopulateFn populate_fn;
};

// TODO(wcromar): Should CompileInstance still exist? Should it be a subclass
// of torch::lazy::Computation?
struct CompileInstance {
Expand Down Expand Up @@ -275,13 +258,13 @@ class ComputationClient {

// Transfers local tensor values to the TPU devices and fetches the handles.
virtual std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) = 0;
absl::Span<const std::shared_ptr<const TensorSource>> tensors) = 0;

// Transfers local sharded tensor values to the TPU devices and returns a
// `PjRtShardedData`.
virtual DataPtr TransferShardsToServer(
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) = 0;
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) = 0;

// Copies `data->buffer` to `dst` device buffer.
virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0;
Expand Down
44 changes: 18 additions & 26 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
Expand Down Expand Up @@ -286,39 +287,30 @@ std::optional<xla::OpSharding> PjRtComputationClient::GetDataSharding(
}

std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
absl::Span<const TensorSource> tensors) {
absl::Span<const std::shared_ptr<const TensorSource>> tensors) {
metrics::TimedSection timed(TransferToServerMetric());
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToServer",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::DataPtr> datas;
datas.reserve(tensors.size());
int64_t total_size = 0;
for (auto& tensor : tensors) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device);

auto literal = std::make_shared<xla::Literal>(tensor.shape);
tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes());
std::vector<int64_t> byte_strides(literal->shape().dimensions_size());
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(),
absl::MakeSpan(byte_strides)));
total_size += literal->size_bytes();

// Avoid use-after-free on `literal` due to unsequenced move and use.
xla::Literal* literal_pointer = literal.get();
std::shared_ptr<xla::PjRtBuffer> buffer = std::move(
client_
->BufferFromHostBuffer(
literal_pointer->untyped_data(),
literal_pointer->shape().element_type(),
literal_pointer->shape().dimensions(), byte_strides,
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[literal{std::move(literal)}]() { /* frees literal */ },
pjrt_device)
.value());
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device());

total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape());

std::shared_ptr<xla::PjRtBuffer> buffer =
std::move(client_
->BufferFromHostBuffer(
tensor->data(), tensor->shape().element_type(),
tensor->shape().dimensions(), tensor->byte_strides(),
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[tensor]() { /* frees tensor */ }, pjrt_device)
.value());

ComputationClient::DataPtr data =
std::make_shared<PjRtData>(tensor.device, tensor.shape, buffer);
std::make_shared<PjRtData>(tensor->device(), tensor->shape(), buffer);
datas.push_back(data);
}
OutboundDataMetric()->AddSample(total_size);
Expand All @@ -328,8 +320,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
}

ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer(
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) {
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) {
tsl::profiler::TraceMe activity(
"PjRtComputationClient::TransferShardsToServer",
tsl::profiler::TraceMeLevel::kInfo);
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ class PjRtComputationClient : public ComputationClient {
std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) override;

std::vector<DataPtr> TransferToServer(
absl::Span<const TensorSource> tensors) override;
absl::Span<const std::shared_ptr<const TensorSource>> tensors) override;

// Use XLA replication to re-assemble the sharded data.
DataPtr ReplicateShardedData(const DataPtr& handle);

std::vector<xla::Literal> TransferFromServer(
absl::Span<const DataPtr> handles) override;

DataPtr TransferShardsToServer(absl::Span<const TensorSource> tensor_shards,
std::string device, xla::Shape shape,
xla::OpSharding sharding) override;
DataPtr TransferShardsToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) override;

DataPtr CopyToDevice(DataPtr data, std::string dst) override;

Expand Down
19 changes: 5 additions & 14 deletions torch_xla/csrc/runtime/pjrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <vector>

#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/logging.h"
Expand All @@ -32,17 +34,6 @@ tsl::StatusOr<xla::XlaComputation> MakeComputation() {
return builder.Build();
}

ComputationClient::TensorSource TensorSourceFromLiteral(
const std::string& device, const xla::Literal& literal) {
auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor,
void* dest_buffer, size_t dest_buffer_size) {
std::memcpy(dest_buffer, literal.data<float>().data(),
dest_buffer_size * sizeof(literal.data<float>().data()));
};
return ComputationClient::TensorSource(literal.shape(), device,
std::move(populate_fn));
}

TEST(PjRtComputationClientTest, Init) {
// Get a CPU client.
tsl::setenv("PJRT_DEVICE", "CPU", true);
Expand All @@ -69,9 +60,9 @@ TEST(PjRtComputationClientTest, Init) {

// Copy inputs to device.
ComputationClient::ExecuteComputationOptions options{};
std::vector<ComputationClient::TensorSource> args = {
TensorSourceFromLiteral(device, literal_x),
TensorSourceFromLiteral(device, literal_y)};
std::vector<std::shared_ptr<const TensorSource>> args = {
std::make_shared<LiteralSource>(std::move(literal_x), device),
std::make_shared<LiteralSource>(std::move(literal_y), device)};

// Execute the graph.
std::vector<ComputationClient::DataPtr> results = client->ExecuteComputation(
Expand Down
70 changes: 70 additions & 0 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#ifndef XLA_CLIENT_TENSOR_SOURCE_H_
#define XLA_CLIENT_TENSOR_SOURCE_H_

#include <ATen/Tensor.h>

#include <vector>

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "xla/literal.h"
#include "xla/shape.h"
#include "xla/shape_util.h"

namespace torch_xla {
namespace runtime {

// Owns a contiguous block of data with the shape and layout matching `shape()`.
class TensorSource {
public:
TensorSource(std::string device) : device_(std::move(device)){};

virtual const void* data() const = 0;

virtual const xla::Shape& shape() const = 0;

const std::string& device() const { return device_; }

std::vector<int64_t> byte_strides() const {
std::vector<int64_t> byte_strides(shape().dimensions_size());
XLA_CHECK_OK(
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
return byte_strides;
}

private:
std::string device_;
};

class AtenSource : public TensorSource {
public:
AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device)
: TensorSource(std::move(device)),
tensor_(std::move(tensor.contiguous())),
shape_(std::move(shape)) {}

const void* data() const override { return tensor_.const_data_ptr(); }

const xla::Shape& shape() const override { return shape_; }

private:
at::Tensor tensor_;
xla::Shape shape_;
};

class LiteralSource : public TensorSource {
public:
LiteralSource(xla::Literal literal, std::string device)
: TensorSource(std::move(device)), literal_(std::move(literal)) {}

const void* data() const override { return literal_.untyped_data(); }

const xla::Shape& shape() const override { return literal_.shape(); }

private:
xla::Literal literal_;
};

} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_COMPUTATION_CLIENT_H_
Loading

0 comments on commit 5ca36cf

Please sign in to comment.