From cdca4a850bb12b473686390deef840e033b5af1c Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 8 Nov 2023 00:10:09 +0000 Subject: [PATCH] better names --- test/cpp/cpp_test_util.cpp | 2 +- torch_xla/csrc/convert_ops.cpp | 2 +- torch_xla/csrc/dtype.cpp | 17 ++++++++--------- torch_xla/csrc/dtype.h | 11 ++++++----- torch_xla/csrc/helpers.cpp | 2 +- torch_xla/csrc/init_python_bindings.cpp | 4 ++-- torch_xla/csrc/tensor.cpp | 2 +- torch_xla/csrc/tensor_methods.cpp | 10 +++++----- torch_xla/csrc/tensor_util.cpp | 8 ++++---- torch_xla/csrc/xla_graph_executor.cpp | 2 +- torch_xla/csrc/xla_sharding_util.cpp | 2 +- 11 files changed, 31 insertions(+), 31 deletions(-) diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 707b273a588f..8efc9c90b348 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -307,7 +307,7 @@ std::vector Fetch( std::vector tensors; for (auto& literal : literals) { tensors.push_back(MakeTensorFromXlaLiteral( - literal, GetHostScalarType(literal.shape().element_type()))); + literal, MaybeUpcastForHost(literal.shape().element_type()))); } return tensors; } diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index 366e3b7a0640..3dd53faefaf6 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -104,7 +104,7 @@ xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) { if (from == xla::PrimitiveType::PRED) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); op = ConvertTo(op, from, - GetDevicePrimitiveType(xla::PrimitiveType::U8, xla_device), + MaybeDowncastForDevice(xla::PrimitiveType::U8, xla_device), &xla_device); } return op; diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index d047beb12c52..fcb5a7541095 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -154,7 +154,7 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) { } -xla::PrimitiveType GetDevicePrimitiveType( +xla::PrimitiveType MaybeDowncastForDevice( xla::PrimitiveType type, const torch::lazy::BackendDevice& device) { XlaDeviceType hw_type = static_cast(device.type()); switch (type) { @@ -197,7 +197,13 @@ xla::PrimitiveType GetDevicePrimitiveType( } } -at::ScalarType GetHostScalarType(xla::PrimitiveType xla_type) { +xla::PrimitiveType MaybeDowncastForDevice( + at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) { + xla::PrimitiveType xla_type = XlaTypeFromTorchType(scalar_type); + return MaybeDowncastForDevice(xla_type, device); +} + +at::ScalarType MaybeUpcastForHost(xla::PrimitiveType xla_type) { at::ScalarType scalar_type = TorchTypeFromXlaType(xla_type); switch (scalar_type) { case at::ScalarType::BFloat16: @@ -214,11 +220,4 @@ at::ScalarType GetHostScalarType(xla::PrimitiveType xla_type) { } } -xla::PrimitiveType GetXlaTypeFromTensorType( - at::ScalarType scalar_type, const torch::lazy::BackendDevice& device) { - // XlaDeviceType hw_type = static_cast(device.type()); - xla::PrimitiveType xla_type = XlaTypeFromTorchType(scalar_type); - return GetDevicePrimitiveType(xla_type, device); -} - } diff --git a/torch_xla/csrc/dtype.h b/torch_xla/csrc/dtype.h index 10eb242fc16e..084d7803e231 100644 --- a/torch_xla/csrc/dtype.h +++ b/torch_xla/csrc/dtype.h @@ -10,15 +10,16 @@ at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type); xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type); -// TODO better names -xla::PrimitiveType GetDevicePrimitiveType( +// Downcast type to be compatible with device if necessary. +xla::PrimitiveType MaybeDowncastForDevice( xla::PrimitiveType type, const torch::lazy::BackendDevice& device); -at::ScalarType GetHostScalarType(xla::PrimitiveType xla_type); - -xla::PrimitiveType GetXlaTypeFromTensorType( +xla::PrimitiveType MaybeDowncastForDevice( at::ScalarType scalar_type, const torch::lazy::BackendDevice& device); +// Upcast type to original PyTorch type. +at::ScalarType MaybeUpcastForHost(xla::PrimitiveType xla_type); + } #endif // XLA_TORCH_XLA_CSRC_DTYPE_H_ diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index ead820d03b40..6c3a4a64d892 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -681,7 +681,7 @@ xla::StatusOr XlaHelpers::WrapXlaComputation( } torch::lazy::Shape XlaHelpers::ConvertXlaShapeToLazy(const xla::Shape& shape) { - at::ScalarType scalar_type = GetHostScalarType(shape.element_type()); + at::ScalarType scalar_type = MaybeUpcastForHost(shape.element_type()); c10::optional> is_symbolic = c10::nullopt; if (shape.is_dynamic()) { std::vector xla_dynamic_dimensions = diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 4cd5e5b201c8..6e81f39080d7 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -801,7 +801,7 @@ class PyLoweringContext { xla::Literal& literal = literals[i]; xla::XlaOp op = lowering_ctx.GetParameter(device_data[i]); at::ScalarType dtype = - GetHostScalarType(literal.shape().element_type()); + MaybeUpcastForHost(literal.shape().element_type()); at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); results[param_ids[i]] = input; } @@ -1760,7 +1760,7 @@ void InitXlaModuleBindings(py::module m) { shards.push_back( XlaDataToTensors( {shard_handle}, - GetHostScalarType(shard_handle->shape().element_type())) + MaybeUpcastForHost(shard_handle->shape().element_type())) .front()); str_devices.push_back(shard_handle->device()); } diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 80e24b6aaaca..11d3c8d17bb5 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -159,7 +159,7 @@ int64_t XLATensor::size(int64_t dim) const { at::ScalarType XLATensor::dtype() const { return data()->logical_element_type ? *data()->logical_element_type - : GetHostScalarType(shape().get().element_type()); + : MaybeUpcastForHost(shape().get().element_type()); } c10::optional XLATensor::dtype_optional() const { diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bd9b7ead6cfc..100484bd9bff 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1027,9 +1027,9 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other, bool input_is_float = xla::primitive_util::IsFloatingPointType(input_type); bool other_is_float = xla::primitive_util::IsFloatingPointType(other_type); if (input_is_float && !other_is_float) { - scalar_type = GetHostScalarType(input_type); + scalar_type = MaybeUpcastForHost(input_type); } else if (!input_is_float && other_is_float) { - scalar_type = GetHostScalarType(other_type); + scalar_type = MaybeUpcastForHost(other_type); } // We need to cast both input and other to float to perform true divide, floor // divide and trunc divide. @@ -1074,7 +1074,7 @@ XLATensorPtr div(const XLATensorPtr& input, const at::Scalar& other) { xla::PrimitiveType input_type = input->shape().get().element_type(); bool input_is_float = xla::primitive_util::IsFloatingPointType(input_type); if (input_is_float) { - scalar_type = GetHostScalarType(input_type); + scalar_type = MaybeUpcastForHost(input_type); } torch::lazy::Value input_value = GetFloatingIrValue(input, scalar_type); torch::lazy::Value other_value = XLAGraphExecutor::Get()->GetIrValueForScalar( @@ -1182,7 +1182,7 @@ XLATensorPtr eye(int64_t lines, int64_t cols, void eye_out(XLATensorPtr& out, int64_t lines, int64_t cols) { out->SetIrValue( Identity(lines, cols >= 0 ? cols : lines, - GetDevicePrimitiveType(out->shape().get().element_type(), + MaybeDowncastForDevice(out->shape().get().element_type(), out->GetDevice()))); } @@ -2056,7 +2056,7 @@ XLATensorPtr pow(const at::Scalar& input, const XLATensorPtr& exponent) { torch::lazy::NodePtr pow_node = Pow(input_node, exponent->GetIrValue()); at::ScalarType input_dtype = GetScalarType(input); at::ScalarType exp_dtype = exponent->dtype(); - at::ScalarType promoted_dtype = GetHostScalarType(XlaHelpers::PromoteType( + at::ScalarType promoted_dtype = MaybeUpcastForHost(XlaHelpers::PromoteType( XlaTypeFromTorchType(input_dtype), XlaTypeFromTorchType(exp_dtype))); return exponent->CreateFrom(pow_node, promoted_dtype); } diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 5e98b2b5b6b5..f0869f16e9a1 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -386,7 +386,7 @@ void TensorToBuffer(const at::Tensor& tensor, const xla::Shape& dest_shape, at::Tensor contiguous_tensor = tensor.contiguous(); xla::Shape src_shape = MakeTorchTensorLayout( XlaHelpers::I64List(contiguous_tensor.sizes()), /*dynamic_dimensions=*/{}, - GetXlaTypeFromTensorType(contiguous_tensor.type().scalarType(), device)); + MaybeDowncastForDevice(contiguous_tensor.type().scalarType(), device)); CopyTensors(contiguous_tensor.data_ptr(), src_shape, dest_buffer, dest_buffer_size, dest_shape); } @@ -761,7 +761,7 @@ xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape, auto dimensions = XlaHelpers::I64List(tensor.sizes()); computed_shape = MakeTorchTensorLayout( dimensions, /*dynamic_dimensions=*/{}, - GetXlaTypeFromTensorType(tensor.type().scalarType(), xla_device)); + MaybeDowncastForDevice(tensor.type().scalarType(), xla_device)); shape = &computed_shape; } xla::Literal literal(*shape); @@ -859,13 +859,13 @@ xla::Shape CreateComputationShapeFromTensor( xla::PrimitiveType GetXlaPrimitiveTypeForCurrentDevice( xla::PrimitiveType xla_type) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); - return GetDevicePrimitiveType(xla_type, xla_device); + return MaybeDowncastForDevice(xla_type, xla_device); } xla::PrimitiveType MakeXlaPrimitiveType( at::ScalarType scalar_type, const torch::lazy::BackendDevice* device) { torch::lazy::BackendDevice xla_device = bridge::GetDeviceOrCurrent(device); - return GetXlaTypeFromTensorType(scalar_type, xla_device); + return MaybeDowncastForDevice(scalar_type, xla_device); } xla::Shape MakeXlaShapeFromLazyShape(torch::lazy::Shape shape, diff --git a/torch_xla/csrc/xla_graph_executor.cpp b/torch_xla/csrc/xla_graph_executor.cpp index 776eb03e5d7f..7a91c446e595 100644 --- a/torch_xla/csrc/xla_graph_executor.cpp +++ b/torch_xla/csrc/xla_graph_executor.cpp @@ -217,7 +217,7 @@ torch::lazy::Value XLAGraphExecutor::GetDeviceDataIrValue( const at::Scalar& value, xla::PrimitiveType type, const torch::lazy::BackendDevice& device) { torch::lazy::BackendDataPtr data = - GetDeviceData(value, GetHostScalarType(type), device); + GetDeviceData(value, MaybeUpcastForHost(type), device); data->SetInfo( std::make_shared( /*tensor_id=*/-1, /*read_only=*/true)); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index ffa6d402a0ea..53b51ceb8851 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -358,7 +358,7 @@ std::vector ShardingUtil::OutputHandler( // Reshards replicated output if `sharding` is present. std::vector tensors = XlaDataToTensors( {sharded_results[0][i]}, - GetHostScalarType(sharding->shape.element_type())); + MaybeUpcastForHost(sharding->shape.element_type())); outputs.push_back( std::dynamic_pointer_cast( CreateTensorsData(