diff --git a/test/cpp/test_aten_xla_tensor_4.cpp b/test/cpp/test_aten_xla_tensor_4.cpp index 1e61a6fa05a..0a0d84d463a 100644 --- a/test/cpp/test_aten_xla_tensor_4.cpp +++ b/test/cpp/test_aten_xla_tensor_4.cpp @@ -956,6 +956,18 @@ TEST_F(AtenXlaTensorTest, TestSqueezeMultipleDims) { }); } +TEST_F(AtenXlaTensorTest, TestSqueezeDimWithNegativeOne) { + torch::Tensor input = + torch::rand({2, 1, 3, 1}, torch::TensorOptions(torch::kFloat)); + std::vector dims = {-1}; + torch::Tensor output = torch::squeeze(input, dims); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor xla_output = torch::squeeze(xla_input, dims); + AllClose(output, xla_output); + }); +} + TEST_F(AtenXlaTensorTest, TestSqueezeOneInPlace) { int rank = 4; for (int dim = -rank; dim < rank; ++dim) { diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index e5425a93001..47ddf302939 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -175,6 +175,29 @@ std::vector BuildSqueezedDimensions( return output_dimensions; } +std::vector BuildSqueezedDimensions( + absl::Span dimensions, std::vector& squeeze_dims) { + std::sort(squeeze_dims.begin(), squeeze_dims.end()); + std::vector output_dimensions; + size_t i = 0; + std::vector tmp(dimensions.begin(), dimensions.end()); + std::cout << "in: " << tmp << " seq: " << squeeze_dims << std::endl; + for (size_t j = 0; j < dimensions.size(); j++) { + auto dim = dimensions[j]; + if (i == squeeze_dims.size() || j < squeeze_dims[i]) { + output_dimensions.push_back(dim); + continue; + } + // Checks to see if we need to squeeze the dim or not. + if (dim != 1) { + output_dimensions.push_back(dim); + } + i++; + } + std::cout << "output_dims: " << output_dimensions << std::endl; + return output_dimensions; +} + std::vector BuildUnsqueezeDimensions( absl::Span dimensions, int64_t dim) { XLA_CHECK_LE(dim, dimensions.size()); diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index e22821a7bb0..067a77abc23 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -49,6 +49,9 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, std::vector BuildSqueezedDimensions( absl::Span dimensions, int64_t squeeze_dim); +std::vector BuildSqueezedDimensions( + absl::Span dimensions, std::vector& squeeze_dim); + std::vector BuildUnsqueezeDimensions( absl::Span dimensions, int64_t dim); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index fa54741190d..4c6444b6565 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2471,16 +2471,17 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector dims) { std::vector input_dimensions = torch_xla::runtime::util::ToVector( input_shape.get().dimensions()); - std::vector output_dimensions; + std::vector squeeze_dims; for (int64_t dim : dims) { - if (dim >= input_dimensions.size()) { - continue; - } int64_t squeeze_dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_dimensions.size()); - output_dimensions = BuildSqueezedDimensions(input_dimensions, squeeze_dim); - input_dimensions = output_dimensions; + if (squeeze_dim >= input_dimensions.size()) { + continue; + } + squeeze_dims.push_back(squeeze_dim); } + std::vector output_dimensions = + BuildSqueezedDimensions(input_dimensions, squeeze_dims); return view(input, output_dimensions); }