Skip to content

Commit

Permalink
unet block renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
davek44 committed Nov 3, 2023
1 parent c4f42eb commit 9db59ca
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 102 deletions.
105 changes: 12 additions & 93 deletions src/baskerville/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def conv_next(
return current


def fpn_unet(
def unet_conv(
inputs,
unet_repr,
activation="relu",
Expand All @@ -456,6 +456,7 @@ def fpn_unet(
bn_momentum=0.99,
kernel_size=1,
kernel_initializer="he_normal",
upsample_conv=False,
):
"""Construct a feature pyramid network block.
Expand All @@ -468,6 +469,7 @@ def fpn_unet(
dropout: Dropout rate probability
norm_type: Apply batch or layer normalization
bn_momentum: BatchNorm momentum
upsample_conv: Conv1D the upsampled input path
Returns:
[batch_size, seq_length, features] output sequence
Expand Down Expand Up @@ -499,11 +501,12 @@ def fpn_unet(
filters = inputs.shape[-1]

# dense
current1 = tf.keras.layers.Dense(
units=filters,
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
kernel_initializer=kernel_initializer,
)(current1)
if upsample_conv:
current1 = tf.keras.layers.Dense(
units=filters,
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
kernel_initializer=kernel_initializer,
)(current1)
current2 = tf.keras.layers.Dense(
units=filters,
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
Expand All @@ -514,7 +517,6 @@ def fpn_unet(
current1 = tf.keras.layers.UpSampling1D(size=stride)(current1)

# add
# current2 = layers.Scale(initializer='ones')(current2)
current = tf.keras.layers.Add()([current1, current2])

# normalize?
Expand All @@ -536,83 +538,7 @@ def fpn_unet(
return current


def fpn1_unet(
inputs,
unet_repr,
activation="relu",
stride=2,
l2_scale=0,
dropout=0,
norm_type=None,
bn_momentum=0.99,
kernel_size=1,
kernel_initializer="he_normal",
):
"""Construct a feature pyramid network block.
Args:
inputs: [batch_size, seq_length, features] input sequence
kernel_size: Conv1D kernel_size
activation: relu/gelu/etc
stride: UpSample stride
l2_scale: L2 regularization weight.
dropout: Dropout rate probability
norm_type: Apply batch or layer normalization
bn_momentum: BatchNorm momentum
Returns:
[batch_size, seq_length, features] output sequence
"""

# variables
current1 = inputs
current2 = unet_repr

# normalize
if norm_type == "batch-sync":
current1 = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current1)
elif norm_type == "batch":
current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1)
elif norm_type == "layer":
current1 = tf.keras.layers.LayerNormalization()(current1)

# activate
current1 = layers.activate(current1, activation)
# current2 = layers.activate(current2, activation)

# dense
current1 = tf.keras.layers.Dense(
units=unet_repr.shape[-1],
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
kernel_initializer=kernel_initializer,
)(current1)

# upsample
current1 = tf.keras.layers.UpSampling1D(size=stride)(current1)

# add
current2 = layers.Scale(initializer="ones")(current2)
current = tf.keras.layers.Add()([current1, current2])

# convolution
current = tf.keras.layers.SeparableConv1D(
filters=unet_repr.shape[-1],
kernel_size=kernel_size,
padding="same",
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
kernel_initializer=kernel_initializer,
)(current)

# dropout
if dropout > 0:
current = tf.keras.layers.Dropout(dropout)(current)

return current


def upsample_unet(
def unet_concat(
inputs,
unet_repr,
activation="relu",
Expand Down Expand Up @@ -774,11 +700,6 @@ def tconv_nac(
return current


def concat_unet(inputs, unet_repr, **kwargs):
current = tf.keras.layers.Concatenate()([inputs, unet_repr])
return current


def conv_block_2d(
inputs,
filters=128,
Expand Down Expand Up @@ -2040,7 +1961,6 @@ def final(
"center_average": center_average,
"concat_dist_2d": concat_dist_2d,
"concat_position": concat_position,
"concat_unet": concat_unet,
"conv_block": conv_block,
"conv_dna": conv_dna,
"conv_nac": conv_nac,
Expand All @@ -2067,10 +1987,9 @@ def final(
"tconv_nac": tconv_nac,
"transformer": transformer,
"transformer_tower": transformer_tower,
"unet_conv": unet_conv,
"unet_concat": unet_concat,
"upper_tri": upper_tri,
"fpn_unet": fpn_unet,
"fpn1_unet": fpn1_unet,
"upsample_unet": upsample_unet,
"wheeze_excite": wheeze_excite,
}

Expand Down
2 changes: 1 addition & 1 deletion src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def build_block(self, current, block_params):
block_args["reprs"] = self.reprs

# U-net helper
if block_name[-5:] == "_unet":
if block_name.startswith("unet_"):
# find matching representation
unet_repr = None
for seq_repr in reversed(self.reprs[:-1]):
Expand Down
10 changes: 6 additions & 4 deletions tests/data/eval/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@
"repeat": 2
},
{
"name": "fpn_unet",
"kernel_size": 3
"name": "unet_conv",
"kernel_size": 3,
"upsample_conv": true
},
{
"name": "fpn_unet",
"kernel_size": 3
"name": "unet_conv",
"kernel_size": 3,
"upsample_conv": true
},
{
"name": "Cropping1D",
Expand Down
10 changes: 6 additions & 4 deletions tests/data/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@
"repeat": 1
},
{
"name": "fpn_unet",
"kernel_size": 3
"name": "unet_conv",
"kernel_size": 3,
"upsample_conv": true
},
{
"name": "fpn_unet",
"kernel_size": 3
"name": "unet_conv",
"kernel_size": 3,
"upsample_conv": true
},
{
"name": "Cropping1D",
Expand Down

0 comments on commit 9db59ca

Please sign in to comment.