From 586876e3a433346bd187c5f3839a66cc31bda1fe Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Wed, 18 Sep 2024 15:31:10 -0700 Subject: [PATCH 1/4] Fix "same" padding torch issue --- keras/src/backend/torch/nn.py | 2 +- keras/src/layers/pooling/average_pooling_test.py | 1 + keras/src/ops/nn_test.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index de931db47d4..4d38401d375 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -174,7 +174,7 @@ def _apply_same_padding( spatial_shape[i], kernel_size[i], strides[i], dilation_rate[i] ) mode = "constant" - padding = (padding_size,) + padding + padding = padding + (padding_size,) if all([left == right for left, right in padding]): return inputs, [left for left, _ in padding] diff --git a/keras/src/layers/pooling/average_pooling_test.py b/keras/src/layers/pooling/average_pooling_test.py index 3e56cfdadf2..02bbdd30198 100644 --- a/keras/src/layers/pooling/average_pooling_test.py +++ b/keras/src/layers/pooling/average_pooling_test.py @@ -174,6 +174,7 @@ def test_average_pooling1d( (2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)), ((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)), ((2, 3), (2, 2), "same", "channels_last", (3, 5, 5, 4), (3, 3, 3, 4)), + ((2, 3), (3, 3), "same", "channels_first", (3, 5, 5, 4), (3, 5, 2, 2)), ) def test_average_pooling2d( self, diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 4d4262e830f..d001d68667c 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -1381,6 +1381,16 @@ def test_average_pool_same_padding(self): knn.average_pool(x, 2, (2, 1), padding="same"), np_avgpool2d(x, 2, (2, 1), padding="same", data_format=data_format), ) + # Test 2D average pooling with different pool size. + if data_format == "channels_last": + input_shape = (2, 10, 9, 3) + else: + input_shape = (2, 3, 10, 9) + x = np.arange(540, dtype=float).reshape(input_shape) + self.assertAllClose( + knn.average_pool(x, [2, 3], (3, 3), padding="same"), + np_avgpool2d(x, [2, 3], (3, 3), padding="same", data_format=data_format), + ) @parameterized.product( strides=(1, 2, 3), From 36ea062729e349daf2776740f718c4de3312803c Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Wed, 18 Sep 2024 15:38:37 -0700 Subject: [PATCH 2/4] format --- keras/src/ops/nn_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index d001d68667c..bbeda79bdba 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -1389,7 +1389,9 @@ def test_average_pool_same_padding(self): x = np.arange(540, dtype=float).reshape(input_shape) self.assertAllClose( knn.average_pool(x, [2, 3], (3, 3), padding="same"), - np_avgpool2d(x, [2, 3], (3, 3), padding="same", data_format=data_format), + np_avgpool2d( + x, [2, 3], (3, 3), padding="same", data_format=data_format + ), ) @parameterized.product( From a6614f4ba1c31f4aa179b02ef32b48dd666ca7dd Mon Sep 17 00:00:00 2001 From: sachin prasad Date: Wed, 18 Sep 2024 15:47:13 -0700 Subject: [PATCH 3/4] fix type --- keras/src/ops/nn_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index bbeda79bdba..0eededaf0bd 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -1388,9 +1388,9 @@ def test_average_pool_same_padding(self): input_shape = (2, 3, 10, 9) x = np.arange(540, dtype=float).reshape(input_shape) self.assertAllClose( - knn.average_pool(x, [2, 3], (3, 3), padding="same"), + knn.average_pool(x, (2, 3), (3, 3), padding="same"), np_avgpool2d( - x, [2, 3], (3, 3), padding="same", data_format=data_format + x, (2, 3), (3, 3), padding="same", data_format=data_format ), ) From b5e23e2eeee0b80aa66c5930a62876d2aeb738f4 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Tue, 24 Sep 2024 16:11:32 -0700 Subject: [PATCH 4/4] add condition for channels first and last --- keras/src/backend/torch/nn.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index 4d38401d375..5051e5ed5cf 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -141,7 +141,7 @@ def _compute_padding_length( def _apply_same_padding( - inputs, kernel_size, strides, operation_type, dilation_rate=1 + inputs, kernel_size, strides, data_format, operation_type, dilation_rate=1 ): """Apply same padding to the input tensor. @@ -174,7 +174,10 @@ def _apply_same_padding( spatial_shape[i], kernel_size[i], strides[i], dilation_rate[i] ) mode = "constant" - padding = padding + (padding_size,) + if data_format == "channels_last": + padding = (padding_size,) + padding + else: + padding = padding + (padding_size,) if all([left == right for left, right in padding]): return inputs, [left for left, _ in padding] @@ -252,7 +255,7 @@ def max_pool( # Torch does not natively support `"same"` padding, we need to manually # apply the right amount of padding to `inputs`. inputs, padding = _apply_same_padding( - inputs, pool_size, strides, operation_type="pooling" + inputs, pool_size, strides, data_format, operation_type="pooling" ) else: padding = 0 @@ -312,7 +315,7 @@ def average_pool( # Torch does not natively support `"same"` padding, we need to manually # apply the right amount of padding to `inputs`. inputs, padding = _apply_same_padding( - inputs, pool_size, strides, operation_type="pooling" + inputs, pool_size, strides, data_format, operation_type="pooling" ) else: padding = 0 @@ -377,6 +380,7 @@ def conv( inputs, kernel.shape[2:], strides, + data_format, operation_type="conv", dilation_rate=dilation_rate, )