From ef753f414dd032760f12ded27209033298c2980e Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Thu, 1 Feb 2024 14:18:33 -0800 Subject: [PATCH] Revert "Avoid fallback for avg_pool - (#6409)" This reverts commit a0bae8210fb98ee92b8fea00775508ff3c71c982. --- test/cpp/test_tensor.cpp | 12 ++--- test/test_core_aten_ops.py | 11 ----- torch_xla/csrc/aten_xla_type.cpp | 20 ++++++-- torch_xla/csrc/ops/avg_pool_nd.cpp | 17 +------ torch_xla/csrc/ops/avg_pool_nd.h | 4 +- torch_xla/csrc/pooling.cpp | 74 ++++++++++++------------------ torch_xla/csrc/tensor_methods.cpp | 6 +-- torch_xla/csrc/tensor_methods.h | 3 +- 8 files changed, 55 insertions(+), 92 deletions(-) diff --git a/test/cpp/test_tensor.cpp b/test/cpp/test_tensor.cpp index eb3b52676eb..757aea62075 100644 --- a/test/cpp/test_tensor.cpp +++ b/test/cpp/test_tensor.cpp @@ -316,8 +316,7 @@ TEST_F(TensorTest, TestAvgPool2D) { /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, /*padding=*/{padding, padding}, - /*ceil_mode=*/false, count_include_pad, - /*divisor_override=*/std::nullopt); + /*ceil_mode=*/false, count_include_pad); ForEachDevice([&](const torch::lazy::BackendDevice& device) { XLATensorPtr dev_input = XLATensor::Create(input, device); XLATensorPtr dev_output = tensor_methods::avg_pool_nd( @@ -326,8 +325,7 @@ TEST_F(TensorTest, TestAvgPool2D) { /*kernel_size=*/{kernel_size, kernel_size}, /*stride=*/{stride, stride}, /*padding=*/{padding, padding}, - /*ceil_mode=*/false, count_include_pad, - /*divisor_override=*/std::nullopt); + /*ceil_mode=*/false, count_include_pad); AllClose(output, dev_output); }); } @@ -346,8 +344,7 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) { /*kernel_size=*/{kernel_size, kernel_size + 1}, /*stride=*/{stride, stride + 1}, /*padding=*/{padding, padding + 1}, /*ceil_mode=*/false, - /*count_include_pad=*/count_include_pad, - /*divisor_override=*/std::nullopt); + /*count_include_pad=*/count_include_pad); ForEachDevice([&](const torch::lazy::BackendDevice& device) { XLATensorPtr dev_input = XLATensor::Create(input, device); XLATensorPtr dev_output = tensor_methods::avg_pool_nd( @@ -357,8 +354,7 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) { /*stride=*/{stride, stride + 1}, /*padding=*/{padding, padding + 1}, /*ceil_mode=*/false, - /*count_include_pad=*/count_include_pad, - /*divisor_override=*/std::nullopt); + /*count_include_pad=*/count_include_pad); AllClose(output, dev_output); }); } diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 220883e6b4e..ab357f57e13 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -737,17 +737,6 @@ def test_aten_avg_pool2d_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) - def test_aten_avg_pool2d_2(self): - args = ( - torch.randn((1, 192, 40, 40)).to(torch.float32), - [3, 3], - [1, 1], - [1, 1], - True, - ) - kwargs = dict() - run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs) - def test_aten_avg_pool3d_0(self): args = ( torch.randn((1, 3, 10, 10, 10)).to(torch.float32), diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 41ccfc7f4ec..770fab61ce4 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -760,11 +760,17 @@ at::Tensor XLANativeFunctions::avg_pool2d( at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + if ((ceil_mode && count_include_pad) || divisor_override) { + return at::native::call_fallback_fn< + &xla_cpu_fallback, ATEN_OP(avg_pool2d)>::call(self, kernel_size, stride, + padding, ceil_mode, + count_include_pad, + divisor_override); + } return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd( bridge::GetXlaTensor(self), /*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode, count_include_pad, - divisor_override)); + XlaHelpers::I64List(padding), ceil_mode, count_include_pad)); } at::Tensor XLANativeFunctions::avg_pool2d_backward( @@ -791,11 +797,17 @@ at::Tensor XLANativeFunctions::avg_pool3d( at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, c10::optional divisor_override) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); + if ((ceil_mode && count_include_pad) || divisor_override) { + return at::native::call_fallback_fn< + &xla_cpu_fallback, ATEN_OP(avg_pool3d)>::call(self, kernel_size, stride, + padding, ceil_mode, + count_include_pad, + divisor_override); + } return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd( bridge::GetXlaTensor(self), /*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride), - XlaHelpers::I64List(padding), ceil_mode, count_include_pad, - divisor_override)); + XlaHelpers::I64List(padding), ceil_mode, count_include_pad)); } at::Tensor XLANativeFunctions::avg_pool3d_backward( diff --git a/torch_xla/csrc/ops/avg_pool_nd.cpp b/torch_xla/csrc/ops/avg_pool_nd.cpp index 783ba6599e0..8dfa259d45f 100644 --- a/torch_xla/csrc/ops/avg_pool_nd.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd.cpp @@ -1,12 +1,10 @@ #include "torch_xla/csrc/ops/avg_pool_nd.h" #include "absl/strings/str_join.h" -#include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/shape_helper.h" namespace torch_xla { namespace { @@ -47,8 +45,7 @@ c10::Symbol AvgPoolNdSymbol(int64_t spatial_dim_count) { AvgPoolNd::AvgPoolNd(const torch::lazy::Value& input, int64_t spatial_dim_count, std::vector kernel_size, std::vector stride, std::vector padding, - bool ceil_mode, bool count_include_pad, - std::optional divisor_override) + bool ceil_mode, bool count_include_pad) : XlaNode(torch::lazy::OpKind(AvgPoolNdSymbol(spatial_dim_count)), {input}, [&]() { return NodeOutputShape(input, spatial_dim_count, kernel_size, @@ -63,8 +60,7 @@ AvgPoolNd::AvgPoolNd(const torch::lazy::Value& input, int64_t spatial_dim_count, stride_(std::move(stride)), padding_(std::move(padding)), ceil_mode_(ceil_mode), - count_include_pad_(count_include_pad), - divisor_override_(divisor_override) {} + count_include_pad_(count_include_pad) {} torch::lazy::NodePtr AvgPoolNd::Clone(torch::lazy::OpList operands) const { return torch::lazy::MakeNode(operands.at(0), spatial_dim_count_, @@ -77,15 +73,6 @@ XlaOpVector AvgPoolNd::Lower(LoweringContext* loctx) const { xla::XlaOp output = BuildAvgPoolNd(input, spatial_dim_count_, kernel_size_, stride_, padding_, ceil_mode_, count_include_pad_); - if (divisor_override_) { - auto dtype = ShapeHelper::ShapeOfXlaOp(output).element_type(); - auto* builder = loctx->builder(); - int size = std::accumulate(kernel_size_.begin(), kernel_size_.end(), 1, - std::multiplies()); - output = xla::Div( - xla::Mul(output, XlaHelpers::ScalarValue(size, dtype, builder)), - XlaHelpers::ScalarValue(*divisor_override_, dtype, builder)); - } return ReturnOp(output, loctx); } diff --git a/torch_xla/csrc/ops/avg_pool_nd.h b/torch_xla/csrc/ops/avg_pool_nd.h index eaada54fc8b..ace84aed59b 100644 --- a/torch_xla/csrc/ops/avg_pool_nd.h +++ b/torch_xla/csrc/ops/avg_pool_nd.h @@ -10,8 +10,7 @@ class AvgPoolNd : public XlaNode { AvgPoolNd(const torch::lazy::Value& input, int64_t spatial_dim_count, std::vector kernel_size, std::vector stride, std::vector padding, bool ceil_mode, - bool count_include_pad, - std::optional divisor_override = std::nullopt); + bool count_include_pad); torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override; @@ -41,7 +40,6 @@ class AvgPoolNd : public XlaNode { // Whether the counts used to compute the average should include the added // padding. bool count_include_pad_; - std::optional divisor_override_; }; } // namespace torch_xla diff --git a/torch_xla/csrc/pooling.cpp b/torch_xla/csrc/pooling.cpp index f6295cd527f..dd202ccadaf 100644 --- a/torch_xla/csrc/pooling.cpp +++ b/torch_xla/csrc/pooling.cpp @@ -131,50 +131,47 @@ xla::XlaOp RemoveTrivialBatch(xla::XlaOp batch, int64_t original_rank, std::vector> CeilModePadding( absl::Span padding, const xla::Shape& input_shape, absl::Span kernel_size, absl::Span stride, - bool ceil_mode, bool count_include_pad) { + bool ceil_mode) { std::vector> ceil_mode_padding; for (int i = 0; i < padding.size(); ++i) { int64_t left_padding = padding[i]; - if (count_include_pad) { - // if count_include_pad; the padding is added as XLA ops - left_padding = 0; - } int64_t input_size = input_shape.dimensions(2 + i); int64_t output_size_rem = (input_size + 2 * left_padding - kernel_size[i]) % stride[i]; int64_t right_padding = left_padding; if (ceil_mode && output_size_rem != 0) { - right_padding += stride[i]; + int64_t extra_padding = stride[i] - output_size_rem; + int64_t new_output_size = + (input_size + left_padding + right_padding + extra_padding - + kernel_size[i] + stride[i] - 1) / + stride[i] + + 1; + // Ensure that the last pooling starts inside the image. + if ((new_output_size - 1) * stride[i] < input_size + left_padding) { + right_padding += extra_padding; + } } ceil_mode_padding.emplace_back(left_padding, right_padding); } return ceil_mode_padding; } -xla::PaddingConfig MakeXlaPaddingConfig(absl::Span padding) { - xla::PaddingConfig padding_config; - for (int i = 0; i < 2; ++i) { - padding_config.add_dimensions(); - } - for (int pad : padding) { - xla::PaddingConfig::PaddingConfigDimension* dims = - padding_config.add_dimensions(); - dims->set_edge_padding_low(pad); - dims->set_edge_padding_high(pad); - } - return padding_config; -} - // Creates an XLA padding configuration from a padding attribute value. -xla::PaddingConfig MakeXlaPaddingConfig( - std::vector> padding) { +xla::PaddingConfig MakeXlaPaddingConfig(absl::Span padding, + const xla::Shape& input_shape, + absl::Span kernel_size, + absl::Span stride, + bool ceil_mode) { xla::PaddingConfig padding_config; for (int i = 0; i < 2; ++i) { padding_config.add_dimensions(); } - for (const auto& dim_padding : padding) { + auto ceil_mode_padding = + CeilModePadding(padding, input_shape, kernel_size, stride, ceil_mode); + for (int i = 0; i < padding.size(); ++i) { xla::PaddingConfig::PaddingConfigDimension* dims = padding_config.add_dimensions(); + const auto dim_padding = ceil_mode_padding[i]; dims->set_edge_padding_low(dim_padding.first); dims->set_edge_padding_high(dim_padding.second); } @@ -448,9 +445,8 @@ MaxPoolResult BuildMaxPoolNd(xla::XlaOp input, int64_t spatial_dim_count, const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(batch_input_info.batch_input); xla::XlaOp init_value = xla::MinValue(builder, input_shape.element_type()); - std::vector> ceil_padding = CeilModePadding( - padding, input_shape, kernel_size, stride, ceil_mode, false); - xla::PaddingConfig padding_config = MakeXlaPaddingConfig(ceil_padding); + xla::PaddingConfig padding_config = MakeXlaPaddingConfig( + padding, input_shape, kernel_size, stride, ceil_mode); xla::XlaOp padded_input = xla::Pad(batch_input_info.batch_input, init_value, padding_config); PoolingOpAttributes pooling_op_attributes = @@ -489,8 +485,8 @@ xla::XlaOp BuildMaxPoolNdBackward(xla::XlaOp out_backprop, xla::XlaOp input, MakePoolingOpAttributes(/*kernel_size_attr=*/kernel_size, /*stride_attr=*/stride); std::vector> window_padding; - const auto ceil_mode_padding = CeilModePadding( - padding, input_shape, kernel_size, stride, ceil_mode, false); + const auto ceil_mode_padding = + CeilModePadding(padding, input_shape, kernel_size, stride, ceil_mode); window_padding.resize(2); window_padding.insert(window_padding.end(), ceil_mode_padding.begin(), ceil_mode_padding.end()); @@ -575,27 +571,15 @@ xla::XlaOp BuildAvgPoolNd(xla::XlaOp input, int64_t spatial_dim_count, BatchInput batch_input_info = CreateBatchInput(input, spatial_dim_count); const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(batch_input_info.batch_input); - - if (count_include_pad) { - xla::PaddingConfig padding_config = MakeXlaPaddingConfig(padding); - auto dtype = ShapeHelper::ShapeOfXlaOp(input).element_type(); - auto padding_value = XlaHelpers::ScalarValue(0, dtype, input.builder()); - batch_input_info.batch_input = - xla::Pad(batch_input_info.batch_input, padding_value, padding_config); - } - - const auto ceil_mode_padding = CeilModePadding( - padding, ShapeHelper::ShapeOfXlaOp(batch_input_info.batch_input), - kernel_size, stride, ceil_mode, count_include_pad); - + const auto ceil_mode_padding = + CeilModePadding(padding, input_shape, kernel_size, stride, ceil_mode); xla::XlaOp batch_result = xla::AvgPool( /*operand=*/batch_input_info.batch_input, /*kernel_size=*/pooling_op_attributes.kernel_size, /*stride=*/pooling_op_attributes.stride, /*padding=*/ceil_mode_padding, /*data_format=*/MakeNCHWFormat(spatial_dim_count), - /*counts_include_padding=*/false); // already compensated in XLA - + /*counts_include_padding=*/count_include_pad); return RemoveTrivialBatch(/*batch=*/batch_result, /*original_rank=*/batch_input_info.original_rank, /*spatial_dim_count=*/spatial_dim_count); @@ -613,8 +597,8 @@ xla::XlaOp BuildAvgPoolNdBackward(xla::XlaOp out_backprop, xla::XlaOp input, BatchInput batch_input_info = CreateBatchInput(input, spatial_dim_count); const xla::Shape& gradients_shape = ShapeHelper::ShapeOfXlaOp(batch_input_info.batch_input); - const auto ceil_mode_padding = CeilModePadding( - padding, gradients_shape, kernel_size, stride, ceil_mode, false); + const auto ceil_mode_padding = + CeilModePadding(padding, gradients_shape, kernel_size, stride, ceil_mode); BatchInput batch_out_backprop_info = CreateBatchInput(out_backprop, spatial_dim_count); xla::XlaOp batch_result = xla::AvgPoolGrad( diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7cc804ec1c9..24ea088e275 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -823,15 +823,13 @@ XLATensorPtr avg_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count, std::vector kernel_size, std::vector stride, std::vector padding, bool ceil_mode, - bool count_include_pad, - std::optional divisor_override) { + bool count_include_pad) { kernel_size = CheckIntList(kernel_size, spatial_dim_count, "kernel_size"); stride = CheckIntList(stride, spatial_dim_count, "stride", kernel_size); padding = CheckIntList(padding, spatial_dim_count, "padding"); return input->CreateFrom(torch::lazy::MakeNode( input->GetIrValue(), spatial_dim_count, std::move(kernel_size), - std::move(stride), std::move(padding), ceil_mode, count_include_pad, - divisor_override)); + std::move(stride), std::move(padding), ceil_mode, count_include_pad)); } XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 4edcee32911..05766ff7e7c 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -222,8 +222,7 @@ XLATensorPtr avg_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count, std::vector kernel_size, std::vector stride, std::vector padding, bool ceil_mode, - bool count_include_pad, - std::optional divisor_override); + bool count_include_pad); XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop, const XLATensorPtr& input,