From 6502e5285c0d1c6100fc3aabef16b6ba424a9d40 Mon Sep 17 00:00:00 2001 From: hy395 Date: Mon, 29 Jan 2024 09:53:54 -0800 Subject: [PATCH] add gene eval --- src/baskerville/blocks.py | 54 +- src/baskerville/layers.py | 15 +- src/baskerville/pygene.py | 324 +++++++++++ src/baskerville/scripts/borzoi_test_genes.py | 550 +++++++++++++++++++ src/baskerville/scripts/hound_eval_spec.py | 39 +- src/baskerville/scripts/hound_transfer.py | 150 +++-- src/baskerville/transfer_helper.py | 300 +++++++--- 7 files changed, 1283 insertions(+), 149 deletions(-) create mode 100755 src/baskerville/pygene.py create mode 100755 src/baskerville/scripts/borzoi_test_genes.py diff --git a/src/baskerville/blocks.py b/src/baskerville/blocks.py index 82e9066..19e64ce 100644 --- a/src/baskerville/blocks.py +++ b/src/baskerville/blocks.py @@ -149,6 +149,8 @@ def conv_dna( conv_type="standard", kernel_initializer="he_normal", padding="same", + transfer_se=False, + se_ratio=16, ): """Construct a single convolution block, assumed to be operating on DNA. @@ -196,6 +198,18 @@ def conv_dna( kernel_regularizer=tf.keras.regularizers.l2(l2_scale), )(current) + # squeeze-excite for transfer + if transfer_se: + se_out = squeeze_excite(current, + activation=None, + additive=False, + bottleneck_ratio=se_ratio, + use_bias=False, + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), + scale_fun='tanh' + ) + current = current + se_out + # squeeze-excite if se: current = squeeze_excite(current) @@ -267,6 +281,8 @@ def conv_nac( kernel_initializer="he_normal", padding="same", se=False, + transfer_se=False, + se_ratio=16, ): """Construct a single convolution block. @@ -326,6 +342,18 @@ def conv_nac( kernel_regularizer=tf.keras.regularizers.l2(l2_scale), )(current) + # squeeze-excite for transfer + if transfer_se: + se_out = squeeze_excite(current, + activation=None, + additive=False, + bottleneck_ratio=se_ratio, + use_bias=False, + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), + scale_fun='tanh' + ) + current = current + se_out + # squeeze-excite if se: current = squeeze_excite(current) @@ -456,6 +484,8 @@ def fpn_unet( bn_momentum=0.99, kernel_size=1, kernel_initializer="he_normal", + transfer_se=False, + se_ratio=16, ): """Construct a feature pyramid network block. @@ -529,6 +559,17 @@ def fpn_unet( kernel_initializer=kernel_initializer, )(current) + if transfer_se: + se_out = squeeze_excite(current, + activation=None, + additive=False, + bottleneck_ratio=se_ratio, + use_bias=False, + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), + scale_fun='tanh' + ) + current = current + se_out + # dropout if dropout > 0: current = tf.keras.layers.Dropout(dropout)(current) @@ -1528,11 +1569,20 @@ def squeeze_excite( additive=False, norm_type=None, bn_momentum=0.9, + kernel_initializer='glorot_uniform', + use_bias=True, + scale_fun='sigmoid', **kwargs, ): return layers.SqueezeExcite( - activation, additive, bottleneck_ratio, norm_type, bn_momentum - )(inputs) + activation=activation, + additive=additive, + bottleneck_ratio=bottleneck_ratio, + norm_type=norm_type, + bn_momentum=bn_momentum, + kernel_initializer=kernel_initializer, + scale_fun=scale_fun, + use_bias=use_bias)(inputs) def wheeze_excite(inputs, pool_size, **kwargs): diff --git a/src/baskerville/layers.py b/src/baskerville/layers.py index 8f6af73..d0513dc 100644 --- a/src/baskerville/layers.py +++ b/src/baskerville/layers.py @@ -756,6 +756,7 @@ def __init__( use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', + scale_fun='sigmoid', ): super(SqueezeExcite, self).__init__() self.activation = activation @@ -766,6 +767,7 @@ def __init__( self.kernel_initializer=kernel_initializer self.bias_initializer=bias_initializer self.use_bias=use_bias + self.scale_fun=scale_fun def build(self, input_shape): self.num_channels = input_shape[-1] @@ -783,6 +785,17 @@ def build(self, input_shape): ) exit(1) + if self.scale_fun=='sigmoid': + self.scale_f = tf.keras.activations.sigmoid + elif self.scale_fun=='tanh': # set to tanh for transfer + self.scale_f = tf.keras.activations.tanh + else: + print( + "scale function must be sigmoid or tanh", + file=sys.stderr, + ) + exit(1) + self.dense1 = tf.keras.layers.Dense( units=self.num_channels // self.bottleneck_ratio, activation="relu", @@ -819,7 +832,7 @@ def call(self, x): if self.additive: xs = x + excite else: - excite = tf.keras.activations.sigmoid(excite) + excite = self.scale_f(excite) xs = x * excite return xs diff --git a/src/baskerville/pygene.py b/src/baskerville/pygene.py new file mode 100755 index 0000000..86cae4f --- /dev/null +++ b/src/baskerville/pygene.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python +from optparse import OptionParser + +import gzip +import pdb + +''' +pygene + +Classes and methods to manage genes in GTF format. +''' + +################################################################################ +# Classes +################################################################################ +class GenomicInterval: + def __init__(self, start, end, chrom=None, strand=None): + self.start = start + self.end = end + self.chrom = chrom + self.strand = strand + + def __eq__(self, other): + return self.start == other.start + + def __lt__(self, other): + return self.start < other.start + + def __cmp__(self, x): + if self.start < x.start: + return -1 + elif self.start > x.start: + return 1 + else: + return 0 + + def __str__(self): + if self.chrom is None: + label = '[%d-%d]' % (self.start, self.end) + else: + label = '%s:%d-%d' % (self.chrom, self.start, self.end) + return label + + +class Transcript: + def __init__(self, chrom, strand, kv): + self.chrom = chrom + self.strand = strand + self.kv = kv + self.exons = [] + self.cds = [] + self.utrs3 = [] + self.utrs5 = [] + self.sorted = False + self.utrs_defined = False + + def add_cds(self, start, end): + self.cds.append(GenomicInterval(start,end)) + + def add_exon(self, start, end): + self.exons.append(GenomicInterval(start,end)) + + def define_utrs(self): + self.utrs_defined = True + + if len(self.cds) == 0: + self.utrs3 = self.exons + + else: + assert(self.sorted) + + # reset UTR lists + self.utrs5 = [] + self.utrs3 = [] + + # match up exons and CDS + ci = 0 + for ei in range(len(self.exons)): + # left initial + if self.exons[ei].end < self.cds[ci].start: + utr = GenomicInterval(self.exons[ei].start, self.exons[ei].end) + if self.strand == '+': + self.utrs5.append(utr) + else: + self.utrs3.append(utr) + + # right initial + elif self.cds[ci].end < self.exons[ei].start: + utr = GenomicInterval(self.exons[ei].start, self.exons[ei].end) + if self.strand == '+': + self.utrs3.append(utr) + else: + self.utrs5.append(utr) + + # overlap + else: + # left overlap + if self.exons[ei].start < self.cds[ci].start: + utr = GenomicInterval(self.exons[ei].start, self.cds[ci].start-1) + if self.strand == '+': + self.utrs5.append(utr) + else: + self.utrs3.append(utr) + + # right overlap + if self.cds[ci].end < self.exons[ei].end: + utr = GenomicInterval(self.cds[ci].end+1, self.exons[ei].end) + if self.strand == '+': + self.utrs3.append(utr) + else: + self.utrs5.append(utr) + + # increment up to last + ci = min(ci+1, len(self.cds)-1) + + def fasta_cds(self, fasta_open, stranded=False): + assert(self.sorted) + gene_seq = '' + for exon in self.cds: + exon_seq = fasta_open.fetch(self.chrom, exon.start-1, exon.end) + gene_seq += exon_seq + if stranded and self.strand == '-': + gene_seq = rc(gene_seq) + return gene_seq + + def fasta_exons(self, fasta_open, stranded=False): + assert(self.sorted) + gene_seq = '' + for exon in self.exons: + exon_seq = fasta_open.fetch(self.chrom, exon.start-1, exon.end) + gene_seq += exon_seq + if stranded and self.strand == '-': + gene_seq = rc(gene_seq) + return gene_seq + + def sort_exons(self): + self.sorted = True + if len(self.exons) > 1: + self.exons.sort() + if len(self.cds) > 1: + self.cds.sort() + + def span(self): + exon_starts = [exon.start for exon in self.exons] + exon_ends = [exon.end for exon in self.exons] + return min(exon_starts), max(exon_ends) + + def tss(self): + if self.strand == '-': + return self.exons[-1].end + else: + return self.exons[0].start + + def write_gtf(self, gtf_out, write_cds=False, write_utrs=False): + for ex in self.exons: + cols = [self.chrom, 'pygene', 'exon', str(ex.start), str(ex.end)] + cols += ['.', self.strand, '.', kv_gtf(self.kv)] + print('\t'.join(cols), file=gtf_out) + if write_cds: + for cds in self.cds: + cols = [self.chrom, 'pygene', 'CDS', str(cds.start), str(cds.end)] + cols += ['.', self.strand, '.', kv_gtf(self.kv)] + print('\t'.join(cols), file=gtf_out) + if write_utrs: + assert(self.utrs_defined) + for utr in self.utrs5: + cols = [self.chrom, 'pygene', '5\'UTR', str(utr.start), str(utr.end)] + cols += ['.', self.strand, '.', kv_gtf(self.kv)] + print('\t'.join(cols), file=gtf_out) + for utr in self.utrs3: + cols = [self.chrom, 'pygene', '3\'UTR', str(utr.start), str(utr.end)] + cols += ['.', self.strand, '.', kv_gtf(self.kv)] + print('\t'.join(cols), file=gtf_out) + + def __str__(self): + return '%s %s %s %s' % (self.chrom, self.strand, kv_gtf(self.kv), ','.join([ex.__str__() for ex in self.exons])) + + +class Gene: + def __init__(self): + self.transcripts = {} + self.chrom = None + self.strand = None + self.start = None + self.end = None + + def add_transcript(self, tx_id, tx): + self.transcripts[tx_id] = tx + self.chrom = tx.chrom + self.strand = tx.strand + self.kv = tx.kv + + def span(self): + tx_spans = [tx.span() for tx in self.transcripts.values()] + tx_starts, tx_ends = zip(*tx_spans) + self.start = min(tx_starts) + self.end = max(tx_ends) + return self.start, self.end + + +class GTF: + def __init__(self, gtf_file, trim_dot=False): + self.gtf_file = gtf_file + self.genes = {} + self.transcripts = {} + self.utrs_defined = False + self.trim_dot = trim_dot + + self.read_gtf() + + def define_utrs(self): + self.utrs_defined = True + for tx in self.transcripts.values(): + tx.define_utrs() + + def read_gtf(self): + if self.gtf_file[-3:] == '.gz': + gtf_in = gzip.open(self.gtf_file, 'rt') + else: + gtf_in = open(self.gtf_file) + + # ignore header + line = gtf_in.readline() + while line[0] == '#': + line = gtf_in.readline() + + while line: + a = line.split('\t') + if a[2] in ['exon','CDS']: + chrom = a[0] + interval_type = a[2] + start = int(a[3]) + end = int(a[4]) + strand = a[6] + kv = gtf_kv(a[8]) + + # add/get transcript + tx_id = kv['transcript_id'] + if self.trim_dot: + tx_id = trim_dot(tx_id) + if not tx_id in self.transcripts: + self.transcripts[tx_id] = Transcript(chrom, strand, kv) + tx = self.transcripts[tx_id] + + # add/get gene + gene_id = kv['gene_id'] + if self.trim_dot: + gene_id = trim_dot(gene_id) + if not gene_id in self.genes: + self.genes[gene_id] = Gene() + self.genes[gene_id].add_transcript(tx_id, tx) + + # add exons + if interval_type == 'exon': + tx.add_exon(start, end) + elif interval_type == 'CDS': + tx.add_cds(start, end) + + line = gtf_in.readline() + + gtf_in.close() + + # sort transcript exons + for tx in self.transcripts.values(): + tx.sort_exons() + + def write_gtf(self, out_gtf_file, write_cds=False, write_utrs=False): + if write_utrs and not self.utrs_defined: + self.define_utrs() + + gtf_out = open(out_gtf_file, 'w') + for tx in self.transcripts.values(): + tx.write_gtf(gtf_out, write_cds, write_utrs) + gtf_out.close() + + +################################################################################ +# Methods +################################################################################ +def gtf_kv(s): + """Convert the last gtf section of key/value pairs into a dict.""" + d = {} + + a = s.split(';') + for key_val in a: + if key_val.strip(): + eq_i = key_val.find('=') + if eq_i != -1 and key_val[eq_i-1] != '"': + kvs = key_val.split('=') + else: + kvs = key_val.split() + + key = kvs[0] + if kvs[1][0] == '"' and kvs[-1][-1] == '"': + val = (' '.join(kvs[1:]))[1:-1].strip() + else: + val = (' '.join(kvs[1:])).strip() + + d[key] = val + + return d + +def kv_gtf(d): + """Convert a kv hash to str gtf representation.""" + s = '' + + if 'gene_id' in d.keys(): + s += '%s "%s"; ' % ('gene_id',d['gene_id']) + + if 'transcript_id' in d.keys(): + s += '%s "%s"; ' % ('transcript_id',d['transcript_id']) + + for key in sorted(d.keys()): + if key not in ['gene_id','transcript_id']: + s += '%s "%s"; ' % (key,d[key]) + + return s + +def trim_dot(gene_id): + """Trim the final dot suffix off a gene_id.""" + dot_i = gene_id.rfind('.') + if dot_i != -1: + gene_id = gene_id[:dot_i] + return gene_id \ No newline at end of file diff --git a/src/baskerville/scripts/borzoi_test_genes.py b/src/baskerville/scripts/borzoi_test_genes.py new file mode 100755 index 0000000..1e2b853 --- /dev/null +++ b/src/baskerville/scripts/borzoi_test_genes.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python +# Copyright 2021 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser +import gc +import json +import os +import time + +from intervaltree import IntervalTree +import numpy as np +import pandas as pd +import pybedtools +import pyranges as pr +from qnorm import quantile_normalize +from scipy.stats import pearsonr +from sklearn.metrics import explained_variance_score + +from baskerville import pygene +from baskerville import dataset +from baskerville import seqnn + +""" +borzoi_test_genes.py + +Measure accuracy at gene-level. +""" + +################################################################################ +# main +################################################################################ +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + parser.add_option( + "--head", + dest="head_i", + default=0, + type="int", + help="Parameters head [Default: %default]", + ) + parser.add_option( + "-o", + dest="out_dir", + default="testg_out", + help="Output directory for predictions [Default: %default]", + ) + parser.add_option( + "--rc", + dest="rc", + default=False, + action="store_true", + help="Average the fwd and rc predictions [Default: %default]", + ) + parser.add_option( + "--shifts", + dest="shifts", + default="0", + help="Ensemble prediction shifts [Default: %default]", + ) + parser.add_option( + "--span", + dest="span", + default=False, + action="store_true", + help="Aggregate entire gene span [Default: %default]", + ) + parser.add_option( + "-t", + dest="targets_file", + default=None, + type="str", + help="File specifying target indexes and labels in table format", + ) + parser.add_option( + "--split", + dest="split_label", + default="test", + help="Dataset split label for eg TFR pattern [Default: %default]", + ) + parser.add_option( + "--tfr", + dest="tfr_pattern", + default=None, + help="TFR pattern string appended to data_dir/tfrecords for subsetting [Default: %default]", + ) + parser.add_option( + "-u", + dest="untransform_old", + default=False, + action="store_true", + help="Untransform old models [Default: %default]", + ) + (options, args) = parser.parse_args() + + if len(args) != 4: + parser.error("Must provide parameters, model, data directory, and genes GTF") + else: + params_file = args[0] + model_file = args[1] + data_dir = args[2] + genes_gtf_file = args[3] + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + # parse shifts to integers + options.shifts = [int(shift) for shift in options.shifts.split(",")] + + ####################################################### + # inputs + + # read targets + if options.targets_file is None: + options.targets_file = "%s/targets.txt" % data_dir + targets_df = pd.read_csv(options.targets_file, index_col=0, sep="\t") + + # prep strand + targets_strand_df = dataset.targets_prep_strand(targets_df) + num_targets = targets_df.shape[0] + num_targets_strand = targets_strand_df.shape[0] + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params["model"] + params_train = params["train"] + + # set strand pairs (using new indexing) + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array( + [orig_new_index[ti] for ti in targets_df.strand_pair] + ) + params_model["strand_pair"] = [targets_strand_pair] + + # construct eval data + eval_data = dataset.SeqDataset( + data_dir, + split_label=options.split_label, + batch_size=params_train["batch_size"], + mode="eval", + tfr_pattern=options.tfr_pattern, + ) + + # initialize model + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file, options.head_i) + seqnn_model.build_slice(targets_df.index) + seqnn_model.build_ensemble(options.rc, options.shifts) + + ####################################################### + # sequence intervals + + # read data parameters + with open("%s/statistics.json" % data_dir) as data_open: + data_stats = json.load(data_open) + crop_bp = data_stats["crop_bp"] + pool_width = data_stats["pool_width"] + + # read sequence positions + seqs_df = pd.read_csv( + "%s/sequences.bed" % data_dir, + sep="\t", + names=["Chromosome", "Start", "End", "Name"], + ) + seqs_df = seqs_df[seqs_df.Name == options.split_label] + seqs_pr = pr.PyRanges(seqs_df) + + ####################################################### + # make gene BED + + t0 = time.time() + print("Making gene BED...", end="") + genes_bed_file = "%s/genes.bed" % options.out_dir + if options.span: + make_genes_span(genes_bed_file, genes_gtf_file, options.out_dir) + else: + make_genes_exon(genes_bed_file, genes_gtf_file, options.out_dir) + + genes_pr = pr.read_bed(genes_bed_file) + print("DONE in %ds" % (time.time() - t0)) + + # count gene normalization lengths + gene_lengths = {} + gene_strand = {} + for line in open(genes_bed_file): + a = line.rstrip().split("\t") + gene_id = a[3] + gene_seg_len = int(a[2]) - int(a[1]) + gene_lengths[gene_id] = gene_lengths.get(gene_id, 0) + gene_seg_len + gene_strand[gene_id] = a[5] + + ####################################################### + # intersect genes w/ preds, targets + + # intersect seqs, genes + t0 = time.time() + print("Intersecting sequences w/ genes...", end="") + seqs_genes_pr = seqs_pr.join(genes_pr) + print("DONE in %ds" % (time.time() - t0), flush=True) + + # hash preds/targets by gene_id + gene_preds_dict = {} + gene_targets_dict = {} + + si = 0 + for x, y in eval_data.dataset: + # predict only if gene overlaps + yh = None + y = y.numpy()[..., targets_df.index] + + t0 = time.time() + print("Sequence %d..." % si, end="") + for bsi in range(x.shape[0]): + seq = seqs_df.iloc[si + bsi] + + cseqs_genes_df = seqs_genes_pr[seq.Chromosome].df + if cseqs_genes_df.shape[0] == 0: + # empty. no genes on this chromosome + seq_genes_df = cseqs_genes_df + else: + seq_genes_df = cseqs_genes_df[cseqs_genes_df.Start == seq.Start] + + for _, seq_gene in seq_genes_df.iterrows(): + gene_id = seq_gene.Name_b + gene_start = seq_gene.Start_b + gene_end = seq_gene.End_b + seq_start = seq_gene.Start + + # clip boundaries + gene_seq_start = max(0, gene_start - seq_start) + gene_seq_end = max(0, gene_end - seq_start) + + # requires >50% overlap + bin_start = int(np.round(gene_seq_start / pool_width)) + bin_end = int(np.round(gene_seq_end / pool_width)) + + # predict + if yh is None: + yh = seqnn_model(x) + + # slice gene region + yhb = yh[bsi, bin_start:bin_end].astype("float16") + yb = y[bsi, bin_start:bin_end].astype("float16") + + if len(yb) > 0: + gene_preds_dict.setdefault(gene_id, []).append(yhb) + gene_targets_dict.setdefault(gene_id, []).append(yb) + + # advance sequence table index + si += x.shape[0] + print("DONE in %ds" % (time.time() - t0), flush=True) + if si % 128 == 0: + gc.collect() + + # aggregate gene bin values into arrays + gene_targets = [] + gene_preds = [] + gene_ids = sorted(gene_targets_dict.keys()) + gene_within = [] + gene_wvar = [] + + for gene_id in gene_ids: + gene_preds_gi = np.concatenate(gene_preds_dict[gene_id], axis=0).astype( + "float32" + ) + gene_targets_gi = np.concatenate(gene_targets_dict[gene_id], axis=0).astype( + "float32" + ) + + # slice strand + if gene_strand[gene_id] == "+": + gene_strand_mask = (targets_df.strand != "-").to_numpy() + else: + gene_strand_mask = (targets_df.strand != "+").to_numpy() + gene_preds_gi = gene_preds_gi[:, gene_strand_mask] + gene_targets_gi = gene_targets_gi[:, gene_strand_mask] + + if gene_targets_gi.shape[0] == 0: + print(gene_id, gene_targets_gi.shape, gene_preds_gi.shape) + + # untransform + if options.untransform_old: + gene_preds_gi = dataset.untransform_preds1(gene_preds_gi, targets_strand_df) + gene_targets_gi = dataset.untransform_preds1(gene_targets_gi, targets_strand_df) + else: + gene_preds_gi = dataset.untransform_preds(gene_preds_gi, targets_strand_df) + gene_targets_gi = dataset.untransform_preds(gene_targets_gi, targets_strand_df) + + # compute within gene correlation before dropping length axis + gene_corr_gi = np.zeros(num_targets_strand) + for ti in range(num_targets_strand): + if ( + gene_preds_gi[:, ti].var() > 1e-6 + and gene_targets_gi[:, ti].var() > 1e-6 + ): + preds_log = np.log2(gene_preds_gi[:, ti] + 1) + targets_log = np.log2(gene_targets_gi[:, ti] + 1) + gene_corr_gi[ti] = pearsonr(preds_log, targets_log)[0] + # gene_corr_gi[ti] = pearsonr(gene_preds_gi[:,ti], gene_targets_gi[:,ti])[0] + else: + gene_corr_gi[ti] = np.nan + gene_within.append(gene_corr_gi) + gene_wvar.append(gene_targets_gi.var(axis=0)) + + # TEMP: save gene preds/targets + # os.makedirs('%s/gene_within' % options.out_dir, exist_ok=True) + # np.save('%s/gene_within/%s_preds.npy' % (options.out_dir, gene_id), gene_preds_gi.astype('float16')) + # np.save('%s/gene_within/%s_targets.npy' % (options.out_dir, gene_id), gene_targets_gi.astype('float16')) + + # mean coverage + gene_preds_gi = gene_preds_gi.mean(axis=0) + gene_targets_gi = gene_targets_gi.mean(axis=0) + + # scale by gene length + gene_preds_gi *= gene_lengths[gene_id] + gene_targets_gi *= gene_lengths[gene_id] + + gene_preds.append(gene_preds_gi) + gene_targets.append(gene_targets_gi) + + gene_targets = np.array(gene_targets) + gene_preds = np.array(gene_preds) + gene_within = np.array(gene_within) + gene_wvar = np.array(gene_wvar) + + # log2 transform + gene_targets = np.log2(gene_targets + 1) + gene_preds = np.log2(gene_preds + 1) + + # save values + genes_targets_df = pd.DataFrame( + gene_targets, index=gene_ids, columns=targets_strand_df.identifier + ) + genes_targets_df.to_csv("%s/gene_targets.tsv" % options.out_dir, sep="\t") + genes_preds_df = pd.DataFrame( + gene_preds, index=gene_ids, columns=targets_strand_df.identifier + ) + genes_preds_df.to_csv("%s/gene_preds.tsv" % options.out_dir, sep="\t") + genes_within_df = pd.DataFrame( + gene_within, index=gene_ids, columns=targets_strand_df.identifier + ) + genes_within_df.to_csv("%s/gene_within.tsv" % options.out_dir, sep="\t") + genes_var_df = pd.DataFrame( + gene_wvar, index=gene_ids, columns=targets_strand_df.identifier + ) + genes_var_df.to_csv("%s/gene_var.tsv" % options.out_dir, sep="\t") + + # quantile and mean normalize + gene_targets_norm = quantile_normalize(gene_targets, ncpus=2) + gene_targets_norm = gene_targets_norm - gene_targets_norm.mean( + axis=-1, keepdims=True + ) + gene_preds_norm = quantile_normalize(gene_preds, ncpus=2) + gene_preds_norm = gene_preds_norm - gene_preds_norm.mean(axis=-1, keepdims=True) + + ####################################################### + # accuracy stats + + wvar_t = np.percentile(gene_wvar, 80, axis=0) + + acc_pearsonr = [] + acc_r2 = [] + acc_npearsonr = [] + acc_nr2 = [] + acc_wpearsonr = [] + for ti in range(num_targets_strand): + r_ti = pearsonr(gene_targets[:, ti], gene_preds[:, ti])[0] + acc_pearsonr.append(r_ti) + r2_ti = explained_variance_score(gene_targets[:, ti], gene_preds[:, ti]) + acc_r2.append(r2_ti) + nr_ti = pearsonr(gene_targets_norm[:, ti], gene_preds_norm[:, ti])[0] + acc_npearsonr.append(nr_ti) + nr2_ti = explained_variance_score( + gene_targets_norm[:, ti], gene_preds_norm[:, ti] + ) + acc_nr2.append(nr2_ti) + var_mask = gene_wvar[:, ti] > wvar_t[ti] + wr_ti = gene_within[var_mask].mean() + acc_wpearsonr.append(wr_ti) + + acc_df = pd.DataFrame( + { + "identifier": targets_strand_df.identifier, + "pearsonr": acc_pearsonr, + "r2": acc_r2, + "pearsonr_norm": acc_npearsonr, + "r2_norm": acc_nr2, + "pearsonr_gene": acc_wpearsonr, + "description": targets_strand_df.description, + } + ) + acc_df.to_csv("%s/acc.txt" % options.out_dir, sep="\t") + + print("%d genes" % gene_targets.shape[0]) + print("Overall PearsonR: %.4f" % np.mean(acc_df.pearsonr)) + print("Overall R2: %.4f" % np.mean(acc_df.r2)) + print("Normalized PearsonR: %.4f" % np.mean(acc_df.pearsonr_norm)) + print("Normalized R2: %.4f" % np.mean(acc_df.r2_norm)) + print("Within-gene PearsonR: %.4f" % np.mean(acc_df.pearsonr_gene)) + + +def genes_aggregate(genes_bed_file, values_bedgraph): + """Aggregate values across genes. + + Args: + genes_bed_file (str): BED file of genes. + values_bedgraph (str): BedGraph file of values. + + Returns: + gene_values (dict): Dictionary of gene values. + """ + values_bt = pybedtools.BedTool(values_bedgraph) + genes_bt = pybedtools.BedTool(genes_bed_file) + + gene_values = {} + + for overlap in genes_bt.intersect(values_bt, wo=True): + gene_id = overlap[3] + value = overlap[7] + gene_values[gene_id] = gene_values.get(gene_id, 0) + value + + return gene_values + + +def make_genes_exon(genes_bed_file: str, genes_gtf_file: str, out_dir: str): + """Make a BED file with each genes' exons, excluding exons overlapping + across genes. + + Args: + genes_bed_file (str): Output BED file of genes. + genes_gtf_file (str): Input GTF file of genes. + out_dir (str): Output directory for temporary files. + """ + # read genes + genes_gtf = pygene.GTF(genes_gtf_file) + + # write gene exons + agenes_bed_file = "%s/genes_all.bed" % out_dir + agenes_bed_out = open(agenes_bed_file, "w") + for gene_id, gene in genes_gtf.genes.items(): + # collect exons + gene_intervals = IntervalTree() + for tx_id, tx in gene.transcripts.items(): + for exon in tx.exons: + gene_intervals[exon.start - 1 : exon.end] = True + + # union + gene_intervals.merge_overlaps() + + # write + for interval in sorted(gene_intervals): + cols = [ + gene.chrom, + str(interval.begin), + str(interval.end), + gene_id, + ".", + gene.strand, + ] + print("\t".join(cols), file=agenes_bed_out) + agenes_bed_out.close() + + # find overlapping exons + genes1_bt = pybedtools.BedTool(agenes_bed_file) + genes2_bt = pybedtools.BedTool(agenes_bed_file) + overlapping_exons = set() + for overlap in genes1_bt.intersect(genes2_bt, s=True, wo=True): + gene1_id = overlap[3] + gene1_start = int(overlap[1]) + gene1_end = int(overlap[2]) + overlapping_exons.add((gene1_id, gene1_start, gene1_end)) + + gene2_id = overlap[9] + gene2_start = int(overlap[7]) + gene2_end = int(overlap[8]) + overlapping_exons.add((gene2_id, gene2_start, gene2_end)) + + # filter for nonoverlapping exons + genes_bed_out = open(genes_bed_file, "w") + for line in open(agenes_bed_file): + a = line.split() + start = int(a[1]) + end = int(a[2]) + gene_id = a[-1] + if (gene_id, start, end) not in overlapping_exons: + print(line, end="", file=genes_bed_out) + genes_bed_out.close() + + +def make_genes_span( + genes_bed_file: str, genes_gtf_file: str, out_dir: str, stranded: bool = True +): + """Make a BED file with the span of each gene. + + Args: + genes_bed_file (str): Output BED file of genes. + genes_gtf_file (str): Input GTF file of genes. + out_dir (str): Output directory for temporary files. + stranded (bool): Perform stranded intersection. + """ + # read genes + genes_gtf = pygene.GTF(genes_gtf_file) + + # write all gene spans + agenes_bed_file = "%s/genes_all.bed" % out_dir + agenes_bed_out = open(agenes_bed_file, "w") + for gene_id, gene in genes_gtf.genes.items(): + start, end = gene.span() + cols = [gene.chrom, str(start - 1), str(end), gene_id, ".", gene.strand] + print("\t".join(cols), file=agenes_bed_out) + agenes_bed_out.close() + + # find overlapping genes + genes1_bt = pybedtools.BedTool(agenes_bed_file) + genes2_bt = pybedtools.BedTool(agenes_bed_file) + overlapping_genes = set() + for overlap in genes1_bt.intersect(genes2_bt, s=stranded, wo=True): + gene1_id = overlap[3] + gene2_id = overlap[7] + if gene1_id != gene2_id: + overlapping_genes.add(gene1_id) + overlapping_genes.add(gene2_id) + + # filter for nonoverlapping genes + genes_bed_out = open(genes_bed_file, "w") + for line in open(agenes_bed_file): + gene_id = line.split()[-1] + if gene_id not in overlapping_genes: + print(line, end="", file=genes_bed_out) + genes_bed_out.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == "__main__": + main() diff --git a/src/baskerville/scripts/hound_eval_spec.py b/src/baskerville/scripts/hound_eval_spec.py index 43d908c..ad66fa3 100755 --- a/src/baskerville/scripts/hound_eval_spec.py +++ b/src/baskerville/scripts/hound_eval_spec.py @@ -45,7 +45,7 @@ def main(): parser.add_option( "-c", dest="class_min", - default=100, + default=5, type="int", help="Minimum target class size to consider [Default: %default]", ) @@ -97,6 +97,13 @@ def main(): type="str", help="File specifying target indexes and labels in table format", ) + parser.add_option( + "--target_classes", + dest="target_classes", + default=None, + type="str", + help="comma separated string of target classes", + ) parser.add_option( "--split", dest="split_label", @@ -142,19 +149,25 @@ def main(): # classify target_classes = [] - for ti in range(num_targets): - description = targets_df.iloc[ti].description - if description.find(":") == -1: - tc = "*" - else: - desc_split = description.split(":") - if desc_split[0] == "CHIP": - tc = "/".join(desc_split[:2]) + + if options.target_classes is None: + for ti in range(num_targets): + description = targets_df.iloc[ti].description + if description.find(":") == -1: + tc = "*" else: - tc = desc_split[0] - target_classes.append(tc) - targets_df["class"] = target_classes - target_classes = sorted(set(target_classes)) + desc_split = description.split(":") + if desc_split[0] == "CHIP": + tc = "/".join(desc_split[:2]) + else: + tc = desc_split[0] + target_classes.append(tc) + targets_df["class"] = target_classes + target_classes = sorted(set(target_classes)) + else: + targets_df["class"] = targets_df['description'].str.replace(':.*','',regex=True) + target_classes = options.target_classes.split(',') + print(target_classes) ####################################################### diff --git a/src/baskerville/scripts/hound_transfer.py b/src/baskerville/scripts/hound_transfer.py index 0af7997..592a846 100755 --- a/src/baskerville/scripts/hound_transfer.py +++ b/src/baskerville/scripts/hound_transfer.py @@ -91,7 +91,7 @@ def main(): "--conv_adapter", default=None, type=str, - help="conv layer module [conv, batch_norm, squez_excit]", + help="conv layer module [conv, bn, conv_bn, squez_excit]", ) parser.add_argument( @@ -206,9 +206,12 @@ def main(): # attention adapter if args.att_adapter is not None: if args.att_adapter=='adapterHoulsby': - seqnn_model.model = transfer_helper.add_houlsby(seqnn_model.model, - strand_pairs[0], - latent_size=args.att_latent) + if args.conv_adapter not in ['se', 'se_bn', 'se_all','se_all_bn']: + # when att_adapter=='Houlsby' and conv_adapter=='se', do nothing. + # see conv_adapter section. + seqnn_model.model = transfer_helper.add_houlsby(seqnn_model.model, + strand_pairs[0], + latent_size=args.att_latent) elif args.att_adapter=='lora': transfer_helper.add_lora(seqnn_model.model, rank=args.att_latent, @@ -228,57 +231,89 @@ def main(): if args.conv_adapter=='conv': params_added = 0 for l in seqnn_model.model.layers: - if l.name.startswith("conv1d"): + if l.name.startswith(("conv1d","separable_conv1d")): l.trainable=True params_added += transfer_helper.param_count(l, type='trainable') print('params added/unfrozen by conv: %d'%params_added) - if args.conv_adapter=='conv_all': + elif args.conv_adapter=='conv_bn': params_added = 0 for l in seqnn_model.model.layers: - if l.name.startswith(("conv1d","separable_conv1d")): + if l.name.startswith(("conv1d","separable_conv1d","batch_normalization")): l.trainable=True params_added += transfer_helper.param_count(l, type='trainable') - print('params added/unfrozen by conv_all: %d'%params_added) + print('params added/unfrozen by conv_bn: %d'%params_added) - elif args.conv_adapter=='batch_norm': + elif args.conv_adapter=='bn': params_added = 0 for l in seqnn_model.model.layers: if l.name.startswith("batch_normalization"): l.trainable=True params_added += transfer_helper.param_count(l, type='trainable') - print('params added/unfrozen by batch_norm: %d'%params_added) + print('params added/unfrozen by bn: %d'%params_added) ################## # squeeze-excite # ################## - elif args.conv_adapter=='se': - seqnn_model.model = transfer_helper.add_se(seqnn_model.model, - strand_pair=strand_pairs[0], - bottleneck_ratio=args.se_ratio, - insert_mode='pre_att', - unfreeze_bn=False) - - elif args.conv_adapter=='se_bn': - seqnn_model.model = transfer_helper.add_se(seqnn_model.model, - strand_pair=strand_pairs[0], - bottleneck_ratio=args.se_ratio, - insert_mode='pre_att', - unfreeze_bn=True) - - elif args.conv_adapter=='se_all': - seqnn_model.model = transfer_helper.add_se(seqnn_model.model, - strand_pair=strand_pairs[0], - bottleneck_ratio=args.se_ratio, - insert_mode='all', - unfreeze_bn=False) - - elif args.conv_adapter=='se_all_bn': - seqnn_model.model = transfer_helper.add_se(seqnn_model.model, - strand_pair=strand_pairs[0], - bottleneck_ratio=args.se_ratio, - insert_mode='all', - unfreeze_bn=True) + elif args.conv_adapter in ['se','se_bn','se_all','se_all_bn']: + if args.att_adapter=='adapterHoulsby': + if args.conv_adapter=='se': + seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, + strand_pair=strand_pairs[0], + houlsby_latent=args.att_latent, + bottleneck_ratio=args.se_ratio, + insert_mode='pre_att', + unfreeze_bn=False) + elif args.conv_adapter=='se_bn': + seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, + strand_pair=strand_pairs[0], + houlsby_latent=args.att_latent, + bottleneck_ratio=args.se_ratio, + insert_mode='pre_att', + unfreeze_bn=True) + elif args.conv_adapter=='se_all': + seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, + strand_pair=strand_pairs[0], + houlsby_latent=args.att_latent, + bottleneck_ratio=args.se_ratio, + insert_mode='all', + unfreeze_bn=False) + elif args.conv_adapter=='se_all_bn': + seqnn_model.model = transfer_helper.add_houlsby_se(seqnn_model.model, + strand_pair=strand_pairs[0], + houlsby_latent=args.att_latent, + bottleneck_ratio=args.se_ratio, + insert_mode='all', + unfreeze_bn=True) + else: + if args.conv_adapter=='se': + seqnn_model.model = transfer_helper.add_se(seqnn_model.model, + strand_pair=strand_pairs[0], + houlsby_latent=args.att_latent, + bottleneck_ratio=args.se_ratio, + insert_mode='pre_att', + unfreeze_bn=False) + elif args.conv_adapter=='se_bn': + seqnn_model.model = transfer_helper.add_se(seqnn_model.model, + strand_pair=strand_pairs[0], + houlsby_latent=args.att_latent, + bottleneck_ratio=args.se_ratio, + insert_mode='pre_att', + unfreeze_bn=True) + elif args.conv_adapter=='se_all': + seqnn_model.model = transfer_helper.add_se(seqnn_model.model, + strand_pair=strand_pairs[0], + houlsby_latent=args.att_latent, + bottleneck_ratio=args.se_ratio, + insert_mode='all', + unfreeze_bn=False) + elif args.conv_adapter=='se_all_bn': + seqnn_model.model = transfer_helper.add_se(seqnn_model.model, + strand_pair=strand_pairs[0], + houlsby_latent=args.att_latent, + bottleneck_ratio=args.se_ratio, + insert_mode='pre_att', + unfreeze_bn=True) ################# # final summary # @@ -307,36 +342,37 @@ def main(): ############################# if args.transfer_mode=='sparse': - # Houlsby adapter requires architecture change, overwrite params.json file with new one - if args.att_adapter=='adapterHoulsby': - transfer_helper.modify_json(input_json=args.params_file, - output_json=args.out_dir, - adapter='houlsby', - latent_size=args.att_latent) - - # merge lora weights to original, save weight to: model_best.mergeW.h5 - # use original params.json + # overwrite json file when needed + # for: adapterHoulsby and squeeze-excite + transfer_helper.modify_json(input_json=args.params_file, + output_json='%s/params.json'%args.out_dir, + adapter=args.att_adapter, + latent=args.att_latent, + conv=args.conv_adapter, + se_ratio=args.se_ratio) + + # merge weights when needed + # for: lora and ia3 + # save weight to: model_best.mergeW.h5 if args.att_adapter=='lora': - seqnn_model.model.load_weights('%s/model_best.h5'args.out_dir) + seqnn_model.model.load_weights('%s/model_best.h5'%args.out_dir) transfer_helper.merge_lora(seqnn_model.model, mode='default') - seqnn_model.save('%s/model_best.mergeW.h5'args.out_dir) - transfer_helper.var_reorder('%s/model_best.mergeW.h5'args.out_dir) + seqnn_model.save('%s/model_best.mergeW.h5'%args.out_dir) + transfer_helper.var_reorder('%s/model_best.mergeW.h5'%args.out_dir) if args.att_adapter=='lora_full': - seqnn_model.model.load_weights('%s/model_best.h5'args.out_dir) + seqnn_model.model.load_weights('%s/model_best.h5'%args.out_dir) transfer_helper.merge_lora(seqnn_model.model, mode='full') - seqnn_model.save('%s/model_best.mergeW.h5'args.out_dir) - transfer_helper.var_reorder('%s/model_best.mergeW.h5'args.out_dir) + seqnn_model.save('%s/model_best.mergeW.h5'%args.out_dir) + transfer_helper.var_reorder('%s/model_best.mergeW.h5'%args.out_dir) # merge ia3 weights to original, save weight to: model_best_mergeweight.h5 if args.att_adapter=='ia3': - seqnn_model.model.load_weights('%s/model_best.h5'args.out_dir) + seqnn_model.model.load_weights('%s/model_best.h5'%args.out_dir) transfer_helper.merge_ia3(seqnn_model.model) - seqnn_model.save('%s/model_best.mergeW.h5'args.out_dir) - transfer_helper.var_reorder('%s/model_best.mergeW.h5'args.out_dir) - + seqnn_model.save('%s/model_best.mergeW.h5'%args.out_dir) + transfer_helper.var_reorder('%s/model_best.mergeW.h5'%args.out_dir) - else: ######################################## # multi GPU diff --git a/src/baskerville/transfer_helper.py b/src/baskerville/transfer_helper.py index d6cd851..401cd3a 100644 --- a/src/baskerville/transfer_helper.py +++ b/src/baskerville/transfer_helper.py @@ -34,7 +34,6 @@ def param_summary(model): print('trainable params:%d' %trainable) print('non-trainable params:%d' %non_trainable) - ###################### # add houlsby layers # ###################### @@ -117,19 +116,6 @@ def add_houlsby(input_model, strand_pair, latent_size=16): return model_adapter -# save Houlsby json -def modify_json(input_json, output_json, adapter, latent=None): - - with open(input_json) as params_open: - params = json.load(params_open) - - params["model"]["trunk"][2]['adapter']= adapter - params["model"]["trunk"][2]['latent']= latent - - ### output - with open(output_json, 'w') as params_open: - json.dump(params, params_open, indent=4) - ################### # add lora layers # ################### @@ -186,49 +172,6 @@ def add_lora(input_model, rank=8, alpha=16, mode='default'): print('params added/unfrozen by lora: %d'%params_added) -# merge lora weights -def merge_lora_layer(lora_layer): - down_weights = lora_layer.down_layer.kernel - up_weights = lora_layer.up_layer.kernel - increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * lora_layer.scale - lora_layer.original_layer.kernel.assign_add(increment_weights) - return lora_layer.original_layer - -def merge_lora(input_model, mode='default'): - for layer in input_model.layers: - if 'multihead_attention' in layer.name: - # default loRA - layer._q_layer = merge_lora_layer(layer._q_layer) - layer._v_layer = merge_lora_layer(layer._v_layer) - if mode=='full': - layer._k_layer = merge_lora_layer(layer._k_layer) - layer._embedding_layer = merge_lora_layer(layer._embedding_layer) - input_model(input_model.input) - -# correct weights.h5 weight order -def var_reorder(weight_h5): - # assumes weight_h5 model saved with seqnn_model.save() - # [i.name for i in model.layers[30].weights] to check for multihead_attention layer weights order. - # model.load_weights() load weights sequencially, assuming layer weights are in the right order. - # When inserting lora/ia3, multihead_attention layer weights order changed. - # multihead_attention layer weights order is saved inside f['model_weights']['multihead_attention'].attrs - # After saving the weight_merged model, we need to go into the weights.h5, and change the attrs in multihead attention. - var_init_order = ['r_w_bias:0:0', - 'r_r_bias:0:0', - 'q_layer/kernel:0', - 'k_layer/kernel:0', - 'v_layer/kernel:0', - 'embedding_layer/kernel:0', - 'embedding_layer/bias:0', - 'r_k_layer/kernel:0'] - - f = h5py.File(weight_h5, 'r+') - layers = [i for i in list(f['model_weights'].keys()) if 'multihead_attention' in i] - for l_name in layers: - new_name_order = [l_name+'/'+i for i in var_init_order] - f['model_weights'][l_name].attrs.modify(name='weight_names', value=new_name_order) - f.close() - ################## # add ia3 layers # ################## @@ -270,23 +213,6 @@ def add_ia3(input_model): print('params added/unfrozen by ia3: %d'%params_added) -# merge lora weights -def merge_ia3_layer(ia3_layer, type='kv'): - scaler = ia3_layer._ia3_layer.kernel[0] - ia3_layer.original_layer.kernel.assign(ia3_layer.original_layer.kernel * scaler) - if type=='embedding': - ia3_layer.original_layer.bias.assign(ia3_layer.original_layer.bias * scaler) - return ia3_layer.original_layer - -def merge_ia3(input_model): - for layer in input_model.layers: - if 'multihead_attention' in layer.name: - layer._k_layer = merge_ia3_layer(layer._k_layer, type='kv') - layer._v_layer = merge_ia3_layer(layer._v_layer, type='kv') - layer._embedding_layer = merge_ia3_layer(layer._embedding_layer, type='embedding') - input_model(input_model.input) - - ###################### # add squeeze excite # ###################### @@ -344,7 +270,8 @@ def add_se(input_model, strand_pair, bottleneck_ratio=8, insert_mode='pre_att', additive=False, # use sigmoid multiplicative scaling bottleneck_ratio=bottleneck_ratio, # bottleneck ratio use_bias=False, # ignore bias - kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3) # near-zero weight initialization + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization + scale_fun='tanh' ) x = layer(layer_input) x = x + se_layer(x) @@ -356,7 +283,8 @@ def add_se(input_model, strand_pair, bottleneck_ratio=8, insert_mode='pre_att', additive=False, # use sigmoid multiplicative scaling bottleneck_ratio=bottleneck_ratio, # bottleneck ratio use_bias=False, # ignore bias - kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3) # near-zero weight initialization + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization + scale_fun='tanh' ) x = layer(layer_input) x = x + se_layer(x) @@ -390,3 +318,223 @@ def add_se(input_model, strand_pair, bottleneck_ratio=8, insert_mode='pre_att', print('params added/unfrozen by se_block: %d'%params_added) return model_final + + +def add_houlsby_se(input_model, strand_pair, houlsby_latent=8, bottleneck_ratio=8, insert_mode='pre_att', unfreeze_bn=False): + # add squeeze-excitation blocks after conv + # input_model should be properly frozen + # pre_att: add se_block to pre-attention conv1d + # all: add se_block to pre-attention conv1d and post-attention separable_conv1d + + if insert_mode not in ['pre_att','all']: + raise ValueError("insert_mode must be pre_att or all") + + model = tf.keras.Model(inputs=input_model.input, + outputs=input_model.layers[-2].output) # remove the switch_reverse layer + + # save current graph + layer_parent_dict_old = {} # the parent layers of each layer in the old graph + for layer in model.layers: + for node in layer._outbound_nodes: + layer_name = node.outbound_layer.name + if layer_name not in layer_parent_dict_old: + layer_parent_dict_old.update({layer_name: [layer.name]}) + else: + if layer.name not in layer_parent_dict_old[layer_name]: + layer_parent_dict_old[layer_name].append(layer.name) + + layer_output_dict_new = {} # the output tensor of each layer in the new graph + layer_output_dict_new.update({model.layers[0].name: model.input}) + + # remove switch_reverse + to_fix = [i for i in layer_parent_dict_old if re.match('switch_reverse', i)] + for i in to_fix: + del layer_parent_dict_old[i] + + # Iterate over all layers after the input + model_outputs = [] + reverse_bool = None + + for layer in model.layers[1:]: + + # parent layers + parent_layers = layer_parent_dict_old[layer.name] + + # layer inputs + layer_input = [layer_output_dict_new[parent] for parent in parent_layers] + if len(layer_input) == 1: layer_input = layer_input[0] + + if layer.name.startswith("stochastic_reverse_complement"): + x, reverse_bool = layer(layer_input) + + # insert houlsby: + elif re.match('add', layer.name): + if any([re.match('dropout', i) for i in parent_layers]): + print('adapter added before:%s'%layer.name) + x = layers.AdapterHoulsby(latent_size=houlsby_latent)(layer_input[1]) + x = layer([layer_input[0], x]) + else: + x = layer(layer_input) + + # insert squeeze-excite layer: + elif layer.name.startswith("conv1d"): + se_layer = layers.SqueezeExcite( + activation=None, # no activation before squeezing + additive=False, # use sigmoid multiplicative scaling + bottleneck_ratio=bottleneck_ratio, # bottleneck ratio + use_bias=False, # ignore bias + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization + scale_fun='tanh' + ) + x = layer(layer_input) + x = x + se_layer(x) + + elif layer.name.startswith("separable_conv1d"): + if insert_mode=='all': + se_layer = layers.SqueezeExcite( + activation=None, # no activation before squeezing + additive=False, # use sigmoid multiplicative scaling + bottleneck_ratio=bottleneck_ratio, # bottleneck ratio + use_bias=False, # ignore bias + kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=1e-3), # near-zero weight initialization + scale_fun='tanh' + ) + x = layer(layer_input) + x = x + se_layer(x) + else: + x = layer(layer_input) + + else: + x = layer(layer_input) + + # save the output tensor of every layer + layer_output_dict_new.update({layer.name: x}) + + final = layers.SwitchReverse(strand_pair)([layer_output_dict_new[model.layers[-1].name], reverse_bool]) + model_final = tf.keras.Model(inputs=model.inputs, outputs=final) + + # set trainable + for l in model_final.layers[:-2]: # trunk + if re.match('layer_normalization|adapter_houlsby', l.name): + l.trainable = True + else: + l.trainable = False + + for l in model_final.layers: # set trunk + if l.name.startswith("squeeze_excite"): l.trainable = True + + if unfreeze_bn: + for l in model_final.layers: + if l.name.startswith("batch_normalization"): l.trainable=True + + # expected number of trainable params added/unfrozen: + params_added = 0 + for l in model_final.layers: + if l.name.startswith("squeeze_excite"): + params_added += param_count(l) + elif l.name.startswith("batch_normalization"): + if unfreeze_bn: params_added += param_count(l, type='trainable') + elif l.name.startswith("adapter_houlsby"): + params_added += param_count(l) + elif l.name.startswith("layer_normalization"): + params_added += param_count(l, type='trainable') + print('params added/unfrozen by se_block: %d'%params_added) + + return model_final + +############### +# modify json # +############### +# houlsby and squeeze-excite +def modify_json(input_json, output_json, adapter='adapterHoulsby', latent=None, conv=None, se_ratio=None): + + with open(input_json) as params_open: + params = json.load(params_open) + + # houlsby # + if adapter=='adapterHoulsby': + params["model"]["trunk"][2]['adapter']= 'houlsby' + params["model"]["trunk"][2]['latent']= latent + + # squeeze-excite # + if conv=='se_all' or conv=='se_all_bn': + for i in [0, 1, 3, 4]: + params['model']['trunk'][i]['transfer_se']=True + params['model']['trunk'][i]['se_ratio']=se_ratio + + elif conv=='se' or conv=='se_bn': + for i in [0, 1]: + params['model']['trunk'][i]['transfer_se']=True + params['model']['trunk'][i]['se_ratio']=se_ratio + + else: + pass + + ### output + with open(output_json, 'w') as params_open: + json.dump(params, params_open, indent=4) + + +###################### +# merge lora weights # +###################### +def merge_lora_layer(lora_layer): + down_weights = lora_layer.down_layer.kernel + up_weights = lora_layer.up_layer.kernel + increment_weights = tf.einsum("ab,bc->ac", down_weights, up_weights) * lora_layer.scale + lora_layer.original_layer.kernel.assign_add(increment_weights) + return lora_layer.original_layer + +def merge_lora(input_model, mode='default'): + for layer in input_model.layers: + if 'multihead_attention' in layer.name: + # default loRA + layer._q_layer = merge_lora_layer(layer._q_layer) + layer._v_layer = merge_lora_layer(layer._v_layer) + if mode=='full': + layer._k_layer = merge_lora_layer(layer._k_layer) + layer._embedding_layer = merge_lora_layer(layer._embedding_layer) + input_model(input_model.input) + +# correct weights.h5 weight order +def var_reorder(weight_h5): + # assumes weight_h5 model saved with seqnn_model.save() + # [i.name for i in model.layers[30].weights] to check for multihead_attention layer weights order. + # model.load_weights() load weights sequencially, assuming layer weights are in the right order. + # When inserting lora/ia3, multihead_attention layer weights order changed. + # multihead_attention layer weights order is saved inside f['model_weights']['multihead_attention'].attrs + # After saving the weight_merged model, we need to go into the weights.h5, and change the attrs in multihead attention. + var_init_order = ['r_w_bias:0:0', + 'r_r_bias:0:0', + 'q_layer/kernel:0', + 'k_layer/kernel:0', + 'v_layer/kernel:0', + 'embedding_layer/kernel:0', + 'embedding_layer/bias:0', + 'r_k_layer/kernel:0'] + + f = h5py.File(weight_h5, 'r+') + layers = [i for i in list(f['model_weights'].keys()) if 'multihead_attention' in i] + for l_name in layers: + new_name_order = [l_name+'/'+i for i in var_init_order] + f['model_weights'][l_name].attrs.modify(name='weight_names', value=new_name_order) + f.close() + +##################### +# merge ia3 weights # +##################### +def merge_ia3_layer(ia3_layer, type='kv'): + scaler = ia3_layer._ia3_layer.kernel[0] + ia3_layer.original_layer.kernel.assign(ia3_layer.original_layer.kernel * scaler) + if type=='embedding': + ia3_layer.original_layer.bias.assign(ia3_layer.original_layer.bias * scaler) + return ia3_layer.original_layer + +def merge_ia3(input_model): + for layer in input_model.layers: + if 'multihead_attention' in layer.name: + layer._k_layer = merge_ia3_layer(layer._k_layer, type='kv') + layer._v_layer = merge_ia3_layer(layer._v_layer, type='kv') + layer._embedding_layer = merge_ia3_layer(layer._embedding_layer, type='embedding') + input_model(input_model.input) +