diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp old mode 100644 new mode 100755 index fcdf4b11a6f..f190403dd87 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -2117,6 +2117,19 @@ TEST_F(AtenXlaTensorTest, TestArgMinDimKeep) { ExpectCounterChanged("xla::argmin", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestArgMinDimKeepNoDim) { + torch::Tensor a = torch::rand({4, 4, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor b = torch::argmin(a, c10::nullopt, /*keepdim=*/true); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = torch::argmin(xla_a, c10::nullopt, /*keepdim=*/true); + AllEqual(b, xla_b); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::argmin", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestArgMinSameValue) { torch::Tensor a = torch::ones({4, 4, 4}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::argmin(a); @@ -2188,6 +2201,19 @@ TEST_F(AtenXlaTensorTest, TestArgMaxDimKeep) { ExpectCounterChanged("xla::argmax", cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestArgMaxDimKeepNoDim) { + torch::Tensor a = torch::rand({4, 4, 4}, torch::TensorOptions(torch::kFloat)); + torch::Tensor b = torch::argmax(a, c10::nullopt, /*keepdim=*/true); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = torch::argmax(xla_a, c10::nullopt, /*keepdim=*/true); + AllEqual(b, xla_b); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::argmax", cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestArgMaxSameValue) { torch::Tensor a = torch::ones({4, 4, 4}, torch::TensorOptions(torch::kFloat)); torch::Tensor b = torch::argmax(a, c10::nullopt, /*keepdim=*/false); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp old mode 100644 new mode 100755 index 0f84734ddaa..50469820a50 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -124,7 +124,7 @@ torch_xla::XlaOpVector Argmax::Lower(LoweringContext* loctx) const { return ReturnOp(torch_xla::BuildArgMax(input, canonical_dim, keepdim), loctx); } else { - return ReturnOp(torch_xla::BuildArgMax(input, -1, false), loctx); + return ReturnOp(torch_xla::BuildArgMax(input, -1, keepdim), loctx); } } @@ -137,7 +137,7 @@ torch_xla::XlaOpVector Argmin::Lower(LoweringContext* loctx) const { return ReturnOp(torch_xla::BuildArgMin(input, canonical_dim, keepdim), loctx); } else { - return ReturnOp(torch_xla::BuildArgMin(input, -1, false), loctx); + return ReturnOp(torch_xla::BuildArgMin(input, -1, keepdim), loctx); } } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp old mode 100644 new mode 100755 index 485fa3f95bf..9de0afe6b01 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -201,7 +201,7 @@ xla::Shape ArgmaxOutputShape(const torch::lazy::Value& input, dim.value(), input_shape.rank()); return BuildArgMax(operands[0], {canonical_dim}, keepdim); } else { - return BuildArgMax(operands[0], {-1}, false); + return BuildArgMax(operands[0], {-1}, keepdim); } }; return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); @@ -217,7 +217,7 @@ xla::Shape ArgminOutputShape(const torch::lazy::Value& input, dim.value(), input_shape.rank()); return BuildArgMin(operands[0], {canonical_dim}, keepdim); } else { - return BuildArgMin(operands[0], {-1}, false); + return BuildArgMin(operands[0], {-1}, keepdim); } }; return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 94041ef6862..46c1ec5f35d 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -409,18 +409,28 @@ xla::XlaOp BuildMinInDims(xla::XlaOp input, xla::XlaOp BuildArgMax(xla::XlaOp input, int64_t dim, bool keepdim) { const xla::Shape* shape = &ShapeHelper::ShapeOfXlaOp(input); xla::XlaOp operand = input; + bool dim_is_none = false; if (dim < 0) { dim = 0; + dim_is_none = true; operand = XlaHelpers::DynamicReshape(operand, {xla::ShapeUtil::ElementsIn(*shape)}); - shape = &ShapeHelper::ShapeOfXlaOp(operand); + if (!keepdim) { + shape = &ShapeHelper::ShapeOfXlaOp(operand); + } } xla::XlaOp result = xla::ArgMax( operand, GetXlaPrimitiveTypeForCurrentDevice(xla::PrimitiveType::S64), dim); if (keepdim) { auto dimensions = torch::lazy::ToVector(shape->dimensions()); - dimensions[dim] = 1; + if (dim_is_none) { + for (auto& dim_it : dimensions) { + dim_it = 1; + } + } else { + dimensions[dim] = 1; + } result = XlaHelpers::DynamicReshape(result, dimensions); } return result; @@ -429,18 +439,28 @@ xla::XlaOp BuildArgMax(xla::XlaOp input, int64_t dim, bool keepdim) { xla::XlaOp BuildArgMin(xla::XlaOp input, int64_t dim, bool keepdim) { const xla::Shape* shape = &ShapeHelper::ShapeOfXlaOp(input); xla::XlaOp operand = input; + bool dim_is_none = false; if (dim < 0) { dim = 0; + dim_is_none = true; operand = XlaHelpers::DynamicReshape(operand, {xla::ShapeUtil::ElementsIn(*shape)}); - shape = &ShapeHelper::ShapeOfXlaOp(operand); + if (!keepdim) { + shape = &ShapeHelper::ShapeOfXlaOp(operand); + } } xla::XlaOp result = xla::ArgMin( operand, GetXlaPrimitiveTypeForCurrentDevice(xla::PrimitiveType::S64), dim); if (keepdim) { auto dimensions = torch::lazy::ToVector(shape->dimensions()); - dimensions[dim] = 1; + if (dim_is_none) { + for (auto& dim_it : dimensions) { + dim_it = 1; + } + } else { + dimensions[dim] = 1; + } result = XlaHelpers::DynamicReshape(result, dimensions); } return result;