diff --git a/src/baskerville/dataset.py b/src/baskerville/dataset.py index 1498556..f061b97 100644 --- a/src/baskerville/dataset.py +++ b/src/baskerville/dataset.py @@ -20,6 +20,7 @@ from natsort import natsorted import numpy as np import pandas as pd +from scipy.sparse import dok_matrix import tensorflow as tf gpu_devices = tf.config.experimental.list_physical_devices("GPU") @@ -310,6 +311,36 @@ def numpy( return targets +def make_strand_transform(targets_df, targets_strand_df): + """Make a sparse matrix to sum strand pairs. + + Args: + targets_df (pd.DataFrame): Targets DataFrame. + targets_strand_df (pd.DataFrame): Targets DataFrame, with strand pairs collapsed. + + Returns: + scipy.sparse.csr_matrix: Sparse matrix to sum strand pairs. + """ + + # initialize sparse matrix + strand_transform = dok_matrix((targets_df.shape[0], targets_strand_df.shape[0])) + + # fill in matrix + ti = 0 + sti = 0 + for _, target in targets_df.iterrows(): + strand_transform[ti, sti] = True + if target.strand_pair == target.name: + sti += 1 + else: + if target.identifier[-1] == "-": + sti += 1 + ti += 1 + strand_transform = strand_transform.tocsr() + + return strand_transform + + def targets_prep_strand(targets_df): """Adjust targets table for merged stranded datasets. diff --git a/src/baskerville/scripts/hound_ism_bed.py b/src/baskerville/scripts/hound_ism_bed.py new file mode 100755 index 0000000..faf62b4 --- /dev/null +++ b/src/baskerville/scripts/hound_ism_bed.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser + +import gc +import json +import os +import pickle +from queue import Queue +import sys +from threading import Thread + +import h5py +import numpy as np +import pandas as pd +import tensorflow as tf + + +from baskerville import bed +from baskerville import dataset +from baskerville import dna +from baskerville import seqnn +from baskerville import snps + +""" +hound_ism_bed.py + +Perform an in silico saturation mutagenesis of sequences in a BED file. +""" + + +################################################################################ +# main +################################################################################ +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + parser.add_option( + "-d", + dest="mut_down", + default=0, + type="int", + help="Nucleotides downstream of center sequence to mutate [Default: %default]", + ) + parser.add_option( + "-f", + dest="genome_fasta", + default=None, + help="Genome FASTA for sequences [Default: %default]", + ) + parser.add_option( + "-l", + dest="mut_len", + default=0, + type="int", + help="Length of center sequence to mutate [Default: %default]", + ) + parser.add_option( + "-o", + dest="out_dir", + default="sat_mut", + help="Output directory [Default: %default]", + ) + parser.add_option( + "-p", + dest="processes", + default=None, + type="int", + help="Number of processes, passed by multi script", + ) + parser.add_option( + "--rc", + dest="rc", + default=False, + action="store_true", + help="Ensemble forward and reverse complement predictions [Default: %default]", + ) + parser.add_option( + "--shifts", + dest="shifts", + default="0", + help="Ensemble prediction shifts [Default: %default]", + ) + parser.add_option( + "--stats", + dest="snp_stats", + default="logSUM", + help="Comma-separated list of stats to save. [Default: %default]", + ) + parser.add_option( + "-t", + dest="targets_file", + default=None, + type="str", + help="File specifying target indexes and labels in table format", + ) + parser.add_option( + "-u", + dest="mut_up", + default=0, + type="int", + help="Nucleotides upstream of center sequence to mutate [Default: %default]", + ) + parser.add_option( + "--untransform_old", + dest="untransform_old", + default=False, + action="store_true", + help="Untransform old models [Default: %default]", + ) + (options, args) = parser.parse_args() + + if len(args) == 3: + # single worker + params_file = args[0] + model_file = args[1] + bed_file = args[2] + else: + parser.error("Must provide parameter and model files and BED file") + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(",")] + options.snp_stats = [snp_stat for snp_stat in options.snp_stats.split(",")] + + if options.mut_up > 0 or options.mut_down > 0: + options.mut_len = options.mut_up + options.mut_down + else: + assert options.mut_len > 0 + options.mut_up = options.mut_len // 2 + options.mut_down = options.mut_len - options.mut_up + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params["model"] + + # read targets + if options.targets_file is None: + parser.error("Must provide targets file to clarify stranded datasets") + targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0) + + # handle strand pairs + if "strand_pair" in targets_df.columns: + # prep strand + targets_strand_df = dataset.targets_prep_strand(targets_df) + + # set strand pairs (using new indexing) + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array( + [orig_new_index[ti] for ti in targets_df.strand_pair] + ) + params_model["strand_pair"] = [targets_strand_pair] + + # construct strand sum transform + strand_transform = dataset.make_strand_transform(targets_df, targets_strand_df) + else: + targets_strand_df = targets_df + strand_transform = None + num_targets = targets_strand_df.shape[0] + + ################################################################# + # setup model + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + seqnn_model.build_ensemble(options.rc) + + ################################################################# + # sequence dataset + + # read sequences from BED + seqs_dna, seqs_coords = bed.make_bed_seqs( + bed_file, options.genome_fasta, params_model["seq_length"], stranded=True + ) + num_seqs = len(seqs_dna) + + # determine mutation region limits + seq_mid = params_model["seq_length"] // 2 + mut_start = seq_mid - options.mut_up + mut_end = mut_start + options.mut_len + + ################################################################# + # setup output + + scores_h5_file = "%s/scores.h5" % options.out_dir + if os.path.isfile(scores_h5_file): + os.remove(scores_h5_file) + scores_h5 = h5py.File(scores_h5_file, "w") + scores_h5.create_dataset("seqs", dtype="bool", shape=(num_seqs, options.mut_len, 4)) + for snp_stat in options.snp_stats: + scores_h5.create_dataset( + snp_stat, dtype="float16", shape=(num_seqs, options.mut_len, 4, num_targets) + ) + + # store mutagenesis sequence coordinates + scores_chr = [] + scores_start = [] + scores_end = [] + scores_strand = [] + for seq_chr, seq_start, seq_end, seq_strand in seqs_coords: + scores_chr.append(seq_chr) + scores_strand.append(seq_strand) + if seq_strand == "+": + score_start = seq_start + mut_start + score_end = score_start + options.mut_len + else: + score_end = seq_end - mut_start + score_start = score_end - options.mut_len + scores_start.append(score_start) + scores_end.append(score_end) + + scores_h5.create_dataset("chr", data=np.array(scores_chr, dtype="S")) + scores_h5.create_dataset("start", data=np.array(scores_start)) + scores_h5.create_dataset("end", data=np.array(scores_end)) + scores_h5.create_dataset("strand", data=np.array(scores_strand, dtype="S")) + + ################################################################# + # predict scores, write output + + for si, seq_dna in enumerate(seqs_dna): + print("Predicting %d" % si, flush=True) + + # 1 hot code DNA + ref_1hot = dna.dna_1hot(seq_dna) + ref_1hot = np.expand_dims(ref_1hot, axis=0) + + # save sequence + scores_h5["seqs"][si] = ref_1hot[0, mut_start:mut_end].astype("bool") + + # predict reference + ref_preds = [] + for shift in options.shifts: + # shift sequence and predict + ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift) + ref_preds_shift = seqnn_model.predict_transform( + ref_1hot_shift, + targets_df, + strand_transform, + options.untransform_old, + ) + ref_preds.append(ref_preds_shift) + ref_preds = np.array(ref_preds) + + # for mutation positions + for mi in range(mut_start, mut_end): + # for each nucleotide + for ni in range(4): + # if non-reference + if ref_1hot[0, mi, ni] == 0: + # copy and modify + alt_1hot = np.copy(ref_1hot) + alt_1hot[0, mi, :] = 0 + alt_1hot[0, mi, ni] = 1 + + # predict alternate + alt_preds = [] + for shift in options.shifts: + # shift sequence and predict + alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift) + alt_preds_shift = seqnn_model.predict_transform( + alt_1hot_shift, + targets_df, + strand_transform, + options.untransform_old, + ) + alt_preds.append(alt_preds_shift) + alt_preds = np.array(alt_preds) + + ism_scores = snps.compute_scores( + ref_preds, alt_preds, options.snp_stats + ) + for snp_stat in options.snp_stats: + scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[ + snp_stat + ] + + # close output HDF5 + scores_h5.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == "__main__": + main() diff --git a/src/baskerville/scripts/hound_ism_snp.py b/src/baskerville/scripts/hound_ism_snp.py new file mode 100755 index 0000000..52332d5 --- /dev/null +++ b/src/baskerville/scripts/hound_ism_snp.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser +import json +import os +import pdb +import time + +import h5py +import numpy as np +import pandas as pd + +from baskerville import dataset +from baskerville import dna +from baskerville import seqnn +from baskerville import snps +from baskerville import vcf + +""" +hound_ism_snp.py + +Perform an in silico saturated mutagenesis of the sequences surrounding variants +given in a VCF file. +""" + + +################################################################################ +# main +################################################################################ +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + parser.add_option( + "-d", + dest="mut_down", + default=0, + type="int", + help="Nucleotides downstream of center sequence to mutate [Default: %default]", + ) + parser.add_option( + "-f", + dest="genome_fasta", + default=None, + help="Genome FASTA [Default: %default]", + ) + parser.add_option( + "-l", + dest="mut_len", + default=200, + type="int", + help="Length of centered sequence to mutate [Default: %default]", + ) + parser.add_option( + "-o", + dest="out_dir", + default="ism_snp_out", + help="Output directory [Default: %default]", + ) + parser.add_option( + "--rc", + dest="rc", + default=False, + action="store_true", + help="Ensemble forward and reverse complement predictions [Default: %default]", + ) + parser.add_option( + "--shifts", + dest="shifts", + default="0", + help="Ensemble prediction shifts [Default: %default]", + ) + parser.add_option( + "--stats", + dest="snp_stats", + default="logSUM", + help="Comma-separated list of stats to save. [Default: %default]", + ) + parser.add_option( + "-t", + dest="targets_file", + default=None, + type="str", + help="File specifying target indexes and labels in table format", + ) + parser.add_option( + "-u", + dest="mut_up", + default=0, + type="int", + help="Nucleotides upstream of center sequence to mutate [Default: %default]", + ) + parser.add_option( + "--untransform_old", + dest="untransform_old", + default=False, + action="store_true", + help="Untransform old models [Default: %default]", + ) + (options, args) = parser.parse_args() + + if len(args) != 3: + parser.error("Must provide parameters and model files and VCF") + else: + params_file = args[0] + model_file = args[1] + vcf_file = args[2] + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(",")] + options.snp_stats = [snp_stat for snp_stat in options.snp_stats.split(",")] + + if options.mut_up > 0 or options.mut_down > 0: + options.mut_len = options.mut_up + options.mut_down + else: + assert options.mut_len > 0 + options.mut_up = options.mut_len // 2 + options.mut_down = options.mut_len - options.mut_up + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params["model"] + + # read targets + if options.targets_file is None: + parser.error("Must provide targets file to clarify stranded datasets") + targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0) + + # handle strand pairs + if "strand_pair" in targets_df.columns: + # prep strand + targets_strand_df = dataset.targets_prep_strand(targets_df) + + # set strand pairs (using new indexing) + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array( + [orig_new_index[ti] for ti in targets_df.strand_pair] + ) + params_model["strand_pair"] = [targets_strand_pair] + + # construct strand sum transform + strand_transform = dataset.make_strand_transform(targets_df, targets_strand_df) + else: + targets_strand_df = targets_df + strand_transform = None + num_targets = targets_strand_df.shape[0] + + ################################################################# + # setup model + + seqnn_model = seqnn.SeqNN(params_model) + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + seqnn_model.build_ensemble(options.rc) + + ################################################################# + # SNP sequence dataset + + # load SNPs + variants = vcf.vcf_snps(vcf_file) + + # get one hot coded input sequences + seqs_1hot, seq_headers, variants, seqs_dna = vcf.snps_seq1( + variants, params_model["seq_length"], options.genome_fasta, return_seqs=True + ) + num_seqs = seqs_1hot.shape[0] + + # determine mutation region limits + seq_mid = params_model["seq_length"] // 2 + mut_start = seq_mid - options.mut_up + mut_end = mut_start + options.mut_len + + ################################################################# + # setup output + + scores_h5_file = "%s/scores.h5" % options.out_dir + if os.path.isfile(scores_h5_file): + os.remove(scores_h5_file) + scores_h5 = h5py.File(scores_h5_file, "w") + scores_h5.create_dataset("label", data=np.array(seq_headers, dtype="S")) + scores_h5.create_dataset("seqs", dtype="bool", shape=(num_seqs, options.mut_len, 4)) + for snp_stat in options.snp_stats: + scores_h5.create_dataset( + snp_stat, dtype="float16", shape=(num_seqs, options.mut_len, 4, num_targets) + ) + + ################################################################# + # predict scores and write output + + for si in range(seqs_1hot.shape[0]): + print("Predicting %d" % si, flush=True) + + # 1-hot encode reference + ref_1hot = np.expand_dims(seqs_1hot[si], axis=0) + + # save sequence + scores_h5["seqs"][si] = ref_1hot[0, mut_start:mut_end].astype("bool") + + # predict reference + ref_preds = [] + for shift in options.shifts: + # shift sequence and predict + ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift) + ref_preds_shift = seqnn_model.predict_transform( + ref_1hot_shift, + targets_df, + strand_transform, + options.untransform_old, + ) + ref_preds.append(ref_preds_shift) + ref_preds = np.array(ref_preds) + + # for mutation positions + for mi in range(mut_start, mut_end): + # for each nucleotide + for ni in range(4): + # if non-reference + if ref_1hot[0, mi, ni] == 0: + # copy and modify + alt_1hot = np.copy(ref_1hot) + alt_1hot[0, mi, :] = 0 + alt_1hot[0, mi, ni] = 1 + + # predict alternate + alt_preds = [] + for shift in options.shifts: + # shift sequence and predict + alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift) + alt_preds_shift = seqnn_model.predict_transform( + alt_1hot_shift, + targets_df, + strand_transform, + options.untransform_old, + ) + alt_preds.append(alt_preds_shift) + alt_preds = np.array(alt_preds) + + ism_scores = snps.compute_scores( + ref_preds, alt_preds, options.snp_stats + ) + for snp_stat in options.snp_stats: + scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[ + snp_stat + ] + + # close output HDF5 + scores_h5.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == "__main__": + main() diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index cf31a5d..f29bf75 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -15,7 +15,6 @@ # ========================================================================= from optparse import OptionParser import pdb -import pickle import os from baskerville.snps import score_snps import tempfile @@ -51,7 +50,7 @@ def main(): "-f", dest="genome_fasta", default=None, - help="Genome FASTA for sequences [Default: %default]", + help="Genome FASTA [Default: %default]", ) parser.add_option( "-o", @@ -173,6 +172,7 @@ def main(): options.snp_stats = options.snp_stats.split(",") if options.targets_file is None: parser.error("Must provide targets file") + ################################################################# # check if the program is run on GPU, else quit physical_devices = tf.config.list_physical_devices() @@ -185,6 +185,7 @@ def main(): print("Running on CPU") if options.require_gpu: raise SystemExit("Job terminated because it's running on CPU") + ################################################################# # download input files from gcs to a local file if options.gcs: @@ -199,6 +200,7 @@ def main(): options.targets_file = download_rename_inputs( options.targets_file, temp_dir ) + ################################################################# # calculate SAD scores: if options.processes is not None: diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 716bb7b..48aa300 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -13,14 +13,16 @@ # limitations under the License. # ========================================================================= import pdb +import gc import sys import time from natsort import natsorted import numpy as np import tensorflow as tf -import gc + from baskerville import blocks +from baskerville import dataset from baskerville import layers from baskerville import metrics @@ -967,6 +969,36 @@ def predict( return preds + def predict_transform( + self, + seq_1hot: np.array, + targets_df, + strand_transform: np.array = None, + untransform_old: bool = False, + ): + """Predict a single sequence and transform. + + Args: + seq_1hot (np.array): 1-hot encoded sequence. + targets_df (pd.DataFrame): Targets dataframe. + strand_transform (np.array): Strand merging transform. + untransform_old (bool): Apply old untransform. + """ + # predict + preds = self(seq_1hot)[0] + + # untransform predictions + if untransform_old: + preds = dataset.untransform_preds1(preds, targets_df) + else: + preds = dataset.untransform_preds(preds, targets_df) + + # sum strand pairs + if strand_transform is not None: + preds = preds * strand_transform + + return preds + def restore(self, model_file, head_i=0, trunk=False): """Restore weights from saved model.""" if trunk: diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 2603863..6ecc4b4 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd import pysam -from scipy.sparse import dok_matrix from scipy.special import rel_entr from tqdm import tqdm @@ -37,6 +36,9 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): params_model = params["model"] # read targets + if options.targets_file is None: + print("Must provide targets file to clarify stranded datasets", file=sys.stderr) + exit(1) targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0) # handle strand pairs @@ -52,7 +54,7 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): params_model["strand_pair"] = [targets_strand_pair] # construct strand sum transform - strand_transform = make_strand_transform(targets_df, targets_strand_df) + strand_transform = dataset.make_strand_transform(targets_df, targets_strand_df) else: targets_strand_df = targets_df strand_transform = None @@ -72,15 +74,7 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): # shift outside seqnn num_shifts = len(options.shifts) - targets_length = seqnn_model.target_lengths[0] - num_targets = seqnn_model.num_targets() - if options.targets_file is None: - target_ids = ["t%d" % ti for ti in range(num_targets)] - target_labels = [""] * len(target_ids) - targets_strand_df = pd.DataFrame( - {"identifier": target_ids, "description": target_labels} - ) ################################################################# # load SNPs @@ -225,7 +219,10 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): if sum_length: write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats) else: - write_snp_len(rp_snp, ap_snp, scores_out, si, options.snp_stats) + # write_snp_len(rp_snp, ap_snp, scores_out, si, options.snp_stats) + scores = compute_scores(rp_snp, ap_snp, options.snp_stats) + for snp_stat in options.snp_stats: + scores_out[snp_stat][si] = scores[snp_stat] # update SNP index si += 1 @@ -266,6 +263,167 @@ def cluster_snps(snps, seq_len: int, center_pct: float): return snp_clusters +def compute_scores(ref_preds, alt_preds, snp_stats): + """Compute SNP scores from reference and alternative predictions. + + Args: + ref_preds (np.array): Reference allele predictions. + alt_preds (np.array): Alternative allele predictions. + snp_stats [str]: List of SAD stats to compute. + """ + num_shifts, seq_length, num_targets = ref_preds.shape + + # log/sqrt + ref_preds_log = np.log2(ref_preds + 1) + alt_preds_log = np.log2(alt_preds + 1) + ref_preds_sqrt = np.sqrt(ref_preds) + alt_preds_sqrt = np.sqrt(alt_preds) + + # sum across length + ref_preds_sum = ref_preds.sum(axis=(0, 1)) + alt_preds_sum = alt_preds.sum(axis=(0, 1)) + ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1)) + alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1)) + ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1)) + alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1)) + + # difference + altref_diff = alt_preds - ref_preds + altref_adiff = np.abs(altref_diff) + altref_log_diff = alt_preds_log - ref_preds_log + altref_log_adiff = np.abs(altref_log_diff) + altref_sqrt_diff = alt_preds_sqrt - ref_preds_sqrt + altref_sqrt_adiff = np.abs(altref_sqrt_diff) + + # initialize scores dict + scores = {} + + # compare reference to alternative via sum subtraction + if "SUM" in snp_stats: + sad = alt_preds_sum - ref_preds_sum + sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) + scores["SUM"] = sad.astype("float16") + if "logSUM" in snp_stats: + log_sad = alt_preds_log_sum - ref_preds_log_sum + log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) + scores["logSUM"] = log_sad.astype("float16") + if "sqrtSUM" in snp_stats: + sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum + sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) + scores["sqrtSUM"] = sqrt_sad.astype("float16") + + # TEMP during name change + if "SAD" in snp_stats: + sad = alt_preds_sum - ref_preds_sum + sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) + scores["SAD"] = sad.astype("float16") + if "logSAD" in snp_stats: + log_sad = alt_preds_log_sum - ref_preds_log_sum + log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) + scores["logSAD"] = log_sad.astype("float16") + if "sqrtSAD" in snp_stats: + sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum + sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) + scores["sqrtSAD"] = sqrt_sad.astype("sqrtSAD") + + # compare reference to alternative via max subtraction + if "SAX" in snp_stats: + sax = [] + for s in range(num_shifts): + max_i = np.argmax(altref_adiff[s], axis=0) + sax.append(altref_diff[s, max_i, np.arange(num_targets)]) + sax = np.array(sax).mean(axis=0) + scores["SAX"] = sax.astype("float16") + + # L1 norm of difference vector + if "D1" in snp_stats: + sad_d1 = altref_adiff.sum(axis=1) + sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) + sad_d1 = sad_d1.mean(axis=0) + scores["D1"] = sad_d1.mean().astype("float16") + if "logD1" in snp_stats: + log_d1 = altref_log_adiff.sum(axis=1) + log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) + log_d1 = log_d1.mean(axis=0) + scores["logD1"] = log_d1.astype("float16") + if "sqrtD1" in snp_stats: + sqrt_d1 = altref_sqrt_adiff.sum(axis=1) + sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) + sqrt_d1 = sqrt_d1.mean(axis=0) + scores["sqrtD1"] = sqrt_d1.astype("float16") + + # L2 norm of difference vector + if "D2" in snp_stats: + altref_diff2 = np.power(altref_diff, 2) + sad_d2 = np.sqrt(altref_diff2.sum(axis=1)) + sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) + sad_d2 = sad_d2.mean(axis=0) + scores["D2"] = sad_d2.astype("float16") + if "logD2" in snp_stats: + altref_log_diff2 = np.power(altref_log_diff, 2) + log_d2 = np.sqrt(altref_log_diff2.sum(axis=1)) + log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) + log_d2 = log_d2.mean(axis=0) + scores["logD2"] = log_d2.astype("float16") + if "sqrtD2" in snp_stats: + altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2) + sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=1)) + sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) + sqrt_d2 = sqrt_d2.mean(axis=0) + scores["sqrtD2"] = sqrt_d2.astype("float16") + + if "JS" in snp_stats: + # normalized scores + pseudocounts = np.percentile(ref_preds, 25, axis=1) + ref_preds_norm = ref_preds + pseudocounts + ref_preds_norm /= ref_preds_norm.sum(axis=1) + alt_preds_norm = alt_preds + pseudocounts + alt_preds_norm /= alt_preds_norm.sum(axis=1) + + # compare normalized JS + js_dist = [] + for s in range(num_shifts): + ref_alt_entr = rel_entr(ref_preds_norm[s], alt_preds_norm[s]).sum(axis=0) + alt_ref_entr = rel_entr(alt_preds_norm[s], ref_preds_norm[s]).sum(axis=0) + js_dist.append((ref_alt_entr + alt_ref_entr) / 2) + js_dist = np.mean(js_dist, axis=0) + scores["JS"] = js_dist.astype("float16") + if "logJS" in snp_stats: + # normalized scores + pseudocounts = np.percentile(ref_preds_log, 25, axis=0) + ref_preds_log_norm = ref_preds_log + pseudocounts + ref_preds_log_norm /= ref_preds_log_norm.sum(axis=0) + alt_preds_log_norm = alt_preds_log + pseudocounts + alt_preds_log_norm /= alt_preds_log_norm.sum(axis=0) + + # compare normalized JS + log_js_dist = [] + for s in range(num_shifts): + ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum( + axis=0 + ) + alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum( + axis=0 + ) + log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2) + log_js_dist = np.mean(log_js_dist, axis=0) + scores["logJS"] = log_js_dist.astype("float16") + + # predictions + if "REF" in snp_stats: + ref_preds = np.clip( + ref_preds, np.finfo(np.float16).min, np.finfo(np.float16).max + ) + scores["REF"] = ref_preds.astype("float16") + if "ALT" in snp_stats: + alt_preds = np.clip( + alt_preds, np.finfo(np.float16).min, np.finfo(np.float16).max + ) + scores["ALT"] = alt_preds.astype("float16") + + return scores + + def initialize_output_h5( out_dir, snp_stats, snps, targets_length, targets_df, num_shifts ): @@ -382,36 +540,6 @@ def make_alt_1hot(ref_1hot, snp_seq_pos, ref_allele, alt_allele): return alt_1hot -def make_strand_transform(targets_df, targets_strand_df): - """Make a sparse matrix to sum strand pairs. - - Args: - targets_df (pd.DataFrame): Targets DataFrame. - targets_strand_df (pd.DataFrame): Targets DataFrame, with strand pairs collapsed. - - Returns: - scipy.sparse.csr_matrix: Sparse matrix to sum strand pairs. - """ - - # initialize sparse matrix - strand_transform = dok_matrix((targets_df.shape[0], targets_strand_df.shape[0])) - - # fill in matrix - ti = 0 - sti = 0 - for _, target in targets_df.iterrows(): - strand_transform[ti, sti] = True - if target.strand_pair == target.name: - sti += 1 - else: - if target.identifier[-1] == "-": - sti += 1 - ti += 1 - strand_transform = strand_transform.tocsr() - - return strand_transform - - def write_pct(scores_out, snp_stats): """Compute percentile values for each target and write to HDF5. diff --git a/src/baskerville/vcf.py b/src/baskerville/vcf.py index d455257..66e7b3e 100644 --- a/src/baskerville/vcf.py +++ b/src/baskerville/vcf.py @@ -212,8 +212,8 @@ def snp_seq1(snp, seq_len, genome_open): seq = genome_open.fetch(snp.chr, seq_start - 1, seq_end).upper() # extend to full length - if len(seq) < seq_end - seq_start: - seq += "N" * (seq_end - seq_start - len(seq)) + if len(seq) < seq_len: + seq += "N" * (seq_len - len(seq)) # verify that ref allele matches ref sequence seq_ref = seq[left_len : left_len + len(snp.ref_allele)] @@ -301,16 +301,17 @@ def snps_seq1(snps, seq_len, genome_fasta, return_seqs=False): # extract sequence as BED style if seq_start < 0: - seq = "N" * (-seq_start) + genome_open.fetch(snp.chr, 0, seq_end).upper() + seq = "N" * (1 - seq_start) + genome_open.fetch(snp.chr, 0, seq_end).upper() else: seq = genome_open.fetch(snp.chr, seq_start - 1, seq_end).upper() # extend to full length - if len(seq) < seq_end - seq_start: - seq += "N" * (seq_end - seq_start - len(seq)) + if len(seq) < seq_len: + seq += "N" * (seq_len - len(seq)) # verify that ref allele matches ref sequence seq_ref = seq[left_len : left_len + len(snp.ref_allele)] + ref_found = True if seq_ref != snp.ref_allele: # search for reference allele in alternatives ref_found = False @@ -336,13 +337,13 @@ def snps_seq1(snps, seq_len, genome_fasta, return_seqs=False): ) break - if not ref_found: - print( - "WARNING: %s - reference genome %s does not match any allele; skipping" - % (seq_ref, snp.rsid), - file=sys.stderr, - ) - continue + if not ref_found: + print( + "WARNING: %s - reference genome %s does not match any allele; skipping" + % (seq_ref, snp.rsid), + file=sys.stderr, + ) + break seq_snps.append(snp) diff --git a/tests/data/.gitignore b/tests/data/.gitignore index c79104e..26c4af7 100644 --- a/tests/data/.gitignore +++ b/tests/data/.gitignore @@ -1,2 +1,2 @@ -eval/eval_out train?/ +*/*_out/ \ No newline at end of file diff --git a/tests/data/hg38_1m.fa.gz b/tests/data/hg38_1m.fa.gz new file mode 100644 index 0000000..f2b9c20 Binary files /dev/null and b/tests/data/hg38_1m.fa.gz differ diff --git a/tests/data/hg38_1m.fa.gz.fai b/tests/data/hg38_1m.fa.gz.fai new file mode 100644 index 0000000..935da38 --- /dev/null +++ b/tests/data/hg38_1m.fa.gz.fai @@ -0,0 +1 @@ +chr1 1000000 6 50 51 diff --git a/tests/data/hg38_1m.fa.gz.gzi b/tests/data/hg38_1m.fa.gz.gzi new file mode 100644 index 0000000..1a8a1d4 Binary files /dev/null and b/tests/data/hg38_1m.fa.gz.gzi differ diff --git a/tests/data/ism/seqs.bed b/tests/data/ism/seqs.bed new file mode 100644 index 0000000..f78e9c8 --- /dev/null +++ b/tests/data/ism/seqs.bed @@ -0,0 +1,2 @@ +chr1 487600 487601 +chr1 959300 959301 diff --git a/tests/data/ism/snp.vcf b/tests/data/ism/snp.vcf new file mode 100644 index 0000000..983b7b1 --- /dev/null +++ b/tests/data/ism/snp.vcf @@ -0,0 +1,2 @@ +##fileformat=VCFv4.2 +chr1 14677 chr1_14677_G_A_b38 G A . . diff --git a/tests/data/snp/eqtl.vcf b/tests/data/snp/eqtl.vcf new file mode 100644 index 0000000..47c66c5 --- /dev/null +++ b/tests/data/snp/eqtl.vcf @@ -0,0 +1,9 @@ +##fileformat=VCFv4.2 +chr1 14677 chr1_14677_G_A_b38 G A . . +chr1 54490 chr1_54490_G_A_b38 G A . . +chr1 63671 chr1_63671_G_A_b38 G A . . +chr1 63697 chr1_63697_T_C_b38 T C . . +chr1 64764 chr1_64764_C_T_b38 C T . . +chr1 108230 chr1_108230_C_T_b38 C T . . +chr1 108826 chr1_108826_G_C_b38 G C . . +chr1 135203 chr1_135203_G_A_b38 G A . . diff --git a/tests/data/snp/eqtl_flip.vcf b/tests/data/snp/eqtl_flip.vcf new file mode 100644 index 0000000..e3a5763 --- /dev/null +++ b/tests/data/snp/eqtl_flip.vcf @@ -0,0 +1,9 @@ +##fileformat=VCFv4.2 +chr1 14677 chr1_14677_G_A_b38 A G . . +chr1 54490 chr1_54490_G_A_b38 A G . . +chr1 63671 chr1_63671_G_A_b38 A G . . +chr1 63697 chr1_63697_T_C_b38 C T . . +chr1 64764 chr1_64764_C_T_b38 T C . . +chr1 108230 chr1_108230_C_T_b38 T C . . +chr1 108826 chr1_108826_G_C_b38 C G . . +chr1 135203 chr1_135203_G_A_b38 A G . . diff --git a/tests/data/tiny/hg38/targets_rna.txt b/tests/data/tiny/hg38/targets_rna.txt new file mode 100644 index 0000000..cbbb073 --- /dev/null +++ b/tests/data/tiny/hg38/targets_rna.txt @@ -0,0 +1,41 @@ + identifier file clip clip_soft scale sum_stat strand_pair description +34 ENCFF520NFI+ /home/drk/tillage/datasets/human/rna/encode/ENCSR000AAR/summary/coverage+.w5 384 64 0.3 sum_sqrt 35 RNA:tracheal epithelial cell male adult (21 years) and male adult (68 years) +35 ENCFF520NFI- /home/drk/tillage/datasets/human/rna/encode/ENCSR000AAR/summary/coverage-.w5 384 64 0.3 sum_sqrt 34 RNA:tracheal epithelial cell male adult (21 years) and male adult (68 years) +36 ENCFF892OBT+ /home/drk/tillage/datasets/human/rna/encode/ENCSR000AFC/summary/coverage+.w5 384 64 0.3 sum_sqrt 37 RNA:lung tissue female embryo (20 weeks) and female embryo (24 weeks) +37 ENCFF892OBT- /home/drk/tillage/datasets/human/rna/encode/ENCSR000AFC/summary/coverage-.w5 384 64 0.3 sum_sqrt 36 RNA:lung tissue female embryo (20 weeks) and female embryo (24 weeks) +38 ENCFF782HFV+ /home/drk/tillage/datasets/human/rna/encode/ENCSR151NGC/summary/coverage+.w5 384 64 0.3 sum_sqrt 39 RNA:GM12878 +39 ENCFF782HFV- /home/drk/tillage/datasets/human/rna/encode/ENCSR151NGC/summary/coverage-.w5 384 64 0.3 sum_sqrt 38 RNA:GM12878 +40 ENCFF946ZPT /home/drk/tillage/datasets/human/rna/encode/ENCSR264LON/summary/coverage.w5 384 64 0.3 sum_sqrt 40 RNA:with multiple sclerosis; CD14-positive monocyte +41 ENCFF213LRI+ /home/drk/tillage/datasets/human/rna/encode/ENCSR474TRG/summary/coverage+.w5 384 64 0.3 sum_sqrt 42 RNA:esophagus squamous epithelium tissue male adult (54 years) +42 ENCFF213LRI- /home/drk/tillage/datasets/human/rna/encode/ENCSR474TRG/summary/coverage-.w5 384 64 0.3 sum_sqrt 41 RNA:esophagus squamous epithelium tissue male adult (54 years) +43 ENCFF709BGZ /home/drk/tillage/datasets/human/rna/encode/ENCSR535YOP/summary/coverage.w5 384 64 0.3 sum_sqrt 43 RNA:with multiple sclerosis; IgD-negative memory B cell +44 ENCFF489SFV /home/drk/tillage/datasets/human/rna/encode/ENCSR555BCP/summary/coverage.w5 384 64 0.3 sum_sqrt 44 RNA:adrenal gland tissue embryo (96 days) +45 ENCFF047ADC /home/drk/tillage/datasets/human/rna/encode/ENCSR563ZWI/summary/coverage.w5 384 64 0.3 sum_sqrt 45 RNA:K562 treated with 10 nM Bortezomib for 4 hours +46 ENCFF218MMR+ /home/drk/tillage/datasets/human/rna/encode/ENCSR569JKX/summary/coverage+.w5 384 64 0.3 sum_sqrt 47 RNA:SK-N-DZ cytosolic fraction +47 ENCFF218MMR- /home/drk/tillage/datasets/human/rna/encode/ENCSR569JKX/summary/coverage-.w5 384 64 0.3 sum_sqrt 46 RNA:SK-N-DZ cytosolic fraction +48 ENCFF733JPK+ /home/drk/tillage/datasets/human/rna/encode/ENCSR621PZI/summary/coverage+.w5 384 64 0.3 sum_sqrt 49 RNA:spleen tissue female adult (41 years) +49 ENCFF733JPK- /home/drk/tillage/datasets/human/rna/encode/ENCSR621PZI/summary/coverage-.w5 384 64 0.3 sum_sqrt 48 RNA:spleen tissue female adult (41 years) +50 ENCFF759QTQ /home/drk/tillage/datasets/human/rna/encode/ENCSR702IGQ/summary/coverage.w5 384 64 0.3 sum_sqrt 50 RNA:stomach tissue female embryo (107 days) +51 ENCFF751PQP+ /home/drk/tillage/datasets/human/rna/encode/ENCSR725TPW/summary/coverage+.w5 384 64 0.3 sum_sqrt 52 RNA:ovary tissue female adult (30 years) +52 ENCFF751PQP- /home/drk/tillage/datasets/human/rna/encode/ENCSR725TPW/summary/coverage-.w5 384 64 0.3 sum_sqrt 51 RNA:ovary tissue female adult (30 years) +53 ENCFF193HUI /home/drk/tillage/datasets/human/rna/encode/ENCSR755LFM/summary/coverage.w5 384 64 0.3 sum_sqrt 53 RNA:testis tissue male embryo +54 ENCFF295UWT /home/drk/tillage/datasets/human/rna/encode/ENCSR774SEX/summary/coverage.w5 384 64 0.3 sum_sqrt 54 RNA:stomach tissue embryo (101 days) +55 ENCFF107LVE /home/drk/tillage/datasets/human/rna/encode/ENCSR880XLM/summary/coverage.w5 384 64 0.3 sum_sqrt 55 RNA:placenta tissue female embryo (113 days) +56 ENCFF649RYB+ /home/drk/tillage/datasets/human/rna/encode/ENCSR892LBU/summary/coverage+.w5 384 64 0.3 sum_sqrt 57 RNA:kidney tissue female adult (47 years) +57 ENCFF649RYB- /home/drk/tillage/datasets/human/rna/encode/ENCSR892LBU/summary/coverage-.w5 384 64 0.3 sum_sqrt 56 RNA:kidney tissue female adult (47 years) +58 ENCFF734VZN+ /home/drk/tillage/datasets/human/rna/encode/ENCSR897KTO/summary/coverage+.w5 384 64 0.3 sum_sqrt 59 RNA:epithelial cell of alveolus of lung NONE and female embryo (21 weeks) +59 ENCFF734VZN- /home/drk/tillage/datasets/human/rna/encode/ENCSR897KTO/summary/coverage-.w5 384 64 0.3 sum_sqrt 58 RNA:epithelial cell of alveolus of lung NONE and female embryo (21 weeks) +60 ENCFF895JFS+ /home/drk/tillage/datasets/human/rna/encode/ENCSR997KDB/summary/coverage+.w5 384 64 0.3 sum_sqrt 61 RNA:heart right ventricle tissue female adult (59 years) +61 ENCFF895JFS- /home/drk/tillage/datasets/human/rna/encode/ENCSR997KDB/summary/coverage-.w5 384 64 0.3 sum_sqrt 60 RNA:heart right ventricle tissue female adult (59 years) +62 GTEX-14DAR-0726-SM-5RQIA.1 /home/drk/tillage/datasets/human/rna/recount3/esophagus/GTEX-14DAR-0726-SM-5RQIA.1/coverage.w5 384 64 0.02 sum_sqrt 62 RNA:esophagus +63 GTEX-1MUQO-1226-SM-E9TJK.1 /home/drk/tillage/datasets/human/rna/recount3/heart/GTEX-1MUQO-1226-SM-E9TJK.1/coverage.w5 384 64 0.02 sum_sqrt 63 RNA:heart +64 tabula174+ /home/drk/tillage/datasets/human/rna3/czi/tabula174/coverage+.w5 384 64 0.3059869193359048 sum_sqrt 65 RNA3:Lymph_Node, memory b cell +65 tabula174- /home/drk/tillage/datasets/human/rna3/czi/tabula174/coverage-.w5 384 64 0.3059869193359048 sum_sqrt 64 RNA3:Lymph_Node, memory b cell +66 tabula206+ /home/drk/tillage/datasets/human/rna3/czi/tabula206/coverage+.w5 384 64 0.133490386585116 sum_sqrt 67 RNA3:Muscle, mesenchymal stem cell +67 tabula206- /home/drk/tillage/datasets/human/rna3/czi/tabula206/coverage-.w5 384 64 0.133490386585116 sum_sqrt 66 RNA3:Muscle, mesenchymal stem cell +68 tabula208+ /home/drk/tillage/datasets/human/rna3/czi/tabula208/coverage+.w5 384 64 0.5075664849226661 sum_sqrt 69 RNA3:Muscle, skeletal muscle satellite stem cell +69 tabula208- /home/drk/tillage/datasets/human/rna3/czi/tabula208/coverage-.w5 384 64 0.5075664849226661 sum_sqrt 68 RNA3:Muscle, skeletal muscle satellite stem cell +70 tabula271+ /home/drk/tillage/datasets/human/rna3/czi/tabula271/coverage+.w5 384 64 1.0 sum_sqrt 71 RNA3:Skin, macrophage +71 tabula271- /home/drk/tillage/datasets/human/rna3/czi/tabula271/coverage-.w5 384 64 1.0 sum_sqrt 70 RNA3:Skin, macrophage +72 tabula399+ /home/drk/tillage/datasets/human/rna3/czi/tabula399/coverage+.w5 384 64 1.0 sum_sqrt 73 RNA3:Vasculature, epithelial cell +73 tabula399- /home/drk/tillage/datasets/human/rna3/czi/tabula399/coverage-.w5 384 64 1.0 sum_sqrt 72 RNA3:Vasculature, epithelial cell diff --git a/tests/test_eval.py b/tests/test_eval.py old mode 100644 new mode 100755 diff --git a/tests/test_ism.py b/tests/test_ism.py new file mode 100755 index 0000000..4a93e3d --- /dev/null +++ b/tests/test_ism.py @@ -0,0 +1,83 @@ +import h5py +import pdb +import subprocess + +import numpy as np + +stat_keys = ["logSUM", "logD2"] +fasta_file = "tests/data/hg38_1m.fa.gz" +targets_file = "tests/data/tiny/hg38/targets.txt" +params_file = "tests/data/eval/params.json" +model_file = "tests/data/eval/model.h5" +snp_out_dir = "tests/data/ism/snp_out" +bed_out_dir = "tests/data/ism/bed_out" + + +def test_snp(): + cmd = [ + "src/baskerville/scripts/hound_ism_snp.py", + "-f", + fasta_file, + "-l", + "6", + "-o", + snp_out_dir, + "--stats", + ",".join(stat_keys), + "-t", + targets_file, + params_file, + model_file, + "tests/data/ism/snp.vcf", + ] + print(" ".join(cmd)) + subprocess.run(cmd, check=True) + + with h5py.File(f"{snp_out_dir}/scores.h5", "r") as scores_h5: + for sk in stat_keys: + score = scores_h5[sk][:] + score_var = score.var(axis=2, dtype="float32") + + # verify shape + assert score.shape == (2, 6, 4, 47) + + # verify not NaN + assert not np.isnan(score).any() + + # verify variance + assert (score_var > 0).all() + + +def test_bed(): + cmd = [ + "src/baskerville/scripts/hound_ism_bed.py", + "-f", + fasta_file, + "-l", + "6", + "-o", + bed_out_dir, + "--stats", + ",".join(stat_keys), + "-t", + targets_file, + params_file, + model_file, + "tests/data/ism/seqs.bed", + ] + print(" ".join(cmd)) + subprocess.run(cmd, check=True) + + with h5py.File(f"{bed_out_dir}/scores.h5", "r") as scores_h5: + for sk in stat_keys: + score = scores_h5[sk][:] + score_var = score.var(axis=2, dtype="float32") + + # verify shape + assert score.shape == (2, 6, 4, 47) + + # verify not NaN + assert not np.isnan(score).any() + + # verify variance + assert (score_var > 0).all() diff --git a/tests/test_snp.py b/tests/test_snp.py new file mode 100755 index 0000000..68a0a9b --- /dev/null +++ b/tests/test_snp.py @@ -0,0 +1,125 @@ +import subprocess + +import h5py +import numpy as np +import pandas as pd + +from baskerville.dataset import targets_prep_strand + +stat_keys = ["logSUM", "logD2"] +fasta_file = "tests/data/hg38_1m.fa.gz" +targets_file = "tests/data/tiny/hg38/targets.txt" +params_file = "tests/data/eval/params.json" +model_file = "tests/data/eval/model.h5" +snp_out_dir = "tests/data/snp/eqtl_out" + + +def test_snp(): + cmd = [ + "src/baskerville/scripts/hound_snp.py", + "-f", + fasta_file, + "-o", + snp_out_dir, + "--stats", + ",".join(stat_keys), + "-t", + targets_file, + params_file, + model_file, + "tests/data/snp/eqtl.vcf", + ] + print(" ".join(cmd)) + subprocess.run(cmd, check=True) + + scores_file = "tests/data/snp/eqtl_out/scores.h5" + with h5py.File(scores_file, "r") as scores_h5: + for sk in stat_keys: + score = scores_h5[sk][:] + score_var = score.var(axis=0, dtype="float32") + + # verify shapes + assert score.shape == (8, 47) + + # verify not NaN + assert not np.isnan(score).any() + + # verify variance + assert (score_var > 0).all() + + +def test_flip(): + # score SNPs + flip_out_dir = "tests/data/snp/flip_out" + cmd = [ + "src/baskerville/scripts/hound_snp.py", + "-f", + fasta_file, + "-o", + flip_out_dir, + "--stats", + ",".join(stat_keys), + "-t", + targets_file, + params_file, + model_file, + "tests/data/snp/eqtl_flip.vcf", + ] + print(" ".join(cmd)) + subprocess.run(cmd, check=True) + + scores_file = f"{snp_out_dir}/scores.h5" + with h5py.File(scores_file, "r") as scores_h5: + score_sum = scores_h5["logSUM"][:] + score_d2 = scores_h5["logD2"][:] + + scores_flip_file = f"{flip_out_dir}/scores.h5" + with h5py.File(scores_flip_file, "r") as scores_h5: + score_sum_flip = scores_h5["logSUM"][:] + score_d2_flip = scores_h5["logD2"][:] + + assert np.allclose(score_sum, -score_sum_flip) + assert np.allclose(score_d2, score_d2_flip) + + +def test_slice(): + # slice targets + targets_df = pd.read_csv(targets_file, sep="\t", index_col=0) + rna_mask = np.array([desc.startswith("RNA") for desc in targets_df.description]) + targets_rna_df = targets_df[rna_mask] + targets_rna_file = targets_file.replace(".txt", "_rna.txt") + targets_rna_df.to_csv(targets_rna_file, sep="\t") + + # score SNPs + slice_out_dir = "tests/data/snp/slice_out" + cmd = [ + "src/baskerville/scripts/hound_snp.py", + "-f", + fasta_file, + "-o", + slice_out_dir, + "--stats", + ",".join(stat_keys), + "-t", + targets_rna_file, + params_file, + model_file, + "tests/data/snp/eqtl.vcf", + ] + print(" ".join(cmd)) + subprocess.run(cmd, check=True) + + # stranded mask + targets_strand_df = targets_prep_strand(targets_df) + rna_strand_mask = np.array( + [desc.startswith("RNA") for desc in targets_strand_df.description] + ) + + # verify all close + for sk in stat_keys: + with h5py.File(f"{snp_out_dir}/scores.h5", "r") as scores_h5: + score_full = scores_h5[sk][:].astype("float32") + score_full = score_full[..., rna_strand_mask] + with h5py.File(f"{slice_out_dir}/scores.h5", "r") as scores_h5: + score_slice = scores_h5[sk][:].astype("float32") + assert np.allclose(score_full, score_slice) diff --git a/tests/test_train.py b/tests/test_train.py old mode 100644 new mode 100755