Skip to content

Commit

Permalink
Revert "Avoid fallback for avg_pool - (#6409)"
Browse files Browse the repository at this point in the history
This reverts commit a0bae82.
  • Loading branch information
yeounoh authored Feb 1, 2024
1 parent 52ca01e commit ef753f4
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 92 deletions.
12 changes: 4 additions & 8 deletions test/cpp/test_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,7 @@ TEST_F(TensorTest, TestAvgPool2D) {
/*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*ceil_mode=*/false, count_include_pad,
/*divisor_override=*/std::nullopt);
/*ceil_mode=*/false, count_include_pad);
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
XLATensorPtr dev_input = XLATensor::Create(input, device);
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
Expand All @@ -326,8 +325,7 @@ TEST_F(TensorTest, TestAvgPool2D) {
/*kernel_size=*/{kernel_size, kernel_size},
/*stride=*/{stride, stride},
/*padding=*/{padding, padding},
/*ceil_mode=*/false, count_include_pad,
/*divisor_override=*/std::nullopt);
/*ceil_mode=*/false, count_include_pad);
AllClose(output, dev_output);
});
}
Expand All @@ -346,8 +344,7 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
/*kernel_size=*/{kernel_size, kernel_size + 1},
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1}, /*ceil_mode=*/false,
/*count_include_pad=*/count_include_pad,
/*divisor_override=*/std::nullopt);
/*count_include_pad=*/count_include_pad);
ForEachDevice([&](const torch::lazy::BackendDevice& device) {
XLATensorPtr dev_input = XLATensor::Create(input, device);
XLATensorPtr dev_output = tensor_methods::avg_pool_nd(
Expand All @@ -357,8 +354,7 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
/*stride=*/{stride, stride + 1},
/*padding=*/{padding, padding + 1},
/*ceil_mode=*/false,
/*count_include_pad=*/count_include_pad,
/*divisor_override=*/std::nullopt);
/*count_include_pad=*/count_include_pad);
AllClose(output, dev_output);
});
}
Expand Down
11 changes: 0 additions & 11 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,17 +737,6 @@ def test_aten_avg_pool2d_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs)

def test_aten_avg_pool2d_2(self):
args = (
torch.randn((1, 192, 40, 40)).to(torch.float32),
[3, 3],
[1, 1],
[1, 1],
True,
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.avg_pool2d, args, kwargs)

def test_aten_avg_pool3d_0(self):
args = (
torch.randn((1, 3, 10, 10, 10)).to(torch.float32),
Expand Down
20 changes: 16 additions & 4 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,11 +760,17 @@ at::Tensor XLANativeFunctions::avg_pool2d(
at::IntArrayRef padding, bool ceil_mode, bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if ((ceil_mode && count_include_pad) || divisor_override) {
return at::native::call_fallback_fn<
&xla_cpu_fallback, ATEN_OP(avg_pool2d)>::call(self, kernel_size, stride,
padding, ceil_mode,
count_include_pad,
divisor_override);
}
return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd(
bridge::GetXlaTensor(self), /*spatial_dim_count=*/2,
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding), ceil_mode, count_include_pad,
divisor_override));
XlaHelpers::I64List(padding), ceil_mode, count_include_pad));
}

at::Tensor XLANativeFunctions::avg_pool2d_backward(
Expand All @@ -791,11 +797,17 @@ at::Tensor XLANativeFunctions::avg_pool3d(
at::IntArrayRef padding, bool ceil_mode, bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
if ((ceil_mode && count_include_pad) || divisor_override) {
return at::native::call_fallback_fn<
&xla_cpu_fallback, ATEN_OP(avg_pool3d)>::call(self, kernel_size, stride,
padding, ceil_mode,
count_include_pad,
divisor_override);
}
return bridge::AtenFromXlaTensor(tensor_methods::avg_pool_nd(
bridge::GetXlaTensor(self), /*spatial_dim_count=*/3,
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
XlaHelpers::I64List(padding), ceil_mode, count_include_pad,
divisor_override));
XlaHelpers::I64List(padding), ceil_mode, count_include_pad));
}

at::Tensor XLANativeFunctions::avg_pool3d_backward(
Expand Down
17 changes: 2 additions & 15 deletions torch_xla/csrc/ops/avg_pool_nd.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#include "torch_xla/csrc/ops/avg_pool_nd.h"

#include "absl/strings/str_join.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/pooling.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/shape_helper.h"

namespace torch_xla {
namespace {
Expand Down Expand Up @@ -47,8 +45,7 @@ c10::Symbol AvgPoolNdSymbol(int64_t spatial_dim_count) {
AvgPoolNd::AvgPoolNd(const torch::lazy::Value& input, int64_t spatial_dim_count,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride, std::vector<int64_t> padding,
bool ceil_mode, bool count_include_pad,
std::optional<int> divisor_override)
bool ceil_mode, bool count_include_pad)
: XlaNode(torch::lazy::OpKind(AvgPoolNdSymbol(spatial_dim_count)), {input},
[&]() {
return NodeOutputShape(input, spatial_dim_count, kernel_size,
Expand All @@ -63,8 +60,7 @@ AvgPoolNd::AvgPoolNd(const torch::lazy::Value& input, int64_t spatial_dim_count,
stride_(std::move(stride)),
padding_(std::move(padding)),
ceil_mode_(ceil_mode),
count_include_pad_(count_include_pad),
divisor_override_(divisor_override) {}
count_include_pad_(count_include_pad) {}

torch::lazy::NodePtr AvgPoolNd::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<AvgPoolNd>(operands.at(0), spatial_dim_count_,
Expand All @@ -77,15 +73,6 @@ XlaOpVector AvgPoolNd::Lower(LoweringContext* loctx) const {
xla::XlaOp output =
BuildAvgPoolNd(input, spatial_dim_count_, kernel_size_, stride_, padding_,
ceil_mode_, count_include_pad_);
if (divisor_override_) {
auto dtype = ShapeHelper::ShapeOfXlaOp(output).element_type();
auto* builder = loctx->builder();
int size = std::accumulate(kernel_size_.begin(), kernel_size_.end(), 1,
std::multiplies<int>());
output = xla::Div(
xla::Mul(output, XlaHelpers::ScalarValue(size, dtype, builder)),
XlaHelpers::ScalarValue(*divisor_override_, dtype, builder));
}
return ReturnOp(output, loctx);
}

Expand Down
4 changes: 1 addition & 3 deletions torch_xla/csrc/ops/avg_pool_nd.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class AvgPoolNd : public XlaNode {
AvgPoolNd(const torch::lazy::Value& input, int64_t spatial_dim_count,
std::vector<int64_t> kernel_size, std::vector<int64_t> stride,
std::vector<int64_t> padding, bool ceil_mode,
bool count_include_pad,
std::optional<int> divisor_override = std::nullopt);
bool count_include_pad);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

Expand Down Expand Up @@ -41,7 +40,6 @@ class AvgPoolNd : public XlaNode {
// Whether the counts used to compute the average should include the added
// padding.
bool count_include_pad_;
std::optional<int> divisor_override_;
};

} // namespace torch_xla
Expand Down
74 changes: 29 additions & 45 deletions torch_xla/csrc/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,50 +131,47 @@ xla::XlaOp RemoveTrivialBatch(xla::XlaOp batch, int64_t original_rank,
std::vector<std::pair<int64_t, int64_t>> CeilModePadding(
absl::Span<const int64_t> padding, const xla::Shape& input_shape,
absl::Span<const int64_t> kernel_size, absl::Span<const int64_t> stride,
bool ceil_mode, bool count_include_pad) {
bool ceil_mode) {
std::vector<std::pair<int64_t, int64_t>> ceil_mode_padding;
for (int i = 0; i < padding.size(); ++i) {
int64_t left_padding = padding[i];
if (count_include_pad) {
// if count_include_pad; the padding is added as XLA ops
left_padding = 0;
}
int64_t input_size = input_shape.dimensions(2 + i);
int64_t output_size_rem =
(input_size + 2 * left_padding - kernel_size[i]) % stride[i];
int64_t right_padding = left_padding;
if (ceil_mode && output_size_rem != 0) {
right_padding += stride[i];
int64_t extra_padding = stride[i] - output_size_rem;
int64_t new_output_size =
(input_size + left_padding + right_padding + extra_padding -
kernel_size[i] + stride[i] - 1) /
stride[i] +
1;
// Ensure that the last pooling starts inside the image.
if ((new_output_size - 1) * stride[i] < input_size + left_padding) {
right_padding += extra_padding;
}
}
ceil_mode_padding.emplace_back(left_padding, right_padding);
}
return ceil_mode_padding;
}

xla::PaddingConfig MakeXlaPaddingConfig(absl::Span<const int64_t> padding) {
xla::PaddingConfig padding_config;
for (int i = 0; i < 2; ++i) {
padding_config.add_dimensions();
}
for (int pad : padding) {
xla::PaddingConfig::PaddingConfigDimension* dims =
padding_config.add_dimensions();
dims->set_edge_padding_low(pad);
dims->set_edge_padding_high(pad);
}
return padding_config;
}

// Creates an XLA padding configuration from a padding attribute value.
xla::PaddingConfig MakeXlaPaddingConfig(
std::vector<std::pair<int64_t, int64_t>> padding) {
xla::PaddingConfig MakeXlaPaddingConfig(absl::Span<const int64_t> padding,
const xla::Shape& input_shape,
absl::Span<const int64_t> kernel_size,
absl::Span<const int64_t> stride,
bool ceil_mode) {
xla::PaddingConfig padding_config;
for (int i = 0; i < 2; ++i) {
padding_config.add_dimensions();
}
for (const auto& dim_padding : padding) {
auto ceil_mode_padding =
CeilModePadding(padding, input_shape, kernel_size, stride, ceil_mode);
for (int i = 0; i < padding.size(); ++i) {
xla::PaddingConfig::PaddingConfigDimension* dims =
padding_config.add_dimensions();
const auto dim_padding = ceil_mode_padding[i];
dims->set_edge_padding_low(dim_padding.first);
dims->set_edge_padding_high(dim_padding.second);
}
Expand Down Expand Up @@ -448,9 +445,8 @@ MaxPoolResult BuildMaxPoolNd(xla::XlaOp input, int64_t spatial_dim_count,
const xla::Shape& input_shape =
ShapeHelper::ShapeOfXlaOp(batch_input_info.batch_input);
xla::XlaOp init_value = xla::MinValue(builder, input_shape.element_type());
std::vector<std::pair<int64_t, int64_t>> ceil_padding = CeilModePadding(
padding, input_shape, kernel_size, stride, ceil_mode, false);
xla::PaddingConfig padding_config = MakeXlaPaddingConfig(ceil_padding);
xla::PaddingConfig padding_config = MakeXlaPaddingConfig(
padding, input_shape, kernel_size, stride, ceil_mode);
xla::XlaOp padded_input =
xla::Pad(batch_input_info.batch_input, init_value, padding_config);
PoolingOpAttributes pooling_op_attributes =
Expand Down Expand Up @@ -489,8 +485,8 @@ xla::XlaOp BuildMaxPoolNdBackward(xla::XlaOp out_backprop, xla::XlaOp input,
MakePoolingOpAttributes(/*kernel_size_attr=*/kernel_size,
/*stride_attr=*/stride);
std::vector<std::pair<int64_t, int64_t>> window_padding;
const auto ceil_mode_padding = CeilModePadding(
padding, input_shape, kernel_size, stride, ceil_mode, false);
const auto ceil_mode_padding =
CeilModePadding(padding, input_shape, kernel_size, stride, ceil_mode);
window_padding.resize(2);
window_padding.insert(window_padding.end(), ceil_mode_padding.begin(),
ceil_mode_padding.end());
Expand Down Expand Up @@ -575,27 +571,15 @@ xla::XlaOp BuildAvgPoolNd(xla::XlaOp input, int64_t spatial_dim_count,
BatchInput batch_input_info = CreateBatchInput(input, spatial_dim_count);
const xla::Shape& input_shape =
ShapeHelper::ShapeOfXlaOp(batch_input_info.batch_input);

if (count_include_pad) {
xla::PaddingConfig padding_config = MakeXlaPaddingConfig(padding);
auto dtype = ShapeHelper::ShapeOfXlaOp(input).element_type();
auto padding_value = XlaHelpers::ScalarValue(0, dtype, input.builder());
batch_input_info.batch_input =
xla::Pad(batch_input_info.batch_input, padding_value, padding_config);
}

const auto ceil_mode_padding = CeilModePadding(
padding, ShapeHelper::ShapeOfXlaOp(batch_input_info.batch_input),
kernel_size, stride, ceil_mode, count_include_pad);

const auto ceil_mode_padding =
CeilModePadding(padding, input_shape, kernel_size, stride, ceil_mode);
xla::XlaOp batch_result = xla::AvgPool(
/*operand=*/batch_input_info.batch_input,
/*kernel_size=*/pooling_op_attributes.kernel_size,
/*stride=*/pooling_op_attributes.stride,
/*padding=*/ceil_mode_padding,
/*data_format=*/MakeNCHWFormat(spatial_dim_count),
/*counts_include_padding=*/false); // already compensated in XLA

/*counts_include_padding=*/count_include_pad);
return RemoveTrivialBatch(/*batch=*/batch_result,
/*original_rank=*/batch_input_info.original_rank,
/*spatial_dim_count=*/spatial_dim_count);
Expand All @@ -613,8 +597,8 @@ xla::XlaOp BuildAvgPoolNdBackward(xla::XlaOp out_backprop, xla::XlaOp input,
BatchInput batch_input_info = CreateBatchInput(input, spatial_dim_count);
const xla::Shape& gradients_shape =
ShapeHelper::ShapeOfXlaOp(batch_input_info.batch_input);
const auto ceil_mode_padding = CeilModePadding(
padding, gradients_shape, kernel_size, stride, ceil_mode, false);
const auto ceil_mode_padding =
CeilModePadding(padding, gradients_shape, kernel_size, stride, ceil_mode);
BatchInput batch_out_backprop_info =
CreateBatchInput(out_backprop, spatial_dim_count);
xla::XlaOp batch_result = xla::AvgPoolGrad(
Expand Down
6 changes: 2 additions & 4 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,15 +823,13 @@ XLATensorPtr avg_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride,
std::vector<int64_t> padding, bool ceil_mode,
bool count_include_pad,
std::optional<int> divisor_override) {
bool count_include_pad) {
kernel_size = CheckIntList(kernel_size, spatial_dim_count, "kernel_size");
stride = CheckIntList(stride, spatial_dim_count, "stride", kernel_size);
padding = CheckIntList(padding, spatial_dim_count, "padding");
return input->CreateFrom(torch::lazy::MakeNode<AvgPoolNd>(
input->GetIrValue(), spatial_dim_count, std::move(kernel_size),
std::move(stride), std::move(padding), ceil_mode, count_include_pad,
divisor_override));
std::move(stride), std::move(padding), ceil_mode, count_include_pad));
}

XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop,
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,7 @@ XLATensorPtr avg_pool_nd(const XLATensorPtr& input, int64_t spatial_dim_count,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride,
std::vector<int64_t> padding, bool ceil_mode,
bool count_include_pad,
std::optional<int> divisor_override);
bool count_include_pad);

XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop,
const XLATensorPtr& input,
Expand Down

0 comments on commit ef753f4

Please sign in to comment.