diff --git a/src/agents/makeMDPAgentSatisfia.wppl b/src/agents/makeMDPAgentSatisfia.wppl index 04d381f..2114e10 100644 --- a/src/agents/makeMDPAgentSatisfia.wppl +++ b/src/agents/makeMDPAgentSatisfia.wppl @@ -434,11 +434,15 @@ var makeMDPAgentSatisfia = function(params_, world) { // but the loss estimation requires an estimate of the probability of the chosen action, // we estimate the probability at 1 / number of actions: var indices = Array.from(actions.keys()), - propensities = map(function(index) { + losses = map(function(index) { var action = actions[index], loss = combinedLoss(state, action, aleph4state, estAlephs1[index], 1 / indices.length); // ! bottleneck ! - return Math.min(1e100, Math.max(Math.exp(-loss / lossTemperature), 1e-100)); - }, indices); + return loss; + }, indices), + meanLoss = sum(losses) / indices.length, + propensities = map(function(loss) { + return Math.min(1e100, Math.max(Math.exp(-(loss-meanLoss) / lossTemperature), 1e-100)); + }, losses); if (debug) console.log(pad(state),"| localPolicyData", prettyState(state), aleph, actions, {propensities}); @@ -477,11 +481,15 @@ var makeMDPAgentSatisfia = function(params_, world) { var aleph2target = interpolate(estAleph1, 2.0, aleph4state); // Due to the new target aleph, we have to recompute the estimated alephs and resulting losses and propensities: var estAlephs2 = map(function(index) { return estAspiration4action(state, actions[index], aleph2target); }, indices), - propensities2 = map(function(index) { + losses2 = map(function(index) { var action = actions[index], loss = combinedLoss(state, action, aleph4state, estAlephs2[index], 1 / indices2.length); - return Math.min(1e100, Math.max(Math.exp(-loss / lossTemperature), 1e-100)); - }, indices2); + return loss; + }, indices2), + meanLoss2 = sum(losses2) / indices2.length, + propensities2 = map(function(loss) { + return Math.min(1e100, Math.max(Math.exp(-(loss-meanLoss2) / lossTemperature), 1e-100)); + }, losses2); if (debug) console.log(pad(state),"| | localPolicyData", prettyState(state), aleph4state, {a1, midTarget, estAleph1, mid1, indices2, aleph2target, estAlephs2, propensities2});