diff --git a/test/run_tests.sh b/test/run_tests.sh index 712d1d94cd1..1553d53e409 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -141,8 +141,7 @@ function run_torch_op_tests { run_test "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA run_dynamic "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA run_dynamic "$CDIR/../../test/nn/test_dropout.py" "$@" -v TestDropoutNNDeviceTypeXLA - # TODO: Disable and foward fix, regression due to https://github.com/pytorch/xla/pull/6409 - #run_dynamic "$CDIR/../../test/nn/test_pooling.py" "$@" -v TestPoolingNNDeviceTypeXLA + run_dynamic "$CDIR/../../test/nn/test_pooling.py" "$@" -v TestPoolingNNDeviceTypeXLA run_dynamic "$CDIR/../../test/nn/test_embedding.py" "$@" -v TestEmbeddingNNDeviceTypeXLA run_dynamic "$CDIR/../../test/nn/test_convolution.py" "$@" -v TestConvolutionNNDeviceTypeXLA run_dynamic "$CDIR/../../test/nn/test_multihead_attention.py" "$@" -v TestMultiheadAttentionNNDeviceTypeXLA diff --git a/torch_xla/csrc/pooling.cpp b/torch_xla/csrc/pooling.cpp index f6295cd527f..f3c47b34f36 100644 --- a/torch_xla/csrc/pooling.cpp +++ b/torch_xla/csrc/pooling.cpp @@ -144,7 +144,23 @@ std::vector> CeilModePadding( (input_size + 2 * left_padding - kernel_size[i]) % stride[i]; int64_t right_padding = left_padding; if (ceil_mode && output_size_rem != 0) { - right_padding += stride[i]; + int64_t extra_padding = stride[i] - output_size_rem; + int64_t new_output_size = + (input_size + left_padding + right_padding + extra_padding - + kernel_size[i] + stride[i] - 1) / + stride[i] + + 1; + // Ensure that the last pooling starts inside the image. + int64_t size_to_compare = input_size + left_padding; + if (count_include_pad) { + // here left padding is reset to 0; + // but input size already includes both left_padding and + // right padding so we need to substract padding[i] + size_to_compare = input_size - padding[i]; + } + if ((new_output_size - 1) * stride[i] < size_to_compare) { + right_padding += extra_padding; + } } ceil_mode_padding.emplace_back(left_padding, right_padding); }