diff --git a/official/vision/beta/configs/decoders.py b/official/vision/beta/configs/decoders.py index 4afaca1a073..0ba2873c875 100644 --- a/official/vision/beta/configs/decoders.py +++ b/official/vision/beta/configs/decoders.py @@ -32,6 +32,7 @@ class Identity(hyperparams.Config): class FPN(hyperparams.Config): """FPN config.""" num_filters: int = 256 + fusion_type: str = 'sum' use_separable_conv: bool = False diff --git a/official/vision/beta/modeling/decoders/fpn.py b/official/vision/beta/modeling/decoders/fpn.py index 9d16b383603..aa2c79c6d18 100644 --- a/official/vision/beta/modeling/decoders/fpn.py +++ b/official/vision/beta/modeling/decoders/fpn.py @@ -42,6 +42,7 @@ def __init__( min_level: int = 3, max_level: int = 7, num_filters: int = 256, + fusion_type: str = 'sum', use_separable_conv: bool = False, activation: str = 'relu', use_sync_bn: bool = False, @@ -59,6 +60,8 @@ def __init__( min_level: An `int` of minimum level in FPN output feature maps. max_level: An `int` of maximum level in FPN output feature maps. num_filters: An `int` number of filters in FPN layers. + fusion_type: A `str` of `sum` or `concat`. Whether performing sum or + concat for feature fusion. use_separable_conv: A `bool`. If True use separable convolution for convolution in FPN layers. activation: A `str` name of the activation function. @@ -77,6 +80,7 @@ def __init__( 'min_level': min_level, 'max_level': max_level, 'num_filters': num_filters, + 'fusion_type': fusion_type, 'use_separable_conv': use_separable_conv, 'activation': activation, 'use_sync_bn': use_sync_bn, @@ -122,8 +126,16 @@ def __init__( # Build top-down path. feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]} for level in range(backbone_max_level - 1, min_level - 1, -1): - feats[str(level)] = spatial_transform_ops.nearest_upsampling( - feats[str(level + 1)], 2) + feats_lateral[str(level)] + feat_a = spatial_transform_ops.nearest_upsampling( + feats[str(level + 1)], 2) + feat_b = feats_lateral[str(level)] + + if fusion_type == 'sum': + feats[str(level)] = feat_a + feat_b + elif fusion_type == 'concat': + feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1) + else: + raise ValueError('Fusion type {} not supported.'.format(fusion_type)) # TODO(xianzhi): consider to remove bias in conv2d. # Build post-hoc 3x3 convolution kernel. @@ -224,6 +236,7 @@ def build_fpn_decoder( min_level=model_config.min_level, max_level=model_config.max_level, num_filters=decoder_cfg.num_filters, + fusion_type=decoder_cfg.fusion_type, use_separable_conv=decoder_cfg.use_separable_conv, activation=norm_activation_config.activation, use_sync_bn=norm_activation_config.use_sync_bn, diff --git a/official/vision/beta/modeling/decoders/fpn_test.py b/official/vision/beta/modeling/decoders/fpn_test.py index dc4aeeb2ebb..1aef30011ab 100644 --- a/official/vision/beta/modeling/decoders/fpn_test.py +++ b/official/vision/beta/modeling/decoders/fpn_test.py @@ -27,11 +27,11 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): @parameterized.parameters( - (256, 3, 7, False), - (256, 3, 7, True), + (256, 3, 7, False, 'sum'), + (256, 3, 7, True, 'concat'), ) def test_network_creation(self, input_size, min_level, max_level, - use_separable_conv): + use_separable_conv, fusion_type): """Test creation of FPN.""" tf.keras.backend.set_image_data_format('channels_last') @@ -42,6 +42,7 @@ def test_network_creation(self, input_size, min_level, max_level, input_specs=backbone.output_specs, min_level=min_level, max_level=max_level, + fusion_type=fusion_type, use_separable_conv=use_separable_conv) endpoints = backbone(inputs) @@ -87,6 +88,7 @@ def test_serialize_deserialize(self): min_level=3, max_level=7, num_filters=256, + fusion_type='sum', use_separable_conv=False, use_sync_bn=False, activation='relu',