Skip to content

Commit

Permalink
Make the returned rank a namedtuple
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong authored and mergify[bot] committed Jul 14, 2022
1 parent 812af67 commit 8a69970
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 9 deletions.
6 changes: 6 additions & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,13 @@ Also see the {ref}`sec_python_api_tree_sequences` summary.
```{eval-rst}
.. autoclass:: Interval()
:members:
```

#### The {class}`Rank` class

```{eval-rst}
.. autoclass:: Rank()
:members:
```

### TableCollection and Table classes
Expand Down
14 changes: 12 additions & 2 deletions python/tests/test_combinatorics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# MIT License
#
# Copyright (c) 2020-2021 Tskit Developers
# Copyright (c) 2020-2022 Tskit Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,6 +37,7 @@
import tskit
import tskit.combinatorics as comb
from tests import test_stats
from tskit.combinatorics import Rank
from tskit.combinatorics import RankTree


Expand Down Expand Up @@ -220,7 +221,7 @@ def test_unrank_labelled(self, n):
@pytest.mark.parametrize("n", range(10))
def test_unrank_unlabelled(self, n):
for shape_rank in range(comb.num_shapes(n)):
rank = (shape_rank, 0)
rank = Rank(shape_rank, 0)
unranked = RankTree.unrank(n, rank)
assert rank, unranked.rank()

Expand Down Expand Up @@ -288,6 +289,15 @@ def test_label_unrank(self, n):
assert labelled_tree.rank() == rank
assert unranked.rank() == rank

def test_rank_names(self):
shape = 1
label = 0
n = 3
tree = tskit.Tree.unrank(n, (shape, label))
rank = tree.rank()
assert rank.shape == shape
assert rank.label == label

@pytest.mark.parametrize("n", range(6))
def test_unrank_rank_round_trip(self, n):
for shape_rank in range(comb.num_shapes(n)):
Expand Down
1 change: 1 addition & 0 deletions python/tskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
all_tree_shapes,
all_tree_labellings,
TopologyCounter,
Rank,
)
from tskit.drawing import SVGString # NOQA
from tskit.exceptions import * # NOQA
Expand Down
20 changes: 19 additions & 1 deletion python/tskit/combinatorics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,31 @@
import itertools
import json
import random
from typing import NamedTuple

import attr
import numpy as np

import tskit


class Rank(NamedTuple):
"""
A tuple of 2 numbers, ``(shape, label)``, together defining a unique
topology for a labeled tree. See :ref:`sec_combinatorics`.
"""

shape: int
"""
A non-negative integer representing the (unlabelled) topology of a tree with a
defined number of tips.
"""
label: int
"""
A non-negative integer representing the order of labels for a given tree topology.
"""


def equal_chunks(lst, k):
"""
Yield k successive equally sized chunks from lst of size n.
Expand Down Expand Up @@ -796,7 +814,7 @@ def num_labellings(self):
return num_list_of_group_labellings(child_groups)

def rank(self):
return self.shape_rank(), self.label_rank()
return Rank(self.shape_rank(), self.label_rank())

def shape_rank(self):
if self._shape_rank is None:
Expand Down
9 changes: 3 additions & 6 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,19 +849,18 @@ def seek(self, position):
raise ValueError("Position out of bounds")
self._ll_tree.seek(position)

def rank(self):
def rank(self) -> tskit.Rank:
"""
Produce the rank of this tree in the enumeration of all leaf-labelled
trees of n leaves. See the :ref:`sec_tree_ranks` section for
details on ranking and unranking trees.
:rtype: tuple(int)
:raises ValueError: If the tree has multiple roots.
"""
return combinatorics.RankTree.from_tsk_tree(self).rank()

@staticmethod
def unrank(num_leaves, rank, *, span=1, branch_length=1):
def unrank(num_leaves, rank, *, span=1, branch_length=1) -> Tree:
"""
Reconstruct the tree of the given ``rank``
(see :meth:`tskit.Tree.rank`) with ``num_leaves`` leaves.
Expand All @@ -880,14 +879,13 @@ def unrank(num_leaves, rank, *, span=1, branch_length=1):
from which the tree is taken will have its
:attr:`~tskit.TreeSequence.sequence_length` equal to ``span``.
:param: float branch_length: The minimum length of a branch in this tree.
:rtype: Tree
:raises: ValueError: If the given rank is out of bounds for trees
with ``num_leaves`` leaves.
"""
rank_tree = combinatorics.RankTree.unrank(num_leaves, rank)
return rank_tree.to_tsk_tree(span=span, branch_length=branch_length)

def count_topologies(self, sample_sets=None):
def count_topologies(self, sample_sets=None) -> tskit.TopologyCounter:
"""
Calculates the distribution of embedded topologies for every combination
of the sample sets in ``sample_sets``. ``sample_sets`` defaults to all
Expand Down Expand Up @@ -926,7 +924,6 @@ def count_topologies(self, sample_sets=None):
:param list sample_sets: A list of lists of Node IDs, specifying the
groups of nodes to compute the statistic with.
Defaults to all samples grouped by population.
:rtype: tskit.TopologyCounter
:raises ValueError: If nodes in ``sample_sets`` are invalid or are
internal samples.
"""
Expand Down

0 comments on commit 8a69970

Please sign in to comment.