From c5b6724c0a200adb1c9e59b60ba6974861a874f8 Mon Sep 17 00:00:00 2001 From: hy395 Date: Wed, 3 Jul 2024 00:32:52 -0700 Subject: [PATCH] add se_adapter and locon --- src/baskerville/HY_helper.py | 75 --- src/baskerville/blocks.py | 43 +- .../{ => helpers}/transfer_helper.py | 551 ++++++++++-------- src/baskerville/layers.py | 100 +++- src/baskerville/scripts/hound_transfer.py | 228 +++----- src/baskerville/seqnn.py | 24 +- src/baskerville/trainer.py | 48 +- tests/test_transfer/test_ia3.ipynb | 2 +- 8 files changed, 538 insertions(+), 533 deletions(-) delete mode 100644 src/baskerville/HY_helper.py rename src/baskerville/{ => helpers}/transfer_helper.py (61%) diff --git a/src/baskerville/HY_helper.py b/src/baskerville/HY_helper.py deleted file mode 100644 index d4de926..0000000 --- a/src/baskerville/HY_helper.py +++ /dev/null @@ -1,75 +0,0 @@ -import numpy as np -from basenji import dna_io -import pysam -import pyBigWig - - -def make_seq_1hot(genome_open, chrm, start, end, seq_len): - if start < 0: - seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end) - else: - seq_dna = genome_open.fetch(chrm, start, end) - - #Extend to full length - if len(seq_dna) < seq_len: - seq_dna += 'N'*(seq_len-len(seq_dna)) - - seq_1hot = dna_io.dna_1hot(seq_dna) - return seq_1hot - -# Helper function to get (padded) one-hot -def process_sequence(fasta_file, chrom, start, end, seq_len=524288) : - - fasta_open = pysam.Fastafile(fasta_file) - seq_len_actual = end - start - - #Pad sequence to input window size - start -= (seq_len - seq_len_actual) // 2 - end += (seq_len - seq_len_actual) // 2 - - #Get one-hot - sequence_one_hot = make_seq_1hot(fasta_open, chrom, start, end, seq_len) - - return sequence_one_hot.astype('float32') - -def compute_cov(seqnn_model, chr, start, end): - seq_len = seqnn_model.model.layers[0].input.shape[1] - seq1hot = process_sequence('/home/yuanh/programs/genomes/hg38/hg38.fa', chr, start, end, seq_len=seq_len) - out = seqnn_model.model(seq1hot[None, ]) - return out.numpy() - -def write_bw(bw_file, chr, start, end, values, span=32): - bw_out = pyBigWig.open(bw_file, 'w') - header = [] - header.append((chr, end+1)) - bw_out.addHeader(header) - bw_out.addEntries(chr, start, values=values, span=span, step=span) - bw_out.close() - -def transform(seq_cov, clip=384, clip_soft=320, scale=0.3): - seq_cov = scale * seq_cov # scale - seq_cov = -1 + np.sqrt(1+seq_cov) # variant stabilize - clip_mask = (seq_cov > clip_soft) # soft clip - seq_cov[clip_mask] = clip_soft-1 + np.sqrt(seq_cov[clip_mask] - clip_soft+1) - seq_cov = np.clip(seq_cov, -clip, clip) # hard clip - return seq_cov - -def untransform(cov, scale=0.3, clip_soft=320, pool_width=32): - - # undo clip_soft - cov_unclipped = (cov - clip_soft + 1)**2 + clip_soft - 1 - unclip_mask = (cov > clip_soft) - cov[unclip_mask] = cov_unclipped[unclip_mask] - - # undo sqrt - cov = (cov +1)**2 - 1 - - # undo scale - cov = cov / scale - - # undo sum - cov = cov / pool_width - - return cov - - diff --git a/src/baskerville/blocks.py b/src/baskerville/blocks.py index 8e74e31..ffbdb75 100644 --- a/src/baskerville/blocks.py +++ b/src/baskerville/blocks.py @@ -149,8 +149,6 @@ def conv_dna( conv_type="standard", kernel_initializer="he_normal", padding="same", - transfer_se=False, - se_ratio=16, ): """Construct a single convolution block, assumed to be operating on DNA. @@ -197,19 +195,7 @@ def conv_dna( kernel_initializer=kernel_initializer, kernel_regularizer=tf.keras.regularizers.l2(l2_scale), )(current) - - # squeeze-excite for transfer - if transfer_se: - se_out = squeeze_excite(current, - activation=None, - additive=False, - bottleneck_ratio=se_ratio, - use_bias=False, - kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), - scale_fun='tanh' - ) - current = current + se_out - + # squeeze-excite if se: current = squeeze_excite(current) @@ -281,8 +267,6 @@ def conv_nac( kernel_initializer="he_normal", padding="same", se=False, - transfer_se=False, - se_ratio=16, ): """Construct a single convolution block. @@ -342,18 +326,6 @@ def conv_nac( kernel_regularizer=tf.keras.regularizers.l2(l2_scale), )(current) - # squeeze-excite for transfer - if transfer_se: - se_out = squeeze_excite(current, - activation=None, - additive=False, - bottleneck_ratio=se_ratio, - use_bias=False, - kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), - scale_fun='tanh' - ) - current = current + se_out - # squeeze-excite if se: current = squeeze_excite(current) @@ -484,8 +456,6 @@ def unet_conv( bn_momentum=0.99, kernel_size=1, kernel_initializer="he_normal", - transfer_se=False, - se_ratio=16, upsample_conv=False, ): """Construct a feature pyramid network block. @@ -561,17 +531,6 @@ def unet_conv( kernel_initializer=kernel_initializer, )(current) - if transfer_se: - se_out = squeeze_excite(current, - activation=None, - additive=False, - bottleneck_ratio=se_ratio, - use_bias=False, - kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), - scale_fun='tanh' - ) - current = current + se_out - # dropout if dropout > 0: current = tf.keras.layers.Dropout(dropout)(current) diff --git a/src/baskerville/transfer_helper.py b/src/baskerville/helpers/transfer_helper.py similarity index 61% rename from src/baskerville/transfer_helper.py rename to src/baskerville/helpers/transfer_helper.py index 72acefc..aa178b6 100644 --- a/src/baskerville/transfer_helper.py +++ b/src/baskerville/helpers/transfer_helper.py @@ -46,33 +46,92 @@ def keras2dict(model): layer_parent_dict[layer_name].append(layer.name) return layer_parent_dict +# lora requires change model.h5 weight order. +# locon and ia3 don't modify model in place. +def var_reorder(weight_h5): + # assumes weight_h5 model saved with seqnn_model.save() + # [i.name for i in model.layers[30].weights] to check for multihead_attention layer weights order. + # model.load_weights() load weights sequencially, assuming h5 weights are in the right order. + # When inserting lora, multihead_attention layer weights order changed. + # multihead_attention layer weights order is saved inside f['model_weights']['multihead_attention'].attrs + # After saving the weight_merged model, we need to go into the weights.h5, and change the attrs in multihead attention. + var_init_order = ['r_w_bias:0:0', + 'r_r_bias:0:0', + 'q_layer/kernel:0', + 'k_layer/kernel:0', + 'v_layer/kernel:0', + 'embedding_layer/kernel:0', + 'embedding_layer/bias:0', + 'r_k_layer/kernel:0'] + + f = h5py.File(weight_h5, 'r+') + layers = [i for i in list(f['model_weights'].keys()) if 'multihead_attention' in i] + for l_name in layers: + new_name_order = [l_name+'/'+i for i in var_init_order] + f['model_weights'][l_name].attrs.modify(name='weight_names', value=new_name_order) + f.close() + + +# houlsby requires architecture change. +# thus we need to modify json. +def modify_json(input_json, output_json, adapter, latent=8, se_rank=None, conv_select=None): + + with open(input_json) as params_open: + params = json.load(params_open) + + # houlsby + if adapter=='adapterHoulsby': + params["model"]['adapter']= 'houlsby' + params["model"]['adapter_latent']= latent + + # houlsby_se + elif adapter=='houlsby_se': + params["model"]['adapter']= 'houlsby_se' + params["model"]['adapter_latent']= latent + params["model"]['se_rank']= se_rank + params["model"]['conv_select']= conv_select + + else: + raise ValueError("adapter must be adapterHoulsby or houlsby_se") + + ### output + with open(output_json, 'w') as params_open: + json.dump(params, params_open, indent=4) + ###################### # add houlsby layers # ###################### -def add_houlsby(input_model, strand_pair, latent_size=16): +def add_houlsby(input_model, strand_pair, latent_size=8): # take seqnn_model as input # output a new seqnn_model object # only the adapter, and layer_norm are trainable - - model = tf.keras.Model(inputs=input_model.input, - outputs=input_model.layers[-2].output) # remove the switch_reverse layer - - # save current graph - layer_parent_dict_old = keras2dict(model) - - layer_output_dict_new = {} # the output tensor of each layer in the new graph - layer_output_dict_new.update({model.layers[0].name: model.input}) - - # remove switch_reverse + + ################## + # houlsby layers # + ################## + houlsby_layers = [] + for i in range(len(input_model.layers)-1): + layer = input_model.layers[i] + next_layer = input_model.layers[i+1] + if re.match('dropout', layer.name) and re.match('add', next_layer.name): + houlsby_layers += [next_layer.name] + + ################### + # construct model # + ################### + layer_parent_dict_old = keras2dict(input_model) + # remove switch_reverse_layer to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] for i in to_fix: del layer_parent_dict_old[i] - + # create new graph + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({input_model.layers[0].name: input_model.input}) # Iterate over all layers after the input model_outputs = [] reverse_bool = None - - for layer in model.layers[1:]: + + for layer in input_model.layers[1:-1]: # parent layers parent_layers = layer_parent_dict_old[layer.name] @@ -84,14 +143,11 @@ def add_houlsby(input_model, strand_pair, latent_size=16): if re.match('stochastic_reverse_complement', layer.name): x, reverse_bool = layer(layer_input) - # insert adapter: - elif re.match('add', layer.name): - if any([re.match('dropout', i) for i in parent_layers]): - print('adapter added before:%s'%layer.name) - x = layers.AdapterHoulsby(latent_size=latent_size)(layer_input[1]) - x = layer([layer_input[0], x]) - else: - x = layer(layer_input) + # insert houlsby layer: + elif layer.name in houlsby_layers: + print('adapter added before:%s'%layer.name) + x = layers.AdapterHoulsby(latent_size=latent_size)(layer_input[1]) + x = layer([layer_input[0], x]) else: x = layer(layer_input) @@ -99,12 +155,10 @@ def add_houlsby(input_model, strand_pair, latent_size=16): # save the output tensor of every layer layer_output_dict_new.update({layer.name: x}) - final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[model.layers[-1].name], reverse_bool]) - model_adapter = tf.keras.Model(inputs=model.inputs, outputs=final) + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool]) + model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final) - ################# # set trainable # - ################# for l in model_adapter.layers[:-2]: # trunk if re.match('layer_normalization|adapter_houlsby', l.name): l.trainable = True @@ -122,10 +176,10 @@ def add_houlsby(input_model, strand_pair, latent_size=16): return model_adapter -################### -# add lora layers # -################### -def add_lora(input_model, rank=8, alpha=16, mode='default'): +############### +# lora layers # +############### +def add_lora(input_model, rank=8, alpha=16, mode='default', report_param=True): # take seqnn.model as input # replace _q_layer, _v_layer in multihead_attention # optionally replace _k_layer, _embedding_layer @@ -175,25 +229,81 @@ def add_lora(input_model, rank=8, alpha=16, mode='default'): params_added += param_count(l._k_layer.up_layer) params_added += param_count(l._embedding_layer.down_layer) params_added += param_count(l._embedding_layer.up_layer) + + if report_param: + print('params added/unfrozen by lora: %d'%params_added) + +############### +# lora layers # +############### +def add_lora_conv(input_model, conv_select=None): + + # add lora layers + add_lora(input_model, rank=8, alpha=16, mode='default', report_param=False) + + # list all conv layers + conv_layers = [] + for layer in input_model.layers: + if re.match('conv1d', layer.name): + conv_layers += [layer.name] + if conv_select is None: + conv_select = len(conv_layers) + if conv_select > len(conv_layers): + raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers)) + + # set conv layers trainable + trainable_conv = conv_layers[-conv_select:] + for layer in input_model.layers: + if layer.name in trainable_conv: + layer.trainable=True - print('params added/unfrozen by lora: %d'%params_added) + # expected number of trainable params added/unfrozen: + params_added = 0 + for l in input_model.layers: + if re.match('multihead_attention', l.name): + params_added += param_count(l._q_layer.down_layer) + params_added += param_count(l._q_layer.up_layer) + params_added += param_count(l._v_layer.down_layer) + params_added += param_count(l._v_layer.up_layer) + elif l.name in trainable_conv: + params_added += param_count(l) + + print('params added/unfrozen by lora_conv: %d'%params_added) -################## -# add ia3 layers # -################## +# merge lora weights # +def merge_lora_layer(lora_layer): + down_weights = lora_layer.down_layer.kernel + up_weights = lora_layer.up_layer.kernel + increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * lora_layer.scale + lora_layer.original_layer.kernel.assign_add(increment_weights) + return lora_layer.original_layer + +def merge_lora(input_model): + for layer in input_model.layers: + if 'multihead_attention' in layer.name: + if isinstance(layer._q_layer, layers.Lora): + layer._q_layer = merge_lora_layer(layer._q_layer) + if isinstance(layer._v_layer, layers.Lora): + layer._v_layer = merge_lora_layer(layer._v_layer) + if isinstance(layer._k_layer, layers.Lora): + layer._k_layer = merge_lora_layer(layer._k_layer) + if isinstance(layer._embedding_layer, layers.Lora): + layer._embedding_layer = merge_lora_layer(layer._embedding_layer) + input_model(input_model.input) + + +############## +# IA3 layers # +############## def add_ia3(input_model, strand_pair): - #################### # add to kv layers # - #################### for layer in input_model.layers: if re.match('multihead_attention', layer.name): layer._k_layer = layers.IA3(layer._k_layer, trainable=True) layer._v_layer = layers.IA3(layer._v_layer, trainable=True) - ################### # add to ff layer # - ################### # save old graph to dictionary layer_parent_dict_old = keras2dict(input_model) @@ -231,9 +341,7 @@ def add_ia3(input_model, strand_pair): final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool]) model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final) - ################# # set trainable # - ################# for layer in model_adapter._flatten_layers(): lst_of_sublayers = list(layer._flatten_layers()) if len(lst_of_sublayers) == 1: @@ -264,88 +372,6 @@ def add_ia3(input_model, strand_pair): return model_adapter - -############### -# modify json # -############### -# houlsby and squeeze-excite -def modify_json(input_json, output_json, adapter='adapterHoulsby', latent=None, conv=None, se_ratio=None): - - with open(input_json) as params_open: - params = json.load(params_open) - - # houlsby # - if adapter=='adapterHoulsby': - params["model"]["trunk"][2]['adapter']= 'houlsby' - params["model"]["trunk"][2]['latent']= latent - - # squeeze-excite # - if conv=='se_all' or conv=='se_all_bn': - for i in [0, 1, 3, 4]: - params['model']['trunk'][i]['transfer_se']=True - params['model']['trunk'][i]['se_ratio']=se_ratio - - elif conv=='se' or conv=='se_bn': - for i in [0, 1]: - params['model']['trunk'][i]['transfer_se']=True - params['model']['trunk'][i]['se_ratio']=se_ratio - - else: - pass - - ### output - with open(output_json, 'w') as params_open: - json.dump(params, params_open, indent=4) - - -###################### -# merge lora weights # -###################### -def merge_lora_layer(lora_layer): - down_weights = lora_layer.down_layer.kernel - up_weights = lora_layer.up_layer.kernel - increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * lora_layer.scale - lora_layer.original_layer.kernel.assign_add(increment_weights) - return lora_layer.original_layer - -def merge_lora(input_model, mode='default'): - for layer in input_model.layers: - if 'multihead_attention' in layer.name: - # default loRA - layer._q_layer = merge_lora_layer(layer._q_layer) - layer._v_layer = merge_lora_layer(layer._v_layer) - if mode=='full': - layer._k_layer = merge_lora_layer(layer._k_layer) - layer._embedding_layer = merge_lora_layer(layer._embedding_layer) - input_model(input_model.input) - -# correct weights.h5 weight order -def var_reorder(weight_h5): - # assumes weight_h5 model saved with seqnn_model.save() - # [i.name for i in model.layers[30].weights] to check for multihead_attention layer weights order. - # model.load_weights() load weights sequencially, assuming h5 weights are in the right order. - # When inserting lora/ia3, multihead_attention layer weights order changed. - # multihead_attention layer weights order is saved inside f['model_weights']['multihead_attention'].attrs - # After saving the weight_merged model, we need to go into the weights.h5, and change the attrs in multihead attention. - var_init_order = ['r_w_bias:0:0', - 'r_r_bias:0:0', - 'q_layer/kernel:0', - 'k_layer/kernel:0', - 'v_layer/kernel:0', - 'embedding_layer/kernel:0', - 'embedding_layer/bias:0', - 'r_k_layer/kernel:0'] - - f = h5py.File(weight_h5, 'r+') - layers = [i for i in list(f['model_weights'].keys()) if 'multihead_attention' in i] - for l_name in layers: - new_name_order = [l_name+'/'+i for i in var_init_order] - f['model_weights'][l_name].attrs.modify(name='weight_names', value=new_name_order) - f.close() - -##################### -# merge ia3 weights # -##################### def merge_ia3(original_model, ia3_model): # original model contains pre-trained weights # ia3 model is the fine-tuned ia3 model @@ -366,134 +392,187 @@ def merge_ia3(original_model, ia3_model): else: layer.set_weights(ia3_model.layers[i].get_weights()) -''' -###################### -# add squeeze excite # -###################### -def add_se(input_model, strand_pair, bottleneck_ratio=8, insert_mode='pre_att', unfreeze_bn=False): - # add squeeze-excitation blocks after conv - # input_model should be properly frozen - # pre_att: add se_block to pre-attention conv1d - # all: add se_block to pre-attention conv1d and post-attention separable_conv1d - - if insert_mode not in ['pre_att','all']: - raise ValueError("insert_mode must be pre_att or all") +############# +# add locon # +############# +def add_locon(input_model, strand_pair, conv_select=None, rank=4, alpha=1): - model = tf.keras.Model(inputs=input_model.input, - outputs=input_model.layers[-2].output) # remove the switch_reverse layer + # first add lora to attention + add_lora(input_model, report_param=False) - # save current graph - layer_parent_dict_old = keras2dict(model) + # decide: + # 1. whether conv1 is trainable + # 2. which conv layers to add loRA - layer_output_dict_new = {} # the output tensor of each layer in the new graph - layer_output_dict_new.update({model.layers[0].name: model.input}) + # all conv layers + conv_layers = [] + for layer in input_model.layers: + if re.match('conv1d', layer.name): + conv_layers += [layer.name] + + if conv_select is None: + conv_select = len(conv_layers) + + if conv_select > len(conv_layers): + raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers)) + + locon_layers = [] + conv1_tune = False + if conv_select == len(conv_layers): + locon_layers = conv_layers[1:] + conv1_tune = True + else: + locon_layers = conv_layers[-conv_select:] + + layer_parent_dict_old = keras2dict(input_model) - # remove switch_reverse + # remove switch_reverse_layer to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] for i in to_fix: del layer_parent_dict_old[i] + + # create new graph + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({input_model.layers[0].name: input_model.input}) # Iterate over all layers after the input model_outputs = [] reverse_bool = None + for layer in input_model.layers[1:-1]: - for layer in model.layers[1:]: - - # parent layers + # get layer inputs parent_layers = layer_parent_dict_old[layer.name] - - # layer inputs layer_input = [layer_output_dict_new[parent] for parent in parent_layers] if len(layer_input) == 1: layer_input = layer_input[0] - if layer.name.startswith("stochastic_reverse_complement"): + # construct + if re.match('stochastic_reverse_complement', layer.name): x, reverse_bool = layer(layer_input) - - # insert squeeze-excite layer: - elif layer.name.startswith("conv1d"): - se_layer = layers.SqueezeExcite( - activation=None, # no activation before squeezing - additive=False, # use sigmoid multiplicative scaling - bottleneck_ratio=bottleneck_ratio, # bottleneck ratio - use_bias=False, # ignore bias - kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization - scale_fun='tanh' - ) - x = layer(layer_input) - x = x + se_layer(x) - - elif layer.name.startswith("separable_conv1d"): - if insert_mode=='all': - se_layer = layers.SqueezeExcite( - activation=None, # no activation before squeezing - additive=False, # use sigmoid multiplicative scaling - bottleneck_ratio=bottleneck_ratio, # bottleneck ratio - use_bias=False, # ignore bias - kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization - scale_fun='tanh' - ) - x = layer(layer_input) - x = x + se_layer(x) - else: - x = layer(layer_input) - + elif layer.name in locon_layers: + x = layers.Locon(layer, trainable=True, rank=rank, alpha=alpha)(layer_input) else: x = layer(layer_input) - # save the output tensor of every layer + # save layers to dictionary layer_output_dict_new.update({layer.name: x}) - final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[model.layers[-1].name], reverse_bool]) - model_final = tf.keras.Model(inputs=model.inputs, outputs=final) - - # unfreeze layers - for l in model_final.layers: # set trunk - if l.name.startswith("squeeze_excite"): l.trainable = True + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool]) + model_adapter = tf.keras.Model(inputs=input_model.inputs, outputs=final) - if unfreeze_bn: - for l in model_final.layers: - if l.name.startswith("batch_normalization"): l.trainable=True + if conv1_tune: + model_adapter.get_layer(name=conv_layers[0]).trainable = True # expected number of trainable params added/unfrozen: params_added = 0 - for l in model_final.layers: - if l.name.startswith("squeeze_excite"): - params_added += param_count(l) - elif l.name.startswith("batch_normalization"): - if unfreeze_bn: params_added += param_count(l, type='trainable') - print('params added/unfrozen by se_block: %d'%params_added) - - return model_final + if conv1_tune: + params_added += param_count(model_adapter.get_layer(name=conv_layers[0])) + for l in model_adapter.layers: + if re.match('multihead_attention', l.name): + params_added += param_count(l._q_layer.down_layer) + params_added += param_count(l._q_layer.up_layer) + params_added += param_count(l._v_layer.down_layer) + params_added += param_count(l._v_layer.up_layer) + if l.name in locon_layers: + params_added += param_count(l.down_layer) + params_added += param_count(l.up_layer) + + print('params added/unfrozen by lora: %d'%params_added) + return model_adapter -def add_houlsby_se(input_model, strand_pair, houlsby_latent=8, bottleneck_ratio=8, insert_mode='pre_att', unfreeze_bn=False): +#### functions to merge locon +def lora_increment(layer): + down_weights = layer.down_layer.kernel + up_weights = layer.up_layer.kernel + increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * layer.scale + return increment_weights + +def locon_increment(layer): + down_weights = layer.down_layer.kernel + up_weights = layer.up_layer.kernel[0] + increment_weights = tf.einsum("abc,cd->abd", down_weights, up_weights) * layer.scale + return increment_weights + +def merge_locon(original_model, locon_model): + # original model contains pre-trained weights + for i, layer in enumerate(original_model.layers): + + # lora layers + if re.match('multihead_attention', layer.name): + q = locon_model.layers[i]._q_layer + k = locon_model.layers[i]._k_layer + v = locon_model.layers[i]._v_layer + e = locon_model.layers[i]._embedding_layer + if isinstance(q, layers.Lora): + increment_weights = lora_increment(q) + layer._q_layer.kernel.assign_add(increment_weights) + if isinstance(v, layers.Lora): + increment_weights = lora_increment(v) + layer._v_layer.kernel.assign_add(increment_weights) + if isinstance(k, layers.Lora): + increment_weights = lora_increment(k) + layer._k_layer.kernel.assign_add(increment_weights) + if isinstance(e, layers.Lora): + increment_weights = lora_increment(e) + layer._embedding_layer.kernel.assign_add(increment_weights) + + # locon layers + elif isinstance(locon_model.layers[i], layers.Locon): + increment_weights = locon_increment(locon_model.layers[i]) + layer.kernel.assign_add(increment_weights) + + else: + layer.set_weights(locon_model.layers[i].get_weights()) + + +############## +# houlsby_se # +############## +def add_houlsby_se(input_model, strand_pair, houlsby_latent=8, conv_select=None, se_rank=16): # add squeeze-excitation blocks after conv # input_model should be properly frozen # pre_att: add se_block to pre-attention conv1d # all: add se_block to pre-attention conv1d and post-attention separable_conv1d - - if insert_mode not in ['pre_att','all']: - raise ValueError("insert_mode must be pre_att or all") - model = tf.keras.Model(inputs=input_model.input, - outputs=input_model.layers[-2].output) # remove the switch_reverse layer - - # save current graph - layer_parent_dict_old = keras2dict(model) - - layer_output_dict_new = {} # the output tensor of each layer in the new graph - layer_output_dict_new.update({model.layers[0].name: model.input}) - - # remove switch_reverse + ################## + # houlsby layers # + ################## + houlsby_layers = [] + for i in range(len(input_model.layers)-1): + layer = input_model.layers[i] + next_layer = input_model.layers[i+1] + if re.match('dropout', layer.name) and re.match('add', next_layer.name): + houlsby_layers += [next_layer.name] + + ############# + # SE layers # + ############# + conv_layers = [] + for layer in input_model.layers: + if re.match('conv1d', layer.name): + conv_layers += [layer.name] + if conv_select is None: + se_layers = conv_layers[1:] + if conv_select >= len(conv_layers): + raise ValueError("conv_select must be less than number of conv layers %d."%len(conv_layers)) + se_layers = conv_layers[-conv_select:] + + ################### + # construct model # + ################### + layer_parent_dict_old = keras2dict(input_model) + # remove switch_reverse_layer to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] for i in to_fix: del layer_parent_dict_old[i] - + # create new graph + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({input_model.layers[0].name: input_model.input}) # Iterate over all layers after the input model_outputs = [] reverse_bool = None - for layer in model.layers[1:]: + for layer in input_model.layers[1:-1]: # parent layers parent_layers = layer_parent_dict_old[layer.name] @@ -505,42 +584,24 @@ def add_houlsby_se(input_model, strand_pair, houlsby_latent=8, bottleneck_ratio= if layer.name.startswith("stochastic_reverse_complement"): x, reverse_bool = layer(layer_input) - # insert houlsby: - elif re.match('add', layer.name): - if any([re.match('dropout', i) for i in parent_layers]): - print('adapter added before:%s'%layer.name) - x = layers.AdapterHoulsby(latent_size=houlsby_latent)(layer_input[1]) - x = layer([layer_input[0], x]) - else: - x = layer(layer_input) + # insert houlsby layer: + elif layer.name in houlsby_layers: + print('adapter added before:%s'%layer.name) + x = layers.AdapterHoulsby(latent_size=houlsby_latent)(layer_input[1]) + x = layer([layer_input[0], x]) # insert squeeze-excite layer: - elif layer.name.startswith("conv1d"): + elif layer.name in se_layers: se_layer = layers.SqueezeExcite( activation=None, # no activation before squeezing additive=False, # use sigmoid multiplicative scaling - bottleneck_ratio=bottleneck_ratio, # bottleneck ratio + rank=se_rank, # bottleneck ratio use_bias=False, # ignore bias kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization scale_fun='tanh' ) x = layer(layer_input) x = x + se_layer(x) - - elif layer.name.startswith("separable_conv1d"): - if insert_mode=='all': - se_layer = layers.SqueezeExcite( - activation=None, # no activation before squeezing - additive=False, # use sigmoid multiplicative scaling - bottleneck_ratio=bottleneck_ratio, # bottleneck ratio - use_bias=False, # ignore bias - kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization - scale_fun='tanh' - ) - x = layer(layer_input) - x = x + se_layer(x) - else: - x = layer(layer_input) else: x = layer(layer_input) @@ -548,8 +609,8 @@ def add_houlsby_se(input_model, strand_pair, houlsby_latent=8, bottleneck_ratio= # save the output tensor of every layer layer_output_dict_new.update({layer.name: x}) - final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[model.layers[-1].name], reverse_bool]) - model_final = tf.keras.Model(inputs=model.inputs, outputs=final) + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[input_model.layers[-2].name], reverse_bool]) + model_final = tf.keras.Model(inputs=input_model.inputs, outputs=final) # set trainable for l in model_final.layers[:-2]: # trunk @@ -561,22 +622,14 @@ def add_houlsby_se(input_model, strand_pair, houlsby_latent=8, bottleneck_ratio= for l in model_final.layers: # set trunk if l.name.startswith("squeeze_excite"): l.trainable = True - if unfreeze_bn: - for l in model_final.layers: - if l.name.startswith("batch_normalization"): l.trainable=True - # expected number of trainable params added/unfrozen: params_added = 0 for l in model_final.layers: - if l.name.startswith("squeeze_excite"): - params_added += param_count(l) - elif l.name.startswith("batch_normalization"): - if unfreeze_bn: params_added += param_count(l, type='trainable') - elif l.name.startswith("adapter_houlsby"): + if re.match('squeeze_excite|adapter_houlsby', l.name): params_added += param_count(l) elif l.name.startswith("layer_normalization"): params_added += param_count(l, type='trainable') - print('params added/unfrozen by se_block: %d'%params_added) + print('params added/unfrozen by houlsby_se: %d'%params_added) return model_final -''' + diff --git a/src/baskerville/layers.py b/src/baskerville/layers.py index 2bfb5cc..c8acef3 100644 --- a/src/baskerville/layers.py +++ b/src/baskerville/layers.py @@ -149,7 +149,7 @@ def __init__(self, use_bias=False, kernel_initializer=tf.keras.initializers.HeUniform(), #kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1 / self.rank), - trainable=trainable, + trainable=True, name="lora_a" ) @@ -157,7 +157,7 @@ def __init__(self, units=self.output_dim, use_bias=False, kernel_initializer=tf.keras.initializers.Zeros(), - trainable=trainable, + trainable=True, name="lora_b" ) @@ -176,6 +176,83 @@ def get_config(self): ) return config +class Locon(tf.keras.layers.Layer): + # LoRA for conv-layer, adapted from: + # https://arxiv.org/pdf/2309.14859#page=23.84 + # https://github.com/KohakuBlueleaf/LyCORIS/blob/main/lycoris/modules/locon.py + # use default alpha and rank for locon + + def __init__(self, + original_layer, + rank=4, + alpha=1, + trainable=False, + **kwargs): + + # keep the name of this layer the same as the original conv layer. + original_layer_config = original_layer.get_config() + name = original_layer_config["name"] + kwargs.pop("name", None) + super().__init__(name=name, trainable=trainable, **kwargs) + + self.input_dim = original_layer.input_shape[-1] + self.output_dim = original_layer_config["filters"] + + if rank > self.output_dim: + raise ValueError(f"LoRA rank {rank} must be less or equal than {self.output_dim}") + + self.rank = rank + self.alpha = alpha + self.scale = alpha / rank + self.original_layer = original_layer + self.original_layer.trainable = False + + input_dim = original_layer.input_shape[-1] + output_dim = original_layer_config["filters"] + kernel_size = original_layer_config['kernel_size'][0] + stride = original_layer_config['strides'][0] + dilation_rate = original_layer_config["dilation_rate"][0] + + # Note: the original paper mentions that normal distribution was + # used for initialization. However, the official LoRA implementation + # uses "Kaiming/He Initialization". + + self.down_layer = tf.keras.layers.Conv1D( + filters=rank, + kernel_size=kernel_size, + strides=stride, + padding="same", + use_bias=False, + dilation_rate=dilation_rate, + kernel_initializer=tf.keras.initializers.HeUniform(), + name='locon_down' + ) + + self.up_layer = tf.keras.layers.Conv1D( + filters=output_dim, + kernel_size=1, + strides=stride, + padding="same", + use_bias=False, + kernel_initializer=tf.keras.initializers.Zeros(), + name='locon_up' + ) + + def call(self, inputs): + original_output = self.original_layer(inputs) + lora_output = self.up_layer(self.down_layer(inputs)) * self.scale + return original_output + lora_output + + def get_config(self): + config = super().get_config().copy() + config.update( + { + "rank": self.rank, + "alpha": self.alpha + } + ) + return config + class AdapterHoulsby(tf.keras.layers.Layer): # https://arxiv.org/abs/1902.00751 # adapted from: https://github.com/jain-harshil/Adapter-BERT @@ -227,7 +304,6 @@ def get_config(self): # Basic ############################################################ - class Scale(tf.keras.layers.Layer): """Scale the input by a learned value. @@ -678,7 +754,8 @@ def call(self, inputs, training=False): q *= self._key_size**-0.5 # [B, H, T', T] - content_logits = tf.matmul(q + self._r_w_bias, k, transpose_b=True) + #content_logits = tf.matmul(q + self._r_w_bias, k, transpose_b=True) + content_logits = tf.matmul(q + tf.cast(self._r_w_bias, dtype=inputs.dtype), k, transpose_b=True) if self._num_position_features == 0: logits = content_logits @@ -714,10 +791,12 @@ def call(self, inputs, training=False): # Add shifted relative logits to content logits. if self._content_position_bias: # [B, H, T', 2T-1] - relative_logits = tf.matmul(q + self._r_r_bias, r_k, transpose_b=True) + #relative_logits = tf.matmul(q + self._r_r_bias, r_k, transpose_b=True) + relative_logits = tf.matmul(q + tf.cast(self._r_r_bias, dtype=inputs.dtype), r_k, transpose_b=True) else: # [1, H, 1, 2T-1] - relative_logits = tf.matmul(self._r_r_bias, r_k, transpose_b=True) + #relative_logits = tf.matmul(self._r_r_bias, r_k, transpose_b=True) + relative_logits = tf.matmul(tf.cast(self._r_r_bias, dtype=inputs.dtype), r_k, transpose_b=True) # [1, H, T', 2T-1] relative_logits = tf.broadcast_to( relative_logits, @@ -804,7 +883,7 @@ def __init__( self, activation='relu', additive=False, - bottleneck_ratio=8, + rank=8, norm_type=None, bn_momentum=0.9, use_bias=True, @@ -817,7 +896,7 @@ def __init__( self.additive = additive self.norm_type = norm_type self.bn_momentum = bn_momentum - self.bottleneck_ratio = bottleneck_ratio + self.rank = rank self.kernel_initializer=kernel_initializer self.bias_initializer=bias_initializer self.use_bias=use_bias @@ -851,7 +930,7 @@ def build(self, input_shape): exit(1) self.dense1 = tf.keras.layers.Dense( - units=self.num_channels // self.bottleneck_ratio, + units=self.rank, activation="relu", use_bias=self.use_bias, kernel_initializer=self.kernel_initializer, @@ -900,8 +979,7 @@ def get_config(self): "use_bias":self.use_bias, "norm_type": self.norm_type, "bn_momentum": self.bn_momentum, - "bottleneck_ratio": self.bottleneck_ratio, - 'bottleneck_size': self.num_channels // self.bottleneck_ratio, + "rank": self.rank } ) return config diff --git a/src/baskerville/scripts/hound_transfer.py b/src/baskerville/scripts/hound_transfer.py index 85864f3..06f184c 100755 --- a/src/baskerville/scripts/hound_transfer.py +++ b/src/baskerville/scripts/hound_transfer.py @@ -28,7 +28,7 @@ from baskerville import seqnn from baskerville import trainer from baskerville import layers -from baskerville import transfer_helper +from baskerville.helpers import transfer_helper """ hound_transfer.py @@ -79,27 +79,38 @@ def main(): "--att_adapter", default=None, type=str, - help="attention layer module [adapterHoulsby, lora, lora_full, ia3]", + help="attention layer module [adapterHoulsby, lora, lora_full, ia3, locon]", ) parser.add_argument( "--att_latent", type=int, - default=16, + default=8, help="attention adapter latent size.", - ) - parser.add_argument( - "--conv_adapter", - default=None, - type=str, - help="conv layer module [conv, bn, conv_bn, squez_excit]", ) - parser.add_argument( - "--se_ratio", + "--lora_alpha", type=int, default=16, - help="se bottleneck ratio.", + help="lora alpha.", ) + parser.add_argument( + "--conv_select", + default=None, + type=int, + help="# of conv layers to insert locon/se.", + ) + parser.add_argument( + "--conv_rank", + type=int, + default=4, + help="locon/se rank.", + ) + parser.add_argument( + "--locon_alpha", + type=int, + default=1, + help="locon_alpha.", + ) parser.add_argument( "--tfr_train", default=None, @@ -171,8 +182,7 @@ def main(): params_model["strand_pair"] = strand_pairs if args.mixed_precision: - policy = mixed_precision.Policy('mixed_float16') - mixed_precision.set_global_policy(policy) + mixed_precision.set_global_policy('mixed_float16') if params_train.get("num_gpu", 1) == 1: ######################################## @@ -206,127 +216,58 @@ def main(): # attention adapter if args.att_adapter is not None: if args.att_adapter=='adapterHoulsby': - if args.conv_adapter not in ['se', 'se_bn', 'se_all','se_all_bn']: - # when att_adapter=='Houlsby' and conv_adapter=='se', do nothing. - # see conv_adapter section. - seqnn_model.model = transfer_helper.add_houlsby(seqnn_model.model, - strand_pairs[0], - latent_size=args.att_latent) + seqnn_model.model = transfer_helper.add_houlsby(seqnn_model.model, + strand_pairs[0], + latent_size=args.att_latent) elif args.att_adapter=='lora': transfer_helper.add_lora(seqnn_model.model, rank=args.att_latent, + alpha=args.lora_alpha, mode='default') elif args.att_adapter=='lora_full': transfer_helper.add_lora(seqnn_model.model, rank=args.att_latent, + alpha=args.lora_alpha, mode='full') elif args.att_adapter=='ia3': - seqnn_model.model = transfer_helper.add_ia3(seqnn_model.model, strand_pairs[0]) - - ''' - # conv adapter - # assume seqnn_model is appropriately frozen - if args.conv_adapter is not None: - if args.conv_adapter=='conv': - params_added = 0 - for l in seqnn_model.model.layers: - if l.name.startswith(("conv1d","separable_conv1d")): - l.trainable=True - params_added += transfer_helper.param_count(l, type='trainable') - print('params added/unfrozen by conv: %d'%params_added) - - elif args.conv_adapter=='conv_bn': - params_added = 0 - for l in seqnn_model.model.layers: - if l.name.startswith(("conv1d","separable_conv1d","batch_normalization")): - l.trainable=True - params_added += transfer_helper.param_count(l, type='trainable') - print('params added/unfrozen by conv_bn: %d'%params_added) - - elif args.conv_adapter=='bn': - params_added = 0 - for l in seqnn_model.model.layers: - if l.name.startswith("batch_normalization"): - l.trainable=True - params_added += transfer_helper.param_count(l, type='trainable') - print('params added/unfrozen by bn: %d'%params_added) - - ################## - # squeeze-excite # - ################## - elif args.conv_adapter in ['se','se_bn','se_all','se_all_bn']: - if args.att_adapter=='adapterHoulsby': - if args.conv_adapter=='se': - seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, - strand_pair=strand_pairs[0], - houlsby_latent=args.att_latent, - bottleneck_ratio=args.se_ratio, - insert_mode='pre_att', - unfreeze_bn=False) - elif args.conv_adapter=='se_bn': - seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, - strand_pair=strand_pairs[0], - houlsby_latent=args.att_latent, - bottleneck_ratio=args.se_ratio, - insert_mode='pre_att', - unfreeze_bn=True) - elif args.conv_adapter=='se_all': - seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, - strand_pair=strand_pairs[0], - houlsby_latent=args.att_latent, - bottleneck_ratio=args.se_ratio, - insert_mode='all', - unfreeze_bn=False) - elif args.conv_adapter=='se_all_bn': - seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, - strand_pair=strand_pairs[0], - houlsby_latent=args.att_latent, - bottleneck_ratio=args.se_ratio, - insert_mode='all', - unfreeze_bn=True) - else: - if args.conv_adapter=='se': - seqnn_model.model = transfer_helper.add_se(seqnn_model.model, - strand_pair=strand_pairs[0], - houlsby_latent=args.att_latent, - bottleneck_ratio=args.se_ratio, - insert_mode='pre_att', - unfreeze_bn=False) - elif args.conv_adapter=='se_bn': - seqnn_model.model = transfer_helper.add_se(seqnn_model.model, - strand_pair=strand_pairs[0], - houlsby_latent=args.att_latent, - bottleneck_ratio=args.se_ratio, - insert_mode='pre_att', - unfreeze_bn=True) - elif args.conv_adapter=='se_all': - seqnn_model.model = transfer_helper.add_se(seqnn_model.model, - strand_pair=strand_pairs[0], - houlsby_latent=args.att_latent, - bottleneck_ratio=args.se_ratio, - insert_mode='all', - unfreeze_bn=False) - elif args.conv_adapter=='se_all_bn': - seqnn_model.model = transfer_helper.add_se(seqnn_model.model, + seqnn_model.model = transfer_helper.add_ia3(seqnn_model.model, + strand_pairs[0]) + + elif args.att_adapter=='locon': # lora on conv+att + seqnn_model.model = transfer_helper.add_locon(seqnn_model.model, + strand_pairs[0], + conv_select=args.conv_select, + rank=args.conv_rank, + alpha=args.locon_alpha) + + elif args.att_adapter=='lora_conv': # lora on att, unfreeze_conv + transfer_helper.add_lora_conv(seqnn_model.model, conv_select=args.conv_select) + + elif args.att_adapter=='houlsby_se': # adapter on conv+att + seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, strand_pair=strand_pairs[0], - houlsby_latent=args.att_latent, - bottleneck_ratio=args.se_ratio, - insert_mode='pre_att', - unfreeze_bn=True) - ''' + conv_select=args.conv_select, + se_rank=args.conv_rank) ################# # final summary # ################# seqnn_model.model.summary() - - # initialize trainer - seqnn_trainer = trainer.Trainer( - params_train, train_data, eval_data, args.out_dir - ) - + + if args.mixed_precision: + # add additional activation to cast float16 output to float32 + seqnn_model.append_activation() + # run with loss scaling + seqnn_trainer = trainer.Trainer( + params_train, train_data, eval_data, args.out_dir, loss_scale=True + ) + else: + seqnn_trainer = trainer.Trainer( + params_train, train_data, eval_data, args.out_dir + ) + # compile model seqnn_trainer.compile(seqnn_model) @@ -344,31 +285,28 @@ def main(): ############################# if args.transfer_mode=='sparse': - # overwrite json file when needed - # for: adapterHoulsby and squeeze-excite - transfer_helper.modify_json(input_json=args.params_file, - output_json='%s/params.json'%args.out_dir, - adapter=args.att_adapter, - latent=args.att_latent, - conv=args.conv_adapter, - se_ratio=args.se_ratio) - - # merge weights when needed - # for: lora and ia3 - # save weight to: model_best.mergeW.h5 - if args.att_adapter=='lora': - seqnn_model.model.load_weights('%s/model_best.h5'%args.out_dir) - transfer_helper.merge_lora(seqnn_model.model, mode='default') - seqnn_model.save('%s/model_best.mergeW.h5'%args.out_dir) - transfer_helper.var_reorder('%s/model_best.mergeW.h5'%args.out_dir) + # for: adapterHoulsby and houlsby_se, overwrite json file + if args.att_adapter=='adapterHoulsby': + transfer_helper.modify_json(input_json=args.params_file, + output_json='%s/params.json'%args.out_dir, + adapter=args.att_adapter, + latent=args.att_latent) + + if args.att_adapter=='houlsby_se': + transfer_helper.modify_json(input_json=args.params_file, + output_json='%s/params.json'%args.out_dir, + adapter=args.att_adapter, + conv_select=args.conv_select, + se_rank=args.conv_rank + ) - if args.att_adapter=='lora_full': + # for lora, ia3, locon, save weight to: model_best.mergeW.h5 + if args.att_adapter in ['lora', 'lora_full', 'lora_conv']: seqnn_model.model.load_weights('%s/model_best.h5'%args.out_dir) - transfer_helper.merge_lora(seqnn_model.model, mode='full') + transfer_helper.merge_lora(seqnn_model.model) seqnn_model.save('%s/model_best.mergeW.h5'%args.out_dir) transfer_helper.var_reorder('%s/model_best.mergeW.h5'%args.out_dir) - - # merge ia3 weights to original, save weight to: model_best_mergeweight.h5 + if args.att_adapter=='ia3': # ia3 model ia3_model = seqnn_model.model @@ -381,6 +319,18 @@ def main(): transfer_helper.merge_ia3(original_model, ia3_model) original_model.save('%s/model_best.mergeW.h5'%args.out_dir) + if args.att_adapter=='locon': + # locon model + locon_model = seqnn_model.model + locon_model.load_weights('%s/model_best.h5'%args.out_dir) + # original model + seqnn_model2 = seqnn.SeqNN(params_model) + seqnn_model2.restore(args.restore, trunk=args.trunk) + original_model = seqnn_model2.model + # merge weights into original model + transfer_helper.merge_locon(original_model, locon_model) + original_model.save('%s/model_best.mergeW.h5'%args.out_dir) + else: ######################################## # multi GPU diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 1ffca86..82db788 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -25,7 +25,7 @@ from baskerville import dataset from baskerville import layers from baskerville import metrics - +from baskerville.helpers import transfer_helper class SeqNN: """Sequence neural network model. @@ -198,6 +198,13 @@ def build_model(self, save_reprs: bool = True): for ho in self.head_output: self.models.append(tf.keras.Model(inputs=sequence, outputs=ho)) self.model = self.models[0] + + # add adapter + if hasattr(self, 'adapter'): + for hi, head in enumerate(self.heads): + self.models[hi] = self.insert_adapter(self.models[hi]) + self.model = self.models[0] + if self.verbose: print(self.model.summary()) @@ -1093,3 +1100,18 @@ def track_sequence(self, sequence): print("model_strides", self.model_strides) print("target_lengths", self.target_lengths) print("target_crops", self.target_crops) + + # method for inserting adapter for transfer learning + def insert_adapter(self, model): + if self.adapter=='houlsby': + output_model = transfer_helper.add_houlsby(model, + self.strand_pair[0], + latent_size=self.adapter_latent) + elif self.adapter=='houlsby_se': + output_model = transfer_helper.add_houlsby_se(model, + self.strand_pair[0], + houlsby_latent=self.adapter_latent, + conv_select=self.conv_select, + se_rank=self.se_rank) + return output_model + diff --git a/src/baskerville/trainer.py b/src/baskerville/trainer.py index d7c048e..3136642 100644 --- a/src/baskerville/trainer.py +++ b/src/baskerville/trainer.py @@ -531,22 +531,40 @@ def fit_tape(self, seqnn_model): if self.strategy is None: - @tf.function - def train_step(x, y): - with tf.GradientTape() as tape: - pred = model(x, training=True) - loss = self.loss_fn(y, pred) + sum(model.losses) - train_loss(loss) - train_r(y, pred) - train_r2(y, pred) - gradients = tape.gradient(loss, model.trainable_variables) - if self.agc_clip is not None: - gradients = adaptive_clip_grad( - model.trainable_variables, gradients, self.agc_clip + if self.loss_scale: + + @tf.function + def train_step(x, y): + with tf.GradientTape() as tape: + pred = model(x, training=True) + loss = self.loss_fn(y, pred) + sum(model.losses) + scaled_loss = self.optimizer.get_scaled_loss(loss) + train_loss(loss) + train_r(y, pred) + train_r2(y, pred) + scaled_gradients = tape.gradient(scaled_loss, model.trainable_variables) + gradients = self.optimizer.get_unscaled_gradients(scaled_gradients) + self.optimizer.apply_gradients( + zip(gradients, model.trainable_variables) + ) + else: + + @tf.function + def train_step(x, y): + with tf.GradientTape() as tape: + pred = model(x, training=True) + loss = self.loss_fn(y, pred) + sum(model.losses) + train_loss(loss) + train_r(y, pred) + train_r2(y, pred) + gradients = tape.gradient(loss, model.trainable_variables) + if self.agc_clip is not None: + gradients = adaptive_clip_grad( + model.trainable_variables, gradients, self.agc_clip + ) + self.optimizer.apply_gradients( + zip(gradients, model.trainable_variables) ) - self.optimizer.apply_gradients( - zip(gradients, model.trainable_variables) - ) @tf.function def eval_step(x, y): diff --git a/tests/test_transfer/test_ia3.ipynb b/tests/test_transfer/test_ia3.ipynb index b51fdd0..7ac40ba 100644 --- a/tests/test_transfer/test_ia3.ipynb +++ b/tests/test_transfer/test_ia3.ipynb @@ -24,7 +24,7 @@ "import tensorflow as tf\n", "from baskerville import seqnn\n", "from baskerville import layers\n", - "from baskerville import transfer_helper" + "from baskerville.helpers import transfer_helper" ] }, {