Skip to content

Commit

Permalink
Optimized likelihood (#29)
Browse files Browse the repository at this point in the history
Co-authored-by: Laurenz Keller <[email protected]>
  • Loading branch information
laurenzkeller and laukeller authored Nov 14, 2023
1 parent f1cc76d commit 16b6412
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 158 deletions.
210 changes: 77 additions & 133 deletions src/pmhn/_trees/_backend_code.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,29 @@
from typing import Optional
from typing import cast
import numpy as np
from pmhn._trees._tree_utils_geno import create_mappings
from anytree import Node
from joblib import Parallel, delayed


class TreeWrapperCode:
"""Tree wrapper using smart encoding of subtrees."""

def __init__(self, tree: Node) -> None:
self._genotype_subtree_node_map: dict[
tuple[tuple[Node, int], ...], tuple[int, int]
]
self._genotype_subtree_node_map: dict[tuple[tuple[Node, int], ...], int]
self._genotype_list_subtree_map: dict[tuple[int, ...], int]
self._index_subclone_map: dict[int, tuple[int, ...]]
self._subclone_index_map: dict[tuple[int, ...], int]

(
self._genotype_subtree_node_map,
self._genotype_list_subtree_map,
self._index_subclone_map,
self._subclone_index_map,
) = create_mappings(tree)


class TreeMHNBackendCode:
def __init__(self, jitter: float = 1e-10) -> None:
self._jitter: float = jitter

def _diag_entry(
self,
tree_wrapper: TreeWrapperCode,
genotype: tuple[tuple[Node, int], ...],
theta: np.ndarray,
all_mut: set[int],
) -> float:
"""Calculates a diagonal entry of the V matrix.
Args:
tree: a tree wrappper
genotype: the genotype of a subtree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
all_mut: a set containing all possible mutations
Returns:
the diagonal entry of the V matrix corresponding to
genotype
"""
lamb_sum = 0
for i, (node, val) in enumerate(genotype):
if val:
lineage = tree_wrapper._index_subclone_map[i]
lineage = list(lineage)
tree_mutations = set(lineage + [c.name for c in node.children])

exit_mutations = all_mut.difference(tree_mutations)

for mutation in exit_mutations:
lamb = 0
lamb += theta[mutation - 1][mutation - 1]
for j in lineage:
if j != 0:
lamb += theta[mutation - 1][j - 1]
lamb = np.exp(lamb)
lamb_sum -= lamb
return lamb_sum

def find_single_difference(
self, arr1: np.ndarray, arr2: np.ndarray
) -> Optional[int]:
"""
Checks if two binary arrays of equal size differ in only one entry.
If so, the index of the differing entry is returned, otherwise None.
Args:
arr1: the first array
arr2: the second array
Returns:
the index of the differing entry if there's
a single difference, otherwise None.
"""
differing_indices = np.nonzero(np.bitwise_xor(arr1, arr2))[0]

return differing_indices[0] if len(differing_indices) == 1 else None

def _off_diag_entry(
self,
tree_wrapper: TreeWrapperCode,
genotype_i: np.ndarray,
genotype_j: np.ndarray,
theta: np.ndarray,
) -> float:
"""
Calculates an off-diagonal entry of the V matrix.
Args:
tree: the original tree
genotype_i: the genotype of a subtree
genotype_j: the genotype of another subtree
theta: real-valued (i.e., log-theta) matrix,
shape (n_mutations, n_mutations)
Returns:
an off-diagonal entry of the V matrix corresponding to
the genotype_i and genotype_j
"""
index = self.find_single_difference(genotype_i, genotype_j)
if index is None:
return 0
else:
lamb = 0
lineage = tree_wrapper._index_subclone_map[index]
exit_mutation = lineage[-1]
for mutation in lineage:
if mutation != 0:
lamb += theta[exit_mutation - 1][mutation - 1]
lamb = np.exp(lamb)
return float(lamb)

def loglikelihood(
self,
tree_wrapper: TreeWrapperCode,
Expand All @@ -135,38 +45,74 @@ def loglikelihood(
the loglikelihood of tree
"""
subtrees_size = len(tree_wrapper._genotype_subtree_node_map)
x = np.zeros(subtrees_size)
x[0] = 1
genotype_lists = []
for genotype in tree_wrapper._genotype_subtree_node_map.keys():
genotype_lists.append(np.array([item[1] for item in genotype]))
for genotype_i, (
i,
subtree_size_i,
) in tree_wrapper._genotype_subtree_node_map.items():
V_col = []
V_diag = 0.0
for j, subtree_size_j in tree_wrapper._genotype_subtree_node_map.values():
if subtree_size_i - subtree_size_j == 1:
V_col.append(
(
j,
-self._off_diag_entry(
tree_wrapper,
genotype_lists[j],
genotype_lists[i],
theta,
),

subclone_lamb_map = {}

for i, subclone in enumerate(tree_wrapper._subclone_index_map.keys()):
lamb = 0
last_mut = subclone[-1]
for mutation in subclone[1:]:
lamb += theta[last_mut - 1][mutation - 1]
lamb = np.exp(lamb)
subclone_lamb_map[i] = lamb
exit_lamb_map = {}
for i, (node, val) in enumerate(
list(tree_wrapper._genotype_subtree_node_map.keys())[-1]
):
lineage = tree_wrapper._index_subclone_map[i]
lineage = list(lineage)
tree_mutations = set(lineage + [c.name for c in node.children])

exit_mutations = all_mut.difference(tree_mutations)

for mutation in exit_mutations:
lamb = 0
lamb += theta[mutation - 1][mutation - 1]
for j in lineage[1:]:
lamb += theta[mutation - 1][j - 1]
lamb = np.exp(lamb)
exit_lamb_map[tuple(lineage + [mutation])] = lamb

V_old = np.zeros(subtrees_size)
V_old[0] = -1.0
V_new = np.zeros(subtrees_size)
for genotype, index in tree_wrapper._genotype_subtree_node_map.items():
x = 0.0
genotype_list = [item[1] for item in genotype]
for i, (node, val) in enumerate(genotype):
if val:
lineage = tree_wrapper._index_subclone_map[i]
lineage = list(lineage)
tree_mutations = set(lineage + [c.name for c in node.children])

exit_mutations = all_mut.difference(tree_mutations)

for mutation in exit_mutations:
subclone_index = tree_wrapper._subclone_index_map.get(
tuple(lineage + [mutation])
)
)
elif i == j:
V_diag = sampling_rate - self._diag_entry(
tree_wrapper, genotype_i, theta, all_mut
)
for index, val in V_col:
x[i] -= val * x[index]
x[i] /= V_diag
return np.log(x[-1] + self._jitter) + np.log(sampling_rate)
if subclone_index is None:
lamb = exit_lamb_map[tuple(lineage + [mutation])]
V_new[index] += lamb
else:
genotype_list[subclone_index] = 1
ind = tree_wrapper._genotype_list_subtree_map.get(
tuple(genotype_list)
)

lamb = subclone_lamb_map[subclone_index]
V_new[ind] = -lamb
V_new[index] += lamb
genotype_list[subclone_index] = 0
V_new[index] += sampling_rate

x = -V_old[index] / V_new[index]
if index == subtrees_size - 1:
return np.log(x + self._jitter) + np.log(sampling_rate)
V_old += V_new * x
V_new = np.zeros_like(V_new)

return 0.0

def loglikelihood_tree_list(
self,
Expand All @@ -187,11 +133,9 @@ def loglikelihood_tree_list(
Returns:
a list of loglikelihoods, one for each tree
"""
loglikelihoods = cast(
list[float],
Parallel(n_jobs=-1)(
delayed(self.loglikelihood)(tree, theta, sampling_rate, all_mut)
for tree in trees
),
)
loglikelihoods = []
for i, tree in enumerate(trees):
loglikelihoods.append(
self.loglikelihood(tree, theta, sampling_rate, all_mut)
)
return loglikelihoods
28 changes: 20 additions & 8 deletions src/pmhn/_trees/_tree_utils_geno.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ def create_index_subclone_maps(
index = 0
for level in LevelOrderGroupIter(root):
for node in level:
index_subclone_map[index] = get_lineage(node)
subclone_index_map[get_lineage(node)] = index
lineage = get_lineage(node)

index_subclone_map[index] = lineage
subclone_index_map[lineage] = index
index += 1
return index_subclone_map, subclone_index_map

Expand Down Expand Up @@ -165,7 +167,10 @@ def create_genotype(
def create_mappings(
root: Node,
) -> tuple[
dict[tuple[tuple[Node, int], ...], tuple[int, int]], dict[int, tuple[int, ...]]
dict[tuple[tuple[Node, int], ...], int],
dict[tuple[int, ...], int],
dict[int, tuple[int, ...]],
dict[tuple[int, ...], int],
]:
"""
Creates the required mappings to calculate the likelihood of a tree.
Expand All @@ -179,12 +184,19 @@ def create_mappings(
"""
index_subclone_map, subclone_index_map = create_index_subclone_maps(root)
genotype_subtree_map = {}
genotype_list_subtree_map = {}
subtrees = get_subtrees(root)
original_tree = subtrees[-1]
all_node_lists_with_len = [(subtree, len(subtree)) for subtree in subtrees]
size = len(subtrees)
for index, (subtree, subtree_size) in enumerate(all_node_lists_with_len):
size = len(original_tree)
for index, subtree in enumerate(subtrees):
subtree = create_subtree(subtree, original_tree)
genotype = create_genotype(size, subtree, subclone_index_map)
genotype_subtree_map[genotype] = (index, subtree_size)
return genotype_subtree_map, index_subclone_map
genotype_list = tuple([item[1] for item in genotype])
genotype_subtree_map[genotype] = index
genotype_list_subtree_map[genotype_list] = index
return (
genotype_subtree_map,
genotype_list_subtree_map,
index_subclone_map,
subclone_index_map,
)
Loading

0 comments on commit 16b6412

Please sign in to comment.