diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index bcad2fdd2afa..0d25b1488d81 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -297,6 +297,8 @@ std::vector PjRtComputationClient::TransferToServer( for (auto& tensor : tensors) { xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device()); + total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape()); + std::shared_ptr buffer = std::move(client_ ->BufferFromHostBuffer( @@ -308,8 +310,7 @@ std::vector PjRtComputationClient::TransferToServer( .value()); ComputationClient::DataPtr data = - std::make_shared(tensor->device(), buffer); - total_size += xla::ShapeUtil::ByteSizeOf(data->shape()); + std::make_shared(tensor->device(), tensor->shape(), buffer); datas.push_back(data); } OutboundDataMetric()->AddSample(total_size); diff --git a/torch_xla/csrc/runtime/tensor_source.h b/torch_xla/csrc/runtime/tensor_source.h index cc10728b0ce8..11d4b2f71a55 100644 --- a/torch_xla/csrc/runtime/tensor_source.h +++ b/torch_xla/csrc/runtime/tensor_source.h @@ -22,13 +22,25 @@ class TensorSource { virtual const void* data() const = 0; - virtual xla::PrimitiveType primitive_type() const = 0; + virtual const xla::Shape& shape() const = 0; - virtual std::vector dimensions() const = 0; + const std::string& device() const { return device_; } - virtual std::vector byte_strides() const = 0; + virtual std::vector byte_strides() const { + std::vector byte_strides(shape().dimensions_size()); + XLA_CHECK_OK( + xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); + return byte_strides; + } - const std::string& device() const { return device_; } + virtual std::vector dimensions() const { + auto dimensions = shape().dimensions(); + return {dimensions.begin(), dimensions.end()}; + } + + virtual xla::PrimitiveType primitive_type() const { + return shape().element_type(); + } private: std::string device_; @@ -36,9 +48,8 @@ class TensorSource { class AtenSource : public TensorSource { public: - AtenSource(const at::Tensor& tensor, xla::PrimitiveType target_type, - std::string device) - : TensorSource(std::move(device)), target_type_(target_type) { + AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device) + : TensorSource(std::move(device)), shape_(std::move(shape)) { at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type()); if (target_torch_type != tensor.type().scalarType()) { TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1); @@ -50,12 +61,7 @@ class AtenSource : public TensorSource { const void* data() const override { return tensor_.const_data_ptr(); } - xla::PrimitiveType primitive_type() const override { return target_type_; } - - std::vector dimensions() const override { - auto sizes = tensor_.sizes(); - return {sizes.begin(), sizes.end()}; - } + const xla::Shape& shape() const override { return shape_; } std::vector byte_strides() const override { std::vector strides; @@ -65,9 +71,14 @@ class AtenSource : public TensorSource { return strides; } + std::vector dimensions() const override { + auto sizes = tensor_.sizes(); + return {sizes.begin(), sizes.end()}; + } + private: at::Tensor tensor_; - xla::PrimitiveType target_type_; + xla::Shape shape_; }; class LiteralSource : public TensorSource { @@ -77,23 +88,7 @@ class LiteralSource : public TensorSource { const void* data() const override { return literal_.untyped_data(); } - const xla::Shape& shape() const { return literal_.shape(); } - - xla::PrimitiveType primitive_type() const override { - return shape().element_type(); - } - - std::vector dimensions() const override { - auto dimensions = shape().dimensions(); - return {dimensions.begin(), dimensions.end()}; - } - - std::vector byte_strides() const override { - std::vector byte_strides(shape().dimensions_size()); - XLA_CHECK_OK( - xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides))); - return byte_strides; - } + const xla::Shape& shape() const override { return literal_.shape(); } private: xla::Literal literal_; diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index fbda300e884f..f0869f16e9a1 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -479,8 +479,8 @@ torch::lazy::BackendDataPtr TensorToXlaData( } std::vector> source_tensors; - source_tensors.push_back(std::make_shared( - tensor, shape.element_type(), device.toString())); + source_tensors.push_back( + std::make_shared(tensor, shape, device.toString())); auto handles = runtime::GetComputationClient()->TransferToServer(source_tensors); @@ -705,10 +705,9 @@ std::vector CreateTensorsData( std::vector> source_tensors; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); + xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); source_tensors.push_back(std::make_shared( - tensors[i], - MaybeDowncastForDevice(tensors[i].type().scalarType(), device), - devices[i])); + tensors[i], std::move(shape), devices[i])); } return WrapXlaData( runtime::GetComputationClient()->TransferToServer(source_tensors)); @@ -725,6 +724,7 @@ std::vector CreateTensorsData( std::vector handles; for (size_t i = 0; i < tensors.size(); ++i) { torch::lazy::BackendDevice device = ParseDeviceString(devices[i]); + xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device); std::vector> source_tensors; // in @@ -744,9 +744,7 @@ std::vector CreateTensorsData( local_shards, local_devices, shardings[i])); } else { source_tensors.push_back(std::make_shared( - tensors[i], - MaybeDowncastForDevice(tensors[i].type().scalarType(), device), - devices[i])); + tensors[i], std::move(shape), devices[i])); new_handles = runtime::GetComputationClient()->TransferToServer(source_tensors); } diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 8d7ef17b3648..d0912997f0b3 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -727,11 +727,10 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( } for (int64_t j = 0; j < devices.size(); ++j) { auto shard_device = ParseDeviceString(devices[j]); + auto shard_shape = + CreateComputationShapeFromTensor(local_shards[j], &shard_device); source_tensors.push_back(std::make_shared( - local_shards[j], - MaybeDowncastForDevice(local_shards[j].type().scalarType(), - shard_device), - devices[j])); + local_shards[j], shard_shape, devices[j])); } return runtime::GetComputationClient()->TransferShardsToServer( source_tensors, GetVirtualDevice().toString(), global_shape, sharding);