Skip to content

Commit

Permalink
fixed missing Delta from state at timeOut
Browse files Browse the repository at this point in the history
  • Loading branch information
mensch72 committed Mar 1, 2024
1 parent 19c36e3 commit 9b681b7
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 12 deletions.
3 changes: 3 additions & 0 deletions examples/runVerySimpleGW.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ var env = getEnv(),
var t0 = _SU.time();
// verify meeting of expectations:
console.log("\nV", V(startState, aleph0),"\n");
console.log("\Variance", V2(startState, aleph0)-V(startState, aleph0)*V(startState, aleph0),"\n");
console.log("\nTIME:", _SU.time() - t0, "ms\n");
console.log("\ncupLoss", cupLoss_state(mdp.startState, aleph0),"\n");
console.log("\nentropy", behaviorEntropy_state(mdp.startState, aleph0),"\n");
Expand All @@ -43,6 +44,8 @@ var gd = agent.getData, agentData = gd();

var sym = simulateMDPAgentSatisfia(mdp, agent, mdp.startState, aleph0, argv); // simulate(mdp.startState, aleph0);

console.log(sym.trajectory);

var trajDist = Infer({ model() {
return sym.trajectory;
}}).getDist();
Expand Down
18 changes: 9 additions & 9 deletions src/agents/makeMDPAgentSatisfia.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,14 @@ var makeMDPAgentSatisfia = function(params_, world) {
lossCoeff4KLdiv1: 1, // weight of current-state KL divergence in loss function, must be >= 0
// coefficients for expensive to compute loss functions (all zero by default except for variance):
lossCoeff4Variance: 1, // weight of variance of total in loss function, must be >= 0
lossCoeff4Fourth: 1, // weight of centralized fourth moment of total in loss function, must be >= 0
lossCoeff4Cup: 1, // weight of "cup" loss component, based on sixth moment of total, must be >= 0
lossCoeff4LRA: 1, // weight of deviation of LRA from 0.5 in loss function, must be >= 0
lossCoeff4Time: 1, // weight of time in loss function, must be >= 0
lossCoeff4DeltaVariation: 1, // weight of variation of Delta in loss function, must be >= 0
lossCoeff4Entropy: 1, // weight of action entropy in loss function, must be >= 0
lossCoeff4KLdiv: 1, // weight of KL divergence in loss function, must be >= 0
lossCoeff4OtherLoss: 1, // weight of other loss components specified by otherLossIncrement, must be >= 0
lossCoeff4Fourth: 0, // weight of centralized fourth moment of total in loss function, must be >= 0
lossCoeff4Cup: 0, // weight of "cup" loss component, based on sixth moment of total, must be >= 0
lossCoeff4LRA: 0, // weight of deviation of LRA from 0.5 in loss function, must be >= 0
lossCoeff4Time: 0, // weight of time in loss function, must be >= 0
lossCoeff4DeltaVariation: 0, // weight of variation of Delta in loss function, must be >= 0
lossCoeff4Entropy: 0, // weight of action entropy in loss function, must be >= 0
lossCoeff4KLdiv: 0, // weight of KL divergence in loss function, must be >= 0
lossCoeff4OtherLoss: 0, // weight of other loss components specified by otherLossIncrement, must be >= 0
allowNegativeCoeffs: false, // if true, allow negative loss coefficients
}, params_), { options: extend(extend({
verbose: false, // if true, print explanatory messages
Expand Down Expand Up @@ -440,7 +440,7 @@ var makeMDPAgentSatisfia = function(params_, world) {
return Math.min(1e100, Math.max(Math.exp(-loss / lossTemperature), 1e-100));
}, indices);

if (verbose || debug) console.log(pad(state),"| localPolicyData", prettyState(state), aleph, actions, {propensities});
if (debug) console.log(pad(state),"| localPolicyData", prettyState(state), aleph, actions, {propensities});

// now we can construct the local policy as a WebPPL distribution object:
var locPol = Infer({ model() {
Expand Down
64 changes: 63 additions & 1 deletion src/environments/safety_gridworlds/very_simple.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var VerySimpleGW = function(gw, parms, _time, timeOutDelta) {
stateToActions0 = world.stateToActions,
startState = mdp.startState,
feature = world.feature,
expectedDelta = tableToExpectedDeltaFct(d.expectedDeltaTable, feature, d.timeOutDelta || timeOutDelta),
expectedDelta = tableToExpectedDeltaFct(d.expectedDeltaTable, feature, d.timeOutDelta === undefined ? timeOutDelta : d.timeOutDelta),
uninformedPolicy = UniformGridPolicy(),
referencePolicy = UniformGridPolicy(),
stateToActions = function(s) {
Expand Down Expand Up @@ -173,6 +173,68 @@ var VerySimpleGW = function(gw, parms, _time, timeOutDelta) {
/*
Desired: go to the G that gives 4.
*/
} else if (gw == "GW21") {
return makeGW({
/* grid: [
['#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#'],
['#','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','#'],
['#','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','#'],
['#',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','#'],
['#','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6','#'],
['#','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','#'],
['#','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','#'],
['#',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','#'],
['#','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6','#'],
['#','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','#'],
['#','6',' ','4','^','6',' ','4','^','6','A','4','^','6',' ','4','^','6',' ','4','#'],
['#',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','#'],
['#','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6','#'],
['#','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','#'],
['#','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','#'],
['#',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','#'],
['#','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6','#'],
['#','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','#'],
['#','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','^','6',' ','4','#'],
['#',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','1',' ','9','^','#'],
['#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#']
],
*/ grid: [
['#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#'],
['#',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','#'],
['#','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4','#'],
['#',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','#'],
['#','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6','#'],
['#',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','#'],
['#','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4','#'],
['#',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','#'],
['#','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6','#'],
['#',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','#'],
['#','6',' ','4',' ','6',' ','4',' ','6','A','4',' ','6',' ','4',' ','6',' ','4','#'],
['#',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','#'],
['#','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6','#'],
['#',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','#'],
['#','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4','#'],
['#',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','#'],
['#','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6','#'],
['#',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','#'],
['#','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4',' ','6',' ','4','#'],
['#',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','1',' ','9',' ','#'],
['#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#','#']
],
expectedDeltaTable: {
'1': 1,
'4': 4,
'6': 6,
'9': 9,
' ': 0
},
aleph0: 25,
totalTime: 2,
timeOutDelta: 0
});
/*
Desired: l-r-r-l-l-r-r-l-l-r. Challenge: uniform random policy also gives 25 in expectation.
*/
}
};

Expand Down
4 changes: 2 additions & 2 deletions src/utils/utilsSatisfia.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ var tableToExpectedDeltaFct = function(table, feature, timeOutDelta) {
return function(state, unused_action) {
var f = feature(state),
stateFeatureName = f.name,
Edel = state.timeLeft > 1 ? (stateFeatureName ? table[stateFeatureName] : table[f['0']])
: (timeOutDelta || -1000);
Edel = state.timeLeft > 0 ? (stateFeatureName ? table[stateFeatureName] : table[f['0']])
: (timeOutDelta === undefined ? -1000 : timeOutDelta);
return Edel;
};
};
Expand Down

0 comments on commit 9b681b7

Please sign in to comment.