diff --git a/aparent/model/aparent_model_plasmid_resnet_var_batch_size.py b/aparent/model/aparent_model_plasmid_resnet_var_batch_size.py new file mode 100644 index 0000000..fd5e0cf --- /dev/null +++ b/aparent/model/aparent_model_plasmid_resnet_var_batch_size.py @@ -0,0 +1,192 @@ +from __future__ import print_function +import keras +from keras.models import Sequential, Model, load_model +from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda +from keras.layers import Conv2D, MaxPooling2D, LocallyConnected2D, Conv1D, MaxPooling1D, LocallyConnected1D, LSTM, ConvLSTM2D, BatchNormalization +from keras.layers import Concatenate, Reshape +from keras.callbacks import ModelCheckpoint, EarlyStopping +from keras import regularizers +from keras import backend as K +import keras.losses + +import tensorflow as tf + +import isolearn.keras as iso + +from aparent.losses import * + +def make_resblock(n_channels=64, window_size=8, dilation_rate=1, group_ix=0, layer_ix=0, drop_rate=0.0) : + + #Initialize res block layers + batch_norm_0 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_0') + + relu_0 = Lambda(lambda x: K.relu(x, alpha=0.0)) + + conv_0 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_0') + + batch_norm_1 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_1') + + relu_1 = Lambda(lambda x: K.relu(x, alpha=0.0)) + + conv_1 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_1') + + skip_1 = Lambda(lambda x: x[0] + x[1], name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_1') + + drop_1 = None + if drop_rate > 0.0 : + drop_1 = Dropout(drop_rate) + + #Execute res block + def _resblock_func(input_tensor) : + batch_norm_0_out = batch_norm_0(input_tensor) + relu_0_out = relu_0(batch_norm_0_out) + conv_0_out = conv_0(relu_0_out) + + batch_norm_1_out = batch_norm_1(conv_0_out) + relu_1_out = relu_1(batch_norm_1_out) + + if drop_rate > 0.0 : + conv_1_out = drop_1(conv_1(relu_1_out)) + else : + conv_1_out = conv_1(relu_1_out) + + skip_1_out = skip_1([conv_1_out, input_tensor]) + + return skip_1_out + + return _resblock_func + +def load_residual_network(n_groups=1, n_resblocks_per_group=4, n_channels=32, window_size=8, dilation_rates=[1], drop_rate=0.0) : + + #Discriminator network definition + conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_conv_0') + + skip_convs = [] + resblock_groups = [] + for group_ix in range(n_groups) : + + skip_convs.append(Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_skip_conv_' + str(group_ix))) + + resblocks = [] + for layer_ix in range(n_resblocks_per_group) : + resblocks.append(make_resblock(n_channels=n_channels, window_size=window_size, dilation_rate=dilation_rates[group_ix], group_ix=group_ix, layer_ix=layer_ix, drop_rate=drop_rate)) + + resblock_groups.append(resblocks) + + last_block_conv = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_last_block_conv') + + skip_add = Lambda(lambda x: x[0] + x[1], name='aparent_skip_add') + + final_conv = Conv2D(1, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_final_conv') + + extend_tensor = Lambda(lambda x: K.concatenate([x, K.zeros((K.shape(x)[0], 1, 1, 1))], axis=2), name='aparent_extend_tensor') + + expand_lib = Lambda(lambda x: K.tile(K.expand_dims(K.expand_dims(x, axis=1), axis=2), (1, 1, 206, 1)), name='aparent_expand_lib') + + lib_conv = LocallyConnected2D(1, (1, 1), strides=(1, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='aparent_lib_conv') + + lib_add = Lambda(lambda x: x[0] + x[1], name='aparent_lib_add') + + def _net_func(sequence_input, lib_input) : + conv_0_out = conv_0(sequence_input) + + #Connect group of res blocks + output_tensor = conv_0_out + + #Res block group execution + skip_conv_outs = [] + for group_ix in range(n_groups) : + skip_conv_out = skip_convs[group_ix](output_tensor) + skip_conv_outs.append(skip_conv_out) + + for layer_ix in range(n_resblocks_per_group) : + output_tensor = resblock_groups[group_ix][layer_ix](output_tensor) + + #Last res block extr conv + last_block_conv_out = last_block_conv(output_tensor) + + skip_add_out = last_block_conv_out + for group_ix in range(n_groups) : + skip_add_out = skip_add([skip_add_out, skip_conv_outs[group_ix]]) + + #Final conv out + final_conv_out = extend_tensor(final_conv(skip_add_out)) + + #Add library bias + lib_conv_out = lib_add([final_conv_out, lib_conv(expand_lib(lib_input))]) + + return lib_conv_out + + return _net_func + +def load_aparent_model(batch_size, use_sample_weights=False, drop_rate=0.25) : + + #APARENT parameters + seq_input_shape = (1, 205, 4) + lib_input_shape = (13,) + num_outputs_iso = 1 + num_outputs_cut = 206 + + #Plasmid model definition + + #Resnet function + resnet = load_residual_network( + n_groups=7, + n_resblocks_per_group=4, + n_channels=32, + window_size=3, + dilation_rates=[1, 2, 4, 8, 4, 2, 1], + drop_rate=drop_rate + ) + + #Inputs + seq_input = Input(shape=seq_input_shape) + lib_input = Input(shape=lib_input_shape) + plasmid_count = Input(shape=(1,)) + + #Outputs + true_iso = Input(shape=(num_outputs_iso,)) + true_cut = Input(shape=(num_outputs_cut,)) + + cut_score = resnet(seq_input, lib_input) + + cut_prob = Lambda(lambda x: K.softmax(x[:, 0, :, 0], axis=-1))(cut_score) + iso_prob = Lambda(lambda cl: K.expand_dims(K.sum(cl[:, 80:80+30], axis=-1), axis=-1))(cut_prob) + + plasmid_model = Model( + inputs=[ + seq_input, + lib_input + ], + outputs=[ + iso_prob, + cut_prob + ] + ) + + #Loss model definition + sigmoid_kl_divergence = get_sigmoid_kl_divergence(batch_size, use_sample_weights=use_sample_weights) + kl_divergence = get_kl_divergence(batch_size, use_sample_weights=use_sample_weights) + + plasmid_loss_iso = Lambda(sigmoid_kl_divergence, output_shape = (1,))([true_iso, iso_prob, plasmid_count]) + plasmid_loss_cut = Lambda(kl_divergence, output_shape = (1,))([true_cut, cut_prob, plasmid_count]) + + total_loss = Lambda( + lambda l: 0.5 * l[0] + 0.5 * l[1], + output_shape = (1,) + )( + [ + plasmid_loss_iso, + plasmid_loss_cut + ] + ) + + loss_model = Model([ + seq_input, + lib_input, + plasmid_count, + true_iso, + true_cut + ], total_loss) + + return [ ('plasmid_iso_cut_resnet', plasmid_model), ('loss', loss_model) ] diff --git a/aparent/model/aparent_model_plasmid_resnet_var_batch_size_inference_mode.py b/aparent/model/aparent_model_plasmid_resnet_var_batch_size_inference_mode.py new file mode 100644 index 0000000..1c6b79e --- /dev/null +++ b/aparent/model/aparent_model_plasmid_resnet_var_batch_size_inference_mode.py @@ -0,0 +1,192 @@ +from __future__ import print_function +import keras +from keras.models import Sequential, Model, load_model +from keras.layers import Dense, Dropout, Activation, Flatten, Input, Lambda +from keras.layers import Conv2D, MaxPooling2D, LocallyConnected2D, Conv1D, MaxPooling1D, LocallyConnected1D, LSTM, ConvLSTM2D, BatchNormalization +from keras.layers import Concatenate, Reshape +from keras.callbacks import ModelCheckpoint, EarlyStopping +from keras import regularizers +from keras import backend as K +import keras.losses + +import tensorflow as tf + +import isolearn.keras as iso + +from aparent.losses import * + +def make_resblock(n_channels=64, window_size=8, dilation_rate=1, group_ix=0, layer_ix=0, drop_rate=0.0) : + + #Initialize res block layers + batch_norm_0 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_0') + + relu_0 = Lambda(lambda x: K.relu(x, alpha=0.0)) + + conv_0 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_0') + + batch_norm_1 = BatchNormalization(name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_batch_norm_1') + + relu_1 = Lambda(lambda x: K.relu(x, alpha=0.0)) + + conv_1 = Conv2D(n_channels, (1, window_size), dilation_rate=dilation_rate, strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_conv_1') + + skip_1 = Lambda(lambda x: x[0] + x[1], name='aparent_resblock_' + str(group_ix) + '_' + str(layer_ix) + '_skip_1') + + drop_1 = None + if drop_rate > 0.0 : + drop_1 = Dropout(drop_rate) + + #Execute res block + def _resblock_func(input_tensor) : + batch_norm_0_out = batch_norm_0(input_tensor, training=False) + relu_0_out = relu_0(batch_norm_0_out) + conv_0_out = conv_0(relu_0_out) + + batch_norm_1_out = batch_norm_1(conv_0_out, training=False) + relu_1_out = relu_1(batch_norm_1_out) + + if drop_rate > 0.0 : + conv_1_out = drop_1(conv_1(relu_1_out)) + else : + conv_1_out = conv_1(relu_1_out) + + skip_1_out = skip_1([conv_1_out, input_tensor]) + + return skip_1_out + + return _resblock_func + +def load_residual_network(n_groups=1, n_resblocks_per_group=4, n_channels=32, window_size=8, dilation_rates=[1], drop_rate=0.0) : + + #Discriminator network definition + conv_0 = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_conv_0') + + skip_convs = [] + resblock_groups = [] + for group_ix in range(n_groups) : + + skip_convs.append(Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_skip_conv_' + str(group_ix))) + + resblocks = [] + for layer_ix in range(n_resblocks_per_group) : + resblocks.append(make_resblock(n_channels=n_channels, window_size=window_size, dilation_rate=dilation_rates[group_ix], group_ix=group_ix, layer_ix=layer_ix, drop_rate=drop_rate)) + + resblock_groups.append(resblocks) + + last_block_conv = Conv2D(n_channels, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_last_block_conv') + + skip_add = Lambda(lambda x: x[0] + x[1], name='aparent_skip_add') + + final_conv = Conv2D(1, (1, 1), strides=(1, 1), padding='same', activation='linear', kernel_initializer='glorot_normal', name='aparent_final_conv') + + extend_tensor = Lambda(lambda x: K.concatenate([x, K.zeros((K.shape(x)[0], 1, 1, 1))], axis=2), name='aparent_extend_tensor') + + expand_lib = Lambda(lambda x: K.tile(K.expand_dims(K.expand_dims(x, axis=1), axis=2), (1, 1, 206, 1)), name='aparent_expand_lib') + + lib_conv = LocallyConnected2D(1, (1, 1), strides=(1, 1), padding='valid', activation='linear', kernel_initializer='glorot_normal', name='aparent_lib_conv') + + lib_add = Lambda(lambda x: x[0] + x[1], name='aparent_lib_add') + + def _net_func(sequence_input, lib_input) : + conv_0_out = conv_0(sequence_input) + + #Connect group of res blocks + output_tensor = conv_0_out + + #Res block group execution + skip_conv_outs = [] + for group_ix in range(n_groups) : + skip_conv_out = skip_convs[group_ix](output_tensor) + skip_conv_outs.append(skip_conv_out) + + for layer_ix in range(n_resblocks_per_group) : + output_tensor = resblock_groups[group_ix][layer_ix](output_tensor) + + #Last res block extr conv + last_block_conv_out = last_block_conv(output_tensor) + + skip_add_out = last_block_conv_out + for group_ix in range(n_groups) : + skip_add_out = skip_add([skip_add_out, skip_conv_outs[group_ix]]) + + #Final conv out + final_conv_out = extend_tensor(final_conv(skip_add_out)) + + #Add library bias + lib_conv_out = lib_add([final_conv_out, lib_conv(expand_lib(lib_input))]) + + return lib_conv_out + + return _net_func + +def load_aparent_model(batch_size, use_sample_weights=False, drop_rate=0.25) : + + #APARENT parameters + seq_input_shape = (1, 205, 4) + lib_input_shape = (13,) + num_outputs_iso = 1 + num_outputs_cut = 206 + + #Plasmid model definition + + #Resnet function + resnet = load_residual_network( + n_groups=7, + n_resblocks_per_group=4, + n_channels=32, + window_size=3, + dilation_rates=[1, 2, 4, 8, 4, 2, 1], + drop_rate=drop_rate + ) + + #Inputs + seq_input = Input(shape=seq_input_shape) + lib_input = Input(shape=lib_input_shape) + plasmid_count = Input(shape=(1,)) + + #Outputs + true_iso = Input(shape=(num_outputs_iso,)) + true_cut = Input(shape=(num_outputs_cut,)) + + cut_score = resnet(seq_input, lib_input) + + cut_prob = Lambda(lambda x: K.softmax(x[:, 0, :, 0], axis=-1))(cut_score) + iso_prob = Lambda(lambda cl: K.expand_dims(K.sum(cl[:, 80:80+30], axis=-1), axis=-1))(cut_prob) + + plasmid_model = Model( + inputs=[ + seq_input, + lib_input + ], + outputs=[ + iso_prob, + cut_prob + ] + ) + + #Loss model definition + sigmoid_kl_divergence = get_sigmoid_kl_divergence(batch_size, use_sample_weights=use_sample_weights) + kl_divergence = get_kl_divergence(batch_size, use_sample_weights=use_sample_weights) + + plasmid_loss_iso = Lambda(sigmoid_kl_divergence, output_shape = (1,))([true_iso, iso_prob, plasmid_count]) + plasmid_loss_cut = Lambda(kl_divergence, output_shape = (1,))([true_cut, cut_prob, plasmid_count]) + + total_loss = Lambda( + lambda l: 0.5 * l[0] + 0.5 * l[1], + output_shape = (1,) + )( + [ + plasmid_loss_iso, + plasmid_loss_cut + ] + ) + + loss_model = Model([ + seq_input, + lib_input, + plasmid_count, + true_iso, + true_cut + ], total_loss) + + return [ ('plasmid_iso_cut_resnet', plasmid_model), ('loss', loss_model) ]