Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: make importance resampling robust to importance samples with weight log(0) #68

Merged
merged 2 commits into from
Apr 10, 2024

Conversation

Schaechtle
Copy link
Collaborator

What does this do?

Fixes a bug in importance resampling. Before, importance resampling crashed if any samples had a weight = log(0). But that should be a fine case for importance resampling to handle. In this case, importance sampling should act like rejection sampling and never return those samples.

How was this tested?

I've added a test in a separate commit.

@Schaechtle Schaechtle requested a review from sritchie April 10, 2024 14:58
@Schaechtle
Copy link
Collaborator Author

@sritchie, sorry for the premature review request. The failing CLJS test makes sense to me and should be easy to fix. What's the Gen.clj-test-idiomatic way to test versions of dist that are compatible with both CLJ and CLJS?

(Not sure I track why the linter is failing 🤷‍♂️)

@sritchie
Copy link
Collaborator

@Schaechtle in this case let's test with gen.distribution.kixi, that works with both clj and cljs.

The linter's failing because Double/... won't work in cljs. I wish it would be more clear that it's firing for the 'cljs' side of the cljc file!

Instead you could do

(defn- neg-inf?
  [v]
  (= v ##-Inf))

which I believe will work in cljs as well.

@@ -31,8 +35,8 @@
(let [candidate (gf/generate gf args observations)
candidate-model-trace (:trace candidate)
log-weight (:weight candidate)]
(vswap! log-total-weight #(logsumexp [log-weight %]))
(when (dist/bernoulli (math/exp (- log-weight @log-total-weight)))
(when-not (neg-inf? log-weight) (vswap! log-total-weight #(logsumexp [log-weight %])))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about moving the lower when into the when-not block?

(when-not (neg-inf? log-weight)
          (vswap! log-total-weight #(logsumexp [log-weight %]))
          (when (dist/bernoulli (math/exp (- log-weight @log-total-weight)))
            (vreset! model-trace candidate-model-trace)))

Also this makes me thing that dist/bernoulli should really take the log-weight without running math/exp on it again...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@Schaechtle Schaechtle force-pushed the schaechtle/make-importance-resampling-more-robust branch from 4d2538e to 67e382b Compare April 10, 2024 18:53
@Schaechtle Schaechtle force-pushed the schaechtle/make-importance-resampling-more-robust branch from 67e382b to e6b6fa6 Compare April 10, 2024 18:58
@Schaechtle
Copy link
Collaborator Author

@sritchie, thanks for the quick turnaround and the suggestions. All implemented.

Also this makes me thing that dist/bernoulli should really take the log-weight without running math/exp on it again...

maybe we add a log-bernoulli?

@Schaechtle Schaechtle requested a review from sritchie April 10, 2024 19:01
@sritchie sritchie merged commit 51235ed into main Apr 10, 2024
4 checks passed
@sritchie sritchie deleted the schaechtle/make-importance-resampling-more-robust branch April 10, 2024 19:37
@sritchie
Copy link
Collaborator

Nice work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants