Skip to content

Commit

Permalink
Merge pull request #770 from null-a/guide-band-aid2
Browse files Browse the repository at this point in the history
Add `noAutoGuide` option
  • Loading branch information
stuhlmueller authored Jan 31, 2017
2 parents d812c0b + bf39e8d commit b9305e3
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 57 deletions.
18 changes: 8 additions & 10 deletions src/guide.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,14 @@ function runDistThunkCond(s, k, a, env, maybeThunk, alternate) {
alternate(s, k, a);
}

function getDist(maybeThunk, env, s, a, k) {
// Runs the guide thunk passed to a `sample` statement, and returns
// the guide distribution returned by the thunk. When no guide is
// given, then the 'noAutoGuide' flag determines whether to
// automatically generate a suitable guide distribution or return
// `null`.
function getDist(maybeThunk, noAutoGuide, targetDist, env, s, a, k) {
return runDistThunkCond(s, k, a, env, maybeThunk, function(s, k, a) {
return k(s, null);
});
}

function getDistOrAuto(maybeThunk, targetDist, env, s, a, k) {
return runDistThunkCond(s, k, a, env, maybeThunk, function(s, k, a) {
return k(s, independent(targetDist, a, env));
return k(s, noAutoGuide ? null : independent(targetDist, a, env));
});
}

Expand Down Expand Up @@ -334,6 +333,5 @@ function squishToInterval(interval) {
module.exports = {
independent: independent,
runThunk: runThunk,
getDist: getDist,
getDistOrAuto: getDistOrAuto
getDist: getDist
};
5 changes: 4 additions & 1 deletion src/inference/elbo.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ module.exports = function(env) {

sample: function(s, k, a, dist, options) {
options = options || {};
return guide.getDistOrAuto(options.guide, dist, env, s, a, function(s, guideDist) {
return guide.getDist(options.guide, options.noAutoGuide, dist, env, s, a, function(s, guideDist) {
if (!guideDist) {
throw new Error('ELBO: No guide distribution to optimize.');
}

var ret = this.sampleGuide(guideDist, options);
var val = ret.val;
Expand Down
6 changes: 5 additions & 1 deletion src/inference/eubo.ad.js
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ module.exports = function(env) {
sample: function(s, k, a, dist, options) {
'use ad';
options = options || {};
return guide.getDistOrAuto(options.guide, dist, env, s, a, function(s, guideDist) {
return guide.getDist(options.guide, options.noAutoGuide, dist, env, s, a, function(s, guideDist) {
if (!guideDist) {
throw new Error('EUBO: No guide distribution to optimize.');
}

var rel = util.relativizeAddress(env, a);
var guideVal = this.trace.findChoice(this.trace.baseAddress + rel).val;
assert.notStrictEqual(guideVal, undefined);
Expand Down
9 changes: 6 additions & 3 deletions src/inference/forwardSample.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ module.exports = function(env) {
sample: function(s, k, a, dist, options) {
if (this.opts.guide) {
options = options || {};
return guide.getDistOrAuto(options.guide, dist, env, s, a, function(s, guideDist) {
return k(s, guideDist.sample());
});
return guide.getDist(
options.guide, options.noAutoGuide, dist, env, s, a,
function(s, maybeGuideDist) {
var d = maybeGuideDist || dist;
return k(s, d.sample());
});
} else {
return k(s, dist.sample());
}
Expand Down
12 changes: 2 additions & 10 deletions src/inference/smc.js
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,8 @@ module.exports = function(env) {
SMC.prototype.sample = function(s, k, a, dist, options) {
options = options || {};
var thunk = (this.importanceOpt === 'ignoreGuide') ? undefined : options.guide;
return guide.getDist(thunk, env, s, a, function(s, maybeDist) {

// maybeDist will be null if either the 'ignoreGuide' option is
// set, or no guide is specified in the program.

// Auto guide if requested.
var importanceDist =
!maybeDist && (this.importanceOpt === 'autoGuide') ?
guide.independent(dist, a, env) :
maybeDist;
var noAutoGuide = (this.importanceOpt !== 'autoGuide') || options.noAutoGuide;
return guide.getDist(thunk, noAutoGuide, dist, env, s, a, function(s, importanceDist) {

var _val, choiceScore, importanceScore;
if (importanceDist) {
Expand Down
5 changes: 1 addition & 4 deletions tests/test-data/deterministic/expected/guides.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
{
"result": [
1,
6
]
"result": [true, true, true, true, true, true, true]
}
6 changes: 0 additions & 6 deletions tests/test-data/deterministic/expected/smc.json

This file was deleted.

43 changes: 41 additions & 2 deletions tests/test-data/deterministic/models/guides.wppl
Original file line number Diff line number Diff line change
@@ -1,4 +1,43 @@
var numParamsCreatedBy = function(thunk) {
setFreshParamsId();
thunk();
return _.size(getParams());
};

[
param({mu: 1, sigma: 0}),
T.sumreduce(param({dims: [3, 2], mu: 1, sigma: 0}))
param({mu: 1, sigma: 0}) === 1,
T.sumreduce(param({dims: [3, 2], mu: 1, sigma: 0})) === 6,

// Check (indirectly) that a guide is automatically generated, by
// checking that a parameter is created.

numParamsCreatedBy(function() {
Infer({method: 'SMC', particles: 1, importance: 'default', model() {
return flip();
}});
}) === 0,

numParamsCreatedBy(function() {
Infer({method: 'SMC', particles: 1, importance: 'autoGuide', model() {
return flip();
}});
}) === 1,

numParamsCreatedBy(function() {
Infer({method: 'SMC', particles: 1, importance: 'autoGuide', model() {
return sample(Bernoulli({p: 0.5}), {noAutoGuide: true});
}});
}) === 0,

numParamsCreatedBy(function() {
Infer({method: 'forward', samples: 1, guide: true, model() {
return sample(Bernoulli({p: 0.5}));
}});
}) === 1,

numParamsCreatedBy(function() {
Infer({method: 'forward', samples: 1, guide: true, model() {
return sample(Bernoulli({p: 0.5}), {noAutoGuide: true});
}});
}) === 0
];
20 changes: 0 additions & 20 deletions tests/test-data/deterministic/models/smc.wppl

This file was deleted.

0 comments on commit b9305e3

Please sign in to comment.