Skip to content

Commit

Permalink
first version to create XML with new operators #109
Browse files Browse the repository at this point in the history
  • Loading branch information
walterxie committed Jan 16, 2024
1 parent 8ac4505 commit 2d2aaaf
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 24 deletions.
17 changes: 17 additions & 0 deletions lphybeast/src/main/java/lphybeast/BEASTContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ public class BEASTContext {
// A list of strategy patterns define how to create operators in extensions,
// which already exclude DefaultTreeOperatorStrategy
private List<TreeOperatorStrategy> newTreeOperatorStrategies;
// BEAST 2.7.6 introduced AdaptableOperatorSampler and AdaptableVarianceMultivariateNormalOperator
// which creates the new format for some operators.
// A list of AdaptableOperatorSampler, and create AVMNOperator if list is not empty.
private List<BEASTInterface> beastObjForOpSamplers = new ArrayList<>();

//*** operators ***//
// a list of extra loggables in 3 default loggers: parameter logger, screen logger, tree logger.
Expand Down Expand Up @@ -1037,6 +1041,7 @@ public void clear() {
beastObjects.clear();
extraOperators.clear();
skipOperators.clear();
beastObjForOpSamplers.clear();
}

public void runBEAST(String logFileStem) {
Expand Down Expand Up @@ -1068,6 +1073,18 @@ public boolean hasExtraOperator(String opID) {
return extraOperators.stream().anyMatch(op -> op.getID().equals(opID));
}

public void addBeastObjForOpSamplers(BEASTInterface beastInterface) {
beastObjForOpSamplers.add(beastInterface);
}

public boolean isForOperatorSampler(BEASTInterface beastInterface) {
return beastObjForOpSamplers.contains(beastInterface);
}

public List<BEASTInterface> getBeastObjForOpSamplers() {
return beastObjForOpSamplers;
}

public List<StateNode> getState() {
return state;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lphybeast.tobeast.generators;

import beast.base.core.BEASTInterface;
import beast.base.evolution.substitutionmodel.Frequencies;
import beast.base.inference.parameter.RealParameter;
import lphy.base.evolution.substitutionmodel.HKY;
import lphybeast.BEASTContext;
Expand All @@ -10,10 +11,31 @@ public class HKYToBEAST implements GeneratorToBEAST<HKY, beast.base.evolution.su
@Override
public beast.base.evolution.substitutionmodel.HKY generatorToBEAST(HKY hky, BEASTInterface value, BEASTContext context) {

RealParameter kappa = (RealParameter) context.getBEASTObject(hky.getKappa());
RealParameter freqParam = (RealParameter) context.getBEASTObject(hky.getFreq());
Frequencies frequencies = BEASTContext.createBEASTFrequencies(freqParam,"A C G T");

beast.base.evolution.substitutionmodel.HKY beastHKY = new beast.base.evolution.substitutionmodel.HKY();
beastHKY.setInputValue("kappa", context.getBEASTObject(hky.getKappa()));
beastHKY.setInputValue("frequencies", BEASTContext.createBEASTFrequencies((RealParameter) context.getBEASTObject(hky.getFreq()),"A C G T"));
beastHKY.setInputValue("kappa", kappa);
beastHKY.setInputValue("frequencies", frequencies);
beastHKY.initAndValidate();

// <operator id="KappaScaler.s:$(n)" spec="beast.base.evolution.operator.AdaptableOperatorSampler" weight="0.05">
// <parameter idref="kappa.s:$(n)"/>
// <operator idref="AVMNOperator.$(n)"/>
// <operator id='KappaScalerX.s:$(n)' spec='kernel.BactrianScaleOperator' scaleFactor="0.1" weight="0.1" parameter="@kappa.s:$(n)"/>
// </operator>
//
// <operator id="FrequenciesExchanger.s:$(n)" spec="beast.base.evolution.operator.AdaptableOperatorSampler" weight="0.05">
// <parameter idref="freqParameter.s:$(n)"/>
// <operator idref="AVMNOperator.$(n)"/>
// <operator id='FrequenciesExchangerX.s:$(n)' spec='kernel.BactrianDeltaExchangeOperator' delta="0.01" weight="0.1" parameter="@freqParameter.s:$(n)"/>
// </operator>

// they will create AdaptableOperatorSampler later
context.addBeastObjForOpSamplers(kappa);
context.addBeastObjForOpSamplers(frequencies);

return beastHKY;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
package lphybeast.tobeast.generators;

import beast.base.core.BEASTInterface;
import beast.base.core.Function;
import beast.base.core.Input;
import beast.base.evolution.branchratemodel.StrictClockModel;
import beast.base.evolution.branchratemodel.UCRelaxedClockModel;
import beast.base.evolution.datatype.DataType;
import beast.base.evolution.datatype.UserDataType;
import beast.base.evolution.likelihood.GenericTreeLikelihood;
import beast.base.evolution.likelihood.ThreadedTreeLikelihood;
import beast.base.evolution.operator.kernel.AdaptableVarianceMultivariateNormalOperator;
import beast.base.evolution.sitemodel.SiteModel;
import beast.base.evolution.substitutionmodel.Frequencies;
import beast.base.evolution.substitutionmodel.SubstitutionModel;
import beast.base.evolution.tree.Tree;
import beast.base.inference.distribution.Prior;
import beast.base.inference.operator.kernel.Transform;
import beast.base.inference.parameter.RealParameter;
import beastclassic.evolution.alignment.AlignmentFromTrait;
import beastclassic.evolution.likelihood.AncestralStateTreeLikelihood;
Expand All @@ -34,6 +39,7 @@
import lphybeast.tobeast.loggers.TraitTreeLogger;
import lphybeast.tobeast.operators.DefaultOperatorStrategy;

import java.util.List;
import java.util.Map;

public class PhyloCTMCToBEAST implements GeneratorToBEAST<PhyloCTMC, GenericTreeLikelihood> {
Expand All @@ -55,8 +61,9 @@ private AncestralStateTreeLikelihood createAncestralStateTreeLikelihood(PhyloCTM
AncestralStateTreeLikelihood treeLikelihood = new AncestralStateTreeLikelihood();
treeLikelihood.setInputValue("tag", LOCATION);
treeLikelihood.setInputValue("data", traitAlignment);

constructTreeAndBranchRate(phyloCTMC, treeLikelihood, context);
//TODO
constructTreeAndBranchRate(phyloCTMC, treeLikelihood, null, null,
context, false);

DataType userDataType = traitAlignment.getDataType();
if (! (userDataType instanceof UserDataType) )
Expand Down Expand Up @@ -129,28 +136,37 @@ private ThreadedTreeLikelihood createThreadedTreeLikelihood(PhyloCTMC phyloCTMC,
beast.base.evolution.alignment.Alignment alignment = (beast.base.evolution.alignment.Alignment)value;
treeLikelihood.setInputValue("data", alignment);

constructTreeAndBranchRate(phyloCTMC, treeLikelihood, context);
// AVMNOperator for each TreeLikelihood
AdaptableVarianceMultivariateNormalOperator opAVMNN = DefaultOperatorStrategy.initAVMNOperator();
opAVMNN.setID(alignment.getID() + ".AVMNOperator");

Transform.LogConstrainedSumTransform sumTransform = DefaultOperatorStrategy.initAVMNSumTransform(alignment.getID());
Transform.LogTransform logTransform = DefaultOperatorStrategy.initLogTransform(alignment.getID());
Transform.NoTransform noTransform = DefaultOperatorStrategy.initNoTransform(alignment.getID());

// branch models
constructTreeAndBranchRate(phyloCTMC, treeLikelihood, logTransform, noTransform,
context, false);

SiteModel siteModel = constructSiteModel(phyloCTMC, context);
SiteModel siteModel = constructSiteModel(phyloCTMC, sumTransform, logTransform, context);
treeLikelihood.setInputValue("siteModel", siteModel);

treeLikelihood.initAndValidate();
treeLikelihood.setID(alignment.getID() + ".treeLikelihood");
// logging
context.addExtraLoggable(treeLikelihood);

return treeLikelihood;
}
// AVMNOperator
sumTransform.initAndValidate();
logTransform.initAndValidate();
noTransform.initAndValidate();

List<Transform> transformList = List.of(sumTransform, logTransform, noTransform);
opAVMNN.setInputValue("transformations", transformList);
opAVMNN.initAndValidate();
context.addBeastObjForOpSamplers(opAVMNN);

/**
* Create tree and clock rate inside this tree likelihood.
* @param phyloCTMC
* @param treeLikelihood
* @param context
*/
public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLikelihood treeLikelihood, BEASTContext context) {
constructTreeAndBranchRate(phyloCTMC, treeLikelihood, context, false);
return treeLikelihood;
}

/**
Expand All @@ -160,7 +176,9 @@ public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLi
* @param context
* @param skipBranchOperators skip constructing branch rates
*/
public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLikelihood treeLikelihood, BEASTContext context, boolean skipBranchOperators) {
public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLikelihood treeLikelihood,
Transform.LogTransform logTransform, Transform.NoTransform noTransform,
BEASTContext context, boolean skipBranchOperators) {
Value<TimeTree> timeTreeValue = phyloCTMC.getTree();
Tree tree = (Tree) context.getBEASTObject(timeTreeValue);
//tree.setInputValue("taxa", value);
Expand All @@ -175,7 +193,7 @@ public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLi
Generator generator = branchRates.getGenerator();
if (generator instanceof IID &&
((IID<?>) generator).getBaseDistribution() instanceof LogNormal) {

//TODO migrate to new operators
// simpleRelaxedClock.lphy
UCRelaxedClockModel relaxedClockModel = new UCRelaxedClockModel();

Expand Down Expand Up @@ -216,8 +234,14 @@ public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLi

if (clockRate instanceof RandomVariable && timeTreeValue instanceof RandomVariable && skipBranchOperators == false) {
DefaultOperatorStrategy.addUpDownOperator(tree, clockRatePara, context);
// will create AdaptableOperatorSampler later
context.addBeastObjForOpSamplers(clockRatePara);
// AVMN log Transform
logTransform.setInputValue("f", clockRatePara);
}
}
// AVMN No Transform
noTransform.setInputValue("f", tree);
}


Expand All @@ -226,7 +250,8 @@ public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLi
* @param context the beast context
* @return a BEAST SiteModel representing the site model of this LPHY PhyloCTMC
*/
public static SiteModel constructSiteModel(PhyloCTMC phyloCTMC, BEASTContext context) {
public static SiteModel constructSiteModel(PhyloCTMC phyloCTMC, Transform.LogConstrainedSumTransform sumTransform,
Transform.LogTransform logTransform, BEASTContext context) {

SiteModel siteModel = new SiteModel();

Expand All @@ -251,11 +276,20 @@ public static SiteModel constructSiteModel(PhyloCTMC phyloCTMC, BEASTContext con
} else {
throw new UnsupportedOperationException("Only discretized gamma site rates are supported by LPhyBEAST !");
}
siteModel.setInputValue("shape", context.getAsRealParameter(shape));
RealParameter shapeParam = context.getAsRealParameter(shape);
siteModel.setInputValue("shape", shapeParam);
// ncat is Integer, do not require to be parameter
siteModel.setInputValue("gammaCategoryCount", ncat.value());

//TODO add proportionInvariant

//TODO need a better solution than rm RandomVariable siteRates
context.removeBEASTObject(context.getBEASTObject(siteRates));

// will create AdaptableOperatorSampler later
context.addBeastObjForOpSamplers(shapeParam);
// AVMN log Transform
logTransform.setInputValue("f", shapeParam);
}

// Scenario 2: siteRates = NULL
Expand All @@ -267,10 +301,40 @@ public static SiteModel constructSiteModel(PhyloCTMC phyloCTMC, BEASTContext con
if (substitutionModel == null) throw new IllegalArgumentException("Substitution Model was null!");
siteModel.setInputValue("substModel", substitutionModel);

if (substitutionModel instanceof SubstitutionModel.Base substBase) {

Map<String, Input<?>> allInputs = substBase.getInputs();
// check if any inputs of SubstitutionModel.Base have been added to create AdaptableOperatorSampler,
// therefore no context.addBeastObjForOpSamplers here
for (Map.Entry<String, Input<?>> entry : allInputs.entrySet()) {
Input<?> input = entry.getValue();
if (input.get() instanceof BEASTInterface beastInterface) {
if (context.isForOperatorSampler(beastInterface)) {
if (beastInterface instanceof Frequencies frequencies) {
Function freqParam = frequencies.frequenciesInput.get();
// AVMN Log Constrained Sum Transform
sumTransform.setInputValue("f", freqParam);
sumTransform.setInputValue("sum", "1.0");
} else
// AVMN log Transform
logTransform.setInputValue("f", beastInterface);
}
}
}

}

RateMatrix rateMatrix = (RateMatrix)qGenerator;
Value<Double> meanRate = rateMatrix.getMeanRate();
BEASTInterface mutationRate = meanRate==null ? null : context.getBEASTObject(meanRate);
if (mutationRate != null) siteModel.setInputValue("mutationRate", mutationRate);
if (mutationRate != null) {
siteModel.setInputValue("mutationRate", mutationRate);

// will create AdaptableOperatorSampler later
context.addBeastObjForOpSamplers(mutationRate);
// AVMN log Transform
logTransform.setInputValue("f", mutationRate);
}

siteModel.initAndValidate();
}
Expand Down
Loading

0 comments on commit 2d2aaaf

Please sign in to comment.