Skip to content

Commit

Permalink
lora
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Sep 21, 2023
1 parent 6e79986 commit ae31d2f
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 4 deletions.
69 changes: 68 additions & 1 deletion src/baskerville/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,75 @@
#####################
# transfer learning #
#####################

class Lora(tf.keras.layers.Layer):
# https://arxiv.org/abs/2106.09685
# adapted from:
# https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/
# https://github.com/Elvenson/stable-diffusion-keras-ft/blob/main/layers.py

def __init__(self,
original_layer,
rank=8,
alpha=16,
trainable=True,
**kwargs):

# keep the name of this layer the same as the original dense layer.
original_layer_config = original_layer.get_config()
name = original_layer_config["name"]
kwargs.pop("name", None)
super().__init__(name=name, trainable=trainable, **kwargs)

self.output_dim = original_layer_config["units"]

if rank > self.output_dim:
raise ValueError(f"LoRA rank {rank} must be less or equal than {self.output_dim}")

self.rank = rank
self.alpha = alpha
self.scale = alpha / rank
self.original_layer = original_layer
self.original_layer.trainable = False

# Note: the original paper mentions that normal distribution was
# used for initialization. However, the official LoRA implementation
# uses "Kaiming/He Initialization".
self.down_layer = tf.keras.layers.Dense(
units=rank,
use_bias=False,
kernel_initializer=tf.keras.initializers.HeUniform(),
#kernel_initializer=tf.keras.initializers.RandomNormal(stddev=1 / self.rank),
trainable=trainable,
name="lora_a"
)

self.up_layer = tf.keras.layers.Dense(
units=self.output_dim,
use_bias=False,
kernel_initializer=tf.keras.initializers.Zeros(),
trainable=trainable,
name="lora_b"
)

def call(self, inputs):
original_output = self.original_layer(inputs)
lora_output = self.up_layer(self.down_layer(inputs)) * self.scale
return original_output + lora_output

def get_config(self):
config = super().get_config().copy()
config.update(
{
"rank": self.rank,
"alpha": self.alpha
}
)
return config

class AdapterHoulsby(tf.keras.layers.Layer):
### Houlsby et al. 2019 implementation
# https://arxiv.org/abs/1902.00751
# adapted from: https://github.com/jain-harshil/Adapter-BERT

def __init__(
self,
Expand Down
31 changes: 30 additions & 1 deletion src/baskerville/scripts/hound_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main():
parser.add_argument(
"--transfer_mode",
default="full",
help="transfer method. [full, linear, adapter]",
help="transfer method. [full, linear, adapterHoulsby, lora, lora_full]",
)
parser.add_argument(
"--latent",
Expand Down Expand Up @@ -185,6 +185,16 @@ def main():
seqnn_model.model_trunk = None
seqnn_model.model.summary()

elif args.transfer_mode=='lora':
seqnn_model.model_trunk.trainable=False
add_lora(seqnn_model.model, rank=args.latent, mode='default')
seqnn_model.model.summary()

elif args.transfer_mode=='lora_full':
seqnn_model.model_trunk.trainable=False
add_lora(seqnn_model.model, rank=args.latent, mode='full')
seqnn_model.model.summary()

# initialize trainer
seqnn_trainer = trainer.Trainer(
params_train, train_data, eval_data, args.out_dir
Expand Down Expand Up @@ -312,6 +322,25 @@ def make_adapter_model(input_model, strand_pair, latent_size=16):
if re.match('layer_normalization', l.name): l.trainable = True

return model_adapter

def add_lora(input_model, rank=8, alpha=16, mode='default'):
# take seqnn.model as input
# replace _q_layer, _v_layer in multihead_attention
# optionally replace _k_layer, _embedding_layer
if mode not in ['default','full']:
raise ValueError("mode must be default or full")

for layer in input_model.layers:
if re.match('multihead_attention', layer.name):
# default loRA
layer._q_layer = layers.Lora(layer._q_layer, rank=rank, alpha=alpha)
layer._v_layer = layers.Lora(layer._v_layer, rank=rank, alpha=alpha)
# full loRA
if mode=='full':
layer._k_layer = layers.Lora(layer._k_layer, rank=rank, alpha=alpha)
layer._embedding_layer = layers.Lora(layer._embedding_layer, rank=rank, alpha=alpha)


################################################################################
# __main__
################################################################################
Expand Down
25 changes: 23 additions & 2 deletions src/baskerville/scripts/westminster_train_folds_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ def main():
help='Evaluation TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]')
parser.add_option_group(train_options)

# transfer options
transfer_options = OptionGroup(parser, 'transfer options')
transfer_options.add_option('--transfer', dest='transfer',
default=False, action='store_true',
help='whether to do transfer learning.')
transfer_options.add_option('--pretrain', dest='pretrain',
default=None, help='path to pretrained model trunk.')
transfer_options.add_option('--transfer_mode', dest='transfer_mode',
default='linear', help='transfer method.')
transfer_options.add_option('--latent', dest='latent', type='int',
default=0, help='latent size. ')

# eval
eval_options = OptionGroup(parser, 'hound_eval.py options')
eval_options.add_option('--rank', dest='rank_corr',
Expand All @@ -87,7 +99,7 @@ def main():
default=False, action='store_true',
help='Restart training from checkpoint [Default: %default]')
rep_options.add_option('-e', dest='conda_env',
default='tf12',
default='tf2.12',
help='Anaconda environment [Default: %default]')
rep_options.add_option('-f', dest='fold_subset',
default=None, type='int',
Expand Down Expand Up @@ -175,7 +187,7 @@ def main():
exit(0)

cmd_source = 'source /home/yuanh/.bashrc;'
hound_train = '/home/yuanh/programs/source/python_packages/baskerville/scripts/hound_train.py'
hound_train = 'hound_train.py'
#######################################################
# train

Expand Down Expand Up @@ -205,6 +217,15 @@ def main():

cmd += ' %s' %hound_train
cmd += ' %s' % options_string(options, train_options, rep_dir)

# transfer learning options
if options.transfer:
cmd += ' --restore %s/f%dc%d.h5' % (options.pretrain, fi, ci)
cmd += ' --trunk'
cmd += ' --transfer_mode %s' % options.transfer_mode
if options.latent!=0:
cmd += ' --latent %d' % options.latent

cmd += ' %s %s' % (params_file, ' '.join(rep_data_dirs))

name = '%s-train-f%dc%d' % (options.name, fi, ci)
Expand Down

0 comments on commit ae31d2f

Please sign in to comment.