Skip to content

Commit

Permalink
Added ResNet model definitions with batchnorm set to inference mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
jlinder2 committed Aug 5, 2022
1 parent 4dc24e5 commit b427e89
Show file tree
Hide file tree
Showing 2 changed files with 384 additions and 0 deletions.
192 changes: 192 additions & 0 deletions aparent/model/aparent_model_plasmid_resnet_var_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from __future__ import print_function
import keras
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda
from keras.layers import Conv2D, MaxPooling2D, LocallyConnected2D, Conv1D, MaxPooling1D, LocallyConnected1D, LSTM, ConvLSTM2D, BatchNormalization
from keras.layers import Concatenate, Reshape
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import regularizers
from keras import backend as K
import keras.losses

import tensorflow as tf

import isolearn.keras as iso

from aparent.losses import *

def make_resblock(n_channels=64, window_size=8, dilation_rate=1, group_ix=0, layer_ix=0, drop_rate=0.0) :

#Initialize res block layers
batch_norm_0 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_0')

relu_0 = Lambda(lambda x: K.relu(x, alpha=0.0))

conv_0 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_0')

batch_norm_1 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_1')

relu_1 = Lambda(lambda x: K.relu(x, alpha=0.0))

conv_1 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_1')

skip_1 = Lambda(lambda x: x[0] + x[1], name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_1')

drop_1 = None
if drop_rate > 0.0 :
drop_1 = Dropout(drop_rate)

#Execute res block
def _resblock_func(input_tensor) :
batch_norm_0_out = batch_norm_0(input_tensor)
relu_0_out = relu_0(batch_norm_0_out)
conv_0_out = conv_0(relu_0_out)

batch_norm_1_out = batch_norm_1(conv_0_out)
relu_1_out = relu_1(batch_norm_1_out)

if drop_rate > 0.0 :
conv_1_out = drop_1(conv_1(relu_1_out))
else :
conv_1_out = conv_1(relu_1_out)

skip_1_out = skip_1([conv_1_out, input_tensor])

return skip_1_out

return _resblock_func

def load_residual_network(n_groups=1, n_resblocks_per_group=4, n_channels=32, window_size=8, dilation_rates=[1], drop_rate=0.0) :

#Discriminator network definition
conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_conv_0')

skip_convs = []
resblock_groups = []
for group_ix in range(n_groups) :

skip_convs.append(Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_skip_conv_' + str(group_ix)))

resblocks = []
for layer_ix in range(n_resblocks_per_group) :
resblocks.append(make_resblock(n_channels=n_channels, window_size=window_size, dilation_rate=dilation_rates[group_ix], group_ix=group_ix, layer_ix=layer_ix, drop_rate=drop_rate))

resblock_groups.append(resblocks)

last_block_conv = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_last_block_conv')

skip_add = Lambda(lambda x: x[0] + x[1], name='aparent_skip_add')

final_conv = Conv2D(1, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_final_conv')

extend_tensor = Lambda(lambda x: K.concatenate([x, K.zeros((K.shape(x)[0], 1, 1, 1))], axis=2), name='aparent_extend_tensor')

expand_lib = Lambda(lambda x: K.tile(K.expand_dims(K.expand_dims(x, axis=1), axis=2), (1, 1, 206, 1)), name='aparent_expand_lib')

lib_conv = LocallyConnected2D(1, (1, 1), strides=(1, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='aparent_lib_conv')

lib_add = Lambda(lambda x: x[0] + x[1], name='aparent_lib_add')

def _net_func(sequence_input, lib_input) :
conv_0_out = conv_0(sequence_input)

#Connect group of res blocks
output_tensor = conv_0_out

#Res block group execution
skip_conv_outs = []
for group_ix in range(n_groups) :
skip_conv_out = skip_convs[group_ix](output_tensor)
skip_conv_outs.append(skip_conv_out)

for layer_ix in range(n_resblocks_per_group) :
output_tensor = resblock_groups[group_ix][layer_ix](output_tensor)

#Last res block extr conv
last_block_conv_out = last_block_conv(output_tensor)

skip_add_out = last_block_conv_out
for group_ix in range(n_groups) :
skip_add_out = skip_add([skip_add_out, skip_conv_outs[group_ix]])

#Final conv out
final_conv_out = extend_tensor(final_conv(skip_add_out))

#Add library bias
lib_conv_out = lib_add([final_conv_out, lib_conv(expand_lib(lib_input))])

return lib_conv_out

return _net_func

def load_aparent_model(batch_size, use_sample_weights=False, drop_rate=0.25) :

#APARENT parameters
seq_input_shape = (1, 205, 4)
lib_input_shape = (13,)
num_outputs_iso = 1
num_outputs_cut = 206

#Plasmid model definition

#Resnet function
resnet = load_residual_network(
n_groups=7,
n_resblocks_per_group=4,
n_channels=32,
window_size=3,
dilation_rates=[1, 2, 4, 8, 4, 2, 1],
drop_rate=drop_rate
)

#Inputs
seq_input = Input(shape=seq_input_shape)
lib_input = Input(shape=lib_input_shape)
plasmid_count = Input(shape=(1,))

#Outputs
true_iso = Input(shape=(num_outputs_iso,))
true_cut = Input(shape=(num_outputs_cut,))

cut_score = resnet(seq_input, lib_input)

cut_prob = Lambda(lambda x: K.softmax(x[:, 0, :, 0], axis=-1))(cut_score)
iso_prob = Lambda(lambda cl: K.expand_dims(K.sum(cl[:, 80:80+30], axis=-1), axis=-1))(cut_prob)

plasmid_model = Model(
inputs=[
seq_input,
lib_input
],
outputs=[
iso_prob,
cut_prob
]
)

#Loss model definition
sigmoid_kl_divergence = get_sigmoid_kl_divergence(batch_size, use_sample_weights=use_sample_weights)
kl_divergence = get_kl_divergence(batch_size, use_sample_weights=use_sample_weights)

plasmid_loss_iso = Lambda(sigmoid_kl_divergence, output_shape = (1,))([true_iso, iso_prob, plasmid_count])
plasmid_loss_cut = Lambda(kl_divergence, output_shape = (1,))([true_cut, cut_prob, plasmid_count])

total_loss = Lambda(
lambda l: 0.5 * l[0] + 0.5 * l[1],
output_shape = (1,)
)(
[
plasmid_loss_iso,
plasmid_loss_cut
]
)

loss_model = Model([
seq_input,
lib_input,
plasmid_count,
true_iso,
true_cut
], total_loss)

return [ ('plasmid_iso_cut_resnet', plasmid_model), ('loss', loss_model) ]
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from __future__ import print_function
import keras
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda
from keras.layers import Conv2D, MaxPooling2D, LocallyConnected2D, Conv1D, MaxPooling1D, LocallyConnected1D, LSTM, ConvLSTM2D, BatchNormalization
from keras.layers import Concatenate, Reshape
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras import regularizers
from keras import backend as K
import keras.losses

import tensorflow as tf

import isolearn.keras as iso

from aparent.losses import *

def make_resblock(n_channels=64, window_size=8, dilation_rate=1, group_ix=0, layer_ix=0, drop_rate=0.0) :

#Initialize res block layers
batch_norm_0 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_0')

relu_0 = Lambda(lambda x: K.relu(x, alpha=0.0))

conv_0 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_0')

batch_norm_1 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_1')

relu_1 = Lambda(lambda x: K.relu(x, alpha=0.0))

conv_1 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_1')

skip_1 = Lambda(lambda x: x[0] + x[1], name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_1')

drop_1 = None
if drop_rate > 0.0 :
drop_1 = Dropout(drop_rate)

#Execute res block
def _resblock_func(input_tensor) :
batch_norm_0_out = batch_norm_0(input_tensor, training=False)
relu_0_out = relu_0(batch_norm_0_out)
conv_0_out = conv_0(relu_0_out)

batch_norm_1_out = batch_norm_1(conv_0_out, training=False)
relu_1_out = relu_1(batch_norm_1_out)

if drop_rate > 0.0 :
conv_1_out = drop_1(conv_1(relu_1_out))
else :
conv_1_out = conv_1(relu_1_out)

skip_1_out = skip_1([conv_1_out, input_tensor])

return skip_1_out

return _resblock_func

def load_residual_network(n_groups=1, n_resblocks_per_group=4, n_channels=32, window_size=8, dilation_rates=[1], drop_rate=0.0) :

#Discriminator network definition
conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_conv_0')

skip_convs = []
resblock_groups = []
for group_ix in range(n_groups) :

skip_convs.append(Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_skip_conv_' + str(group_ix)))

resblocks = []
for layer_ix in range(n_resblocks_per_group) :
resblocks.append(make_resblock(n_channels=n_channels, window_size=window_size, dilation_rate=dilation_rates[group_ix], group_ix=group_ix, layer_ix=layer_ix, drop_rate=drop_rate))

resblock_groups.append(resblocks)

last_block_conv = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_last_block_conv')

skip_add = Lambda(lambda x: x[0] + x[1], name='aparent_skip_add')

final_conv = Conv2D(1, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_final_conv')

extend_tensor = Lambda(lambda x: K.concatenate([x, K.zeros((K.shape(x)[0], 1, 1, 1))], axis=2), name='aparent_extend_tensor')

expand_lib = Lambda(lambda x: K.tile(K.expand_dims(K.expand_dims(x, axis=1), axis=2), (1, 1, 206, 1)), name='aparent_expand_lib')

lib_conv = LocallyConnected2D(1, (1, 1), strides=(1, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='aparent_lib_conv')

lib_add = Lambda(lambda x: x[0] + x[1], name='aparent_lib_add')

def _net_func(sequence_input, lib_input) :
conv_0_out = conv_0(sequence_input)

#Connect group of res blocks
output_tensor = conv_0_out

#Res block group execution
skip_conv_outs = []
for group_ix in range(n_groups) :
skip_conv_out = skip_convs[group_ix](output_tensor)
skip_conv_outs.append(skip_conv_out)

for layer_ix in range(n_resblocks_per_group) :
output_tensor = resblock_groups[group_ix][layer_ix](output_tensor)

#Last res block extr conv
last_block_conv_out = last_block_conv(output_tensor)

skip_add_out = last_block_conv_out
for group_ix in range(n_groups) :
skip_add_out = skip_add([skip_add_out, skip_conv_outs[group_ix]])

#Final conv out
final_conv_out = extend_tensor(final_conv(skip_add_out))

#Add library bias
lib_conv_out = lib_add([final_conv_out, lib_conv(expand_lib(lib_input))])

return lib_conv_out

return _net_func

def load_aparent_model(batch_size, use_sample_weights=False, drop_rate=0.25) :

#APARENT parameters
seq_input_shape = (1, 205, 4)
lib_input_shape = (13,)
num_outputs_iso = 1
num_outputs_cut = 206

#Plasmid model definition

#Resnet function
resnet = load_residual_network(
n_groups=7,
n_resblocks_per_group=4,
n_channels=32,
window_size=3,
dilation_rates=[1, 2, 4, 8, 4, 2, 1],
drop_rate=drop_rate
)

#Inputs
seq_input = Input(shape=seq_input_shape)
lib_input = Input(shape=lib_input_shape)
plasmid_count = Input(shape=(1,))

#Outputs
true_iso = Input(shape=(num_outputs_iso,))
true_cut = Input(shape=(num_outputs_cut,))

cut_score = resnet(seq_input, lib_input)

cut_prob = Lambda(lambda x: K.softmax(x[:, 0, :, 0], axis=-1))(cut_score)
iso_prob = Lambda(lambda cl: K.expand_dims(K.sum(cl[:, 80:80+30], axis=-1), axis=-1))(cut_prob)

plasmid_model = Model(
inputs=[
seq_input,
lib_input
],
outputs=[
iso_prob,
cut_prob
]
)

#Loss model definition
sigmoid_kl_divergence = get_sigmoid_kl_divergence(batch_size, use_sample_weights=use_sample_weights)
kl_divergence = get_kl_divergence(batch_size, use_sample_weights=use_sample_weights)

plasmid_loss_iso = Lambda(sigmoid_kl_divergence, output_shape = (1,))([true_iso, iso_prob, plasmid_count])
plasmid_loss_cut = Lambda(kl_divergence, output_shape = (1,))([true_cut, cut_prob, plasmid_count])

total_loss = Lambda(
lambda l: 0.5 * l[0] + 0.5 * l[1],
output_shape = (1,)
)(
[
plasmid_loss_iso,
plasmid_loss_cut
]
)

loss_model = Model([
seq_input,
lib_input,
plasmid_count,
true_iso,
true_cut
], total_loss)

return [ ('plasmid_iso_cut_resnet', plasmid_model), ('loss', loss_model) ]

0 comments on commit b427e89

Please sign in to comment.