Skip to content

Commit

Permalink
Merge pull request #35 from calico/stitch-pos
Browse files Browse the repository at this point in the history
Stitch pos
  • Loading branch information
davek44 authored Jun 20, 2024
2 parents b6f5c3c + a4202d7 commit cf70c86
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 81 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"numpy~=1.24.3",
"pandas~=1.5.3",
"pybigwig~=0.3.18",
"pybedtools~=0.9.0",
"pybedtools~=0.10.0",
"pysam~=0.22.0",
"qnorm~=0.8.1",
"seaborn~=0.12.2",
Expand Down
85 changes: 31 additions & 54 deletions src/baskerville/scripts/hound_predbed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import tensorflow as tf

from baskerville import bed
from baskerville import dataset
from baskerville import dna
from baskerville import seqnn
from baskerville import stream


"""
hound_predbed.py
Expand Down Expand Up @@ -120,35 +121,18 @@ def main():
default=None,
help="File specifying target indexes and labels in table format",
)

parser.add_argument(
"-u",
"--untransform_old",
default=False,
action="store_true",
help="Untransform old models [Default: %default]",
)
parser.add_argument("params_file", help="Parameters file")
parser.add_argument("model_file", help="Model file")
parser.add_argument("bed_file", help="BED file")
args = parser.parse_args()

if len(args) == 3:
params_file = args[0]
model_file = args[1]
bed_file = args[2]

elif len(args) == 5:
# multi worker
options_pkl_file = args[0]
params_file = args[1]
model_file = args[2]
bed_file = args[3]
worker_index = int(args[4])

# load options
options_pkl = open(options_pkl_file, "rb")
options = pickle.load(options_pkl)
options_pkl.close()

# update output directory
args.out_dir = "%s/job%d" % (args.out_dir, worker_index)
else:
parser.error("Must provide parameter and model files and BED file")

os.makedirs(args.out_dir, exist_ok=True)

args.shifts = [int(shift) for shift in args.shifts.split(",")]
Expand All @@ -166,7 +150,7 @@ def main():
#################################################################
# read parameters and collet target information

with open(params_file) as params_open:
with open(args.params_file) as params_open:
params = json.load(params_open)
params_model = params["model"]

Expand All @@ -181,7 +165,7 @@ def main():

# initialize model
seqnn_model = seqnn.SeqNN(params_model)
seqnn_model.restore(model_file, args.head)
seqnn_model.restore(args.model_file, args.head)
seqnn_model.build_slice(target_slice)
seqnn_model.build_ensemble(args.rc, args.shifts)

Expand All @@ -205,27 +189,8 @@ def main():

# construct model sequences
model_seqs_dna, model_seqs_coords = bed.make_bed_seqs(
bed_file, args.genome_fasta, params_model["seq_length"], stranded=False
args.bed_file, args.genome_fasta, params_model["seq_length"], stranded=False
)

# construct site coordinates
site_seqs_coords = bed.read_bed_coords(bed_file, args.site_length)

# filter for worker SNPs
if args.processes is not None:
worker_bounds = np.linspace(
0, len(model_seqs_dna), args.processes + 1, dtype="int"
)
model_seqs_dna = model_seqs_dna[
worker_bounds[worker_index] : worker_bounds[worker_index + 1]
]
model_seqs_coords = model_seqs_coords[
worker_bounds[worker_index] : worker_bounds[worker_index + 1]
]
site_seqs_coords = site_seqs_coords[
worker_bounds[worker_index] : worker_bounds[worker_index + 1]
]

num_seqs = len(model_seqs_dna)

#################################################################
Expand Down Expand Up @@ -258,6 +223,7 @@ def main():
)

# store site coordinates
site_seqs_coords = bed.read_bed_coords(args.bed_file, args.site_length)
site_seqs_chr, site_seqs_start, site_seqs_end = zip(*site_seqs_coords)
site_seqs_chr = np.array(site_seqs_chr, dtype="S")
site_seqs_start = np.array(site_seqs_start)
Expand All @@ -268,7 +234,7 @@ def main():

#################################################################
# predict scores, write output

"""
# define sequence generator
def seqs_gen():
for seq_dna in model_seqs_dna:
Expand All @@ -278,18 +244,29 @@ def seqs_gen():
preds_stream = stream.PredStreamGen(
seqnn_model, seqs_gen(), params["train"]["batch_size"]
)
"""

for si, seq_dna in enumerate(model_seqs_dna):
seq_1hot = np.expand_dims(dna.dna_1hot(seq_dna), axis=0)
preds_seq = seqnn_model.predict(seq_1hot)[0]

for si in range(num_seqs):
preds_seq = preds_stream[si]
if args.untransform_old:
preds_seq = dataset.untransform_preds1(preds_seq, targets_df)
else:
preds_seq = dataset.untransform_preds(preds_seq, targets_df)

# slice site
preds_site = preds_seq[site_preds_start:site_preds_end, :]

# write
# optionally, sum
if args.sum:
out_h5["preds"][si] = preds_site.sum(axis=0)
else:
out_h5["preds"][si] = preds_site
preds_site = np.sum(preds_site, axis=0)

# clip to float16 max
preds_write = np.clip(preds_site, 0, np.finfo(np.float16).max)

# write
out_h5["preds"][si] = preds_write

# write bigwig
for ti in args.bigwig_indexes:
Expand Down
7 changes: 2 additions & 5 deletions src/baskerville/scripts/hound_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def main():
out_dir = temp_dir + "/output_dir"
options.out_dir = out_dir

if not os.path.isdir(options.out_dir):
os.mkdir(options.out_dir)
# is this here for GCS?
os.makedirs(options.out_dir, exist_ok=True)

if len(args) == 3:
# single worker
Expand Down Expand Up @@ -200,9 +200,6 @@ def main():
if options.targets_file is None:
parser.error("Must provide targets file")

if options.cluster_snps_pct > 0 and options.indel_stitch:
parser.error("Cannot use --cluster_snps_pct and --indel_stitch together")

#################################################################
# check if the program is run on GPU, else quit
physical_devices = tf.config.list_physical_devices()
Expand Down
11 changes: 9 additions & 2 deletions src/baskerville/scripts/hound_snpgene.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def main():
default="%s/genes/gencode41/gencode41_basic_nort.gtf" % os.environ["HG38"],
help="GTF for gene definition [Default %default]",
)
parser.add_option(
"--indel_stitch",
dest="indel_stitch",
default=False,
action="store_true",
help="Stitch indel compensation shifts [Default: %default]",
)
parser.add_option(
"-o",
dest="out_dir",
Expand Down Expand Up @@ -155,8 +162,8 @@ def main():
out_dir = temp_dir + "/output_dir"
options.out_dir = out_dir

if not os.path.isdir(options.out_dir):
os.mkdir(options.out_dir)
# is this here for GCS?
os.makedirs(options.out_dir, exist_ok=True)

if len(args) == 3:
# single worker
Expand Down
8 changes: 6 additions & 2 deletions src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,11 @@ def __call__(self, x, head_i=None, dtype="float32"):
else:
model = self.model

return model(x).numpy().astype(dtype)
if isinstance(x, np.ndarray):
preds = model(x).numpy().astype(dtype)
else:
preds = model(x)
return preds

def predict(
self,
Expand Down Expand Up @@ -955,7 +959,7 @@ def predict(
preds = model.predict_generator(dataset, **kwargs).astype(dtype)
elif stream:
preds = []
for x, y in seq_data.dataset:
for x, y in dataset:
yh = model.predict(x, **kwargs)
if step > 1:
yh = yh[:, step_i, :]
Expand Down
45 changes: 28 additions & 17 deletions src/baskerville/snps.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):
# shift outside seqnn
num_shifts = len(options.shifts)
targets_length = seqnn_model.target_lengths[0]
targets_length = seqnn_model.target_lengths[0]
model_stride = seqnn_model.model_strides[0]
model_crop = seqnn_model.target_crops[0] * model_stride

#################################################################
# load SNPs
Expand Down Expand Up @@ -194,16 +197,11 @@ def score_write(ref_preds, alt_preds, si):
for ai, alt_1hot in enumerate(snp_1hot_list[1:]):
alt_1hot = np.expand_dims(alt_1hot, axis=0)

# add compensation shifts for indels
# add left/right shifts for indels
indel_size = sc.snps[ai].indel_size()
if indel_size == 0:
alt_shifts = options.shifts
else:
# repeat reference predictions, unless stitching
if not options.indel_stitch:
ref_preds = np.repeat(ref_preds, 2, axis=0)

# add compensation shifts
alt_shifts = []
for shift in options.shifts:
alt_shifts.append(shift)
Expand All @@ -229,9 +227,11 @@ def score_write(ref_preds, alt_preds, si):
ref_preds = [rpsf.result() for rpsf in ref_preds]
alt_preds = [apsf.result() for apsf in alt_preds]

# stitch indel compensation shifts
# stitch indel shifts
if indel_size != 0 and options.indel_stitch:
alt_preds = stitch_preds(alt_preds, options.shifts)
snp_seq_pos = sc.snps[ai].pos - sc.start - model_crop
snp_seq_bin = snp_seq_pos // model_stride
alt_preds = stitch_preds(alt_preds, options.shifts, snp_seq_bin)

# flip reference and alternate
if snps[si].flipped:
Expand All @@ -241,6 +241,10 @@ def score_write(ref_preds, alt_preds, si):
rp_snp = np.array(ref_preds)
ap_snp = np.array(alt_preds)

# repeat reference predictions for indels w/o stitching
if indel_size != 0 and not options.indel_stitch:
rp_snp = np.repeat(rp_snp, 2, axis=0)

# write SNP
if sum_length:
write_snp(rp_snp, ap_snp, scores_out, si, options.snp_stats)
Expand Down Expand Up @@ -420,16 +424,11 @@ def score_write(ref_preds, alt_preds, gene_id, snp_id):
for ai, alt_1hot in enumerate(snp_1hot_list[1:]):
alt_1hot = np.expand_dims(alt_1hot, axis=0)

# add compensation shifts for indels
# add left/right shifts for indels
indel_size = gsc.snps[ai].indel_size()
if indel_size == 0:
alt_shifts = options.shifts
else:
# repeat reference predictions, unless stitching
if not options.indel_stitch:
ref_preds = np.repeat(ref_preds, 2, axis=0)

# add compensation shifts
alt_shifts = []
for shift in options.shifts:
alt_shifts.append(shift)
Expand All @@ -455,6 +454,12 @@ def score_write(ref_preds, alt_preds, gene_id, snp_id):
ref_preds = [rpsf.result() for rpsf in ref_preds]
alt_preds = [apsf.result() for apsf in alt_preds]

# stitch indel shifts
if indel_size != 0 and options.indel_stitch:
snp_seq_pos = gsc.snps[ai].pos - gsc.start - model_crop
snp_seq_bin = snp_seq_pos // model_stride
alt_preds = stitch_preds(alt_preds, options.shifts, snp_seq_bin)

# flip reference and alternate
if gsc.snps[ai].flipped:
rp_snp = np.array(alt_preds)
Expand All @@ -463,6 +468,10 @@ def score_write(ref_preds, alt_preds, gene_id, snp_id):
rp_snp = np.array(ref_preds)
ap_snp = np.array(alt_preds)

# repeat reference predictions for indels w/o stitching
if indel_size != 0 and not options.indel_stitch:
rp_snp = np.repeat(rp_snp, 2, axis=0)

for gene in gsc.genes:
# slice gene positions
gene_slice = gene.output_slice(
Expand Down Expand Up @@ -864,19 +873,21 @@ def map_snps_genes(snps, genesnp_clusters):
genesnp_clusters[gi].add_snp(snps[si])


def stitch_preds(preds, shifts):
def stitch_preds(preds, shifts, pos=None):
"""Stitch indel left and right compensation shifts.
Args:
preds [np.array]: List of predictions.
shifts [int]: List of shifts.
pos (int): SNP position to stitch at.
"""
cp = preds[0].shape[0] // 2
if pos is None:
pos = preds[0].shape[0] // 2
preds_stitch = []
for hi, shift in enumerate(shifts):
hil = 2 * hi
hir = hil + 1
preds_stitch_i = np.concatenate((preds[hil][:cp], preds[hir][cp:]), axis=0)
preds_stitch_i = np.concatenate((preds[hil][:pos], preds[hir][pos:]), axis=0)
preds_stitch.append(preds_stitch_i)
return preds_stitch

Expand Down

0 comments on commit cf70c86

Please sign in to comment.