From 337b0533f46a35fc141cf8b1a3d12d26b82158a3 Mon Sep 17 00:00:00 2001 From: hy395 Date: Tue, 19 Sep 2023 16:29:33 -0700 Subject: [PATCH] transfer_learn --- setup.cfg | 44 +- src/baskerville/HY_helper.py | 74 +++ src/baskerville/layers.py | 49 ++ src/baskerville/scripts/hound_train.py | 138 ++++- .../scripts/westminster_train_folds_copy.py | 509 ++++++++++++++++++ src/baskerville/trainer.py | 79 ++- 6 files changed, 853 insertions(+), 40 deletions(-) create mode 100644 src/baskerville/HY_helper.py create mode 100755 src/baskerville/scripts/westminster_train_folds_copy.py diff --git a/setup.cfg b/setup.cfg index 43fda39..0f1e198 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,31 +19,31 @@ package_dir = packages = find: python_requires = >=3.8, <3.11 install_requires = - h5py~=3.7.0 - intervaltree~=3.1.0 - joblib~=1.1.1 - matplotlib~=3.7.1 - google-cloud-storage~=2.0.0 - natsort~=7.1.1 - networkx~=2.8.4 - numpy~=1.24.3 - pandas~=1.5.3 - pybigwig~=0.3.18 - pysam~=0.21.0 - pybedtools~=0.9.0 - qnorm~=0.8.1 - seaborn~=0.12.2 - scikit-learn~=1.2.2 - scipy~=1.9.1 - statsmodels~=0.13.5 - tabulate~=0.8.10 - tensorflow~=2.12.0 - tqdm~=4.65.0 + h5py>=3.7.0 + intervaltree>=3.1.0 + joblib>=1.1.1 + matplotlib>=3.7.1 + google-cloud-storage>=2.0.0 + natsort>=7.1.1 + networkx>=2.8.4 + numpy>=1.24.3 + pandas>=1.5.3 + pybigwig>=0.3.18 + pysam>=0.21.0 + pybedtools>=0.9.0 + qnorm>=0.8.1 + seaborn>=0.12.2 + scikit-learn>=1.2.2 + scipy>=1.9.1 + statsmodels>=0.13.5 + tabulate>=0.8.10 + tensorflow>=2.12.0 + tqdm>=4.65.0 [options.extras_require] dev = - black==22.3.0 - pytest==7.1.2 + black>=22.3.0 + pytest>=7.1.2 [options.packages.find] where = src diff --git a/src/baskerville/HY_helper.py b/src/baskerville/HY_helper.py new file mode 100644 index 0000000..f4f7878 --- /dev/null +++ b/src/baskerville/HY_helper.py @@ -0,0 +1,74 @@ +import numpy as np +from basenji import dna_io +import pysam +import pyBigWig + +def make_seq_1hot(genome_open, chrm, start, end, seq_len): + if start < 0: + seq_dna = 'N'*(-start) + genome_open.fetch(chrm, 0, end) + else: + seq_dna = genome_open.fetch(chrm, start, end) + + #Extend to full length + if len(seq_dna) < seq_len: + seq_dna += 'N'*(seq_len-len(seq_dna)) + + seq_1hot = dna_io.dna_1hot(seq_dna) + return seq_1hot + +#Helper function to get (padded) one-hot +def process_sequence(fasta_file, chrom, start, end, seq_len=524288) : + + fasta_open = pysam.Fastafile(fasta_file) + seq_len_actual = end - start + + #Pad sequence to input window size + start -= (seq_len - seq_len_actual) // 2 + end += (seq_len - seq_len_actual) // 2 + + #Get one-hot + sequence_one_hot = make_seq_1hot(fasta_open, chrom, start, end, seq_len) + + return sequence_one_hot.astype('float32') + +def compute_cov(seqnn_model, chr, start, end): + seq_len = seqnn_model.model.layers[0].input.shape[1] + seq1hot = process_sequence('/home/yuanh/programs/genomes/hg38/hg38.fa', chr, start, end, seq_len=seq_len) + out = seqnn_model.model(seq1hot[None, ]) + return out.numpy() + +def write_bw(bw_file, chr, start, end, values, span=32): + bw_out = pyBigWig.open(bw_file, 'w') + header = [] + header.append((chr, end+1)) + bw_out.addHeader(header) + bw_out.addEntries(chr, start, values=values, span=span, step=span) + bw_out.close() + +def transform(seq_cov, clip=384, clip_soft=320, scale=0.3): + seq_cov = scale * seq_cov # scale + seq_cov = -1 + np.sqrt(1+seq_cov) # variant stabilize + clip_mask = (seq_cov > clip_soft) # soft clip + seq_cov[clip_mask] = clip_soft-1 + np.sqrt(seq_cov[clip_mask] - clip_soft+1) + seq_cov = np.clip(seq_cov, -clip, clip) # hard clip + return seq_cov + +def untransform(cov, scale=0.3, clip_soft=320, pool_width=32): + + # undo clip_soft + cov_unclipped = (cov - clip_soft + 1)**2 + clip_soft - 1 + unclip_mask = (cov > clip_soft) + cov[unclip_mask] = cov_unclipped[unclip_mask] + + # undo sqrt + cov = (cov +1)**2 - 1 + + # undo scale + cov = cov / scale + + # undo sum + cov = cov / pool_width + + return cov + + diff --git a/src/baskerville/layers.py b/src/baskerville/layers.py index acde2e5..1d28c2c 100644 --- a/src/baskerville/layers.py +++ b/src/baskerville/layers.py @@ -23,6 +23,55 @@ for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True) +##################### +# transfer learning # +##################### +class AdapterHoulsby(tf.keras.layers.Layer): + ### Houlsby et al. 2019 implementation + + def __init__( + self, + latent_size, + activation=tf.keras.layers.ReLU(), + **kwargs): + super(AdapterHoulsby, self).__init__(**kwargs) + self.latent_size = latent_size + self.activation = activation + + def build(self, input_shape): + self.down_project = tf.keras.layers.Dense( + units=self.latent_size, + activation="linear", + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), + bias_initializer="zeros", + name='adapter_down' + ) + + self.up_project = tf.keras.layers.Dense( + units=input_shape[-1], + activation="linear", + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), + bias_initializer="zeros", + name='adapter_up' + ) + + def call(self, inputs): + projected_down = self.down_project(inputs) + activated = self.activation(projected_down) + projected_up = self.up_project(activated) + output = projected_up + inputs + return output + + def get_config(self): + config = super().get_config().copy() + config.update( + { + "latent_size": self.latent_size, + "activation": self.activation + } + ) + return config + ############################################################ # Basic ############################################################ diff --git a/src/baskerville/scripts/hound_train.py b/src/baskerville/scripts/hound_train.py index e7ec150..d5a754d 100755 --- a/src/baskerville/scripts/hound_train.py +++ b/src/baskerville/scripts/hound_train.py @@ -17,6 +17,7 @@ import json import os import shutil +import re import numpy as np import pandas as pd @@ -26,6 +27,7 @@ from baskerville import dataset from baskerville import seqnn from baskerville import trainer +from baskerville import layers """ hound_train.py @@ -33,7 +35,6 @@ Train Hound model using given parameters and data. """ - def main(): parser = argparse.ArgumentParser(description="Train a model.") parser.add_argument( @@ -67,6 +68,17 @@ def main(): default=False, help="Restore only model trunk [Default: %(default)s]", ) + parser.add_argument( + "--transfer_mode", + default="full", + help="transfer method. [full, linear, adapter]", + ) + parser.add_argument( + "--latent", + type=int, + default=16, + help="adapter latent size.", + ) parser.add_argument( "--tfr_train", default=None, @@ -131,31 +143,65 @@ def main(): tfr_pattern=args.tfr_eval, ) ) - + params_model["strand_pair"] = strand_pairs if args.mixed_precision: - mixed_precision.set_global_policy("mixed_float16") - + policy = mixed_precision.Policy('mixed_float16') + mixed_precision.set_global_policy(policy) + if params_train.get("num_gpu", 1) == 1: ######################################## # one GPU # initialize model seqnn_model = seqnn.SeqNN(params_model) - + # restore if args.restore: seqnn_model.restore(args.restore, trunk=args.trunk) + # transfer learning strategies + if args.transfer_mode=='full': + seqnn_model.model.trainable=True + + elif args.transfer_mode=='batch_norm': + seqnn_model.model_trunk.trainable=False + for l in seqnn_model.model.layers: + if l.name.startswith("batch_normalization"): + l.trainable=True + seqnn_model.model.summary() + + elif args.transfer_mode=='linear': + seqnn_model.model_trunk.trainable=False + seqnn_model.model.summary() + + elif args.transfer_mode=='adapterHoulsby': + seqnn_model.model_trunk.trainable=False + strand_pair = strand_pairs[0] + adapter_model = make_adapter_model(seqnn_model.model, strand_pair, args.latent) + seqnn_model.model = adapter_model + seqnn_model.models[0] = seqnn_model.model + seqnn_model.model_trunk = None + seqnn_model.model.summary() + # initialize trainer seqnn_trainer = trainer.Trainer( params_train, train_data, eval_data, args.out_dir ) - + # compile model seqnn_trainer.compile(seqnn_model) + # train model + if args.keras_fit: + seqnn_trainer.fit_keras(seqnn_model) + else: + if len(args.data_dirs) == 1: + seqnn_trainer.fit_tape(seqnn_model) + else: + seqnn_trainer.fit2(seqnn_model) + else: ######################################## # multi GPU @@ -163,6 +209,7 @@ def main(): strategy = tf.distribute.MirroredStrategy() with strategy.scope(): + if not args.keras_fit: # distribute data for di in range(len(args.data_dirs)): @@ -190,16 +237,81 @@ def main(): # compile model seqnn_trainer.compile(seqnn_model) - # train model - if args.keras_fit: - seqnn_trainer.fit_keras(seqnn_model) - else: - if len(args.data_dirs) == 1: - seqnn_trainer.fit_tape(seqnn_model) + # train model + if args.keras_fit: + seqnn_trainer.fit_keras(seqnn_model) else: - seqnn_trainer.fit2(seqnn_model) + if len(args.data_dirs) == 1: + seqnn_trainer.fit_tape(seqnn_model) + else: + seqnn_trainer.fit2(seqnn_model) +def make_adapter_model(input_model, strand_pair, latent_size=16): + # take seqnn_model as input + # output a new seqnn_model object + # only the adapter, and layer_norm are trainable + + model = tf.keras.Model(inputs=input_model.input, + outputs=input_model.layers[-2].output) # remove the switch_reverse layer + + # save current graph + layer_parent_dict_old = {} # the parent layers of each layer in the old graph + for layer in model.layers: + for node in layer._outbound_nodes: + layer_name = node.outbound_layer.name + if layer_name not in layer_parent_dict_old: + layer_parent_dict_old.update({layer_name: [layer.name]}) + else: + if layer.name not in layer_parent_dict_old[layer_name]: + layer_parent_dict_old[layer_name].append(layer.name) + + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({model.layers[0].name: model.input}) + + # remove switch_reverse + to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] + for i in to_fix: + del layer_parent_dict_old[i] + + # Iterate over all layers after the input + model_outputs = [] + reverse_bool = None + + for layer in model.layers[1:]: + + # parent layers + parent_layers = layer_parent_dict_old[layer.name] + + # layer inputs + layer_input = [layer_output_dict_new[parent] for parent in parent_layers] + if len(layer_input) == 1: layer_input = layer_input[0] + + if re.match('stochastic_reverse_complement', layer.name): + x, reverse_bool = layer(layer_input) + + # insert adapter: + elif re.match('add', layer.name): + if any([re.match('dropout', i) for i in parent_layers]): + print('adapter added before:%s'%layer.name) + x = layers.AdapterHoulsby(latent_size=latent_size)(layer_input[1]) + x = layer([layer_input[0], x]) + else: + x = layer(layer_input) + + else: + x = layer(layer_input) + + # save the output tensor of every layer + layer_output_dict_new.update({layer.name: x}) + + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[model.layers[-1].name], reverse_bool]) + model_adapter = tf.keras.Model(inputs=model.inputs, outputs=final) + + # set layer_norm layers to trainable + for l in model_adapter.layers: + if re.match('layer_normalization', l.name): l.trainable = True + return model_adapter ################################################################################ # __main__ ################################################################################ diff --git a/src/baskerville/scripts/westminster_train_folds_copy.py b/src/baskerville/scripts/westminster_train_folds_copy.py new file mode 100755 index 0000000..6f27ec5 --- /dev/null +++ b/src/baskerville/scripts/westminster_train_folds_copy.py @@ -0,0 +1,509 @@ +#!/usr/bin/env python +# Copyright 2019 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from optparse import OptionParser, OptionGroup +import glob +import json +import os +import pdb +import shutil + +from natsort import natsorted + +import slurm + +""" +westminster_train_folds.py + +Train baskerville model replicates on cross folds using given parameters and data. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = 'usage: %prog [options] ...' + parser = OptionParser(usage) + + # train + train_options = OptionGroup(parser, 'houndtrain.py options') + train_options.add_option('-k', dest='keras_fit', + default=False, action='store_true', + help='Train with Keras fit method [Default: %default]') + train_options.add_option('-m', dest='mixed_precision', + default=False, action='store_true', + help='Train with mixed precision [Default: %default]') + train_options.add_option('-o', dest='out_dir', + default='train_out', + help='Training output directory [Default: %default]') + train_options.add_option('--restore', dest='restore', + help='Restore model and continue training, from existing fold train dir [Default: %default]') + train_options.add_option('--trunk', dest='trunk', + default=False, action='store_true', + help='Restore only model trunk [Default: %default]') + train_options.add_option('--tfr_train', dest='tfr_train_pattern', + default=None, + help='Training TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]') + train_options.add_option('--tfr_eval', dest='tfr_eval_pattern', + default=None, + help='Evaluation TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]') + parser.add_option_group(train_options) + + # eval + eval_options = OptionGroup(parser, 'hound_eval.py options') + eval_options.add_option('--rank', dest='rank_corr', + default=False, action='store_true', + help='Compute Spearman rank correlation [Default: %default]') + eval_options.add_option('--rc', dest='rc', + default=False, action='store_true', + help='Average forward and reverse complement predictions [Default: %default]') + eval_options.add_option('--shifts', dest='shifts', + default='0', type='str', + help='Ensemble prediction shifts [Default: %default]') + parser.add_option('--step', dest='step', + default=1, type='int', + help='Spatial step for specificity/spearmanr [Default: %default]') + parser.add_option_group(eval_options) + + # multi + rep_options = OptionGroup(parser, 'replication options') + rep_options.add_option('-c', dest='crosses', + default=1, type='int', + help='Number of cross-fold rounds [Default:%default]') + rep_options.add_option('--checkpoint', dest='checkpoint', + default=False, action='store_true', + help='Restart training from checkpoint [Default: %default]') + rep_options.add_option('-e', dest='conda_env', + default='tf12', + help='Anaconda environment [Default: %default]') + rep_options.add_option('-f', dest='fold_subset', + default=None, type='int', + help='Run a subset of folds [Default:%default]') + rep_options.add_option('--name', dest='name', + default='fold', help='SLURM name prefix [Default: %default]') + rep_options.add_option('-p', dest='processes', + default=None, type='int', + help='Number of processes, passed by multi script') + rep_options.add_option('-q', dest='queue', + default='titan_rtx', + help='SLURM queue on which to run the jobs [Default: %default]') + rep_options.add_option('-r', '--restart', dest='restart', + default=False, action='store_true') + rep_options.add_option('--setup', dest='setup', + default=False, action='store_true', + help='Setup folds data directory only [Default: %default]') + rep_options.add_option('--spec_off', dest='spec_off', + default=False, action='store_true') + rep_options.add_option('--eval_off', dest='eval_off', + default=False, action='store_true') + rep_options.add_option('--eval_train_off', dest='eval_train_off', + default=False, action='store_true') + parser.add_option_group(rep_options) + + (options, args) = parser.parse_args() + + if len(args) < 2: + parser.error('Must provide parameters and data directory.') + else: + params_file = os.path.abspath(args[0]) + data_dirs = [os.path.abspath(arg) for arg in args[1:]] + + ####################################################### + # prep work + + if not options.restart and os.path.isdir(options.out_dir): + print('Output directory %s exists. Please remove.' % options.out_dir) + exit(1) + os.makedirs(options.out_dir, exist_ok=True) + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_train = params['train'] + + # copy params into output directory + shutil.copy(params_file, '%s/params.json' % options.out_dir) + + # read data parameters + num_data = len(data_dirs) + data_stats_file = '%s/statistics.json' % data_dirs[0] + with open(data_stats_file) as data_stats_open: + data_stats = json.load(data_stats_open) + + # count folds + num_folds = len([dkey for dkey in data_stats if dkey.startswith('fold')]) + + # subset folds + if options.fold_subset is not None: + num_folds = min(options.fold_subset, num_folds) + + if options.queue == 'standard': + num_cpu = 8 + num_gpu = 0 + time_base = 64 + else: + num_cpu = 2 + num_gpu = 1 + time_base = 24 + + # arrange data + for ci in range(options.crosses): + for fi in range(num_folds): + rep_dir = '%s/f%dc%d' % (options.out_dir, fi, ci) + os.makedirs(rep_dir, exist_ok=True) + + # make data directories + for di in range(num_data): + rep_data_dir = '%s/data%d' % (rep_dir, di) + if not os.path.isdir(rep_data_dir): + make_rep_data(data_dirs[di], rep_data_dir, fi, ci) + + if options.setup: + exit(0) + + cmd_source = 'source /home/yuanh/.bashrc;' + hound_train = '/home/yuanh/programs/source/python_packages/baskerville/scripts/hound_train.py' + ####################################################### + # train + + jobs = [] + + for ci in range(options.crosses): + for fi in range(num_folds): + rep_dir = '%s/f%dc%d' % (options.out_dir, fi, ci) + + train_dir = '%s/train' % rep_dir + if options.restart and not options.checkpoint and os.path.isdir(train_dir): + print('%s found and skipped.' % rep_dir) + + else: + # collect data directories + rep_data_dirs = [] + for di in range(num_data): + rep_data_dirs.append('%s/data%d' % (rep_dir, di)) + + # if options.checkpoint: + # os.rename('%s/train.out' % rep_dir, '%s/train1.out' % rep_dir) + + # train command + cmd = cmd_source + cmd += ' conda activate %s;' % options.conda_env + cmd += ' echo $HOSTNAME;' + + cmd += ' %s' %hound_train + cmd += ' %s' % options_string(options, train_options, rep_dir) + cmd += ' %s %s' % (params_file, ' '.join(rep_data_dirs)) + + name = '%s-train-f%dc%d' % (options.name, fi, ci) + sbf = os.path.abspath('%s/train.sb' % rep_dir) + outf = os.path.abspath('%s/train.%%j.out' % rep_dir) + errf = os.path.abspath('%s/train.%%j.err' % rep_dir) + + j = slurm.Job(cmd, name, + outf, errf, sbf, + queue=options.queue, + cpu=4, + gpu=params_train.get('num_gpu',1), + mem=30000, time='60-0:0:0') + jobs.append(j) + + slurm.multi_run(jobs, max_proc=options.processes, verbose=True, + launch_sleep=10, update_sleep=60) + + + ####################################################### + # evaluate training set + + jobs = [] + + if not options.eval_train_off: + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (options.out_dir, fi, ci) + + for di in range(num_data): + if num_data == 1: + out_dir = '%s/eval_train' % it_dir + model_file = '%s/train/model_check.h5' % it_dir + else: + out_dir = '%s/eval%d_train' % (it_dir, di) + model_file = '%s/train/model%d_check.h5' % (it_dir, di) + + # check if done + acc_file = '%s/acc.txt' % out_dir + if os.path.isfile(acc_file): + print('%s already generated.' % acc_file) + else: + # hound evaluate + cmd = cmd_source + cmd += ' conda activate %s;' % options.conda_env + cmd += ' echo $HOSTNAME;' + cmd += ' hound_eval.py' + cmd += ' --head %d' % di + cmd += ' -o %s' % out_dir + if options.rc: + cmd += ' --rc' + if options.shifts: + cmd += ' --shifts %s' % options.shifts + cmd += ' --split train' + cmd += ' %s' % params_file + cmd += ' %s' % model_file + cmd += ' %s/data%d' % (it_dir, di) + + name = '%s-evaltr-f%dc%d' % (options.name, fi, ci) + job = slurm.Job(cmd, + name=name, + out_file='%s.out'%out_dir, + err_file='%s.err'%out_dir, + queue=options.queue, + cpu=num_cpu, gpu=num_gpu, + mem=30000, + time='%d:00:00' % (3*time_base)) + jobs.append(job) + + + ####################################################### + # evaluate test set + + if not options.eval_off: + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (options.out_dir, fi, ci) + + for di in range(num_data): + if num_data == 1: + out_dir = '%s/eval' % it_dir + model_file = '%s/train/model_best.h5' % it_dir + else: + out_dir = '%s/eval%d' % (it_dir, di) + model_file = '%s/train/model%d_best.h5' % (it_dir, di) + + # check if done + acc_file = '%s/acc.txt' % out_dir + if os.path.isfile(acc_file): + print('%s already generated.' % acc_file) + else: + cmd = cmd_source + cmd += ' conda activate %s;' % options.conda_env + cmd += ' echo $HOSTNAME;' + cmd += ' hound_eval.py' + cmd += ' --head %d' % di + cmd += ' -o %s' % out_dir + if options.rc: + cmd += ' --rc' + if options.shifts: + cmd += ' --shifts %s' % options.shifts + if options.rank_corr: + cmd += ' --rank' + cmd += ' --step %d' % options.step + cmd += ' %s' % params_file + cmd += ' %s' % model_file + cmd += ' %s/data%d' % (it_dir, di) + + name = '%s-eval-f%dc%d' % (options.name, fi, ci) + job = slurm.Job(cmd, + name=name, + out_file='%s.out'%out_dir, + err_file='%s.err'%out_dir, + queue=options.queue, + cpu=num_cpu, gpu=num_gpu, + mem=30000, + time='%d:00:00' % time_base) + jobs.append(job) + + ####################################################### + # evaluate test specificity + + if not options.spec_off: + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = '%s/f%dc%d' % (options.out_dir, fi, ci) + + for di in range(num_data): + if num_data == 1: + out_dir = '%s/eval_spec' % it_dir + model_file = '%s/train/model_best.h5' % it_dir + else: + out_dir = '%s/eval%d_spec' % (it_dir, di) + model_file = '%s/train/model%d_best.h5' % (it_dir, di) + + # check if done + acc_file = '%s/acc.txt' % out_dir + if os.path.isfile(acc_file): + print('%s already generated.' % acc_file) + else: + cmd = cmd_source + cmd += ' conda activate %s;' % options.conda_env + cmd += ' echo $HOSTNAME;' + cmd += ' hound_eval_spec.py' + cmd += ' --head %d' % di + cmd += ' -o %s' % out_dir + cmd += ' --step %d' % options.step + if options.rc: + cmd += ' --rc' + if options.shifts: + cmd += ' --shifts %s' % options.shifts + cmd += ' %s' % params_file + cmd += ' %s' % model_file + cmd += ' %s/data%d' % (it_dir, di) + + name = '%s-spec-f%dc%d' % (options.name, fi, ci) + job = slurm.Job(cmd, + name=name, + out_file='%s.out'%out_dir, + err_file='%s.err'%out_dir, + queue=options.queue, + cpu=num_cpu, gpu=num_gpu, + mem=150000, + time='%d:00:00' % (5*time_base)) + jobs.append(job) + + slurm.multi_run(jobs, max_proc=options.processes, verbose=True, + launch_sleep=10, update_sleep=60) + + +def make_rep_data(data_dir, rep_data_dir, fi, ci): + # read data parameters + data_stats_file = '%s/statistics.json' % data_dir + with open(data_stats_file) as data_stats_open: + data_stats = json.load(data_stats_open) + + # sequences per fold + fold_seqs = [] + dfi = 0 + while 'fold%d_seqs'%dfi in data_stats: + fold_seqs.append(data_stats['fold%d_seqs'%dfi]) + del data_stats['fold%d_seqs'%dfi] + dfi += 1 + num_folds = dfi + + # split folds into train/valid/test + test_fold = fi + valid_fold = (fi+1+ci) % num_folds + train_folds = [fold for fold in range(num_folds) if fold not in [valid_fold,test_fold]] + + # clear existing directory + if os.path.isdir(rep_data_dir): + shutil.rmtree(rep_data_dir) + + # make data directory + os.makedirs(rep_data_dir, exist_ok=True) + + # dump data stats + data_stats['test_seqs'] = fold_seqs[test_fold] + data_stats['valid_seqs'] = fold_seqs[valid_fold] + data_stats['train_seqs'] = sum([fold_seqs[tf] for tf in train_folds]) + with open('%s/statistics.json'%rep_data_dir, 'w') as data_stats_open: + json.dump(data_stats, data_stats_open, indent=4) + + # set sequence tvt + try: + seqs_bed_out = open('%s/sequences.bed'%rep_data_dir, 'w') + for line in open('%s/sequences.bed'%data_dir): + a = line.split() + sfi = int(a[-1].replace('fold','')) + if sfi == test_fold: + a[-1] = 'test' + elif sfi == valid_fold: + a[-1] = 'valid' + else: + a[-1] = 'train' + print('\t'.join(a), file=seqs_bed_out) + seqs_bed_out.close() + except (ValueError, FileNotFoundError): + pass + + # copy targets + shutil.copy('%s/targets.txt'%data_dir, '%s/targets.txt'%rep_data_dir) + + # sym link tfrecords + rep_tfr_dir = '%s/tfrecords' % rep_data_dir + os.mkdir(rep_tfr_dir) + + # test tfrecords + ti = 0 + test_tfrs = natsorted(glob.glob('%s/tfrecords/fold%d-*.tfr' % (data_dir, test_fold))) + for test_tfr in test_tfrs: + test_tfr = os.path.abspath(test_tfr) + test_rep_tfr = '%s/test-%d.tfr' % (rep_tfr_dir, ti) + os.symlink(test_tfr, test_rep_tfr) + ti += 1 + + # valid tfrecords + ti = 0 + valid_tfrs = natsorted(glob.glob('%s/tfrecords/fold%d-*.tfr' % (data_dir, valid_fold))) + for valid_tfr in valid_tfrs: + valid_tfr = os.path.abspath(valid_tfr) + valid_rep_tfr = '%s/valid-%d.tfr' % (rep_tfr_dir, ti) + os.symlink(valid_tfr, valid_rep_tfr) + ti += 1 + + # train tfrecords + ti = 0 + train_tfrs = [] + for tfi in train_folds: + train_tfrs += natsorted(glob.glob('%s/tfrecords/fold%d-*.tfr' % (data_dir, tfi))) + for train_tfr in train_tfrs: + train_tfr = os.path.abspath(train_tfr) + train_rep_tfr = '%s/train-%d.tfr' % (rep_tfr_dir, ti) + os.symlink(train_tfr, train_rep_tfr) + ti += 1 + + +def options_string(options, train_options, rep_dir): + options_str = '' + + for opt in train_options.option_list: + opt_str = opt.get_opt_string() + opt_value = options.__dict__[opt.dest] + + # wrap askeriks in "" + if type(opt_value) == str and opt_value.find('*') != -1: + opt_value = '"%s"' % opt_value + + # no value for bools + elif type(opt_value) == bool: + if not opt_value: + opt_str = '' + opt_value = '' + + # skip Nones + elif opt_value is None: + opt_str = '' + opt_value = '' + + # modify + elif opt.dest == 'out_dir': + opt_value = '%s/train' % rep_dir + + # find matching restore + elif opt.dest == 'restore': + fold_dir_mid = rep_dir.split('/')[-1] + if options.trunk: + opt_value = '%s/%s/train/model_trunk.h5' % (opt_value, fold_dir_mid) + else: + opt_value = '%s/%s/train/model_best.h5' % (opt_value, fold_dir_mid) + + options_str += ' %s %s' % (opt_str, opt_value) + + return options_str + + +################################################################################ +# __main__ +################################################################################ +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/baskerville/trainer.py b/src/baskerville/trainer.py index 5c55f52..6503815 100644 --- a/src/baskerville/trainer.py +++ b/src/baskerville/trainer.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ========================================================================= +# modified fit2 to: +# show progress bar during training +# save gpu memory information + import time import pdb @@ -19,6 +23,7 @@ import tensorflow as tf from baskerville import metrics +from tensorflow.keras import mixed_precision def parse_loss( @@ -53,6 +58,14 @@ def parse_loss( loss_fn = metrics.PoissonMultinomial( total_weight, reduction=tf.keras.losses.Reduction.NONE ) + elif loss_label == "poisson_kl": + loss_fn = metrics.PoissonKL( + spec_weight, reduction=tf.keras.losses.Reduction.NONE + ) + elif loss_label == "mse_udot": + loss_fn = metrics.MeanSquaredErrorUDot( + spec_weight, reduction=tf.keras.losses.Reduction.NONE + ) else: loss_fn = tf.keras.losses.Poisson(reduction=tf.keras.losses.Reduction.NONE) else: @@ -94,6 +107,7 @@ def __init__( strategy=None, num_gpu: int = 1, keras_fit: bool = False, + loss_scale: bool = False, ): self.params = params self.train_data = train_data @@ -107,6 +121,7 @@ def __init__( self.num_gpu = num_gpu self.batch_size = self.train_data[0].batch_size self.compiled = False + self.loss_scale = loss_scale # early stopping self.patience = self.params.get("patience", 20) @@ -133,7 +148,7 @@ def __init__( ) # optimizer - self.make_optimizer() + self.make_optimizer(loss_scale=loss_scale) def compile(self, seqnn_model): for model in seqnn_model.models: @@ -396,6 +411,11 @@ def eval_step1_distr(xd, yd): ################################################################ # training loop + gpu_memory_callback = GPUMemoryUsageCallback() + file_path='%s/gpu_mem.txt' % self.out_dir + with open(file_path, 'w') as file: + file.write('epoch\tbatch\tgpu_mem(GB)\n') + first_step = True for ei in range(epoch_start, self.train_epochs_max): if ei >= self.train_epochs_min and np.min(unimproved) > self.patience: @@ -406,10 +426,11 @@ def eval_step1_distr(xd, yd): # get iterators train_data_iters = [iter(td.dataset) for td in self.train_data] - + # train t0 = time.time() - for di in self.dataset_indexes: + prog_bar = tf.keras.utils.Progbar(len(self.dataset_indexes)) # Create Keras Progbar + for didx, di in enumerate(self.dataset_indexes): x, y = safe_next(train_data_iters[di]) if self.strategy is None: if di == 0: @@ -424,7 +445,13 @@ def eval_step1_distr(xd, yd): if first_step: print("Successful first step!", flush=True) first_step = False - + prog_bar.add(1) + + if (ei == epoch_start) and (didx < 1000) and (didx%100 == 1): + mem=gpu_memory_callback.on_batch_end() + file = open(file_path, 'a') + file.write("%d\t%d\t%.2f\n"%(ei, didx, mem)) + print("Epoch %d - %ds" % (ei, (time.time() - t0))) for di in range(self.num_datasets): print(" Data %d" % di, end="") @@ -486,6 +513,7 @@ def eval_step1_distr(xd, yd): valid_r[di].reset_states() valid_r2[di].reset_states() + def fit_tape(self, seqnn_model): """Train the model using a custom tf.GradientTape loop.""" if not self.compiled: @@ -588,6 +616,11 @@ def eval_step_distr(xd, yd): unimproved = 0 # training loop + gpu_memory_callback = GPUMemoryUsageCallback() + file_path='%s/gpu_mem.txt' % self.out_dir + with open(file_path, 'w') as file: + file.write('epoch\tbatch\tgpu_mem(GB)\n') + for ei in range(epoch_start, self.train_epochs_max): if ei >= self.train_epochs_min and unimproved > self.patience: break @@ -604,6 +637,12 @@ def eval_step_distr(xd, yd): if ei == epoch_start and si == 0: print("Successful first step!", flush=True) + # print gpu memory usage + if (ei == epoch_start) and (si < 1000) and (si%100 == 1): + mem=gpu_memory_callback.on_batch_end() + with open(file_path, 'a') as file: + file.write("%d\t%d\t%.2f\n"%(ei, si, mem)) + # evaluate for x, y in self.eval_data[0].dataset: if self.strategy is not None: @@ -660,7 +699,7 @@ def eval_step_distr(xd, yd): valid_r.reset_states() valid_r2.reset_states() - def make_optimizer(self): + def make_optimizer(self, loss_scale=False): """Make optimizer object from given parameters.""" cyclical1 = True for lrs_param in [ @@ -715,12 +754,17 @@ def make_optimizer(self): # optimizer optimizer_type = self.params.get("optimizer", "sgd").lower() if optimizer_type == "adam": + if loss_scale: + epsilon_value = 1e-04 + else: + epsilon_value = 1e-07 self.optimizer = tf.keras.optimizers.Adam( learning_rate=lr_schedule, beta_1=self.params.get("adam_beta1", 0.9), beta_2=self.params.get("adam_beta2", 0.999), clipnorm=clip_norm, global_clipnorm=global_clipnorm, + epsilon=epsilon_value, amsgrad=False, ) # reduces performance in my experience @@ -747,6 +791,9 @@ def make_optimizer(self): print("Cannot recognize optimization algorithm %s" % optimizer_type) exit(1) + if loss_scale: + self.optimizer = mixed_precision.LossScaleOptimizer(self.optimizer) + ################################################################ # AGC @@ -964,3 +1011,25 @@ def safe_next(data_iter, retry=5, sleep=10): d = next(data_iter) return d + + +def CheckGradientNA(gradients): + for grad in gradients: + if grad is not None: + if tf.reduce_any(tf.math.is_nan(grad)): + raise ValueError("NaN gradient detected.") + +# Define a custom callback class to track GPU memory usage +class GPUMemoryUsageCallback(tf.keras.callbacks.Callback): + def on_train_begin(self, logs=None): + # Enable memory growth to avoid GPU memory allocation issues + physical_devices = tf.config.experimental.list_physical_devices('GPU') + if physical_devices: + for device in physical_devices: + tf.config.experimental.set_memory_growth(device, True) + + def on_batch_end(self, logs=None): + gpu_memory = tf.config.experimental.get_memory_info('GPU:0') + current_memory = gpu_memory['peak'] / 1e9 # Convert to GB + return current_memory +