diff --git a/.github/workflows/python-pytest.yml b/.github/workflows/python-pytest.yml index 93be0e5..32b5118 100644 --- a/.github/workflows/python-pytest.yml +++ b/.github/workflows/python-pytest.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.9", "3.10"] steps: - name: Checkout base repo diff --git a/pyproject.toml b/pyproject.toml index b98553e..a8f08d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,25 +1,53 @@ [build-system] -requires = [ - "setuptools>=45", - "wheel", - "setuptools_scm>=6.2" -] +requires = ["setuptools>=69.0.3", "setuptools_scm>=8.0.4"] build-backend = "setuptools.build_meta" [project] name = "baskerville" +description = "baskerville" authors = [ - {name = "Calico f(DNA)", email = "drk@calicolabs.com"} + {name = "Calico f(DNA)", email = "drk@calicolabs.com"}, ] readme = "README.md" -requires-python = ">=3.8, <3.11" -classifiers = ["License :: OSI Approved :: Apache License"] -dynamic = ["version", "description", "dependencies"] +classifiers = ["License :: OSI Approved :: MIT License"] +dynamic = ["version"] + +requires-python = ">=3.9" +dependencies = [ + "h5py~=3.10.0", + "intervaltree~=3.1.0", + "joblib~=1.1.1", + "matplotlib~=3.7.1", + "google-cloud-storage~=2.0.0", + "natsort~=7.1.1", + "networkx~=2.8.4", + "numpy~=1.24.3", + "pandas~=1.5.3", + "pybigwig~=0.3.18", + "pybedtools~=0.10.0", + "pysam~=0.22.0", + "qnorm~=0.8.1", + "seaborn~=0.12.2", + "scikit-learn~=1.2.2", + "scipy~=1.9.1", + "statsmodels~=0.13.5", + "tabulate~=0.8.10", + "tensorflow~=2.15.0", + "tqdm~=4.65.0", +] [project.optional-dependencies] dev = [ - "black==22.3.0", - "pytest==7.1.2" + "black~=23.12.1", + "pytest~=7.4.4", + "ruff~=0.1.11", ] -[tool.setuptools_scm] +gpu = [ + "tensorrt==8.6.1" +] + +[project.urls] +Homepage = "https://github.com/calico/baskerville" + +[tool.setuptools_scm] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 0f1e198..0000000 --- a/setup.cfg +++ /dev/null @@ -1,49 +0,0 @@ -[metadata] -name = baskerville -author = David Kelley -author_email = drk@calicolabs.com -description = Machine learning methods for DNA sequence analysis. -long_description = file: README.md -long_description_content_type = text/markdown -url = https://github.com/calico/baskerville -project_urls = - Bug Tracker = https://github.com/calico/baskerville/issues -classifiers = - Programming Language :: Python :: 3 - License :: OSI Approved :: Apache License - Operating System :: OS Independent - -[options] -package_dir = - = src -packages = find: -python_requires = >=3.8, <3.11 -install_requires = - h5py>=3.7.0 - intervaltree>=3.1.0 - joblib>=1.1.1 - matplotlib>=3.7.1 - google-cloud-storage>=2.0.0 - natsort>=7.1.1 - networkx>=2.8.4 - numpy>=1.24.3 - pandas>=1.5.3 - pybigwig>=0.3.18 - pysam>=0.21.0 - pybedtools>=0.9.0 - qnorm>=0.8.1 - seaborn>=0.12.2 - scikit-learn>=1.2.2 - scipy>=1.9.1 - statsmodels>=0.13.5 - tabulate>=0.8.10 - tensorflow>=2.12.0 - tqdm>=4.65.0 - -[options.extras_require] -dev = - black>=22.3.0 - pytest>=7.1.2 - -[options.packages.find] -where = src diff --git a/src/baskerville/data.py b/src/baskerville/data.py new file mode 100644 index 0000000..03fc4d0 --- /dev/null +++ b/src/baskerville/data.py @@ -0,0 +1,322 @@ +# Copyright 2023 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +import collections +import heapq +import math +import subprocess +import sys +import tempfile + +import numpy as np +import pysam + +""" +data.py + +Helper methods for hound_data* +""" + + +def annotate_unmap(mseqs, unmap_bed, seq_length, pool_width): + """Intersect the sequence segments with unmappable regions + and annoate the segments as NaN to possible be ignored. + + Args: + mseqs: list of ModelSeq's + unmap_bed: unmappable regions BED file + seq_length: sequence length (after cropping) + pool_width: pooled bin width + + Returns: + seqs_unmap: NxL binary NA indicators + """ + + # print sequence segments to file + seqs_temp = tempfile.NamedTemporaryFile() + seqs_bed_file = seqs_temp.name + write_seqs_bed(seqs_bed_file, mseqs) + + # hash segments to indexes + chr_start_indexes = {} + for i in range(len(mseqs)): + chr_start_indexes[(mseqs[i].chr, mseqs[i].start)] = i + + # initialize unmappable array + pool_seq_length = seq_length // pool_width + seqs_unmap = np.zeros((len(mseqs), pool_seq_length), dtype="bool") + + # intersect with unmappable regions + p = subprocess.Popen( + "bedtools intersect -wo -a %s -b %s" % (seqs_bed_file, unmap_bed), + shell=True, + stdout=subprocess.PIPE, + ) + for line in p.stdout: + line = line.decode("utf-8") + a = line.split() + + seq_chrom = a[0] + seq_start = int(a[1]) + seq_end = int(a[2]) + seq_key = (seq_chrom, seq_start) + + unmap_start = int(a[4]) + unmap_end = int(a[5]) + + overlap_start = max(seq_start, unmap_start) + overlap_end = min(seq_end, unmap_end) + + pool_seq_unmap_start = math.floor((overlap_start - seq_start) / pool_width) + pool_seq_unmap_end = math.ceil((overlap_end - seq_start) / pool_width) + + # skip minor overlaps to the first + first_start = seq_start + pool_seq_unmap_start * pool_width + first_end = first_start + pool_width + first_overlap = first_end - overlap_start + if first_overlap < 0.1 * pool_width: + pool_seq_unmap_start += 1 + + # skip minor overlaps to the last + last_start = seq_start + (pool_seq_unmap_end - 1) * pool_width + last_overlap = overlap_end - last_start + if last_overlap < 0.1 * pool_width: + pool_seq_unmap_end -= 1 + + seqs_unmap[ + chr_start_indexes[seq_key], pool_seq_unmap_start:pool_seq_unmap_end + ] = True + assert ( + seqs_unmap[ + chr_start_indexes[seq_key], pool_seq_unmap_start:pool_seq_unmap_end + ].sum() + == pool_seq_unmap_end - pool_seq_unmap_start + ) + + return seqs_unmap + + +################################################################################ +def break_large_contigs(contigs, break_t, verbose=False): + """Break large contigs in half until all contigs are under + the size threshold.""" + + # initialize a heapq of contigs and lengths + contig_heapq = [] + for ctg in contigs: + ctg_len = ctg.end - ctg.start + heapq.heappush(contig_heapq, (-ctg_len, ctg)) + + ctg_len = break_t + 1 + while ctg_len > break_t: + + # pop largest contig + ctg_nlen, ctg = heapq.heappop(contig_heapq) + ctg_len = -ctg_nlen + + # if too large + if ctg_len > break_t: + if verbose: + print( + "Breaking %s:%d-%d (%d nt)" % (ctg.chr, ctg.start, ctg.end, ctg_len) + ) + + # break in two + ctg_mid = ctg.start + ctg_len // 2 + ctg_left = Contig(ctg.genome, ctg.chr, ctg.start, ctg_mid) + ctg_right = Contig(ctg.genome, ctg.chr, ctg_mid, ctg.end) + + # add left + ctg_left_len = ctg_left.end - ctg_left.start + heapq.heappush(contig_heapq, (-ctg_left_len, ctg_left)) + + # add right + ctg_right_len = ctg_right.end - ctg_right.start + heapq.heappush(contig_heapq, (-ctg_right_len, ctg_right)) + + # return to list + contigs = [len_ctg[1] for len_ctg in contig_heapq] + + return contigs + + +def contig_sequences(contigs, seq_length, stride, snap=1, label=None): + """Break up a list of Contig's into a list of model length + and stride sequence contigs.""" + mseqs = [] + + for ctg in contigs: + seq_start = int(np.ceil(ctg.start / snap) * snap) + seq_end = seq_start + seq_length + + while seq_end < ctg.end: + # record sequence + mseqs.append(ModelSeq(ctg.genome, ctg.chr, seq_start, seq_end, label)) + + # update + seq_start += stride + seq_end += stride + + return mseqs + + +def load_chromosomes(genome_file): + """Load genome segments from either a FASTA file or + chromosome length table.""" + + # is genome_file FASTA or (chrom,start,end) table? + file_fasta = open(genome_file).readline()[0] == ">" + + chrom_segments = {} + + if file_fasta: + fasta_open = pysam.Fastafile(genome_file) + for i in range(len(fasta_open.references)): + chrom_segments[fasta_open.references[i]] = [(0, fasta_open.lengths[i])] + fasta_open.close() + + else: + for line in open(genome_file): + a = line.split() + chrom_segments[a[0]] = [(0, int(a[1]))] + + return chrom_segments + + +def rejoin_large_contigs(contigs): + """Rejoin large contigs that were broken up before alignment comparison.""" + + # split list by genome/chromosome + gchr_contigs = {} + for ctg in contigs: + gchr = (ctg.genome, ctg.chr) + gchr_contigs.setdefault(gchr, []).append(ctg) + + contigs = [] + for gchr in gchr_contigs: + # sort within chromosome + gchr_contigs[gchr].sort(key=lambda x: x.start) + # gchr_contigs[gchr] = sorted(gchr_contigs[gchr], key=lambda ctg: ctg.start) + + ctg_ongoing = gchr_contigs[gchr][0] + for i in range(1, len(gchr_contigs[gchr])): + ctg_this = gchr_contigs[gchr][i] + if ctg_ongoing.end == ctg_this.start: + # join + # ctg_ongoing.end = ctg_this.end + ctg_ongoing = ctg_ongoing._replace(end=ctg_this.end) + else: + # conclude ongoing + contigs.append(ctg_ongoing) + + # move to next + ctg_ongoing = ctg_this + + # conclude final + contigs.append(ctg_ongoing) + + return contigs + + +def split_contigs(chrom_segments, gaps_file): + """Split the assembly up into contigs defined by the gaps. + + Args: + chrom_segments: dict mapping chromosome names to lists of (start,end) + gaps_file: file specifying assembly gaps + + Returns: + chrom_segments: same, with segments broken by the assembly gaps. + """ + + chrom_events = {} + + # add known segments + for chrom in chrom_segments: + if len(chrom_segments[chrom]) > 1: + print( + "I've made a terrible mistake...regarding the length of chrom_segments[%s]" + % chrom, + file=sys.stderr, + ) + exit(1) + cstart, cend = chrom_segments[chrom][0] + chrom_events.setdefault(chrom, []).append((cstart, "Cstart")) + chrom_events[chrom].append((cend, "cend")) + + # add gaps + for line in open(gaps_file): + a = line.split() + chrom = a[0] + gstart = int(a[1]) + gend = int(a[2]) + + # consider only if its in our genome + if chrom in chrom_events: + chrom_events[chrom].append((gstart, "gstart")) + chrom_events[chrom].append((gend, "Gend")) + + for chrom in chrom_events: + # sort + chrom_events[chrom].sort() + + # read out segments + chrom_segments[chrom] = [] + for i in range(len(chrom_events[chrom]) - 1): + pos1, event1 = chrom_events[chrom][i] + pos2, event2 = chrom_events[chrom][i + 1] + + event1 = event1.lower() + event2 = event2.lower() + + shipit = False + if event1 == "cstart" and event2 == "cend": + shipit = True + elif event1 == "cstart" and event2 == "gstart": + shipit = True + elif event1 == "gend" and event2 == "gstart": + shipit = True + elif event1 == "gend" and event2 == "cend": + shipit = True + elif event1 == "gstart" and event2 == "gend": + pass + else: + print( + "I'm confused by this event ordering: %s - %s" % (event1, event2), + file=sys.stderr, + ) + exit(1) + + if shipit and pos1 < pos2: + chrom_segments[chrom].append((pos1, pos2)) + + return chrom_segments + + +def write_seqs_bed(bed_file, seqs, labels=False): + """Write sequences to BED file.""" + bed_out = open(bed_file, "w") + for i in range(len(seqs)): + line = "%s\t%d\t%d" % (seqs[i].chr, seqs[i].start, seqs[i].end) + if labels: + line += "\t%s" % seqs[i].label + print(line, file=bed_out) + bed_out.close() + + +################################################################################ +Contig = collections.namedtuple("Contig", ["genome", "chr", "start", "end"]) +ModelSeq = collections.namedtuple( + "ModelSeq", ["genome", "chr", "start", "end", "label"] +) diff --git a/src/baskerville/dataset.py b/src/baskerville/dataset.py index f061b97..c8360c3 100644 --- a/src/baskerville/dataset.py +++ b/src/baskerville/dataset.py @@ -319,7 +319,7 @@ def make_strand_transform(targets_df, targets_strand_df): targets_strand_df (pd.DataFrame): Targets DataFrame, with strand pairs collapsed. Returns: - scipy.sparse.csr_matrix: Sparse matrix to sum strand pairs. + scipy.sparse.dok_matrix: Sparse matrix to sum strand pairs. """ # initialize sparse matrix @@ -336,7 +336,6 @@ def make_strand_transform(targets_df, targets_strand_df): if target.identifier[-1] == "-": sti += 1 ti += 1 - strand_transform = strand_transform.tocsr() return strand_transform @@ -367,7 +366,7 @@ def targets_prep_strand(targets_df): return targets_strand_df -def untransform_preds(preds, targets_df, unscale=False): +def untransform_preds(preds, targets_df, unscale=False, unclip=True): """Undo the squashing transformations performed for the tasks. Args: @@ -378,9 +377,10 @@ def untransform_preds(preds, targets_df, unscale=False): preds (np.array): Untransformed predictions LxT. """ # clip soft - cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) - preds_unclip = cs - 1 + (preds - cs + 1) ** 2 - preds = np.where(preds > cs, preds_unclip, preds) + if unclip: + cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) + preds_unclip = cs - 1 + (preds - cs + 1) ** 2 + preds = np.where(preds > cs, preds_unclip, preds) # sqrt sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat]) @@ -394,7 +394,7 @@ def untransform_preds(preds, targets_df, unscale=False): return preds -def untransform_preds1(preds, targets_df, unscale=False): +def untransform_preds1(preds, targets_df, unscale=False, unclip=True): """Undo the squashing transformations performed for the tasks. Args: @@ -409,9 +409,10 @@ def untransform_preds1(preds, targets_df, unscale=False): preds = preds / scale # clip soft - cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) - preds_unclip = cs + (preds - cs) ** 2 - preds = np.where(preds > cs, preds_unclip, preds) + if unclip: + cs = np.expand_dims(np.array(targets_df.clip_soft), axis=0) + preds_unclip = cs + (preds - cs) ** 2 + preds = np.where(preds > cs, preds_unclip, preds) # ** 0.75 sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat]) diff --git a/src/baskerville/gene.py b/src/baskerville/gene.py index 759bf0e..0645385 100644 --- a/src/baskerville/gene.py +++ b/src/baskerville/gene.py @@ -13,6 +13,7 @@ # limitations under the License. # ========================================================================= +import gzip from intervaltree import IntervalTree import numpy as np import pybedtools @@ -77,9 +78,19 @@ def span(self): exon_ends = [exon.end for exon in self.exons] return min(exon_starts), max(exon_ends) - def output_slice(self, seq_start, seq_len, model_stride, span=False): + def output_slice( + self, seq_start, seq_len, model_stride, span=False, majority_overlap=False + ): gene_slice = [] + def clip_boundaries(slice_start, slice_end): + slice_max = int(seq_len / model_stride) + slice_start = min(slice_start, slice_max) + slice_end = min(slice_end, slice_max) + slice_start = max(slice_start, 0) + slice_end = max(slice_end, 0) + return slice_start, slice_end + if span: gene_start, gene_end = self.span() @@ -91,12 +102,12 @@ def output_slice(self, seq_start, seq_len, model_stride, span=False): slice_start = int(np.round(gene_seq_start / model_stride)) slice_end = int(np.round(gene_seq_end / model_stride)) - # clip right boundaries - slice_max = int(seq_len / model_stride) - slice_start = min(slice_start, slice_max) - slice_end = min(slice_end, slice_max) + # clip boundaries + slice_start, slice_end = clip_boundaries(slice_start, slice_end) - gene_slice = range(slice_start, slice_end) + # add to gene slice + if slice_start < slice_end: + gene_slice = range(slice_start, slice_end) else: for exon in self.get_exons(): @@ -104,18 +115,26 @@ def output_slice(self, seq_start, seq_len, model_stride, span=False): exon_seq_start = max(0, exon.begin - seq_start) exon_seq_end = max(0, exon.end - seq_start) - # requires >50% overlap - slice_start = int(np.round(exon_seq_start / model_stride)) - slice_end = int(np.round(exon_seq_end / model_stride)) + if majority_overlap: + # requires >50% overlap + slice_start = int(np.round(exon_seq_start / model_stride)) + slice_end = int(np.round(exon_seq_end / model_stride)) + else: + # any overlap + slice_start = int(np.floor(exon_seq_start / model_stride)) + slice_end = int(np.ceil(exon_seq_end / model_stride)) + + # clip boundaries + slice_start, slice_end = clip_boundaries(slice_start, slice_end) - # clip right boundaries - slice_max = int(seq_len / model_stride) - slice_start = min(slice_start, slice_max) - slice_end = min(slice_end, slice_max) + # add to gene slice + if slice_start < slice_end: + gene_slice.extend(range(slice_start, slice_end)) - gene_slice.extend(range(slice_start, slice_end)) + # collapse overlaps + gene_slice = np.unique(gene_slice) - return np.array(gene_slice) + return gene_slice class Transcriptome: diff --git a/src/baskerville/helpers/gcs_utils.py b/src/baskerville/helpers/gcs_utils.py index 35aad61..4184268 100644 --- a/src/baskerville/helpers/gcs_utils.py +++ b/src/baskerville/helpers/gcs_utils.py @@ -254,9 +254,9 @@ def download_rename_inputs(filepath: str, temp_dir: str, is_dir: bool = False) - Returns: new filepath in the local machine """ if is_dir: - download_folder_from_gcs(filepath, temp_dir) dir_name = filepath.split("/")[-1] - return temp_dir + download_folder_from_gcs(filepath, f"{temp_dir}/{dir_name}") + return f"{temp_dir}/{dir_name}" else: _, filename = split_gcs_uri(filepath) if "/" in filename: diff --git a/src/baskerville/helpers/tensorrt_helpers.py b/src/baskerville/helpers/tensorrt_helpers.py new file mode 100644 index 0000000..a586710 --- /dev/null +++ b/src/baskerville/helpers/tensorrt_helpers.py @@ -0,0 +1,138 @@ +import argparse +import json +import pdb +import time + +import numpy as np +import pandas as pd +import tensorflow as tf +from tensorflow.python.compiler.tensorrt import trt_convert as tf_trt + +from baskerville import seqnn + + +precision_dict = { + "FP32": tf_trt.TrtPrecisionMode.FP32, + "FP16": tf_trt.TrtPrecisionMode.FP16, + "INT8": tf_trt.TrtPrecisionMode.INT8, +} + + +class ModelOptimizer: + """ + Class of converter for tensorrt + Args: + input_saved_model_dir: Folder with saved model of the input model + """ + + def __init__(self, input_saved_model_dir, calibration_data=None): + self.input_saved_model_dir = input_saved_model_dir + self.calibration_data = None + if not calibration_data is None: + self.set_calibration_data(calibration_data) + + def set_calibration_data(self, calibration_data): + def calibration_input_fn(): + yield (tf.constant(calibration_data.astype("float32")),) + + self.calibration_data = calibration_input_fn + + def convert(self, precision="FP32"): + t0 = time.time() + print("Converting the model.") + + if precision == "INT8" and self.calibration_data is None: + raise (Exception("No calibration data set!")) + + trt_precision = precision_dict[precision] + conversion_params = tf_trt.DEFAULT_TRT_CONVERSION_PARAMS._replace( + precision_mode=trt_precision, + use_calibration=precision == "INT8", + max_workspace_size_bytes=8000000000, + ) + self.converter = tf_trt.TrtGraphConverterV2( + input_saved_model_dir=self.input_saved_model_dir, + conversion_params=conversion_params, + ) + + if precision == "INT8": + self.func = self.converter.convert( + calibration_input_fn=self.calibration_data + ) + else: + self.func = self.converter.convert() + print("Done in %ds" % (time.time() - t0)) + + def build(self, seq_length): + input_shape = (1, seq_length, 4) + t0 = time.time() + print("Building TRT engines for shape:", input_shape) + + def input_fn(): + x = np.random.random(input_shape).astype(np.float32) + x = tf.cast(x, tf.float32) + yield x + + self.converter.build(input_fn) + print("Done in %ds" % (time.time() - t0)) + + def build_func(self, seq_length): + input_shape = (1, seq_length, 4) + t0 = time.time() + print("Building TRT engines for shape:", input_shape) + x = np.random.random(input_shape) + x = tf.cast(x, tf.float32) + self.func(x) + print("Done in %ds" % (time.time() - t0)) + + def save(self, output_dir): + self.converter.save(output_saved_model_dir=output_dir) + + +def main(): + parser = argparse.ArgumentParser( + description="Convert a seqnn model to TensorRT model." + ) + parser.add_argument( + "-t", "--targets_file", default=None, help="Path to the target variants file" + ) + parser.add_argument( + "-o", + "--out_dir", + default="trt_out", + help="Output directory for storing saved models (original & converted)", + ) + parser.add_argument( + "params_file", type=str, help="Path to the JSON parameters file" + ) + parser.add_argument("model_file", help="Trained model HDF5.") + args = parser.parse_args() + + # Load parameters + with open(args.params_file) as params_open: + params = json.load(params_open) + + # Load keras model into seqnn class + seqnn_model = seqnn.SeqNN(params["model"]) + seqnn_model.restore(args.model_file) + + # Load target variants + if args.targets_file is not None: + targets_df = pd.read_csv(args.targets_file, sep="\t", index_col=0) + seqnn_model.build_slice(np.array(targets_df.index)) + + # ensemble rc + seqnn_model.build_ensemble(True) + + # save this model to a directory + seqnn_model.ensemble.save(f"{args.out_dir}/original") + + # Convert the model + opt_model = ModelOptimizer(f"{args.out_dir}/original") + opt_model.convert(precision="FP32") + # opt_model.build(seqnn_model.seq_length) + opt_model.save(f"{args.out_dir}/convert") + + +if __name__ == "__main__": + main() diff --git a/src/baskerville/helpers/trt_optimized_model.py b/src/baskerville/helpers/trt_optimized_model.py new file mode 100644 index 0000000..ad90325 --- /dev/null +++ b/src/baskerville/helpers/trt_optimized_model.py @@ -0,0 +1,45 @@ +import tensorflow as tf +from tensorflow.python.saved_model import tag_constants + + +class OptimizedModel: + """ + Class of model optimized with tensorrt + Args: + saved_model_dir: Folder with saved model + """ + + def __init__(self, saved_model_dir=None, strand_pair=[]): + self.loaded_model_fn = None + self.strand_pair = strand_pair + if not saved_model_dir is None: + self.load_model(saved_model_dir) + + def predict(self, input_data): + if self.loaded_model_fn is None: + raise (Exception("Haven't loaded a model")) + x = tf.cast(input_data, tf.float32) + labeling = self.loaded_model_fn(x) + try: + preds = labeling["predictions"].numpy() + except: + try: + preds = labeling["probs"].numpy() + except: + try: + preds = labeling[next(iter(labeling.keys()))] + except: + raise ( + Exception("Failed to get predictions from saved model object") + ) + return preds + + def load_model(self, saved_model_dir): + saved_model_loaded = tf.saved_model.load( + saved_model_dir, tags=[tag_constants.SERVING] + ) + wrapper_fp32 = saved_model_loaded.signatures["serving_default"] + self.loaded_model_fn = wrapper_fp32 + + def __call__(self, x): + return self.predict(x) diff --git a/src/baskerville/helpers/utils.py b/src/baskerville/helpers/utils.py index d1549dc..9892179 100644 --- a/src/baskerville/helpers/utils.py +++ b/src/baskerville/helpers/utils.py @@ -1,4 +1,62 @@ +import os import pickle +import sys +import subprocess +import time + + +def exec_par(cmds, max_proc=None, verbose=False): + """ + Execute the commands in the list 'cmds' in parallel, but + only running 'max_proc' at a time. + Args: + cmds: list of commands to execute + max_proc: maximum number of processes to run in parallel + verbose: print command to stderr + """ + total = len(cmds) + finished = 0 + running = 0 + p = [] + + if max_proc == None: + max_proc = len(cmds) + + if max_proc == 1: + while finished < total: + if verbose: + print(cmds[finished], file=sys.stderr) + op = subprocess.Popen(cmds[finished], shell=True) + os.waitpid(op.pid, 0) + finished += 1 + + else: + while finished + running < total: + # launch jobs up to max + while running < max_proc and finished + running < total: + if verbose: + print(cmds[finished + running], file=sys.stderr) + p.append(subprocess.Popen(cmds[finished + running], shell=True)) + # print 'Running %d' % p[running].pid + running += 1 + + # are any jobs finished + new_p = [] + for i in range(len(p)): + if p[i].poll() != None: + running -= 1 + finished += 1 + else: + new_p.append(p[i]) + + # if none finished, sleep + if len(new_p) == len(p): + time.sleep(1) + p = new_p + + # wait for all to finish + for i in range(len(p)): + p[i].wait() def load_extra_options(options_pkl_file, options): diff --git a/src/baskerville/metrics.py b/src/baskerville/metrics.py index 29c0d99..b540086 100644 --- a/src/baskerville/metrics.py +++ b/src/baskerville/metrics.py @@ -29,7 +29,11 @@ # Losses ################################################################################ def mean_squared_error_udot(y_true, y_pred, udot_weight: float = 1): - """Mean squared error with mean-normalized specificity term.""" + """Mean squared error with mean-normalized specificity term. + + Args: + udot_weight: Weight of the mean-normalized specificity term. + """ mse_term = tf.keras.losses.mean_squared_error(y_true, y_pred) yn_true = y_true - tf.math.reduce_mean(y_true, axis=-1, keepdims=True) @@ -43,7 +47,7 @@ class MeanSquaredErrorUDot(LossFunctionWrapper): """Mean squared error with mean-normalized specificity term. Args: - udot_weight: Weight of the mean-normalized specificity term. + udot_weight: Weight of the mean-normalized specificity term. """ def __init__( @@ -59,7 +63,13 @@ def __init__( ) -def poisson_kl(y_true, y_pred, kl_weight=1, epsilon=1e-3): +def poisson_kl(y_true, y_pred, kl_weight=1, epsilon=1e-7): + """Poisson decomposition with KL specificity term. + + Args: + kl_weight (float): Weight of the KL specificity term. + epsilon (float): Added small value to avoid log(0). + """ # poisson loss poisson_term = tf.keras.losses.poisson(y_true, y_pred) @@ -96,41 +106,66 @@ def __init__( super(PoissonKL, self).__init__(pois_kl, name=name, reduction=reduction) +def poisson(yt, yp, epsilon: float = 1e-7): + """Poisson loss, without mean reduction.""" + return yp - yt * tf.math.log(yp + epsilon) + + def poisson_multinomial( y_true, y_pred, total_weight: float = 1, - epsilon: float = 1e-6, + weight_range: float = 1, + weight_exp: int = 4, + epsilon: float = 1e-7, rescale: bool = False, ): """Possion decomposition with multinomial specificity term. Args: - total_weight (float): Weight of the Poisson total term. - epsilon (float): Added small value to avoid log(0). + total_weight (float): Weight of the Poisson total term. + epsilon (float): Added small value to avoid log(0). + rescale (bool): Rescale loss after re-weighting. """ seq_len = y_true.shape[1] + + if weight_range < 1: + raise ValueError("Poisson Multinomial weight_range must be >=1") + elif weight_range == 1: + position_weights = tf.ones((1, seq_len, 1)) + else: + pos_start = -(seq_len / 2 - 0.5) + pos_end = seq_len / 2 + 0.5 + positions = tf.range(pos_start, pos_end, dtype=tf.float32) + sigma = -pos_start / (np.log(weight_range)) ** (1 / weight_exp) + position_weights = tf.exp(-((positions / sigma) ** weight_exp)) + position_weights /= tf.reduce_max(position_weights) + position_weights = tf.expand_dims(position_weights, axis=0) + position_weights = tf.expand_dims(position_weights, axis=-1) + + y_true = tf.math.multiply(y_true, position_weights) + y_pred = tf.math.multiply(y_pred, position_weights) + + # sum across lengths + s_true = tf.math.reduce_sum(y_true, axis=-2) # B x T + s_pred = tf.math.reduce_sum(y_pred, axis=-2) # B x T + + # total count poisson loss, mean across targets + poisson_term = poisson(s_true, s_pred) # B x T + poisson_term /= tf.reduce_sum(position_weights) # add epsilon to protect against tiny values y_true += epsilon y_pred += epsilon - # sum across lengths - s_true = tf.math.reduce_sum(y_true, axis=-2, keepdims=True) - s_pred = tf.math.reduce_sum(y_pred, axis=-2, keepdims=True) - # normalize to sum to one - p_pred = y_pred / s_pred - - # total count poisson loss - poisson_term = tf.keras.losses.poisson(s_true, s_pred) # B x T - poisson_term /= seq_len + p_pred = y_pred / tf.expand_dims(s_pred, axis=-2) # B x L x T # multinomial loss pl_pred = tf.math.log(p_pred) # B x L x T multinomial_dot = -tf.math.multiply(y_true, pl_pred) # B x L x T multinomial_term = tf.math.reduce_sum(multinomial_dot, axis=-2) # B x T - multinomial_term /= seq_len + multinomial_term /= tf.reduce_sum(position_weights) # normalize to scale of 1:1 term ratio loss_raw = multinomial_term + total_weight * poisson_term @@ -147,17 +182,19 @@ class PoissonMultinomial(LossFunctionWrapper): Args: total_weight (float): Weight of the Poisson total term. - epsilon (float): Added small value to avoid log(0). """ def __init__( self, - total_weight=1, + total_weight: float = 1, + weight_range: float = 1, + weight_exp: int = 4, reduction=losses_utils.ReductionV2.AUTO, name: str = "poisson_multinomial", ): - self.total_weight = total_weight - pois_mn = lambda yt, yp: poisson_multinomial(yt, yp, self.total_weight) + pois_mn = lambda yt, yp: poisson_multinomial( + yt, yp, total_weight, weight_range, weight_exp + ) super(PoissonMultinomial, self).__init__( pois_mn, name=name, reduction=reduction ) diff --git a/src/baskerville/scripts/hound_data.py b/src/baskerville/scripts/hound_data.py new file mode 100755 index 0000000..a95284a --- /dev/null +++ b/src/baskerville/scripts/hound_data.py @@ -0,0 +1,905 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser +import gzip +import json +import pdb +import os +import random +import shutil +import subprocess +import sys +import tempfile + +import numpy as np +import pandas as pd + +from baskerville import data +from baskerville.helpers import utils + +try: + import slurm +except ModuleNotFoundError: + pass + +""" +hound_data.py + +Compute model sequences from the genome, extracting DNA coverage values. +""" + + +################################################################################ +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + parser.add_option( + "-b", + dest="blacklist_bed", + help="Set blacklist nucleotides to a baseline value.", + ) + parser.add_option( + "--break", + dest="break_t", + default=786432, + type="int", + help="Break in half contigs above length [Default: %default]", + ) + parser.add_option( + "-c", + "--crop", + dest="crop_bp", + default=0, + type="int", + help="Crop bp off each end [Default: %default]", + ) + parser.add_option( + "-d", + dest="decimals", + default=None, + type="int", + help="Round values to given decimals [Default: %default]", + ) + parser.add_option( + "-f", + dest="folds", + default=None, + type="int", + help="Generate cross fold split [Default: %default]", + ) + parser.add_option( + "-g", dest="gaps_file", help="Genome assembly gaps BED [Default: %default]" + ) + parser.add_option( + "-i", + dest="interp_nan", + default=False, + action="store_true", + help="Interpolate NaNs [Default: %default]", + ) + parser.add_option( + "-l", + dest="seq_length", + default=131072, + type="int", + help="Sequence length [Default: %default]", + ) + parser.add_option( + "--limit", + dest="limit_bed", + help="Limit to segments that overlap regions in a BED file", + ) + parser.add_option( + "--local", + dest="run_local", + default=False, + action="store_true", + help="Run jobs locally as opposed to on SLURM [Default: %default]", + ) + parser.add_option( + "-o", + dest="out_dir", + default="data_out", + help="Output directory [Default: %default]", + ) + parser.add_option( + "-p", + dest="processes", + default=None, + type="int", + help="Number parallel processes [Default: %default]", + ) + parser.add_option( + "--peaks", + dest="peaks_only", + default=False, + action="store_true", + help="Create contigs only from peaks [Default: %default]", + ) + parser.add_option( + "-r", + dest="seqs_per_tfr", + default=256, + type="int", + help="Sequences per TFRecord file [Default: %default]", + ) + parser.add_option( + "--restart", + dest="restart", + default=False, + action="store_true", + help="Continue progress from midpoint. [Default: %default]", + ) + parser.add_option( + "-s", + dest="sample_pct", + default=1.0, + type="float", + help="Down-sample the segments", + ) + parser.add_option( + "--seed", + dest="seed", + default=44, + type="int", + help="Random seed [Default: %default]", + ) + parser.add_option( + "--snap", + dest="snap", + default=1, + type="int", + help="Snap sequences to multiple of the given value [Default: %default]", + ) + parser.add_option( + "--st", + "--split_test", + dest="split_test", + default=False, + action="store_true", + help="Exit after split. [Default: %default]", + ) + parser.add_option( + "--stride", + "--stride_train", + dest="stride_train", + default=1.0, + type="float", + help="Stride to advance train sequences [Default: seq_length]", + ) + parser.add_option( + "--stride_test", + dest="stride_test", + default=1.0, + type="float", + help="Stride to advance valid and test sequences [Default: seq_length]", + ) + parser.add_option( + "-t", + dest="test_pct_or_chr", + default=0.05, + type="str", + help="Proportion of the data for testing [Default: %default]", + ) + parser.add_option("-u", dest="umap_bed", help="Unmappable regions in BED format") + parser.add_option( + "--umap_t", + dest="umap_t", + default=0.5, + type="float", + help="Remove sequences with more than this unmappable bin % [Default: %default]", + ) + parser.add_option( + "--umap_clip", + dest="umap_clip", + default=1, + type="float", + help="Clip values at unmappable positions to distribution quantiles, eg 0.25. [Default: %default]", + ) + parser.add_option( + "--umap_tfr", + dest="umap_tfr", + default=False, + action="store_true", + help="Save umap array into TFRecords [Default: %default]", + ) + parser.add_option( + "-w", + dest="pool_width", + default=32, + type="int", + help="Sum pool width [Default: %default]", + ) + parser.add_option( + "-v", + dest="valid_pct_or_chr", + default=0.05, + type="str", + help="Proportion of the data for validation [Default: %default]", + ) + (options, args) = parser.parse_args() + + if len(args) != 2: + parser.error("Must provide FASTA and sample coverage labels and paths.") + else: + fasta_file = args[0] + targets_file = args[1] + + random.seed(options.seed) + np.random.seed(options.seed) + + if options.break_t is not None and options.break_t < options.seq_length: + print( + "Maximum contig length --break cannot be less than sequence length.", + file=sys.stderr, + ) + exit(1) + + # transform proportion strides to base pairs + if options.stride_train <= 1: + print("stride_train %.f" % options.stride_train, end="") + options.stride_train = options.stride_train * options.seq_length + print(" converted to %f" % options.stride_train) + options.stride_train = int(np.round(options.stride_train)) + if options.stride_test <= 1: + if options.folds is None: + print("stride_test %.f" % options.stride_test, end="") + options.stride_test = options.stride_test * options.seq_length + print(" converted to %f" % options.stride_test) + options.stride_test = int(np.round(options.stride_test)) + + # check snap + if options.snap is not None: + if np.mod(options.seq_length, options.snap) != 0: + raise ValueError("seq_length must be a multiple of snap") + if np.mod(options.stride_train, options.snap) != 0: + raise ValueError("stride_train must be a multiple of snap") + if np.mod(options.stride_test, options.snap) != 0: + raise ValueError("stride_test must be a multiple of snap") + + # setup output directory + if os.path.isdir(options.out_dir) and not options.restart: + print("Remove output directory %s or use --restart option." % options.out_dir) + exit(1) + elif not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + # read target datasets + targets_df = pd.read_csv(targets_file, index_col=0, sep="\t") + + ################################################################ + # define genomic contigs + ################################################################ + if not options.restart: + chrom_contigs = data.load_chromosomes(fasta_file) + + # remove gaps + if options.gaps_file: + chrom_contigs = data.split_contigs(chrom_contigs, options.gaps_file) + + # ditch the chromosomes for contigs + contigs = [] + for chrom in chrom_contigs: + contigs += [ + data.Contig(0, chrom, ctg_start, ctg_end) + for ctg_start, ctg_end in chrom_contigs[chrom] + ] + + # limit to a BED file + if options.limit_bed is not None: + contigs = limit_contigs(contigs, options.limit_bed) + + # limit to peaks + if options.peaks_only: + peaks_bed = curate_peaks( + targets_df, options.out_dir, options.pool_width, options.crop_bp + ) + contigs = limit_contigs(contigs, peaks_bed) + + # filter for large enough + seq_tlength = options.seq_length - 2 * options.crop_bp + contigs = [ctg for ctg in contigs if ctg.end - ctg.start >= seq_tlength] + + # break up large contigs + if options.break_t is not None: + contigs = data.break_large_contigs(contigs, options.break_t) + + # print contigs to BED file + # ctg_bed_file = '%s/contigs.bed' % options.out_dir + # write_seqs_bed(ctg_bed_file, contigs) + + ################################################################ + # divide between train/valid/test + ################################################################ + # label folds + if options.folds is not None: + fold_labels = ["fold%d" % fi for fi in range(options.folds)] + num_folds = options.folds + else: + fold_labels = ["train", "valid", "test"] + num_folds = 3 + + if not options.restart: + if options.folds is not None: + # divide by fold pct + fold_contigs = divide_contigs_folds(contigs, options.folds) + + else: + try: + # convert to float pct + valid_pct = float(options.valid_pct_or_chr) + test_pct = float(options.test_pct_or_chr) + assert 0 <= valid_pct <= 1 + assert 0 <= test_pct <= 1 + + # divide by pct + fold_contigs = divide_contigs_pct(contigs, test_pct, valid_pct) + + except (ValueError, AssertionError): + # divide by chr + valid_chrs = options.valid_pct_or_chr.split(",") + test_chrs = options.test_pct_or_chr.split(",") + fold_contigs = divide_contigs_chr(contigs, test_chrs, valid_chrs) + + # rejoin broken contigs within set + for fi in range(len(fold_contigs)): + fold_contigs[fi] = data.rejoin_large_contigs(fold_contigs[fi]) + + # write labeled contigs to BED file + ctg_bed_file = "%s/contigs.bed" % options.out_dir + ctg_bed_out = open(ctg_bed_file, "w") + for fi in range(len(fold_contigs)): + for ctg in fold_contigs[fi]: + line = "%s\t%d\t%d\t%s" % (ctg.chr, ctg.start, ctg.end, fold_labels[fi]) + print(line, file=ctg_bed_out) + ctg_bed_out.close() + + if options.split_test: + exit() + + ################################################################ + # define model sequences + ################################################################ + if not options.restart: + fold_mseqs = [] + for fi in range(num_folds): + if fold_labels[fi] in ["valid", "test"]: + stride_fold = options.stride_test + else: + stride_fold = options.stride_train + + # stride sequences across contig + fold_mseqs_fi = data.contig_sequences( + fold_contigs[fi], + seq_tlength, + stride_fold, + options.snap, + fold_labels[fi], + ) + fold_mseqs.append(fold_mseqs_fi) + + # shuffle + random.shuffle(fold_mseqs[fi]) + + # down-sample + if options.sample_pct < 1.0: + fold_mseqs[fi] = random.sample( + fold_mseqs[fi], int(options.sample_pct * len(fold_mseqs[fi])) + ) + + # merge into one list + mseqs = [ms for fm in fold_mseqs for ms in fm] + + ################################################################ + # mappability + ################################################################ + if not options.restart: + if options.umap_bed is not None: + if shutil.which("bedtools") is None: + print("Install Bedtools to annotate unmappable sites", file=sys.stderr) + exit(1) + + # annotate unmappable positions + mseqs_unmap = data.annotate_unmap( + mseqs, options.umap_bed, seq_tlength, options.pool_width + ) + + # filter unmappable + mseqs_map_mask = mseqs_unmap.mean(axis=1, dtype="float64") < options.umap_t + mseqs = [mseqs[i] for i in range(len(mseqs)) if mseqs_map_mask[i]] + mseqs_unmap = mseqs_unmap[mseqs_map_mask, :] + + # write to file + unmap_npy = "%s/mseqs_unmap.npy" % options.out_dir + np.save(unmap_npy, mseqs_unmap) + + # write sequences to BED + seqs_bed_file = "%s/sequences.bed" % options.out_dir + data.write_seqs_bed(seqs_bed_file, mseqs, True) + + else: + # read from directory + seqs_bed_file = "%s/sequences.bed" % options.out_dir + unmap_npy = "%s/mseqs_unmap.npy" % options.out_dir + mseqs = [] + fold_mseqs = [] + for fi in range(num_folds): + fold_mseqs.append([]) + for line in open(seqs_bed_file): + a = line.split() + msg = data.ModelSeq(0, a[0], int(a[1]), int(a[2]), a[3]) + mseqs.append(msg) + if a[3] == "train": + fi = 0 + elif a[3] == "valid": + fi = 1 + elif a[3] == "test": + fi = 2 + else: + fi = int(a[3].replace("fold", "")) + fold_mseqs[fi].append(msg) + + ################################################################ + # read sequence coverage values + ################################################################ + seqs_cov_dir = "%s/seqs_cov" % options.out_dir + if not os.path.isdir(seqs_cov_dir): + os.mkdir(seqs_cov_dir) + + read_jobs = [] + + for ti in range(targets_df.shape[0]): + genome_cov_file = targets_df["file"].iloc[ti] + seqs_cov_stem = "%s/%d" % (seqs_cov_dir, ti) + seqs_cov_file = "%s.h5" % seqs_cov_stem + + clip_ti = None + if "clip" in targets_df.columns: + clip_ti = targets_df["clip"].iloc[ti] + + clipsoft_ti = None + if "clip_soft" in targets_df.columns: + clipsoft_ti = targets_df["clip_soft"].iloc[ti] + + scale_ti = 1 + if "scale" in targets_df.columns: + scale_ti = targets_df["scale"].iloc[ti] + + if options.restart and os.path.isfile(seqs_cov_file): + print("Skipping existing %s" % seqs_cov_file, file=sys.stderr) + else: + cmd = "hound_data_read.py" + cmd += " -w %d" % options.pool_width + cmd += " -u %s" % targets_df["sum_stat"].iloc[ti] + if clip_ti is not None: + cmd += " -c %f" % clip_ti + if clipsoft_ti is not None: + cmd += " --clip_soft %f" % clipsoft_ti + cmd += " -s %f" % scale_ti + if options.blacklist_bed: + cmd += " -b %s" % options.blacklist_bed + if options.interp_nan: + cmd += " -i" + cmd += " %s" % genome_cov_file + cmd += " %s" % seqs_bed_file + cmd += " %s" % seqs_cov_file + + if options.run_local: + # breaks on some OS + # cmd += ' &> %s.err' % seqs_cov_stem + read_jobs.append(cmd) + else: + j = slurm.Job( + cmd, + name="read_t%d" % ti, + out_file="%s.out" % seqs_cov_stem, + err_file="%s.err" % seqs_cov_stem, + queue="standard", + mem=15000, + time="12:0:0", + ) + read_jobs.append(j) + + if options.run_local: + utils.exec_par(read_jobs, options.processes, verbose=True) + else: + slurm.multi_run( + read_jobs, options.processes, verbose=True, launch_sleep=1, update_sleep=5 + ) + + ################################################################ + # write TF Records + ################################################################ + # copy targets file + shutil.copy(targets_file, "%s/targets.txt" % options.out_dir) + + # initialize TF Records dir + tfr_dir = "%s/tfrecords" % options.out_dir + if not os.path.isdir(tfr_dir): + os.mkdir(tfr_dir) + + write_jobs = [] + + for fold_set in fold_labels: + fold_set_indexes = [i for i in range(len(mseqs)) if mseqs[i].label == fold_set] + fold_set_start = fold_set_indexes[0] + fold_set_end = fold_set_indexes[-1] + 1 + + tfr_i = 0 + tfr_start = fold_set_start + tfr_end = min(tfr_start + options.seqs_per_tfr, fold_set_end) + + while tfr_start <= fold_set_end: + tfr_stem = "%s/%s-%d" % (tfr_dir, fold_set, tfr_i) + + tfr_file = "%s.tfr" % tfr_stem + if options.restart and os.path.isfile(tfr_file): + print("Skipping existing %s" % tfr_file, file=sys.stderr) + else: + cmd = "hound_data_write.py" + cmd += " -s %d" % tfr_start + cmd += " -e %d" % tfr_end + cmd += " --umap_clip %f" % options.umap_clip + cmd += " -x %d" % options.crop_bp + if options.decimals is not None: + cmd += " -d %d" % options.decimals + if options.umap_tfr: + cmd += " --umap_tfr" + if options.umap_bed is not None: + cmd += " -u %s" % unmap_npy + + cmd += " %s" % fasta_file + cmd += " %s" % seqs_bed_file + cmd += " %s" % seqs_cov_dir + cmd += " %s" % tfr_file + + if options.run_local: + # breaks on some OS + # cmd += ' &> %s.err' % tfr_stem + write_jobs.append(cmd) + else: + j = slurm.Job( + cmd, + name="write_%s-%d" % (fold_set, tfr_i), + out_file="%s.out" % tfr_stem, + err_file="%s.err" % tfr_stem, + queue="standard", + mem=15000, + time="12:0:0", + ) + write_jobs.append(j) + + # update + tfr_i += 1 + tfr_start += options.seqs_per_tfr + tfr_end = min(tfr_start + options.seqs_per_tfr, fold_set_end) + + if options.run_local: + utils.exec_par(write_jobs, options.processes, verbose=True) + else: + slurm.multi_run( + write_jobs, options.processes, verbose=True, launch_sleep=1, update_sleep=5 + ) + + ################################################################ + # stats + ################################################################ + stats_dict = {} + stats_dict["num_targets"] = targets_df.shape[0] + stats_dict["seq_length"] = options.seq_length + stats_dict["seq_1hot"] = True + stats_dict["pool_width"] = options.pool_width + stats_dict["crop_bp"] = options.crop_bp + + target_length = options.seq_length - 2 * options.crop_bp + target_length = target_length // options.pool_width + stats_dict["target_length"] = target_length + + for fi in range(num_folds): + stats_dict["%s_seqs" % fold_labels[fi]] = len(fold_mseqs[fi]) + + with open("%s/statistics.json" % options.out_dir, "w") as stats_json_out: + json.dump(stats_dict, stats_json_out, indent=4) + + +################################################################################ +def curate_peaks(targets_df, out_dir, pool_width, crop_bp): + """Merge all peaks, round to nearest pool_width, and add cropped bp.""" + + # concatenate and extend peaks + cat_bed_file = "%s/peaks_cat.bed" % out_dir + cat_bed_out = open(cat_bed_file, "w") + for bed_file in targets_df.file: + if bed_file[-3:] == ".gz": + bed_in = gzip.open(bed_file, "rt") + else: + bed_in = open(bed_file, "r") + + for line in bed_in: + a = line.rstrip().split("\t") + chrm = a[0] + start = int(a[1]) + end = int(a[2]) + + # extend to pool width + length = end - start + if length < pool_width: + mid = (start + end) // 2 + start = mid - pool_width // 2 + end = start + pool_width + + # add cropped bp + start = max(0, start - crop_bp) + end += crop_bp + + # print + print("%s\t%d\t%d" % (chrm, start, end), file=cat_bed_out) + + bed_in.close() + cat_bed_out.close() + + # merge + merge_bed_file = "%s/peaks_merge.bed" % out_dir + bedtools_cmd = "bedtools sort -i %s" % cat_bed_file + bedtools_cmd += " | bedtools merge -i - > %s" % merge_bed_file + subprocess.call(bedtools_cmd, shell=True) + + # round and add crop_bp + full_bed_file = "%s/peaks_full.bed" % out_dir + full_bed_out = open(full_bed_file, "w") + + for line in open(merge_bed_file): + a = line.rstrip().split("\t") + chrm = a[0] + start = int(a[1]) + end = int(a[2]) + mid = (start + end) // 2 + length = end - start + + # round length to nearest pool_width + bins = int(np.round(length / pool_width)) + assert bins > 0 + start = mid - (bins * pool_width) // 2 + start = max(0, start) + end = start + (bins * pool_width) + + # write + print("%s\t%d\t%d" % (chrm, start, end), file=full_bed_out) + + full_bed_out.close() + + return full_bed_file + + +################################################################################ +def divide_contigs_chr(contigs, test_chrs, valid_chrs): + """Divide list of contigs into train/valid/test lists + by chromosome.""" + + # initialize current train/valid/test nucleotides + train_nt = 0 + valid_nt = 0 + test_nt = 0 + + # initialize train/valid/test contig lists + train_contigs = [] + valid_contigs = [] + test_contigs = [] + + # process contigs + for ctg in contigs: + ctg_len = ctg.end - ctg.start + + if ctg.chr in test_chrs: + test_contigs.append(ctg) + test_nt += ctg_len + elif ctg.chr in valid_chrs: + valid_contigs.append(ctg) + valid_nt += ctg_len + else: + train_contigs.append(ctg) + train_nt += ctg_len + + total_nt = train_nt + valid_nt + test_nt + + print("Contigs divided into") + print( + " Train: %5d contigs, %10d nt (%.4f)" + % (len(train_contigs), train_nt, train_nt / total_nt) + ) + print( + " Valid: %5d contigs, %10d nt (%.4f)" + % (len(valid_contigs), valid_nt, valid_nt / total_nt) + ) + print( + " Test: %5d contigs, %10d nt (%.4f)" + % (len(test_contigs), test_nt, test_nt / total_nt) + ) + + return [train_contigs, valid_contigs, test_contigs] + + +################################################################################ +def divide_contigs_folds(contigs, folds): + """Divide list of contigs into cross fold lists.""" + + # sort contigs descending by length + length_contigs = [(ctg.end - ctg.start, ctg) for ctg in contigs] + length_contigs.sort(reverse=True) + + # compute total nucleotides + total_nt = sum([lc[0] for lc in length_contigs]) + + # compute aimed fold nucleotides + fold_nt_aim = int(np.ceil(total_nt / folds)) + + # initialize current fold nucleotides + fold_nt = np.zeros(folds) + + # initialize fold contig lists + fold_contigs = [] + for fi in range(folds): + fold_contigs.append([]) + + # process contigs + for ctg_len, ctg in length_contigs: + + # compute gap between current and aim + fold_nt_gap = fold_nt_aim - fold_nt + fold_nt_gap = np.clip(fold_nt_gap, 0, np.inf) + + # compute sample probability + fold_prob = fold_nt_gap / fold_nt_gap.sum() + + # sample train/valid/test + fi = np.random.choice(folds, p=fold_prob) + fold_contigs[fi].append(ctg) + fold_nt[fi] += ctg_len + + print("Contigs divided into") + for fi in range(folds): + print( + " Fold%d: %5d contigs, %10d nt (%.4f)" + % (fi, len(fold_contigs[fi]), fold_nt[fi], fold_nt[fi] / total_nt) + ) + + return fold_contigs + + +################################################################################ +def divide_contigs_pct(contigs, test_pct, valid_pct, pct_abstain=0.2): + """Divide list of contigs into train/valid/test lists, + aiming for the specified nucleotide percentages.""" + + # sort contigs descending by length + length_contigs = [(ctg.end - ctg.start, ctg) for ctg in contigs] + length_contigs.sort(reverse=True) + + # compute total nucleotides + total_nt = sum([lc[0] for lc in length_contigs]) + + # compute aimed train/valid/test nucleotides + test_nt_aim = test_pct * total_nt + valid_nt_aim = valid_pct * total_nt + train_nt_aim = total_nt - valid_nt_aim - test_nt_aim + + # initialize current train/valid/test nucleotides + train_nt = 0 + valid_nt = 0 + test_nt = 0 + + # initialize train/valid/test contig lists + train_contigs = [] + valid_contigs = [] + test_contigs = [] + + # process contigs + for ctg_len, ctg in length_contigs: + + # compute gap between current and aim + test_nt_gap = max(0, test_nt_aim - test_nt) + valid_nt_gap = max(0, valid_nt_aim - valid_nt) + train_nt_gap = max(1, train_nt_aim - train_nt) + + # skip if too large + if ctg_len > pct_abstain * test_nt_gap: + test_nt_gap = 0 + if ctg_len > pct_abstain * valid_nt_gap: + valid_nt_gap = 0 + + # compute remaining % + gap_sum = train_nt_gap + valid_nt_gap + test_nt_gap + test_pct_gap = test_nt_gap / gap_sum + valid_pct_gap = valid_nt_gap / gap_sum + train_pct_gap = train_nt_gap / gap_sum + + # sample train/valid/test + ri = np.random.choice( + range(3), 1, p=[train_pct_gap, valid_pct_gap, test_pct_gap] + )[0] + if ri == 0: + train_contigs.append(ctg) + train_nt += ctg_len + elif ri == 1: + valid_contigs.append(ctg) + valid_nt += ctg_len + elif ri == 2: + test_contigs.append(ctg) + test_nt += ctg_len + else: + print("TVT random number beyond 0,1,2", file=sys.stderr) + exit(1) + + print("Contigs divided into") + print( + " Train: %5d contigs, %10d nt (%.4f)" + % (len(train_contigs), train_nt, train_nt / total_nt) + ) + print( + " Valid: %5d contigs, %10d nt (%.4f)" + % (len(valid_contigs), valid_nt, valid_nt / total_nt) + ) + print( + " Test: %5d contigs, %10d nt (%.4f)" + % (len(test_contigs), test_nt, test_nt / total_nt) + ) + + return [train_contigs, valid_contigs, test_contigs] + + +################################################################################ +def limit_contigs(contigs, filter_bed): + """Limit to contigs overlapping the given BED. + + Args + contigs: list of Contigs + filter_bed: BED file to filter by + + Returns: + fcontigs: list of Contigs + """ + + # print ctgments to BED + ctg_fd, ctg_bed_file = tempfile.mkstemp() + ctg_bed_out = open(ctg_bed_file, "w") + for ctg in contigs: + print("%s\t%d\t%d" % (ctg.chr, ctg.start, ctg.end), file=ctg_bed_out) + ctg_bed_out.close() + + # intersect w/ filter_bed + fcontigs = [] + p = subprocess.Popen( + "bedtools intersect -a %s -b %s" % (ctg_bed_file, filter_bed), + shell=True, + stdout=subprocess.PIPE, + ) + for line in p.stdout: + a = line.decode("utf-8").split() + chrom = a[0] + ctg_start = int(a[1]) + ctg_end = int(a[2]) + fcontigs.append(data.Contig(0, chrom, ctg_start, ctg_end)) + + p.communicate() + + os.close(ctg_fd) + os.remove(ctg_bed_file) + + return fcontigs + + +if __name__ == "__main__": + main() diff --git a/src/baskerville/scripts/hound_data_align.py b/src/baskerville/scripts/hound_data_align.py new file mode 100755 index 0000000..d1ae93d --- /dev/null +++ b/src/baskerville/scripts/hound_data_align.py @@ -0,0 +1,978 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser +import collections +import gzip +import heapq +import pdb +import os +import random +import sys + +import networkx as nx +import numpy as np +import pybedtools + +from baskerville import data + +""" +hound_data_align.py + +Partition sequences from multiple aligned genomes into train/valid/test splits +that respect homology. +""" + + +################################################################################ +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + parser.add_option( + "-a", dest="genome_labels", default=None, help="Genome labels in output" + ) + parser.add_option( + "--break", + dest="break_t", + default=None, + type="int", + help="Break in half contigs above length [Default: %default]", + ) + parser.add_option( + "-c", + "--crop", + dest="crop_bp", + default=0, + type="int", + help="Crop bp off each end [Default: %default]", + ) + parser.add_option( + "-f", + dest="folds", + default=None, + type="int", + help="Generate cross fold split [Default: %default]", + ) + parser.add_option( + "-g", + dest="gap_files", + help="Comma-separated list of assembly gaps BED files [Default: %default]", + ) + parser.add_option( + "-l", + dest="seq_length", + default=131072, + type="int", + help="Sequence length [Default: %default]", + ) + parser.add_option( + "--nf", + dest="net_fill_min", + default=100000, + type="int", + help="Alignment net fill size minimum [Default: %default]", + ) + parser.add_option( + "--no", + dest="net_olap_min", + default=1024, + type="int", + help="Alignment net and contig overlap minimum [Default: %default]", + ) + parser.add_option( + "-o", + dest="out_dir", + default="align_out", + help="Output directory [Default: %default]", + ) + parser.add_option( + "-s", + dest="sample_pct", + default=1.0, + type="float", + help="Down-sample the segments", + ) + parser.add_option( + "--seed", + dest="seed", + default=44, + type="int", + help="Random seed [Default: %default]", + ) + parser.add_option( + "--snap", + dest="snap", + default=1, + type="int", + help="Snap sequences to multiple of the given value [Default: %default]", + ) + parser.add_option( + "--stride", + "--stride_train", + dest="stride_train", + default=1.0, + type="float", + help="Stride to advance train sequences [Default: seq_length]", + ) + parser.add_option( + "--stride_test", + dest="stride_test", + default=1.0, + type="float", + help="Stride to advance valid and test sequences [Default: %default]", + ) + parser.add_option( + "-t", + dest="test_pct", + default=0.1, + type="float", + help="Proportion of the data for testing [Default: %default]", + ) + parser.add_option( + "-u", + dest="umap_beds", + help="Comma-separated genome unmappable segments to set to NA", + ) + parser.add_option( + "--umap_t", + dest="umap_t", + default=0.5, + type="float", + help="Remove sequences with more than this unmappable bin % [Default: %default]", + ) + parser.add_option( + "-w", + dest="pool_width", + default=32, + type="int", + help="Sum pool width [Default: %default]", + ) + parser.add_option( + "-v", + dest="valid_pct", + default=0.1, + type="float", + help="Proportion of the data for validation [Default: %default]", + ) + (options, args) = parser.parse_args() + + if len(args) != 2: + parser.error("Must provide alignment and FASTA files.") + else: + align_net_file = args[0] + fasta_files = args[1].split(",") + + # there is still some source of stochasticity + random.seed(options.seed) + np.random.seed(options.seed) + + # transform proportion strides to base pairs + if options.stride_train <= 1: + print("stride_train %.f" % options.stride_train, end="") + options.stride_train = options.stride_train * options.seq_length + print(" converted to %f" % options.stride_train) + options.stride_train = int(np.round(options.stride_train)) + if options.stride_test <= 1: + print("stride_test %.f" % options.stride_test, end="") + options.stride_test = options.stride_test * options.seq_length + print(" converted to %f" % options.stride_test) + options.stride_test = int(np.round(options.stride_test)) + + # check snap + if options.snap is not None: + if np.mod(options.seq_length, options.snap) != 0: + raise ValueError("seq_length must be a multiple of snap") + if np.mod(options.stride_train, options.snap) != 0: + raise ValueError("stride_train must be a multiple of snap") + if np.mod(options.stride_test, options.snap) != 0: + raise ValueError("stride_test must be a multiple of snap") + + # count genomes + num_genomes = len(fasta_files) + + # parse gap files + if options.gap_files is not None: + options.gap_files = options.gap_files.split(",") + assert len(options.gap_files) == num_genomes + + # parse unmappable files + if options.umap_beds is not None: + options.umap_beds = options.umap_beds.split(",") + assert len(options.umap_beds) == num_genomes + + # label genomes + if options.genome_labels is None: + options.genome_labels = ["genome%d" % (gi + 1) for gi in range(num_genomes)] + else: + options.genome_labels = options.genome_labels.split(",") + assert len(options.genome_labels) == num_genomes + + # create output directorys + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + genome_out_dirs = [] + for gi in range(num_genomes): + gout_dir = "%s/%s" % (options.out_dir, options.genome_labels[gi]) + if not os.path.isdir(gout_dir): + os.mkdir(gout_dir) + genome_out_dirs.append(gout_dir) + + ################################################################ + # define genomic contigs + ################################################################ + genome_chr_contigs = [] + for gi in range(num_genomes): + genome_chr_contigs.append(data.load_chromosomes(fasta_files[gi])) + + # remove gaps + if options.gap_files[gi]: + genome_chr_contigs[gi] = data.split_contigs( + genome_chr_contigs[gi], options.gap_files[gi] + ) + + # ditch the chromosomes + contigs = [] + for gi in range(num_genomes): + for chrom in genome_chr_contigs[gi]: + contigs += [ + data.Contig(gi, chrom, ctg_start, ctg_end) + for ctg_start, ctg_end in genome_chr_contigs[gi][chrom] + ] + + # filter for large enough + seq_tlength = options.seq_length - 2 * options.crop_bp + contigs = [ctg for ctg in contigs if ctg.end - ctg.start >= seq_tlength] + + # break up large contigs + if options.break_t is not None: + contigs = break_large_contigs(contigs, options.break_t) + + # print contigs to BED file + for gi in range(num_genomes): + contigs_i = [ctg for ctg in contigs if ctg.genome == gi] + ctg_bed_file = "%s/contigs.bed" % genome_out_dirs[gi] + data.write_seqs_bed(ctg_bed_file, contigs_i) + + ################################################################ + # divide between train/valid/test + ################################################################ + + # connect contigs across genomes by alignment + contig_components = connect_contigs( + contigs, + align_net_file, + options.net_fill_min, + options.net_olap_min, + options.out_dir, + genome_out_dirs, + ) + + if options.folds is not None: + # divide by fold + fold_contigs = divide_components_folds(contig_components, options.folds) + + else: + # divide by train/valid/test pct + fold_contigs = divide_components_pct( + contig_components, options.test_pct, options.valid_pct + ) + + # rejoin broken contigs within set + for fi in range(len(fold_contigs)): + fold_contigs[fi] = data.rejoin_large_contigs(fold_contigs[fi]) + + # label folds + if options.folds is not None: + fold_labels = ["fold%d" % fi for fi in range(options.folds)] + num_folds = options.folds + else: + fold_labels = ["train", "valid", "test"] + num_folds = 3 + + if options.folds is None: + # quantify leakage across sets + quantify_leakage( + align_net_file, + fold_contigs[0], + fold_contigs[1], + fold_contigs[2], + options.out_dir, + ) + + ################################################################ + # define model sequences + ################################################################ + + fold_mseqs = [] + for fi in range(num_folds): + if fold_labels[fi] in ["valid", "test"]: + stride_fold = options.stride_test + else: + stride_fold = options.stride_train + + # stride sequences across contig + fold_mseqs_fi = data.contig_sequences( + fold_contigs[fi], seq_tlength, stride_fold, options.snap, fold_labels[fi] + ) + fold_mseqs.append(fold_mseqs_fi) + + # shuffle + random.shuffle(fold_mseqs[fi]) + + # down-sample + if options.sample_pct < 1.0: + fold_mseqs[fi] = random.sample( + fold_mseqs[fi], int(options.sample_pct * len(fold_mseqs[fi])) + ) + + # merge into one list + mseqs = [ms for fm in fold_mseqs for ms in fm] + + # separate by genome + mseqs_genome = [] + for gi in range(num_genomes): + mseqs_gi = [mseqs[si] for si in range(len(mseqs)) if mseqs[si].genome == gi] + mseqs_genome.append(mseqs_gi) + + ################################################################ + # filter for sufficient mappability + ################################################################ + for gi in range(num_genomes): + if options.umap_beds[gi] is not None: + # annotate unmappable positions + mseqs_unmap = data.annotate_unmap( + mseqs_genome[gi], options.umap_beds[gi], seq_tlength, options.pool_width + ) + + # filter unmappable + mseqs_map_mask = mseqs_unmap.mean(axis=1, dtype="float64") < options.umap_t + mseqs_genome[gi] = [ + mseqs_genome[gi][si] + for si in range(len(mseqs_genome[gi])) + if mseqs_map_mask[si] + ] + mseqs_unmap = mseqs_unmap[mseqs_map_mask, :] + + # write to file + unmap_npy_file = "%s/mseqs_unmap.npy" % genome_out_dirs[gi] + np.save(unmap_npy_file, mseqs_unmap) + + seqs_bed_files = [] + for gi in range(num_genomes): + # write sequences to BED + seqs_bed_files.append("%s/sequences.bed" % genome_out_dirs[gi]) + data.write_seqs_bed(seqs_bed_files[gi], mseqs_genome[gi], True) + + +################################################################################ +GraphSeq = collections.namedtuple("GraphSeq", ["genome", "net", "chr", "start", "end"]) + + +################################################################################ +def quantify_leakage( + align_net_file, train_contigs, valid_contigs, test_contigs, out_dir +): + """Quanitfy the leakage across sequence sets.""" + + def split_genome(contigs): + genome_contigs = [] + for ctg in contigs: + while len(genome_contigs) <= ctg.genome: + genome_contigs.append([]) + genome_contigs[ctg.genome].append((ctg.chr, ctg.start, ctg.end)) + genome_bedtools = [pybedtools.BedTool(ctgs) for ctgs in genome_contigs] + return genome_bedtools + + def bed_sum(overlaps): + osum = 0 + for overlap in overlaps: + osum += int(overlap[2]) - int(overlap[1]) + return osum + + train0_bt, train1_bt = split_genome(train_contigs) + valid0_bt, valid1_bt = split_genome(valid_contigs) + test0_bt, test1_bt = split_genome(test_contigs) + + assign0_sums = {} + assign1_sums = {} + + if os.path.splitext(align_net_file)[-1] == ".gz": + align_net_open = gzip.open(align_net_file, "rt") + else: + align_net_open = open(align_net_file, "r") + + for net_line in align_net_open: + if net_line.startswith("net"): + net_a = net_line.split() + chrom0 = net_a[1] + + elif net_line.startswith(" fill"): + net_a = net_line.split() + + # extract genome1 interval + start0 = int(net_a[1]) + size0 = int(net_a[2]) + end0 = start0 + size0 + align0_bt = pybedtools.BedTool([(chrom0, start0, end0)]) + + # extract genome2 interval + chrom1 = net_a[3] + start1 = int(net_a[5]) + size1 = int(net_a[6]) + end1 = start1 + size1 + align1_bt = pybedtools.BedTool([(chrom1, start1, end1)]) + + # count interval overlap + align0_train_bp = bed_sum(align0_bt.intersect(train0_bt)) + align0_valid_bp = bed_sum(align0_bt.intersect(valid0_bt)) + align0_test_bp = bed_sum(align0_bt.intersect(test0_bt)) + align0_max_bp = max(align0_train_bp, align0_valid_bp, align0_test_bp) + + align1_train_bp = bed_sum(align1_bt.intersect(train1_bt)) + align1_valid_bp = bed_sum(align1_bt.intersect(valid1_bt)) + align1_test_bp = bed_sum(align1_bt.intersect(test1_bt)) + align1_max_bp = max(align1_train_bp, align1_valid_bp, align1_test_bp) + + # assign to class + if align0_max_bp == 0: + assign0 = None + elif align0_train_bp == align0_max_bp: + assign0 = "train" + elif align0_valid_bp == align0_max_bp: + assign0 = "valid" + elif align0_test_bp == align0_max_bp: + assign0 = "test" + else: + print("Bad logic") + exit(1) + + if align1_max_bp == 0: + assign1 = None + elif align1_train_bp == align1_max_bp: + assign1 = "train" + elif align1_valid_bp == align1_max_bp: + assign1 = "valid" + elif align1_test_bp == align1_max_bp: + assign1 = "test" + else: + print("Bad logic") + exit(1) + + # increment + assign0_sums[(assign0, assign1)] = ( + assign0_sums.get((assign0, assign1), 0) + align0_max_bp + ) + assign1_sums[(assign0, assign1)] = ( + assign1_sums.get((assign0, assign1), 0) + align1_max_bp + ) + + # sum contigs + splits0_bp = {} + splits0_bp["train"] = bed_sum(train0_bt) + splits0_bp["valid"] = bed_sum(valid0_bt) + splits0_bp["test"] = bed_sum(test0_bt) + splits1_bp = {} + splits1_bp["train"] = bed_sum(train1_bt) + splits1_bp["valid"] = bed_sum(valid1_bt) + splits1_bp["test"] = bed_sum(test1_bt) + + leakage_out = open("%s/leakage.txt" % out_dir, "w") + print("Genome0", file=leakage_out) + for split0 in ["train", "valid", "test"]: + print(" %5s: %10d nt" % (split0, splits0_bp[split0]), file=leakage_out) + for split1 in ["train", "valid", "test", None]: + ss_bp = assign0_sums.get((split0, split1), 0) + print( + " %5s: %10d (%.5f)" % (split1, ss_bp, ss_bp / splits0_bp[split0]), + file=leakage_out, + ) + print("\nGenome1", file=leakage_out) + for split1 in ["train", "valid", "test"]: + print(" %5s: %10d nt" % (split1, splits1_bp[split1]), file=leakage_out) + for split0 in ["train", "valid", "test", None]: + ss_bp = assign1_sums.get((split0, split1), 0) + print( + " %5s: %10d (%.5f)" % (split0, ss_bp, ss_bp / splits1_bp[split1]), + file=leakage_out, + ) + leakage_out.close() + + +################################################################################ +def break_large_contigs(contigs, break_t, verbose=False): + """Break large contigs in half until all contigs are under + the size threshold.""" + + # initialize a heapq of contigs and lengths + contig_heapq = [] + for ctg in contigs: + ctg_len = ctg.end - ctg.start + heapq.heappush(contig_heapq, (-ctg_len, ctg)) + + ctg_len = break_t + 1 + while ctg_len > break_t: + + # pop largest contig + ctg_nlen, ctg = heapq.heappop(contig_heapq) + ctg_len = -ctg_nlen + + # if too large + if ctg_len > break_t: + if verbose: + print( + "Breaking %s:%d-%d (%d nt)" % (ctg.chr, ctg.start, ctg.end, ctg_len) + ) + + # break in two + ctg_mid = ctg.start + ctg_len // 2 + ctg_left = data.Contig(ctg.genome, ctg.chr, ctg.start, ctg_mid) + ctg_right = data.Contig(ctg.genome, ctg.chr, ctg_mid, ctg.end) + + # add left + ctg_left_len = ctg_left.end - ctg_left.start + heapq.heappush(contig_heapq, (-ctg_left_len, ctg_left)) + + # add right + ctg_right_len = ctg_right.end - ctg_right.start + heapq.heappush(contig_heapq, (-ctg_right_len, ctg_right)) + + # return to list + contigs = [len_ctg[1] for len_ctg in contig_heapq] + + return contigs + + +################################################################################ +def connect_contigs( + contigs, align_net_file, net_fill_min, net_olap_min, out_dir, genome_out_dirs +): + """Connect contigs across genomes by forming a graph that includes + net format aligning regions and contigs. Compute contig components + as connected components of that graph.""" + + # construct align net graph and write net BEDs + if align_net_file is None: + graph_contigs_nets = nx.Graph() + else: + graph_contigs_nets = make_net_graph(align_net_file, net_fill_min, out_dir) + + # add contig nodes + for ctg in contigs: + ctg_node = GraphSeq(ctg.genome, False, ctg.chr, ctg.start, ctg.end) + graph_contigs_nets.add_node(ctg_node) + + # intersect contigs BED w/ nets BED, adding graph edges. + intersect_contigs_nets( + graph_contigs_nets, 0, out_dir, genome_out_dirs[0], net_olap_min + ) + intersect_contigs_nets( + graph_contigs_nets, 1, out_dir, genome_out_dirs[1], net_olap_min + ) + + # find connected components + contig_components = [] + for contig_net_component in nx.connected_components(graph_contigs_nets): + # extract only the contigs + cc_contigs = [ + contig_or_net + for contig_or_net in contig_net_component + if contig_or_net.net is False + ] + + if cc_contigs: + # add to list + contig_components.append(cc_contigs) + + # write summary stats + comp_out = open("%s/contig_components.txt" % out_dir, "w") + for ctg_comp in contig_components: + ctg_comp0 = [ctg for ctg in ctg_comp if ctg.genome == 0] + ctg_comp1 = [ctg for ctg in ctg_comp if ctg.genome == 1] + ctg_comp0_nt = sum([ctg.end - ctg.start for ctg in ctg_comp0]) + ctg_comp1_nt = sum([ctg.end - ctg.start for ctg in ctg_comp1]) + ctg_comp_nt = ctg_comp0_nt + ctg_comp1_nt + cols = [len(ctg_comp), len(ctg_comp0), len(ctg_comp1)] + cols += [ctg_comp0_nt, ctg_comp1_nt, ctg_comp_nt] + cols = [str(c) for c in cols] + print("\t".join(cols), file=comp_out) + comp_out.close() + + return contig_components + + +################################################################################ +def contig_stats_genome(contigs): + """Compute contig statistics within each genome.""" + contigs_count_genome = [] + contigs_nt_genome = [] + + contigs_genome_found = True + gi = 0 + while contigs_genome_found: + contigs_genome = [ctg for ctg in contigs if ctg.genome == gi] + + if len(contigs_genome) == 0: + contigs_genome_found = False + + else: + contigs_nt = [ctg.end - ctg.start for ctg in contigs_genome] + + contigs_count_genome.append(len(contigs_genome)) + contigs_nt_genome.append(sum(contigs_nt)) + + gi += 1 + + return contigs_count_genome, contigs_nt_genome + + +################################################################################ +def divide_components_folds(contig_components, folds): + """Divide contig connected components into cross fold lists.""" + + # sort contig components descending by length + length_contig_components = [] + for cc_contigs in contig_components: + cc_len = sum([ctg.end - ctg.start for ctg in cc_contigs]) + length_contig_components.append((cc_len, cc_contigs)) + length_contig_components.sort(reverse=True) + + # compute total nucleotides + total_nt = sum([lc[0] for lc in length_contig_components]) + + # compute aimed fold nucleotides + fold_nt_aim = int(np.ceil(total_nt / folds)) + + # initialize current fold nucleotides + fold_nt = np.zeros(folds) + + # initialize fold contig lists + fold_contigs = [] + for fi in range(folds): + fold_contigs.append([]) + + # process contigs + for ctg_comp_len, ctg_comp in length_contig_components: + # compute gap between current and aim + fold_nt_gap = fold_nt_aim - fold_nt + fold_nt_gap = np.clip(fold_nt_gap, 0, np.inf) + + # compute sample probability + fold_prob = fold_nt_gap / fold_nt_gap.sum() + + # sample train/valid/test + fi = np.random.choice(folds, p=fold_prob) + fold_nt[fi] += ctg_comp_len + for ctg in ctg_comp: + fold_contigs[fi].append(ctg) + + # report genome-specific train/valid/test stats + report_divide_stats(fold_contigs) + + return fold_contigs + + +################################################################################ +def divide_components_pct(contig_components, test_pct, valid_pct, pct_abstain=0.5): + """Divide contig connected components into train/valid/test, + and aiming for the specified nucleotide percentages.""" + + # sort contig components descending by length + length_contig_components = [] + for cc_contigs in contig_components: + cc_len = sum([ctg.end - ctg.start for ctg in cc_contigs]) + length_contig_components.append((cc_len, cc_contigs)) + length_contig_components.sort(reverse=True) + + # compute total nucleotides + total_nt = sum([lc[0] for lc in length_contig_components]) + + # compute aimed train/valid/test nucleotides + test_nt_aim = test_pct * total_nt + valid_nt_aim = valid_pct * total_nt + train_nt_aim = total_nt - valid_nt_aim - test_nt_aim + + # initialize current train/valid/test nucleotides + train_nt = 0 + valid_nt = 0 + test_nt = 0 + + # initialie train/valid/test contig lists + train_contigs = [] + valid_contigs = [] + test_contigs = [] + + # process contigs + for ctg_comp_len, ctg_comp in length_contig_components: + # compute gap between current and aim + test_nt_gap = max(0, test_nt_aim - test_nt) + valid_nt_gap = max(0, valid_nt_aim - valid_nt) + train_nt_gap = max(1, train_nt_aim - train_nt) + + # skip if too large + if ctg_comp_len > pct_abstain * test_nt_gap: + test_nt_gap = 0 + if ctg_comp_len > pct_abstain * valid_nt_gap: + valid_nt_gap = 0 + + # compute remaining % + gap_sum = train_nt_gap + valid_nt_gap + test_nt_gap + test_pct_gap = test_nt_gap / gap_sum + valid_pct_gap = valid_nt_gap / gap_sum + train_pct_gap = train_nt_gap / gap_sum + + # sample train/valid/test + ri = np.random.choice( + range(3), 1, p=[train_pct_gap, valid_pct_gap, test_pct_gap] + )[0] + + # collect contigs (sorted is required for deterministic sequence order) + if ri == 0: + for ctg in sorted(ctg_comp): + train_contigs.append(ctg) + train_nt += ctg_comp_len + elif ri == 1: + for ctg in sorted(ctg_comp): + valid_contigs.append(ctg) + valid_nt += ctg_comp_len + elif ri == 2: + for ctg in sorted(ctg_comp): + test_contigs.append(ctg) + test_nt += ctg_comp_len + else: + print("TVT random number beyond 0,1,2", file=sys.stderr) + exit(1) + + # report genome-specific train/valid/test stats + report_divide_stats([train_contigs, valid_contigs, test_contigs]) + + return [train_contigs, valid_contigs, test_contigs] + + +################################################################################ +def intersect_contigs_nets( + graph_contigs_nets, genome_i, out_dir, genome_out_dir, min_olap=128 +): + """Intersect the contigs and nets from genome_i, adding the + overlaps as edges to graph_contigs_nets.""" + + contigs_file = "%s/contigs.bed" % genome_out_dir + nets_file = "%s/nets%d.bed" % (out_dir, genome_i) + + contigs_bed = pybedtools.BedTool(contigs_file) + nets_bed = pybedtools.BedTool(nets_file) + + for overlap in contigs_bed.intersect(nets_bed, wo=True): + ctg_chr = overlap[0] + ctg_start = int(overlap[1]) + ctg_end = int(overlap[2]) + net_chr = overlap[3] + net_start = int(overlap[4]) + net_end = int(overlap[5]) + olap_len = int(overlap[6]) + + if olap_len > min_olap: + # create node objects + ctg_node = GraphSeq(genome_i, False, ctg_chr, ctg_start, ctg_end) + net_node = GraphSeq(genome_i, True, net_chr, net_start, net_end) + + # add edge / verify we found nodes + gcn_size_pre = graph_contigs_nets.number_of_nodes() + graph_contigs_nets.add_edge(ctg_node, net_node) + gcn_size_post = graph_contigs_nets.number_of_nodes() + assert gcn_size_pre == gcn_size_post + + +################################################################################ +def make_net_graph(align_net_file, net_fill_min, out_dir): + """Construct a Graph with aligned net intervals connected + by edges.""" + + graph_nets = nx.Graph() + + nets1_bed_out = open("%s/nets0.bed" % out_dir, "w") + nets2_bed_out = open("%s/nets1.bed" % out_dir, "w") + + if os.path.splitext(align_net_file)[-1] == ".gz": + align_net_open = gzip.open(align_net_file, "rt") + else: + align_net_open = open(align_net_file, "r") + + for net_line in align_net_open: + if net_line.startswith("net"): + net_a = net_line.split() + chrom1 = net_a[1] + + elif net_line.startswith(" fill"): + net_a = net_line.split() + + # extract genome1 interval + start1 = int(net_a[1]) + size1 = int(net_a[2]) + end1 = start1 + size1 + + # extract genome2 interval + chrom2 = net_a[3] + start2 = int(net_a[5]) + size2 = int(net_a[6]) + end2 = start2 + size2 + + if min(size1, size2) >= net_fill_min: + # add edge + net1_node = GraphSeq(0, True, chrom1, start1, end1) + net2_node = GraphSeq(1, True, chrom2, start2, end2) + graph_nets.add_edge(net1_node, net2_node) + + # write interval1 + cols = [chrom1, str(start1), str(end1)] + print("\t".join(cols), file=nets1_bed_out) + + # write interval2 + cols = [chrom2, str(start2), str(end2)] + print("\t".join(cols), file=nets2_bed_out) + + nets1_bed_out.close() + nets2_bed_out.close() + + return graph_nets + + +################################################################################ +def report_divide_stats(fold_contigs): + """Report genome-specific statistics about the division of contigs + between sets.""" + + fold_counts_genome = [] + fold_nts_genome = [] + for fi in range(len(fold_contigs)): + fcg, fng = contig_stats_genome(fold_contigs[fi]) + fold_counts_genome.append(fcg) + fold_nts_genome.append(fng) + num_genomes = len(fold_counts_genome[0]) + + # sum nt across genomes + fold_nts = [sum(fng) for fng in fold_nts_genome] + total_nt = sum(fold_nts) + + # compute total sum nt per genome + total_nt_genome = [] + print("Total nt") + for gi in range(num_genomes): + total_nt_gi = sum([fng[gi] for fng in fold_nts_genome]) + total_nt_genome.append(total_nt_gi) + print(" Genome%d: %10d nt" % (gi, total_nt_gi)) + + # label folds and guess that 3 is train/valid/test + fold_labels = [] + if len(fold_contigs) == 3: + fold_labels = ["Train", "Valid", "Test"] + else: + fold_labels = ["Fold%d" % fi for fi in range(len(fold_contigs))] + + print("Contigs divided into") + for fi in range(len(fold_contigs)): + print( + " %s: %5d contigs, %10d nt (%.4f)" + % ( + fold_labels[fi], + len(fold_contigs[fi]), + fold_nts[fi], + fold_nts[fi] / total_nt, + ) + ) + for gi in range(num_genomes): + print( + " Genome%d: %5d contigs, %10d nt (%.4f)" + % ( + gi, + fold_counts_genome[fi][gi], + fold_nts_genome[fi][gi], + fold_nts_genome[fi][gi] / total_nt_genome[gi], + ) + ) + + +################################################################################ +def report_divide_stats_v1(train_contigs, valid_contigs, test_contigs): + """Report genome-specific statistics about the division of contigs + between train/valid/test sets.""" + + # compute genome-specific stats + train_count_genome, train_nt_genome = contig_stats_genome(train_contigs) + valid_count_genome, valid_nt_genome = contig_stats_genome(valid_contigs) + test_count_genome, test_nt_genome = contig_stats_genome(test_contigs) + num_genomes = len(train_count_genome) + + # sum nt across genomes + train_nt = sum(train_nt_genome) + valid_nt = sum(valid_nt_genome) + test_nt = sum(test_nt_genome) + total_nt = train_nt + valid_nt + test_nt + + # compute total sum nt per genome + total_nt_genome = [] + for gi in range(num_genomes): + total_nt_gi = train_nt_genome[gi] + valid_nt_genome[gi] + test_nt_genome[gi] + total_nt_genome.append(total_nt_gi) + + print("Contigs divided into") + print( + " Train: %5d contigs, %10d nt (%.4f)" + % (len(train_contigs), train_nt, train_nt / total_nt) + ) + for gi in range(num_genomes): + print( + " Genome%d: %5d contigs, %10d nt (%.4f)" + % ( + gi, + train_count_genome[gi], + train_nt_genome[gi], + train_nt_genome[gi] / total_nt_genome[gi], + ) + ) + + print( + " Valid: %5d contigs, %10d nt (%.4f)" + % (len(valid_contigs), valid_nt, valid_nt / total_nt) + ) + for gi in range(num_genomes): + print( + " Genome%d: %5d contigs, %10d nt (%.4f)" + % ( + gi, + valid_count_genome[gi], + valid_nt_genome[gi], + valid_nt_genome[gi] / total_nt_genome[gi], + ) + ) + + print( + " Test: %5d contigs, %10d nt (%.4f)" + % (len(test_contigs), test_nt, test_nt / total_nt) + ) + for gi in range(num_genomes): + print( + " Genome%d: %5d contigs, %10d nt (%.4f)" + % ( + gi, + test_count_genome[gi], + test_nt_genome[gi], + test_nt_genome[gi] / total_nt_genome[gi], + ) + ) + + +################################################################################ +if __name__ == "__main__": + main() diff --git a/src/baskerville/scripts/hound_data_read.py b/src/baskerville/scripts/hound_data_read.py new file mode 100755 index 0000000..5b6ec35 --- /dev/null +++ b/src/baskerville/scripts/hound_data_read.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python +# Copyright 2017 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser + +import os +import sys + +import h5py +import intervaltree +import numpy as np +import pandas as pd +import scipy.interpolate + +import pyBigWig + +from baskerville import data + +""" +hound_data_read.py + +Read sequence values from coverage files. +""" + + +################################################################################ +# main +################################################################################ +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + parser.add_option( + "-b", + dest="blacklist_bed", + help="Set blacklist nucleotides to a baseline value.", + ) + parser.add_option( + "--black_pct", + dest="blacklist_pct", + default=0.5, + type="float", + help="Clip blacklisted regions to this distribution value [Default: %default", + ) + parser.add_option( + "-c", + dest="clip", + default=None, + type="float", + help="Clip values post-summary to a maximum [Default: %default]", + ) + parser.add_option( + "--clip_soft", + dest="clip_soft", + default=None, + type="float", + help="Soft clip values, applying sqrt to the execess above the threshold [Default: %default]", + ) + parser.add_option( + "--clip_pct", + dest="clip_pct", + default=0.9999999, + type="float", + help="Clip extreme values to this distribution value [Default: %default", + ) + parser.add_option( + "--crop", + dest="crop_bp", + default=0, + type="int", + help="Crop bp off each end [Default: %default]", + ) + parser.add_option( + "-i", + dest="interp_nan", + default=False, + action="store_true", + help="Interpolate NaNs [Default: %default]", + ) + parser.add_option( + "-s", + dest="scale", + default=1.0, + type="float", + help="Scale values by [Default: %default]", + ) + parser.add_option( + "-u", + dest="sum_stat", + default="sum", + help="Summary statistic to compute in windows [Default: %default]", + ) + parser.add_option( + "-w", + dest="pool_width", + default=1, + type="int", + help="Average pooling width [Default: %default]", + ) + (options, args) = parser.parse_args() + + if len(args) != 3: + parser.error("") + else: + genome_cov_file = args[0] + seqs_bed_file = args[1] + seqs_cov_file = args[2] + + assert options.crop_bp >= 0 + + # read model sequences + model_seqs = [] + for line in open(seqs_bed_file): + a = line.split() + model_seqs.append(data.ModelSeq(0, a[0], int(a[1]), int(a[2]), None)) + + # read blacklist regions + black_chr_trees = read_blacklist(options.blacklist_bed) + + # compute dimensions + num_seqs = len(model_seqs) + seq_len_nt = model_seqs[0].end - model_seqs[0].start + seq_len_nt -= 2 * options.crop_bp + target_length = seq_len_nt // options.pool_width + assert target_length > 0 + + # collect targets + targets = [] + + # open genome coverage file + genome_cov_open = CovFace(genome_cov_file) + + # for each model sequence + for si in range(num_seqs): + mseq = model_seqs[si] + + # read coverage + seq_cov_nt = genome_cov_open.read(mseq.chr, mseq.start, mseq.end) + seq_cov_nt = seq_cov_nt.astype("float32") + + # interpolate NaN + if options.interp_nan: + seq_cov_nt = interp_nan(seq_cov_nt) + + # determine baseline coverage + if target_length >= 8: + baseline_cov = np.percentile(seq_cov_nt, 100 * options.blacklist_pct) + baseline_cov = np.nan_to_num(baseline_cov) + else: + baseline_cov = 0 + + # set blacklist to baseline + if mseq.chr in black_chr_trees: + for black_interval in black_chr_trees[mseq.chr][mseq.start : mseq.end]: + # adjust for sequence indexes + black_seq_start = black_interval.begin - mseq.start + black_seq_end = black_interval.end - mseq.start + black_seq_values = seq_cov_nt[black_seq_start:black_seq_end] + seq_cov_nt[black_seq_start:black_seq_end] = np.clip( + black_seq_values, -baseline_cov, baseline_cov + ) + # seq_cov_nt[black_seq_start:black_seq_end] = baseline_cov + + # set NaN's to baseline + if not options.interp_nan: + nan_mask = np.isnan(seq_cov_nt) + seq_cov_nt[nan_mask] = baseline_cov + + # crop + if options.crop_bp > 0: + seq_cov_nt = seq_cov_nt[options.crop_bp : -options.crop_bp] + + # scale + seq_cov_nt = options.scale * seq_cov_nt + + # sum pool + seq_cov = seq_cov_nt.reshape(target_length, options.pool_width) + if options.sum_stat == "sum": + seq_cov = seq_cov.sum(axis=1, dtype="float32") + elif options.sum_stat == "sum_sqrt": + seq_cov = seq_cov.sum(axis=1, dtype="float32") + seq_cov = -1 + np.sqrt(1 + seq_cov) + elif options.sum_stat == "sum_exp75": + seq_cov = seq_cov.sum(axis=1, dtype="float32") + seq_cov = -1 + (1 + seq_cov) ** 0.75 + elif options.sum_stat in ["mean", "avg"]: + seq_cov = seq_cov.mean(axis=1, dtype="float32") + elif options.sum_stat in ["mean_sqrt", "avg_sqrt"]: + seq_cov = seq_cov.mean(axis=1, dtype="float32") + seq_cov = -1 + np.sqrt(1 + seq_cov) + elif options.sum_stat == "median": + seq_cov = seq_cov.median(axis=1) + elif options.sum_stat == "max": + seq_cov = seq_cov.max(axis=1) + elif options.sum_stat == "peak": + seq_cov = seq_cov.mean(axis=1, dtype="float32") + seq_cov = np.clip(np.sqrt(seq_cov * 4), 0, 1) + else: + print( + 'ERROR: Unrecognized summary statistic "%s".' % options.sum_stat, + file=sys.stderr, + ) + exit(1) + + # clip + if options.clip_soft is not None: + clip_mask = seq_cov > options.clip_soft + seq_cov[clip_mask] = ( + options.clip_soft + - 1 + + np.sqrt(seq_cov[clip_mask] - options.clip_soft + 1) + ) + if options.clip is not None: + seq_cov = np.clip(seq_cov, -options.clip, options.clip) + + # clip float16 min/max + seq_cov = np.clip(seq_cov, np.finfo(np.float16).min, np.finfo(np.float16).max) + + # save + targets.append(seq_cov.astype("float16")) + + # close genome coverage file + genome_cov_open.close() + + # clip extreme values + targets = np.array(targets, dtype="float16") + extreme_clip = np.percentile(targets, 100 * options.clip_pct) + targets = np.clip(targets, -extreme_clip, extreme_clip) + print("Targets sum: %.3f" % targets.sum(dtype="float64")) + + # write + with h5py.File(seqs_cov_file, "w") as seqs_cov_open: + seqs_cov_open.create_dataset( + "targets", data=targets, dtype="float16", compression="gzip" + ) + + +def interp_nan(x, kind="linear"): + """Linearly interpolate to fill NaN.""" + + # pad zeroes + xp = np.zeros(len(x) + 2) + xp[1:-1] = x + + # find NaN + x_nan = np.isnan(xp) + + if np.sum(x_nan) == 0: + # unnecessary + return x + + else: + # interpolate + inds = np.arange(len(xp)) + interpolator = scipy.interpolate.interp1d( + inds[~x_nan], xp[~x_nan], kind=kind, bounds_error=False + ) + + loc = np.where(x_nan) + xp[loc] = interpolator(loc) + + # slice off pad + return xp[1:-1] + + +def read_blacklist(blacklist_bed, black_buffer=20): + """Construct interval trees of blacklist + regions for each chromosome.""" + black_chr_trees = {} + + if blacklist_bed is not None and os.path.isfile(blacklist_bed): + for line in open(blacklist_bed): + a = line.split() + chrm = a[0] + start = max(0, int(a[1]) - black_buffer) + end = int(a[2]) + black_buffer + + if chrm not in black_chr_trees: + black_chr_trees[chrm] = intervaltree.IntervalTree() + + black_chr_trees[chrm][start:end] = True + + return black_chr_trees + + +class CovFace: + def __init__(self, cov_file): + self.cov_file = cov_file + self.bigwig = False + self.bed = False + + cov_ext = os.path.splitext(self.cov_file)[1].lower() + if cov_ext == ".gz": + cov_ext = os.path.splitext(self.cov_file[:-3])[1].lower() + + if cov_ext in [".bed", ".narrowpeak"]: + self.bed = True + self.preprocess_bed() + + elif cov_ext in [".bw", ".bigwig"]: + self.cov_open = pyBigWig.open(self.cov_file, "r") + self.bigwig = True + + elif cov_ext in [".h5", ".hdf5", ".w5", ".wdf5"]: + self.cov_open = h5py.File(self.cov_file, "r") + + else: + print( + 'Cannot identify coverage file extension "%s".' % cov_ext, + file=sys.stderr, + ) + exit(1) + + def preprocess_bed(self): + # read BED + bed_df = pd.read_csv( + self.cov_file, sep="\t", usecols=range(3), names=["chr", "start", "end"] + ) + + # for each chromosome + self.cov_open = {} + for chrm in bed_df.chr.unique(): + bed_chr_df = bed_df[bed_df.chr == chrm] + + # find max pos + pos_max = bed_chr_df.end.max() + + # initialize array + self.cov_open[chrm] = np.zeros(pos_max, dtype="bool") + + # set peaks + for peak in bed_chr_df.itertuples(): + self.cov_open[peak.chr][peak.start : peak.end] = 1 + + def read(self, chrm, start, end): + if self.bigwig: + cov = self.cov_open.values(chrm, start, end, numpy=True).astype("float16") + + else: + if chrm in self.cov_open: + cov = self.cov_open[chrm][start:end] + + # handle mysterious inf's + cov = np.clip(cov, np.finfo(np.float16).min, np.finfo(np.float16).max) + + # pad + pad_zeros = end - start - len(cov) + if pad_zeros > 0: + cov_pad = np.zeros(pad_zeros, dtype="bool") + cov = np.concatenate([cov, cov_pad]) + + else: + print( + "WARNING: %s doesn't see %s:%d-%d. Setting to all zeros." + % (self.cov_file, chrm, start, end), + file=sys.stderr, + ) + cov = np.zeros(end - start, dtype="float16") + + return cov + + def close(self): + if not self.bed: + self.cov_open.close() + + +################################################################################ +# __main__ +################################################################################ +if __name__ == "__main__": + main() diff --git a/src/baskerville/scripts/hound_data_write.py b/src/baskerville/scripts/hound_data_write.py new file mode 100755 index 0000000..bf55065 --- /dev/null +++ b/src/baskerville/scripts/hound_data_write.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python +# Copyright 2019 Calico LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser +import os +import sys + +import h5py +import numpy as np +import pdb +import pysam +import tensorflow as tf + +from baskerville import data +from baskerville import dna + +""" +basenji_data_write.py + +Write TF Records for batches of model sequences. +""" + + +################################################################################ +# main +################################################################################ +def main(): + usage = ( + "usage: %prog [options] " + ) + parser = OptionParser(usage) + parser.add_option( + "-d", + dest="decimals", + default=None, + type="int", + help="Round values to given decimals [Default: %default]", + ) + parser.add_option( + "-s", + dest="start_i", + default=0, + type="int", + help="Sequence start index [Default: %default]", + ) + parser.add_option( + "-e", + dest="end_i", + default=None, + type="int", + help="Sequence end index [Default: %default]", + ) + parser.add_option( + "--te", + dest="target_extend", + default=None, + type="int", + help="Extend targets vector [Default: %default]", + ) + parser.add_option("-u", dest="umap_npy", help="Unmappable array numpy file") + parser.add_option( + "--umap_clip", + dest="umap_clip", + default=1, + type="float", + help="Clip values at unmappable positions to distribution quantiles, eg 0.25. [Default: %default]", + ) + parser.add_option( + "--umap_tfr", + dest="umap_tfr", + default=False, + action="store_true", + help="Save umap array into TFRecords [Default: %default]", + ) + parser.add_option( + "-x", + dest="extend_bp", + default=0, + type="int", + help="Extend sequences on each side [Default: %default]", + ) + (options, args) = parser.parse_args() + + if len(args) != 4: + parser.error("Must provide input arguments.") + else: + fasta_file = args[0] + seqs_bed_file = args[1] + seqs_cov_dir = args[2] + tfr_file = args[3] + + ################################################################ + # read model sequences + + model_seqs = [] + for line in open(seqs_bed_file): + a = line.split() + model_seqs.append(data.ModelSeq(0, a[0], int(a[1]), int(a[2]), None)) + + if options.end_i is None: + options.end_i = len(model_seqs) + + num_seqs = options.end_i - options.start_i + + ################################################################ + # determine sequence coverage files + + seqs_cov_files = [] + ti = 0 + seqs_cov_file = "%s/%d.h5" % (seqs_cov_dir, ti) + while os.path.isfile(seqs_cov_file): + seqs_cov_files.append(seqs_cov_file) + ti += 1 + seqs_cov_file = "%s/%d.h5" % (seqs_cov_dir, ti) + + if len(seqs_cov_files) == 0: + print( + "Sequence coverage files not found, e.g. %s" % seqs_cov_file, + file=sys.stderr, + ) + exit(1) + + seq_pool_len = h5py.File(seqs_cov_files[0], "r")["targets"].shape[1] + num_targets = len(seqs_cov_files) + + ################################################################ + # read targets + + # initialize targets + targets = np.zeros((num_seqs, seq_pool_len, num_targets), dtype="float16") + + # read each target + for ti in range(num_targets): + seqs_cov_open = h5py.File(seqs_cov_files[ti], "r") + targets[:, :, ti] = seqs_cov_open["targets"][options.start_i : options.end_i, :] + seqs_cov_open.close() + + ################################################################ + # modify unmappable + + if options.umap_npy is not None and options.umap_clip < 1: + unmap_mask = np.load(options.umap_npy) + + for si in range(num_seqs): + msi = options.start_i + si + + # determine unmappable null value + seq_target_null = np.percentile( + targets[si], q=[100 * options.umap_clip], axis=0 + )[0] + + # set unmappable positions to null + targets[si, unmap_mask[msi, :], :] = np.minimum( + targets[si, unmap_mask[msi, :], :], seq_target_null + ) + + elif options.umap_npy is not None and options.umap_tfr: + unmap_mask = np.load(options.umap_npy) + + ################################################################ + # write TFRecords + + # open FASTA + fasta_open = pysam.Fastafile(fasta_file) + + # define options + tf_opts = tf.io.TFRecordOptions(compression_type="ZLIB") + + with tf.io.TFRecordWriter(tfr_file, tf_opts) as writer: + for si in range(num_seqs): + msi = options.start_i + si + mseq = model_seqs[msi] + mseq_start = mseq.start - options.extend_bp + mseq_end = mseq.end + options.extend_bp + + # read FASTA + seq_dna = fetch_dna(fasta_open, mseq.chr, mseq_start, mseq_end) + + # one hot code (N's as zero) + # seq_1hot = dna.dna_1hot(seq_dna, n_uniform=False, n_sample=False) + seq_1hot = dna.dna_1hot_index(seq_dna) # more efficient + + # truncate decimals (which aids compression) + if options.decimals is not None: + targets_si = targets[si].astype("float32") + targets_si = np.around(targets_si, decimals=options.decimals) + targets_si = targets_si.astype("float16") + # targets_si = rround(targets[si], decimals=options.decimals) + else: + targets_si = targets[si] + + assert np.isinf(targets_si).sum() == 0 + + # hash to bytes + features_dict = { + "sequence": feature_bytes(seq_1hot), + "target": feature_bytes(targets_si), + } + + # add unmappability + if options.umap_tfr: + features_dict["umap"] = feature_bytes(unmap_mask[msi, :]) + + # write example + example = tf.train.Example( + features=tf.train.Features(feature=features_dict) + ) + writer.write(example.SerializeToString()) + + fasta_open.close() + + +def tround(a, decimals): + """Truncate to the specified number of decimals.""" + return np.true_divide(np.floor(a * 10**decimals), 10**decimals) + + +def rround(a, decimals): + """Round to the specified number of decimals, randomly sampling + the last digit according to a bernoulli RV.""" + a_dtype = a.dtype + a = a.astype("float32") + dec_probs = (a - tround(a, decimals)) * 10**decimals + dec_bin = np.random.binomial(n=1, p=dec_probs) + a_dec = tround(a, decimals) + dec_bin / 10**decimals + return np.around(a_dec.astype(a_dtype), decimals) + + +def fetch_dna(fasta_open, chrm, start, end): + """Fetch DNA when start/end may reach beyond chromosomes.""" + + # initialize sequence + seq_len = end - start + seq_dna = "" + + # add N's for left over reach + if start < 0: + seq_dna = "N" * (-start) + start = 0 + + # get dna + seq_dna += fasta_open.fetch(chrm, start, end) + + # add N's for right over reach + if len(seq_dna) < seq_len: + seq_dna += "N" * (seq_len - len(seq_dna)) + + return seq_dna + + +def feature_bytes(values): + """Convert numpy arrays to bytes features.""" + values = values.flatten().tobytes() + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) + + +def feature_floats(values): + """Convert numpy arrays to floats features. + Requires more space than bytes for float16""" + values = values.flatten().tolist() + return tf.train.Feature(float_list=tf.train.FloatList(value=values)) + + +################################################################################ +# __main__ +################################################################################ +if __name__ == "__main__": + main() diff --git a/src/baskerville/scripts/hound_eval.py b/src/baskerville/scripts/hound_eval.py index b7fca0d..a82199c 100755 --- a/src/baskerville/scripts/hound_eval.py +++ b/src/baskerville/scripts/hound_eval.py @@ -101,7 +101,6 @@ def main(): parser.add_argument( "--split", default="test", - choices=["train", "valid", "test"], help="Dataset split label for eg TFR pattern [Default: %(default)s]", ) parser.add_argument( diff --git a/src/baskerville/scripts/hound_ism_bed.py b/src/baskerville/scripts/hound_ism_bed.py index 7449c18..7d4668e 100755 --- a/src/baskerville/scripts/hound_ism_bed.py +++ b/src/baskerville/scripts/hound_ism_bed.py @@ -276,7 +276,7 @@ def main(): alt_preds = np.array(alt_preds) ism_scores = snps.compute_scores( - ref_preds, alt_preds, options.snp_stats + ref_preds, alt_preds, options.snp_stats, None ) for snp_stat in options.snp_stats: scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[ diff --git a/src/baskerville/scripts/hound_ism_snp.py b/src/baskerville/scripts/hound_ism_snp.py index 12b8ba5..0318765 100755 --- a/src/baskerville/scripts/hound_ism_snp.py +++ b/src/baskerville/scripts/hound_ism_snp.py @@ -156,6 +156,7 @@ def main(): else: targets_strand_df = targets_df strand_transform = None + num_targets = targets_strand_df.shape[0] ################################################################# @@ -249,7 +250,7 @@ def main(): alt_preds = np.array(alt_preds) ism_scores = snps.compute_scores( - ref_preds, alt_preds, options.snp_stats + ref_preds, alt_preds, options.snp_stats, None ) for snp_stat in options.snp_stats: scores_h5[snp_stat][si, mi - mut_start, ni] = ism_scores[ diff --git a/src/baskerville/scripts/hound_predbed.py b/src/baskerville/scripts/hound_predbed.py index 4112048..518b2aa 100755 --- a/src/baskerville/scripts/hound_predbed.py +++ b/src/baskerville/scripts/hound_predbed.py @@ -26,9 +26,10 @@ import tensorflow as tf from baskerville import bed +from baskerville import dataset from baskerville import dna from baskerville import seqnn -from baskerville import stream + """ hound_predbed.py @@ -120,35 +121,18 @@ def main(): default=None, help="File specifying target indexes and labels in table format", ) - + parser.add_argument( + "-u", + "--untransform_old", + default=False, + action="store_true", + help="Untransform old models [Default: %default]", + ) parser.add_argument("params_file", help="Parameters file") parser.add_argument("model_file", help="Model file") parser.add_argument("bed_file", help="BED file") args = parser.parse_args() - if len(args) == 3: - params_file = args[0] - model_file = args[1] - bed_file = args[2] - - elif len(args) == 5: - # multi worker - options_pkl_file = args[0] - params_file = args[1] - model_file = args[2] - bed_file = args[3] - worker_index = int(args[4]) - - # load options - options_pkl = open(options_pkl_file, "rb") - options = pickle.load(options_pkl) - options_pkl.close() - - # update output directory - args.out_dir = "%s/job%d" % (args.out_dir, worker_index) - else: - parser.error("Must provide parameter and model files and BED file") - os.makedirs(args.out_dir, exist_ok=True) args.shifts = [int(shift) for shift in args.shifts.split(",")] @@ -166,7 +150,7 @@ def main(): ################################################################# # read parameters and collet target information - with open(params_file) as params_open: + with open(args.params_file) as params_open: params = json.load(params_open) params_model = params["model"] @@ -181,7 +165,7 @@ def main(): # initialize model seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file, args.head) + seqnn_model.restore(args.model_file, args.head) seqnn_model.build_slice(target_slice) seqnn_model.build_ensemble(args.rc, args.shifts) @@ -205,27 +189,8 @@ def main(): # construct model sequences model_seqs_dna, model_seqs_coords = bed.make_bed_seqs( - bed_file, args.genome_fasta, params_model["seq_length"], stranded=False + args.bed_file, args.genome_fasta, params_model["seq_length"], stranded=False ) - - # construct site coordinates - site_seqs_coords = bed.read_bed_coords(bed_file, args.site_length) - - # filter for worker SNPs - if args.processes is not None: - worker_bounds = np.linspace( - 0, len(model_seqs_dna), args.processes + 1, dtype="int" - ) - model_seqs_dna = model_seqs_dna[ - worker_bounds[worker_index] : worker_bounds[worker_index + 1] - ] - model_seqs_coords = model_seqs_coords[ - worker_bounds[worker_index] : worker_bounds[worker_index + 1] - ] - site_seqs_coords = site_seqs_coords[ - worker_bounds[worker_index] : worker_bounds[worker_index + 1] - ] - num_seqs = len(model_seqs_dna) ################################################################# @@ -258,6 +223,7 @@ def main(): ) # store site coordinates + site_seqs_coords = bed.read_bed_coords(args.bed_file, args.site_length) site_seqs_chr, site_seqs_start, site_seqs_end = zip(*site_seqs_coords) site_seqs_chr = np.array(site_seqs_chr, dtype="S") site_seqs_start = np.array(site_seqs_start) @@ -268,7 +234,7 @@ def main(): ################################################################# # predict scores, write output - + """ # define sequence generator def seqs_gen(): for seq_dna in model_seqs_dna: @@ -278,18 +244,29 @@ def seqs_gen(): preds_stream = stream.PredStreamGen( seqnn_model, seqs_gen(), params["train"]["batch_size"] ) + """ + + for si, seq_dna in enumerate(model_seqs_dna): + seq_1hot = np.expand_dims(dna.dna_1hot(seq_dna), axis=0) + preds_seq = seqnn_model.predict(seq_1hot)[0] - for si in range(num_seqs): - preds_seq = preds_stream[si] + if args.untransform_old: + preds_seq = dataset.untransform_preds1(preds_seq, targets_df) + else: + preds_seq = dataset.untransform_preds(preds_seq, targets_df) # slice site preds_site = preds_seq[site_preds_start:site_preds_end, :] - # write + # optionally, sum if args.sum: - out_h5["preds"][si] = preds_site.sum(axis=0) - else: - out_h5["preds"][si] = preds_site + preds_site = np.sum(preds_site, axis=0) + + # clip to float16 max + preds_write = np.clip(preds_site, 0, np.finfo(np.float16).max) + + # write + out_h5["preds"][si] = preds_write # write bigwig for ti in args.bigwig_indexes: diff --git a/src/baskerville/scripts/hound_snp.py b/src/baskerville/scripts/hound_snp.py index 1ccad45..847cd16 100755 --- a/src/baskerville/scripts/hound_snp.py +++ b/src/baskerville/scripts/hound_snp.py @@ -49,6 +49,20 @@ def main(): default=None, help="Genome FASTA [Default: %default]", ) + parser.add_option( + "--float16", + dest="float16", + default=False, + action="store_true", + help="Use mixed float16 precision [Default: %default]", + ) + parser.add_option( + "--indel_stitch", + dest="indel_stitch", + default=False, + action="store_true", + help="Stitch indel compensation shifts [Default: %default]", + ) parser.add_option( "-o", dest="out_dir", @@ -79,7 +93,7 @@ def main(): parser.add_option( "--stats", dest="snp_stats", - default="logSAD", + default="logSUM", help="Comma-separated list of stats to save. [Default: %default]", ) parser.add_option( @@ -110,6 +124,13 @@ def main(): action="store_true", help="Only run on GPU", ) + parser.add_option( + "--tensorrt", + dest="tensorrt", + default=False, + action="store_true", + help="Model type is tensorrt optimized", + ) (options, args) = parser.parse_args() if options.gcs: @@ -117,10 +138,11 @@ def main(): gcs_output_dir = options.out_dir temp_dir = tempfile.mkdtemp() # create a temp dir for output out_dir = temp_dir + "/output_dir" - if not os.path.isdir(out_dir): - os.mkdir(out_dir) options.out_dir = out_dir + # is this here for GCS? + os.makedirs(options.out_dir, exist_ok=True) + if len(args) == 3: # single worker params_file = args[0] @@ -162,6 +184,14 @@ def main(): else: parser.error("Must provide parameters and model files and QTL VCF file") + # check if the model type is correct + if options.tensorrt: + if model_file.endswith(".h5"): + raise SystemExit("Model type is tensorrt but model file is keras") + is_dir_model = True + else: + is_dir_model = False + if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) @@ -173,11 +203,13 @@ def main(): ################################################################# # check if the program is run on GPU, else quit physical_devices = tf.config.list_physical_devices() - # Check if a GPU is available gpu_available = any(device.device_type == "GPU" for device in physical_devices) - if gpu_available: print("Running on GPU") + if options.float16: + print("Using mixed precision") + policy = tf.keras.mixed_precision.Policy("mixed_float16") + tf.keras.mixed_precision.set_global_policy(policy) else: print("Running on CPU") if options.require_gpu: @@ -188,7 +220,7 @@ def main(): if options.gcs: params_file = download_rename_inputs(params_file, temp_dir) vcf_file = download_rename_inputs(vcf_file, temp_dir) - model_file = download_rename_inputs(model_file, temp_dir) + model_file = download_rename_inputs(model_file, temp_dir, is_dir_model) if options.genome_fasta is not None: options.genome_fasta = download_rename_inputs( options.genome_fasta, temp_dir diff --git a/src/baskerville/scripts/hound_snp_slurm.py b/src/baskerville/scripts/hound_snp_slurm.py index e2be5fd..a980e17 100755 --- a/src/baskerville/scripts/hound_snp_slurm.py +++ b/src/baskerville/scripts/hound_snp_slurm.py @@ -55,6 +55,13 @@ def main(): default=None, help="Genome FASTA for sequences [Default: %default]", ) + parser.add_option( + "--float16", + dest="float16", + default=False, + action="store_true", + help="Use mixed float16 precision [Default: %default]", + ) parser.add_option( "-o", dest="out_dir", diff --git a/src/baskerville/scripts/hound_snpgene.py b/src/baskerville/scripts/hound_snpgene.py new file mode 100755 index 0000000..f3c4ffa --- /dev/null +++ b/src/baskerville/scripts/hound_snpgene.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python +# Copyright 2023 Calico LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= +from optparse import OptionParser +import pdb +import os +import tempfile +import shutil +import tensorflow as tf + +from baskerville.snps import score_gene_snps +from baskerville.helpers.gcs_utils import ( + upload_folder_gcs, + download_rename_inputs, +) +from baskerville.helpers.utils import load_extra_options + +""" +hound_snpgene.py + +Compute variant effect predictions for SNPs in a VCF file, +with respect to gene exons in a GTF file +""" + + +def main(): + usage = "usage: %prog [options] " + parser = OptionParser(usage) + # parser.add_option( + # "-b", + # dest="bedgraph", + # default=False, + # action="store_true", + # help="Write ref/alt predictions as bedgraph [Default: %default]", + # ) + parser.add_option( + "-c", + dest="cluster_pct", + default=0, + type="float", + help="Cluster genes within a %% of the seq length to make a single ref pred [Default: %default]", + ) + parser.add_option( + "-f", + dest="genome_fasta", + default=None, + help="Genome FASTA [Default: %default]", + ) + parser.add_option( + "--float16", + dest="float16", + default=False, + action="store_true", + help="Use mixed float16 precision [Default: %default]", + ) + parser.add_option( + "-g", + dest="genes_gtf", + default="%s/genes/gencode41/gencode41_basic_nort.gtf" % os.environ["HG38"], + help="GTF for gene definition [Default %default]", + ) + parser.add_option( + "--indel_stitch", + dest="indel_stitch", + default=False, + action="store_true", + help="Stitch indel compensation shifts [Default: %default]", + ) + parser.add_option( + "-o", + dest="out_dir", + default="snpgene_out", + help="Output directory for tables and plots [Default: %default]", + ) + parser.add_option( + "-p", + dest="processes", + default=None, + type="int", + help="Number of processes, passed by multi script", + ) + parser.add_option( + "--rc", + dest="rc", + default=False, + action="store_true", + help="Average forward and reverse complement predictions [Default: %default]", + ) + parser.add_option( + "--shifts", + dest="shifts", + default="0", + type="str", + help="Ensemble prediction shifts [Default: %default]", + ) + parser.add_option( + "--span", + dest="span", + default=False, + action="store_true", + help="Aggregate entire gene span [Default: %default]", + ) + parser.add_option( + "--stats", + dest="snp_stats", + default="logSUM", + help="Comma-separated list of stats to save. [Default: %default]", + ) + parser.add_option( + "-t", + dest="targets_file", + default=None, + type="str", + help="File specifying target indexes and labels in table format", + ) + parser.add_option( + "-u", + dest="untransform_old", + default=False, + action="store_true", + help="Untransform old models [Default: %default]", + ) + parser.add_option( + "--gcs", + dest="gcs", + default=False, + action="store_true", + help="Input and output are in gcs", + ) + parser.add_option( + "--require_gpu", + dest="require_gpu", + default=False, + action="store_true", + help="Only run on GPU", + ) + parser.add_option( + "--tensorrt", + dest="tensorrt", + default=False, + action="store_true", + help="Model type is tensorrt optimized", + ) + (options, args) = parser.parse_args() + + if options.gcs: + """Assume that output_dir will be gcs""" + gcs_output_dir = options.out_dir + temp_dir = tempfile.mkdtemp() # create a temp dir for output + out_dir = temp_dir + "/output_dir" + options.out_dir = out_dir + + # is this here for GCS? + os.makedirs(options.out_dir, exist_ok=True) + + if len(args) == 3: + # single worker + params_file = args[0] + model_file = args[1] + vcf_file = args[2] + + elif len(args) == 4: + # multi separate + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + + # save out dir + out_dir = options.out_dir + + # load options + if options.gcs: + options_pkl_file = download_rename_inputs(options_pkl_file, temp_dir) + options = load_extra_options(options_pkl_file, options) + # update output directory + options.out_dir = out_dir + + elif len(args) == 5: + # multi worker + options_pkl_file = args[0] + params_file = args[1] + model_file = args[2] + vcf_file = args[3] + worker_index = int(args[4]) + + # load options + if options.gcs: + options_pkl_file = download_rename_inputs(options_pkl_file, temp_dir) + options = load_extra_options(options_pkl_file, options) + # update output directory + options.out_dir = "%s/job%d" % (options.out_dir, worker_index) + + else: + parser.error("Must provide parameters and model files and QTL VCF file") + + # check if the model type is correct + if options.tensorrt: + if model_file.endswith(".h5"): + raise SystemExit("Model type is tensorrt but model file is keras") + is_dir_model = True + else: + is_dir_model = False + + if not os.path.isdir(options.out_dir): + os.mkdir(options.out_dir) + + options.shifts = [int(shift) for shift in options.shifts.split(",")] + options.snp_stats = options.snp_stats.split(",") + if options.targets_file is None: + parser.error("Must provide targets file") + + ################################################################# + # check if the program is run on GPU, else quit + physical_devices = tf.config.list_physical_devices() + gpu_available = any(device.device_type == "GPU" for device in physical_devices) + if gpu_available: + print("Running on GPU") + if options.float16: + print("Using mixed precision") + policy = tf.keras.mixed_precision.Policy("mixed_float16") + tf.keras.mixed_precision.set_global_policy(policy) + else: + print("Running on CPU") + if options.require_gpu: + raise SystemExit("Job terminated because it's running on CPU") + + ################################################################# + # download input files from gcs to a local file + if options.gcs: + params_file = download_rename_inputs(params_file, temp_dir) + vcf_file = download_rename_inputs(vcf_file, temp_dir) + model_file = download_rename_inputs(model_file, temp_dir, is_dir_model) + if options.genome_fasta is not None: + options.genome_fasta = download_rename_inputs( + options.genome_fasta, temp_dir + ) + if options.targets_file is not None: + options.targets_file = download_rename_inputs( + options.targets_file, temp_dir + ) + + ################################################################# + # calculate SAD scores: + if options.processes is not None: + score_gene_snps(params_file, model_file, vcf_file, worker_index, options) + else: + score_gene_snps(params_file, model_file, vcf_file, 0, options) + + # if the output dir is in gcs, sync it up + if options.gcs: + upload_folder_gcs(options.out_dir, gcs_output_dir) + if os.path.isdir(temp_dir): + shutil.rmtree(temp_dir) # clean up temp dir + + +################################################################################ +# __main__ +################################################################################ +if __name__ == "__main__": + main() diff --git a/src/baskerville/scripts/hound_train.py b/src/baskerville/scripts/hound_train.py index 032d6f0..cec1dcf 100755 --- a/src/baskerville/scripts/hound_train.py +++ b/src/baskerville/scripts/hound_train.py @@ -56,6 +56,12 @@ def main(): default="train_out", help="Output directory [Default: %(default)s]", ) + parser.add_argument( + "-l", + "--log_dir", + default="log_out", + help="Tensorboard log directory [Default: %(default)s]", + ) parser.add_argument( "--restore", default=None, @@ -150,7 +156,7 @@ def main(): # initialize trainer seqnn_trainer = trainer.Trainer( - params_train, train_data, eval_data, args.out_dir + params_train, train_data, eval_data, args.out_dir, args.log_dir ) # compile model @@ -182,6 +188,7 @@ def main(): train_data, eval_data, args.out_dir, + args.log_dir, strategy, params_train["num_gpu"], args.keras_fit, diff --git a/src/baskerville/seqnn.py b/src/baskerville/seqnn.py index 82db788..87aa223 100644 --- a/src/baskerville/seqnn.py +++ b/src/baskerville/seqnn.py @@ -925,7 +925,12 @@ def __call__(self, x, head_i=None, dtype="float32"): else: model = self.model - return model(x).numpy().astype(dtype) + preds = model(x).numpy().astype(dtype) + # if isinstance(x, np.ndarray): + # preds = model(x).numpy().astype(dtype) + # else: + # preds = model(x) + return preds def predict( self, @@ -968,7 +973,7 @@ def predict( preds = model.predict_generator(dataset, **kwargs).astype(dtype) elif stream: preds = [] - for x, y in seq_data.dataset: + for x, y in dataset: yh = model.predict(x, **kwargs) if step > 1: yh = yh[:, step_i, :] diff --git a/src/baskerville/snps.py b/src/baskerville/snps.py index 6ecc4b4..b4fce23 100644 --- a/src/baskerville/snps.py +++ b/src/baskerville/snps.py @@ -1,3 +1,4 @@ +import concurrent import json import pdb import sys @@ -5,14 +6,17 @@ import h5py import numpy as np import pandas as pd +import pybedtools import pysam from scipy.special import rel_entr from tqdm import tqdm from baskerville import dna from baskerville import dataset +from baskerville.gene import Transcriptome from baskerville import seqnn from baskerville import vcf as bvcf +from baskerville.helpers.trt_optimized_model import OptimizedModel def score_snps(params_file, model_file, vcf_file, worker_index, options): @@ -62,19 +66,32 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): ################################################################# # setup model - # can we sum on GPU? - sum_length = options.snp_stats == "SAD" - seqnn_model = seqnn.SeqNN(params_model) - seqnn_model.restore(model_file) - seqnn_model.build_slice(targets_df.index) - if sum_length: - seqnn_model.build_sad() - seqnn_model.build_ensemble(options.rc) + + # load model + sum_length = options.snp_stats == "SAD" + if options.tensorrt: + seqnn_model.model = OptimizedModel(model_file, seqnn_model.strand_pair) + input_shape = tuple(seqnn_model.model.loaded_model_fn.inputs[0].shape.as_list()) + else: + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + if sum_length: + seqnn_model.build_sad() + seqnn_model.build_ensemble(options.rc) + input_shape = seqnn_model.model.input_shape + + # make dummy predictions to warm up model + dummy_input_shape = (1,) + input_shape[1:] + dummy_input = np.random.random(dummy_input_shape).astype(np.float32) + seqnn_model(dummy_input) # shift outside seqnn num_shifts = len(options.shifts) targets_length = seqnn_model.target_lengths[0] + targets_length = seqnn_model.target_lengths[0] + model_stride = seqnn_model.model_strides[0] + model_crop = seqnn_model.target_crops[0] * model_stride ################################################################# # load SNPs @@ -133,106 +150,398 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options): num_shifts, ) + # CPU computation + def score_write(ref_preds, alt_preds, si): + scores = compute_scores( + ref_preds, alt_preds, options.snp_stats, strand_transform + ) + for snp_stat in options.snp_stats: + scores_out[snp_stat][si] = scores[snp_stat] + + if options.untransform_old: + untransform = dataset.untransform_preds1 + else: + untransform = dataset.untransform_preds + # SNP index si = 0 - for sc in tqdm(snp_clusters): - snp_1hot_list = sc.get_1hots(genome_open) - ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) - - # predict reference - ref_preds = [] - for shift in options.shifts: - ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift) - ref_preds_shift = seqnn_model(ref_1hot_shift)[0] - - # untransform predictions - if options.targets_file is not None: - if options.untransform_old: - ref_preds_shift = dataset.untransform_preds1( - ref_preds_shift, targets_df - ) + with concurrent.futures.ThreadPoolExecutor() as executor: + # initialize 1 hot encoding + sc0 = snp_clusters[0] + s1l = executor.submit(sc0.get_1hots, genome_open) + + for ci, sc in enumerate(tqdm(snp_clusters)): + # pull latest 1 hot encoding + snp_1hot_list = s1l.result() + ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) + + # submit next 1 hot encoding + if ci + 1 < len(snp_clusters): + sc1 = snp_clusters[ci + 1] + s1l = executor.submit(sc1.get_1hots, genome_open) + + # predict reference + ref_preds = [] + for shift in options.shifts: + ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift) + ref_preds_shift = seqnn_model(ref_1hot_shift)[0] + + # untransform predictions + if options.targets_file is None: + ref_preds.append(ref_preds_shift) else: - ref_preds_shift = dataset.untransform_preds( - ref_preds_shift, targets_df - ) + rpsf = executor.submit(untransform, ref_preds_shift, targets_df) + ref_preds.append(rpsf) + + for ai, alt_1hot in enumerate(snp_1hot_list[1:]): + alt_1hot = np.expand_dims(alt_1hot, axis=0) - # sum strand pairs - if strand_transform is not None: - ref_preds_shift = ref_preds_shift * strand_transform + # add left/right shifts for indels + indel_size = sc.snps[ai].indel_size() + if indel_size == 0: + alt_shifts = options.shifts + else: + alt_shifts = [] + for shift in options.shifts: + alt_shifts.append(shift) + alt_shifts.append(shift - indel_size) + + # predict alternate + alt_preds = [] + for shift in alt_shifts: + alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift) + alt_preds_shift = seqnn_model(alt_1hot_shift)[0] + + # untransform predictions + if options.targets_file is None: + alt_preds.append(alt_preds_shift) + else: + apsf = executor.submit(untransform, alt_preds_shift, targets_df) + alt_preds.append(apsf) - # save shift prediction - ref_preds.append(ref_preds_shift) - ref_preds = np.array(ref_preds) + # result + if options.targets_file is not None: + # get result, only if not already gotten + if isinstance(ref_preds[0], concurrent.futures.Future): + ref_preds = [rpsf.result() for rpsf in ref_preds] + alt_preds = [apsf.result() for apsf in alt_preds] + + # stitch indel shifts + if indel_size != 0 and options.indel_stitch: + snp_seq_pos = sc.snps[ai].pos - sc.start - model_crop + snp_seq_bin = snp_seq_pos // model_stride + alt_preds = stitch_preds(alt_preds, options.shifts, snp_seq_bin) + + # flip reference and alternate + if snps[si].flipped: + rp_snp = np.array(alt_preds) + ap_snp = np.array(ref_preds) + else: + rp_snp = np.array(ref_preds) + ap_snp = np.array(alt_preds) - ai = 0 - for alt_1hot in snp_1hot_list[1:]: - alt_1hot = np.expand_dims(alt_1hot, axis=0) + # repeat reference predictions for indels w/o stitching + if indel_size != 0 and not options.indel_stitch: + rp_snp = np.repeat(rp_snp, 2, axis=0) - # add compensation shifts for indels - indel_size = sc.snps[ai].indel_size() - if indel_size == 0: - alt_shifts = options.shifts - else: - # repeat reference predictions - ref_preds = np.repeat(ref_preds, 2, axis=0) + # write SNP + if sum_length: + write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats) + else: + executor.submit(score_write, rp_snp, ap_snp, si) + + # update SNP index + si += 1 + + # close genome + genome_open.close() + + # compute SAD distributions across variants + write_pct(scores_out, options.snp_stats) + scores_out.close() + + +def score_gene_snps(params_file, model_file, vcf_file, worker_index, options): + """ + Score SNPs in a VCF file with a SeqNN model. + + :param params_file: Model parameters + :param model_file: Saved model weights + :param vcf_file: VCF + :param worker_index + :param options: options from cmd args + :return: + """ + + ################################################################# + # read parameters and targets + + # read model parameters + with open(params_file) as params_open: + params = json.load(params_open) + params_model = params["model"] + + # read targets + if options.targets_file is None: + print("Must provide targets file to clarify stranded datasets", file=sys.stderr) + exit(1) + targets_df = pd.read_csv(options.targets_file, sep="\t", index_col=0) + + # handle strand pairs + if "strand_pair" in targets_df.columns: + # prep strand + targets_strand_df = dataset.targets_prep_strand(targets_df) + + # set strand pairs (using new indexing) + orig_new_index = dict(zip(targets_df.index, np.arange(targets_df.shape[0]))) + targets_strand_pair = np.array( + [orig_new_index[ti] for ti in targets_df.strand_pair] + ) + params_model["strand_pair"] = [targets_strand_pair] + else: + targets_strand_df = targets_df + + # construct strand sum transform + plus_mask = targets_df.strand != "-" + minus_mask = targets_df.strand != "+" + + ################################################################# + # setup model + + seqnn_model = seqnn.SeqNN(params_model) + + # load model + if options.tensorrt: + seqnn_model.model = OptimizedModel(model_file, seqnn_model.strand_pair) + input_shape = tuple(seqnn_model.model.loaded_model_fn.inputs[0].shape.as_list()) + else: + seqnn_model.restore(model_file) + seqnn_model.build_slice(targets_df.index) + seqnn_model.build_ensemble(options.rc) + input_shape = seqnn_model.model.input_shape + + # make dummy predictions to warm up model + dummy_input_shape = (1,) + input_shape[1:] + dummy_input = np.random.random(dummy_input_shape).astype(np.float32) + seqnn_model(dummy_input) + + # shift outside seqnn + num_shifts = len(options.shifts) + targets_length = seqnn_model.target_lengths[0] + model_stride = seqnn_model.model_strides[0] + model_crop = seqnn_model.target_crops[0] * model_stride + + ################################################################# + # load SNPs + + # filter for worker SNPs + if options.processes is None: + start_i = None + end_i = None + else: + # determine boundaries + num_snps = bvcf.vcf_count(vcf_file) + worker_bounds = np.linspace(0, num_snps, options.processes + 1, dtype="int") + start_i = worker_bounds[worker_index] + end_i = worker_bounds[worker_index + 1] - # add compensation shifts - alt_shifts = [] - for shift in options.shifts: - alt_shifts.append(shift) - alt_shifts.append(shift - indel_size) + # read SNPs + snps = bvcf.vcf_snps( + vcf_file, + require_sorted=True, + flip_ref=False, + validate_ref_fasta=options.genome_fasta, + start_i=start_i, + end_i=end_i, + ) - # predict alternate - alt_preds = [] - for shift in alt_shifts: - alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift) - alt_preds_shift = seqnn_model(alt_1hot_shift)[0] + # read genes + transcriptome = Transcriptome(options.genes_gtf) + + # cluster genes + genesnp_clusters = cluster_genes( + transcriptome, params_model["seq_length"], options.cluster_pct + ) + + # delimit sequence boundaries + [gsc.delimit(params_model["seq_length"], model_crop) for gsc in genesnp_clusters] + + # assign SNPs to genes + map_snps_genes(snps, genesnp_clusters) + + # remove genes w/o SNPs + genesnp_clusters = [gsc for gsc in genesnp_clusters if len(gsc.snps) > 0] + + # open genome FASTA + genome_open = pysam.Fastafile(options.genome_fasta) + + ################################################################# + # predict SNP scores, write output + + # setup output + scores_out = initialize_output_h5( + options.out_dir, + options.snp_stats, + snps, + targets_length, + targets_strand_df, + num_shifts, + genesnp_clusters, + ) + + # CPU computation + def score_write(ref_preds, alt_preds, gene_id, snp_id): + scores = compute_scores(ref_preds, alt_preds, options.snp_stats) + for snp_stat in options.snp_stats: + stat_out = scores_out.require_group(snp_stat) + snp_out = stat_out.require_group(snp_id) + snp_out.create_dataset(gene_id, data=scores[snp_stat], dtype="float16") + + if options.untransform_old: + untransform = dataset.untransform_preds1 + else: + untransform = dataset.untransform_preds + + with concurrent.futures.ThreadPoolExecutor() as executor: + for gsc in tqdm(genesnp_clusters): + snp_1hot_list = gsc.get_1hots(genome_open) + ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0) + + # predict reference + ref_preds = [] + for shift in options.shifts: + ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift) + ref_preds_shift = seqnn_model(ref_1hot_shift)[0] # untransform predictions + if options.targets_file is None: + ref_preds.append(ref_preds_shift) + else: + rpsf = executor.submit(untransform, ref_preds_shift, targets_df) + ref_preds.append(rpsf) + + for ai, alt_1hot in enumerate(snp_1hot_list[1:]): + alt_1hot = np.expand_dims(alt_1hot, axis=0) + + # add left/right shifts for indels + indel_size = gsc.snps[ai].indel_size() + if indel_size == 0: + alt_shifts = options.shifts + else: + alt_shifts = [] + for shift in options.shifts: + alt_shifts.append(shift) + alt_shifts.append(shift - indel_size) + + # predict alternate + alt_preds = [] + for shift in alt_shifts: + alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift) + alt_preds_shift = seqnn_model(alt_1hot_shift)[0] + + # untransform predictions + if options.targets_file is None: + alt_preds.append(alt_preds_shift) + else: + apsf = executor.submit(untransform, alt_preds_shift, targets_df) + alt_preds.append(apsf) + + # result if options.targets_file is not None: - if options.untransform_old: - alt_preds_shift = dataset.untransform_preds1( - alt_preds_shift, targets_df + # get result, only if not already gotten + if isinstance(ref_preds[0], concurrent.futures.Future): + ref_preds = [rpsf.result() for rpsf in ref_preds] + alt_preds = [apsf.result() for apsf in alt_preds] + + # stitch indel shifts + if indel_size != 0 and options.indel_stitch: + snp_seq_pos = gsc.snps[ai].pos - gsc.start - model_crop + snp_seq_bin = snp_seq_pos // model_stride + alt_preds = stitch_preds(alt_preds, options.shifts, snp_seq_bin) + + # flip reference and alternate + if gsc.snps[ai].flipped: + rp_snp = np.array(alt_preds) + ap_snp = np.array(ref_preds) + else: + rp_snp = np.array(ref_preds) + ap_snp = np.array(alt_preds) + + # repeat reference predictions for indels w/o stitching + if indel_size != 0 and not options.indel_stitch: + rp_snp = np.repeat(rp_snp, 2, axis=0) + + for gene in gsc.genes: + # slice gene positions + gene_slice = gene.output_slice( + gsc.pstart, gsc.pend - gsc.pstart, model_stride + ) + if len(gene_slice) == 0: + print( + f"WARNING: {gene.kv['gene_id']} exons fall outside prediction boundaries." ) else: - alt_preds_shift = dataset.untransform_preds( - alt_preds_shift, targets_df + rp_gene = rp_snp[:, gene_slice] + ap_gene = ap_snp[:, gene_slice] + + # slice gene strand + if gene.strand == "+": + rp_gene = rp_gene[..., plus_mask] + ap_gene = ap_gene[..., plus_mask] + else: + rp_gene = rp_gene[..., minus_mask] + ap_gene = ap_gene[..., minus_mask] + + # write SNP + executor.submit( + score_write, + rp_gene, + ap_gene, + gene.kv["gene_id"], + gsc.snps[ai].rsid, ) - # sum strand pairs - if strand_transform is not None: - alt_preds_shift = alt_preds_shift * strand_transform + # close open files + genome_open.close() + scores_out.close() - # save shift prediction - alt_preds.append(alt_preds_shift) - # flip reference and alternate - if snps[si].flipped: - rp_snp = np.array(alt_preds) - ap_snp = np.array(ref_preds) - else: - rp_snp = np.array(ref_preds) - ap_snp = np.array(alt_preds) +def cluster_genes(transcriptome, seq_length: int, center_pct: float): + """Cluster genes into regions that will satisfy the required center_pct. - # write SNP - if sum_length: - write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats) + Args: + transcriptome (Transcriptome): Transcriptome object. + seq_length (int): Sequence length. + center_pct (float): Percent of sequence length to cluster genes. + """ + valid_gene_distance = int(seq_length * center_pct) + + gene_clusters = [] + + # re-sort genes by midpoint + chromosomes = set([gene.chrom for gene in transcriptome.genes.values()]) + for chrom in chromosomes: + gene_pos = [] + gene_objs = [] + for gene in transcriptome.genes.values(): + if gene.chrom == chrom: + gene_pos.append(gene.midpoint()) + gene_objs.append(gene) + + cluster_pos0 = -valid_gene_distance + for gi in np.argsort(gene_pos): + gene = gene_objs[gi] + if gene_pos[gi] < cluster_pos0 + valid_gene_distance: + # append to latest cluster + gene_clusters[-1].add_gene(gene) else: - # write_snp_len(rp_snp, ap_snp, scores_out, si, options.snp_stats) - scores = compute_scores(rp_snp, ap_snp, options.snp_stats) - for snp_stat in options.snp_stats: - scores_out[snp_stat][si] = scores[snp_stat] - - # update SNP index - si += 1 - - # close genome - genome_open.close() + # initialize new cluster + gene_clusters.append(GeneSNPCluster()) + gene_clusters[-1].add_gene(gene) + cluster_pos0 = gene_pos[gi] - # compute SAD distributions across variants - write_pct(scores_out, options.snp_stats) - scores_out.close() + return gene_clusters def cluster_snps(snps, seq_len: int, center_pct: float): @@ -263,13 +572,14 @@ def cluster_snps(snps, seq_len: int, center_pct: float): return snp_clusters -def compute_scores(ref_preds, alt_preds, snp_stats): +def compute_scores(ref_preds, alt_preds, snp_stats, strand_transform=None): """Compute SNP scores from reference and alternative predictions. Args: ref_preds (np.array): Reference allele predictions. alt_preds (np.array): Alternative allele predictions. snp_stats [str]: List of SAD stats to compute. + strand_transform (scipy.sparse): Strand transform matrix. """ num_shifts, seq_length, num_targets = ref_preds.shape @@ -279,98 +589,75 @@ def compute_scores(ref_preds, alt_preds, snp_stats): ref_preds_sqrt = np.sqrt(ref_preds) alt_preds_sqrt = np.sqrt(alt_preds) - # sum across length - ref_preds_sum = ref_preds.sum(axis=(0, 1)) - alt_preds_sum = alt_preds.sum(axis=(0, 1)) - ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1)) - alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1)) - ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1)) - alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1)) + # sum across length, mean across shifts + ref_preds_sum = ref_preds.sum(axis=(0, 1)) / num_shifts + alt_preds_sum = alt_preds.sum(axis=(0, 1)) / num_shifts + ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1)) / num_shifts + alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1)) / num_shifts + ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1)) / num_shifts + alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1)) / num_shifts # difference altref_diff = alt_preds - ref_preds - altref_adiff = np.abs(altref_diff) altref_log_diff = alt_preds_log - ref_preds_log - altref_log_adiff = np.abs(altref_log_diff) altref_sqrt_diff = alt_preds_sqrt - ref_preds_sqrt - altref_sqrt_adiff = np.abs(altref_sqrt_diff) # initialize scores dict scores = {} + def strand_clip_save(key, score, d2=False): + if strand_transform is not None: + if d2: + score = np.power(score, 2) + score = score @ strand_transform + score = np.sqrt(score) + else: + score = score @ strand_transform + score = np.clip(score, np.finfo(np.float16).min, np.finfo(np.float16).max) + scores[key] = score.astype("float16") + # compare reference to alternative via sum subtraction if "SUM" in snp_stats: sad = alt_preds_sum - ref_preds_sum - sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["SUM"] = sad.astype("float16") + strand_clip_save("SUM", sad) if "logSUM" in snp_stats: log_sad = alt_preds_log_sum - ref_preds_log_sum - log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["logSUM"] = log_sad.astype("float16") + strand_clip_save("logSUM", log_sad) if "sqrtSUM" in snp_stats: sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum - sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["sqrtSUM"] = sqrt_sad.astype("float16") - - # TEMP during name change - if "SAD" in snp_stats: - sad = alt_preds_sum - ref_preds_sum - sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["SAD"] = sad.astype("float16") - if "logSAD" in snp_stats: - log_sad = alt_preds_log_sum - ref_preds_log_sum - log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["logSAD"] = log_sad.astype("float16") - if "sqrtSAD" in snp_stats: - sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum - sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores["sqrtSAD"] = sqrt_sad.astype("sqrtSAD") + strand_clip_save("sqrtSUM", sqrt_sad) # compare reference to alternative via max subtraction if "SAX" in snp_stats: + altref_adiff = np.abs(altref_diff) sax = [] for s in range(num_shifts): max_i = np.argmax(altref_adiff[s], axis=0) sax.append(altref_diff[s, max_i, np.arange(num_targets)]) sax = np.array(sax).mean(axis=0) - scores["SAX"] = sax.astype("float16") + strand_clip_save("SAX", sax) # L1 norm of difference vector if "D1" in snp_stats: - sad_d1 = altref_adiff.sum(axis=1) - sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_d1 = sad_d1.mean(axis=0) - scores["D1"] = sad_d1.mean().astype("float16") + sad_d1 = np.linalg.norm(altref_diff, ord=1, axis=1) + strand_clip_save("D1", sad_d1.mean(axis=0)) if "logD1" in snp_stats: - log_d1 = altref_log_adiff.sum(axis=1) - log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - log_d1 = log_d1.mean(axis=0) - scores["logD1"] = log_d1.astype("float16") + log_d1 = np.linalg.norm(altref_log_diff, ord=1, axis=1) + strand_clip_save("logD1", log_d1.mean(axis=0)) if "sqrtD1" in snp_stats: - sqrt_d1 = altref_sqrt_adiff.sum(axis=1) - sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sqrt_d1 = sqrt_d1.mean(axis=0) - scores["sqrtD1"] = sqrt_d1.astype("float16") + sqrt_d1 = np.linalg.norm(altref_sqrt_diff, ord=1, axis=1) + strand_clip_save("sqrtD1", sqrt_d1.mean(axis=0)) # L2 norm of difference vector if "D2" in snp_stats: - altref_diff2 = np.power(altref_diff, 2) - sad_d2 = np.sqrt(altref_diff2.sum(axis=1)) - sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_d2 = sad_d2.mean(axis=0) - scores["D2"] = sad_d2.astype("float16") + sad_d2 = np.linalg.norm(altref_diff, ord=2, axis=1) + strand_clip_save("D2", sad_d2.mean(axis=0), d2=True) if "logD2" in snp_stats: - altref_log_diff2 = np.power(altref_log_diff, 2) - log_d2 = np.sqrt(altref_log_diff2.sum(axis=1)) - log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - log_d2 = log_d2.mean(axis=0) - scores["logD2"] = log_d2.astype("float16") + log_d2 = np.linalg.norm(altref_log_diff, ord=2, axis=1) + strand_clip_save("logD2", log_d2.mean(axis=0), d2=True) if "sqrtD2" in snp_stats: - altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2) - sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=1)) - sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sqrt_d2 = sqrt_d2.mean(axis=0) - scores["sqrtD2"] = sqrt_d2.astype("float16") + sqrt_d2 = np.linalg.norm(altref_sqrt_diff, ord=2, axis=1) + strand_clip_save("sqrtD2", sqrt_d2.mean(axis=0), d2=True) if "JS" in snp_stats: # normalized scores @@ -387,7 +674,9 @@ def compute_scores(ref_preds, alt_preds, snp_stats): alt_ref_entr = rel_entr(alt_preds_norm[s], ref_preds_norm[s]).sum(axis=0) js_dist.append((ref_alt_entr + alt_ref_entr) / 2) js_dist = np.mean(js_dist, axis=0) - scores["JS"] = js_dist.astype("float16") + # handling strand this way is incorrect, but I'm punting for now + strand_clip_save("JS", js_dist) + if "logJS" in snp_stats: # normalized scores pseudocounts = np.percentile(ref_preds_log, 25, axis=0) @@ -399,15 +688,14 @@ def compute_scores(ref_preds, alt_preds, snp_stats): # compare normalized JS log_js_dist = [] for s in range(num_shifts): - ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum( - axis=0 - ) - alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum( - axis=0 - ) + rps = ref_preds_log_norm[s] + aps = alt_preds_log_norm[s] + ref_alt_entr = rel_entr(rps, aps).sum(axis=0) + alt_ref_entr = rel_entr(aps, rps).sum(axis=0) log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2) log_js_dist = np.mean(log_js_dist, axis=0) - scores["logJS"] = log_js_dist.astype("float16") + # handling strand this way is incorrect, but I'm punting for now + strand_clip_save("logJS", log_js_dist) # predictions if "REF" in snp_stats: @@ -425,7 +713,13 @@ def compute_scores(ref_preds, alt_preds, snp_stats): def initialize_output_h5( - out_dir, snp_stats, snps, targets_length, targets_df, num_shifts + out_dir, + snp_stats, + snps, + targets_length, + targets_df, + num_shifts, + geneseq_clusters=None, ): """Initialize an output HDF5 file for SAD stats. @@ -436,6 +730,7 @@ def initialize_output_h5( targets_length (int): Targets' sequence length targets_df (pd.DataFrame): Targets DataFrame. num_shifts (int): Number of shifts. + geneseq_clusters [GeneSNPCluster]: Gene sequence clusters. """ num_targets = targets_df.shape[0] @@ -476,18 +771,26 @@ def initialize_output_h5( "target_labels", data=np.array(targets_df.description, "S") ) - # initialize SAD stats - for snp_stat in snp_stats: - if snp_stat in ["REF", "ALT"]: - scores_out.create_dataset( - snp_stat, - shape=(num_snps, num_shifts, targets_length, num_targets), - dtype="float16", - ) - else: - scores_out.create_dataset( - snp_stat, shape=(num_snps, num_targets), dtype="float16" - ) + if geneseq_clusters is not None: + # write genes + gene_ids = [] + for gsc in geneseq_clusters: + gene_ids.extend([gene.kv["gene_id"] for gene in gsc.genes]) + gene_ids = np.array(gene_ids, "S") + scores_out.create_dataset("gene", data=gene_ids) + else: + # initialize stats + for snp_stat in snp_stats: + if snp_stat in ["REF", "ALT"]: + scores_out.create_dataset( + snp_stat, + shape=(num_snps, num_shifts, targets_length, num_targets), + dtype="float16", + ) + else: + scores_out.create_dataset( + snp_stat, shape=(num_snps, num_targets), dtype="float16" + ) return scores_out @@ -540,6 +843,55 @@ def make_alt_1hot(ref_1hot, snp_seq_pos, ref_allele, alt_allele): return alt_1hot +def make_gene_bedt(genesnp_clusters): + """Make a BedTool object for all gene sequences.""" + gene_bed_lines = [] + for gi, gsc in enumerate(genesnp_clusters): + geneseq_start = max(0, gsc.start) + gene_bed_lines.append("%s %d %d %d" % (gsc.chr, geneseq_start, gsc.end, gi)) + gene_bedt = pybedtools.BedTool("\n".join(gene_bed_lines), from_string=True) + return gene_bedt + + +def make_snp_bedt(snps): + """Make a BedTool object for all SNPs""" + snp_bed_lines = [] + for si, snp in enumerate(snps): + snp_bed_lines.append("%s %d %d %d" % (snp.chr, snp.pos - 1, snp.pos, si)) + snp_bedt = pybedtools.BedTool("\n".join(snp_bed_lines), from_string=True) + return snp_bedt + + +def map_snps_genes(snps, genesnp_clusters): + """Map SNPs to gene sequences.""" + geneseq_bedt = make_gene_bedt(genesnp_clusters) + snp_bedt = make_snp_bedt(snps) + + for overlap in geneseq_bedt.intersect(snp_bedt, wa=True, wb=True): + gchr, gstart, gend, gi, schr, spos, send, si = overlap + gi, si = int(gi), int(si) + genesnp_clusters[gi].add_snp(snps[si]) + + +def stitch_preds(preds, shifts, pos=None): + """Stitch indel left and right compensation shifts. + + Args: + preds [np.array]: List of predictions. + shifts [int]: List of shifts. + pos (int): SNP position to stitch at. + """ + if pos is None: + pos = preds[0].shape[0] // 2 + preds_stitch = [] + for hi, shift in enumerate(shifts): + hil = 2 * hi + hir = hil + 1 + preds_stitch_i = np.concatenate((preds[hil][:pos], preds[hir][pos:]), axis=0) + preds_stitch.append(preds_stitch_i) + return preds_stitch + + def write_pct(scores_out, snp_stats): """Compute percentile values for each target and write to HDF5. @@ -588,151 +940,6 @@ def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats): scores_out["SAD"][si] = sad.astype("float16") -def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats): - """Write SNP predictions to HDF, assuming the length dimension has - been maintained. - - Args: - ref_preds (np.array): Reference allele predictions. - alt_preds (np.array): Alternative allele predictions. - scores_out (h5py.File): Output HDF5 file. - si (int): SNP index. - snp_stats [str]: List of SAD stats to compute. - """ - num_shifts, seq_length, num_targets = ref_preds.shape - - # log/sqrt - ref_preds_log = np.log2(ref_preds + 1) - alt_preds_log = np.log2(alt_preds + 1) - ref_preds_sqrt = np.sqrt(ref_preds) - alt_preds_sqrt = np.sqrt(alt_preds) - - # sum across length - ref_preds_sum = ref_preds.sum(axis=(0, 1)) - alt_preds_sum = alt_preds.sum(axis=(0, 1)) - ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1)) - alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1)) - ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1)) - alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1)) - - # difference - altref_diff = alt_preds - ref_preds - altref_adiff = np.abs(altref_diff) - altref_log_diff = alt_preds_log - ref_preds_log - altref_log_adiff = np.abs(altref_log_diff) - altref_sqrt_diff = alt_preds_sqrt - ref_preds_sqrt - altref_sqrt_adiff = np.abs(altref_sqrt_diff) - - # compare reference to alternative via sum subtraction - if "SAD" in snp_stats: - sad = alt_preds_sum - ref_preds_sum - sad = np.clip(sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores_out["SAD"][si] = sad.astype("float16") - if "logSAD" in snp_stats: - log_sad = alt_preds_log_sum - ref_preds_log_sum - log_sad = np.clip(log_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores_out["logSAD"][si] = log_sad.astype("float16") - if "sqrtSAD" in snp_stats: - sqrt_sad = alt_preds_sqrt_sum - ref_preds_sqrt_sum - sqrt_sad = np.clip(sqrt_sad, np.finfo(np.float16).min, np.finfo(np.float16).max) - scores_out["sqrtSAD"][si] = sqrt_sad.astype("float16") - - # compare reference to alternative via max subtraction - if "SAX" in snp_stats: - sax = [] - for s in range(num_shifts): - max_i = np.argmax(altref_adiff[s], axis=0) - sax.append(altref_diff[s, max_i, np.arange(num_targets)]) - sax = np.array(sax).mean(axis=0) - scores_out["SAX"][si] = sax.astype("float16") - - # L1 norm of difference vector - if "D1" in snp_stats: - sad_d1 = altref_adiff.sum(axis=1) - sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_d1 = sad_d1.mean(axis=0) - scores_out["D1"][si] = sad_d1.mean().astype("float16") - if "logD1" in snp_stats: - log_d1 = altref_log_adiff.sum(axis=1) - log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - log_d1 = log_d1.mean(axis=0) - scores_out["logD1"][si] = log_d1.astype("float16") - if "sqrtD1" in snp_stats: - sqrt_d1 = altref_sqrt_adiff.sum(axis=1) - sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max) - sqrt_d1 = sqrt_d1.mean(axis=0) - scores_out["sqrtD1"][si] = sqrt_d1.astype("float16") - - # L2 norm of difference vector - if "D2" in snp_stats: - altref_diff2 = np.power(altref_diff, 2) - sad_d2 = np.sqrt(altref_diff2.sum(axis=1)) - sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sad_d2 = sad_d2.mean(axis=0) - scores_out["D2"][si] = sad_d2.astype("float16") - if "logD2" in snp_stats: - altref_log_diff2 = np.power(altref_log_diff, 2) - log_d2 = np.sqrt(altref_log_diff2.sum(axis=1)) - log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - log_d2 = log_d2.mean(axis=0) - scores_out["logD2"][si] = log_d2.astype("float16") - if "sqrtD2" in snp_stats: - altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2) - sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=1)) - sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max) - sqrt_d2 = sqrt_d2.mean(axis=0) - scores_out["sqrtD2"][si] = sqrt_d2.astype("float16") - - if "JS" in snp_stats: - # normalized scores - pseudocounts = np.percentile(ref_preds, 25, axis=1) - ref_preds_norm = ref_preds + pseudocounts - ref_preds_norm /= ref_preds_norm.sum(axis=1) - alt_preds_norm = alt_preds + pseudocounts - alt_preds_norm /= alt_preds_norm.sum(axis=1) - - # compare normalized JS - js_dist = [] - for s in range(num_shifts): - ref_alt_entr = rel_entr(ref_preds_norm[s], alt_preds_norm[s]).sum(axis=0) - alt_ref_entr = rel_entr(alt_preds_norm[s], ref_preds_norm[s]).sum(axis=0) - js_dist.append((ref_alt_entr + alt_ref_entr) / 2) - js_dist = np.mean(js_dist, axis=0) - scores_out["JS"][si] = js_dist.astype("float16") - if "logJS" in snp_stats: - # normalized scores - pseudocounts = np.percentile(ref_preds_log, 25, axis=0) - ref_preds_log_norm = ref_preds_log + pseudocounts - ref_preds_log_norm /= ref_preds_log_norm.sum(axis=0) - alt_preds_log_norm = alt_preds_log + pseudocounts - alt_preds_log_norm /= alt_preds_log_norm.sum(axis=0) - - # compare normalized JS - log_js_dist = [] - for s in range(num_shifts): - ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum( - axis=0 - ) - alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum( - axis=0 - ) - log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2) - log_js_dist = np.mean(log_js_dist, axis=0) - scores_out["logJS"][si] = log_js_dist.astype("float16") - - # predictions - if "REF" in snp_stats: - ref_preds = np.clip( - ref_preds, np.finfo(np.float16).min, np.finfo(np.float16).max - ) - scores_out["REF"][si] = ref_preds.astype("float16") - if "ALT" in snp_stats: - alt_preds = np.clip( - alt_preds, np.finfo(np.float16).min, np.finfo(np.float16).max - ) - scores_out["ALT"][si] = alt_preds.astype("float16") - - class SNPCluster: def __init__(self): self.snps = [] @@ -755,8 +962,8 @@ def delimit(self, seq_len): self.start = pos_mid - seq_len // 2 self.end = self.start + seq_len - for snp in self.snps: - snp.seq_pos = snp.pos - 1 - self.start + # for snp in self.snps: + # snp.seq_pos = snp.pos - 1 - self.start def get_1hots(self, genome_open): """Get list of one hot coded sequences.""" @@ -777,7 +984,8 @@ def get_1hots(self, genome_open): # verify reference alleles for snp in self.snps: ref_n = len(snp.ref_allele) - ref_snp = ref_seq[snp.seq_pos : snp.seq_pos + ref_n] + snp_pos = snp.pos - 1 - self.start + ref_snp = ref_seq[snp_pos : snp_pos + ref_n] if snp.ref_allele != ref_snp: print( "ERROR: %s does not match reference %s" % (snp, ref_snp), @@ -792,9 +1000,29 @@ def get_1hots(self, genome_open): # make alternative 1 hot coded sequences # (assuming SNP is 1-based indexed) for snp in self.snps: + snp_pos = snp.pos - 1 - self.start alt_1hot = make_alt_1hot( - ref_1hot, snp.seq_pos, snp.ref_allele, snp.alt_alleles[0] + ref_1hot, snp_pos, snp.ref_allele, snp.alt_alleles[0] ) seqs1_list.append(alt_1hot) return seqs1_list + + +class GeneSNPCluster(SNPCluster): + def __init__(self): + super().__init__() + self.genes = [] + + def add_gene(self, gene): + """Add gene to cluster.""" + self.genes.append(gene) + + def delimit(self, seq_len, crop=0): + """Delimit sequence boundaries.""" + self.chr = self.genes[0].chrom + midp = int(np.mean([g.midpoint() for g in self.genes])) + self.start = midp - seq_len // 2 + self.end = self.start + seq_len + self.pstart = self.start + crop + self.pend = self.end - crop diff --git a/src/baskerville/trainer.py b/src/baskerville/trainer.py index 3136642..69aa96f 100644 --- a/src/baskerville/trainer.py +++ b/src/baskerville/trainer.py @@ -21,7 +21,8 @@ import numpy as np import tensorflow as tf - +import tempfile +from baskerville.helpers.gcs_utils import is_gcs_path, upload_folder_gcs from baskerville import metrics from tensorflow.keras import mixed_precision @@ -32,6 +33,8 @@ def parse_loss( keras_fit: bool = True, spec_weight: float = 1, total_weight: float = 1, + weight_range: float = 1, + weight_exp: int = 1, ): """Parse loss function from label, strategy, and fitting method. @@ -56,7 +59,10 @@ def parse_loss( ) elif loss_label == "poisson_mn": loss_fn = metrics.PoissonMultinomial( - total_weight, reduction=tf.keras.losses.Reduction.NONE + total_weight=total_weight, + weight_range=weight_range, + weight_exp=weight_exp, + reduction=tf.keras.losses.Reduction.NONE, ) elif loss_label == "poisson_kl": loss_fn = metrics.PoissonKL( @@ -78,7 +84,11 @@ def parse_loss( elif loss_label == "poisson_kl": loss_fn = metrics.PoissonKL(spec_weight) elif loss_label == "poisson_mn": - loss_fn = metrics.PoissonMultinomial(total_weight) + loss_fn = metrics.PoissonMultinomial( + total_weight=total_weight, + weight_range=weight_range, + weight_exp=weight_exp, + ) else: loss_fn = tf.keras.losses.Poisson() @@ -104,6 +114,7 @@ def __init__( train_data, eval_data, out_dir: str, + log_dir: str, strategy=None, num_gpu: int = 1, keras_fit: bool = False, @@ -117,12 +128,22 @@ def __init__( if type(self.eval_data) is not list: self.eval_data = [self.eval_data] self.out_dir = out_dir + self.log_dir = log_dir self.strategy = strategy self.num_gpu = num_gpu self.batch_size = self.train_data[0].batch_size self.compiled = False self.loss_scale = loss_scale + # if log_dir is in gcs then create a local temp dir + if is_gcs_path(self.log_dir): + folder_name = "/".join(self.log_dir.split("/")[3:]) + self.log_dir = tempfile.mkdtemp() + "/" + folder_name + self.gcs_log_dir = log_dir + self.gcs = True + else: + self.gcs = False + # early stopping self.patience = self.params.get("patience", 20) @@ -142,9 +163,17 @@ def __init__( # loss self.spec_weight = self.params.get("spec_weight", 1) self.total_weight = self.params.get("total_weight", 1) + self.weight_range = self.params.get("weight_range", 1) + self.weight_exp = self.params.get("weight_exp", 1) self.loss = self.params.get("loss", "poisson").lower() self.loss_fn = parse_loss( - self.loss, self.strategy, keras_fit, self.spec_weight, self.total_weight + self.loss, + self.strategy, + keras_fit, + self.spec_weight, + self.total_weight, + self.weight_range, + self.weight_exp, ) # optimizer @@ -203,7 +232,7 @@ def fit_keras(self, seqnn_model): callbacks = [ early_stop, - tf.keras.callbacks.TensorBoard(self.out_dir), + tf.keras.callbacks.TensorBoard(self.log_dir, histogram_freq=1), tf.keras.callbacks.ModelCheckpoint("%s/model_check.h5" % self.out_dir), save_best, ] @@ -417,6 +446,12 @@ def eval_step1_distr(xd, yd): file.write('epoch\tbatch\tgpu_mem(GB)\n') first_step = True + # set up summary writer + train_log_dir = self.log_dir + "/train" + valid_log_dir = self.log_dir + "/valid" + train_summary_writer = tf.summary.create_file_writer(train_log_dir) + valid_summary_writer = tf.summary.create_file_writer(valid_log_dir) + for ei in range(epoch_start, self.train_epochs_max): if ei >= self.train_epochs_min and np.min(unimproved) > self.patience: break @@ -456,6 +491,13 @@ def eval_step1_distr(xd, yd): for di in range(self.num_datasets): print(" Data %d" % di, end="") model = seqnn_model.models[di] + with train_summary_writer.as_default(): + tf.summary.scalar( + "loss", train_loss[di].result().numpy(), step=ei + ) + tf.summary.scalar("r", train_r[di].result().numpy(), step=ei) + tf.summary.scalar("r2", train_r2[di].result().numpy(), step=ei) + train_summary_writer.flush() # print training accuracy print( @@ -477,6 +519,14 @@ def eval_step1_distr(xd, yd): else: eval_step1_distr(x, y) + with valid_summary_writer.as_default(): + tf.summary.scalar( + "loss", valid_loss[di].result().numpy(), step=ei + ) + tf.summary.scalar("r", valid_r[di].result().numpy(), step=ei) + tf.summary.scalar("r2", valid_r2[di].result().numpy(), step=ei) + valid_summary_writer.flush() + # print validation accuracy print( " - valid_loss: %.4f" % valid_loss[di].result().numpy(), end="" @@ -485,6 +535,10 @@ def eval_step1_distr(xd, yd): print(" - valid_r2: %.4f" % valid_r2[di].result().numpy(), end="") early_stop_stat = valid_r[di].result().numpy() + # upload to gcs + if self.gcs: + upload_folder_gcs(train_log_dir, self.gcs_log_dir) + upload_folder_gcs(valid_log_dir, self.gcs_log_dir) # checkpoint managers[di].save() model.save( @@ -633,6 +687,12 @@ def eval_step_distr(xd, yd): valid_best = -np.inf unimproved = 0 + # set up summary writer + train_log_dir = self.log_dir + "/train" + valid_log_dir = self.log_dir + "/valid" + train_summary_writer = tf.summary.create_file_writer(train_log_dir) + valid_summary_writer = tf.summary.create_file_writer(valid_log_dir) + # training loop gpu_memory_callback = GPUMemoryUsageCallback() file_path='%s/gpu_mem.txt' % self.out_dir @@ -672,6 +732,13 @@ def eval_step_distr(xd, yd): train_loss_epoch = train_loss.result().numpy() train_r_epoch = train_r.result().numpy() train_r2_epoch = train_r2.result().numpy() + + with train_summary_writer.as_default(): + tf.summary.scalar("loss", train_loss_epoch, step=ei) + tf.summary.scalar("r", train_r_epoch, step=ei) + tf.summary.scalar("r2", train_r2_epoch, step=ei) + train_summary_writer.flush() + print( "Epoch %d - %ds - train_loss: %.4f - train_r: %.4f - train_r2: %.4f" % ( @@ -688,12 +755,24 @@ def eval_step_distr(xd, yd): valid_loss_epoch = valid_loss.result().numpy() valid_r_epoch = valid_r.result().numpy() valid_r2_epoch = valid_r2.result().numpy() + + with valid_summary_writer.as_default(): + tf.summary.scalar("loss", valid_loss_epoch, step=ei) + tf.summary.scalar("r", valid_r_epoch, step=ei) + tf.summary.scalar("r2", valid_r2_epoch, step=ei) + valid_summary_writer.flush() + print( " - valid_loss: %.4f - valid_r: %.4f - valid_r2: %.4f" % (valid_loss_epoch, valid_r_epoch, valid_r2_epoch), end="", ) + # upload to gcs + if self.gcs: + upload_folder_gcs(train_log_dir, self.gcs_log_dir) + upload_folder_gcs(valid_log_dir, self.gcs_log_dir) + # checkpoint manager.save() seqnn_model.save("%s/model_check.h5" % self.out_dir) diff --git a/tests/test_train.py b/tests/test_train.py index 75949b9..7ff1ba4 100755 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -18,6 +18,8 @@ def test_train(clean_data): "src/baskerville/scripts/hound_train.py", "-o", "tests/data/train1", + "-l", + "tests/data/train1/logs", "tests/data/params.json", "tests/data/tiny/hg38", ] @@ -33,6 +35,8 @@ def test_train2(clean_data): "src/baskerville/scripts/hound_train.py", "-o", "tests/data/train2", + "-l", + "tests/data/train2/logs", "tests/data/params.json", "tests/data/tiny/hg38", "tests/data/tiny/mm10",