Skip to content

Commit

Permalink
add optional arg to setInternalNodesID #518
Browse files Browse the repository at this point in the history
  • Loading branch information
walterxie committed Nov 21, 2024
1 parent e0b38ba commit a2ef370
Showing 1 changed file with 28 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,24 @@
import lphy.core.model.annotation.GeneratorInfo;
import lphy.core.model.annotation.ParameterInfo;

import java.util.Arrays;
import java.util.List;

import static lphy.base.evolution.EvolutionConstants.treeParamName;

/**
* A function to set internal nodes id given a tree
*/
public class InternalNodesID extends DeterministicFunction<TimeTree> {

public static final String INTER_NODE_ID = "internalNodesID";

public InternalNodesID(@ParameterInfo(name = treeParamName,
description = "the tree to set internal nodes id.") Value<TimeTree> tree) {
description = "the tree to set internal nodes id.") Value<TimeTree> tree,
@ParameterInfo(name = INTER_NODE_ID, optional = true,
description = "the vector of internal nodes id.") Value<Object[]> internalNodesID) {
setInput(treeParamName, tree);
setInput(INTER_NODE_ID, internalNodesID);
}

@GeneratorInfo(name = "setInternalNodesID", category = GeneratorCategory.TREE,
Expand All @@ -26,13 +34,29 @@ public InternalNodesID(@ParameterInfo(name = treeParamName,
public Value<TimeTree> apply() {

Value<TimeTree> tree = getParams().get(treeParamName);
Value<Object[]> internalNodesIDValue = getParams().get(INTER_NODE_ID);
String[] internalNodesID = new String[0];
if (internalNodesIDValue != null && internalNodesIDValue.value() != null) {
internalNodesID= Arrays.stream(internalNodesIDValue.value())
.map(Object::toString).toArray(String[]::new);
}

// do deep copy
TimeTree newTree = new TimeTree(tree.value());

for (TimeTreeNode node : newTree.getInternalNodes()) {
if (node.getId() == null) // set index as id
node.setId(String.valueOf(node.getIndex()));
List<TimeTreeNode> internalNodes = newTree.getInternalNodes();
for (int i = 0; i < internalNodes.size(); i++) {
TimeTreeNode node = internalNodes.get(i);
if (node.getId() == null) {// set index as id
if (internalNodesID.length > 0) {
if (internalNodesID.length != internalNodes.size())
throw new IllegalArgumentException("Internal nodes " + internalNodes.size() +
" do not match IDs + " + internalNodesID.length);
// given ids
node.setId(internalNodesID[i]);
} else // not given ids, default
node.setId(String.valueOf(node.getIndex()));
}
}

return new Value<>(null, newTree, this);
Expand Down

0 comments on commit a2ef370

Please sign in to comment.