diff --git a/src/baskerville/blocks.py b/src/baskerville/blocks.py index 1ba7028..3330ea0 100644 --- a/src/baskerville/blocks.py +++ b/src/baskerville/blocks.py @@ -445,7 +445,7 @@ def conv_next( return current -def fpn_unet( +def unet_conv( inputs, unet_repr, activation="relu", @@ -456,6 +456,7 @@ def fpn_unet( bn_momentum=0.99, kernel_size=1, kernel_initializer="he_normal", + upsample_conv=False, ): """Construct a feature pyramid network block. @@ -468,6 +469,7 @@ def fpn_unet( dropout: Dropout rate probability norm_type: Apply batch or layer normalization bn_momentum: BatchNorm momentum + upsample_conv: Conv1D the upsampled input path Returns: [batch_size, seq_length, features] output sequence @@ -499,11 +501,12 @@ def fpn_unet( filters = inputs.shape[-1] # dense - current1 = tf.keras.layers.Dense( - units=filters, - kernel_regularizer=tf.keras.regularizers.l2(l2_scale), - kernel_initializer=kernel_initializer, - )(current1) + if upsample_conv: + current1 = tf.keras.layers.Dense( + units=filters, + kernel_regularizer=tf.keras.regularizers.l2(l2_scale), + kernel_initializer=kernel_initializer, + )(current1) current2 = tf.keras.layers.Dense( units=filters, kernel_regularizer=tf.keras.regularizers.l2(l2_scale), @@ -514,7 +517,6 @@ def fpn_unet( current1 = tf.keras.layers.UpSampling1D(size=stride)(current1) # add - # current2 = layers.Scale(initializer='ones')(current2) current = tf.keras.layers.Add()([current1, current2]) # normalize? @@ -536,83 +538,7 @@ def fpn_unet( return current -def fpn1_unet( - inputs, - unet_repr, - activation="relu", - stride=2, - l2_scale=0, - dropout=0, - norm_type=None, - bn_momentum=0.99, - kernel_size=1, - kernel_initializer="he_normal", -): - """Construct a feature pyramid network block. - - Args: - inputs: [batch_size, seq_length, features] input sequence - kernel_size: Conv1D kernel_size - activation: relu/gelu/etc - stride: UpSample stride - l2_scale: L2 regularization weight. - dropout: Dropout rate probability - norm_type: Apply batch or layer normalization - bn_momentum: BatchNorm momentum - - Returns: - [batch_size, seq_length, features] output sequence - """ - - # variables - current1 = inputs - current2 = unet_repr - - # normalize - if norm_type == "batch-sync": - current1 = tf.keras.layers.BatchNormalization( - momentum=bn_momentum, synchronized=True - )(current1) - elif norm_type == "batch": - current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1) - elif norm_type == "layer": - current1 = tf.keras.layers.LayerNormalization()(current1) - - # activate - current1 = layers.activate(current1, activation) - # current2 = layers.activate(current2, activation) - - # dense - current1 = tf.keras.layers.Dense( - units=unet_repr.shape[-1], - kernel_regularizer=tf.keras.regularizers.l2(l2_scale), - kernel_initializer=kernel_initializer, - )(current1) - - # upsample - current1 = tf.keras.layers.UpSampling1D(size=stride)(current1) - - # add - current2 = layers.Scale(initializer="ones")(current2) - current = tf.keras.layers.Add()([current1, current2]) - - # convolution - current = tf.keras.layers.SeparableConv1D( - filters=unet_repr.shape[-1], - kernel_size=kernel_size, - padding="same", - kernel_regularizer=tf.keras.regularizers.l2(l2_scale), - kernel_initializer=kernel_initializer, - )(current) - - # dropout - if dropout > 0: - current = tf.keras.layers.Dropout(dropout)(current) - - return current - - -def upsample_unet( +def unet_concat( inputs, unet_repr, activation="relu", @@ -774,11 +700,6 @@ def tconv_nac( return current -def concat_unet(inputs, unet_repr, **kwargs): - current = tf.keras.layers.Concatenate()([inputs, unet_repr]) - return current - - def conv_block_2d( inputs, filters=128, @@ -2040,7 +1961,6 @@ def final( "center_average": center_average, "concat_dist_2d": concat_dist_2d, "concat_position": concat_position, - "concat_unet": concat_unet, "conv_block": conv_block, "conv_dna": conv_dna, "conv_nac": conv_nac, @@ -2067,10 +1987,9 @@ def final( "tconv_nac": tconv_nac, "transformer": transformer, "transformer_tower": transformer_tower, + "unet_conv": unet_conv, + "unet_concat": unet_concat, "upper_tri": upper_tri, - "fpn_unet": fpn_unet, - "fpn1_unet": fpn1_unet, - "upsample_unet": upsample_unet, "wheeze_excite": wheeze_excite, } diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 36377fd..716bb7b 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -99,7 +99,7 @@ def build_block(self, current, block_params): block_args["reprs"] = self.reprs # U-net helper - if block_name[-5:] == "_unet": + if block_name.startswith("unet_"): # find matching representation unet_repr = None for seq_repr in reversed(self.reprs[:-1]): diff --git a/tests/data/eval/params.json b/tests/data/eval/params.json index 7387223..2decd6f 100644 --- a/tests/data/eval/params.json +++ b/tests/data/eval/params.json @@ -54,12 +54,14 @@ "repeat": 2 }, { - "name": "fpn_unet", - "kernel_size": 3 + "name": "unet_conv", + "kernel_size": 3, + "upsample_conv": true }, { - "name": "fpn_unet", - "kernel_size": 3 + "name": "unet_conv", + "kernel_size": 3, + "upsample_conv": true }, { "name": "Cropping1D", diff --git a/tests/data/params.json b/tests/data/params.json index 24d3d0a..34ac510 100644 --- a/tests/data/params.json +++ b/tests/data/params.json @@ -54,12 +54,14 @@ "repeat": 1 }, { - "name": "fpn_unet", - "kernel_size": 3 + "name": "unet_conv", + "kernel_size": 3, + "upsample_conv": true }, { - "name": "fpn_unet", - "kernel_size": 3 + "name": "unet_conv", + "kernel_size": 3, + "upsample_conv": true }, { "name": "Cropping1D",