diff --git a/src/baskerville/blocks.py b/src/baskerville/blocks.py index fab9a4b..2490f2e 100644 --- a/src/baskerville/blocks.py +++ b/src/baskerville/blocks.py @@ -18,6 +18,7 @@ from baskerville import layers + ############################################################ # Convolution ############################################################ @@ -892,7 +893,7 @@ def conv_tower( divisible_by=1, repeat=1, reprs=[], - **kwargs + **kwargs, ): """Construct a reducing convolution block. @@ -943,7 +944,7 @@ def conv_tower_nac( divisible_by=1, repeat=1, reprs=[], - **kwargs + **kwargs, ): """Construct a reducing convolution block. @@ -1000,7 +1001,7 @@ def res_tower( repeat=1, num_convs=2, reprs=[], - **kwargs + **kwargs, ): """Construct a reducing convolution block. @@ -1087,7 +1088,7 @@ def convnext_tower( repeat=1, num_convs=2, reprs=[], - **kwargs + **kwargs, ): """Abc. @@ -1129,7 +1130,7 @@ def _round(x): filters=rep_filters_int, kernel_size=kernel_size, dropout=dropout, - **kwargs + **kwargs, ) current0 = current @@ -1141,7 +1142,7 @@ def _round(x): filters=rep_filters_int, kernel_size=kernel_size, dropout=dropout, - **kwargs + **kwargs, ) # residual add @@ -1187,7 +1188,7 @@ def transformer( qkv_width=1, mha_initializer="he_normal", kernel_initializer="he_normal", - **kwargs + **kwargs, ): """Construct a transformer block. @@ -1255,7 +1256,7 @@ def transformer_split( qkv_width=1, mha_initializer="he_normal", kernel_initializer="he_normal", - **kwargs + **kwargs, ): """Construct a transformer block. @@ -1393,7 +1394,7 @@ def transformer2( dropout=0.25, dense_expansion=2.0, qkv_width=1, - **kwargs + **kwargs, ): """Construct a transformer block, with length-wise pooling before returning to full length. @@ -1416,7 +1417,7 @@ def transformer2( filters=min(4 * key_size, inputs.shape[-1]), kernel_size=3, pool_size=2, - **kwargs + **kwargs, ) # layer norm @@ -1517,7 +1518,7 @@ def squeeze_excite( additive=False, norm_type=None, bn_momentum=0.9, - **kwargs + **kwargs, ): return layers.SqueezeExcite( activation, additive, bottleneck_ratio, norm_type, bn_momentum @@ -1545,7 +1546,7 @@ def dilated_dense( conv_type="standard", dropout=0, repeat=1, - **kwargs + **kwargs, ): """Construct a residual dilated dense block. @@ -1570,7 +1571,7 @@ def dilated_dense( kernel_size=kernel_size, dilation_rate=int(np.round(dilation_rate)), conv_type=conv_type, - **kwargs + **kwargs, ) # dense concat @@ -1592,7 +1593,7 @@ def dilated_residual( conv_type="standard", norm_type=None, round=False, - **kwargs + **kwargs, ): """Construct a residual dilated convolution block. @@ -1619,7 +1620,7 @@ def dilated_residual( conv_type=conv_type, norm_type=norm_type, norm_gamma="ones", - **kwargs + **kwargs, ) # return @@ -1629,7 +1630,7 @@ def dilated_residual( dropout=dropout, norm_type=norm_type, norm_gamma="zeros", - **kwargs + **kwargs, ) # InitZero @@ -1672,7 +1673,7 @@ def dilated_residual_nac( filters=filters, kernel_size=kernel_size, dilation_rate=int(np.round(dilation_rate)), - **kwargs + **kwargs, ) # return @@ -1697,7 +1698,7 @@ def dilated_residual_2d( dropout=0, repeat=1, symmetric=True, - **kwargs + **kwargs, ): """Construct a residual dilated convolution block.""" @@ -1717,7 +1718,7 @@ def dilated_residual_2d( kernel_size=kernel_size, dilation_rate=int(np.round(dilation_rate)), norm_gamma="ones", - **kwargs + **kwargs, ) # return @@ -1726,7 +1727,7 @@ def dilated_residual_2d( filters=rep_input.shape[-1], dropout=dropout, norm_gamma="zeros", - **kwargs + **kwargs, ) # residual add @@ -1818,7 +1819,7 @@ def dense_block( bn_momentum=0.99, norm_gamma=None, kernel_initializer="he_normal", - **kwargs + **kwargs, ): """Construct a single convolution block. @@ -1909,7 +1910,7 @@ def dense_nac( bn_momentum=0.99, norm_gamma=None, kernel_initializer="he_normal", - **kwargs + **kwargs, ): """Construct a single convolution block. @@ -1991,7 +1992,7 @@ def final( kernel_initializer="he_normal", l2_scale=0, l1_scale=0, - **kwargs + **kwargs, ): """Final simple transformation before comparison to targets. diff --git a/src/baskerville/metrics.py b/src/baskerville/metrics.py index 8e947ba..29c0d99 100644 --- a/src/baskerville/metrics.py +++ b/src/baskerville/metrics.py @@ -24,6 +24,7 @@ for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True) + ################################################################################ # Losses ################################################################################ diff --git a/src/baskerville/scripts/hound_eval_spec.py b/src/baskerville/scripts/hound_eval_spec.py index e69d146..43d908c 100755 --- a/src/baskerville/scripts/hound_eval_spec.py +++ b/src/baskerville/scripts/hound_eval_spec.py @@ -35,6 +35,7 @@ Test the accuracy of a trained model on targets/predictions normalized across targets. """ + ################################################################################ # main ################################################################################ diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index cfbd676..0eb8a7d 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -17,7 +17,7 @@ import pdb import pickle import os -from baskerville.snps import calculate_sad +from baskerville.snps import score_snps """ hound_snp.py @@ -25,12 +25,20 @@ Compute variant effect predictions for SNPs in a VCF file. """ + ################################################################################ # main ################################################################################ def main(): usage = "usage: %prog [options] " parser = OptionParser(usage) + parser.add_option( + "-c", + dest="cluster_snps_pct", + default=0, + type="float", + help="Cluster SNPs within a %% of the seq length to make a single ref pred [Default: %default]", + ) parser.add_option( "-f", dest="genome_fasta", @@ -66,8 +74,8 @@ def main(): ) parser.add_option( "--stats", - dest="sad_stats", - default="SAD", + dest="snp_stats", + default="logSAD", help="Comma-separated list of stats to save. [Default: %default]", ) parser.add_option( @@ -129,17 +137,20 @@ def main(): else: parser.error("Must provide parameters and model files and QTL VCF file") + if options.targets_file is None: + parser.error("Must provide targets file") + if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) options.shifts = [int(shift) for shift in options.shifts.split(",")] - options.sad_stats = options.sad_stats.split(",") + options.snp_stats = options.snp_stats.split(",") # calculate SAD scores: if options.processes is not None: - calculate_sad(params_file, model_file, vcf_file, worker_index, options) + score_snps(params_file, model_file, vcf_file, worker_index, options) else: - calculate_sad(params_file, model_file, vcf_file, 0, options) + score_snps(params_file, model_file, vcf_file, 0, options) ################################################################################ diff --git a/src/baskerville/scripts/hound_snp_slurm.py b/src/baskerville/scripts/hound_snp_slurm.py index dd4aa9e..e2be5fd 100755 --- a/src/baskerville/scripts/hound_snp_slurm.py +++ b/src/baskerville/scripts/hound_snp_slurm.py @@ -33,6 +33,7 @@ parallelized across a slurm cluster. """ + ################################################################################ # main ################################################################################ @@ -41,6 +42,13 @@ def main(): parser = OptionParser(usage) # snp + parser.add_option( + "-c", + dest="cluster_snps_pct", + default=0, + type="float", + help="Cluster SNPs within a %% of the seq length to make a single ref pred [Default: %default]", + ) parser.add_option( "-f", dest="genome_fasta", @@ -69,8 +77,8 @@ def main(): ) parser.add_option( "--stats", - dest="sad_stats", - default="SAD", + dest="snp_stats", + default="logSAD", help="Comma-separated list of stats to save. [Default: %default]", ) parser.add_option( @@ -80,12 +88,19 @@ def main(): type="str", help="File specifying target indexes and labels in table format", ) + parser.add_option( + "-u", + dest="untransform_old", + default=False, + action="store_true", + help="Untransform old models [Default: %default]", + ) # multi parser.add_option( "-e", dest="conda_env", - default="tf210", + default="tf12", help="Anaconda environment [Default: %default]", ) parser.add_option( diff --git a/src/baskerville/scripts/hound_train.py b/src/baskerville/scripts/hound_train.py index beec2e2..e7ec150 100755 --- a/src/baskerville/scripts/hound_train.py +++ b/src/baskerville/scripts/hound_train.py @@ -163,7 +163,6 @@ def main(): strategy = tf.distribute.MirroredStrategy() with strategy.scope(): - if not args.keras_fit: # distribute data for di in range(len(args.data_dirs)): diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 12592ca..a00c843 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -524,7 +524,7 @@ def predict( stream: bool = False, step: int = 1, dtype: str = "float32", - **kwargs + **kwargs, ): """Predict targets for SeqDataset, with more options. diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 0c3d81d..d89a6ab 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -1,4 +1,7 @@ import json +import pdb +import sys + import h5py import numpy as np import pandas as pd @@ -7,24 +10,24 @@ from scipy.special import rel_entr from tqdm import tqdm +from baskerville import dna from baskerville import dataset from baskerville import seqnn from baskerville import vcf as bvcf -def calculate_sad(params_file, model_file, vcf_file, worker_index, options): +def score_snps(params_file, model_file, vcf_file, worker_index, options): """ - write SAD output - :param params_file: DNN model params - :param model_file: DNN model - :param vcf_file: input vcf + Score SNPs in a VCF file with a SeqNN model. + + :param params_file: Model parameters + :param model_file: Saved model weights + :param vcf_file: VCF :param worker_index :param options: options from cmd args :return: """ - ################################################################# - # read parameters and targets ################################################################# # read parameters and targets @@ -32,57 +35,37 @@ def calculate_sad(params_file, model_file, vcf_file, worker_index, options): with open(params_file) as params_open: params = json.load(params_open) params_model = params["model"] - params_train = params["train"] - - if options.targets_file is None: - target_slice = None - sum_strand = False - else: - targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0) - target_slice = targets_df.index - if "strand_pair" in targets_df.columns: - sum_strand = True + # read targets + targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0) - # prep strand - targets_strand_df = dataset.targets_prep_strand(targets_df) + # handle strand pairs + if "strand_pair" in targets_df.columns: + # prep strand + targets_strand_df = dataset.targets_prep_strand(targets_df) - # set strand pairs (using new indexing) - orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) - targets_strand_pair = np.array( - [orig_new_index[ti] for ti in targets_df.strand_pair] - ) - params_model["strand_pair"] = [targets_strand_pair] - - # construct strand sum transform - strand_transform = dok_matrix( - (targets_df.shape[0], targets_strand_df.shape[0]) - ) - ti = 0 - sti = 0 - for _, target in targets_df.iterrows(): - strand_transform[ti, sti] = True - if target.strand_pair == target.name: - sti += 1 - else: - if target.identifier[-1] == "-": - sti += 1 - ti += 1 - strand_transform = strand_transform.tocsr() + # set strand pairs (using new indexing) + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array( + [orig_new_index[ti] for ti in targets_df.strand_pair] + ) + params_model["strand_pair"] = [targets_strand_pair] - else: - targets_strand_df = targets_df - sum_strand = False + # construct strand sum transform + strand_transform = make_strand_transform(targets_df, targets_strand_df) + else: + targets_strand_df = targets_df + strand_transform = None ################################################################# # setup model # can we sum on GPU? - sum_length = options.sad_stats == "SAD" + sum_length = options.snp_stats == "SAD" seqnn_model = seqnn.SeqNN(params_model) seqnn_model.restore(model_file) - seqnn_model.build_slice(target_slice) + seqnn_model.build_slice(targets_df.index) if sum_length: seqnn_model.build_sad() seqnn_model.build_ensemble(options.rc, options.shifts) @@ -99,22 +82,43 @@ def calculate_sad(params_file, model_file, vcf_file, worker_index, options): ################################################################# # load SNPs + # clustering SNPs requires sorted VCF and no reference flips + snps_clustered = options.cluster_snps_pct > 0 + # filter for worker SNPs - if options.processes is not None: + if options.processes is None: + start_i = None + end_i = None + else: # determine boundaries num_snps = bvcf.vcf_count(vcf_file) worker_bounds = np.linspace(0, num_snps, options.processes + 1, dtype="int") + start_i = worker_bounds[worker_index] + end_i = worker_bounds[worker_index + 1] + + # read SNPs + snps = bvcf.vcf_snps( + vcf_file, + require_sorted=snps_clustered, + flip_ref=~snps_clustered, + validate_ref_fasta=options.genome_fasta, + start_i=start_i, + end_i=end_i, + ) - # read SNPs form VCF - snps = bvcf.vcf_snps( - vcf_file, - start_i=worker_bounds[worker_index], - end_i=worker_bounds[worker_index + 1], + # cluster SNPs + if snps_clustered: + snp_clusters = cluster_snps( + snps, params_model["seq_length"], options.cluster_snps_pct ) - else: - # read SNPs form VCF - snps = bvcf.vcf_snps(vcf_file) + snp_clusters = [] + for snp in snps: + snp_clusters.append(SNPCluster()) + snp_clusters[-1].add_snp(snp) + + # delimit sequence boundaries + [sc.delimit(params_model["seq_length"]) for sc in snp_clusters] # open genome FASTA genome_open = pysam.Fastafile(options.genome_fasta) @@ -123,75 +127,128 @@ def calculate_sad(params_file, model_file, vcf_file, worker_index, options): # predict SNP scores, write output # setup output - sad_out = initialize_output_h5( - options.out_dir, options.sad_stats, snps, targets_length, targets_strand_df + scores_out = initialize_output_h5( + options.out_dir, options.snp_stats, snps, targets_length, targets_strand_df ) - for si, snp in tqdm(enumerate(snps), total=len(snps)): - # get SNP sequences - snp_1hot_list = bvcf.snp_seq1(snp, params_model["seq_length"], genome_open) - snps_1hot = np.array(snp_1hot_list) + # SNP index + si = 0 - # get predictions - if params_train["batch_size"] == 1: - ref_preds = seqnn_model(snps_1hot[:1])[0] - alt_preds = seqnn_model(snps_1hot[1:])[0] - else: - snp_preds = seqnn_model(snps_1hot) - ref_preds, alt_preds = snp_preds[0], snp_preds[1] + for sc in tqdm(snp_clusters): + snp_1hot_list = sc.get_1hots(genome_open) + + # predict reference + ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) + ref_preds = seqnn_model(ref_1hot)[0] # untransform predictions if options.targets_file is not None: if options.untransform_old: ref_preds = dataset.untransform_preds1(ref_preds, targets_df) - alt_preds = dataset.untransform_preds1(alt_preds, targets_df) else: ref_preds = dataset.untransform_preds(ref_preds, targets_df) - alt_preds = dataset.untransform_preds(alt_preds, targets_df) # sum strand pairs - if sum_strand: + if strand_transform is not None: ref_preds = ref_preds * strand_transform - alt_preds = alt_preds * strand_transform - # process SNP - if sum_length: - write_snp(ref_preds, alt_preds, sad_out, si, options.sad_stats) - else: - write_snp_len(ref_preds, alt_preds, sad_out, si, options.sad_stats) + for alt_1hot in snp_1hot_list[1:]: + alt_1hot = np.expand_dims(alt_1hot, axis=0) + + # predict alternate + alt_preds = seqnn_model(alt_1hot)[0] + + # untransform predictions + if options.targets_file is not None: + if options.untransform_old: + alt_preds = dataset.untransform_preds1(alt_preds, targets_df) + else: + alt_preds = dataset.untransform_preds(alt_preds, targets_df) + + # sum strand pairs + if strand_transform is not None: + alt_preds = alt_preds * strand_transform + + # flip reference and alternate + if snps[si].flipped: + rp_snp = alt_preds + ap_snp = ref_preds + else: + rp_snp = ref_preds + ap_snp = alt_preds + + # write SNP + if sum_length: + write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats) + else: + write_snp_len(rp_snp, ap_snp, scores_out, si, options.snp_stats) + + # update SNP index + si += 1 # close genome genome_open.close() - ################################################### # compute SAD distributions across variants + write_pct(scores_out, options.snp_stats) + scores_out.close() + + +def cluster_snps(snps, seq_len: int, center_pct: float): + """Cluster a sorted list of SNPs into regions that will satisfy + the required center_pct. + + Args: + snps [SNP]: List of SNPs. + seq_len (int): Sequence length. + center_pct (float): Percent of sequence length to cluster SNPs. + """ + valid_snp_distance = int(seq_len * center_pct) + + snp_clusters = [] + cluster_chr = None + + for snp in snps: + if snp.chr == cluster_chr and snp.pos < cluster_pos0 + valid_snp_distance: + # append to latest cluster + snp_clusters[-1].add_snp(snp) + else: + # initialize new cluster + snp_clusters.append(SNPCluster()) + snp_clusters[-1].add_snp(snp) + cluster_chr = snp.chr + cluster_pos0 = snp.pos - write_pct(sad_out, options.sad_stats) - sad_out.close() + return snp_clusters -def initialize_output_h5(out_dir, sad_stats, snps, targets_length, targets_df): - """Initialize an output HDF5 file for SAD stats.""" +def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df): + """Initialize an output HDF5 file for SAD stats. + + Args: + out_dir (str): Output directory. + snp_stats [str]: List of SAD stats to compute. + snps [SNP]: List of SNPs. + targets_length (int): Targets' sequence length + targets_df (pd.DataFrame): Targets AataFrame. + """ num_targets = targets_df.shape[0] num_snps = len(snps) - sad_out = h5py.File("%s/scores.h5" % out_dir, "w") + scores_out = h5py.File("%s/scores.h5" % out_dir, "w") # write SNPs snp_ids = np.array([snp.rsid for snp in snps], "S") - sad_out.create_dataset("snp", data=snp_ids) + scores_out.create_dataset("snp", data=snp_ids) # write SNP chr snp_chr = np.array([snp.chr for snp in snps], "S") - sad_out.create_dataset("chr", data=snp_chr) + scores_out.create_dataset("chr", data=snp_chr) # write SNP pos snp_pos = np.array([snp.pos for snp in snps], dtype="uint32") - sad_out.create_dataset("pos", data=snp_pos) - - # check flips - snp_flips = [snp.flipped for snp in snps] + scores_out.create_dataset("pos", data=snp_pos) # write SNP reference allele snp_refs = [] @@ -205,30 +262,114 @@ def initialize_output_h5(out_dir, sad_stats, snps, targets_length, targets_df): snp_alts.append(snp.alt_alleles[0]) snp_refs = np.array(snp_refs, "S") snp_alts = np.array(snp_alts, "S") - sad_out.create_dataset("ref_allele", data=snp_refs) - sad_out.create_dataset("alt_allele", data=snp_alts) + scores_out.create_dataset("ref_allele", data=snp_refs) + scores_out.create_dataset("alt_allele", data=snp_alts) # write targets - sad_out.create_dataset("target_ids", data=np.array(targets_df.identifier, "S")) - sad_out.create_dataset("target_labels", data=np.array(targets_df.description, "S")) + scores_out.create_dataset("target_ids", data=np.array(targets_df.identifier, "S")) + scores_out.create_dataset( + "target_labels", data=np.array(targets_df.description, "S") + ) # initialize SAD stats - for sad_stat in sad_stats: - if sad_stat in ["REF", "ALT"]: - sad_out.create_dataset( - sad_stat, shape=(num_snps, targets_length, num_targets), dtype="float16" + for snp_stat in snp_stats: + if snp_stat in ["REF", "ALT"]: + scores_out.create_dataset( + snp_stat, shape=(num_snps, targets_length, num_targets), dtype="float16" ) else: - sad_out.create_dataset( - sad_stat, shape=(num_snps, num_targets), dtype="float16" + scores_out.create_dataset( + snp_stat, shape=(num_snps, num_targets), dtype="float16" ) - return sad_out + return scores_out + + +def make_alt_1hot(ref_1hot, snp_seq_pos, ref_allele, alt_allele): + """Return alternative allele one hot coding. + Args: + ref_1hot (np.array): Reference allele one hot coding. + snp_seq_pos (int): SNP position in sequence. + ref_allele (str): Reference allele. + alt_allele (str): Alternative allele. + + Returns: + np.array: Alternative allele one hot coding. + """ + ref_n = len(ref_allele) + alt_n = len(alt_allele) + + # copy reference + alt_1hot = np.copy(ref_1hot) + + if alt_n == ref_n: + # SNP + dna.hot1_set(alt_1hot, snp_seq_pos, alt_allele) + + elif ref_n > alt_n: + # deletion + delete_len = ref_n - alt_n + if ref_allele[0] == alt_allele[0]: + dna.hot1_delete(alt_1hot, snp_seq_pos + 1, delete_len) + else: + print( + "WARNING: Delection first nt does not match: %s %s" + % (ref_allele, alt_allele), + file=sys.stderr, + ) + + else: + # insertion + if ref_allele[0] == alt_allele[0]: + dna.hot1_insert(alt_1hot, snp_seq_pos + 1, alt_allele[1:]) + else: + print( + "WARNING: Insertion first nt does not match: %s %s" + % (ref_allele, alt_allele), + file=sys.stderr, + ) -def write_pct(sad_out, sad_stats): - """Compute percentile values for each target and write to HDF5.""" + return alt_1hot + +def make_strand_transform(targets_df, targets_strand_df): + """Make a sparse matrix to sum strand pairs. + + Args: + targets_df (pd.DataFrame): Targets DataFrame. + targets_strand_df (pd.DataFrame): Targets DataFrame, with strand pairs collapsed. + + Returns: + scipy.sparse.csr_matrix: Sparse matrix to sum strand pairs. + """ + + # initialize sparse matrix + strand_transform = dok_matrix((targets_df.shape[0], targets_strand_df.shape[0])) + + # fill in matrix + ti = 0 + sti = 0 + for _, target in targets_df.iterrows(): + strand_transform[ti, sti] = True + if target.strand_pair == target.name: + sti += 1 + else: + if target.identifier[-1] == "-": + sti += 1 + ti += 1 + strand_transform = strand_transform.tocsr() + + return strand_transform + + +def write_pct(scores_out, snp_stats): + """Compute percentile values for each target and write to HDF5. + + Args: + scores_out (h5py.File): Output HDF5 file. + snp_stats [str]: List of SAD stats to compute. + """ # define percentiles d_fine = 0.001 d_coarse = 0.01 @@ -237,34 +378,49 @@ def write_pct(sad_out, sad_stats): percentiles_pos = np.arange(0.9, 1, d_fine) percentiles = np.concatenate([percentiles_neg, percentiles_base, percentiles_pos]) - sad_out.create_dataset("percentiles", data=percentiles) - pct_len = len(percentiles) + scores_out.create_dataset("percentiles", data=percentiles) - for sad_stat in sad_stats: - if sad_stat not in ["REF", "ALT"]: - sad_stat_pct = "%s_pct" % sad_stat + for snp_stat in snp_stats: + if snp_stat not in ["REF", "ALT"]: + snp_stat_pct = "%s_pct" % snp_stat # compute - sad_pct = np.percentile(sad_out[sad_stat], 100 * percentiles, axis=0).T + sad_pct = np.percentile(scores_out[snp_stat], 100 * percentiles, axis=0).T sad_pct = sad_pct.astype("float16") # save - sad_out.create_dataset(sad_stat_pct, data=sad_pct, dtype="float16") + scores_out.create_dataset(snp_stat_pct, data=sad_pct, dtype="float16") -def write_snp(ref_preds_sum, alt_preds_sum, sad_out, si, sad_stats): +def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats): """Write SNP predictions to HDF, assuming the length dimension has - been collapsed.""" + been collapsed. + + Args: + ref_preds_sum (np.array): Reference allele predictions. + alt_preds_sum (np.array): Alternative allele predictions. + scores_out (h5py.File): Output HDF5 file. + si (int): SNP index. + snp_stats [str]: List of SAD stats to compute. + """ # compare reference to alternative via mean subtraction - if "SAD" in sad_stats: + if "SAD" in snp_stats: sad = alt_preds_sum - ref_preds_sum - sad_out["SAD"][si, :] = sad.astype("float16") + scores_out["SAD"][si, :] = sad.astype("float16") -def write_snp_len(ref_preds, alt_preds, sad_out, si, sad_stats): +def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): """Write SNP predictions to HDF, assuming the length dimension has - been maintained.""" + been maintained. + + Args: + ref_preds (np.array): Reference allele predictions. + alt_preds (np.array): Alternative allele predictions. + scores_out (h5py.File): Output HDF5 file. + si (int): SNP index. + snp_stats [str]: List of SAD stats to compute. + """ seq_length, num_targets = ref_preds.shape # log/sqrt @@ -290,57 +446,57 @@ def write_snp_len(ref_preds, alt_preds, sad_out, si, sad_stats): altref_sqrt_adiff = np.abs(altref_sqrt_diff) # compare reference to alternative via sum subtraction - if "SAD" in sad_stats: + if "SAD" in snp_stats: sad = alt_preds_sum - ref_preds_sum sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["SAD"][si] = sad.astype("float16") - if "logSAD" in sad_stats: + scores_out["SAD"][si] = sad.astype("float16") + if "logSAD" in snp_stats: log_sad = alt_preds_log_sum - ref_preds_log_sum log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["logSAD"][si] = log_sad.astype("float16") - if "sqrtSAD" in sad_stats: + scores_out["logSAD"][si] = log_sad.astype("float16") + if "sqrtSAD" in snp_stats: sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["sqrtSAD"][si] = sqrt_sad.astype("float16") + scores_out["sqrtSAD"][si] = sqrt_sad.astype("float16") # compare reference to alternative via max subtraction - if "SAX" in sad_stats: + if "SAX" in snp_stats: max_i = np.argmax(altref_adiff, axis=0) sax = altref_diff[max_i, np.arange(num_targets)] - sad_out["SAX"][si] = sax.astype("float16") + scores_out["SAX"][si] = sax.astype("float16") # L1 norm of difference vector - if "D1" in sad_stats: + if "D1" in snp_stats: sad_d1 = altref_adiff.sum(axis=0) sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["D1"][si] = sad_d1.astype("float16") - if "logD1" in sad_stats: + scores_out["D1"][si] = sad_d1.astype("float16") + if "logD1" in snp_stats: log_d1 = altref_log_adiff.sum(axis=0) log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["logD1"][si] = log_d1.astype("float16") - if "sqrtD1" in sad_stats: + scores_out["logD1"][si] = log_d1.astype("float16") + if "sqrtD1" in snp_stats: sqrt_d1 = altref_sqrt_adiff.sum(axis=0) sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["sqrtD1"][si] = sqrt_d1.astype("float16") + scores_out["sqrtD1"][si] = sqrt_d1.astype("float16") # L2 norm of difference vector - if "D2" in sad_stats: + if "D2" in snp_stats: altref_diff2 = np.power(altref_diff, 2) sad_d2 = np.sqrt(altref_diff2.sum(axis=0)) sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["D2"][si] = sad_d2.astype("float16") - if "logD2" in sad_stats: + scores_out["D2"][si] = sad_d2.astype("float16") + if "logD2" in snp_stats: altref_log_diff2 = np.power(altref_log_diff, 2) log_d2 = np.sqrt(altref_log_diff2.sum(axis=0)) log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["logD2"][si] = log_d2.astype("float16") - if "sqrtD2" in sad_stats: + scores_out["logD2"][si] = log_d2.astype("float16") + if "sqrtD2" in snp_stats: altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2) sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=0)) sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_out["sqrtD2"][si] = sqrt_d2.astype("float16") + scores_out["sqrtD2"][si] = sqrt_d2.astype("float16") - if "JS" in sad_stats: + if "JS" in snp_stats: # normalized scores pseudocounts = np.percentile(ref_preds, 25, axis=0) ref_preds_norm = ref_preds + pseudocounts @@ -352,8 +508,8 @@ def write_snp_len(ref_preds, alt_preds, sad_out, si, sad_stats): ref_alt_entr = rel_entr(ref_preds_norm, alt_preds_norm).sum(axis=0) alt_ref_entr = rel_entr(alt_preds_norm, ref_preds_norm).sum(axis=0) js_dist = (ref_alt_entr + alt_ref_entr) / 2 - sad_out["JS"][si] = js_dist.astype("float16") - if "logJS" in sad_stats: + scores_out["JS"][si] = js_dist.astype("float16") + if "logJS" in snp_stats: # normalized scores pseudocounts = np.percentile(ref_preds_log, 25, axis=0) ref_preds_log_norm = ref_preds_log + pseudocounts @@ -365,16 +521,83 @@ def write_snp_len(ref_preds, alt_preds, sad_out, si, sad_stats): ref_alt_entr = rel_entr(ref_preds_log_norm, alt_preds_log_norm).sum(axis=0) alt_ref_entr = rel_entr(alt_preds_log_norm, ref_preds_log_norm).sum(axis=0) log_js_dist = (ref_alt_entr + alt_ref_entr) / 2 - sad_out["logJS"][si] = log_js_dist.astype("float16") + scores_out["logJS"][si] = log_js_dist.astype("float16") # predictions - if "REF" in sad_stats: + if "REF" in snp_stats: ref_preds = np.clip( ref_preds, np.finfo(np.float16).min, np.finfo(np.float16).max ) - sad_out["REF"][si] = ref_preds.astype("float16") - if "ALT" in sad_stats: + scores_out["REF"][si] = ref_preds.astype("float16") + if "ALT" in snp_stats: alt_preds = np.clip( alt_preds, np.finfo(np.float16).min, np.finfo(np.float16).max ) - sad_out["ALT"][si] = alt_preds.astype("float16") + scores_out["ALT"][si] = alt_preds.astype("float16") + + +class SNPCluster: + def __init__(self): + self.snps = [] + self.chr = None + self.start = None + self.end = None + + def add_snp(self, snp): + """Add SNP to cluster.""" + self.snps.append(snp) + + def delimit(self, seq_len): + """Delimit sequence boundaries.""" + positions = [snp.pos for snp in self.snps] + pos_min = np.min(positions) + pos_max = np.max(positions) + pos_mid = (pos_min + pos_max) // 2 + + self.chr = self.snps[0].chr + self.start = pos_mid - seq_len // 2 + self.end = self.start + seq_len + + for snp in self.snps: + snp.seq_pos = snp.pos - 1 - self.start + + def get_1hots(self, genome_open): + """Get list of one hot coded sequences.""" + seqs1_list = [] + + # extract reference + if self.start < 0: + ref_seq = ( + "N" * (-self.start) + genome_open.fetch(self.chr, 0, self.end).upper() + ) + else: + ref_seq = genome_open.fetch(self.chr, self.start, self.end).upper() + + # extend to full length + if len(ref_seq) < self.end - self.start: + ref_seq += "N" * (self.end - self.start - len(ref_seq)) + + # verify reference alleles + for snp in self.snps: + ref_n = len(snp.ref_allele) + ref_snp = ref_seq[snp.seq_pos : snp.seq_pos + ref_n] + if snp.ref_allele != ref_snp: + print( + "ERROR: %s does not match reference %s" % (snp, ref_snp), + file=sys.stderr, + ) + exit(1) + + # 1 hot code reference sequence + ref_1hot = dna.dna_1hot(ref_seq) + seqs1_list = [ref_1hot] + + # make alternative 1 hot coded sequences + # (assuming SNP is 1-based indexed) + for snp in self.snps: + alt_1hot = make_alt_1hot( + ref_1hot, snp.seq_pos, snp.ref_allele, snp.alt_alleles[0] + ) + seqs1_list.append(alt_1hot) + + return seqs1_list diff --git a/src/baskerville/trainer.py b/src/baskerville/trainer.py index d55c636..5c55f52 100644 --- a/src/baskerville/trainer.py +++ b/src/baskerville/trainer.py @@ -783,7 +783,7 @@ def adaptive_clip_grad( ): """Adaptive gradient clipping.""" new_grads = [] - for (params, grads) in zip(parameters, gradients): + for params, grads in zip(parameters, gradients): p_norm = unitwise_norm(params) max_norm = tf.math.maximum(p_norm, eps) * clip_factor grad_norm = unitwise_norm(grads) diff --git a/src/baskerville/vcf.py b/src/baskerville/vcf.py index 95b67db..802b0b5 100644 --- a/src/baskerville/vcf.py +++ b/src/baskerville/vcf.py @@ -219,13 +219,11 @@ def snp_seq1(snp, seq_len, genome_open): seq_ref = seq[left_len : left_len + len(snp.ref_allele)] ref_found = True if seq_ref != snp.ref_allele: - # search for reference allele in alternatives ref_found = False # for each alternative allele for alt_al in snp.alt_alleles: - # grab reference sequence matching alt length seq_ref_alt = seq[left_len : left_len + len(alt_al)] if seq_ref_alt == alt_al: @@ -314,13 +312,11 @@ def snps_seq1(snps, seq_len, genome_fasta, return_seqs=False): # verify that ref allele matches ref sequence seq_ref = seq[left_len : left_len + len(snp.ref_allele)] if seq_ref != snp.ref_allele: - # search for reference allele in alternatives ref_found = False # for each alternative allele for alt_al in snp.alt_alleles: - # grab reference sequence matching alt length seq_ref_alt = seq[left_len : left_len + len(alt_al)] if seq_ref_alt == alt_al: @@ -604,6 +600,7 @@ def vcf_snps( ref_n = len(snps[-1].ref_allele) snp_pos = snps[-1].pos - 1 ref_snp = genome_open.fetch(snps[-1].chr, snp_pos, snp_pos + ref_n) + if snps[-1].ref_allele != ref_snp: if not flip_ref: # bail