diff --git a/src/baskerville/layers.py b/src/baskerville/layers.py index 1d28c2c..028a6a5 100644 --- a/src/baskerville/layers.py +++ b/src/baskerville/layers.py @@ -26,8 +26,75 @@ ##################### # transfer learning # ##################### + +class Lora(tf.keras.layers.Layer): + # https://arxiv.org/abs/2106.09685 + # adapted from: + # https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/ + # https://github.com/Elvenson/stable-diffusion-keras-ft/blob/main/layers.py + + def __init__(self, + original_layer, + rank=8, + alpha=16, + trainable=True, + **kwargs): + + # keep the name of this layer the same as the original dense layer. + original_layer_config = original_layer.get_config() + name = original_layer_config["name"] + kwargs.pop("name", None) + super().__init__(name=name, trainable=trainable, **kwargs) + + self.output_dim = original_layer_config["units"] + + if rank > self.output_dim: + raise ValueError(f"LoRA rank {rank} must be less or equal than {self.output_dim}") + + self.rank = rank + self.alpha = alpha + self.scale = alpha / rank + self.original_layer = original_layer + self.original_layer.trainable = False + + # Note: the original paper mentions that normal distribution was + # used for initialization. However, the official LoRA implementation + # uses "Kaiming/He Initialization". + self.down_layer = tf.keras.layers.Dense( + units=rank, + use_bias=False, + kernel_initializer=tf.keras.initializers.HeUniform(), + #kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1 / self.rank), + trainable=trainable, + name="lora_a" + ) + + self.up_layer = tf.keras.layers.Dense( + units=self.output_dim, + use_bias=False, + kernel_initializer=tf.keras.initializers.Zeros(), + trainable=trainable, + name="lora_b" + ) + + def call(self, inputs): + original_output = self.original_layer(inputs) + lora_output = self.up_layer(self.down_layer(inputs)) * self.scale + return original_output + lora_output + + def get_config(self): + config = super().get_config().copy() + config.update( + { + "rank": self.rank, + "alpha": self.alpha + } + ) + return config + class AdapterHoulsby(tf.keras.layers.Layer): - ### Houlsby et al. 2019 implementation + # https://arxiv.org/abs/1902.00751 + # adapted from: https://github.com/jain-harshil/Adapter-BERT def __init__( self, diff --git a/src/baskerville/scripts/hound_train.py b/src/baskerville/scripts/hound_train.py index d5a754d..871ff51 100755 --- a/src/baskerville/scripts/hound_train.py +++ b/src/baskerville/scripts/hound_train.py @@ -71,7 +71,7 @@ def main(): parser.add_argument( "--transfer_mode", default="full", - help="transfer method. [full, linear, adapter]", + help="transfer method. [full, linear, adapterHoulsby, lora, lora_full]", ) parser.add_argument( "--latent", @@ -185,6 +185,16 @@ def main(): seqnn_model.model_trunk = None seqnn_model.model.summary() + elif args.transfer_mode=='lora': + seqnn_model.model_trunk.trainable=False + add_lora(seqnn_model.model, rank=args.latent, mode='default') + seqnn_model.model.summary() + + elif args.transfer_mode=='lora_full': + seqnn_model.model_trunk.trainable=False + add_lora(seqnn_model.model, rank=args.latent, mode='full') + seqnn_model.model.summary() + # initialize trainer seqnn_trainer = trainer.Trainer( params_train, train_data, eval_data, args.out_dir @@ -312,6 +322,25 @@ def make_adapter_model(input_model, strand_pair, latent_size=16): if re.match('layer_normalization', l.name): l.trainable = True return model_adapter + +def add_lora(input_model, rank=8, alpha=16, mode='default'): + # take seqnn.model as input + # replace _q_layer, _v_layer in multihead_attention + # optionally replace _k_layer, _embedding_layer + if mode not in ['default','full']: + raise ValueError("mode must be default or full") + + for layer in input_model.layers: + if re.match('multihead_attention', layer.name): + # default loRA + layer._q_layer = layers.Lora(layer._q_layer, rank=rank, alpha=alpha) + layer._v_layer = layers.Lora(layer._v_layer, rank=rank, alpha=alpha) + # full loRA + if mode=='full': + layer._k_layer = layers.Lora(layer._k_layer, rank=rank, alpha=alpha) + layer._embedding_layer = layers.Lora(layer._embedding_layer, rank=rank, alpha=alpha) + + ################################################################################ # __main__ ################################################################################ diff --git a/src/baskerville/scripts/westminster_train_folds_copy.py b/src/baskerville/scripts/westminster_train_folds_copy.py index 6f27ec5..777d784 100755 --- a/src/baskerville/scripts/westminster_train_folds_copy.py +++ b/src/baskerville/scripts/westminster_train_folds_copy.py @@ -62,6 +62,18 @@ def main(): help='Evaluation TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]') parser.add_option_group(train_options) + # transfer options + transfer_options = OptionGroup(parser, 'transfer options') + transfer_options.add_option('--transfer', dest='transfer', + default=False, action='store_true', + help='whether to do transfer learning.') + transfer_options.add_option('--pretrain', dest='pretrain', + default=None, help='path to pretrained model trunk.') + transfer_options.add_option('--transfer_mode', dest='transfer_mode', + default='linear', help='transfer method.') + transfer_options.add_option('--latent', dest='latent', type='int', + default=0, help='latent size. ') + # eval eval_options = OptionGroup(parser, 'hound_eval.py options') eval_options.add_option('--rank', dest='rank_corr', @@ -87,7 +99,7 @@ def main(): default=False, action='store_true', help='Restart training from checkpoint [Default: %default]') rep_options.add_option('-e', dest='conda_env', - default='tf12', + default='tf2.12', help='Anaconda environment [Default: %default]') rep_options.add_option('-f', dest='fold_subset', default=None, type='int', @@ -175,7 +187,7 @@ def main(): exit(0) cmd_source = 'source /home/yuanh/.bashrc;' - hound_train = '/home/yuanh/programs/source/python_packages/baskerville/scripts/hound_train.py' + hound_train = 'hound_train.py' ####################################################### # train @@ -205,6 +217,15 @@ def main(): cmd += ' %s' %hound_train cmd += ' %s' % options_string(options, train_options, rep_dir) + + # transfer learning options + if options.transfer: + cmd += ' --restore %s/f%dc%d.h5' % (options.pretrain, fi, ci) + cmd += ' --trunk' + cmd += ' --transfer_mode %s' % options.transfer_mode + if options.latent!=0: + cmd += ' --latent %d' % options.latent + cmd += ' %s %s' % (params_file, ' '.join(rep_data_dirs)) name = '%s-train-f%dc%d' % (options.name, fi, ci)