Skip to content

Commit

Permalink
fix squeeze op lowering issue when dim is not in sorted order
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore committed Oct 31, 2023
1 parent 63ea76c commit 2e0804e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
12 changes: 12 additions & 0 deletions test/cpp/test_aten_xla_tensor_4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,18 @@ TEST_F(AtenXlaTensorTest, TestSqueezeMultipleDims) {
});
}

TEST_F(AtenXlaTensorTest, TestSqueezeDimWithNegativeOne) {
torch::Tensor input =
torch::rand({2, 1, 3, 1}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> dims = {-1};
torch::Tensor output = torch::squeeze(input, dims);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::squeeze(xla_input, dims);
AllClose(output, xla_output);
});
}

TEST_F(AtenXlaTensorTest, TestSqueezeOneInPlace) {
int rank = 4;
for (int dim = -rank; dim < rank; ++dim) {
Expand Down
23 changes: 23 additions & 0 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,29 @@ std::vector<int64_t> BuildSqueezedDimensions(
return output_dimensions;
}

std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, std::vector<int64_t>& squeeze_dims) {
std::sort(squeeze_dims.begin(), squeeze_dims.end());
std::vector<int64_t> output_dimensions;
size_t i = 0;
std::vector<int64_t> tmp(dimensions.begin(), dimensions.end());
std::cout << "in: " << tmp << " seq: " << squeeze_dims << std::endl;
for (size_t j = 0; j < dimensions.size(); j++) {
auto dim = dimensions[j];
if (i == squeeze_dims.size() || j < squeeze_dims[i]) {
output_dimensions.push_back(dim);
continue;
}
// Checks to see if we need to squeeze the dim or not.
if (dim != 1) {
output_dimensions.push_back(dim);
}
i++;
}
std::cout << "output_dims: " << output_dimensions << std::endl;
return output_dimensions;
}

std::vector<int64_t> BuildUnsqueezeDimensions(
absl::Span<const int64_t> dimensions, int64_t dim) {
XLA_CHECK_LE(dim, dimensions.size());
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,
std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, int64_t squeeze_dim);

std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, std::vector<int64_t>& squeeze_dim);

std::vector<int64_t> BuildUnsqueezeDimensions(
absl::Span<const int64_t> dimensions, int64_t dim);

Expand Down
13 changes: 7 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2471,16 +2471,17 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector<int64_t> dims) {
std::vector<int64_t> input_dimensions =
torch_xla::runtime::util::ToVector<int64_t>(
input_shape.get().dimensions());
std::vector<int64_t> output_dimensions;
std::vector<int64_t> squeeze_dims;
for (int64_t dim : dims) {
if (dim >= input_dimensions.size()) {
continue;
}
int64_t squeeze_dim =
torch::lazy::GetCanonicalDimensionIndex(dim, input_dimensions.size());
output_dimensions = BuildSqueezedDimensions(input_dimensions, squeeze_dim);
input_dimensions = output_dimensions;
if (squeeze_dim >= input_dimensions.size()) {
continue;
}
squeeze_dims.push_back(squeeze_dim);
}
std::vector<int64_t> output_dimensions =
BuildSqueezedDimensions(input_dimensions, squeeze_dims);
return view(input, output_dimensions);
}

Expand Down

0 comments on commit 2e0804e

Please sign in to comment.