Skip to content

Commit

Permalink
adjusted aspiration propagation to rescaling as in Python version
Browse files Browse the repository at this point in the history
  • Loading branch information
mensch72 committed Mar 6, 2024
1 parent 3bb9e87 commit 2fb2009
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 55 deletions.
2 changes: 1 addition & 1 deletion examples/runVerySimpleGW.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ var _PP = webpplAgents.pretty;
var env = getEnv(),
argv = extend({}, env.argv),
params = extend({}, argv),
mdp = VerySimpleGW(argv.gw || "GW2", argv.gwparms, argv.time, argv.timeOutDelta),
mdp = VerySimpleGW(argv.gw || "GW2", argv.gwparms, argv.time, argv.timeOutDelta, true),
world = mdp.world,
transition = world.transition,
expectedDelta = mdp.expectedDelta,
Expand Down
79 changes: 29 additions & 50 deletions src/agents/makeMDPAgentSatisfia.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ var makeMDPAgentSatisfia = function(params_, world) {
→ minAdmissibleV, maxAdmissibleV
→ MIN(minAdmissibleQ), MAX(maxAdmissibleQ)
→ E(minAdmissibleV), E(maxAdmissibleV) (RECURSION)
estAspiration4action
aspiration4action
→ minAdmissibleV, maxAdmissibleV, minAdmissibleQ, maxAdmissibleQ
→ combinedLoss
→ Q, .., Q6, Q_DeltaSquare, Q_ones
Expand Down Expand Up @@ -320,8 +320,16 @@ var makeMDPAgentSatisfia = function(params_, world) {
// When constructing the local policy, we first use an estimated action aspiration interval
// that does not depend on the local policy but is simply based on the state's aspiration interval,
// moved from the admissibility interval of the state to the admissibility interval of the action.
var estAspiration4action = dp.cache(function(state, action, aleph4state){
var aspiration4action = dp.cache(function(state, action, aleph4state){
if (verbose || debug) console.log(pad(state),"| | estAspiration4action, state",prettyState(state),"action",action,"aleph4state",aleph4state,"...");
var res = interpolate(minAdmissibleQ(state, action),
relativePosition(minAdmissibleV(state), aleph4state, maxAdmissibleV(state)),
maxAdmissibleQ(state, action));

if (verbose || debug) console.log(pad(state),"| | ╰ estAspiration4action, state",prettyState(state),"action",action,"aleph4state",aleph4state,":",res,"(rescaled)");
return res;

// DISABLED:
var phi = admissibility4action(state, action);
if (isSubsetOf(phi, aleph4state)) {

Expand Down Expand Up @@ -426,7 +434,7 @@ var makeMDPAgentSatisfia = function(params_, world) {
// Estimate aspiration intervals for all possible actions in a way
// independent from the local policy that we are about to construct,
var actions = stateToActions(state),
estAlephs1 = map(function(action) { return estAspiration4action(state, action, aleph4state); }, actions);
alephs4action = map(function(action) { return aspiration4action(state, action, aleph4state); }, actions);

// Estimate losses based on this estimated aspiration intervals
// and use it to construct softmin propensities (probability weights) for choosing actions.
Expand All @@ -436,7 +444,7 @@ var makeMDPAgentSatisfia = function(params_, world) {
var indices = Array.from(actions.keys()),
losses = map(function(index) {
var action = actions[index],
loss = combinedLoss(state, action, aleph4state, estAlephs1[index], 1 / indices.length); // ! bottleneck !
loss = combinedLoss(state, action, aleph4state, alephs4action[index], 1 / indices.length); // ! bottleneck !
return loss;
}, indices),
minLoss = _SU.min(losses),
Expand Down Expand Up @@ -466,32 +474,25 @@ var makeMDPAgentSatisfia = function(params_, world) {

if (verbose || debug) console.log(pad(state),"| localPolicyData, state",prettyState(state),"aleph4state",aleph4state,": a1",a1,"adm1",adm1,"(need to draw a 2nd a)...");

// For drawing the second action, restrict actions so that the the midpoint of aleph4state can be mixed from
// For drawing the second action, restrict actions so that the midpoint of aleph4state can be mixed from
// those of estAlephs4action of the first and second action:
var midTarget = midpoint(aleph4state),
estAleph1 = estAlephs1[i1],
mid1 = midpoint(estAleph1),
aleph1 = alephs4action[i1],
mid1 = midpoint(aleph1),
indices2 = (mid1 <= midTarget)
? filter(function(index) { return midpoint(estAlephs1[index]) >= midTarget; }, indices)
: filter(function(index) { return midpoint(estAlephs1[index]) <= midTarget; }, indices);
// Since we are already set on giving a1 a considerable weight, we no longer aim to have aleph(a2)
// as close as possible to aleph(s), but to a target aleph that would allow mixing a1 and a2
// in roughly equal proportions, i.e., we aim to have aleph(a2) as close as possible to
// aleph(s) + (aleph(s) - aleph(a1)):
var aleph2target = interpolate(estAleph1, 2.0, aleph4state);
// Due to the new target aleph, we have to recompute the estimated alephs and resulting losses and propensities:
var estAlephs2 = map(function(index) { return estAspiration4action(state, actions[index], aleph2target); }, indices),
losses2 = map(function(index) {
? filter(function(index) { return midpoint(alephs4action[index]) >= midTarget; }, indices)
: filter(function(index) { return midpoint(alephs4action[index]) <= midTarget; }, indices);
var losses2 = map(function(index) {
var action = actions[index],
loss = combinedLoss(state, action, aleph4state, estAlephs2[index], 1 / indices2.length);
loss = combinedLoss(state, action, aleph4state, alephs4action[index], 1 / indices2.length);
return loss;
}, indices2),
meanLoss2 = _SU.min(losses2),
propensities2 = map(function(loss) {
return Math.max(Math.exp(-(loss-meanLoss2) / lossTemperature), 1e-100);
}, losses2);

if (debug) console.log(pad(state),"| | localPolicyData", prettyState(state), aleph4state, {a1, midTarget, estAleph1, mid1, indices2, aleph2target, estAlephs2, propensities2});
if (debug) console.log(pad(state),"| localPolicyData", prettyState(state), aleph4state, {a1, midTarget, aleph1, mid1, indices2, propensities2});

// Like for a1, we now draw a2 using a softmin mixture of these actions, based on the new propentities,
// and get its admissibility interval:
Expand All @@ -500,33 +501,11 @@ var makeMDPAgentSatisfia = function(params_, world) {
adm2 = admissibility4action(state, a2),
adm2Lo = adm2[0], adm2Hi = adm2[1];

// Now we need to find two aspiration intervals aleph1 in adm1 and aleph2 in adm2,
// and a probability p such that
// aleph1:p:aleph2 is contained in aleph4state
// and aleph1, aleph2 are close to the estimates we used above in estimating loss.
// Instead of optimizing this, we use the following heuristic:
// We first choose p so that the midpoints mix exactly:
var estAleph2 = estAlephs2[i2],
mid2 = midpoint(estAleph2),
var aleph2 = alephs4action[i2],
mid2 = midpoint(aleph2),
p = relativePosition(mid1, midTarget, mid2);

// Now we find the largest relative size of aleph1 and aleph2
// so that their mixture is still contained in aleph4state:
// we want aleph1Lo:p:aleph2Lo >= alephLo and aleph1Hi:p:aleph2Hi <= alephHi
// where aleph1Lo = mid1 - x * w1, aleph1Hi = mid1 + x * w1,
// aleph2Lo = mid2 - x * w2, aleph2Hi = mid2 + x * w2,
// hence midTarget - x * w1:p:w2 >= alephLo and midTarget + x * w1:p:w2 <= alephHi,
// i.e., x <= (midTarget - alephLo) / (w1:p:w2) and x <= (alephHi - midTarget) / (w1:p:w2):
var w1 = estAleph1[1] - estAleph1[0],
w2 = estAleph2[1] - estAleph2[0],
w = interpolate(w1, p, w2),
x = w > 0 ? Math.min((midTarget - alephLo) / w, (alephHi - midTarget) / w) : 0,
aleph1 = [mid1 - x * w1, mid1 + x * w1],
aleph2 = [mid2 - x * w2, mid2 + x * w2];

if (debug) console.log(pad(state),"| | localPolicyData",prettyState(state), aleph4state, {a1, estAleph1, adm1: adm1, w1, a2, estAleph2, adm2, w2, p, w, x, aleph1, aleph2});

if (verbose || debug) console.log(pad(state),"| | localPolicyData, state",prettyState(state),"aleph4state",aleph4state,": a1,p,a2",a1,p,a2,"adm12",adm1,adm2,"aleph12",aleph1,aleph2);
if (verbose || debug) console.log(pad(state),"| localPolicyData, state",prettyState(state),"aleph4state",aleph4state,": a1,p,a2",a1,p,a2,"adm12",adm1,adm2,"aleph12",aleph1,aleph2);

return sample(Categorical({vs: [[a1, aleph1], [a2, aleph2]], ps: [1-p, p]}));

Expand Down Expand Up @@ -566,10 +545,10 @@ var makeMDPAgentSatisfia = function(params_, world) {
// and aleph1, aleph2 are close to the estimates we used above in estimating loss.
// To measure the error, we use the sum of the squared deviations of aleph1, aleph2 from the estimates,
// divided by the squared width of adm1 and adm2.
var estAleph1 = estAlephs1[i1],
estAleph2 = estAlephs1[i2],
target1Lo = estAleph1[0], target1Hi = estAleph1[1], wsq1 = squared(adm1Hi - adm1Lo),
target2Lo = estAleph2[0], target2Hi = estAleph2[1], wsq2 = squared(adm2Hi - adm2Lo);
var aleph1 = alephs4action[i1],
aleph2 = alephs4action[i2],
target1Lo = aleph1[0], target1Hi = aleph1[1], wsq1 = squared(adm1Hi - adm1Lo),
target2Lo = aleph2[0], target2Hi = aleph2[1], wsq2 = squared(adm2Hi - adm2Lo);

if (wsq1 == 0 && wsq2 == 0) {

Expand All @@ -592,8 +571,8 @@ var makeMDPAgentSatisfia = function(params_, world) {
aleph2 = [target2Lo + devLo * wsq2, target2Hi + devHi * wsq2];

if (verbose || debug) console.log(pad(state),"| ",prettyState(state), aleph4state,
"\n ", a1, estAleph1, adm1, wsq1, d1Lo, d1Hi,
"\n ", a2, estAleph2, adm2, wsq2, d2Lo, d2Hi,
"\n ", a1, aleph1, adm1, wsq1, d1Lo, d1Hi,
"\n ", a2, aleph2, adm2, wsq2, d2Lo, d2Hi,
"\n ", facLo, facHi, p, wsq, devLo, devHi,
"\n ", aleph1, aleph2);
// TODO: verify that the following inquality constraints are always satisfied, otherwise deal with it:
Expand Down
8 changes: 4 additions & 4 deletions src/environments/safety_gridworlds/very_simple.wppl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// TODO: prevent agent from turning around

var VerySimpleGW = function(gw, parms, _time, timeOutDelta) {
var VerySimpleGW = function(gw, parms, _time, timeOutDelta, allowBackwards) {
var time = _time ? _time : 10;

var makeGW = function(d) {
Expand All @@ -26,7 +26,7 @@ var VerySimpleGW = function(gw, parms, _time, timeOutDelta) {
: (lx == px-1 && ly == py) ? "r"
: undefined,
backIndex = actions.indexOf(back);
return actions.slice(0, backIndex).concat(actions.slice(backIndex+1));
return allowBackwards ? actions : actions.slice(0, backIndex).concat(actions.slice(backIndex+1));
},
our_world = { transition: world.transition, stateToActions, feature };

Expand Down Expand Up @@ -228,8 +228,8 @@ var VerySimpleGW = function(gw, parms, _time, timeOutDelta) {
'9': 9,
' ': 0
},
aleph0: 5,
totalTime: 2,
aleph0: 15,
totalTime: 6,
timeOutDelta: 0
});
/*
Expand Down

0 comments on commit 2fb2009

Please sign in to comment.