Skip to content

Commit

Permalink
add gene eval
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Jan 29, 2024
1 parent 9db2136 commit 6502e52
Show file tree
Hide file tree
Showing 7 changed files with 1,283 additions and 149 deletions.
54 changes: 52 additions & 2 deletions src/baskerville/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def conv_dna(
conv_type="standard",
kernel_initializer="he_normal",
padding="same",
transfer_se=False,
se_ratio=16,
):
"""Construct a single convolution block, assumed to be operating on DNA.
Expand Down Expand Up @@ -196,6 +198,18 @@ def conv_dna(
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
)(current)

# squeeze-excite for transfer
if transfer_se:
se_out = squeeze_excite(current,
activation=None,
additive=False,
bottleneck_ratio=se_ratio,
use_bias=False,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3),
scale_fun='tanh'
)
current = current + se_out

# squeeze-excite
if se:
current = squeeze_excite(current)
Expand Down Expand Up @@ -267,6 +281,8 @@ def conv_nac(
kernel_initializer="he_normal",
padding="same",
se=False,
transfer_se=False,
se_ratio=16,
):
"""Construct a single convolution block.
Expand Down Expand Up @@ -326,6 +342,18 @@ def conv_nac(
kernel_regularizer=tf.keras.regularizers.l2(l2_scale),
)(current)

# squeeze-excite for transfer
if transfer_se:
se_out = squeeze_excite(current,
activation=None,
additive=False,
bottleneck_ratio=se_ratio,
use_bias=False,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3),
scale_fun='tanh'
)
current = current + se_out

# squeeze-excite
if se:
current = squeeze_excite(current)
Expand Down Expand Up @@ -456,6 +484,8 @@ def fpn_unet(
bn_momentum=0.99,
kernel_size=1,
kernel_initializer="he_normal",
transfer_se=False,
se_ratio=16,
):
"""Construct a feature pyramid network block.
Expand Down Expand Up @@ -529,6 +559,17 @@ def fpn_unet(
kernel_initializer=kernel_initializer,
)(current)

if transfer_se:
se_out = squeeze_excite(current,
activation=None,
additive=False,
bottleneck_ratio=se_ratio,
use_bias=False,
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3),
scale_fun='tanh'
)
current = current + se_out

# dropout
if dropout > 0:
current = tf.keras.layers.Dropout(dropout)(current)
Expand Down Expand Up @@ -1528,11 +1569,20 @@ def squeeze_excite(
additive=False,
norm_type=None,
bn_momentum=0.9,
kernel_initializer='glorot_uniform',
use_bias=True,
scale_fun='sigmoid',
**kwargs,
):
return layers.SqueezeExcite(
activation, additive, bottleneck_ratio, norm_type, bn_momentum
)(inputs)
activation=activation,
additive=additive,
bottleneck_ratio=bottleneck_ratio,
norm_type=norm_type,
bn_momentum=bn_momentum,
kernel_initializer=kernel_initializer,
scale_fun=scale_fun,
use_bias=use_bias)(inputs)


def wheeze_excite(inputs, pool_size, **kwargs):
Expand Down
15 changes: 14 additions & 1 deletion src/baskerville/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ def __init__(
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
scale_fun='sigmoid',
):
super(SqueezeExcite, self).__init__()
self.activation = activation
Expand All @@ -766,6 +767,7 @@ def __init__(
self.kernel_initializer=kernel_initializer
self.bias_initializer=bias_initializer
self.use_bias=use_bias
self.scale_fun=scale_fun

def build(self, input_shape):
self.num_channels = input_shape[-1]
Expand All @@ -783,6 +785,17 @@ def build(self, input_shape):
)
exit(1)

if self.scale_fun=='sigmoid':
self.scale_f = tf.keras.activations.sigmoid
elif self.scale_fun=='tanh': # set to tanh for transfer
self.scale_f = tf.keras.activations.tanh
else:
print(
"scale function must be sigmoid or tanh",
file=sys.stderr,
)
exit(1)

self.dense1 = tf.keras.layers.Dense(
units=self.num_channels // self.bottleneck_ratio,
activation="relu",
Expand Down Expand Up @@ -819,7 +832,7 @@ def call(self, x):
if self.additive:
xs = x + excite
else:
excite = tf.keras.activations.sigmoid(excite)
excite = self.scale_f(excite)
xs = x * excite

return xs
Expand Down
Loading

0 comments on commit 6502e52

Please sign in to comment.