From a93e14bf154273450b78108cd6fdc3e5999ebdc2 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Thu, 2 Nov 2023 12:06:23 -0700 Subject: [PATCH] Revert "Transfer data directly to the device (#5752)" (#5765) This reverts commit 5ca36cf5fe3f1d75e7311c0371019b7d91343e44. --- torch_xla/csrc/runtime/BUILD | 64 ++++++----------- torch_xla/csrc/runtime/computation_client.h | 27 +++++-- .../csrc/runtime/pjrt_computation_client.cc | 44 +++++++----- .../csrc/runtime/pjrt_computation_client.h | 8 +-- .../runtime/pjrt_computation_client_test.cc | 19 +++-- torch_xla/csrc/runtime/tensor_source.h | 70 ------------------- torch_xla/csrc/tensor_util.cpp | 41 ++++++++--- torch_xla/csrc/xla_sharding_util.cpp | 13 +++- 8 files changed, 128 insertions(+), 158 deletions(-) delete mode 100644 torch_xla/csrc/runtime/tensor_source.h diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index c17e1a886ec..da12492d2cd 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -1,8 +1,3 @@ -load( - "//bazel:rules_def.bzl", - "ptxla_cc_test", -) - load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -66,7 +61,6 @@ cc_library( ":metrics_reader", ":metrics", ":sys_util", - ":tensor_source", ":types", ":util", ":xla_coordinator", @@ -98,7 +92,6 @@ cc_library( ":env_vars", ":multi_wait", ":stablehlo_helper", - ":tensor_source", ":tf_logging", ":thread_pool", ":xla_coordinator", @@ -299,17 +292,6 @@ cc_library( ], ) -cc_library( - name = "tensor_source", - hdrs = ["tensor_source.h"], - deps = [ - ":debug_macros", - "@xla//xla:literal", - "@xla//xla:shape_util", - "@torch//:headers", - ] -) - cc_library( name = "types", hdrs = ["types.h"], @@ -394,27 +376,25 @@ cc_test( ], ) -ptxla_cc_test( - name = "pjrt_computation_client_test", - srcs = ["pjrt_computation_client_test.cc"], - deps = [ - ":computation_client", - ":pjrt_computation_client", - ":tensor_source", - "@xla//xla:literal", - "@xla//xla:literal_util", - "@xla//xla:shape_util", - "@xla//xla:status", - "@xla//xla:statusor", - "@xla//xla/client:xla_builder", - "@xla//xla/client:xla_computation", - "@xla//xla/tests:literal_test_util", - "@xla//xla/tools:hlo_module_loader", - "@tsl//tsl/lib/core:status_test_util", - "@tsl//tsl/platform:env", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:test", - "@tsl//tsl/platform:test_main", - ], -) +# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream. +# xla_cc_test( +# name = "pjrt_computation_client_test", +# srcs = ["pjrt_computation_client_test.cc"], +# deps = [ +# ":computation_client", +# "@xla//xla:literal", +# "@xla//xla:literal_util", +# "@xla//xla:shape_util", +# "@xla//xla:status", +# "@xla//xla:statusor", +# "@xla//xla/client:xla_builder", +# "@xla//xla/client:xla_computation", +# "@xla//xla/tests:literal_test_util", +# "@xla//xla/tools:hlo_module_loader", +# "@org_tensorflow//tensorflow/core/platform:logging", +# "@tsl//tsl/lib/core:status_test_util", +# "@tsl//tsl/platform:env", +# "@tsl//tsl/platform:test", +# "@tsl//tsl/platform:test_main", +# ], +# ) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index eee5a18e3f9..145a6d0aa09 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -1,7 +1,6 @@ #ifndef XLA_CLIENT_COMPUTATION_CLIENT_H_ #define XLA_CLIENT_COMPUTATION_CLIENT_H_ -#include #include #include #include @@ -21,7 +20,6 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/metrics.h" -#include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" @@ -194,6 +192,25 @@ class ComputationClient { using ComputationPtr = std::shared_ptr; + // The TensorSource provides a way for a client to populate a buffer allocated + // by the core computation client code. + struct TensorSource { + // The PopulateFn accepts a dense buffer is standard array layout + // (dim0-major) and deposits the source tensor data directly over the + // provided buffer. + using PopulateFn = std::function; + + TensorSource() = default; + TensorSource(xla::Shape shape, std::string device, PopulateFn populate_fn) + : shape(std::move(shape)), + device(std::move(device)), + populate_fn(std::move(populate_fn)) {} + + xla::Shape shape; + std::string device; + PopulateFn populate_fn; + }; + // TODO(wcromar): Should CompileInstance still exist? Should it be a subclass // of torch::lazy::Computation? struct CompileInstance { @@ -258,13 +275,13 @@ class ComputationClient { // Transfers local tensor values to the TPU devices and fetches the handles. virtual std::vector TransferToServer( - absl::Span> tensors) = 0; + absl::Span tensors) = 0; // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. virtual DataPtr TransferShardsToServer( - absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) = 0; + absl::Span tensor_shards, std::string device, + xla::Shape shape, xla::OpSharding sharding) = 0; // Copies `data->buffer` to `dst` device buffer. virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 81dc1271129..c003f4f9706 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -12,7 +12,6 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" -#include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" @@ -287,7 +286,7 @@ std::optional PjRtComputationClient::GetDataSharding( } std::vector PjRtComputationClient::TransferToServer( - absl::Span> tensors) { + absl::Span tensors) { metrics::TimedSection timed(TransferToServerMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -295,22 +294,31 @@ std::vector PjRtComputationClient::TransferToServer( datas.reserve(tensors.size()); int64_t total_size = 0; for (auto& tensor : tensors) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); - - total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); - - std::shared_ptr buffer = - std::move(client_ - ->BufferFromHostBuffer( - tensor->data(), tensor->shape().element_type(), - tensor->shape().dimensions(), tensor->byte_strides(), - xla::PjRtClient::HostBufferSemantics:: - kImmutableUntilTransferCompletes, - [tensor]() { /* frees tensor */ }, pjrt_device) - .value()); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device); + + auto literal = std::make_shared(tensor.shape); + tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes()); + std::vector byte_strides(literal->shape().dimensions_size()); + XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(), + absl::MakeSpan(byte_strides))); + total_size += literal->size_bytes(); + + // Avoid use-after-free on `literal` due to unsequenced move and use. + xla::Literal* literal_pointer = literal.get(); + std::shared_ptr buffer = std::move( + client_ + ->BufferFromHostBuffer( + literal_pointer->untyped_data(), + literal_pointer->shape().element_type(), + literal_pointer->shape().dimensions(), byte_strides, + xla::PjRtClient::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [literal{std::move(literal)}]() { /* frees literal */ }, + pjrt_device) + .value()); ComputationClient::DataPtr data = - std::make_shared(tensor->device(), tensor->shape(), buffer); + std::make_shared(tensor.device, tensor.shape, buffer); datas.push_back(data); } OutboundDataMetric()->AddSample(total_size); @@ -320,8 +328,8 @@ std::vector PjRtComputationClient::TransferToServer( } ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer( - absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) { + absl::Span tensor_shards, std::string device, + xla::Shape shape, xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "PjRtComputationClient::TransferShardsToServer", tsl::profiler::TraceMeLevel::kInfo); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index b66e4ff5097..faebd4892b8 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -37,7 +37,7 @@ class PjRtComputationClient : public ComputationClient { std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToServer( - absl::Span> tensors) override; + absl::Span tensors) override; // Use XLA replication to re-assemble the sharded data. DataPtr ReplicateShardedData(const DataPtr& handle); @@ -45,9 +45,9 @@ class PjRtComputationClient : public ComputationClient { std::vector TransferFromServer( absl::Span handles) override; - DataPtr TransferShardsToServer( - absl::Span> tensor_shards, - std::string device, xla::Shape shape, xla::OpSharding sharding) override; + DataPtr TransferShardsToServer(absl::Span tensor_shards, + std::string device, xla::Shape shape, + xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc index d6240f08e98..24cbc4636a6 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc @@ -7,8 +7,6 @@ #include #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/pjrt_computation_client.h" -#include "torch_xla/csrc/runtime/tensor_source.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" @@ -34,6 +32,17 @@ tsl::StatusOr MakeComputation() { return builder.Build(); } +ComputationClient::TensorSource TensorSourceFromLiteral( + const std::string& device, const xla::Literal& literal) { + auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor, + void* dest_buffer, size_t dest_buffer_size) { + std::memcpy(dest_buffer, literal.data().data(), + dest_buffer_size * sizeof(literal.data().data())); + }; + return ComputationClient::TensorSource(literal.shape(), device, + std::move(populate_fn)); +} + TEST(PjRtComputationClientTest, Init) { // Get a CPU client. tsl::setenv("PJRT_DEVICE", "CPU", true); @@ -60,9 +69,9 @@ TEST(PjRtComputationClientTest, Init) { // Copy inputs to device. ComputationClient::ExecuteComputationOptions options{}; - std::vector> args = { - std::make_shared(std::move(literal_x), device), - std::make_shared(std::move(literal_y), device)}; + std::vector args = { + TensorSourceFromLiteral(device, literal_x), + TensorSourceFromLiteral(device, literal_y)}; // Execute the graph. std::vector results = client->ExecuteComputation( diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h deleted file mode 100644 index 4f24084e019..00000000000 --- a/torch_xla/csrc/runtime/tensor_source.h +++ /dev/null @@ -1,70 +0,0 @@ -#ifndef XLA_CLIENT_TENSOR_SOURCE_H_ -#define XLA_CLIENT_TENSOR_SOURCE_H_ - -#include - -#include - -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "xla/literal.h" -#include "xla/shape.h" -#include "xla/shape_util.h" - -namespace torch_xla { -namespace runtime { - -// Owns a contiguous block of data with the shape and layout matching `shape()`. -class TensorSource { - public: - TensorSource(std::string device) : device_(std::move(device)){}; - - virtual const void* data() const = 0; - - virtual const xla::Shape& shape() const = 0; - - const std::string& device() const { return device_; } - - std::vector byte_strides() const { - std::vector byte_strides(shape().dimensions_size()); - XLA_CHECK_OK( - xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); - return byte_strides; - } - - private: - std::string device_; -}; - -class AtenSource : public TensorSource { - public: - AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) - : TensorSource(std::move(device)), - tensor_(std::move(tensor.contiguous())), - shape_(std::move(shape)) {} - - const void* data() const override { return tensor_.const_data_ptr(); } - - const xla::Shape& shape() const override { return shape_; } - - private: - at::Tensor tensor_; - xla::Shape shape_; -}; - -class LiteralSource : public TensorSource { - public: - LiteralSource(xla::Literal literal, std::string device) - : TensorSource(std::move(device)), literal_(std::move(literal)) {} - - const void* data() const override { return literal_.untyped_data(); } - - const xla::Shape& shape() const override { return literal_.shape(); } - - private: - xla::Literal literal_; -}; - -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index f3749a1ecd2..a419bd98b7e 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -34,7 +34,7 @@ namespace torch_xla { namespace { struct DataAsync { - std::vector> source_tensors; + std::vector source_tensors; std::vector async_datas; std::vector handle_unlockers; }; @@ -587,9 +587,15 @@ torch::lazy::BackendDataPtr TensorToXlaData( sharding_spec); } - std::vector> source_tensors; - source_tensors.push_back( - std::make_shared(tensor, shape, device.toString())); + auto populate_fn = + [&](const runtime::ComputationClient::TensorSource& source_tensor, + void* dest_buffer, size_t dest_buffer_size) { + PopulateTensorBuffer(tensor, source_tensor.shape, dest_buffer, + dest_buffer_size, device); + }; + + std::vector source_tensors; + source_tensors.emplace_back(shape, device.toString(), std::move(populate_fn)); auto handles = runtime::GetComputationClient()->TransferToServer(source_tensors); @@ -811,12 +817,19 @@ std::vector CreateTensorsData( return WrapXlaData(handles); } - std::vector> source_tensors; + std::vector source_tensors; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - source_tensors.push_back(std::make_shared( - tensors[i], std::move(shape), devices[i])); + auto populate_fn = + [&, i, device]( + const runtime::ComputationClient::TensorSource& source_tensor, + void* dest_buffer, size_t dest_buffer_size) { + PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, + dest_buffer_size, device); + }; + source_tensors.emplace_back(std::move(shape), devices[i], + std::move(populate_fn)); } return WrapXlaData( runtime::GetComputationClient()->TransferToServer(source_tensors)); @@ -835,8 +848,7 @@ std::vector CreateTensorsData( torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - std::vector> - source_tensors; // in + std::vector source_tensors; // in std::vector new_handles; // out if (static_cast(device.type()) == XlaDeviceType::SPMD) { // GetLocalDevices returns the list of local devices specified by their @@ -852,8 +864,15 @@ std::vector CreateTensorsData( new_handles.push_back(ShardingUtil::CreateShardedData( local_shards, local_devices, shardings[i])); } else { - source_tensors.push_back(std::make_shared( - tensors[i], std::move(shape), devices[i])); + auto populate_fn = + [&, i, device]( + const runtime::ComputationClient::TensorSource& source_tensor, + void* dest_buffer, size_t dest_buffer_size) { + PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, + dest_buffer_size, device); + }; + source_tensors.emplace_back(std::move(shape), devices[i], + std::move(populate_fn)); new_handles = runtime::GetComputationClient()->TransferToServer(source_tensors); } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index e7fc8f7fdea..cde74256eee 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -711,7 +711,7 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; - std::vector> source_tensors; + std::vector source_tensors; xla::Shape global_shape; xla::OpSharding sharding; if (sharding_spec == nullptr) { @@ -728,8 +728,15 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = CreateComputationShapeFromTensor(local_shards[j], &shard_device); - source_tensors.push_back(std::make_shared( - local_shards[j], shard_shape, devices[j])); + auto populate_fn = + [&, j, shard_device]( + const runtime::ComputationClient::TensorSource& source_tensor, + void* dest_buffer, size_t dest_buffer_size) { + PopulateTensorBuffer(local_shards[j], source_tensor.shape, + dest_buffer, dest_buffer_size, shard_device); + }; + source_tensors.emplace_back(shard_shape, devices[j], + std::move(populate_fn)); } return runtime::GetComputationClient()->TransferShardsToServer( source_tensors, GetVirtualDevice().toString(), global_shape, sharding);