From 006dd1273266b1f98ab47938f03fd97d97807076 Mon Sep 17 00:00:00 2001 From: heitzig Date: Sat, 3 Feb 2024 00:07:54 +0100 Subject: [PATCH] example of generating the tree --- examples/runVerySimpleGW.wppl | 161 ++++++++++++++------------- package.json | 1 + src/agents/makeMDPAgentSatisfia.wppl | 2 + src/main.js | 9 ++ src/simulation/getDynTree.wppl | 50 +++++++++ 5 files changed, 144 insertions(+), 79 deletions(-) create mode 100644 src/simulation/getDynTree.wppl diff --git a/examples/runVerySimpleGW.wppl b/examples/runVerySimpleGW.wppl index 32afc08..fbc90a4 100644 --- a/examples/runVerySimpleGW.wppl +++ b/examples/runVerySimpleGW.wppl @@ -30,82 +30,85 @@ var env = getEnv(), messingPotential = agent.messingPotential_state, cupLoss = agent.cupLoss_state; -// Generate and draw a trajectory: -var simulate = function(state, aleph, _t) { - var t = _t ? _t : 0, - aleph4state = asInterval(aleph); - if (options.verbose || options.debug) console.log(pad(state),"SIMULATE, t",t,"state",prettyState(state),"aleph4state",aleph4state,"..."); - var localPolicy = localPolicy(state, aleph4state), - actionAndAleph = sample(localPolicy), - action = actionAndAleph[0], - aleph4action = actionAndAleph[1], - Edel = expectedDelta(state, action); - var stepData = {state, aleph4state, action, aleph4action, Edel}; - if (state.terminateAfterAction) { - if (options.verbose || options.debug) console.log(pad(state),"SIMULATE, t",t,"state",prettyState(state),"aleph4state",aleph4state,": localPolicy",JSON.stringify(localPolicy.params),"\n"+pad(state),"| action",action,"aleph4action",aleph4action,"Edel",Edel,"(terminal)"); - return { - trajectory: [stepData], // sequence of [state, action] pairs - conditionalExpectedIndicator: Edel // expected indicator conditional on this trajectory - }; - } else { - var nextState = transition(state, action), - nextAleph4state = propagateAspiration(state, action, aleph4action, Edel, nextState); - if (options.verbose || options.debug) console.log(pad(state),"SIMULATE, t",t,"state",prettyState(state),"aleph4state",aleph4state,": localPolicy",JSON.stringify(localPolicy.params),"\n"+pad(state),"| action",action,"aleph4action",aleph4action,"Edel",Edel,"nextState",prettyState(nextState),"nextAleph4state",nextAleph4state); - var nextOut = simulate(nextState, nextAleph4state, t+1); - return { - trajectory: [stepData].concat(nextOut.trajectory), - conditionalExpectedIndicator: Edel + nextOut.conditionalExpectedIndicator - }; - } -}; - -console.log("aleph0", asInterval(aleph0)); - -var t0 = webpplAgents.time(); -// verify meeting of expectations: -console.log("V", V(startState, aleph0)); -console.log("TIME:", webpplAgents.time() - t0, "ms"); -console.log("cupLoss", cupLoss(mdp.startState, aleph0)); -console.log("entropy", entropy(mdp.startState, aleph0)); -console.log("KLdiv", KLdiv(mdp.startState, aleph0)); -console.log("messPot", messingPotential(mdp.startState, aleph0)); - -var gd = agent.getData, agentData = gd(); - -// estimate distribution of trajectories: - -var trajDist = Infer({ model() { - return simulate(mdp.startState, aleph0).trajectory; -}}).getDist(); - -console.log("\nDATA FOR REGRESSION TESTS: \ntrajDist"); -var regressionTestData = webpplAgents.trajDist2simpleJSON(trajDist); -console.log(JSON.stringify(regressionTestData)); -console.log("END OF DATA FOR REGRESSION TESTS\n"); - -var trajData = trajDist2TrajData(trajDist, agent); - -//console.log("trajData", trajData); - -var locActionData = webpplAgents.trajDist2LocActionData(trajDist, trajData); -console.log("locActionData", locActionData); - -console.log("\nminAdmissibleQ:"); -console.log(stateActionFct2ASCII(agent.minAdmissibleQ, agentData.stateActionPairs)); -console.log("\nmaxAdmissibleQ:"); -console.log(stateActionFct2ASCII(agent.maxAdmissibleQ, agentData.stateActionPairs)); - -console.log("\nQ:"); -console.log(webpplAgents.locActionData2ASCII(locActionData.Q)); -console.log("\ncupLoss:"); -console.log(webpplAgents.locActionData2ASCII(locActionData.cupLoss)); -console.log("\nmessingPotential:"); -console.log(webpplAgents.locActionData2ASCII(locActionData.messingPotential)); -console.log("\ncombinedLoss:"); -console.log(webpplAgents.locActionData2ASCII(locActionData.combinedLoss)); - -console.log("\naction frequencies:"); -console.log(webpplAgents.locActionData2ASCII(locActionData.actionFrequency)); - - - +console.log(JSON.stringify(getDynTree(agent, mdp.startState, aleph0))); + +if (false) { + + // Generate and draw a trajectory: + var simulate = function(state, aleph, _t) { + var t = _t ? _t : 0, + aleph4state = asInterval(aleph); + if (options.verbose || options.debug) console.log(pad(state),"SIMULATE, t",t,"state",prettyState(state),"aleph4state",aleph4state,"..."); + var localPolicy = localPolicy(state, aleph4state), + actionAndAleph = sample(localPolicy), + action = actionAndAleph[0], + aleph4action = actionAndAleph[1], + Edel = expectedDelta(state, action); + var stepData = {state, aleph4state, action, aleph4action, Edel}; + if (state.terminateAfterAction) { + if (options.verbose || options.debug) console.log(pad(state),"SIMULATE, t",t,"state",prettyState(state),"aleph4state",aleph4state,": localPolicy",JSON.stringify(localPolicy.params),"\n"+pad(state),"| action",action,"aleph4action",aleph4action,"Edel",Edel,"(terminal)"); + return { + trajectory: [stepData], // sequence of [state, action] pairs + conditionalExpectedIndicator: Edel // expected indicator conditional on this trajectory + }; + } else { + var nextState = transition(state, action), + nextAleph4state = propagateAspiration(state, action, aleph4action, Edel, nextState); + if (options.verbose || options.debug) console.log(pad(state),"SIMULATE, t",t,"state",prettyState(state),"aleph4state",aleph4state,": localPolicy",JSON.stringify(localPolicy.params),"\n"+pad(state),"| action",action,"aleph4action",aleph4action,"Edel",Edel,"nextState",prettyState(nextState),"nextAleph4state",nextAleph4state); + var nextOut = simulate(nextState, nextAleph4state, t+1); + return { + trajectory: [stepData].concat(nextOut.trajectory), + conditionalExpectedIndicator: Edel + nextOut.conditionalExpectedIndicator + }; + } + }; + + console.log("aleph0", asInterval(aleph0)); + + var t0 = webpplAgents.time(); + // verify meeting of expectations: + console.log("V", V(startState, aleph0)); + console.log("TIME:", webpplAgents.time() - t0, "ms"); + console.log("cupLoss", cupLoss(mdp.startState, aleph0)); + console.log("entropy", entropy(mdp.startState, aleph0)); + console.log("KLdiv", KLdiv(mdp.startState, aleph0)); + console.log("messPot", messingPotential(mdp.startState, aleph0)); + + var gd = agent.getData, agentData = gd(); + + // estimate distribution of trajectories: + + var trajDist = Infer({ model() { + return simulate(mdp.startState, aleph0).trajectory; + }}).getDist(); + + console.log("\nDATA FOR REGRESSION TESTS: \ntrajDist"); + var regressionTestData = webpplAgents.trajDist2simpleJSON(trajDist); + console.log(JSON.stringify(regressionTestData)); + console.log("END OF DATA FOR REGRESSION TESTS\n"); + + var trajData = trajDist2TrajData(trajDist, agent); + + //console.log("trajData", trajData); + + var locActionData = webpplAgents.trajDist2LocActionData(trajDist, trajData); + console.log("locActionData", locActionData); + + console.log("\nminAdmissibleQ:"); + console.log(stateActionFct2ASCII(agent.minAdmissibleQ, agentData.stateActionPairs)); + console.log("\nmaxAdmissibleQ:"); + console.log(stateActionFct2ASCII(agent.maxAdmissibleQ, agentData.stateActionPairs)); + + console.log("\nQ:"); + console.log(webpplAgents.locActionData2ASCII(locActionData.Q)); + console.log("\ncupLoss:"); + console.log(webpplAgents.locActionData2ASCII(locActionData.cupLoss)); + console.log("\nmessingPotential:"); + console.log(webpplAgents.locActionData2ASCII(locActionData.messingPotential)); + console.log("\ncombinedLoss:"); + console.log(webpplAgents.locActionData2ASCII(locActionData.combinedLoss)); + + console.log("\naction frequencies:"); + console.log(webpplAgents.locActionData2ASCII(locActionData.actionFrequency)); + +} \ No newline at end of file diff --git a/package.json b/package.json index 8b870d3..f7c6b66 100644 --- a/package.json +++ b/package.json @@ -31,6 +31,7 @@ "src/agents/makeMDPAgent.wppl", "src/agents/makeMDPAgentSatisfia.wppl", "src/agents/makePOMDPAgent.wppl", + "src/simulation/getDynTree.wppl", "src/simulation/simulateMDP.wppl", "src/simulation/simulatePOMDP.wppl", "src/visualization/gridworld.wppl", diff --git a/src/agents/makeMDPAgentSatisfia.wppl b/src/agents/makeMDPAgentSatisfia.wppl index 179ac0e..ce9dc58 100644 --- a/src/agents/makeMDPAgentSatisfia.wppl +++ b/src/agents/makeMDPAgentSatisfia.wppl @@ -1221,6 +1221,8 @@ var makeMDPAgentSatisfia = function(params_, world) { }; return { + transitionDistribution, + expectedDelta, varianceOfDelta, skewnessOfDelta, excessKurtosisOfDelta, minAdmissibleQ, maxAdmissibleQ, minAdmissibleV, maxAdmissibleV, localPolicy, localPolicyData, propagateAspiration, Q, V, Q2, V2, Q_DeltaSquare, V_DeltaSquare, Q_ones, V_ones, diff --git a/src/main.js b/src/main.js index e7a9121..c91e199 100644 --- a/src/main.js +++ b/src/main.js @@ -34,6 +34,15 @@ module.exports = { setFrom: (arg) => new Set(arg), + objectFromPairs: (pairs) => { + var result = {}; + for (var index in pairs) { + var [key, value] = pairs[index]; + result[key] = value; + } + return result; + }, + min: (arr) => Math.min.apply(null, arr), max: (arr) => Math.max.apply(null, arr), diff --git a/src/simulation/getDynTree.wppl b/src/simulation/getDynTree.wppl new file mode 100644 index 0000000..500586c --- /dev/null +++ b/src/simulation/getDynTree.wppl @@ -0,0 +1,50 @@ +// test with +// $ webppl --require webppl-dp --require . examples/runVerySimpleGW.wppl -- --gw=GW4 + +var getDynTree = function(agent, state, aleph) { + /* Construct a tree that represents all possible histories starting at state with aspiration aleph. + Return value has format + + {action1: [aleph4action, actionLogit, {nextState1: [aleph4state, nextStateLogit, Edel, nextStateBranch], + nextState2: [aleph4state, nextStateLogit, Edel, nextStateBranch], + ...}], + action2: [aleph4action, actionLogit, {nextState1: [aleph4state, nextStateLogit, Edel, nextStateBranch], + nextState2: [aleph4state, nextStateLogit, Edel, nextStateBranch], + ...}], + ...} + + where + - action1, action2, ... are the actions available at state + - aleph4action is the aspiration for the action + - actionLogit is the logit for the action according to the local policy + - nextState1, nextState2, ... are the possible next states + - aleph4state is the aspiration for the next state + - nextStateLogit is the logit for the next state according to the transition distribution + - Edel is the expected delta for the action + - nextStateBranch is the tree for the next state + */ + var localPolicy = agent.localPolicy, + expectedDelta = agent.expectedDelta, + transitionDistribution = agent.transitionDistribution, + propagateAspiration = agent.propagateAspiration, + locPol = localPolicy(state, asInterval(aleph)); + var stateBranch = webpplAgents.objectFromPairs(map(function(actionAndAleph) { + var action = actionAndAleph[0], + aleph4action = actionAndAleph[1], + actionLogit = locPol.score(actionAndAleph), + Edel = expectedDelta(state, action); + if (state.terminateAfterAction) { + return [action, [aleph4action, actionLogit]]; + } else { + var transDist = transitionDistribution(state, action), + actionBranch = webpplAgents.objectFromPairs(map(function(nextState) { + var nextStateLogit = transDist.score(nextState), + nextAleph4state = propagateAspiration(state, action, aleph4action, Edel, nextState); + var nextStateBranch = getDynTree(agent, nextState, nextAleph4state); + return [JSON.stringify(nextState), [nextAleph4state, nextStateLogit, Edel, nextStateBranch]]; + }, transDist.support())); + return [action, [aleph4action, actionLogit, actionBranch]]; + } + }, locPol.support())); + return stateBranch; +};