Skip to content

Commit

Permalink
fixed wrong loss computation due to overflows
Browse files Browse the repository at this point in the history
  • Loading branch information
mensch72 committed Mar 1, 2024
1 parent 1aa3e15 commit 400ec5e
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/agents/makeMDPAgentSatisfia.wppl
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,15 @@ var makeMDPAgentSatisfia = function(params_, world) {
// but the loss estimation requires an estimate of the probability of the chosen action,
// we estimate the probability at 1 / number of actions:
var indices = Array.from(actions.keys()),
propensities = map(function(index) {
losses = map(function(index) {
var action = actions[index],
loss = combinedLoss(state, action, aleph4state, estAlephs1[index], 1 / indices.length); // ! bottleneck !
return Math.min(1e100, Math.max(Math.exp(-loss / lossTemperature), 1e-100));
}, indices);
return loss;
}, indices),
meanLoss = sum(losses) / indices.length,
propensities = map(function(loss) {
return Math.min(1e100, Math.max(Math.exp(-(loss-meanLoss) / lossTemperature), 1e-100));
}, losses);

if (debug) console.log(pad(state),"| localPolicyData", prettyState(state), aleph, actions, {propensities});

Expand Down Expand Up @@ -477,11 +481,15 @@ var makeMDPAgentSatisfia = function(params_, world) {
var aleph2target = interpolate(estAleph1, 2.0, aleph4state);
// Due to the new target aleph, we have to recompute the estimated alephs and resulting losses and propensities:
var estAlephs2 = map(function(index) { return estAspiration4action(state, actions[index], aleph2target); }, indices),
propensities2 = map(function(index) {
losses2 = map(function(index) {
var action = actions[index],
loss = combinedLoss(state, action, aleph4state, estAlephs2[index], 1 / indices2.length);
return Math.min(1e100, Math.max(Math.exp(-loss / lossTemperature), 1e-100));
}, indices2);
return loss;
}, indices2),
meanLoss2 = sum(losses2) / indices2.length,
propensities2 = map(function(loss) {
return Math.min(1e100, Math.max(Math.exp(-(loss-meanLoss2) / lossTemperature), 1e-100));
}, losses2);

if (debug) console.log(pad(state),"| | localPolicyData", prettyState(state), aleph4state, {a1, midTarget, estAleph1, mid1, indices2, aleph2target, estAlephs2, propensities2});

Expand Down

0 comments on commit 400ec5e

Please sign in to comment.