From bb992f3446ca0f533b07d220ba4014d4a95b2a8b Mon Sep 17 00:00:00 2001 From: David Kelley Date: Tue, 11 Jun 2024 13:19:56 -0700 Subject: [PATCH 1/4] pred bed updates --- src/baskerville/scripts/hound_predbed.py | 85 +++++++++--------------- src/baskerville/seqnn.py | 8 ++- 2 files changed, 37 insertions(+), 56 deletions(-) diff --git a/src/baskerville/scripts/hound_predbed.py b/src/baskerville/scripts/hound_predbed.py index 4112048..518b2aa 100755 --- a/src/baskerville/scripts/hound_predbed.py +++ b/src/baskerville/scripts/hound_predbed.py @@ -26,9 +26,10 @@ import tensorflow as tf from baskerville import bed +from baskerville import dataset from baskerville import dna from baskerville import seqnn -from baskerville import stream + """ hound_predbed.py @@ -120,35 +121,18 @@ def main(): default=None, help="File specifying target indexes and labels in table format", ) - + parser.add_argument( + "-u", + "--untransform_old", + default=False, + action="store_true", + help="Untransform old models [Default: %default]", + ) parser.add_argument("params_file", help="Parameters file") parser.add_argument("model_file", help="Model file") parser.add_argument("bed_file", help="BED file") args = parser.parse_args() - if len(args) == 3: - params_file = args[0] - model_file = args[1] - bed_file = args[2] - - elif len(args) == 5: - # multi worker - options_pkl_file = args[0] - params_file = args[1] - model_file = args[2] - bed_file = args[3] - worker_index = int(args[4]) - - # load options - options_pkl = open(options_pkl_file, "rb") - options = pickle.load(options_pkl) - options_pkl.close() - - # update output directory - args.out_dir = "%s/job%d" % (args.out_dir, worker_index) - else: - parser.error("Must provide parameter and model files and BED file") - os.makedirs(args.out_dir, exist_ok=True) args.shifts = [int(shift) for shift in args.shifts.split(",")] @@ -166,7 +150,7 @@ def main(): ################################################################# # read parameters and collet target information - with open(params_file) as params_open: + with open(args.params_file) as params_open: params = json.load(params_open) params_model = params["model"] @@ -181,7 +165,7 @@ def main(): # initialize model seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file, args.head) + seqnn_model.restore(args.model_file, args.head) seqnn_model.build_slice(target_slice) seqnn_model.build_ensemble(args.rc, args.shifts) @@ -205,27 +189,8 @@ def main(): # construct model sequences model_seqs_dna, model_seqs_coords = bed.make_bed_seqs( - bed_file, args.genome_fasta, params_model["seq_length"], stranded=False + args.bed_file, args.genome_fasta, params_model["seq_length"], stranded=False ) - - # construct site coordinates - site_seqs_coords = bed.read_bed_coords(bed_file, args.site_length) - - # filter for worker SNPs - if args.processes is not None: - worker_bounds = np.linspace( - 0, len(model_seqs_dna), args.processes + 1, dtype="int" - ) - model_seqs_dna = model_seqs_dna[ - worker_bounds[worker_index] : worker_bounds[worker_index + 1] - ] - model_seqs_coords = model_seqs_coords[ - worker_bounds[worker_index] : worker_bounds[worker_index + 1] - ] - site_seqs_coords = site_seqs_coords[ - worker_bounds[worker_index] : worker_bounds[worker_index + 1] - ] - num_seqs = len(model_seqs_dna) ################################################################# @@ -258,6 +223,7 @@ def main(): ) # store site coordinates + site_seqs_coords = bed.read_bed_coords(args.bed_file, args.site_length) site_seqs_chr, site_seqs_start, site_seqs_end = zip(*site_seqs_coords) site_seqs_chr = np.array(site_seqs_chr, dtype="S") site_seqs_start = np.array(site_seqs_start) @@ -268,7 +234,7 @@ def main(): ################################################################# # predict scores, write output - + """ # define sequence generator def seqs_gen(): for seq_dna in model_seqs_dna: @@ -278,18 +244,29 @@ def seqs_gen(): preds_stream = stream.PredStreamGen( seqnn_model, seqs_gen(), params["train"]["batch_size"] ) + """ + + for si, seq_dna in enumerate(model_seqs_dna): + seq_1hot = np.expand_dims(dna.dna_1hot(seq_dna), axis=0) + preds_seq = seqnn_model.predict(seq_1hot)[0] - for si in range(num_seqs): - preds_seq = preds_stream[si] + if args.untransform_old: + preds_seq = dataset.untransform_preds1(preds_seq, targets_df) + else: + preds_seq = dataset.untransform_preds(preds_seq, targets_df) # slice site preds_site = preds_seq[site_preds_start:site_preds_end, :] - # write + # optionally, sum if args.sum: - out_h5["preds"][si] = preds_site.sum(axis=0) - else: - out_h5["preds"][si] = preds_site + preds_site = np.sum(preds_site, axis=0) + + # clip to float16 max + preds_write = np.clip(preds_site, 0, np.finfo(np.float16).max) + + # write + out_h5["preds"][si] = preds_write # write bigwig for ti in args.bigwig_indexes: diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 48aa300..43329b4 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -912,7 +912,11 @@ def __call__(self, x, head_i=None, dtype="float32"): else: model = self.model - return model(x).numpy().astype(dtype) + if isinstance(x, np.ndarray): + preds = model(x).numpy().astype(dtype) + else: + preds = model(x) + return preds def predict( self, @@ -955,7 +959,7 @@ def predict( preds = model.predict_generator(dataset, **kwargs).astype(dtype) elif stream: preds = [] - for x, y in seq_data.dataset: + for x, y in dataset: yh = model.predict(x, **kwargs) if step > 1: yh = yh[:, step_i, :] From 248a7e9f242393ffdcd29719cd12ffb966e549f6 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Sun, 16 Jun 2024 13:50:44 -0700 Subject: [PATCH 2/4] allow stitching from arbitrary indel positions --- src/baskerville/scripts/hound_snp.py | 7 ++-- src/baskerville/scripts/hound_snpgene.py | 11 ++++-- src/baskerville/snps.py | 45 +++++++++++++++--------- 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index e543c43..847cd16 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -140,8 +140,8 @@ def main(): out_dir = temp_dir + "/output_dir" options.out_dir = out_dir - if not os.path.isdir(options.out_dir): - os.mkdir(options.out_dir) + # is this here for GCS? + os.makedirs(options.out_dir, exist_ok=True) if len(args) == 3: # single worker @@ -200,9 +200,6 @@ def main(): if options.targets_file is None: parser.error("Must provide targets file") - if options.cluster_snps_pct > 0 and options.indel_stitch: - parser.error("Cannot use --cluster_snps_pct and --indel_stitch together") - ################################################################# # check if the program is run on GPU, else quit physical_devices = tf.config.list_physical_devices() diff --git a/src/baskerville/scripts/hound_snpgene.py b/src/baskerville/scripts/hound_snpgene.py index 7543d67..f3c4ffa 100755 --- a/src/baskerville/scripts/hound_snpgene.py +++ b/src/baskerville/scripts/hound_snpgene.py @@ -71,6 +71,13 @@ def main(): default="%s/genes/gencode41/gencode41_basic_nort.gtf" % os.environ["HG38"], help="GTF for gene definition [Default %default]", ) + parser.add_option( + "--indel_stitch", + dest="indel_stitch", + default=False, + action="store_true", + help="Stitch indel compensation shifts [Default: %default]", + ) parser.add_option( "-o", dest="out_dir", @@ -155,8 +162,8 @@ def main(): out_dir = temp_dir + "/output_dir" options.out_dir = out_dir - if not os.path.isdir(options.out_dir): - os.mkdir(options.out_dir) + # is this here for GCS? + os.makedirs(options.out_dir, exist_ok=True) if len(args) == 3: # single worker diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index d9d2acc..b4fce23 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -89,6 +89,9 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): # shift outside seqnn num_shifts = len(options.shifts) targets_length = seqnn_model.target_lengths[0] + targets_length = seqnn_model.target_lengths[0] + model_stride = seqnn_model.model_strides[0] + model_crop = seqnn_model.target_crops[0] * model_stride ################################################################# # load SNPs @@ -194,16 +197,11 @@ def score_write(ref_preds, alt_preds, si): for ai, alt_1hot in enumerate(snp_1hot_list[1:]): alt_1hot = np.expand_dims(alt_1hot, axis=0) - # add compensation shifts for indels + # add left/right shifts for indels indel_size = sc.snps[ai].indel_size() if indel_size == 0: alt_shifts = options.shifts else: - # repeat reference predictions, unless stitching - if not options.indel_stitch: - ref_preds = np.repeat(ref_preds, 2, axis=0) - - # add compensation shifts alt_shifts = [] for shift in options.shifts: alt_shifts.append(shift) @@ -229,9 +227,11 @@ def score_write(ref_preds, alt_preds, si): ref_preds = [rpsf.result() for rpsf in ref_preds] alt_preds = [apsf.result() for apsf in alt_preds] - # stitch indel compensation shifts + # stitch indel shifts if indel_size != 0 and options.indel_stitch: - alt_preds = stitch_preds(alt_preds, options.shifts) + snp_seq_pos = sc.snps[ai].pos - sc.start - model_crop + snp_seq_bin = snp_seq_pos // model_stride + alt_preds = stitch_preds(alt_preds, options.shifts, snp_seq_bin) # flip reference and alternate if snps[si].flipped: @@ -241,6 +241,10 @@ def score_write(ref_preds, alt_preds, si): rp_snp = np.array(ref_preds) ap_snp = np.array(alt_preds) + # repeat reference predictions for indels w/o stitching + if indel_size != 0 and not options.indel_stitch: + rp_snp = np.repeat(rp_snp, 2, axis=0) + # write SNP if sum_length: write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats) @@ -420,16 +424,11 @@ def score_write(ref_preds, alt_preds, gene_id, snp_id): for ai, alt_1hot in enumerate(snp_1hot_list[1:]): alt_1hot = np.expand_dims(alt_1hot, axis=0) - # add compensation shifts for indels + # add left/right shifts for indels indel_size = gsc.snps[ai].indel_size() if indel_size == 0: alt_shifts = options.shifts else: - # repeat reference predictions, unless stitching - if not options.indel_stitch: - ref_preds = np.repeat(ref_preds, 2, axis=0) - - # add compensation shifts alt_shifts = [] for shift in options.shifts: alt_shifts.append(shift) @@ -455,6 +454,12 @@ def score_write(ref_preds, alt_preds, gene_id, snp_id): ref_preds = [rpsf.result() for rpsf in ref_preds] alt_preds = [apsf.result() for apsf in alt_preds] + # stitch indel shifts + if indel_size != 0 and options.indel_stitch: + snp_seq_pos = gsc.snps[ai].pos - gsc.start - model_crop + snp_seq_bin = snp_seq_pos // model_stride + alt_preds = stitch_preds(alt_preds, options.shifts, snp_seq_bin) + # flip reference and alternate if gsc.snps[ai].flipped: rp_snp = np.array(alt_preds) @@ -463,6 +468,10 @@ def score_write(ref_preds, alt_preds, gene_id, snp_id): rp_snp = np.array(ref_preds) ap_snp = np.array(alt_preds) + # repeat reference predictions for indels w/o stitching + if indel_size != 0 and not options.indel_stitch: + rp_snp = np.repeat(rp_snp, 2, axis=0) + for gene in gsc.genes: # slice gene positions gene_slice = gene.output_slice( @@ -864,19 +873,21 @@ def map_snps_genes(snps, genesnp_clusters): genesnp_clusters[gi].add_snp(snps[si]) -def stitch_preds(preds, shifts): +def stitch_preds(preds, shifts, pos=None): """Stitch indel left and right compensation shifts. Args: preds [np.array]: List of predictions. shifts [int]: List of shifts. + pos (int): SNP position to stitch at. """ - cp = preds[0].shape[0] // 2 + if pos is None: + pos = preds[0].shape[0] // 2 preds_stitch = [] for hi, shift in enumerate(shifts): hil = 2 * hi hir = hil + 1 - preds_stitch_i = np.concatenate((preds[hil][:cp], preds[hir][cp:]), axis=0) + preds_stitch_i = np.concatenate((preds[hil][:pos], preds[hir][pos:]), axis=0) preds_stitch.append(preds_stitch_i) return preds_stitch From 877bdb7ca401f943f7db434108ae6ebec98a7097 Mon Sep 17 00:00:00 2001 From: lruizcalico Date: Mon, 17 Jun 2024 16:28:34 -0700 Subject: [PATCH 3/4] change pybedtools --- pyproject.toml | 2 +- tests/test_snp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ed4bec0..a8f08d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "numpy~=1.24.3", "pandas~=1.5.3", "pybigwig~=0.3.18", - "pybedtools~=0.9.0", + "pybedtools~=0.10.0", "pysam~=0.22.0", "qnorm~=0.8.1", "seaborn~=0.12.2", diff --git a/tests/test_snp.py b/tests/test_snp.py index 68a0a9b..f4cff69 100755 --- a/tests/test_snp.py +++ b/tests/test_snp.py @@ -6,7 +6,7 @@ from baskerville.dataset import targets_prep_strand -stat_keys = ["logSUM", "logD2"] +stat_keys = ["logSUM", "logD2", "logSAD"] fasta_file = "tests/data/hg38_1m.fa.gz" targets_file = "tests/data/tiny/hg38/targets.txt" params_file = "tests/data/eval/params.json" From a4202d71409257d2b3d257fa9197d8d4e41cb1bf Mon Sep 17 00:00:00 2001 From: lruizcalico Date: Mon, 17 Jun 2024 16:47:00 -0700 Subject: [PATCH 4/4] fix test --- tests/test_snp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_snp.py b/tests/test_snp.py index f4cff69..68a0a9b 100755 --- a/tests/test_snp.py +++ b/tests/test_snp.py @@ -6,7 +6,7 @@ from baskerville.dataset import targets_prep_strand -stat_keys = ["logSUM", "logD2", "logSAD"] +stat_keys = ["logSUM", "logD2"] fasta_file = "tests/data/hg38_1m.fa.gz" targets_file = "tests/data/tiny/hg38/targets.txt" params_file = "tests/data/eval/params.json"