Skip to content

Commit

Permalink
fix squeeze op lowering issue when dim is not in sorted order (#5751)
Browse files Browse the repository at this point in the history
* fix squeeze op lowering issue when dim is not in sorted order

* remove debug info

* remove debug info

* refactor BuildSqueezedDimensions
  • Loading branch information
zpcore authored and golechwierowicz committed Jan 12, 2024
1 parent 1ea485f commit 620da85
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 9 deletions.
12 changes: 12 additions & 0 deletions test/cpp/test_aten_xla_tensor_4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,18 @@ TEST_F(AtenXlaTensorTest, TestSqueezeMultipleDims) {
});
}

TEST_F(AtenXlaTensorTest, TestSqueezeDimWithNegativeOne) {
torch::Tensor input =
torch::rand({2, 1, 3, 1}, torch::TensorOptions(torch::kFloat));
std::vector<int64_t> dims = {-1};
torch::Tensor output = torch::squeeze(input, dims);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_input = CopyToDevice(input, device);
torch::Tensor xla_output = torch::squeeze(xla_input, dims);
AllClose(output, xla_output);
});
}

TEST_F(AtenXlaTensorTest, TestSqueezeOneInPlace) {
int rank = 4;
for (int dim = -rank; dim < rank; ++dim) {
Expand Down
27 changes: 24 additions & 3 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,33 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,

std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, int64_t squeeze_dim) {
std::vector<int64_t> squeeze_dims({squeeze_dim});
return BuildSqueezedDimensions(dimensions, squeeze_dims);
}

std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, std::vector<int64_t>& squeeze_dims) {
std::sort(squeeze_dims.begin(), squeeze_dims.end());
std::vector<int64_t> output_dimensions;
for (int64_t i = 0; i < dimensions.size(); ++i) {
int64_t dim = dimensions[i];
if (dim != 1 || (i != squeeze_dim && squeeze_dim >= 0)) {
size_t i = 0;
for (size_t j = 0; j < dimensions.size(); j++) {
auto dim = dimensions[j];
if (squeeze_dims.size() == 1 && squeeze_dims[0] == -1) {
// Special case where squeeze_dims = {-1}.
if (dim != 1) {
output_dimensions.push_back(dim);
}
continue;
}
if (i == squeeze_dims.size() || j < squeeze_dims[i]) {
output_dimensions.push_back(dim);
continue;
}
// Checks to see if we need to squeeze the dim or not.
if (dim != 1) {
output_dimensions.push_back(dim);
}
i++;
}
return output_dimensions;
}
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/data_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,
std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, int64_t squeeze_dim);

std::vector<int64_t> BuildSqueezedDimensions(
absl::Span<const int64_t> dimensions, std::vector<int64_t>& squeeze_dim);

std::vector<int64_t> BuildUnsqueezeDimensions(
absl::Span<const int64_t> dimensions, int64_t dim);

Expand Down
13 changes: 7 additions & 6 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2471,16 +2471,17 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector<int64_t> dims) {
std::vector<int64_t> input_dimensions =
torch_xla::runtime::util::ToVector<int64_t>(
input_shape.get().dimensions());
std::vector<int64_t> output_dimensions;
std::vector<int64_t> squeeze_dims;
for (int64_t dim : dims) {
if (dim >= input_dimensions.size()) {
continue;
}
int64_t squeeze_dim =
torch::lazy::GetCanonicalDimensionIndex(dim, input_dimensions.size());
output_dimensions = BuildSqueezedDimensions(input_dimensions, squeeze_dim);
input_dimensions = output_dimensions;
if (squeeze_dim >= input_dimensions.size()) {
continue;
}
squeeze_dims.push_back(squeeze_dim);
}
std::vector<int64_t> output_dimensions =
BuildSqueezedDimensions(input_dimensions, squeeze_dims);
return view(input, output_dimensions);
}

Expand Down

0 comments on commit 620da85

Please sign in to comment.