Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indel #8

Merged
merged 3 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/baskerville/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def untransform_preds(preds, targets_df, unscale=False):

# sqrt
sqrt_mask = np.array([ss.find("_sqrt") != -1 for ss in targets_df.sum_stat])
preds[:, sqrt_mask] = -1 + (preds[:, sqrt_mask] + 1) ** 2 # (4 / 3)
preds[:, sqrt_mask] = -1 + (preds[:, sqrt_mask] + 1) ** 2 # (4 / 3)

# scale
if unscale:
Expand Down
183 changes: 126 additions & 57 deletions src/baskerville/snps.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):
seqnn_model.build_slice(targets_df.index)
if sum_length:
seqnn_model.build_sad()
seqnn_model.build_ensemble(options.rc, options.shifts)
seqnn_model.build_ensemble(options.rc)

# shift outside seqnn
num_shifts = len(options.shifts)

targets_length = seqnn_model.target_lengths[0]
num_targets = seqnn_model.num_targets()
Expand Down Expand Up @@ -128,54 +131,95 @@ def score_snps(params_file, model_file, vcf_file, worker_index, options):

# setup output
scores_out = initialize_output_h5(
options.out_dir, options.snp_stats, snps, targets_length, targets_strand_df
options.out_dir,
options.snp_stats,
snps,
targets_length,
targets_strand_df,
num_shifts,
)

# SNP index
si = 0

for sc in tqdm(snp_clusters):
snp_1hot_list = sc.get_1hots(genome_open)

# predict reference
ref_1hot = np.expand_dims(snp_1hot_list[0], axis=0)
ref_preds = seqnn_model(ref_1hot)[0]

# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
ref_preds = dataset.untransform_preds1(ref_preds, targets_df)
else:
ref_preds = dataset.untransform_preds(ref_preds, targets_df)

# sum strand pairs
if strand_transform is not None:
ref_preds = ref_preds * strand_transform

for alt_1hot in snp_1hot_list[1:]:
alt_1hot = np.expand_dims(alt_1hot, axis=0)

# predict alternate
alt_preds = seqnn_model(alt_1hot)[0]
# predict reference
ref_preds = []
for shift in options.shifts:
ref_1hot_shift = dna.hot1_augment(ref_1hot, shift=shift)
ref_preds_shift = seqnn_model(ref_1hot_shift)[0]

# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
alt_preds = dataset.untransform_preds1(alt_preds, targets_df)
ref_preds_shift = dataset.untransform_preds1(
ref_preds_shift, targets_df
)
else:
alt_preds = dataset.untransform_preds(alt_preds, targets_df)
ref_preds_shift = dataset.untransform_preds(
ref_preds_shift, targets_df
)

# sum strand pairs
if strand_transform is not None:
alt_preds = alt_preds * strand_transform
ref_preds_shift = ref_preds_shift * strand_transform

# save shift prediction
ref_preds.append(ref_preds_shift)
ref_preds = np.array(ref_preds)

ai = 0
for alt_1hot in snp_1hot_list[1:]:
alt_1hot = np.expand_dims(alt_1hot, axis=0)

# add compensation shifts for indels
indel_size = sc.snps[ai].indel_size()
if indel_size == 0:
alt_shifts = options.shifts
else:
# repeat reference predictions
ref_preds = np.repeat(ref_preds, 2, axis=0)

# add compensation shifts
alt_shifts = []
for shift in options.shifts:
alt_shifts.append(shift)
alt_shifts.append(shift - indel_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line adds shift in the direction opposite of the indel "direction", right? So if by default it is shifted to the right, this line will add a left shift

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct


# predict alternate
alt_preds = []
for shift in alt_shifts:
alt_1hot_shift = dna.hot1_augment(alt_1hot, shift=shift)
alt_preds_shift = seqnn_model(alt_1hot_shift)[0]

# untransform predictions
if options.targets_file is not None:
if options.untransform_old:
alt_preds_shift = dataset.untransform_preds1(
alt_preds_shift, targets_df
)
else:
alt_preds_shift = dataset.untransform_preds(
alt_preds_shift, targets_df
)

# sum strand pairs
if strand_transform is not None:
alt_preds_shift = alt_preds_shift * strand_transform

# save shift prediction
alt_preds.append(alt_preds_shift)

# flip reference and alternate
if snps[si].flipped:
rp_snp = alt_preds
ap_snp = ref_preds
rp_snp = np.array(alt_preds)
ap_snp = np.array(ref_preds)
else:
rp_snp = ref_preds
ap_snp = alt_preds
rp_snp = np.array(ref_preds)
ap_snp = np.array(alt_preds)

# write SNP
if sum_length:
Expand Down Expand Up @@ -222,15 +266,18 @@ def cluster_snps(snps, seq_len: int, center_pct: float):
return snp_clusters


def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df):
def initialize_output_h5(
out_dir, snp_stats, snps, targets_length, targets_df, num_shifts
):
"""Initialize an output HDF5 file for SAD stats.

Args:
out_dir (str): Output directory.
snp_stats [str]: List of SAD stats to compute.
snps [SNP]: List of SNPs.
targets_length (int): Targets' sequence length
targets_df (pd.DataFrame): Targets AataFrame.
targets_df (pd.DataFrame): Targets DataFrame.
num_shifts (int): Number of shifts.
"""

num_targets = targets_df.shape[0]
Expand Down Expand Up @@ -275,7 +322,9 @@ def initialize_output_h5(out_dir, snp_stats, snps, targets_length, targets_df):
for snp_stat in snp_stats:
if snp_stat in ["REF", "ALT"]:
scores_out.create_dataset(
snp_stat, shape=(num_snps, targets_length, num_targets), dtype="float16"
snp_stat,
shape=(num_snps, num_shifts, targets_length, num_targets),
dtype="float16",
)
else:
scores_out.create_dataset(
Expand Down Expand Up @@ -407,7 +456,8 @@ def write_snp(ref_preds_sum, alt_preds_sum, scores_out, si, snp_stats):
# compare reference to alternative via mean subtraction
if "SAD" in snp_stats:
sad = alt_preds_sum - ref_preds_sum
scores_out["SAD"][si, :] = sad.astype("float16")
sad = sad.mean(axis=0)
scores_out["SAD"][si] = sad.astype("float16")


def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
Expand All @@ -421,7 +471,7 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
si (int): SNP index.
snp_stats [str]: List of SAD stats to compute.
"""
seq_length, num_targets = ref_preds.shape
num_shifts, seq_length, num_targets = ref_preds.shape

# log/sqrt
ref_preds_log = np.log2(ref_preds + 1)
Expand All @@ -430,12 +480,12 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
alt_preds_sqrt = np.sqrt(alt_preds)

# sum across length
ref_preds_sum = ref_preds.sum(axis=0)
alt_preds_sum = alt_preds.sum(axis=0)
ref_preds_log_sum = ref_preds_log.sum(axis=0)
alt_preds_log_sum = alt_preds_log.sum(axis=0)
ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=0)
alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=0)
ref_preds_sum = ref_preds.sum(axis=(0, 1))
alt_preds_sum = alt_preds.sum(axis=(0, 1))
ref_preds_log_sum = ref_preds_log.sum(axis=(0, 1))
alt_preds_log_sum = alt_preds_log.sum(axis=(0, 1))
ref_preds_sqrt_sum = ref_preds_sqrt.sum(axis=(0, 1))
alt_preds_sqrt_sum = alt_preds_sqrt.sum(axis=(0, 1))

# difference
altref_diff = alt_preds - ref_preds
Expand All @@ -461,53 +511,65 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):

# compare reference to alternative via max subtraction
if "SAX" in snp_stats:
max_i = np.argmax(altref_adiff, axis=0)
sax = altref_diff[max_i, np.arange(num_targets)]
sax = []
for s in range(num_shifts):
max_i = np.argmax(altref_adiff[s], axis=0)
sax.append(altref_diff[s, max_i, np.arange(num_targets)])
sax = np.array(sax).mean(axis=0)
scores_out["SAX"][si] = sax.astype("float16")

# L1 norm of difference vector
if "D1" in snp_stats:
sad_d1 = altref_adiff.sum(axis=0)
sad_d1 = altref_adiff.sum(axis=1)
sad_d1 = np.clip(sad_d1, np.finfo(np.float16).min, np.finfo(np.float16).max)
scores_out["D1"][si] = sad_d1.astype("float16")
sad_d1 = sad_d1.mean(axis=0)
scores_out["D1"][si] = sad_d1.mean().astype("float16")
if "logD1" in snp_stats:
log_d1 = altref_log_adiff.sum(axis=0)
log_d1 = altref_log_adiff.sum(axis=1)
log_d1 = np.clip(log_d1, np.finfo(np.float16).min, np.finfo(np.float16).max)
log_d1 = log_d1.mean(axis=0)
scores_out["logD1"][si] = log_d1.astype("float16")
if "sqrtD1" in snp_stats:
sqrt_d1 = altref_sqrt_adiff.sum(axis=0)
sqrt_d1 = altref_sqrt_adiff.sum(axis=1)
sqrt_d1 = np.clip(sqrt_d1, np.finfo(np.float16).min, np.finfo(np.float16).max)
sqrt_d1 = sqrt_d1.mean(axis=0)
scores_out["sqrtD1"][si] = sqrt_d1.astype("float16")

# L2 norm of difference vector
if "D2" in snp_stats:
altref_diff2 = np.power(altref_diff, 2)
sad_d2 = np.sqrt(altref_diff2.sum(axis=0))
sad_d2 = np.sqrt(altref_diff2.sum(axis=1))
sad_d2 = np.clip(sad_d2, np.finfo(np.float16).min, np.finfo(np.float16).max)
sad_d2 = sad_d2.mean(axis=0)
scores_out["D2"][si] = sad_d2.astype("float16")
if "logD2" in snp_stats:
altref_log_diff2 = np.power(altref_log_diff, 2)
log_d2 = np.sqrt(altref_log_diff2.sum(axis=0))
log_d2 = np.sqrt(altref_log_diff2.sum(axis=1))
log_d2 = np.clip(log_d2, np.finfo(np.float16).min, np.finfo(np.float16).max)
log_d2 = log_d2.mean(axis=0)
scores_out["logD2"][si] = log_d2.astype("float16")
if "sqrtD2" in snp_stats:
altref_sqrt_diff2 = np.power(altref_sqrt_diff, 2)
sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=0))
sqrt_d2 = np.sqrt(altref_sqrt_diff2.sum(axis=1))
sqrt_d2 = np.clip(sqrt_d2, np.finfo(np.float16).min, np.finfo(np.float16).max)
sqrt_d2 = sqrt_d2.mean(axis=0)
scores_out["sqrtD2"][si] = sqrt_d2.astype("float16")

if "JS" in snp_stats:
# normalized scores
pseudocounts = np.percentile(ref_preds, 25, axis=0)
pseudocounts = np.percentile(ref_preds, 25, axis=1)
ref_preds_norm = ref_preds + pseudocounts
ref_preds_norm /= ref_preds_norm.sum(axis=0)
ref_preds_norm /= ref_preds_norm.sum(axis=1)
alt_preds_norm = alt_preds + pseudocounts
alt_preds_norm /= alt_preds_norm.sum(axis=0)
alt_preds_norm /= alt_preds_norm.sum(axis=1)

# compare normalized JS
ref_alt_entr = rel_entr(ref_preds_norm, alt_preds_norm).sum(axis=0)
alt_ref_entr = rel_entr(alt_preds_norm, ref_preds_norm).sum(axis=0)
js_dist = (ref_alt_entr + alt_ref_entr) / 2
js_dist = []
for s in range(num_shifts):
ref_alt_entr = rel_entr(ref_preds_norm[s], alt_preds_norm[s]).sum(axis=0)
alt_ref_entr = rel_entr(alt_preds_norm[s], ref_preds_norm[s]).sum(axis=0)
js_dist.append((ref_alt_entr + alt_ref_entr) / 2)
js_dist = np.mean(js_dist, axis=0)
scores_out["JS"][si] = js_dist.astype("float16")
if "logJS" in snp_stats:
# normalized scores
Expand All @@ -518,9 +580,16 @@ def write_snp_len(ref_preds, alt_preds, scores_out, si, snp_stats):
alt_preds_log_norm /= alt_preds_log_norm.sum(axis=0)

# compare normalized JS
ref_alt_entr = rel_entr(ref_preds_log_norm, alt_preds_log_norm).sum(axis=0)
alt_ref_entr = rel_entr(alt_preds_log_norm, ref_preds_log_norm).sum(axis=0)
log_js_dist = (ref_alt_entr + alt_ref_entr) / 2
log_js_dist = []
for s in range(num_shifts):
ref_alt_entr = rel_entr(ref_preds_log_norm[s], alt_preds_log_norm[s]).sum(
axis=0
)
alt_ref_entr = rel_entr(alt_preds_log_norm[s], ref_preds_log_norm[s]).sum(
axis=0
)
log_js_dist.append((ref_alt_entr + alt_ref_entr) / 2)
log_js_dist = np.mean(log_js_dist, axis=0)
scores_out["logJS"][si] = log_js_dist.astype("float16")

# predictions
Expand Down
4 changes: 4 additions & 0 deletions src/baskerville/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,10 @@ def get_alleles(self):
alleles = [self.ref_allele] + self.alt_alleles
return alleles

def indel_size(self):
"""Return the size of the indel."""
return len(self.alt_allele) - len(self.ref_allele)

def longest_alt(self):
"""Return the longest alt allele."""
return max([len(al) for al in self.alt_alleles])
Expand Down
Loading