diff --git a/src/gen/inference/importance.cljc b/src/gen/inference/importance.cljc index 1fb04bb..6e9b577 100644 --- a/src/gen/inference/importance.cljc +++ b/src/gen/inference/importance.cljc @@ -22,6 +22,10 @@ (+ (math/log nr) x) (recur rst nr (double x))))))) +(defn- neg-inf? + [v] + (= v ##-Inf)) + (defn resampling [gf args observations n-samples] ;; https://github.com/probcomp/Gen.jl/blob/master/src/inference/importance.jl#L77...L95 (let [result (gf/generate gf args observations) @@ -31,9 +35,10 @@ (let [candidate (gf/generate gf args observations) candidate-model-trace (:trace candidate) log-weight (:weight candidate)] - (vswap! log-total-weight #(logsumexp [log-weight %])) - (when (dist/bernoulli (math/exp (- log-weight @log-total-weight))) - (vreset! model-trace candidate-model-trace)))) + (when-not (neg-inf? log-weight) + (vswap! log-total-weight #(logsumexp [log-weight %])) + (when (dist/bernoulli (math/exp (- log-weight @log-total-weight))) + (vreset! model-trace candidate-model-trace))))) (let [log-ml-estimate (- @log-total-weight (math/log n-samples))] {:trace @model-trace :weight log-ml-estimate}))) diff --git a/test/gen/inference/importance_test.cljc b/test/gen/inference/importance_test.cljc new file mode 100644 index 0000000..d3be762 --- /dev/null +++ b/test/gen/inference/importance_test.cljc @@ -0,0 +1,23 @@ +(ns gen.inference.importance-test + (:require [clojure.test :refer [deftest is testing]] + [gen.choicemap :as choicemap :refer [choicemap get-value]] + [gen.distribution.kixi :as dist] + [gen.dynamic :as dynamic :refer [gen]] + [gen.inference.importance :as importance] + [gen.trace :as trace])) + +(def model-causing-rejection-sampling + (gen + [] + (if (dynamic/trace! :foo dist/bernoulli 0.5) + (dynamic/trace! :bar dist/bernoulli 1.0) + (dynamic/trace! :bar dist/bernoulli 0.0)))) + +(deftest rejection + (testing "Robustness in the presence of importance samples with weight log(0)." + (is {:foo true :bar true} + ;; Needs a couple of samples to trigger previous bug here. + (-> (importance/resampling model-causing-rejection-sampling [] (choicemap {:bar true}) 10) + (:trace) + (trace/get-choices) + (get-value)))))