diff --git a/src/baskerville/scripts/hound_ism_snp.py b/src/baskerville/scripts/hound_ism_snp.py index b98a02d..52332d5 100755 --- a/src/baskerville/scripts/hound_ism_snp.py +++ b/src/baskerville/scripts/hound_ism_snp.py @@ -36,6 +36,7 @@ given in a VCF file. """ + ################################################################################ # main ################################################################################ @@ -211,7 +212,7 @@ def main(): ref_1hot = np.expand_dims(seqs_1hot[si], axis=0) # save sequence - scores_h5['seqs'][si] = ref_1hot[0, mut_start:mut_end].astype('bool') + scores_h5["seqs"][si] = ref_1hot[0, mut_start:mut_end].astype("bool") # predict reference ref_preds = [] @@ -250,11 +251,15 @@ def main(): options.untransform_old, ) alt_preds.append(alt_preds_shift) - alt_preds = np.array(alt_preds) + alt_preds = np.array(alt_preds) - ism_scores = snps.compute_scores(ref_preds, alt_preds, options.snp_stats) + ism_scores = snps.compute_scores( + ref_preds, alt_preds, options.snp_stats + ) for snp_stat in options.snp_stats: - scores_h5[snp_stat][si, mi-mut_start, ni] = ism_scores[snp_stat] + scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[ + snp_stat + ] # close output HDF5 scores_h5.close() diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index 626eda3..f29bf75 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -185,7 +185,7 @@ def main(): print("Running on CPU") if options.require_gpu: raise SystemExit("Job terminated because it's running on CPU") - + ################################################################# # download input files from gcs to a local file if options.gcs: @@ -200,7 +200,7 @@ def main(): options.targets_file = download_rename_inputs( options.targets_file, temp_dir ) - + ################################################################# # calculate SAD scores: if options.processes is not None: diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 4eb4eea..48aa300 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -970,14 +970,14 @@ def predict( return preds def predict_transform( - self, - seq_1hot: np.array, + self, + seq_1hot: np.array, targets_df, strand_transform: np.array = None, untransform_old: bool = False, ): """Predict a single sequence and transform. - + Args: seq_1hot (np.array): 1-hot encoded sequence. targets_df (pd.DataFrame): Targets dataframe. @@ -986,19 +986,19 @@ def predict_transform( """ # predict preds = self(seq_1hot)[0] - + # untransform predictions if untransform_old: preds = dataset.untransform_preds1(preds, targets_df) else: preds = dataset.untransform_preds(preds, targets_df) - + # sum strand pairs if strand_transform is not None: preds = preds * strand_transform return preds - + def restore(self, model_file, head_i=0, trunk=False): """Restore weights from saved model.""" if trunk: diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index f0d0813..6ecc4b4 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -37,7 +37,7 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): # read targets if options.targets_file is None: - print('Must provide targets file to clarify stranded datasets', file=sys.stderr) + print("Must provide targets file to clarify stranded datasets", file=sys.stderr) exit(1) targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0) diff --git a/tests/test_ism.py b/tests/test_ism.py index d0b23ea..4a93e3d 100755 --- a/tests/test_ism.py +++ b/tests/test_ism.py @@ -12,6 +12,7 @@ snp_out_dir = "tests/data/ism/snp_out" bed_out_dir = "tests/data/ism/bed_out" + def test_snp(): cmd = [ "src/baskerville/scripts/hound_ism_snp.py", @@ -46,6 +47,7 @@ def test_snp(): # verify variance assert (score_var > 0).all() + def test_bed(): cmd = [ "src/baskerville/scripts/hound_ism_bed.py", @@ -78,4 +80,4 @@ def test_bed(): assert not np.isnan(score).any() # verify variance - assert (score_var > 0).all() \ No newline at end of file + assert (score_var > 0).all()