diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 0d25b1488d8..bcad2fdd2af 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -297,8 +297,6 @@ std::vector PjRtComputationClient::TransferToServer( for (auto& tensor : tensors) { xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); - total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); - std::shared_ptr buffer = std::move(client_ ->BufferFromHostBuffer( @@ -310,7 +308,8 @@ std::vector PjRtComputationClient::TransferToServer( .value()); ComputationClient::DataPtr data = - std::make_shared(tensor->device(), tensor->shape(), buffer); + std::make_shared(tensor->device(), buffer); + total_size += xla::ShapeUtil::ByteSizeOf(data->shape()); datas.push_back(data); } OutboundDataMetric()->AddSample(total_size); diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index 11d4b2f71a5..c9c1df8c291 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -22,25 +22,13 @@ class TensorSource { virtual const void* data() const = 0; - virtual const xla::Shape& shape() const = 0; + virtual xla::PrimitiveType primitive_type() const = 0; - const std::string& device() const { return device_; } - - virtual std::vector byte_strides() const { - std::vector byte_strides(shape().dimensions_size()); - XLA_CHECK_OK( - xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); - return byte_strides; - } + virtual std::vector dimensions() const = 0; - virtual std::vector dimensions() const { - auto dimensions = shape().dimensions(); - return {dimensions.begin(), dimensions.end()}; - } + virtual std::vector byte_strides() const = 0; - virtual xla::PrimitiveType primitive_type() const { - return shape().element_type(); - } + const std::string& device() const { return device_; } private: std::string device_; @@ -48,8 +36,8 @@ class TensorSource { class AtenSource : public TensorSource { public: - AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) - : TensorSource(std::move(device)), shape_(std::move(shape)) { + AtenSource(const at::Tensor& tensor, xla::PrimitiveType target_type, std::string device) + : TensorSource(std::move(device)), target_type_(target_type_) { at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type()); if (target_torch_type != tensor.type().scalarType()) { TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); @@ -61,7 +49,12 @@ class AtenSource : public TensorSource { const void* data() const override { return tensor_.const_data_ptr(); } - const xla::Shape& shape() const override { return shape_; } + xla::PrimitiveType primitive_type() const override { return target_type_; } + + std::vector dimensions() const override { + auto sizes = tensor_.sizes(); + return {sizes.begin(), sizes.end()}; + } std::vector byte_strides() const override { std::vector strides; @@ -71,14 +64,9 @@ class AtenSource : public TensorSource { return strides; } - std::vector dimensions() const override { - auto sizes = tensor_.sizes(); - return {sizes.begin(), sizes.end()}; - } - private: at::Tensor tensor_; - xla::Shape shape_; + xla::PrimitiveType target_type_; }; class LiteralSource : public TensorSource { @@ -88,7 +76,23 @@ class LiteralSource : public TensorSource { const void* data() const override { return literal_.untyped_data(); } - const xla::Shape& shape() const override { return literal_.shape(); } + const xla::Shape& shape() const { return literal_.shape(); } + + xla::PrimitiveType primitive_type() const override { + return shape().element_type(); + } + + std::vector dimensions() const override { + auto dimensions = shape().dimensions(); + return {dimensions.begin(), dimensions.end()}; + } + + std::vector byte_strides() const override { + std::vector byte_strides(shape().dimensions_size()); + XLA_CHECK_OK( + xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); + return byte_strides; + } private: xla::Literal literal_; diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index f0869f16e9a..60859e35b0e 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -480,7 +480,7 @@ torch::lazy::BackendDataPtr TensorToXlaData( std::vector> source_tensors; source_tensors.push_back( - std::make_shared(tensor, shape, device.toString())); + std::make_shared(tensor, shape.element_type(), device.toString())); auto handles = runtime::GetComputationClient()->TransferToServer(source_tensors); @@ -705,9 +705,8 @@ std::vector CreateTensorsData( std::vector> source_tensors; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); - xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); source_tensors.push_back(std::make_shared( - tensors[i], std::move(shape), devices[i])); + tensors[i], MaybeDowncastForDevice(tensors[i].type().scalarType(), device), devices[i])); } return WrapXlaData( runtime::GetComputationClient()->TransferToServer(source_tensors)); @@ -724,7 +723,6 @@ std::vector CreateTensorsData( std::vector handles; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); - xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); std::vector> source_tensors; // in @@ -744,7 +742,7 @@ std::vector CreateTensorsData( local_shards, local_devices, shardings[i])); } else { source_tensors.push_back(std::make_shared( - tensors[i], std::move(shape), devices[i])); + tensors[i], MaybeDowncastForDevice(tensors[i].type().scalarType(), device), devices[i])); new_handles = runtime::GetComputationClient()->TransferToServer(source_tensors); } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 379f38bb890..4aef1076605 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -7,6 +7,7 @@ #include "torch/csrc/lazy/core/ir_util.h" #include "torch_xla/csrc/device.h" +#include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/runtime/computation_client.h" @@ -726,10 +727,8 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( } for (int64_t j = 0; j < devices.size(); ++j) { auto shard_device = ParseDeviceString(devices[j]); - auto shard_shape = - CreateComputationShapeFromTensor(local_shards[j], &shard_device); source_tensors.push_back(std::make_shared( - local_shards[j], shard_shape, devices[j])); + local_shards[j], MaybeDowncastForDevice(local_shards[j].type().scalarType(), shard_device), devices[j])); } return runtime::GetComputationClient()->TransferShardsToServer( source_tensors, GetVirtualDevice().toString(), global_shape, sharding);