-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speedup sample
and allow specifying compile_kwargs
(several major changes related to step samplers)
#7578
Conversation
009959b
to
161b859
Compare
d6f9e14
to
87fd299
Compare
874ae65
to
cb8d51e
Compare
10 minutes seem to be saved in pytest CI time compared to previous runs |
39efbd5
to
d752070
Compare
0c8da27
to
d2f9cf9
Compare
f5be4ab
to
95ce8bc
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7578 +/- ##
==========================================
- Coverage 92.84% 92.84% -0.01%
==========================================
Files 106 106
Lines 17686 17719 +33
==========================================
+ Hits 16421 16451 +30
- Misses 1265 1268 +3
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, left a few small comments
Also initialize empty trace and set `trust_input=True`
Also removes default `model.check_start_vals()`
95ce8bc
to
bd232d2
Compare
compile_kwargs
(several major changes related to step samplers)
compile_kwargs
(several major changes related to step samplers)sample
and allow specifying compile_kwargs
(several major changes related to step samplers)
Major changes
ravel_inputs
is specified explicitly. Eventually it will only be possible to useravel_inputs=True
.assign_step_method
does not callinstantiate_steppers
, but returns arguments needed for the latter.compile_kwargs
topm.sample
which is then forwarded to the step samplers functionsEnhancement
This PR speedups
NUTS
(and other step samplers), by:trust_input=True
which can have a large overhead.This PR speedups sample by:
init_nuts
. This will also reduce the path towards external samplers with nutpie/numpyro as it avoids the costly and useless compilation of the logp_dlogp_functiontrust_input
and avoiding deepcopies in the trace function by usingpytensor.In(borrow=True)
andpytensor.Out(borrow=True)
.Further speedups should come for free from #7539, specially for the Numba backend.
Benchmark
In the example below, sampling time is now only 7x slower than nutpie (5s vs 0.7s), compared to 13.5x slower (9.45s vs 0.7s) before. This assuming the same number of logp evals, in fact nutpie tuning allows us to get out with half the evals! We can hopefully bring it over.
Full time until from
pm.sample
to getting a trace is roughly halved as well (7.5s vs 14.4s), although this gain is not proportional to the number of draws.With
compile_kwargs=(mode="NUMBA")
, sampling time is only 3x slower (2.3s).📚 Documentation preview 📚: https://pymc--7578.org.readthedocs.build/en/7578/