Skip to content

Commit

Permalink
modify checks arg lik
Browse files Browse the repository at this point in the history
  • Loading branch information
GertjanBisschop authored and mergify[bot] committed Dec 7, 2023
1 parent 71cdb16 commit 2c5fe55
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
11 changes: 10 additions & 1 deletion msprime/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,21 @@ def log_arg_likelihood(ts, recombination_rate, Ne=1):
`-DBL_MAX`.
"""
for tree in ts.trees():
if np.any(tree.num_children_array > 2):
if np.any(tree.num_children_array[:-1] > 2):
raise ValueError(
"ARG likelihood encountered a polytomy."
" Tree sequences must contain binary mergers only for"
" valid likelihood evaluation."
)
if tree.num_children_array[-1] > 1:
if ts.num_edges > 1:
# num_edges check is here because to avoid breaking the expected
# result of the TestOddToplogies tests.
raise ValueError(
"ARG likelihood encountered a tree with multiple roots."
" All local trees must have a single mrca for"
" valid likelihood evaluation."
)
if ts.num_trees > 1 and not np.any(ts.nodes_flags & _msprime.NODE_IS_RE_EVENT):
raise ValueError(
"ARG likelihood only valid for tree sequences where recombinations"
Expand Down
14 changes: 14 additions & 0 deletions tests/test_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,19 @@ def test_arg_likelihood_polytomy_handling(self):
):
msprime.log_arg_likelihood(arg, 1)

def test_arg_likelihood_multi_root(self):
num_samples = 10
arg = msprime.sim_ancestry(num_samples, record_full_arg=True)
slice_time = arg.nodes_time[num_samples] + 0.01
decap_arg = arg.decapitate(slice_time)
with pytest.raises(
ValueError,
match="ARG likelihood encountered a tree with multiple roots."
" All local trees must have a single mrca for"
" valid likelihood evaluation.",
):
msprime.log_arg_likelihood(decap_arg, 1)

def test_arg_likelihood_no_re_node_handling(self):
tables = tskit.TableCollection(sequence_length=1)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
Expand All @@ -205,6 +218,7 @@ def test_arg_likelihood_no_re_node_handling(self):
tables.edges.add_row(left=0, right=0.5, parent=2, child=0)
tables.edges.add_row(left=0.5, right=1, parent=3, child=0)
tables.edges.add_row(left=0, right=1, parent=3, child=1)
tables.edges.add_row(left=0, right=0.5, parent=3, child=2)
bad_arg = tables.tree_sequence()
with pytest.raises(ValueError, match="NODE_IS_RE_EVENT"):
msprime.log_arg_likelihood(bad_arg, 1)
Expand Down

0 comments on commit 2c5fe55

Please sign in to comment.