From 6a39d05d81d9c73fa1a17b3b025e2a7e764dde4e Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 13 Nov 2023 13:16:37 -0800 Subject: [PATCH] Transfer data directly to the device (#5772) * Transfer data directly to the device (#5752) * Remove `populate_fn` from `TensorSource` * Make TensorSource an interface * Re-enable pjrt_computation_client_test * server -> device * add comment * fix outbound data metric * formatting * implement byte_strides in TensorSource * more formatting * remove extra deps * add missing deps * Revert "server -> device" This reverts commit 63845167bbb81a42ecc73dc5835868befd1baa0f. * Use `at::Tensor`'s layout for byte strides * Downcast at::Tensor if required * formatting * Simplify AtenSource * fix build * formatting * fix typo that makes us ignore input type * Revert "Simplify AtenSource" This reverts commit 4225deb22b3823d46bf73d343ede07aa142d0480. * Skip hanging test * fix gil deadlock * formatting --- torch_xla/csrc/runtime/BUILD | 64 +++++++---- torch_xla/csrc/runtime/computation_client.h | 30 ++---- .../csrc/runtime/pjrt_computation_client.cc | 44 ++++---- .../csrc/runtime/pjrt_computation_client.h | 8 +- .../runtime/pjrt_computation_client_test.cc | 19 +--- torch_xla/csrc/runtime/tensor_source.h | 100 ++++++++++++++++++ torch_xla/csrc/tensor_util.cpp | 68 ++++++------ torch_xla/csrc/tensor_util.h | 6 ++ torch_xla/csrc/xla_graph_executor.cpp | 21 +--- torch_xla/csrc/xla_sharding_util.cpp | 13 +-- 10 files changed, 222 insertions(+), 151 deletions(-) create mode 100644 torch_xla/csrc/runtime/tensor_source.h diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 9b500f46da8..9dc3730299d 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -1,3 +1,8 @@ +load( + "//bazel:rules_def.bzl", + "ptxla_cc_test", +) + load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -46,6 +51,7 @@ cc_library( ":metrics_reader", ":metrics", ":sys_util", + ":tensor_source", ":types", ":util", ":xla_coordinator", @@ -78,6 +84,7 @@ cc_library( ":env_vars", ":multi_wait", ":stablehlo_helper", + ":tensor_source", ":tf_logging", ":thread_pool", ":xla_coordinator", @@ -264,6 +271,17 @@ cc_library( ], ) +cc_library( + name = "tensor_source", + hdrs = ["tensor_source.h"], + deps = [ + ":debug_macros", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@torch//:headers", + ] +) + cc_library( name = "types", hdrs = ["types.h"], @@ -339,25 +357,27 @@ ptxla_cc_test( ], ) -# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream. -# xla_cc_test( -# name = "pjrt_computation_client_test", -# srcs = ["pjrt_computation_client_test.cc"], -# deps = [ -# ":computation_client", -# "@xla//xla:literal", -# "@xla//xla:literal_util", -# "@xla//xla:shape_util", -# "@xla//xla:status", -# "@xla//xla:statusor", -# "@xla//xla/client:xla_builder", -# "@xla//xla/client:xla_computation", -# "@xla//xla/tests:literal_test_util", -# "@xla//xla/tools:hlo_module_loader", -# "@org_tensorflow//tensorflow/core/platform:logging", -# "@tsl//tsl/lib/core:status_test_util", -# "@tsl//tsl/platform:env", -# "@tsl//tsl/platform:test", -# "@tsl//tsl/platform:test_main", -# ], -# ) +ptxla_cc_test( + name = "pjrt_computation_client_test", + srcs = ["pjrt_computation_client_test.cc"], + deps = [ + ":computation_client", + ":pjrt_computation_client", + ":tensor_source", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 145a6d0aa09..9af461bfd56 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -1,6 +1,7 @@ #ifndef XLA_CLIENT_COMPUTATION_CLIENT_H_ #define XLA_CLIENT_COMPUTATION_CLIENT_H_ +#include #include #include #include @@ -20,6 +21,7 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/metrics.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" @@ -192,25 +194,6 @@ class ComputationClient { using ComputationPtr = std::shared_ptr; - // The TensorSource provides a way for a client to populate a buffer allocated - // by the core computation client code. - struct TensorSource { - // The PopulateFn accepts a dense buffer is standard array layout - // (dim0-major) and deposits the source tensor data directly over the - // provided buffer. - using PopulateFn = std::function; - - TensorSource() = default; - TensorSource(xla::Shape shape, std::string device, PopulateFn populate_fn) - : shape(std::move(shape)), - device(std::move(device)), - populate_fn(std::move(populate_fn)) {} - - xla::Shape shape; - std::string device; - PopulateFn populate_fn; - }; - // TODO(wcromar): Should CompileInstance still exist? Should it be a subclass // of torch::lazy::Computation? struct CompileInstance { @@ -275,19 +258,22 @@ class ComputationClient { // Transfers local tensor values to the TPU devices and fetches the handles. virtual std::vector TransferToServer( - absl::Span tensors) = 0; + absl::Span> tensors) = 0; // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. virtual DataPtr TransferShardsToServer( - absl::Span tensor_shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) = 0; + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) = 0; // Copies `data->buffer` to `dst` device buffer. virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0; // Reads the tensor literal values stored at TPU server sites, behind the // supplied handles. + // Note: `TransferFromServer` call will block until the `DataPtrs` are ready + // if they were created by `TransferToServer` or `Execute*`. Calling this from + // python while holding the GIL can cause deadlocks! virtual std::vector TransferFromServer( absl::Span handles) = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 94d504e1714..fba50dcb63d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -12,6 +12,7 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" @@ -303,7 +304,7 @@ std::optional PjRtComputationClient::GetDataSharding( } std::vector PjRtComputationClient::TransferToServer( - absl::Span tensors) { + absl::Span> tensors) { metrics::TimedSection timed(TransferToServerMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -311,31 +312,22 @@ std::vector PjRtComputationClient::TransferToServer( datas.reserve(tensors.size()); int64_t total_size = 0; for (auto& tensor : tensors) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device); - - auto literal = std::make_shared(tensor.shape); - tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes()); - std::vector byte_strides(literal->shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(), - absl::MakeSpan(byte_strides))); - total_size += literal->size_bytes(); - - // Avoid use-after-free on `literal` due to unsequenced move and use. - xla::Literal* literal_pointer = literal.get(); - std::shared_ptr buffer = std::move( - client_ - ->BufferFromHostBuffer( - literal_pointer->untyped_data(), - literal_pointer->shape().element_type(), - literal_pointer->shape().dimensions(), byte_strides, - xla::PjRtClient::HostBufferSemantics:: - kImmutableUntilTransferCompletes, - [literal{std::move(literal)}]() { /* frees literal */ }, - pjrt_device) - .value()); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); + + total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); + + std::shared_ptr buffer = + std::move(client_ + ->BufferFromHostBuffer( + tensor->data(), tensor->primitive_type(), + tensor->dimensions(), tensor->byte_strides(), + xla::PjRtClient::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [tensor]() { /* frees tensor */ }, pjrt_device) + .value()); ComputationClient::DataPtr data = - std::make_shared(tensor.device, tensor.shape, buffer); + std::make_shared(tensor->device(), tensor->shape(), buffer); datas.push_back(data); } OutboundDataMetric()->AddSample(total_size); @@ -345,8 +337,8 @@ std::vector PjRtComputationClient::TransferToServer( } ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer( - absl::Span tensor_shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) { + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "PjRtComputationClient::TransferShardsToServer", tsl::profiler::TraceMeLevel::kInfo); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index faebd4892b8..b66e4ff5097 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -37,7 +37,7 @@ class PjRtComputationClient : public ComputationClient { std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToServer( - absl::Span tensors) override; + absl::Span> tensors) override; // Use XLA replication to re-assemble the sharded data. DataPtr ReplicateShardedData(const DataPtr& handle); @@ -45,9 +45,9 @@ class PjRtComputationClient : public ComputationClient { std::vector TransferFromServer( absl::Span handles) override; - DataPtr TransferShardsToServer(absl::Span tensor_shards, - std::string device, xla::Shape shape, - xla::OpSharding sharding) override; + DataPtr TransferShardsToServer( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc index 24cbc4636a6..d6240f08e98 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc @@ -7,6 +7,8 @@ #include #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" @@ -32,17 +34,6 @@ tsl::StatusOr MakeComputation() { return builder.Build(); } -ComputationClient::TensorSource TensorSourceFromLiteral( - const std::string& device, const xla::Literal& literal) { - auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - std::memcpy(dest_buffer, literal.data().data(), - dest_buffer_size * sizeof(literal.data().data())); - }; - return ComputationClient::TensorSource(literal.shape(), device, - std::move(populate_fn)); -} - TEST(PjRtComputationClientTest, Init) { // Get a CPU client. tsl::setenv("PJRT_DEVICE", "CPU", true); @@ -69,9 +60,9 @@ TEST(PjRtComputationClientTest, Init) { // Copy inputs to device. ComputationClient::ExecuteComputationOptions options{}; - std::vector args = { - TensorSourceFromLiteral(device, literal_x), - TensorSourceFromLiteral(device, literal_y)}; + std::vector> args = { + std::make_shared(std::move(literal_x), device), + std::make_shared(std::move(literal_y), device)}; // Execute the graph. std::vector results = client->ExecuteComputation( diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h new file mode 100644 index 00000000000..11d4b2f71a5 --- /dev/null +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -0,0 +1,100 @@ +#ifndef XLA_CLIENT_TENSOR_SOURCE_H_ +#define XLA_CLIENT_TENSOR_SOURCE_H_ + +#include +#include + +#include + +#include "torch_xla/csrc/dtype.h" +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace torch_xla { +namespace runtime { + +// Owns a contiguous block of data with the shape and layout matching `shape()`. +class TensorSource { + public: + TensorSource(std::string device) : device_(std::move(device)){}; + + virtual const void* data() const = 0; + + virtual const xla::Shape& shape() const = 0; + + const std::string& device() const { return device_; } + + virtual std::vector byte_strides() const { + std::vector byte_strides(shape().dimensions_size()); + XLA_CHECK_OK( + xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); + return byte_strides; + } + + virtual std::vector dimensions() const { + auto dimensions = shape().dimensions(); + return {dimensions.begin(), dimensions.end()}; + } + + virtual xla::PrimitiveType primitive_type() const { + return shape().element_type(); + } + + private: + std::string device_; +}; + +class AtenSource : public TensorSource { + public: + AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) + : TensorSource(std::move(device)), shape_(std::move(shape)) { + at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type()); + if (target_torch_type != tensor.type().scalarType()) { + TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); + tensor_ = std::move(tensor.to(target_torch_type).contiguous()); + } else { + tensor_ = std::move(tensor.contiguous()); + } + } + + const void* data() const override { return tensor_.const_data_ptr(); } + + const xla::Shape& shape() const override { return shape_; } + + std::vector byte_strides() const override { + std::vector strides; + for (auto& stride : tensor_.strides()) { + strides.push_back(stride * tensor_.itemsize()); + } + return strides; + } + + std::vector dimensions() const override { + auto sizes = tensor_.sizes(); + return {sizes.begin(), sizes.end()}; + } + + private: + at::Tensor tensor_; + xla::Shape shape_; +}; + +class LiteralSource : public TensorSource { + public: + LiteralSource(xla::Literal literal, std::string device) + : TensorSource(std::move(device)), literal_(std::move(literal)) {} + + const void* data() const override { return literal_.untyped_data(); } + + const xla::Shape& shape() const override { return literal_.shape(); } + + private: + xla::Literal literal_; +}; + +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_CLIENT_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 3690d043055..6e46899aea3 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -35,7 +35,7 @@ namespace torch_xla { namespace { struct DataAsync { - std::vector source_tensors; + std::vector> source_tensors; std::vector async_datas; std::vector handle_unlockers; }; @@ -479,15 +479,9 @@ torch::lazy::BackendDataPtr TensorToXlaData( sharding_spec); } - auto populate_fn = - [&](const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensor, source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - - std::vector source_tensors; - source_tensors.emplace_back(shape, device.toString(), std::move(populate_fn)); + std::vector> source_tensors; + source_tensors.push_back( + std::make_shared(tensor, shape, device.toString())); auto handles = runtime::GetComputationClient()->TransferToServer(source_tensors); @@ -709,19 +703,12 @@ std::vector CreateTensorsData( return WrapXlaData(handles); } - std::vector source_tensors; + std::vector> source_tensors; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - auto populate_fn = - [&, i, device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - source_tensors.emplace_back(std::move(shape), devices[i], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + tensors[i], std::move(shape), devices[i])); } return WrapXlaData( runtime::GetComputationClient()->TransferToServer(source_tensors)); @@ -740,7 +727,8 @@ std::vector CreateTensorsData( torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - std::vector source_tensors; // in + std::vector> + source_tensors; // in std::vector new_handles; // out if (static_cast(device.type()) == XlaDeviceType::SPMD) { // GetLocalDevices returns the list of local devices specified by their @@ -756,15 +744,8 @@ std::vector CreateTensorsData( new_handles.push_back(ShardingUtil::CreateShardedData( local_shards, local_devices, shardings[i])); } else { - auto populate_fn = - [&, i, device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - source_tensors.emplace_back(std::move(shape), devices[i], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + tensors[i], std::move(shape), devices[i])); new_handles = runtime::GetComputationClient()->TransferToServer(source_tensors); } @@ -790,12 +771,33 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape, return literal; } -std::vector XlaDataToTensors( - absl::Span xla_data, - at::ScalarType dest_element_type) { +std::vector ReleaseGilAndTransferData( + absl::Span xla_data) { + // HACK: This method may be called outside of python (mainly in C++ tests) or + // when the GIL is already released, so we must check both cases here. If + // possible, prefer to release the GIL in the python bindings before copying + // this pattern. + PyThreadState* save = nullptr; + // TODO(wcromar): Remove this setting when we are more confident + static const bool release_gil = + runtime::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true); + if (release_gil && Py_IsInitialized() && PyGILState_Check()) { + save = PyEval_SaveThread(); + } std::vector literals = runtime::GetComputationClient()->TransferFromServer( UnwrapXlaData(xla_data)); + if (save) { + PyEval_RestoreThread(save); + } + + return literals; +} + +std::vector XlaDataToTensors( + absl::Span xla_data, + at::ScalarType dest_element_type) { + std::vector literals = ReleaseGilAndTransferData(xla_data); std::vector tensors; tensors.reserve(literals.size()); for (auto& literal : literals) { diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 480a2e23f7a..81b4cd9a565 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -25,6 +25,12 @@ std::vector ComputeShapeStrides(const xla::Shape& shape); at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal, at::ScalarType dest_element_type); +// Execution and data transfer are async in PJRT, so TransferFromServer may +// block until `DataPtr`s are ready. Release the GIL so other threads can +// proceed and unblock any transfers or collective computations. +std::vector ReleaseGilAndTransferData( + absl::Span xla_data); + // TODO LTC @wonjoo - Migrate to upstream after Device -> BackendDevice std::vector XlaDataToTensors( absl::Span xla_data, diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 4686de452da..02ea28874e3 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -422,26 +422,7 @@ std::vector XLAGraphExecutor::GetTensors( async != nullptr ? async->tensors_data : absl::Span()); - // Execution is async in PJRT, so TransferFromServer may block until execution - // completes. Release the GIL so other threads can proceed and unblock any - // collective computations. - // HACK: This method may be called outside of python (mainly in C++ tests) or - // when the GIL is already released, so we must check both cases here. If - // possible, prefer to release the GIL in the python bindings before copying - // this pattern. - PyThreadState* save = nullptr; - // TODO(wcromar): Remove this setting when we are more confident - static const bool release_gil = - runtime::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true); - if (release_gil && Py_IsInitialized() && PyGILState_Check()) { - save = PyEval_SaveThread(); - } - std::vector literals = - runtime::GetComputationClient()->TransferFromServer( - UnwrapXlaData(tensors_data)); - if (save) { - PyEval_RestoreThread(save); - } + std::vector literals = ReleaseGilAndTransferData(tensors_data); return FetchTensors(tensors, literals, async != nullptr ? &async->indices : nullptr); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 4fb304d37dd..ae586316073 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -716,7 +716,7 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; - std::vector source_tensors; + std::vector> source_tensors; xla::Shape global_shape; xla::OpSharding sharding; if (sharding_spec == nullptr) { @@ -733,15 +733,8 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = CreateComputationShapeFromTensor(local_shards[j], &shard_device); - auto populate_fn = - [&, j, shard_device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(local_shards[j], source_tensor.shape, - dest_buffer, dest_buffer_size, shard_device); - }; - source_tensors.emplace_back(shard_shape, devices[j], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + local_shards[j], shard_shape, devices[j])); } return runtime::GetComputationClient()->TransferShardsToServer( source_tensors, GetVirtualDevice().toString(), global_shape, sharding);