Skip to content

Commit

Permalink
Revert "Transfer data directly to the device (#5752)" (#5765)
Browse files Browse the repository at this point in the history
This reverts commit 5ca36cf.
  • Loading branch information
will-cromar authored and ManfeiBai committed Nov 29, 2023
1 parent dfe84f3 commit eba3afc
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 158 deletions.
64 changes: 22 additions & 42 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
load(
"//bazel:rules_def.bzl",
"ptxla_cc_test",
)

load(
"@tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
Expand Down Expand Up @@ -66,7 +61,6 @@ cc_library(
":metrics_reader",
":metrics",
":sys_util",
":tensor_source",
":types",
":util",
":xla_coordinator",
Expand Down Expand Up @@ -98,7 +92,6 @@ cc_library(
":env_vars",
":multi_wait",
":stablehlo_helper",
":tensor_source",
":tf_logging",
":thread_pool",
":xla_coordinator",
Expand Down Expand Up @@ -299,17 +292,6 @@ cc_library(
],
)

cc_library(
name = "tensor_source",
hdrs = ["tensor_source.h"],
deps = [
":debug_macros",
"@xla//xla:literal",
"@xla//xla:shape_util",
"@torch//:headers",
]
)

cc_library(
name = "types",
hdrs = ["types.h"],
Expand Down Expand Up @@ -394,27 +376,25 @@ cc_test(
],
)

ptxla_cc_test(
name = "pjrt_computation_client_test",
srcs = ["pjrt_computation_client_test.cc"],
deps = [
":computation_client",
":pjrt_computation_client",
":tensor_source",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/client:xla_builder",
"@xla//xla/client:xla_computation",
"@xla//xla/tests:literal_test_util",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/lib/core:status_test_util",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_main",
],
)
# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream.
# xla_cc_test(
# name = "pjrt_computation_client_test",
# srcs = ["pjrt_computation_client_test.cc"],
# deps = [
# ":computation_client",
# "@xla//xla:literal",
# "@xla//xla:literal_util",
# "@xla//xla:shape_util",
# "@xla//xla:status",
# "@xla//xla:statusor",
# "@xla//xla/client:xla_builder",
# "@xla//xla/client:xla_computation",
# "@xla//xla/tests:literal_test_util",
# "@xla//xla/tools:hlo_module_loader",
# "@org_tensorflow//tensorflow/core/platform:logging",
# "@tsl//tsl/lib/core:status_test_util",
# "@tsl//tsl/platform:env",
# "@tsl//tsl/platform:test",
# "@tsl//tsl/platform:test_main",
# ],
# )
27 changes: 22 additions & 5 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#ifndef XLA_CLIENT_COMPUTATION_CLIENT_H_
#define XLA_CLIENT_COMPUTATION_CLIENT_H_

#include <ATen/Tensor.h>
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/hash.h>
Expand All @@ -21,7 +20,6 @@
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/types.h"
#include "torch_xla/csrc/runtime/util.h"
#include "xla/client/xla_computation.h"
Expand Down Expand Up @@ -194,6 +192,25 @@ class ComputationClient {

using ComputationPtr = std::shared_ptr<Computation>;

// The TensorSource provides a way for a client to populate a buffer allocated
// by the core computation client code.
struct TensorSource {
// The PopulateFn accepts a dense buffer is standard array layout
// (dim0-major) and deposits the source tensor data directly over the
// provided buffer.
using PopulateFn = std::function<void(const TensorSource&, void*, size_t)>;

TensorSource() = default;
TensorSource(xla::Shape shape, std::string device, PopulateFn populate_fn)
: shape(std::move(shape)),
device(std::move(device)),
populate_fn(std::move(populate_fn)) {}

xla::Shape shape;
std::string device;
PopulateFn populate_fn;
};

// TODO(wcromar): Should CompileInstance still exist? Should it be a subclass
// of torch::lazy::Computation?
struct CompileInstance {
Expand Down Expand Up @@ -258,13 +275,13 @@ class ComputationClient {

// Transfers local tensor values to the TPU devices and fetches the handles.
virtual std::vector<DataPtr> TransferToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensors) = 0;
absl::Span<const TensorSource> tensors) = 0;

// Transfers local sharded tensor values to the TPU devices and returns a
// `PjRtShardedData`.
virtual DataPtr TransferShardsToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) = 0;
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) = 0;

// Copies `data->buffer` to `dst` device buffer.
virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0;
Expand Down
44 changes: 26 additions & 18 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/thread_pool.h"
#include "torch_xla/csrc/runtime/xla_coordinator.h"
Expand Down Expand Up @@ -287,30 +286,39 @@ std::optional<xla::OpSharding> PjRtComputationClient::GetDataSharding(
}

std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensors) {
absl::Span<const TensorSource> tensors) {
metrics::TimedSection timed(TransferToServerMetric());
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToServer",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::DataPtr> datas;
datas.reserve(tensors.size());
int64_t total_size = 0;
for (auto& tensor : tensors) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device());

total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape());

std::shared_ptr<xla::PjRtBuffer> buffer =
std::move(client_
->BufferFromHostBuffer(
tensor->data(), tensor->shape().element_type(),
tensor->shape().dimensions(), tensor->byte_strides(),
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[tensor]() { /* frees tensor */ }, pjrt_device)
.value());
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device);

auto literal = std::make_shared<xla::Literal>(tensor.shape);
tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes());
std::vector<int64_t> byte_strides(literal->shape().dimensions_size());
XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(),
absl::MakeSpan(byte_strides)));
total_size += literal->size_bytes();

// Avoid use-after-free on `literal` due to unsequenced move and use.
xla::Literal* literal_pointer = literal.get();
std::shared_ptr<xla::PjRtBuffer> buffer = std::move(
client_
->BufferFromHostBuffer(
literal_pointer->untyped_data(),
literal_pointer->shape().element_type(),
literal_pointer->shape().dimensions(), byte_strides,
xla::PjRtClient::HostBufferSemantics::
kImmutableUntilTransferCompletes,
[literal{std::move(literal)}]() { /* frees literal */ },
pjrt_device)
.value());

ComputationClient::DataPtr data =
std::make_shared<PjRtData>(tensor->device(), tensor->shape(), buffer);
std::make_shared<PjRtData>(tensor.device, tensor.shape, buffer);
datas.push_back(data);
}
OutboundDataMetric()->AddSample(total_size);
Expand All @@ -320,8 +328,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
}

ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) {
absl::Span<const TensorSource> tensor_shards, std::string device,
xla::Shape shape, xla::OpSharding sharding) {
tsl::profiler::TraceMe activity(
"PjRtComputationClient::TransferShardsToServer",
tsl::profiler::TraceMeLevel::kInfo);
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ class PjRtComputationClient : public ComputationClient {
std::optional<xla::OpSharding> GetDataSharding(DataPtr handle) override;

std::vector<DataPtr> TransferToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensors) override;
absl::Span<const TensorSource> tensors) override;

// Use XLA replication to re-assemble the sharded data.
DataPtr ReplicateShardedData(const DataPtr& handle);

std::vector<xla::Literal> TransferFromServer(
absl::Span<const DataPtr> handles) override;

DataPtr TransferShardsToServer(
absl::Span<const std::shared_ptr<const TensorSource>> tensor_shards,
std::string device, xla::Shape shape, xla::OpSharding sharding) override;
DataPtr TransferShardsToServer(absl::Span<const TensorSource> tensor_shards,
std::string device, xla::Shape shape,
xla::OpSharding sharding) override;

DataPtr CopyToDevice(DataPtr data, std::string dst) override;

Expand Down
19 changes: 14 additions & 5 deletions torch_xla/csrc/runtime/pjrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include <vector>

#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/tensor_source.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
#include "tsl/platform/logging.h"
Expand All @@ -34,6 +32,17 @@ tsl::StatusOr<xla::XlaComputation> MakeComputation() {
return builder.Build();
}

ComputationClient::TensorSource TensorSourceFromLiteral(
const std::string& device, const xla::Literal& literal) {
auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor,
void* dest_buffer, size_t dest_buffer_size) {
std::memcpy(dest_buffer, literal.data<float>().data(),
dest_buffer_size * sizeof(literal.data<float>().data()));
};
return ComputationClient::TensorSource(literal.shape(), device,
std::move(populate_fn));
}

TEST(PjRtComputationClientTest, Init) {
// Get a CPU client.
tsl::setenv("PJRT_DEVICE", "CPU", true);
Expand All @@ -60,9 +69,9 @@ TEST(PjRtComputationClientTest, Init) {

// Copy inputs to device.
ComputationClient::ExecuteComputationOptions options{};
std::vector<std::shared_ptr<const TensorSource>> args = {
std::make_shared<LiteralSource>(std::move(literal_x), device),
std::make_shared<LiteralSource>(std::move(literal_y), device)};
std::vector<ComputationClient::TensorSource> args = {
TensorSourceFromLiteral(device, literal_x),
TensorSourceFromLiteral(device, literal_y)};

// Execute the graph.
std::vector<ComputationClient::DataPtr> results = client->ExecuteComputation(
Expand Down
70 changes: 0 additions & 70 deletions torch_xla/csrc/runtime/tensor_source.h

This file was deleted.

Loading

0 comments on commit eba3afc

Please sign in to comment.