-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added ResNet model definitions with batchnorm set to inference mode.
- Loading branch information
Showing
2 changed files
with
384 additions
and
0 deletions.
There are no files selected for viewing
192 changes: 192 additions & 0 deletions
192
aparent/model/aparent_model_plasmid_resnet_var_batch_size.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from __future__ import print_function | ||
import keras | ||
from keras.models import Sequential, Model, load_model | ||
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda | ||
from keras.layers import Conv2D, MaxPooling2D, LocallyConnected2D, Conv1D, MaxPooling1D, LocallyConnected1D, LSTM, ConvLSTM2D, BatchNormalization | ||
from keras.layers import Concatenate, Reshape | ||
from keras.callbacks import ModelCheckpoint, EarlyStopping | ||
from keras import regularizers | ||
from keras import backend as K | ||
import keras.losses | ||
|
||
import tensorflow as tf | ||
|
||
import isolearn.keras as iso | ||
|
||
from aparent.losses import * | ||
|
||
def make_resblock(n_channels=64, window_size=8, dilation_rate=1, group_ix=0, layer_ix=0, drop_rate=0.0) : | ||
|
||
#Initialize res block layers | ||
batch_norm_0 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_0') | ||
|
||
relu_0 = Lambda(lambda x: K.relu(x, alpha=0.0)) | ||
|
||
conv_0 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_0') | ||
|
||
batch_norm_1 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_1') | ||
|
||
relu_1 = Lambda(lambda x: K.relu(x, alpha=0.0)) | ||
|
||
conv_1 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_1') | ||
|
||
skip_1 = Lambda(lambda x: x[0] + x[1], name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_1') | ||
|
||
drop_1 = None | ||
if drop_rate > 0.0 : | ||
drop_1 = Dropout(drop_rate) | ||
|
||
#Execute res block | ||
def _resblock_func(input_tensor) : | ||
batch_norm_0_out = batch_norm_0(input_tensor) | ||
relu_0_out = relu_0(batch_norm_0_out) | ||
conv_0_out = conv_0(relu_0_out) | ||
|
||
batch_norm_1_out = batch_norm_1(conv_0_out) | ||
relu_1_out = relu_1(batch_norm_1_out) | ||
|
||
if drop_rate > 0.0 : | ||
conv_1_out = drop_1(conv_1(relu_1_out)) | ||
else : | ||
conv_1_out = conv_1(relu_1_out) | ||
|
||
skip_1_out = skip_1([conv_1_out, input_tensor]) | ||
|
||
return skip_1_out | ||
|
||
return _resblock_func | ||
|
||
def load_residual_network(n_groups=1, n_resblocks_per_group=4, n_channels=32, window_size=8, dilation_rates=[1], drop_rate=0.0) : | ||
|
||
#Discriminator network definition | ||
conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_conv_0') | ||
|
||
skip_convs = [] | ||
resblock_groups = [] | ||
for group_ix in range(n_groups) : | ||
|
||
skip_convs.append(Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_skip_conv_' + str(group_ix))) | ||
|
||
resblocks = [] | ||
for layer_ix in range(n_resblocks_per_group) : | ||
resblocks.append(make_resblock(n_channels=n_channels, window_size=window_size, dilation_rate=dilation_rates[group_ix], group_ix=group_ix, layer_ix=layer_ix, drop_rate=drop_rate)) | ||
|
||
resblock_groups.append(resblocks) | ||
|
||
last_block_conv = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_last_block_conv') | ||
|
||
skip_add = Lambda(lambda x: x[0] + x[1], name='aparent_skip_add') | ||
|
||
final_conv = Conv2D(1, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_final_conv') | ||
|
||
extend_tensor = Lambda(lambda x: K.concatenate([x, K.zeros((K.shape(x)[0], 1, 1, 1))], axis=2), name='aparent_extend_tensor') | ||
|
||
expand_lib = Lambda(lambda x: K.tile(K.expand_dims(K.expand_dims(x, axis=1), axis=2), (1, 1, 206, 1)), name='aparent_expand_lib') | ||
|
||
lib_conv = LocallyConnected2D(1, (1, 1), strides=(1, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='aparent_lib_conv') | ||
|
||
lib_add = Lambda(lambda x: x[0] + x[1], name='aparent_lib_add') | ||
|
||
def _net_func(sequence_input, lib_input) : | ||
conv_0_out = conv_0(sequence_input) | ||
|
||
#Connect group of res blocks | ||
output_tensor = conv_0_out | ||
|
||
#Res block group execution | ||
skip_conv_outs = [] | ||
for group_ix in range(n_groups) : | ||
skip_conv_out = skip_convs[group_ix](output_tensor) | ||
skip_conv_outs.append(skip_conv_out) | ||
|
||
for layer_ix in range(n_resblocks_per_group) : | ||
output_tensor = resblock_groups[group_ix][layer_ix](output_tensor) | ||
|
||
#Last res block extr conv | ||
last_block_conv_out = last_block_conv(output_tensor) | ||
|
||
skip_add_out = last_block_conv_out | ||
for group_ix in range(n_groups) : | ||
skip_add_out = skip_add([skip_add_out, skip_conv_outs[group_ix]]) | ||
|
||
#Final conv out | ||
final_conv_out = extend_tensor(final_conv(skip_add_out)) | ||
|
||
#Add library bias | ||
lib_conv_out = lib_add([final_conv_out, lib_conv(expand_lib(lib_input))]) | ||
|
||
return lib_conv_out | ||
|
||
return _net_func | ||
|
||
def load_aparent_model(batch_size, use_sample_weights=False, drop_rate=0.25) : | ||
|
||
#APARENT parameters | ||
seq_input_shape = (1, 205, 4) | ||
lib_input_shape = (13,) | ||
num_outputs_iso = 1 | ||
num_outputs_cut = 206 | ||
|
||
#Plasmid model definition | ||
|
||
#Resnet function | ||
resnet = load_residual_network( | ||
n_groups=7, | ||
n_resblocks_per_group=4, | ||
n_channels=32, | ||
window_size=3, | ||
dilation_rates=[1, 2, 4, 8, 4, 2, 1], | ||
drop_rate=drop_rate | ||
) | ||
|
||
#Inputs | ||
seq_input = Input(shape=seq_input_shape) | ||
lib_input = Input(shape=lib_input_shape) | ||
plasmid_count = Input(shape=(1,)) | ||
|
||
#Outputs | ||
true_iso = Input(shape=(num_outputs_iso,)) | ||
true_cut = Input(shape=(num_outputs_cut,)) | ||
|
||
cut_score = resnet(seq_input, lib_input) | ||
|
||
cut_prob = Lambda(lambda x: K.softmax(x[:, 0, :, 0], axis=-1))(cut_score) | ||
iso_prob = Lambda(lambda cl: K.expand_dims(K.sum(cl[:, 80:80+30], axis=-1), axis=-1))(cut_prob) | ||
|
||
plasmid_model = Model( | ||
inputs=[ | ||
seq_input, | ||
lib_input | ||
], | ||
outputs=[ | ||
iso_prob, | ||
cut_prob | ||
] | ||
) | ||
|
||
#Loss model definition | ||
sigmoid_kl_divergence = get_sigmoid_kl_divergence(batch_size, use_sample_weights=use_sample_weights) | ||
kl_divergence = get_kl_divergence(batch_size, use_sample_weights=use_sample_weights) | ||
|
||
plasmid_loss_iso = Lambda(sigmoid_kl_divergence, output_shape = (1,))([true_iso, iso_prob, plasmid_count]) | ||
plasmid_loss_cut = Lambda(kl_divergence, output_shape = (1,))([true_cut, cut_prob, plasmid_count]) | ||
|
||
total_loss = Lambda( | ||
lambda l: 0.5 * l[0] + 0.5 * l[1], | ||
output_shape = (1,) | ||
)( | ||
[ | ||
plasmid_loss_iso, | ||
plasmid_loss_cut | ||
] | ||
) | ||
|
||
loss_model = Model([ | ||
seq_input, | ||
lib_input, | ||
plasmid_count, | ||
true_iso, | ||
true_cut | ||
], total_loss) | ||
|
||
return [ ('plasmid_iso_cut_resnet', plasmid_model), ('loss', loss_model) ] |
192 changes: 192 additions & 0 deletions
192
aparent/model/aparent_model_plasmid_resnet_var_batch_size_inference_mode.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
from __future__ import print_function | ||
import keras | ||
from keras.models import Sequential, Model, load_model | ||
from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda | ||
from keras.layers import Conv2D, MaxPooling2D, LocallyConnected2D, Conv1D, MaxPooling1D, LocallyConnected1D, LSTM, ConvLSTM2D, BatchNormalization | ||
from keras.layers import Concatenate, Reshape | ||
from keras.callbacks import ModelCheckpoint, EarlyStopping | ||
from keras import regularizers | ||
from keras import backend as K | ||
import keras.losses | ||
|
||
import tensorflow as tf | ||
|
||
import isolearn.keras as iso | ||
|
||
from aparent.losses import * | ||
|
||
def make_resblock(n_channels=64, window_size=8, dilation_rate=1, group_ix=0, layer_ix=0, drop_rate=0.0) : | ||
|
||
#Initialize res block layers | ||
batch_norm_0 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_0') | ||
|
||
relu_0 = Lambda(lambda x: K.relu(x, alpha=0.0)) | ||
|
||
conv_0 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_0') | ||
|
||
batch_norm_1 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_1') | ||
|
||
relu_1 = Lambda(lambda x: K.relu(x, alpha=0.0)) | ||
|
||
conv_1 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_1') | ||
|
||
skip_1 = Lambda(lambda x: x[0] + x[1], name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_1') | ||
|
||
drop_1 = None | ||
if drop_rate > 0.0 : | ||
drop_1 = Dropout(drop_rate) | ||
|
||
#Execute res block | ||
def _resblock_func(input_tensor) : | ||
batch_norm_0_out = batch_norm_0(input_tensor, training=False) | ||
relu_0_out = relu_0(batch_norm_0_out) | ||
conv_0_out = conv_0(relu_0_out) | ||
|
||
batch_norm_1_out = batch_norm_1(conv_0_out, training=False) | ||
relu_1_out = relu_1(batch_norm_1_out) | ||
|
||
if drop_rate > 0.0 : | ||
conv_1_out = drop_1(conv_1(relu_1_out)) | ||
else : | ||
conv_1_out = conv_1(relu_1_out) | ||
|
||
skip_1_out = skip_1([conv_1_out, input_tensor]) | ||
|
||
return skip_1_out | ||
|
||
return _resblock_func | ||
|
||
def load_residual_network(n_groups=1, n_resblocks_per_group=4, n_channels=32, window_size=8, dilation_rates=[1], drop_rate=0.0) : | ||
|
||
#Discriminator network definition | ||
conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_conv_0') | ||
|
||
skip_convs = [] | ||
resblock_groups = [] | ||
for group_ix in range(n_groups) : | ||
|
||
skip_convs.append(Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_skip_conv_' + str(group_ix))) | ||
|
||
resblocks = [] | ||
for layer_ix in range(n_resblocks_per_group) : | ||
resblocks.append(make_resblock(n_channels=n_channels, window_size=window_size, dilation_rate=dilation_rates[group_ix], group_ix=group_ix, layer_ix=layer_ix, drop_rate=drop_rate)) | ||
|
||
resblock_groups.append(resblocks) | ||
|
||
last_block_conv = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_last_block_conv') | ||
|
||
skip_add = Lambda(lambda x: x[0] + x[1], name='aparent_skip_add') | ||
|
||
final_conv = Conv2D(1, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_final_conv') | ||
|
||
extend_tensor = Lambda(lambda x: K.concatenate([x, K.zeros((K.shape(x)[0], 1, 1, 1))], axis=2), name='aparent_extend_tensor') | ||
|
||
expand_lib = Lambda(lambda x: K.tile(K.expand_dims(K.expand_dims(x, axis=1), axis=2), (1, 1, 206, 1)), name='aparent_expand_lib') | ||
|
||
lib_conv = LocallyConnected2D(1, (1, 1), strides=(1, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='aparent_lib_conv') | ||
|
||
lib_add = Lambda(lambda x: x[0] + x[1], name='aparent_lib_add') | ||
|
||
def _net_func(sequence_input, lib_input) : | ||
conv_0_out = conv_0(sequence_input) | ||
|
||
#Connect group of res blocks | ||
output_tensor = conv_0_out | ||
|
||
#Res block group execution | ||
skip_conv_outs = [] | ||
for group_ix in range(n_groups) : | ||
skip_conv_out = skip_convs[group_ix](output_tensor) | ||
skip_conv_outs.append(skip_conv_out) | ||
|
||
for layer_ix in range(n_resblocks_per_group) : | ||
output_tensor = resblock_groups[group_ix][layer_ix](output_tensor) | ||
|
||
#Last res block extr conv | ||
last_block_conv_out = last_block_conv(output_tensor) | ||
|
||
skip_add_out = last_block_conv_out | ||
for group_ix in range(n_groups) : | ||
skip_add_out = skip_add([skip_add_out, skip_conv_outs[group_ix]]) | ||
|
||
#Final conv out | ||
final_conv_out = extend_tensor(final_conv(skip_add_out)) | ||
|
||
#Add library bias | ||
lib_conv_out = lib_add([final_conv_out, lib_conv(expand_lib(lib_input))]) | ||
|
||
return lib_conv_out | ||
|
||
return _net_func | ||
|
||
def load_aparent_model(batch_size, use_sample_weights=False, drop_rate=0.25) : | ||
|
||
#APARENT parameters | ||
seq_input_shape = (1, 205, 4) | ||
lib_input_shape = (13,) | ||
num_outputs_iso = 1 | ||
num_outputs_cut = 206 | ||
|
||
#Plasmid model definition | ||
|
||
#Resnet function | ||
resnet = load_residual_network( | ||
n_groups=7, | ||
n_resblocks_per_group=4, | ||
n_channels=32, | ||
window_size=3, | ||
dilation_rates=[1, 2, 4, 8, 4, 2, 1], | ||
drop_rate=drop_rate | ||
) | ||
|
||
#Inputs | ||
seq_input = Input(shape=seq_input_shape) | ||
lib_input = Input(shape=lib_input_shape) | ||
plasmid_count = Input(shape=(1,)) | ||
|
||
#Outputs | ||
true_iso = Input(shape=(num_outputs_iso,)) | ||
true_cut = Input(shape=(num_outputs_cut,)) | ||
|
||
cut_score = resnet(seq_input, lib_input) | ||
|
||
cut_prob = Lambda(lambda x: K.softmax(x[:, 0, :, 0], axis=-1))(cut_score) | ||
iso_prob = Lambda(lambda cl: K.expand_dims(K.sum(cl[:, 80:80+30], axis=-1), axis=-1))(cut_prob) | ||
|
||
plasmid_model = Model( | ||
inputs=[ | ||
seq_input, | ||
lib_input | ||
], | ||
outputs=[ | ||
iso_prob, | ||
cut_prob | ||
] | ||
) | ||
|
||
#Loss model definition | ||
sigmoid_kl_divergence = get_sigmoid_kl_divergence(batch_size, use_sample_weights=use_sample_weights) | ||
kl_divergence = get_kl_divergence(batch_size, use_sample_weights=use_sample_weights) | ||
|
||
plasmid_loss_iso = Lambda(sigmoid_kl_divergence, output_shape = (1,))([true_iso, iso_prob, plasmid_count]) | ||
plasmid_loss_cut = Lambda(kl_divergence, output_shape = (1,))([true_cut, cut_prob, plasmid_count]) | ||
|
||
total_loss = Lambda( | ||
lambda l: 0.5 * l[0] + 0.5 * l[1], | ||
output_shape = (1,) | ||
)( | ||
[ | ||
plasmid_loss_iso, | ||
plasmid_loss_cut | ||
] | ||
) | ||
|
||
loss_model = Model([ | ||
seq_input, | ||
lib_input, | ||
plasmid_count, | ||
true_iso, | ||
true_cut | ||
], total_loss) | ||
|
||
return [ ('plasmid_iso_cut_resnet', plasmid_model), ('loss', loss_model) ] |