From 2d2aaaf5b2c17a16b614a669944b94c62ab922fa Mon Sep 17 00:00:00 2001 From: walterxie Date: Tue, 16 Jan 2024 18:52:42 +1300 Subject: [PATCH] first version to create XML with new operators #109 --- .../src/main/java/lphybeast/BEASTContext.java | 17 +++ .../tobeast/generators/HKYToBEAST.java | 26 ++++- .../tobeast/generators/PhyloCTMCToBEAST.java | 102 ++++++++++++++---- .../operators/DefaultOperatorStrategy.java | 72 ++++++++++++- 4 files changed, 193 insertions(+), 24 deletions(-) diff --git a/lphybeast/src/main/java/lphybeast/BEASTContext.java b/lphybeast/src/main/java/lphybeast/BEASTContext.java index f224501..e301665 100644 --- a/lphybeast/src/main/java/lphybeast/BEASTContext.java +++ b/lphybeast/src/main/java/lphybeast/BEASTContext.java @@ -100,6 +100,10 @@ public class BEASTContext { // A list of strategy patterns define how to create operators in extensions, // which already exclude DefaultTreeOperatorStrategy private List newTreeOperatorStrategies; + // BEAST 2.7.6 introduced AdaptableOperatorSampler and AdaptableVarianceMultivariateNormalOperator + // which creates the new format for some operators. + // A list of AdaptableOperatorSampler, and create AVMNOperator if list is not empty. + private List beastObjForOpSamplers = new ArrayList<>(); //*** operators ***// // a list of extra loggables in 3 default loggers: parameter logger, screen logger, tree logger. @@ -1037,6 +1041,7 @@ public void clear() { beastObjects.clear(); extraOperators.clear(); skipOperators.clear(); + beastObjForOpSamplers.clear(); } public void runBEAST(String logFileStem) { @@ -1068,6 +1073,18 @@ public boolean hasExtraOperator(String opID) { return extraOperators.stream().anyMatch(op -> op.getID().equals(opID)); } + public void addBeastObjForOpSamplers(BEASTInterface beastInterface) { + beastObjForOpSamplers.add(beastInterface); + } + + public boolean isForOperatorSampler(BEASTInterface beastInterface) { + return beastObjForOpSamplers.contains(beastInterface); + } + + public List getBeastObjForOpSamplers() { + return beastObjForOpSamplers; + } + public List getState() { return state; } diff --git a/lphybeast/src/main/java/lphybeast/tobeast/generators/HKYToBEAST.java b/lphybeast/src/main/java/lphybeast/tobeast/generators/HKYToBEAST.java index d81a11d..085d105 100644 --- a/lphybeast/src/main/java/lphybeast/tobeast/generators/HKYToBEAST.java +++ b/lphybeast/src/main/java/lphybeast/tobeast/generators/HKYToBEAST.java @@ -1,6 +1,7 @@ package lphybeast.tobeast.generators; import beast.base.core.BEASTInterface; +import beast.base.evolution.substitutionmodel.Frequencies; import beast.base.inference.parameter.RealParameter; import lphy.base.evolution.substitutionmodel.HKY; import lphybeast.BEASTContext; @@ -10,10 +11,31 @@ public class HKYToBEAST implements GeneratorToBEAST +// +// +// +// +// +// +// +// +// +// + + // they will create AdaptableOperatorSampler later + context.addBeastObjForOpSamplers(kappa); + context.addBeastObjForOpSamplers(frequencies); + return beastHKY; } diff --git a/lphybeast/src/main/java/lphybeast/tobeast/generators/PhyloCTMCToBEAST.java b/lphybeast/src/main/java/lphybeast/tobeast/generators/PhyloCTMCToBEAST.java index ccdcb0a..b2b3f61 100644 --- a/lphybeast/src/main/java/lphybeast/tobeast/generators/PhyloCTMCToBEAST.java +++ b/lphybeast/src/main/java/lphybeast/tobeast/generators/PhyloCTMCToBEAST.java @@ -1,16 +1,21 @@ package lphybeast.tobeast.generators; import beast.base.core.BEASTInterface; +import beast.base.core.Function; +import beast.base.core.Input; import beast.base.evolution.branchratemodel.StrictClockModel; import beast.base.evolution.branchratemodel.UCRelaxedClockModel; import beast.base.evolution.datatype.DataType; import beast.base.evolution.datatype.UserDataType; import beast.base.evolution.likelihood.GenericTreeLikelihood; import beast.base.evolution.likelihood.ThreadedTreeLikelihood; +import beast.base.evolution.operator.kernel.AdaptableVarianceMultivariateNormalOperator; import beast.base.evolution.sitemodel.SiteModel; +import beast.base.evolution.substitutionmodel.Frequencies; import beast.base.evolution.substitutionmodel.SubstitutionModel; import beast.base.evolution.tree.Tree; import beast.base.inference.distribution.Prior; +import beast.base.inference.operator.kernel.Transform; import beast.base.inference.parameter.RealParameter; import beastclassic.evolution.alignment.AlignmentFromTrait; import beastclassic.evolution.likelihood.AncestralStateTreeLikelihood; @@ -34,6 +39,7 @@ import lphybeast.tobeast.loggers.TraitTreeLogger; import lphybeast.tobeast.operators.DefaultOperatorStrategy; +import java.util.List; import java.util.Map; public class PhyloCTMCToBEAST implements GeneratorToBEAST { @@ -55,8 +61,9 @@ private AncestralStateTreeLikelihood createAncestralStateTreeLikelihood(PhyloCTM AncestralStateTreeLikelihood treeLikelihood = new AncestralStateTreeLikelihood(); treeLikelihood.setInputValue("tag", LOCATION); treeLikelihood.setInputValue("data", traitAlignment); - - constructTreeAndBranchRate(phyloCTMC, treeLikelihood, context); +//TODO + constructTreeAndBranchRate(phyloCTMC, treeLikelihood, null, null, + context, false); DataType userDataType = traitAlignment.getDataType(); if (! (userDataType instanceof UserDataType) ) @@ -129,9 +136,19 @@ private ThreadedTreeLikelihood createThreadedTreeLikelihood(PhyloCTMC phyloCTMC, beast.base.evolution.alignment.Alignment alignment = (beast.base.evolution.alignment.Alignment)value; treeLikelihood.setInputValue("data", alignment); - constructTreeAndBranchRate(phyloCTMC, treeLikelihood, context); + // AVMNOperator for each TreeLikelihood + AdaptableVarianceMultivariateNormalOperator opAVMNN = DefaultOperatorStrategy.initAVMNOperator(); + opAVMNN.setID(alignment.getID() + ".AVMNOperator"); + + Transform.LogConstrainedSumTransform sumTransform = DefaultOperatorStrategy.initAVMNSumTransform(alignment.getID()); + Transform.LogTransform logTransform = DefaultOperatorStrategy.initLogTransform(alignment.getID()); + Transform.NoTransform noTransform = DefaultOperatorStrategy.initNoTransform(alignment.getID()); + + // branch models + constructTreeAndBranchRate(phyloCTMC, treeLikelihood, logTransform, noTransform, + context, false); - SiteModel siteModel = constructSiteModel(phyloCTMC, context); + SiteModel siteModel = constructSiteModel(phyloCTMC, sumTransform, logTransform, context); treeLikelihood.setInputValue("siteModel", siteModel); treeLikelihood.initAndValidate(); @@ -139,18 +156,17 @@ private ThreadedTreeLikelihood createThreadedTreeLikelihood(PhyloCTMC phyloCTMC, // logging context.addExtraLoggable(treeLikelihood); - return treeLikelihood; - } + // AVMNOperator + sumTransform.initAndValidate(); + logTransform.initAndValidate(); + noTransform.initAndValidate(); + List transformList = List.of(sumTransform, logTransform, noTransform); + opAVMNN.setInputValue("transformations", transformList); + opAVMNN.initAndValidate(); + context.addBeastObjForOpSamplers(opAVMNN); - /** - * Create tree and clock rate inside this tree likelihood. - * @param phyloCTMC - * @param treeLikelihood - * @param context - */ - public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLikelihood treeLikelihood, BEASTContext context) { - constructTreeAndBranchRate(phyloCTMC, treeLikelihood, context, false); + return treeLikelihood; } /** @@ -160,7 +176,9 @@ public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLi * @param context * @param skipBranchOperators skip constructing branch rates */ - public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLikelihood treeLikelihood, BEASTContext context, boolean skipBranchOperators) { + public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLikelihood treeLikelihood, + Transform.LogTransform logTransform, Transform.NoTransform noTransform, + BEASTContext context, boolean skipBranchOperators) { Value timeTreeValue = phyloCTMC.getTree(); Tree tree = (Tree) context.getBEASTObject(timeTreeValue); //tree.setInputValue("taxa", value); @@ -175,7 +193,7 @@ public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLi Generator generator = branchRates.getGenerator(); if (generator instanceof IID && ((IID) generator).getBaseDistribution() instanceof LogNormal) { - +//TODO migrate to new operators // simpleRelaxedClock.lphy UCRelaxedClockModel relaxedClockModel = new UCRelaxedClockModel(); @@ -216,8 +234,14 @@ public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLi if (clockRate instanceof RandomVariable && timeTreeValue instanceof RandomVariable && skipBranchOperators == false) { DefaultOperatorStrategy.addUpDownOperator(tree, clockRatePara, context); + // will create AdaptableOperatorSampler later + context.addBeastObjForOpSamplers(clockRatePara); + // AVMN log Transform + logTransform.setInputValue("f", clockRatePara); } } + // AVMN No Transform + noTransform.setInputValue("f", tree); } @@ -226,7 +250,8 @@ public static void constructTreeAndBranchRate(PhyloCTMC phyloCTMC, GenericTreeLi * @param context the beast context * @return a BEAST SiteModel representing the site model of this LPHY PhyloCTMC */ - public static SiteModel constructSiteModel(PhyloCTMC phyloCTMC, BEASTContext context) { + public static SiteModel constructSiteModel(PhyloCTMC phyloCTMC, Transform.LogConstrainedSumTransform sumTransform, + Transform.LogTransform logTransform, BEASTContext context) { SiteModel siteModel = new SiteModel(); @@ -251,11 +276,20 @@ public static SiteModel constructSiteModel(PhyloCTMC phyloCTMC, BEASTContext con } else { throw new UnsupportedOperationException("Only discretized gamma site rates are supported by LPhyBEAST !"); } - siteModel.setInputValue("shape", context.getAsRealParameter(shape)); + RealParameter shapeParam = context.getAsRealParameter(shape); + siteModel.setInputValue("shape", shapeParam); + // ncat is Integer, do not require to be parameter siteModel.setInputValue("gammaCategoryCount", ncat.value()); + //TODO add proportionInvariant + //TODO need a better solution than rm RandomVariable siteRates context.removeBEASTObject(context.getBEASTObject(siteRates)); + + // will create AdaptableOperatorSampler later + context.addBeastObjForOpSamplers(shapeParam); + // AVMN log Transform + logTransform.setInputValue("f", shapeParam); } // Scenario 2: siteRates = NULL @@ -267,10 +301,40 @@ public static SiteModel constructSiteModel(PhyloCTMC phyloCTMC, BEASTContext con if (substitutionModel == null) throw new IllegalArgumentException("Substitution Model was null!"); siteModel.setInputValue("substModel", substitutionModel); + if (substitutionModel instanceof SubstitutionModel.Base substBase) { + + Map> allInputs = substBase.getInputs(); + // check if any inputs of SubstitutionModel.Base have been added to create AdaptableOperatorSampler, + // therefore no context.addBeastObjForOpSamplers here + for (Map.Entry> entry : allInputs.entrySet()) { + Input input = entry.getValue(); + if (input.get() instanceof BEASTInterface beastInterface) { + if (context.isForOperatorSampler(beastInterface)) { + if (beastInterface instanceof Frequencies frequencies) { + Function freqParam = frequencies.frequenciesInput.get(); + // AVMN Log Constrained Sum Transform + sumTransform.setInputValue("f", freqParam); + sumTransform.setInputValue("sum", "1.0"); + } else + // AVMN log Transform + logTransform.setInputValue("f", beastInterface); + } + } + } + + } + RateMatrix rateMatrix = (RateMatrix)qGenerator; Value meanRate = rateMatrix.getMeanRate(); BEASTInterface mutationRate = meanRate==null ? null : context.getBEASTObject(meanRate); - if (mutationRate != null) siteModel.setInputValue("mutationRate", mutationRate); + if (mutationRate != null) { + siteModel.setInputValue("mutationRate", mutationRate); + + // will create AdaptableOperatorSampler later + context.addBeastObjForOpSamplers(mutationRate); + // AVMN log Transform + logTransform.setInputValue("f", mutationRate); + } siteModel.initAndValidate(); } diff --git a/lphybeast/src/main/java/lphybeast/tobeast/operators/DefaultOperatorStrategy.java b/lphybeast/src/main/java/lphybeast/tobeast/operators/DefaultOperatorStrategy.java index 81b56be..2b044cd 100644 --- a/lphybeast/src/main/java/lphybeast/tobeast/operators/DefaultOperatorStrategy.java +++ b/lphybeast/src/main/java/lphybeast/tobeast/operators/DefaultOperatorStrategy.java @@ -2,6 +2,8 @@ import beast.base.core.BEASTInterface; import beast.base.core.BEASTObject; +import beast.base.evolution.operator.AdaptableOperatorSampler; +import beast.base.evolution.operator.kernel.AdaptableVarianceMultivariateNormalOperator; import beast.base.evolution.operator.kernel.BactrianScaleOperator; import beast.base.evolution.tree.Tree; import beast.base.inference.Operator; @@ -11,6 +13,7 @@ import beast.base.inference.operator.kernel.BactrianDeltaExchangeOperator; import beast.base.inference.operator.kernel.BactrianRandomWalkOperator; import beast.base.inference.operator.kernel.BactrianUpDownOperator; +import beast.base.inference.operator.kernel.Transform; import beast.base.inference.parameter.BooleanParameter; import beast.base.inference.parameter.IntegerParameter; import beast.base.inference.parameter.RealParameter; @@ -77,17 +80,25 @@ public List createOperators() { List operators = new ArrayList<>(); + for (BEASTInterface beastInterface : context.getBeastObjForOpSamplers()) { + if (beastInterface instanceof AdaptableVarianceMultivariateNormalOperator opAVMNN) + operators.add(opAVMNN); + } + Set skipOperators = context.getSkipOperators(); for (StateNode stateNode : context.getState()) { if (!skipOperators.contains(stateNode)) { // The default template to create operators if (stateNode instanceof RealParameter realParameter) { Operator operator = createBEASTOperator(realParameter); - if (operator != null) operators.add(operator); + if (operator != null) + addOperatorOrSampler(stateNode, operator, operators); } else if (stateNode instanceof IntegerParameter integerParameter) { - operators.add(createBEASTOperator(integerParameter)); + Operator operator = createBEASTOperator(integerParameter); + addOperatorOrSampler(stateNode, operator, operators); } else if (stateNode instanceof BooleanParameter booleanParameter) { - operators.add(createBitFlipOperator(booleanParameter)); + Operator operator = createBitFlipOperator(booleanParameter); + addOperatorOrSampler(stateNode, operator, operators); } else if (stateNode instanceof Tree tree) { TreeOperatorStrategy treeOperatorStrategy = context.resolveTreeOperatorStrategy(tree); // create operators @@ -95,6 +106,7 @@ public List createOperators() { if (treeOperators.size() < 1) throw new IllegalArgumentException("No operators are created by strategy " + treeOperatorStrategy.getName() + " !"); + //TODO or samplers? operators.addAll(treeOperators); } } @@ -106,6 +118,27 @@ public List createOperators() { return operators; } + protected void addOperatorOrSampler(StateNode stateNode, Operator operator, List operators) { + // frequencies, site and substitution model parameters, trees + if (context.isForOperatorSampler(stateNode)) { + AdaptableOperatorSampler operatorSampler = new AdaptableOperatorSampler(); + + operatorSampler.setInputValue("weight", "0.05"); + // TODO not only parameter + operatorSampler.setInputValue("parameter", stateNode); + // add operator here + operatorSampler.setInputValue("operator", operator); + + //TODO + + operatorSampler.setID(stateNode.getID() + ".OperatorSampler"); + operatorSampler.initAndValidate(); + operators.add(operatorSampler); + + } else operators.add(operator); + + } + //*** parameter operators ***// public Operator createBEASTOperator(RealParameter parameter) { @@ -190,6 +223,39 @@ private Operator createBitFlipOperator(BooleanParameter parameter) { //*** static methods ***// + // AVMNOperator for each TreeLikelihood + public static AdaptableVarianceMultivariateNormalOperator initAVMNOperator() { + AdaptableVarianceMultivariateNormalOperator opAVMNN = new AdaptableVarianceMultivariateNormalOperator(); + opAVMNN.setInputValue("weight", "0.1"); + opAVMNN.setInputValue("coefficient", "1.0"); + opAVMNN.setInputValue("scaleFactor", "1"); + opAVMNN.setInputValue("beta", "0.05"); + opAVMNN.setInputValue("initial", "800"); + opAVMNN.setInputValue("burnin", "400"); + opAVMNN.setInputValue("every", "1"); + opAVMNN.setInputValue("allowNonsense", "true"); + // require initAndValidate later for adding more input in runtime + return opAVMNN; + } + + public static Transform.LogConstrainedSumTransform initAVMNSumTransform(String idSteam) { + Transform.LogConstrainedSumTransform transform = new Transform.LogConstrainedSumTransform(); + transform.setID(idSteam + ".AVMNSumTransform"); + return transform; + } + + public static Transform.LogTransform initLogTransform(String idSteam) { + Transform.LogTransform transform = new Transform.LogTransform(); + transform.setID(idSteam + ".AVMNLogTransform"); + return transform; + } + + public static Transform.NoTransform initNoTransform(String idSteam) { + Transform.NoTransform transform = new Transform.NoTransform(); + transform.setID(idSteam + ".AVMNNoTransform"); + return transform; + } + // when both mu and tree are random var public static void addUpDownOperator(Tree tree, RealParameter clockRate, BEASTContext context) { String idStr = clockRate.getID() + "Up" + tree.getID() + "DownOperator";