Skip to content

Commit

Permalink
Merge pull request #8 from calico/indel
Browse files Browse the repository at this point in the history
Indel
  • Loading branch information
davek44 authored Nov 14, 2023
2 parents e9390be + a36eb89 commit 0edb29c
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/baskerville/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def untransform_preds(preds, targets_df, unscale=False):

# sqrt
sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat])
preds[:, sqrt_mask] = -1 + (preds[:, sqrt_mask] + 1) ** 2 # (4 / 3)
preds[:, sqrt_mask] = -1 + (preds[:, sqrt_mask] + 1) ** 2 # (4 / 3)

# scale
if unscale:
Expand Down
183 changes: 126 additions & 57 deletions src/baskerville/snps.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):
seqnn_model.build_slice(targets_df.index)
if sum_length:
seqnn_model.build_sad()
seqnn_model.build_ensemble(options.rc, options.shifts)
seqnn_model.build_ensemble(options.rc)

# shift outside seqnn
num_shifts = len(options.shifts)

targets_length = seqnn_model.target_lengths[0]
num_targets = seqnn_model.num_targets()
Expand Down Expand Up @@ -128,54 +131,95 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):

# setup output
scores_out = initialize_output_h5(
options.out_dir, options.snp_stats, snps, targets_length, targets_strand_df
options.out_dir,
options.snp_stats,
snps,
targets_length,
targets_strand_df,
num_shifts,
)

# SNP index
si = 0

for sc in tqdm(snp_clusters):
snp_1hot_list = sc.get_1hots(genome_open)

# predict reference
ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0)
ref_preds = seqnn_model(ref_1hot)[0]

# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
ref_preds = dataset.untransform_preds1(ref_preds, targets_df)
else:
ref_preds = dataset.untransform_preds(ref_preds, targets_df)

# sum strand pairs
if strand_transform is not None:
ref_preds = ref_preds * strand_transform

for alt_1hot in snp_1hot_list[1:]:
alt_1hot = np.expand_dims(alt_1hot, axis=0)

# predict alternate
alt_preds = seqnn_model(alt_1hot)[0]
# predict reference
ref_preds = []
for shift in options.shifts:
ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift)
ref_preds_shift = seqnn_model(ref_1hot_shift)[0]

# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
alt_preds = dataset.untransform_preds1(alt_preds, targets_df)
ref_preds_shift = dataset.untransform_preds1(
ref_preds_shift, targets_df
)
else:
alt_preds = dataset.untransform_preds(alt_preds, targets_df)
ref_preds_shift = dataset.untransform_preds(
ref_preds_shift, targets_df
)

# sum strand pairs
if strand_transform is not None:
alt_preds = alt_preds * strand_transform
ref_preds_shift = ref_preds_shift * strand_transform

# save shift prediction
ref_preds.append(ref_preds_shift)
ref_preds = np.array(ref_preds)

ai = 0
for alt_1hot in snp_1hot_list[1:]:
alt_1hot = np.expand_dims(alt_1hot, axis=0)

# add compensation shifts for indels
indel_size = sc.snps[ai].indel_size()
if indel_size == 0:
alt_shifts = options.shifts
else:
# repeat reference predictions
ref_preds = np.repeat(ref_preds, 2, axis=0)

# add compensation shifts
alt_shifts = []
for shift in options.shifts:
alt_shifts.append(shift)
alt_shifts.append(shift - indel_size)

# predict alternate
alt_preds = []
for shift in alt_shifts:
alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift)
alt_preds_shift = seqnn_model(alt_1hot_shift)[0]

# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
alt_preds_shift = dataset.untransform_preds1(
alt_preds_shift, targets_df
)
else:
alt_preds_shift = dataset.untransform_preds(
alt_preds_shift, targets_df
)

# sum strand pairs
if strand_transform is not None:
alt_preds_shift = alt_preds_shift * strand_transform

# save shift prediction
alt_preds.append(alt_preds_shift)

# flip reference and alternate
if snps[si].flipped:
rp_snp = alt_preds
ap_snp = ref_preds
rp_snp = np.array(alt_preds)
ap_snp = np.array(ref_preds)
else:
rp_snp = ref_preds
ap_snp = alt_preds
rp_snp = np.array(ref_preds)
ap_snp = np.array(alt_preds)

# write SNP
if sum_length:
Expand Down Expand Up @@ -222,15 +266,18 @@ def cluster_snps(snps, seq_len: int, center_pct: float):
return snp_clusters


def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df):
def initialize_output_h5(
out_dir, snp_stats, snps, targets_length, targets_df, num_shifts
):
"""Initialize an output HDF5 file for SAD stats.
Args:
out_dir (str): Output directory.
snp_stats [str]: List of SAD stats to compute.
snps [SNP]: List of SNPs.
targets_length (int): Targets' sequence length
targets_df (pd.DataFrame): Targets AataFrame.
targets_df (pd.DataFrame): Targets DataFrame.
num_shifts (int): Number of shifts.
"""

num_targets = targets_df.shape[0]
Expand Down Expand Up @@ -275,7 +322,9 @@ def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df):
for snp_stat in snp_stats:
if snp_stat in ["REF", "ALT"]:
scores_out.create_dataset(
snp_stat, shape=(num_snps, targets_length, num_targets), dtype="float16"
snp_stat,
shape=(num_snps, num_shifts, targets_length, num_targets),
dtype="float16",
)
else:
scores_out.create_dataset(
Expand Down Expand Up @@ -407,7 +456,8 @@ def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats):
# compare reference to alternative via mean subtraction
if "SAD" in snp_stats:
sad = alt_preds_sum - ref_preds_sum
scores_out["SAD"][si, :] = sad.astype("float16")
sad = sad.mean(axis=0)
scores_out["SAD"][si] = sad.astype("float16")


def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
Expand All @@ -421,7 +471,7 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
si (int): SNP index.
snp_stats [str]: List of SAD stats to compute.
"""
seq_length, num_targets = ref_preds.shape
num_shifts, seq_length, num_targets = ref_preds.shape

# log/sqrt
ref_preds_log = np.log2(ref_preds + 1)
Expand All @@ -430,12 +480,12 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
alt_preds_sqrt = np.sqrt(alt_preds)

# sum across length
ref_preds_sum = ref_preds.sum(axis=0)
alt_preds_sum = alt_preds.sum(axis=0)
ref_preds_log_sum = ref_preds_log.sum(axis=0)
alt_preds_log_sum = alt_preds_log.sum(axis=0)
ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=0)
alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=0)
ref_preds_sum = ref_preds.sum(axis=(0, 1))
alt_preds_sum = alt_preds.sum(axis=(0, 1))
ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1))
alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1))
ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1))
alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1))

# difference
altref_diff = alt_preds - ref_preds
Expand All @@ -461,53 +511,65 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):

# compare reference to alternative via max subtraction
if "SAX" in snp_stats:
max_i = np.argmax(altref_adiff, axis=0)
sax = altref_diff[max_i, np.arange(num_targets)]
sax = []
for s in range(num_shifts):
max_i = np.argmax(altref_adiff[s], axis=0)
sax.append(altref_diff[s, max_i, np.arange(num_targets)])
sax = np.array(sax).mean(axis=0)
scores_out["SAX"][si] = sax.astype("float16")

# L1 norm of difference vector
if "D1" in snp_stats:
sad_d1 = altref_adiff.sum(axis=0)
sad_d1 = altref_adiff.sum(axis=1)
sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max)
scores_out["D1"][si] = sad_d1.astype("float16")
sad_d1 = sad_d1.mean(axis=0)
scores_out["D1"][si] = sad_d1.mean().astype("float16")
if "logD1" in snp_stats:
log_d1 = altref_log_adiff.sum(axis=0)
log_d1 = altref_log_adiff.sum(axis=1)
log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max)
log_d1 = log_d1.mean(axis=0)
scores_out["logD1"][si] = log_d1.astype("float16")
if "sqrtD1" in snp_stats:
sqrt_d1 = altref_sqrt_adiff.sum(axis=0)
sqrt_d1 = altref_sqrt_adiff.sum(axis=1)
sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max)
sqrt_d1 = sqrt_d1.mean(axis=0)
scores_out["sqrtD1"][si] = sqrt_d1.astype("float16")

# L2 norm of difference vector
if "D2" in snp_stats:
altref_diff2 = np.power(altref_diff, 2)
sad_d2 = np.sqrt(altref_diff2.sum(axis=0))
sad_d2 = np.sqrt(altref_diff2.sum(axis=1))
sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max)
sad_d2 = sad_d2.mean(axis=0)
scores_out["D2"][si] = sad_d2.astype("float16")
if "logD2" in snp_stats:
altref_log_diff2 = np.power(altref_log_diff, 2)
log_d2 = np.sqrt(altref_log_diff2.sum(axis=0))
log_d2 = np.sqrt(altref_log_diff2.sum(axis=1))
log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max)
log_d2 = log_d2.mean(axis=0)
scores_out["logD2"][si] = log_d2.astype("float16")
if "sqrtD2" in snp_stats:
altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2)
sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=0))
sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=1))
sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max)
sqrt_d2 = sqrt_d2.mean(axis=0)
scores_out["sqrtD2"][si] = sqrt_d2.astype("float16")

if "JS" in snp_stats:
# normalized scores
pseudocounts = np.percentile(ref_preds, 25, axis=0)
pseudocounts = np.percentile(ref_preds, 25, axis=1)
ref_preds_norm = ref_preds + pseudocounts
ref_preds_norm /= ref_preds_norm.sum(axis=0)
ref_preds_norm /= ref_preds_norm.sum(axis=1)
alt_preds_norm = alt_preds + pseudocounts
alt_preds_norm /= alt_preds_norm.sum(axis=0)
alt_preds_norm /= alt_preds_norm.sum(axis=1)

# compare normalized JS
ref_alt_entr = rel_entr(ref_preds_norm, alt_preds_norm).sum(axis=0)
alt_ref_entr = rel_entr(alt_preds_norm, ref_preds_norm).sum(axis=0)
js_dist = (ref_alt_entr + alt_ref_entr) / 2
js_dist = []
for s in range(num_shifts):
ref_alt_entr = rel_entr(ref_preds_norm[s], alt_preds_norm[s]).sum(axis=0)
alt_ref_entr = rel_entr(alt_preds_norm[s], ref_preds_norm[s]).sum(axis=0)
js_dist.append((ref_alt_entr + alt_ref_entr) / 2)
js_dist = np.mean(js_dist, axis=0)
scores_out["JS"][si] = js_dist.astype("float16")
if "logJS" in snp_stats:
# normalized scores
Expand All @@ -518,9 +580,16 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
alt_preds_log_norm /= alt_preds_log_norm.sum(axis=0)

# compare normalized JS
ref_alt_entr = rel_entr(ref_preds_log_norm, alt_preds_log_norm).sum(axis=0)
alt_ref_entr = rel_entr(alt_preds_log_norm, ref_preds_log_norm).sum(axis=0)
log_js_dist = (ref_alt_entr + alt_ref_entr) / 2
log_js_dist = []
for s in range(num_shifts):
ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum(
axis=0
)
alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum(
axis=0
)
log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2)
log_js_dist = np.mean(log_js_dist, axis=0)
scores_out["logJS"][si] = log_js_dist.astype("float16")

# predictions
Expand Down
4 changes: 4 additions & 0 deletions src/baskerville/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,10 @@ def get_alleles(self):
alleles = [self.ref_allele] + self.alt_alleles
return alleles

def indel_size(self):
"""Return the size of the indel."""
return len(self.alt_allele) - len(self.ref_allele)

def longest_alt(self):
"""Return the longest alt allele."""
return max([len(al) for al in self.alt_alleles])
Expand Down

0 comments on commit 0edb29c

Please sign in to comment.