diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index da12492d2cd..c17e1a886ec 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -1,3 +1,8 @@ +load( + "//bazel:rules_def.bzl", + "ptxla_cc_test", +) + load( "@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -61,6 +66,7 @@ cc_library( ":metrics_reader", ":metrics", ":sys_util", + ":tensor_source", ":types", ":util", ":xla_coordinator", @@ -92,6 +98,7 @@ cc_library( ":env_vars", ":multi_wait", ":stablehlo_helper", + ":tensor_source", ":tf_logging", ":thread_pool", ":xla_coordinator", @@ -292,6 +299,17 @@ cc_library( ], ) +cc_library( + name = "tensor_source", + hdrs = ["tensor_source.h"], + deps = [ + ":debug_macros", + "@xla//xla:literal", + "@xla//xla:shape_util", + "@torch//:headers", + ] +) + cc_library( name = "types", hdrs = ["types.h"], @@ -376,25 +394,27 @@ cc_test( ], ) -# TODO(goranpetrovic): reenable when `xla_cc_test` is fixed upstream. -# xla_cc_test( -# name = "pjrt_computation_client_test", -# srcs = ["pjrt_computation_client_test.cc"], -# deps = [ -# ":computation_client", -# "@xla//xla:literal", -# "@xla//xla:literal_util", -# "@xla//xla:shape_util", -# "@xla//xla:status", -# "@xla//xla:statusor", -# "@xla//xla/client:xla_builder", -# "@xla//xla/client:xla_computation", -# "@xla//xla/tests:literal_test_util", -# "@xla//xla/tools:hlo_module_loader", -# "@org_tensorflow//tensorflow/core/platform:logging", -# "@tsl//tsl/lib/core:status_test_util", -# "@tsl//tsl/platform:env", -# "@tsl//tsl/platform:test", -# "@tsl//tsl/platform:test_main", -# ], -# ) +ptxla_cc_test( + name = "pjrt_computation_client_test", + srcs = ["pjrt_computation_client_test.cc"], + deps = [ + ":computation_client", + ":pjrt_computation_client", + ":tensor_source", + "@xla//xla:literal", + "@xla//xla:literal_util", + "@xla//xla:shape_util", + "@xla//xla:status", + "@xla//xla:statusor", + "@xla//xla/client:xla_builder", + "@xla//xla/client:xla_computation", + "@xla//xla/tests:literal_test_util", + "@xla//xla/tools:hlo_module_loader", + "@tsl//tsl/lib/core:status_test_util", + "@tsl//tsl/platform:env", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 145a6d0aa09..eee5a18e3f9 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -1,6 +1,7 @@ #ifndef XLA_CLIENT_COMPUTATION_CLIENT_H_ #define XLA_CLIENT_COMPUTATION_CLIENT_H_ +#include #include #include #include @@ -20,6 +21,7 @@ #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/metrics.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/types.h" #include "torch_xla/csrc/runtime/util.h" #include "xla/client/xla_computation.h" @@ -192,25 +194,6 @@ class ComputationClient { using ComputationPtr = std::shared_ptr; - // The TensorSource provides a way for a client to populate a buffer allocated - // by the core computation client code. - struct TensorSource { - // The PopulateFn accepts a dense buffer is standard array layout - // (dim0-major) and deposits the source tensor data directly over the - // provided buffer. - using PopulateFn = std::function; - - TensorSource() = default; - TensorSource(xla::Shape shape, std::string device, PopulateFn populate_fn) - : shape(std::move(shape)), - device(std::move(device)), - populate_fn(std::move(populate_fn)) {} - - xla::Shape shape; - std::string device; - PopulateFn populate_fn; - }; - // TODO(wcromar): Should CompileInstance still exist? Should it be a subclass // of torch::lazy::Computation? struct CompileInstance { @@ -275,13 +258,13 @@ class ComputationClient { // Transfers local tensor values to the TPU devices and fetches the handles. virtual std::vector TransferToServer( - absl::Span tensors) = 0; + absl::Span> tensors) = 0; // Transfers local sharded tensor values to the TPU devices and returns a // `PjRtShardedData`. virtual DataPtr TransferShardsToServer( - absl::Span tensor_shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) = 0; + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) = 0; // Copies `data->buffer` to `dst` device buffer. virtual DataPtr CopyToDevice(DataPtr data, std::string dst) = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index c003f4f9706..81dc1271129 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -12,6 +12,7 @@ #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/xla_coordinator.h" @@ -286,7 +287,7 @@ std::optional PjRtComputationClient::GetDataSharding( } std::vector PjRtComputationClient::TransferToServer( - absl::Span tensors) { + absl::Span> tensors) { metrics::TimedSection timed(TransferToServerMetric()); tsl::profiler::TraceMe activity("PjRtComputationClient::TransferToServer", tsl::profiler::TraceMeLevel::kInfo); @@ -294,31 +295,22 @@ std::vector PjRtComputationClient::TransferToServer( datas.reserve(tensors.size()); int64_t total_size = 0; for (auto& tensor : tensors) { - xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor.device); - - auto literal = std::make_shared(tensor.shape); - tensor.populate_fn(tensor, literal->untyped_data(), literal->size_bytes()); - std::vector byte_strides(literal->shape().dimensions_size()); - XLA_CHECK_OK(xla::ShapeUtil::ByteStrides(literal->shape(), - absl::MakeSpan(byte_strides))); - total_size += literal->size_bytes(); - - // Avoid use-after-free on `literal` due to unsequenced move and use. - xla::Literal* literal_pointer = literal.get(); - std::shared_ptr buffer = std::move( - client_ - ->BufferFromHostBuffer( - literal_pointer->untyped_data(), - literal_pointer->shape().element_type(), - literal_pointer->shape().dimensions(), byte_strides, - xla::PjRtClient::HostBufferSemantics:: - kImmutableUntilTransferCompletes, - [literal{std::move(literal)}]() { /* frees literal */ }, - pjrt_device) - .value()); + xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); + + total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); + + std::shared_ptr buffer = + std::move(client_ + ->BufferFromHostBuffer( + tensor->data(), tensor->shape().element_type(), + tensor->shape().dimensions(), tensor->byte_strides(), + xla::PjRtClient::HostBufferSemantics:: + kImmutableUntilTransferCompletes, + [tensor]() { /* frees tensor */ }, pjrt_device) + .value()); ComputationClient::DataPtr data = - std::make_shared(tensor.device, tensor.shape, buffer); + std::make_shared(tensor->device(), tensor->shape(), buffer); datas.push_back(data); } OutboundDataMetric()->AddSample(total_size); @@ -328,8 +320,8 @@ std::vector PjRtComputationClient::TransferToServer( } ComputationClient::DataPtr PjRtComputationClient::TransferShardsToServer( - absl::Span tensor_shards, std::string device, - xla::Shape shape, xla::OpSharding sharding) { + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) { tsl::profiler::TraceMe activity( "PjRtComputationClient::TransferShardsToServer", tsl::profiler::TraceMeLevel::kInfo); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index faebd4892b8..b66e4ff5097 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -37,7 +37,7 @@ class PjRtComputationClient : public ComputationClient { std::optional GetDataSharding(DataPtr handle) override; std::vector TransferToServer( - absl::Span tensors) override; + absl::Span> tensors) override; // Use XLA replication to re-assemble the sharded data. DataPtr ReplicateShardedData(const DataPtr& handle); @@ -45,9 +45,9 @@ class PjRtComputationClient : public ComputationClient { std::vector TransferFromServer( absl::Span handles) override; - DataPtr TransferShardsToServer(absl::Span tensor_shards, - std::string device, xla::Shape shape, - xla::OpSharding sharding) override; + DataPtr TransferShardsToServer( + absl::Span> tensor_shards, + std::string device, xla::Shape shape, xla::OpSharding sharding) override; DataPtr CopyToDevice(DataPtr data, std::string dst) override; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc index 24cbc4636a6..d6240f08e98 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client_test.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client_test.cc @@ -7,6 +7,8 @@ #include #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/pjrt_computation_client.h" +#include "torch_xla/csrc/runtime/tensor_source.h" #include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" @@ -32,17 +34,6 @@ tsl::StatusOr MakeComputation() { return builder.Build(); } -ComputationClient::TensorSource TensorSourceFromLiteral( - const std::string& device, const xla::Literal& literal) { - auto populate_fn = [&](const ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - std::memcpy(dest_buffer, literal.data().data(), - dest_buffer_size * sizeof(literal.data().data())); - }; - return ComputationClient::TensorSource(literal.shape(), device, - std::move(populate_fn)); -} - TEST(PjRtComputationClientTest, Init) { // Get a CPU client. tsl::setenv("PJRT_DEVICE", "CPU", true); @@ -69,9 +60,9 @@ TEST(PjRtComputationClientTest, Init) { // Copy inputs to device. ComputationClient::ExecuteComputationOptions options{}; - std::vector args = { - TensorSourceFromLiteral(device, literal_x), - TensorSourceFromLiteral(device, literal_y)}; + std::vector> args = { + std::make_shared(std::move(literal_x), device), + std::make_shared(std::move(literal_y), device)}; // Execute the graph. std::vector results = client->ExecuteComputation( diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h new file mode 100644 index 00000000000..4f24084e019 --- /dev/null +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -0,0 +1,70 @@ +#ifndef XLA_CLIENT_TENSOR_SOURCE_H_ +#define XLA_CLIENT_TENSOR_SOURCE_H_ + +#include + +#include + +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace torch_xla { +namespace runtime { + +// Owns a contiguous block of data with the shape and layout matching `shape()`. +class TensorSource { + public: + TensorSource(std::string device) : device_(std::move(device)){}; + + virtual const void* data() const = 0; + + virtual const xla::Shape& shape() const = 0; + + const std::string& device() const { return device_; } + + std::vector byte_strides() const { + std::vector byte_strides(shape().dimensions_size()); + XLA_CHECK_OK( + xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); + return byte_strides; + } + + private: + std::string device_; +}; + +class AtenSource : public TensorSource { + public: + AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) + : TensorSource(std::move(device)), + tensor_(std::move(tensor.contiguous())), + shape_(std::move(shape)) {} + + const void* data() const override { return tensor_.const_data_ptr(); } + + const xla::Shape& shape() const override { return shape_; } + + private: + at::Tensor tensor_; + xla::Shape shape_; +}; + +class LiteralSource : public TensorSource { + public: + LiteralSource(xla::Literal literal, std::string device) + : TensorSource(std::move(device)), literal_(std::move(literal)) {} + + const void* data() const override { return literal_.untyped_data(); } + + const xla::Shape& shape() const override { return literal_.shape(); } + + private: + xla::Literal literal_; +}; + +} // namespace runtime +} // namespace torch_xla + +#endif // XLA_CLIENT_COMPUTATION_CLIENT_H_ diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index a419bd98b7e..f3749a1ecd2 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -34,7 +34,7 @@ namespace torch_xla { namespace { struct DataAsync { - std::vector source_tensors; + std::vector> source_tensors; std::vector async_datas; std::vector handle_unlockers; }; @@ -587,15 +587,9 @@ torch::lazy::BackendDataPtr TensorToXlaData( sharding_spec); } - auto populate_fn = - [&](const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensor, source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - - std::vector source_tensors; - source_tensors.emplace_back(shape, device.toString(), std::move(populate_fn)); + std::vector> source_tensors; + source_tensors.push_back( + std::make_shared(tensor, shape, device.toString())); auto handles = runtime::GetComputationClient()->TransferToServer(source_tensors); @@ -817,19 +811,12 @@ std::vector CreateTensorsData( return WrapXlaData(handles); } - std::vector source_tensors; + std::vector> source_tensors; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - auto populate_fn = - [&, i, device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - source_tensors.emplace_back(std::move(shape), devices[i], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + tensors[i], std::move(shape), devices[i])); } return WrapXlaData( runtime::GetComputationClient()->TransferToServer(source_tensors)); @@ -848,7 +835,8 @@ std::vector CreateTensorsData( torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); - std::vector source_tensors; // in + std::vector> + source_tensors; // in std::vector new_handles; // out if (static_cast(device.type()) == XlaDeviceType::SPMD) { // GetLocalDevices returns the list of local devices specified by their @@ -864,15 +852,8 @@ std::vector CreateTensorsData( new_handles.push_back(ShardingUtil::CreateShardedData( local_shards, local_devices, shardings[i])); } else { - auto populate_fn = - [&, i, device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(tensors[i], source_tensor.shape, dest_buffer, - dest_buffer_size, device); - }; - source_tensors.emplace_back(std::move(shape), devices[i], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + tensors[i], std::move(shape), devices[i])); new_handles = runtime::GetComputationClient()->TransferToServer(source_tensors); } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index cde74256eee..e7fc8f7fdea 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -711,7 +711,7 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( const XLATensor::ShardingSpecPtr& sharding_spec) { XLA_CHECK(local_shards.size() == devices.size()) << "A device must be speficied for each shard"; - std::vector source_tensors; + std::vector> source_tensors; xla::Shape global_shape; xla::OpSharding sharding; if (sharding_spec == nullptr) { @@ -728,15 +728,8 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( auto shard_device = ParseDeviceString(devices[j]); auto shard_shape = CreateComputationShapeFromTensor(local_shards[j], &shard_device); - auto populate_fn = - [&, j, shard_device]( - const runtime::ComputationClient::TensorSource& source_tensor, - void* dest_buffer, size_t dest_buffer_size) { - PopulateTensorBuffer(local_shards[j], source_tensor.shape, - dest_buffer, dest_buffer_size, shard_device); - }; - source_tensors.emplace_back(shard_shape, devices[j], - std::move(populate_fn)); + source_tensors.push_back(std::make_shared( + local_shards[j], shard_shape, devices[j])); } return runtime::GetComputationClient()->TransferShardsToServer( source_tensors, GetVirtualDevice().toString(), global_shape, sharding);