Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for 0.10 release #97

Merged
merged 5 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 114 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,130 @@

All notable changes to this project will be documented in this file.

## [0.8.0] - 2023-08-06
## [0.10.0] - 2024-03-20

### Documentation

- Mention non-blocking sampling in readme (Adrian Seyboldt)


### Features

- Allow sampling in the backgound (Adrian Seyboldt)

- Implement check if background sampling is complete (Adrian Seyboldt)

- Implement pausing and unpausing of samplers (Adrian Seyboldt)

- Filter warnings and compile through pymc (Adrian Seyboldt)


### Miscellaneous Tasks

- Bump actions/setup-python from 4 to 5 (dependabot[bot])

- Bump uraimo/run-on-arch-action from 2.5.0 to 2.7.1 (dependabot[bot])

- Bump actions/checkout from 3 to 4 (dependabot[bot])

- Bump actions/upload-artifact from 3 to 4 (dependabot[bot])

- Bump the cargo group across 1 directory with 2 updates (dependabot[bot])

- Some major version bumps in rust deps (Adrian Seyboldt)

- Bump dependency versions (Adrian Seyboldt)

- Update changelog (Adrian Seyboldt)

- Bump version (Adrian Seyboldt)


### Performance

- Set the number of parallel chains dynamically (Adrian Seyboldt)


## [0.9.2] - 2024-02-19

### Bug Fixes

- Allow dims with only length specified (Adrian Seyboldt)


### Documentation

- Update suggested mamba commands in README (#70) (Ben Mares)

- Fix README typo bridgestan→nutpie (#69) (Ben Mares)


### Features

- Handle missing libraries more robustly (#72) (Ben Mares)


### Miscellaneous Tasks

- Bump actions/download-artifact from 3 to 4 (dependabot[bot])


### Ci

- Make sure the local nutpie is installed (Adrian Seyboldt)

- Install local nutpie package in all jobs (Adrian Seyboldt)


## [0.9.0] - 2023-09-12

### Bug Fixes

- Better error context for init point errors (Adrian Seyboldt)


### Features

- Improve error message by providing context (Adrian Seyboldt)

- Use standard normal to initialize chains (Adrian Seyboldt)


### Miscellaneous Tasks

- Update deps (Adrian Seyboldt)

- Rename stan model transpose function (Adrian Seyboldt)

- Update nutpie (Adrian Seyboldt)


### Styling

- Fix formatting (Adrian Seyboldt)


## [0.8.0] - 2023-08-18

### Bug Fixes

- Initialize points in uniform(-2, 2) (Adrian Seyboldt)

- Multidimensional stan variables were stored in incorrect order (Adrian Seyboldt)

- Fix with_coords for stan models (Adrian Seyboldt)


### Miscellaneous Tasks

- Update deps (Adrian Seyboldt)

- Update deps (Adrian Seyboldt)

- Bump version (Adrian Seyboldt)

- Update deps (Adrian Seyboldt)


## [0.7.0] - 2023-07-21

Expand Down
50 changes: 23 additions & 27 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "nutpie"
version = "0.9.2"
version = "0.10.0"
authors = [
"Adrian Seyboldt <[email protected]>",
"PyMC Developers <[email protected]>"
Expand All @@ -21,7 +21,7 @@ name = "_lib"
crate-type = ["cdylib"]

[dependencies]
nuts-rs = "0.7.0"
nuts-rs = "0.8.0"
numpy = "0.20.0"
ndarray = "0.15.6"
rand = "0.8.5"
Expand Down
38 changes: 27 additions & 11 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import itertools
import warnings
from dataclasses import dataclass
from importlib.util import find_spec
from math import prod
Expand Down Expand Up @@ -200,12 +201,26 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
logp_numba_raw, c_sig = _make_c_logp_func(
n_dim, logp_fn, user_data, shared_logp, shared_data
)
logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Cannot cache compiled function .* as it uses dynamic globals",
category=numba.NumbaWarning,
)

logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)

expand_numba_raw, c_sig_expand = _make_c_expand_func(
n_dim, n_expanded, expand_fn, user_data, shared_expand, shared_data
)
expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Cannot cache compiled function .* as it uses dynamic globals",
category=numba.NumbaWarning,
)

expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw)

coords = {}
for name, vals in model.coords.items():
Expand Down Expand Up @@ -276,6 +291,7 @@ def _make_functions(model):
import pytensor
import pytensor.link.numba.dispatch
import pytensor.tensor as pt
from pymc.pytensorf import compile_pymc

shapes = _compute_shapes(model)

Expand Down Expand Up @@ -340,9 +356,8 @@ def _make_functions(model):
(logp, grad) = pytensor.graph_replace([logp, grad], replacements)

# We should avoid compiling the function, and optimize only
logp_fn_pt = pytensor.compile.function.function(
(joined,), (logp, grad), mode=pytensor.compile.NUMBA
)
with model:
logp_fn_pt = compile_pymc((joined,), (logp, grad), mode=pytensor.compile.NUMBA)

logp_fn = logp_fn_pt.vm.jit_fn

Expand All @@ -368,12 +383,13 @@ def _make_functions(model):
num_expanded = count

allvars = pt.concatenate([joined, *[var.ravel() for var in remaining_rvs]])
expand_fn_pt = pytensor.compile.function.function(
(joined,),
(allvars,),
givens=list(replacements.items()),
mode=pytensor.compile.NUMBA,
)
with model:
expand_fn_pt = compile_pymc(
(joined,),
(allvars,),
givens=list(replacements.items()),
mode=pytensor.compile.NUMBA,
)
expand_fn = expand_fn_pt.vm.jit_fn

return (
Expand Down
Loading