From 5eee07b9c257e8a5b7f4210cacdf07a69378c8bb Mon Sep 17 00:00:00 2001 From: EvaLiyt Date: Tue, 18 Jun 2024 15:20:07 +1200 Subject: [PATCH] implement CalibratedYule distribution #497 --- .../evolution/birthdeath/CalibratedYule.java | 390 ++++++++++++++++++ .../evolution/tree/CalibratedYuleTest.java | 112 +++++ 2 files changed, 502 insertions(+) create mode 100644 lphy-base/src/main/java/lphy/base/evolution/birthdeath/CalibratedYule.java create mode 100644 lphy-base/src/test/java/lphy/base/evolution/tree/CalibratedYuleTest.java diff --git a/lphy-base/src/main/java/lphy/base/evolution/birthdeath/CalibratedYule.java b/lphy-base/src/main/java/lphy/base/evolution/birthdeath/CalibratedYule.java new file mode 100644 index 000000000..d99b664c4 --- /dev/null +++ b/lphy-base/src/main/java/lphy/base/evolution/birthdeath/CalibratedYule.java @@ -0,0 +1,390 @@ +package lphy.base.evolution.birthdeath; + +import lphy.base.distribution.DistributionConstants; +import lphy.base.distribution.Exp; +import lphy.base.distribution.UniformDiscrete; +import lphy.base.evolution.Taxa; +import lphy.base.evolution.Taxon; +import lphy.base.evolution.tree.TaxaConditionedTreeGenerator; +import lphy.base.evolution.tree.TimeTree; +import lphy.base.evolution.tree.TimeTreeNode; +import lphy.core.model.GenerativeDistribution; +import lphy.core.model.RandomVariable; +import lphy.core.model.Value; +import lphy.core.model.annotation.GeneratorCategory; +import lphy.core.model.annotation.GeneratorInfo; +import lphy.core.model.annotation.ParameterInfo; + +import java.util.*; + + +public class CalibratedYule extends TaxaConditionedTreeGenerator implements GenerativeDistribution { + Value rootAge; + Value birthRate; + Value cladeMRCAAge; + Value cladeTaxaValue; + Value otherTaxa; + Taxa taxa; + Taxa[] cladeTaxaArray; + List activeNodes; + List inactiveNodes; + public final String cladeMRCAAgeName = "cladeMRCAAge"; + public final String cladeTaxaName = "cladeTaxa"; + public final String otherTaxaName = "otherTaxa"; + + public CalibratedYule(@ParameterInfo(name = BirthDeathConstants.lambdaParamName, description = "per-lineage birth rate, possibly scaled to mutations or calendar units.") Value birthRate, + @ParameterInfo(name = DistributionConstants.nParamName, description = "the total number of taxa.", optional = true) Value n, + @ParameterInfo(name = cladeTaxaName, description = "a string array of taxa id or a taxa object for clade taxa (e.g. dataframe, alignment or tree)") Value cladeTaxa, + @ParameterInfo(name = cladeMRCAAgeName, description = "an array of ages for clade most recent common ancestor, ages should be correspond with clade taxa array.") Value cladeMRCAAge, + @ParameterInfo(name = otherTaxaName, description = "a string array of taxa id or a taxa object for other taxa (e.g. dataframe, alignment or tree)", optional = true) Value otherTaxa, + @ParameterInfo(name = BirthDeathConstants.rootAgeParamName, description = "the root age to be conditioned on optional.", optional = true) Value rootAge){ + super(n, null, null); + if (cladeTaxa == null) throw new IllegalArgumentException("The clade taxa shouldn't be null!"); + if (cladeMRCAAge == null) throw new IllegalArgumentException("The clade mrca age shouldn't be null!"); + if (n == null && otherTaxa == null) { + throw new IllegalArgumentException("At least one of " + DistributionConstants.nParamName + ", " + otherTaxaName + " must be specified."); + } + + this.cladeTaxaValue = cladeTaxa; + this.cladeMRCAAge = cladeMRCAAge; + this.otherTaxa = otherTaxa; + this.rootAge = rootAge; + this.birthRate = birthRate; + + if (otherTaxa == null) { + activeNodes = new ArrayList<>(n() - getTaxaLength(cladeTaxa)); + } else { + activeNodes = new ArrayList<>(getTaxaLength(otherTaxa)); + } + inactiveNodes = new ArrayList<>(); + } + + private int getTaxaLength(Value taxa) { + int nTaxa = 0; + if (taxa.value() instanceof Taxa) { + nTaxa = ((Taxa) taxa.value()).length(); + } else if (taxa.value().getClass().isArray()) { + nTaxa = ((Object[]) taxa.value()).length; + } else { + throw new IllegalArgumentException("Taxa must be of type Object[] or Taxa!"); + } + return nTaxa; + } + + @GeneratorInfo(name = "CalibratedYule", + category = GeneratorCategory.BD_TREE, + description = "The CalibratedYule method accepts one or more clade taxa and generates a tip-labelled time tree. If a root age is provided, the method conditions the tree generation on this root age.") + @Override + public RandomVariable sample() { + // construct the clade taxa first + constructCladeTaxa(); + + //adding other taxa to active node list, must come after constructCladeTaxa() call + constructOtherTaxa(); + + Number[] cladeMRCAAge = getCladeMRCAAge().value(); + + // do another check after constructing clade taxa + if (getCladeTaxaArray().length != cladeMRCAAge.length) throw new IllegalArgumentException("The number of clade mrca age should be the same as clade taxa number!"); + + // initialise a new tree + TimeTree tree = new TimeTree(); + + // get active nodes names + List activeNodeNames = new ArrayList<>(); + for (TimeTreeNode node : activeNodes){ + activeNodeNames.add(node.getId()); + } + + for (int i = 0; i < getCladeTaxaArray().length; i++) { + // generate the clade tree + TimeTree cladeTree = getCladeTree(cladeMRCAAge[i], cladeTaxaArray[i]); + // add the root node to inactiveNodes + inactiveNodes.add(cladeTree.getRoot()); + + // check repeat names + List leafNames = List.of(cladeTree.getRoot().getAllLeafNodeNames()); + // prepare for checking active nodes names + boolean hasRepeatName = leafNames.stream().anyMatch(activeNodeNames::contains); + + // prepare for checking inactive node leaf nodes names + Set allCladeNames = new HashSet<>(); + boolean hasRepeatCladeNames = false; + + for (Taxa cladeTaxa : getCladeTaxaArray()) { + for (Taxon taxon : cladeTaxa.getTaxonArray()) { + if (!allCladeNames.add(taxon.getName())) { + hasRepeatCladeNames = true; + break; + } + } + if (hasRepeatCladeNames) { + break; + } + } + + // if there are repeat names, change all leaf node names for the clade + if (hasRepeatName || hasRepeatCladeNames){ + List leafNodes = cladeTree.getLeafNodes(); + for (int j = 0; j 1){ + // sample t with exp distribution + double mean = activeNodes.size() * lambda; + Value meanValue = new Value<>("mean", mean); + Exp exp = new Exp(meanValue); + + t += exp.sample().value(); + + if (inactiveNodes.size() != 0) { + if (t >= getYoungestNode(inactiveNodes).getAge()) { // count clade root to coalesce + t = getYoungestNode(inactiveNodes).getAge(); + // add cladeRoot to the candidate list if it's not exist in it + if (!activeNodes.contains(getYoungestNode(inactiveNodes))) { + activeNodes.add(getYoungestNode(inactiveNodes)); + inactiveNodes.remove(getYoungestNode(inactiveNodes)); + } + + coalesceNodes(activeNodes, t); + } else { // do not count clade root to coalesce + coalesceNodes(activeNodes, t); + if (activeNodes.size() == 1 && !activeNodes.contains(getYoungestNode(inactiveNodes))) { + activeNodes.add(getYoungestNode(inactiveNodes)); + inactiveNodes.remove(getYoungestNode(inactiveNodes)); + t = getYoungestNode(inactiveNodes).getAge(); + } + } + } else coalesceNodes(activeNodes, t); + } + + System.out.println("done coalescent"); + // set root to construct the tree + if (tree != null) { + tree.setRoot(activeNodes.get(0), true); + } + + // specify the root age if given + if (rootAge != null){ + Number rootAgeValue = getRootAge().value(); + if (rootAgeValue instanceof Double) { + tree.getRoot().setAge((double) rootAgeValue); + } else { + // handle other number types if necessary + tree.getRoot().setAge(rootAgeValue.doubleValue()); + } + } + + System.out.println("calibrated yule tree is " + tree); + return new RandomVariable<>(null, tree, this); + } + + /** + * Get a Yule tree for each clade taxa. + * @param cladeMRCAAge + * @param taxa + * @return the Yule tree + */ + private TimeTree getCladeTree(Number cladeMRCAAge, Taxa taxa) { + Value cladeLengthValue = new Value<> (null, taxa.length()); + Value cladeMRCAAgeValue = new Value<>(null, cladeMRCAAge); + Value taxaValue = new Value<>(null, taxa); + Yule yuleInstance = new Yule(getBirthRate(), cladeLengthValue, taxaValue, cladeMRCAAgeValue); + + return yuleInstance.sample().value(); + } + + /** + * Get the youngest node in the node list. + * @param inactiveNodes a list of nodes + * @return the node with the smallest age + */ + private TimeTreeNode getYoungestNode(List inactiveNodes) { + TimeTreeNode tempNode = inactiveNodes.get(0); + double age = tempNode.getAge(); + for (TimeTreeNode node : inactiveNodes){ + if (age > node.getAge()){ + tempNode = node; + age = node.getAge(); + } + } + return tempNode; + } + + public void constructCladeTaxa() { + Object cladeTaxaValueObject = getCladeTaxa().value(); + + if (cladeTaxaValueObject instanceof Taxa) { + cladeTaxaArray = new Taxa[] {(Taxa) cladeTaxaValueObject}; + } else if (cladeTaxaValueObject.getClass().isArray()) { + if (cladeTaxaValueObject instanceof Taxa[]){ + cladeTaxaArray = (Taxa[]) cladeTaxaValueObject; + } else if (cladeTaxaValueObject instanceof Taxon[]) { + cladeTaxaArray = new Taxa[] {Taxa.createTaxa((Taxon[]) cladeTaxaValueObject)}; + } else if (cladeTaxaValueObject instanceof Taxon[][]){ + Taxon[][] taxonArray = (Taxon[][]) cladeTaxaValueObject; + cladeTaxaArray = new Taxa[taxonArray.length]; + + for (int i = 0; i < taxonArray.length; i++) { + Taxon[] innerArray = taxonArray[i]; + cladeTaxaArray[i] = Taxa.createTaxa(innerArray); + } + } else if (cladeTaxaValueObject instanceof Object[]) { + if (((Object[]) cladeTaxaValueObject).length > 0 && ((Object[]) cladeTaxaValueObject)[0] instanceof Object[]) { + Object[][] objectArray = (Object[][]) cladeTaxaValueObject; + cladeTaxaArray = new Taxa[objectArray.length]; + for (int i = 0; i < objectArray.length; i++) { + cladeTaxaArray[i] = Taxa.createTaxa(objectArray[i]); + } + } else { + cladeTaxaArray = new Taxa[]{Taxa.createTaxa((Object[]) cladeTaxaValueObject)}; + } + } else { + throw new IllegalArgumentException(taxaParamName + " must be of type Object[], Taxa, or Taxa[], but it is type " + cladeTaxaValueObject.getClass()); + } + } + } + + private void constructOtherTaxa() { + if (getOtherTaxa() == null) { + + int totalCladeTaxaLength = 0; + for (Taxa taxa : cladeTaxaArray) { + totalCladeTaxaLength += taxa.length(); + } + taxa = Taxa.createTaxa(n() - totalCladeTaxaLength); + mapActiveNodes(); + } else { + if (getOtherTaxa().value() instanceof Taxa) { + taxa = (Taxa) getOtherTaxa().value(); + mapActiveNodes(); + } else if (getOtherTaxa().value().getClass().isArray()) { + if (getOtherTaxa().value() instanceof Taxon[]) { + taxa = Taxa.createTaxa((Taxon[]) getOtherTaxa().value()); + mapActiveNodes(); + } else { + taxa = Taxa.createTaxa((Object[]) getOtherTaxa().value()); + mapActiveNodes(); + } + } else { + throw new IllegalArgumentException(taxaParamName + " must be of type Object[] or Taxa, but it is type " + getOtherTaxa().value().getClass()); + } + } + } + + private void mapActiveNodes() { + TimeTreeNode[] nodes = new TimeTreeNode[taxa.ntaxa()]; + for (int i = 0; i activeNodes, double t) { + // random two nodes to coalesceT + List nodes = randomTwoNodes(activeNodes); + + TimeTreeNode node1 = nodes.get(0); + TimeTreeNode node2 = nodes.get(1); + + // create the parent node + TimeTreeNode parentNode = new TimeTreeNode(t); + parentNode.addChild(node1); + parentNode.addChild(node2); + node1.setParent(parentNode); + node2.setParent(parentNode); + + // remove coalesced nodes from the candidate list and add parent + activeNodes.remove(node1); + activeNodes.remove(node2); + activeNodes.add(parentNode); + } + + // public for unit test + public static List randomTwoNodes(List activeNodes) { + // get node1 + TimeTreeNode node1 = randomNode(activeNodes); + + // get a new list without node1 + List copyList = new ArrayList<>(activeNodes); + copyList.remove(node1); + + // get node2 + TimeTreeNode node2 = randomNode(copyList); + + // create the random result list + List randomNodes = new ArrayList<>(2); + randomNodes.add(node1); + randomNodes.add(node2); + + return randomNodes; + } + + private static TimeTreeNode randomNode(List nodeList) { + // create uniform discrete instance + Value lower = new Value<>("low", 0); + Value upper = new Value<>("high", nodeList.size()-1); + UniformDiscrete uniformDiscrete = new UniformDiscrete(lower, upper); + // random an index + RandomVariable index = uniformDiscrete.sample(); + return nodeList.get(index.value()); + } + + @Override + public Map getParams() { + Map map = super.getParams(); + map.put(BirthDeathConstants.lambdaParamName, birthRate); + map.put(BirthDeathConstants.rootAgeParamName, rootAge); + map.put(cladeMRCAAgeName, cladeMRCAAge); + map.put(cladeTaxaName, cladeTaxaValue); + map.put(otherTaxaName,otherTaxa); + return map; + } + + public void setParam(String paramName, Value value){ + if (paramName.equals(BirthDeathConstants.lambdaParamName)) birthRate = value; + else if (paramName.equals(BirthDeathConstants.rootAgeParamName)) rootAge = value; + else if (paramName.equals(cladeTaxaName)) { + cladeTaxaValue = value; + constructCladeTaxa(); + } + else if (paramName.equals(cladeMRCAAgeName)) cladeMRCAAge = value; + else if (paramName.equals(otherTaxaName)) otherTaxa = value; + else super.setParam(paramName, value); + } + + public Value getBirthRate(){ + return getParams().get(BirthDeathConstants.lambdaParamName); + } + public Value getCladeTaxa(){ + return getParams().get(cladeTaxaName); + } + public Taxa[] getCladeTaxaArray(){ + return cladeTaxaArray; + } + public Value getCladeMRCAAge(){ + return getParams().get(cladeMRCAAgeName); + } + public Value getOtherTaxa(){ + return getParams().get(otherTaxaName); + } + public Value getRootAge(){ + return getParams().get(BirthDeathConstants.rootAgeParamName); + } +} \ No newline at end of file diff --git a/lphy-base/src/test/java/lphy/base/evolution/tree/CalibratedYuleTest.java b/lphy-base/src/test/java/lphy/base/evolution/tree/CalibratedYuleTest.java new file mode 100644 index 000000000..a8d3279c2 --- /dev/null +++ b/lphy-base/src/test/java/lphy/base/evolution/tree/CalibratedYuleTest.java @@ -0,0 +1,112 @@ +package lphy.base.evolution.tree; + +import lphy.base.evolution.birthdeath.CalibratedYule; +import lphy.core.model.Value; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class CalibratedYuleTest { + @Test + void test1() { //test one clade taxa + double birthRate = 0.25; + int n = 520; + String[] taxa = {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}; + Number[] cladeAge = new Number[]{5.5}; + + Value birthRateValue = new Value<>("birthRate", birthRate); + Value nValue = new Value<>("n", n); + Value taxaValue = new Value("taxa", taxa); + Value cladeAgeValue = new Value<>("cladeAge", cladeAge); + + CalibratedYule instance = new CalibratedYule(birthRateValue, nValue, taxaValue, cladeAgeValue, null, null); + TimeTree observe = instance.sample().value(); + + // node number should be same + assertEquals(n , observe.getRoot().getAllLeafNodes().size()); + // randomly check the names for clade taxa + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade_1")); + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade_3")); + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade_10")); + + } + + @Test + void randomNodeTest() { + TimeTreeNode node1 = new TimeTreeNode(1.0); + TimeTreeNode node2 = new TimeTreeNode(1.1); + TimeTreeNode node3 = new TimeTreeNode(1.2); + TimeTreeNode node4 = new TimeTreeNode(1.3); + + List activeNodes = new ArrayList<>(); + activeNodes.add(node1); + activeNodes.add(node2); + activeNodes.add(node3); + activeNodes.add(node4); + + for (int i = 0; i<30; i++) { + List observe = CalibratedYule.randomTwoNodes(activeNodes); + double age1 = observe.get(0).age; + double age2 = observe.get(1).age; + assertNotEquals(age1, age2, "The ages should not be equal"); + } + } + + @Test + void test2() { // test multiple clade taxa + double birthRate = 0.25; + int n = 520; + String[] taxa1 = {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}; + String[] taxa2 = {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}; + String[][] taxa = {taxa1, taxa2}; + Number[] cladeAge = new Number[]{5.5 , 7}; + + Value birthRateValue = new Value<>("birthRate", birthRate); + Value nValue = new Value<>("n", n); + Value taxaValue = new Value("taxa", taxa); + Value cladeAgeValue = new Value<>("cladeAge", cladeAge); + Value rootAgeValue = new Value<>("rootAge", 19); + + CalibratedYule instance = new CalibratedYule(birthRateValue, nValue, taxaValue, cladeAgeValue, null, rootAgeValue); + TimeTree observe = instance.sample().value(); + + // node number should be same + assertEquals(n , observe.getRoot().getAllLeafNodes().size()); + // check root age + assertEquals(19, observe.getRoot().age); + // randomly draw and check leaf node names + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade0_1")); + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade1_3")); + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade0_10")); + } + + @Test + void test3() { // test same clade names + double birthRate = 0.25; + int n = 520; + String[] taxa1 = {"taxa1", "taxa2", "taxa3", "taxa4", "taxa5", "taxa6", "taxa7", "taxa8", "taxa9", "taxa10"}; + String[] taxa2 = {"taxa1", "taxa2", "taxa3", "taxa4", "taxa5", "taxa6", "taxa7", "taxa8", "taxa9", "taxa10"}; + String[][] taxa = {taxa1, taxa2}; + Number[] cladeAge = new Number[]{5.5 , 7}; + + Value birthRateValue = new Value<>("birthRate", birthRate); + Value nValue = new Value<>("n", n); + Value taxaValue = new Value("taxa", taxa); + Value cladeAgeValue = new Value<>("cladeAge", cladeAge); + + CalibratedYule instance = new CalibratedYule(birthRateValue, nValue, taxaValue, cladeAgeValue, null, null); + TimeTree observe = instance.sample().value(); + + // node number should be same + assertEquals(n , observe.getRoot().getAllLeafNodes().size()); + + // randomly draw and check leaf node names + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade0_taxa1")); + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade1_taxa3")); + assert observe.getRoot().getAllLeafNodes().stream().anyMatch(node -> node.getId().equals("clade0_taxa10")); + } +}