GPJAX silently enables JAX 64bit #126
patel-zeel
started this conversation in
General
Replies: 2 comments 1 reply
-
Thanks for raising this - I was not aware of this. I like your suggestion of just making it a manual specification. We're working towards supporting conjugate gradient methods soon, so 32bit precision may even be desirable here as there's no pesky matrix inversions to worry about. |
Beta Was this translation helpful? Give feedback.
0 replies
-
If you have the time, would you be willing to open a PR for this @patel-zeel? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I was recently comparing my implementation of a method with GPJAX and noticed that things completely change after I
import gpjax as gpx
. It took me a while to figure out that it happened because of a change in the precision (from 32bit to 64 bit). Since,jax.random
behaves differently for different precisions, I was getting completely different results (please refer to the code below). Would it be useful not to enable 64bit by default to avoid such problems? Another way to handle this might be to usejax.config.update("jax_enable_x64", True)
in all documentation examples to make it a habit for gpjax users.Beta Was this translation helpful? Give feedback.
All reactions