From 73d86f0a73400b8204ca524033c2cc337a7bc726 Mon Sep 17 00:00:00 2001 From: David Kelley Date: Wed, 24 Apr 2024 16:13:34 -0700 Subject: [PATCH] gtex genes eval --- .../scripts/westminster_gtex_folds.py | 14 +- .../scripts/westminster_gtexg_coef.py | 134 ++-- .../scripts/westminster_gtexg_folds.py | 435 +++++------ .../scripts/westminster_gtexsed_folds.py | 680 ++++++++++++++++++ 4 files changed, 968 insertions(+), 295 deletions(-) create mode 100755 src/westminster/scripts/westminster_gtexsed_folds.py diff --git a/src/westminster/scripts/westminster_gtex_folds.py b/src/westminster/scripts/westminster_gtex_folds.py index aa59dad..aee13b6 100755 --- a/src/westminster/scripts/westminster_gtex_folds.py +++ b/src/westminster/scripts/westminster_gtex_folds.py @@ -52,15 +52,15 @@ def main(): snp_options.add_option( "-f", dest="genome_fasta", - default="%s/data/hg38.fa" % os.environ["BASENJIDIR"], + default="%s/assembly/ucsc/hg38.fa" % os.environ["HG38"], help="Genome FASTA for sequences [Default: %default]", ) snp_options.add_option( - '--indel_stitch', - dest='indel_stitch', + "--indel_stitch", + dest="indel_stitch", default=True, - action='store_true', - help="Stitch indel compensation shifts [Default: %default]" + action="store_true", + help="Stitch indel compensation shifts [Default: %default]", ) snp_options.add_option( "-o", @@ -406,10 +406,10 @@ def main(): # fit classifiers # SNPs - cmd_base = 'westminster_classify.py -f 8 -i 100 -r 44 -s' + cmd_base = "westminster_classify.py -f 8 -i 100 -r 44 -s" # indels # cmd_base = 'westminster_classify.py -f 6 -i 64 -r 44 -s' - cmd_base += ' --msl %d' % options.msl + cmd_base += " --msl %d" % options.msl if options.class_targets_file is not None: cmd_base += " -t %s" % options.class_targets_file diff --git a/src/westminster/scripts/westminster_gtexg_coef.py b/src/westminster/scripts/westminster_gtexg_coef.py index 2973119..37a4d29 100755 --- a/src/westminster/scripts/westminster_gtexg_coef.py +++ b/src/westminster/scripts/westminster_gtexg_coef.py @@ -35,7 +35,7 @@ def main(): parser.add_argument( "-g", "--gtex_vcf_dir", - default="/home/drk/seqnn/data/gtex_fine/susie_pip90", + default="/home/drk/seqnn/data/gtex_fine/susie_pip90r", help="GTEx VCF directory", ) parser.add_argument( @@ -44,7 +44,7 @@ def main(): parser.add_argument( "-s", "--snp_stat", - default="logSED", + default="logSUM", help="SNP statistic. [Default: %(default)s]", ) parser.add_argument("-v", "--verbose", action="store_true") @@ -95,7 +95,7 @@ def main(): eqtl_df = read_eqtl(tissue, args.gtex_vcf_dir) if eqtl_df is not None: # read model predictions - gtex_scores_file = f"{args.gtex_dir}/{tissue}_pos/sed.h5" + gtex_scores_file = f"{args.gtex_dir}/{tissue}_pos/scores.h5" eqtl_df = add_scores( gtex_scores_file, keyword, eqtl_df, args.snp_stat, verbose=args.verbose ) @@ -107,9 +107,7 @@ def main(): coef_r = spearmanr(eqtl_df.coef, eqtl_df.score)[0] # classification AUROC - class_auroc = classify_auroc( - gtex_scores_file, keyword, eqtl_df, args.snp_stat - ) + class_auroc = classify_auroc(gtex_scores_file, keyword, args.snp_stat) if args.plot: eqtl_df.to_csv(f"{args.out_dir}/{tissue}.tsv", index=False, sep="\t") @@ -210,18 +208,17 @@ def add_scores( """ with h5py.File(gtex_scores_file, "r") as gtex_scores_h5: # read data - snp_i = gtex_scores_h5["si"][:] - snps = np.array([snp.decode("UTF-8") for snp in gtex_scores_h5["snp"]]) - ref_allele = np.array( - [ref.decode("UTF-8") for ref in gtex_scores_h5["ref_allele"]] - ) - genes = np.array([snp.decode("UTF-8") for snp in gtex_scores_h5["gene"]]) - target_ids = np.array( - [ref.decode("UTF-8") for ref in gtex_scores_h5["target_ids"]] - ) - target_labels = np.array( - [ref.decode("UTF-8") for ref in gtex_scores_h5["target_labels"]] - ) + snps = [snp.decode("UTF-8") for snp in gtex_scores_h5["snp"]] + ref_allele = [ref.decode("UTF-8") for ref in gtex_scores_h5["ref_allele"]] + target_ids = [tid.decode("UTF-8") for tid in gtex_scores_h5["target_ids"]] + target_labels = [tl.decode("UTF-8") for tl in gtex_scores_h5["target_labels"]] + gtex_scores = gtex_scores_h5[score_key] + + # map gene identifiers + gene_map = {} + for gene in gtex_scores_h5["gene"]: + gene = gene.decode("UTF-8") + gene_map[trim_dot(gene)] = gene # determine matching GTEx targets match_tis = [] @@ -236,62 +233,50 @@ def add_scores( match_tis.append(ti) match_tis = np.array(match_tis) - # read scores and take mean across targets - score_table = gtex_scores_h5[score_key][..., match_tis].mean( - axis=-1, dtype="float32" - ) - score_table = np.arcsinh(score_table) - - # hash scores to (snp,gene) - snpgene_scores = {} - for sgi in range(score_table.shape[0]): - snp = snps[snp_i[sgi]] - gene = trim_dot(genes[sgi]) - snpgene_scores[(snp, gene)] = score_table[sgi] - - # add scores to eQTL table - # flipping when allele1 doesn't match - snp_ref = dict(zip(snps, ref_allele)) - eqtl_df_scores = [] - for sgi, eqtl in eqtl_df.iterrows(): - sgs = snpgene_scores.get((eqtl.variant, eqtl.gene), 0) - if sgs != 0 and snp_ref[eqtl.variant] != eqtl.allele1: - sgs *= -1 - eqtl_df_scores.append(sgs) - eqtl_df["score"] = eqtl_df_scores + # add scores to eQTL table + # flipping when allele1 doesn't match + snp_ref = dict(zip(snps, ref_allele)) + eqtl_df_scores = [] + for _, eqtl in eqtl_df.iterrows(): + if eqtl.variant in gtex_scores: + snp_scores = gtex_scores[eqtl.variant] + egene = gene_map.get(eqtl.gene, "") + if egene in snp_scores.keys(): + sgs = snp_scores[egene][match_tis].mean(dtype="float32") + else: + sgs = 0 + else: + sgs = 0 + + # flip + if sgs != 0 and snp_ref[eqtl.variant] != eqtl.allele1: + sgs *= -1 + + eqtl_df_scores.append(sgs) + eqtl_df["score"] = eqtl_df_scores return eqtl_df -def classify_auroc( - gtex_scores_file: str, keyword: str, eqtl_df: pd.DataFrame, score_key: str = "SED" -): +def classify_auroc(gtex_scores_file: str, keyword: str, score_key: str = "SED"): """Read eQTL RNA predictions for negatives from the given tissue. Args: gtex_scores_file (str): Variant scores HDF5. tissue_keyword (str): tissue keyword, for matching GTEx targets - eqtl_df (pd.DataFrame): eQTL dataframe score_key (str): score key in HDF5 file verbose (bool): Print matching targets. Returns: class_auroc (float): Classification AUROC. """ - gtex_nscores_file = gtex_scores_file.replace("_pos", "_neg") - with h5py.File(gtex_nscores_file, "r") as gtex_scores_h5: - # read data - snp_i = gtex_scores_h5["si"][:] - snps = np.array([snp.decode("UTF-8") for snp in gtex_scores_h5["snp"]]) - genes = np.array([snp.decode("UTF-8") for snp in gtex_scores_h5["gene"]]) - target_ids = np.array( - [ref.decode("UTF-8") for ref in gtex_scores_h5["target_ids"]] - ) - target_labels = np.array( - [ref.decode("UTF-8") for ref in gtex_scores_h5["target_labels"]] - ) + # rescore positives using all genes + with h5py.File(gtex_scores_file, "r") as gtex_scores_h5: + gtex_scores = gtex_scores_h5[score_key] # determine matching GTEx targets + target_ids = [tid.decode("UTF-8") for tid in gtex_scores_h5["target_ids"]] + target_labels = [tl.decode("UTF-8") for tl in gtex_scores_h5["target_labels"]] match_tis = [] for ti in range(len(target_ids)): if ( @@ -302,26 +287,33 @@ def classify_auroc( match_tis.append(ti) match_tis = np.array(match_tis) - # read scores and take mean across targets - score_table = gtex_scores_h5[score_key][..., match_tis].mean( - axis=-1, dtype="float32" - ) - # score_table = np.arcsinh(score_table) + # aggregate across genes w/ sum abs + psnp_scores = {} + for snp in gtex_scores.keys(): + gtex_snp_scores = gtex_scores[snp] + psnp_scores[snp] = 0 + for gene in gtex_snp_scores.keys(): + sgs = gtex_snp_scores[gene][match_tis].mean(dtype="float32") + psnp_scores[snp] += np.abs(sgs) - # aggregate across genes w/ sum abs - nsnp_scores = {} - for sgi in range(score_table.shape[0]): - snp = snps[snp_i[sgi]] - nsnp_scores[snp] = nsnp_scores.get(snp, 0) + np.abs(score_table[sgi]) + # score negatives + gtex_nscores_file = gtex_scores_file.replace("_pos", "_neg") + with h5py.File(gtex_nscores_file, "r") as gtex_scores_h5: + gtex_scores = gtex_scores_h5[score_key] - psnp_scores = {} - for sgi, eqtl in eqtl_df.iterrows(): - snp = eqtl.variant - psnp_scores[snp] = psnp_scores.get(snp, 0) + np.abs(eqtl.score) + # aggregate across genes w/ sum abs + nsnp_scores = {} + for snp in gtex_scores.keys(): + gtex_snp_scores = gtex_scores[snp] + nsnp_scores[snp] = 0 + for gene in gtex_snp_scores.keys(): + sgs = gtex_snp_scores[gene][match_tis].mean(dtype="float32") + nsnp_scores[snp] += np.abs(sgs) # compute AUROC Xp = list(psnp_scores.values()) Xn = list(nsnp_scores.values()) + print(keyword, len(Xp), len(Xn)) X = Xp + Xn y = [1] * len(Xp) + [0] * len(Xn) diff --git a/src/westminster/scripts/westminster_gtexg_folds.py b/src/westminster/scripts/westminster_gtexg_folds.py index 0f02186..6471b5e 100755 --- a/src/westminster/scripts/westminster_gtexg_folds.py +++ b/src/westminster/scripts/westminster_gtexg_folds.py @@ -18,14 +18,13 @@ import pickle import pdb import os +import sys import h5py import numpy as np import slurm -from westminster.multi import nonzero_h5 - """ westminster_gtexg_folds.py @@ -41,75 +40,96 @@ def main(): parser = OptionParser(usage) # sed options - sed_options = OptionGroup(parser, "borzoi_sed.py options") - sed_options.add_option( - "-b", - dest="bedgraph", - default=False, - action="store_true", - help="Write ref/alt predictions as bedgraph [Default: %default]", + snp_options = OptionGroup(parser, "hound_snpgene.py options") + snp_options.add_option( + "-c", + dest="cluster_pct", + default=0, + type="float", + help="Cluster genes within a %% of the seq length to make a single ref pred [Default: %default]", ) - sed_options.add_option( + snp_options.add_option( "-f", dest="genome_fasta", - default="%s/data/hg38.fa" % os.environ["BASENJIDIR"], + default="%s/assembly/ucsc/hg38.fa" % os.environ["HG38"], help="Genome FASTA for sequences [Default: %default]", ) - sed_options.add_option( + snp_options.add_option( "-g", dest="genes_gtf", default="%s/genes/gencode41/gencode41_basic_nort.gtf" % os.environ["HG38"], help="GTF for gene definition [Default %default]", ) - sed_options.add_option( + snp_options.add_option( "-o", dest="out_dir", default="sed", help="Output directory for tables and plots [Default: %default]", ) - sed_options.add_option( + snp_options.add_option( "--rc", dest="rc", default=False, action="store_true", help="Average forward and reverse complement predictions [Default: %default]", ) - sed_options.add_option( + snp_options.add_option( "--shifts", dest="shifts", default="0", type="str", help="Ensemble prediction shifts [Default: %default]", ) - sed_options.add_option( + snp_options.add_option( "--span", dest="span", default=False, action="store_true", help="Aggregate entire gene span [Default: %default]", ) - sed_options.add_option( + snp_options.add_option( "--stats", - dest="sed_stats", - default="SED", + dest="snp_stats", + default="logSUM", help="Comma-separated list of stats to save. [Default: %default]", ) - sed_options.add_option( + snp_options.add_option( "-t", dest="targets_file", default=None, type="str", help="File specifying target indexes and labels in table format", ) - sed_options.add_option( + snp_options.add_option( "-u", dest="untransform_old", default=False, action="store_true" ) - parser.add_option_group(sed_options) + snp_options.add_option( + "--gcs", + dest="gcs", + default=False, + action="store_true", + help="Input and output are in gcs", + ) + snp_options.add_option( + "--require_gpu", + dest="require_gpu", + default=False, + action="store_true", + help="Only run on GPU", + ) + snp_options.add_option( + "--tensorrt", + dest="tensorrt", + default=False, + action="store_true", + help="Model type is tensorrt optimized", + ) + parser.add_option_group(snp_options) # cross-fold fold_options = OptionGroup(parser, "cross-fold options") fold_options.add_option( - "-c", + "--crosses", dest="crosses", default=1, type="int", @@ -193,7 +213,7 @@ def main(): gtex_out_dir = options.out_dir # split SNP stats - sed_stats = options.sed_stats.split(",") + snp_stats = options.snp_stats.split(",") # merge study/tissue variants mpos_vcf_file = "%s/pos_merge.vcf" % options.gtex_vcf_dir @@ -235,7 +255,7 @@ def main(): options_pkl.close() # create base fold command - cmd_fold = "%s time borzoi_sed.py %s %s %s" % ( + cmd_fold = "%s hound_snpgene.py %s %s %s" % ( cmd_base, options_pkl_file, params_file, @@ -243,8 +263,8 @@ def main(): ) for pi in range(options.processes): - sed_file = "%s/job%d/sed.h5" % (options.out_dir, pi) - if not nonzero_h5(sed_file, sed_stats): + scores_file = "%s/job%d/scores.h5" % (options.out_dir, pi) + if not nonzero_h5(scores_file, snp_stats): cmd_job = "%s %s %d" % (cmd_fold, mneg_vcf_file, pi) j = slurm.Job( cmd_job, @@ -272,7 +292,7 @@ def main(): options_pkl.close() # create base fold command - cmd_fold = "%s time borzoi_sed.py %s %s %s" % ( + cmd_fold = "%s hound_snpgene.py %s %s %s" % ( cmd_base, options_pkl_file, params_file, @@ -280,8 +300,8 @@ def main(): ) for pi in range(options.processes): - sed_file = "%s/job%d/sed.h5" % (options.out_dir, pi) - if not nonzero_h5(sed_file, sed_stats): + scores_file = "%s/job%d/scores.h5" % (options.out_dir, pi) + if not nonzero_h5(scores_file, snp_stats): cmd_job = "%s %s %d" % (cmd_fold, mpos_vcf_file, pi) j = slurm.Job( cmd_job, @@ -310,13 +330,15 @@ def main(): # collect negatives neg_out_dir = "%s/merge_neg" % it_out_dir - if not os.path.isfile("%s/sed.h5" % neg_out_dir): - collect_scores(neg_out_dir, options.processes, "sed.h5") + if not os.path.isfile("%s/scores.h5" % neg_out_dir): + print(f"Collecting {neg_out_dir}") + collect_scores(neg_out_dir, options.processes, "scores.h5") # collect positives pos_out_dir = "%s/merge_pos" % it_out_dir - if not os.path.isfile("%s/sed.h5" % pos_out_dir): - collect_scores(pos_out_dir, options.processes, "sed.h5") + if not os.path.isfile("%s/scores.h5" % pos_out_dir): + print(f"Collecting {pos_out_dir}") + collect_scores(pos_out_dir, options.processes, "scores.h5") ################################################################ # split study/tissue variants @@ -324,13 +346,13 @@ def main(): for ci in range(options.crosses): for fi in range(num_folds): it_out_dir = "%s/f%dc%d/%s" % (exp_dir, fi, ci, gtex_out_dir) - print(it_out_dir) + print(f"Splitting {it_out_dir}") # split positives - split_scores(it_out_dir, "pos", options.gtex_vcf_dir, sed_stats) + split_scores(it_out_dir, "pos", options.gtex_vcf_dir, snp_stats) # split negatives - split_scores(it_out_dir, "neg", options.gtex_vcf_dir, sed_stats) + split_scores(it_out_dir, "neg", options.gtex_vcf_dir, snp_stats) ################################################################ # ensemble @@ -344,36 +366,37 @@ def main(): os.mkdir(gtex_dir) for gtex_pos_vcf in glob.glob("%s/*_pos.vcf" % options.gtex_vcf_dir): + print(f"Ensembling {gtex_pos_vcf}") gtex_neg_vcf = gtex_pos_vcf.replace("_pos.", "_neg.") pos_base = os.path.splitext(os.path.split(gtex_pos_vcf)[1])[0] neg_base = os.path.splitext(os.path.split(gtex_neg_vcf)[1])[0] - # collect SED files - sed_pos_files = [] - sed_neg_files = [] + # collect score files + score_pos_files = [] + score_neg_files = [] for ci in range(options.crosses): for fi in range(num_folds): it_dir = "%s/f%dc%d" % (exp_dir, fi, ci) it_out_dir = "%s/%s" % (it_dir, gtex_out_dir) - sed_pos_file = "%s/%s/sed.h5" % (it_out_dir, pos_base) - sed_pos_files.append(sed_pos_file) + score_pos_file = "%s/%s/scores.h5" % (it_out_dir, pos_base) + score_pos_files.append(score_pos_file) - sed_neg_file = "%s/%s/sed.h5" % (it_out_dir, neg_base) - sed_neg_files.append(sed_neg_file) + score_neg_file = "%s/%s/scores.h5" % (it_out_dir, neg_base) + score_neg_files.append(score_neg_file) # ensemble ens_pos_dir = "%s/%s" % (gtex_dir, pos_base) os.makedirs(ens_pos_dir, exist_ok=True) - ens_pos_file = "%s/sed.h5" % (ens_pos_dir) + ens_pos_file = "%s/scores.h5" % (ens_pos_dir) if not os.path.isfile(ens_pos_file): - ensemble_h5(ens_pos_file, sed_pos_files, sed_stats) + ensemble_h5(ens_pos_file, score_pos_files, snp_stats) ens_neg_dir = "%s/%s" % (gtex_dir, neg_base) os.makedirs(ens_neg_dir, exist_ok=True) - ens_neg_file = "%s/sed.h5" % (ens_neg_dir) + ens_neg_file = "%s/scores.h5" % (ens_neg_dir) if not os.path.isfile(ens_neg_file): - ensemble_h5(ens_neg_file, sed_neg_files, sed_stats) + ensemble_h5(ens_neg_file, score_neg_files, snp_stats) ################################################################ # coefficient analysis @@ -386,9 +409,9 @@ def main(): it_dir = "%s/f%dc%d" % (exp_dir, fi, ci) it_out_dir = "%s/%s" % (it_dir, gtex_out_dir) - for sed_stat in sed_stats: - coef_out_dir = f"{it_out_dir}/coef-{sed_stat}" - cmd_coef = f"{cmd_base} -o {coef_out_dir} -s {sed_stat} {it_out_dir}" + for snp_stat in snp_stats: + coef_out_dir = f"{it_out_dir}/coef-{snp_stat}" + cmd_coef = f"{cmd_base} -o {coef_out_dir} -s {snp_stat} {it_out_dir}" j = slurm.Job( cmd_coef, "coef", @@ -403,9 +426,9 @@ def main(): # ensemble it_out_dir = f"{exp_dir}/ensemble/{gtex_out_dir}" - for sed_stat in sed_stats: - coef_out_dir = f"{it_out_dir}/coef-{sed_stat}" - cmd_coef = f"{cmd_base} -o {coef_out_dir} -s {sed_stat} {it_out_dir}" + for snp_stat in snp_stats: + coef_out_dir = f"{it_out_dir}/coef-{snp_stat}" + cmd_coef = f"{cmd_base} -o {coef_out_dir} -s {snp_stat} {it_out_dir}" j = slurm.Job( cmd_coef, "coef", @@ -421,7 +444,7 @@ def main(): slurm.multi_run(jobs, verbose=True) -def collect_scores(out_dir: str, num_jobs: int, h5f_name: str = "sad.h5"): +def collect_scores(out_dir: str, num_jobs: int, h5f_name: str = "scores.h5"): """Collect parallel SAD jobs' output into one HDF5. Args: @@ -430,144 +453,141 @@ def collect_scores(out_dir: str, num_jobs: int, h5f_name: str = "sad.h5"): """ # count variants num_variants = 0 - num_rows = 0 for pi in range(num_jobs): # open job job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5f_name) - job_h5_open = h5py.File(job_h5_file, "r") - num_variants += len(job_h5_open["snp"]) - num_rows += len(job_h5_open["si"]) - job_h5_open.close() - - # initialize final h5 - final_h5_file = "%s/%s" % (out_dir, h5f_name) - final_h5_open = h5py.File(final_h5_file, "w") - - # SNP stats - snp_stats = {} - - job0_h5_file = "%s/job0/%s" % (out_dir, h5f_name) - job0_h5_open = h5py.File(job0_h5_file, "r") - for key in job0_h5_open.keys(): - if key in ["target_ids", "target_labels"]: - # copy - final_h5_open.create_dataset(key, data=job0_h5_open[key]) - - elif key in ["snp", "chr", "pos", "ref_allele", "alt_allele", "gene"]: - snp_stats[key] = [] - - elif job0_h5_open[key].ndim == 1: - final_h5_open.create_dataset( - key, shape=(num_rows,), dtype=job0_h5_open[key].dtype - ) - - else: - num_targets = job0_h5_open[key].shape[1] - final_h5_open.create_dataset( - key, shape=(num_rows, num_targets), dtype=job0_h5_open[key].dtype - ) + with h5py.File(job_h5_file, "r") as job_h5_open: + num_variants += len(job_h5_open["snp"]) - job0_h5_open.close() + final_dict = {} - # set values - vgi = 0 - vi = 0 for pi in range(num_jobs): # open job job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5f_name) with h5py.File(job_h5_file, "r") as job_h5_open: - job_snps = len(job_h5_open["snp"]) - job_rows = job_h5_open["si"].shape[0] - - # append to final for key in job_h5_open.keys(): - try: - if key in ["target_ids", "target_labels"]: - # once is enough - pass - - elif key in [ - "snp", - "chr", - "pos", - "ref_allele", - "alt_allele", - "gene", - ]: - snp_stats[key] += list(job_h5_open[key]) - - elif key == "si": - # re-index SNPs - final_h5_open[key][vgi : vgi + job_rows] = ( - job_h5_open[key][:] + vi - ) - - else: - final_h5_open[key][vgi : vgi + job_rows] = job_h5_open[key] - - except TypeError as e: - print(e) - print( - f"{job_h5_file} {key} has the wrong shape. Remove this file and rerun" - ) - exit() - - vgi += job_rows - vi += job_snps - - # create final SNP stat datasets - for key in snp_stats: - if key == "pos": - final_h5_open.create_dataset(key, data=np.array(snp_stats[key])) - else: - final_h5_open.create_dataset(key, data=np.array(snp_stats[key], dtype="S")) + if key in ["target_ids", "target_labels"]: + final_dict[key] = job_h5_open[key][:] + elif key in ["snp", "chr", "pos", "ref_allele", "alt_allele", "gene"]: + final_dict.setdefault(key, []).append(job_h5_open[key][:]) + elif isinstance(job_h5_open[key], h5py.Group): + snp_stat = key + if snp_stat not in final_dict: + final_dict[snp_stat] = {} + for snp in job_h5_open[snp_stat].keys(): + final_dict[snp_stat][snp] = {} + for gene in job_h5_open[snp_stat][snp].keys(): + final_dict[snp_stat][snp][gene] = job_h5_open[snp_stat][ + snp + ][gene][:] + else: + print(f"During collection, unknown key {key}") - final_h5_open.close() + # initialize final h5 + final_h5_file = "%s/%s" % (out_dir, h5f_name) + with h5py.File(final_h5_file, "w") as final_h5_open: + for key in final_dict.keys(): + if key in ["target_ids", "target_labels"]: + final_h5_open.create_dataset(key, data=final_dict[key]) + elif key in ["snp", "chr", "pos", "ref_allele", "alt_allele", "gene"]: + fdv = np.concatenate(final_dict[key]) + final_h5_open.create_dataset(key, data=fdv) + + else: + snp_stat = key + final_h5_open.create_group(snp_stat) + for snp in final_dict[snp_stat].keys(): + final_h5_open[snp_stat].create_group(snp) + for gene in final_dict[snp_stat][snp].keys(): + final_h5_open[snp_stat][snp].create_dataset( + gene, data=final_dict[snp_stat][snp][gene] + ) -def ensemble_h5(ensemble_h5_file: str, scores_files: list, sed_stats: list): +def ensemble_h5(ensemble_h5_file: str, scores_files: list, snp_stats: list): """Ensemble scores from multiple files into a single file. Args: ensemble_h5_file (str): ensemble score HDF5. scores_files ([str]): list of replicate score HDFs. - sed_stats ([str]): SED stats to average over folds. + snp_stats ([str]): SNP stats to average over folds. """ # open ensemble ensemble_h5 = h5py.File(ensemble_h5_file, "w") - # transfer non-SED keys - sed_shapes = {} - scores0_h5 = h5py.File(scores_files[0], "r") - for key in scores0_h5.keys(): - if key not in sed_stats: - ensemble_h5.create_dataset(key, data=scores0_h5[key]) - else: - sed_shapes[key] = scores0_h5[key].shape - scores0_h5.close() + with h5py.File(scores_files[0], "r") as scores0_h5: + for key in scores0_h5.keys(): + if key in snp_stats: + ensemble_h5.create_group(key) + else: + ensemble_h5.create_dataset(key, data=scores0_h5[key]) # average stats num_folds = len(scores_files) - for sed_stat in sed_stats: - # initialize ensemble array - sed_values = np.zeros(shape=sed_shapes[sed_stat], dtype="float32") - - # read and add folds + for snp_stat in snp_stats: + # sum scores across folds + snpgene_scores = {} for scores_file in scores_files: with h5py.File(scores_file, "r") as scores_h5: - sed_values += scores_h5[sed_stat][:].astype("float32") + for snp in scores_h5[snp_stat].keys(): + if snp not in snpgene_scores: + snpgene_scores[snp] = {} + for gene in scores_h5[snp_stat][snp].keys(): + if gene not in snpgene_scores[snp]: + snpgene_scores[snp][gene] = scores_h5[snp_stat][snp][gene][ + : + ].astype("float32") + else: + snpgene_scores[snp][gene] += scores_h5[snp_stat][snp][gene][ + : + ].astype("float32") + + # write average score + for snp in snpgene_scores: + ensemble_h5[snp_stat].create_group(snp) + for gene in snpgene_scores[snp]: + ensemble_score = snpgene_scores[snp][gene] / num_folds + ensemble_h5[snp_stat][snp].create_dataset( + gene, data=ensemble_score.astype("float16") + ) - # normalize and downcast - sed_values /= num_folds - sed_values = sed_values.astype("float16") + ensemble_h5.close() - # save - ensemble_h5.create_dataset(sed_stat, data=sed_values) - ensemble_h5.close() +def nonzero_h5(h5_file: str, stat_keys): + """Verify the HDF5 exists, and there are nonzero values + for each stat key given. + + Args: + h5_file (str): HDF5 file name. + stat_keys ([str]): List of SNP stat keys. + """ + if os.path.isfile(h5_file): + try: + with h5py.File(h5_file, "r") as h5_open: + snps_all = set([snp.decode("UTF-8") for snp in h5_open["snp"]]) + for sk in stat_keys: + snps_stat = set(h5_open[sk].keys()) + snps_ovl = snps_all & snps_stat + if len(snps_ovl) == 0: + print(f"{h5_file}: {sk} empty.") + return False + else: + for snp in list(snps_ovl)[:5]: + for gene in h5_open[sk][snp].keys(): + score = h5_open[sk][snp][gene][:] + if score.var(dtype="float64") == 0: + print(f"{h5_file}: {sk} {snp} {gene} zero var.") + return False + return True + except: + print(f"{h5_file}: error", sys.exc_info()[0]) + return False + else: + return False -def split_scores(it_out_dir: str, posneg: str, vcf_dir: str, sed_stats): +def split_scores(it_out_dir: str, posneg: str, vcf_dir: str, snp_stats): """Split merged VCF predictions in HDF5 into tissue-specific predictions in HDF5. @@ -575,23 +595,14 @@ def split_scores(it_out_dir: str, posneg: str, vcf_dir: str, sed_stats): it_out_dir (str): output directory for iteration. posneg (str): 'pos' or 'neg'. vcf_dir (str): directory containing tissue-specific VCFs. - sed_stats ([str]]): list of SED stats. + snp_stats ([str]]): list of SED stats. """ - merge_h5_file = "%s/merge_%s/sed.h5" % (it_out_dir, posneg) + merge_h5_file = "%s/merge_%s/scores.h5" % (it_out_dir, posneg) merge_h5 = h5py.File(merge_h5_file, "r") - # read merged data - merge_si = merge_h5["si"][:] - merge_snps = [snp.decode("UTF-8") for snp in merge_h5["snp"]] - merge_gene = [gene.decode("UTF-8") for gene in merge_h5["gene"]] - merge_scores = {} - for ss in sed_stats: - merge_scores[ss] = merge_h5[ss][:] - - # hash snps to row indexes - snp_ri = {} - for ri, si in enumerate(merge_si): - snp_ri.setdefault(merge_snps[si], []).append(ri) + # hash scored SNPs + all_snps = set([snp.decode("UTF-8") for snp in merge_h5["snp"]]) + scored_snps = set([snp for snp in merge_h5[snp_stats[0]].keys()]) # for each tissue VCF vcf_glob = "%s/*_%s.vcf" % (vcf_dir, posneg) @@ -601,74 +612,64 @@ def split_scores(it_out_dir: str, posneg: str, vcf_dir: str, sed_stats): tissue_label = tissue_label.replace("_neg.vcf", "") # initialize HDF5 arrays - sed_snp = [] - sed_chr = [] - sed_pos = [] - sed_ref = [] - sed_alt = [] - sed_gene = [] - sed_snpi = [] - sed_scores = {} - for ss in sed_stats: - sed_scores[ss] = [] + snpg_snp = [] + snpg_chr = [] + snpg_pos = [] + snpg_ref = [] + snpg_alt = [] # fill HDF5 arrays with ordered SNPs - si = 0 for line in open(tissue_vcf_file): if not line.startswith("#"): a = line.split() chrm, pos, snp, ref, alt = a[:5] # SNPs w/o genes disappear - if snp in snp_ri: - sed_snp.append(snp) - sed_chr.append(chrm) - sed_pos.append(int(pos)) - sed_ref.append(ref) - sed_alt.append(alt) - - for ri in snp_ri[snp]: - sed_snpi.append(si) - sed_gene.append(merge_gene[ri]) - for ss in sed_stats: - sed_scores[ss].append(merge_scores[ss][ri]) - - si += 1 + if snp in all_snps: + snpg_snp.append(snp) + snpg_chr.append(chrm) + snpg_pos.append(int(pos)) + snpg_ref.append(ref) + snpg_alt.append(alt) # write tissue HDF5 tissue_dir = "%s/%s_%s" % (it_out_dir, tissue_label, posneg) os.makedirs(tissue_dir, exist_ok=True) - with h5py.File("%s/sed.h5" % tissue_dir, "w") as tissue_h5: + with h5py.File("%s/scores.h5" % tissue_dir, "w") as tissue_h5: # write SNPs - tissue_h5.create_dataset("snp", data=np.array(sed_snp, "S")) + tissue_h5.create_dataset("snp", data=np.array(snpg_snp, "S")) # write chr - tissue_h5.create_dataset("chr", data=np.array(sed_chr, "S")) + tissue_h5.create_dataset("chr", data=np.array(snpg_chr, "S")) # write SNP pos - tissue_h5.create_dataset("pos", data=np.array(sed_pos, dtype="uint32")) + tissue_h5.create_dataset("pos", data=np.array(snpg_pos, dtype="uint32")) # write ref allele - tissue_h5.create_dataset("ref_allele", data=np.array(sed_ref, dtype="S")) + tissue_h5.create_dataset("ref_allele", data=np.array(snpg_ref, dtype="S")) # write alt allele - tissue_h5.create_dataset("alt_allele", data=np.array(sed_alt, dtype="S")) - - # write SNP i - tissue_h5.create_dataset("si", data=np.array(sed_snpi)) - - # write gene - tissue_h5.create_dataset("gene", data=np.array(sed_gene, "S")) + tissue_h5.create_dataset("alt_allele", data=np.array(snpg_alt, dtype="S")) # write targets tissue_h5.create_dataset("target_ids", data=merge_h5["target_ids"]) tissue_h5.create_dataset("target_labels", data=merge_h5["target_labels"]) - # write sed stats - for ss in sed_stats: - tissue_h5.create_dataset( - ss, data=np.array(sed_scores[ss], dtype="float16") - ) + # write SNP stats + genes = set() + for ss in snp_stats: + tissue_h5.create_group(ss) + for snp in snpg_snp: + if snp in scored_snps: + tissue_h5[ss].create_group(snp) + for gene in merge_h5[ss][snp].keys(): + tissue_h5[ss][snp].create_dataset( + gene, data=merge_h5[ss][snp][gene][:] + ) + genes.add(gene) + + # write genes + tissue_h5.create_dataset("gene", data=np.array(sorted(genes), "S")) merge_h5.close() diff --git a/src/westminster/scripts/westminster_gtexsed_folds.py b/src/westminster/scripts/westminster_gtexsed_folds.py new file mode 100755 index 0000000..0f02186 --- /dev/null +++ b/src/westminster/scripts/westminster_gtexsed_folds.py @@ -0,0 +1,680 @@ +#!/usr/bin/env python +# Copyright 2023 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser, OptionGroup +import glob +import pickle +import pdb +import os + +import h5py +import numpy as np + +import slurm + +from westminster.multi import nonzero_h5 + +""" +westminster_gtexg_folds.py + +Benchmark Baskerville model replicates on GTEx eQTL coefficient task. +""" + + +################################################################################ +# main +################################################################################ +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + + # sed options + sed_options = OptionGroup(parser, "borzoi_sed.py options") + sed_options.add_option( + "-b", + dest="bedgraph", + default=False, + action="store_true", + help="Write ref/alt predictions as bedgraph [Default: %default]", + ) + sed_options.add_option( + "-f", + dest="genome_fasta", + default="%s/data/hg38.fa" % os.environ["BASENJIDIR"], + help="Genome FASTA for sequences [Default: %default]", + ) + sed_options.add_option( + "-g", + dest="genes_gtf", + default="%s/genes/gencode41/gencode41_basic_nort.gtf" % os.environ["HG38"], + help="GTF for gene definition [Default %default]", + ) + sed_options.add_option( + "-o", + dest="out_dir", + default="sed", + help="Output directory for tables and plots [Default: %default]", + ) + sed_options.add_option( + "--rc", + dest="rc", + default=False, + action="store_true", + help="Average forward and reverse complement predictions [Default: %default]", + ) + sed_options.add_option( + "--shifts", + dest="shifts", + default="0", + type="str", + help="Ensemble prediction shifts [Default: %default]", + ) + sed_options.add_option( + "--span", + dest="span", + default=False, + action="store_true", + help="Aggregate entire gene span [Default: %default]", + ) + sed_options.add_option( + "--stats", + dest="sed_stats", + default="SED", + help="Comma-separated list of stats to save. [Default: %default]", + ) + sed_options.add_option( + "-t", + dest="targets_file", + default=None, + type="str", + help="File specifying target indexes and labels in table format", + ) + sed_options.add_option( + "-u", dest="untransform_old", default=False, action="store_true" + ) + parser.add_option_group(sed_options) + + # cross-fold + fold_options = OptionGroup(parser, "cross-fold options") + fold_options.add_option( + "-c", + dest="crosses", + default=1, + type="int", + help="Number of cross-fold rounds [Default:%default]", + ) + fold_options.add_option( + "-d", + dest="data_head", + default=None, + type="int", + help="Index for dataset/head [Default: %default]", + ) + fold_options.add_option( + "-e", + dest="conda_env", + default="tf210", + help="Anaconda environment [Default: %default]", + ) + fold_options.add_option( + "--gtex", + dest="gtex_vcf_dir", + default="/home/drk/seqnn/data/gtex_fine/susie_pip90", + ) + fold_options.add_option( + "--name", + dest="name", + default="gtex", + help="SLURM name prefix [Default: %default]", + ) + fold_options.add_option( + "--max_proc", + dest="max_proc", + default=None, + type="int", + help="Maximum concurrent processes [Default: %default]", + ) + fold_options.add_option( + "-p", + dest="processes", + default=None, + type="int", + help="Number of processes, passed by multi script. \ + (Unused, but needs to appear as dummy.)", + ) + fold_options.add_option( + "-q", + dest="queue", + default="geforce", + help="SLURM queue on which to run the jobs [Default: %default]", + ) + parser.add_option_group(fold_options) + + (options, args) = parser.parse_args() + + if len(args) != 2: + parser.error("Must provide parameters file and cross-fold directory") + else: + params_file = args[0] + exp_dir = args[1] + + ####################################################### + # prep work + + # count folds + num_folds = 0 + fold0_dir = "%s/f%dc0" % (exp_dir, num_folds) + model_file = "%s/train/model_best.h5" % fold0_dir + if options.data_head is not None: + model_file = "%s/train/model%d_best.h5" % (fold0_dir, options.data_head) + while os.path.isfile(model_file): + num_folds += 1 + fold0_dir = "%s/f%dc0" % (exp_dir, num_folds) + model_file = "%s/train/model_best.h5" % fold0_dir + if options.data_head is not None: + model_file = "%s/train/model%d_best.h5" % (fold0_dir, options.data_head) + print("Found %d folds" % num_folds) + if num_folds == 0: + exit(1) + + # extract output subdirectory name + gtex_out_dir = options.out_dir + + # split SNP stats + sed_stats = options.sed_stats.split(",") + + # merge study/tissue variants + mpos_vcf_file = "%s/pos_merge.vcf" % options.gtex_vcf_dir + mneg_vcf_file = "%s/neg_merge.vcf" % options.gtex_vcf_dir + + ################################################################ + # SED + + # SED command base + cmd_base = ". /home/drk/anaconda3/etc/profile.d/conda.sh;" + cmd_base += " conda activate %s;" % options.conda_env + cmd_base += " echo $HOSTNAME;" + + jobs = [] + + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = "%s/f%dc%d" % (exp_dir, fi, ci) + name = "%s-f%dc%d" % (options.name, fi, ci) + + # update output directory + it_out_dir = "%s/%s" % (it_dir, gtex_out_dir) + os.makedirs(it_out_dir, exist_ok=True) + + # choose model + model_file = "%s/train/model_best.h5" % it_dir + if options.data_head is not None: + model_file = "%s/train/model%d_best.h5" % (it_dir, options.data_head) + + ######################################## + # negative jobs + + # pickle options + options.out_dir = "%s/merge_neg" % it_out_dir + os.makedirs(options.out_dir, exist_ok=True) + options_pkl_file = "%s/options.pkl" % options.out_dir + options_pkl = open(options_pkl_file, "wb") + pickle.dump(options, options_pkl) + options_pkl.close() + + # create base fold command + cmd_fold = "%s time borzoi_sed.py %s %s %s" % ( + cmd_base, + options_pkl_file, + params_file, + model_file, + ) + + for pi in range(options.processes): + sed_file = "%s/job%d/sed.h5" % (options.out_dir, pi) + if not nonzero_h5(sed_file, sed_stats): + cmd_job = "%s %s %d" % (cmd_fold, mneg_vcf_file, pi) + j = slurm.Job( + cmd_job, + "%s_neg%d" % (name, pi), + "%s/job%d.out" % (options.out_dir, pi), + "%s/job%d.err" % (options.out_dir, pi), + "%s/job%d.sb" % (options.out_dir, pi), + queue=options.queue, + gpu=1, + cpu=2, + mem=30000, + time="7-0:0:0", + ) + jobs.append(j) + + ######################################## + # positive jobs + + # pickle options + options.out_dir = "%s/merge_pos" % it_out_dir + os.makedirs(options.out_dir, exist_ok=True) + options_pkl_file = "%s/options.pkl" % options.out_dir + options_pkl = open(options_pkl_file, "wb") + pickle.dump(options, options_pkl) + options_pkl.close() + + # create base fold command + cmd_fold = "%s time borzoi_sed.py %s %s %s" % ( + cmd_base, + options_pkl_file, + params_file, + model_file, + ) + + for pi in range(options.processes): + sed_file = "%s/job%d/sed.h5" % (options.out_dir, pi) + if not nonzero_h5(sed_file, sed_stats): + cmd_job = "%s %s %d" % (cmd_fold, mpos_vcf_file, pi) + j = slurm.Job( + cmd_job, + "%s_pos%d" % (name, pi), + "%s/job%d.out" % (options.out_dir, pi), + "%s/job%d.err" % (options.out_dir, pi), + "%s/job%d.sb" % (options.out_dir, pi), + queue=options.queue, + gpu=1, + cpu=2, + mem=30000, + time="7-0:0:0", + ) + jobs.append(j) + + slurm.multi_run( + jobs, max_proc=options.max_proc, verbose=True, launch_sleep=10, update_sleep=60 + ) + + ####################################################### + # collect output + + for ci in range(options.crosses): + for fi in range(num_folds): + it_out_dir = "%s/f%dc%d/%s" % (exp_dir, fi, ci, gtex_out_dir) + + # collect negatives + neg_out_dir = "%s/merge_neg" % it_out_dir + if not os.path.isfile("%s/sed.h5" % neg_out_dir): + collect_scores(neg_out_dir, options.processes, "sed.h5") + + # collect positives + pos_out_dir = "%s/merge_pos" % it_out_dir + if not os.path.isfile("%s/sed.h5" % pos_out_dir): + collect_scores(pos_out_dir, options.processes, "sed.h5") + + ################################################################ + # split study/tissue variants + + for ci in range(options.crosses): + for fi in range(num_folds): + it_out_dir = "%s/f%dc%d/%s" % (exp_dir, fi, ci, gtex_out_dir) + print(it_out_dir) + + # split positives + split_scores(it_out_dir, "pos", options.gtex_vcf_dir, sed_stats) + + # split negatives + split_scores(it_out_dir, "neg", options.gtex_vcf_dir, sed_stats) + + ################################################################ + # ensemble + + ensemble_dir = "%s/ensemble" % exp_dir + if not os.path.isdir(ensemble_dir): + os.mkdir(ensemble_dir) + + gtex_dir = "%s/%s" % (ensemble_dir, gtex_out_dir) + if not os.path.isdir(gtex_dir): + os.mkdir(gtex_dir) + + for gtex_pos_vcf in glob.glob("%s/*_pos.vcf" % options.gtex_vcf_dir): + gtex_neg_vcf = gtex_pos_vcf.replace("_pos.", "_neg.") + pos_base = os.path.splitext(os.path.split(gtex_pos_vcf)[1])[0] + neg_base = os.path.splitext(os.path.split(gtex_neg_vcf)[1])[0] + + # collect SED files + sed_pos_files = [] + sed_neg_files = [] + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = "%s/f%dc%d" % (exp_dir, fi, ci) + it_out_dir = "%s/%s" % (it_dir, gtex_out_dir) + + sed_pos_file = "%s/%s/sed.h5" % (it_out_dir, pos_base) + sed_pos_files.append(sed_pos_file) + + sed_neg_file = "%s/%s/sed.h5" % (it_out_dir, neg_base) + sed_neg_files.append(sed_neg_file) + + # ensemble + ens_pos_dir = "%s/%s" % (gtex_dir, pos_base) + os.makedirs(ens_pos_dir, exist_ok=True) + ens_pos_file = "%s/sed.h5" % (ens_pos_dir) + if not os.path.isfile(ens_pos_file): + ensemble_h5(ens_pos_file, sed_pos_files, sed_stats) + + ens_neg_dir = "%s/%s" % (gtex_dir, neg_base) + os.makedirs(ens_neg_dir, exist_ok=True) + ens_neg_file = "%s/sed.h5" % (ens_neg_dir) + if not os.path.isfile(ens_neg_file): + ensemble_h5(ens_neg_file, sed_neg_files, sed_stats) + + ################################################################ + # coefficient analysis + + cmd_base = "westminster_gtexg_coef.py -g %s" % options.gtex_vcf_dir + + jobs = [] + for ci in range(options.crosses): + for fi in range(num_folds): + it_dir = "%s/f%dc%d" % (exp_dir, fi, ci) + it_out_dir = "%s/%s" % (it_dir, gtex_out_dir) + + for sed_stat in sed_stats: + coef_out_dir = f"{it_out_dir}/coef-{sed_stat}" + cmd_coef = f"{cmd_base} -o {coef_out_dir} -s {sed_stat} {it_out_dir}" + j = slurm.Job( + cmd_coef, + "coef", + f"{coef_out_dir}.out", + f"{coef_out_dir}.err", + queue="standard", + cpu=2, + mem=22000, + time="12:0:0", + ) + jobs.append(j) + + # ensemble + it_out_dir = f"{exp_dir}/ensemble/{gtex_out_dir}" + for sed_stat in sed_stats: + coef_out_dir = f"{it_out_dir}/coef-{sed_stat}" + cmd_coef = f"{cmd_base} -o {coef_out_dir} -s {sed_stat} {it_out_dir}" + j = slurm.Job( + cmd_coef, + "coef", + f"{coef_out_dir}.out", + f"{coef_out_dir}.err", + queue="standard", + cpu=2, + mem=22000, + time="12:0:0", + ) + jobs.append(j) + + slurm.multi_run(jobs, verbose=True) + + +def collect_scores(out_dir: str, num_jobs: int, h5f_name: str = "sad.h5"): + """Collect parallel SAD jobs' output into one HDF5. + + Args: + out_dir (str): Output directory. + num_jobs (int): Number of jobs to combine results from. + """ + # count variants + num_variants = 0 + num_rows = 0 + for pi in range(num_jobs): + # open job + job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5f_name) + job_h5_open = h5py.File(job_h5_file, "r") + num_variants += len(job_h5_open["snp"]) + num_rows += len(job_h5_open["si"]) + job_h5_open.close() + + # initialize final h5 + final_h5_file = "%s/%s" % (out_dir, h5f_name) + final_h5_open = h5py.File(final_h5_file, "w") + + # SNP stats + snp_stats = {} + + job0_h5_file = "%s/job0/%s" % (out_dir, h5f_name) + job0_h5_open = h5py.File(job0_h5_file, "r") + for key in job0_h5_open.keys(): + if key in ["target_ids", "target_labels"]: + # copy + final_h5_open.create_dataset(key, data=job0_h5_open[key]) + + elif key in ["snp", "chr", "pos", "ref_allele", "alt_allele", "gene"]: + snp_stats[key] = [] + + elif job0_h5_open[key].ndim == 1: + final_h5_open.create_dataset( + key, shape=(num_rows,), dtype=job0_h5_open[key].dtype + ) + + else: + num_targets = job0_h5_open[key].shape[1] + final_h5_open.create_dataset( + key, shape=(num_rows, num_targets), dtype=job0_h5_open[key].dtype + ) + + job0_h5_open.close() + + # set values + vgi = 0 + vi = 0 + for pi in range(num_jobs): + # open job + job_h5_file = "%s/job%d/%s" % (out_dir, pi, h5f_name) + with h5py.File(job_h5_file, "r") as job_h5_open: + job_snps = len(job_h5_open["snp"]) + job_rows = job_h5_open["si"].shape[0] + + # append to final + for key in job_h5_open.keys(): + try: + if key in ["target_ids", "target_labels"]: + # once is enough + pass + + elif key in [ + "snp", + "chr", + "pos", + "ref_allele", + "alt_allele", + "gene", + ]: + snp_stats[key] += list(job_h5_open[key]) + + elif key == "si": + # re-index SNPs + final_h5_open[key][vgi : vgi + job_rows] = ( + job_h5_open[key][:] + vi + ) + + else: + final_h5_open[key][vgi : vgi + job_rows] = job_h5_open[key] + + except TypeError as e: + print(e) + print( + f"{job_h5_file} {key} has the wrong shape. Remove this file and rerun" + ) + exit() + + vgi += job_rows + vi += job_snps + + # create final SNP stat datasets + for key in snp_stats: + if key == "pos": + final_h5_open.create_dataset(key, data=np.array(snp_stats[key])) + else: + final_h5_open.create_dataset(key, data=np.array(snp_stats[key], dtype="S")) + + final_h5_open.close() + + +def ensemble_h5(ensemble_h5_file: str, scores_files: list, sed_stats: list): + """Ensemble scores from multiple files into a single file. + + Args: + ensemble_h5_file (str): ensemble score HDF5. + scores_files ([str]): list of replicate score HDFs. + sed_stats ([str]): SED stats to average over folds. + """ + # open ensemble + ensemble_h5 = h5py.File(ensemble_h5_file, "w") + + # transfer non-SED keys + sed_shapes = {} + scores0_h5 = h5py.File(scores_files[0], "r") + for key in scores0_h5.keys(): + if key not in sed_stats: + ensemble_h5.create_dataset(key, data=scores0_h5[key]) + else: + sed_shapes[key] = scores0_h5[key].shape + scores0_h5.close() + + # average stats + num_folds = len(scores_files) + for sed_stat in sed_stats: + # initialize ensemble array + sed_values = np.zeros(shape=sed_shapes[sed_stat], dtype="float32") + + # read and add folds + for scores_file in scores_files: + with h5py.File(scores_file, "r") as scores_h5: + sed_values += scores_h5[sed_stat][:].astype("float32") + + # normalize and downcast + sed_values /= num_folds + sed_values = sed_values.astype("float16") + + # save + ensemble_h5.create_dataset(sed_stat, data=sed_values) + + ensemble_h5.close() + + +def split_scores(it_out_dir: str, posneg: str, vcf_dir: str, sed_stats): + """Split merged VCF predictions in HDF5 into tissue-specific + predictions in HDF5. + + Args: + it_out_dir (str): output directory for iteration. + posneg (str): 'pos' or 'neg'. + vcf_dir (str): directory containing tissue-specific VCFs. + sed_stats ([str]]): list of SED stats. + """ + merge_h5_file = "%s/merge_%s/sed.h5" % (it_out_dir, posneg) + merge_h5 = h5py.File(merge_h5_file, "r") + + # read merged data + merge_si = merge_h5["si"][:] + merge_snps = [snp.decode("UTF-8") for snp in merge_h5["snp"]] + merge_gene = [gene.decode("UTF-8") for gene in merge_h5["gene"]] + merge_scores = {} + for ss in sed_stats: + merge_scores[ss] = merge_h5[ss][:] + + # hash snps to row indexes + snp_ri = {} + for ri, si in enumerate(merge_si): + snp_ri.setdefault(merge_snps[si], []).append(ri) + + # for each tissue VCF + vcf_glob = "%s/*_%s.vcf" % (vcf_dir, posneg) + for tissue_vcf_file in glob.glob(vcf_glob): + tissue_label = tissue_vcf_file.split("/")[-1] + tissue_label = tissue_label.replace("_pos.vcf", "") + tissue_label = tissue_label.replace("_neg.vcf", "") + + # initialize HDF5 arrays + sed_snp = [] + sed_chr = [] + sed_pos = [] + sed_ref = [] + sed_alt = [] + sed_gene = [] + sed_snpi = [] + sed_scores = {} + for ss in sed_stats: + sed_scores[ss] = [] + + # fill HDF5 arrays with ordered SNPs + si = 0 + for line in open(tissue_vcf_file): + if not line.startswith("#"): + a = line.split() + chrm, pos, snp, ref, alt = a[:5] + + # SNPs w/o genes disappear + if snp in snp_ri: + sed_snp.append(snp) + sed_chr.append(chrm) + sed_pos.append(int(pos)) + sed_ref.append(ref) + sed_alt.append(alt) + + for ri in snp_ri[snp]: + sed_snpi.append(si) + sed_gene.append(merge_gene[ri]) + for ss in sed_stats: + sed_scores[ss].append(merge_scores[ss][ri]) + + si += 1 + + # write tissue HDF5 + tissue_dir = "%s/%s_%s" % (it_out_dir, tissue_label, posneg) + os.makedirs(tissue_dir, exist_ok=True) + with h5py.File("%s/sed.h5" % tissue_dir, "w") as tissue_h5: + # write SNPs + tissue_h5.create_dataset("snp", data=np.array(sed_snp, "S")) + + # write chr + tissue_h5.create_dataset("chr", data=np.array(sed_chr, "S")) + + # write SNP pos + tissue_h5.create_dataset("pos", data=np.array(sed_pos, dtype="uint32")) + + # write ref allele + tissue_h5.create_dataset("ref_allele", data=np.array(sed_ref, dtype="S")) + + # write alt allele + tissue_h5.create_dataset("alt_allele", data=np.array(sed_alt, dtype="S")) + + # write SNP i + tissue_h5.create_dataset("si", data=np.array(sed_snpi)) + + # write gene + tissue_h5.create_dataset("gene", data=np.array(sed_gene, "S")) + + # write targets + tissue_h5.create_dataset("target_ids", data=merge_h5["target_ids"]) + tissue_h5.create_dataset("target_labels", data=merge_h5["target_labels"]) + + # write sed stats + for ss in sed_stats: + tissue_h5.create_dataset( + ss, data=np.array(sed_scores[ss], dtype="float16") + ) + + merge_h5.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == "__main__": + main()