From c17f1f685353857fa540e31a443a6d8dff3fb335 Mon Sep 17 00:00:00 2001 From: Pei Zhang Date: Tue, 7 Nov 2023 15:09:07 -0800 Subject: [PATCH] fix squeeze op lowering issue when dim is not in sorted order (#5751) * fix squeeze op lowering issue when dim is not in sorted order * remove debug info * remove debug info * refactor BuildSqueezedDimensions --- test/cpp/test_aten_xla_tensor_4.cpp | 12 ++++++++++++ torch_xla/csrc/data_ops.cpp | 27 ++++++++++++++++++++++++--- torch_xla/csrc/data_ops.h | 3 +++ torch_xla/csrc/tensor_methods.cpp | 13 +++++++------ 4 files changed, 46 insertions(+), 9 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp index 1e61a6fa05a..0a0d84d463a 100644 --- a/test/cpp/test_aten_xla_tensor_4.cpp +++ b/test/cpp/test_aten_xla_tensor_4.cpp @@ -956,6 +956,18 @@ TEST_F(AtenXlaTensorTest, TestSqueezeMultipleDims) { }); } +TEST_F(AtenXlaTensorTest, TestSqueezeDimWithNegativeOne) { + torch::Tensor input = + torch::rand({2, 1, 3, 1}, torch::TensorOptions(torch::kFloat)); + std::vector dims = {-1}; + torch::Tensor output = torch::squeeze(input, dims); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_output = torch::squeeze(xla_input, dims); + AllClose(output, xla_output); + }); +} + TEST_F(AtenXlaTensorTest, TestSqueezeOneInPlace) { int rank = 4; for (int dim = -rank; dim < rank; ++dim) { diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index e5425a93001..aad682dec70 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -165,12 +165,33 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, std::vector BuildSqueezedDimensions( absl::Span dimensions, int64_t squeeze_dim) { + std::vector squeeze_dims({squeeze_dim}); + return BuildSqueezedDimensions(dimensions, squeeze_dims); +} + +std::vector BuildSqueezedDimensions( + absl::Span dimensions, std::vector& squeeze_dims) { + std::sort(squeeze_dims.begin(), squeeze_dims.end()); std::vector output_dimensions; - for (int64_t i = 0; i < dimensions.size(); ++i) { - int64_t dim = dimensions[i]; - if (dim != 1 || (i != squeeze_dim && squeeze_dim >= 0)) { + size_t i = 0; + for (size_t j = 0; j < dimensions.size(); j++) { + auto dim = dimensions[j]; + if (squeeze_dims.size() == 1 && squeeze_dims[0] == -1) { + // Special case where squeeze_dims = {-1}. + if (dim != 1) { + output_dimensions.push_back(dim); + } + continue; + } + if (i == squeeze_dims.size() || j < squeeze_dims[i]) { + output_dimensions.push_back(dim); + continue; + } + // Checks to see if we need to squeeze the dim or not. + if (dim != 1) { output_dimensions.push_back(dim); } + i++; } return output_dimensions; } diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index e22821a7bb0..067a77abc23 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -49,6 +49,9 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, std::vector BuildSqueezedDimensions( absl::Span dimensions, int64_t squeeze_dim); +std::vector BuildSqueezedDimensions( + absl::Span dimensions, std::vector& squeeze_dim); + std::vector BuildUnsqueezeDimensions( absl::Span dimensions, int64_t dim); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index fa54741190d..4c6444b6565 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2471,16 +2471,17 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector dims) { std::vector input_dimensions = torch_xla::runtime::util::ToVector( input_shape.get().dimensions()); - std::vector output_dimensions; + std::vector squeeze_dims; for (int64_t dim : dims) { - if (dim >= input_dimensions.size()) { - continue; - } int64_t squeeze_dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_dimensions.size()); - output_dimensions = BuildSqueezedDimensions(input_dimensions, squeeze_dim); - input_dimensions = output_dimensions; + if (squeeze_dim >= input_dimensions.size()) { + continue; + } + squeeze_dims.push_back(squeeze_dim); } + std::vector output_dimensions = + BuildSqueezedDimensions(input_dimensions, squeeze_dims); return view(input, output_dimensions); }