Skip to content

Commit

Permalink
Pass kwargs to nutpie + create env.yml file (#855)
Browse files Browse the repository at this point in the history
* Pass kwargs to nutpie + create env.yml file

* Add comment, format code, and move environment file to its own directory

* Rerun alternative sampler NB

* Handle case when cores and chains are None for nutpie

* Update bayeux version pin

* Run example notebook and remove old comment from pyproject.toml

* Add pyproject
  • Loading branch information
AlexAndorra authored Dec 21, 2024
1 parent 649a304 commit 27f8136
Show file tree
Hide file tree
Showing 4 changed files with 4,156 additions and 1,163 deletions.
43 changes: 36 additions & 7 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,26 @@
import operator
import traceback
import warnings

from copy import deepcopy
from importlib.metadata import version

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr

from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_observations
from pymc.util import get_default_varnames
from pytensor.tensor.special import softmax

from bambi.backend.inference_methods import inference_methods
from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
from bambi.backend.links import (
arctan_2,
cloglog,
identity,
inverse_squared,
logit,
probit,
)
from bambi.backend.model_components import (
ConstantComponent,
DistributionalComponent,
Expand Down Expand Up @@ -246,6 +251,17 @@ def _run_mcmc(
import bayeux as bx # pylint: disable=import-outside-toplevel
import jax # pylint: disable=import-outside-toplevel

# pylint: disable=import-outside-toplevel
from pymc.sampling.parallel import (
_cpu_count,
)

# handle case where cores and chains are not provided
if cores is None:
cores = min(4, _cpu_count())
if chains is None:
chains = max(2, cores)

# Set the seed for reproducibility if provided
if random_seed is not None:
if not isinstance(random_seed, int):
Expand All @@ -255,10 +271,20 @@ def _run_mcmc(
jax_seed = jax.random.PRNGKey(np.random.randint(2**31 - 1))

bx_model = bx.Model.from_pymc(self.model)
bx_sampler = operator.attrgetter(sampler_backend)(
bx_model.mcmc # pylint: disable=no-member
# pylint: disable=no-member
bx_sampler = operator.attrgetter(sampler_backend)(bx_model.mcmc)

# We pass 'draws', 'tune', 'chains', and 'cores' because they can be used by some
# samplers. Since those are keyword arguments of `Model.fit()`, they would not
# be passed in the `kwargs` dict.
idata = bx_sampler(
seed=jax_seed,
draws=draws,
tune=tune,
chains=chains,
cores=cores,
**kwargs,
)
idata = bx_sampler(seed=jax_seed, **kwargs)
idata_from = "bayeux"
else:
raise ValueError(
Expand Down Expand Up @@ -494,7 +520,10 @@ def create_posterior_bayeux(posterior, pm_model):
# https://docs.xarray.dev/en/stable/generated/xarray.Dataset.html
data_vars_values = {}
for data_var_name, data_var_dims in data_vars_dims.items():
data_vars_values[data_var_name] = (data_var_dims, posterior[data_var_name].to_numpy())
data_vars_values[data_var_name] = (
data_var_dims,
posterior[data_var_name].to_numpy(),
)

# Get coords
dims_in_use = set(dim for dims in data_vars_dims.values() for dim in dims)
Expand Down
24 changes: 24 additions & 0 deletions conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: bambi-env
channels:
- conda-forge
- defaults
dependencies:
- python>=3.10,<3.13
- arviz>=0.12.0
- formulae>=0.5.3
- graphviz
- pandas>=1.0.0
- pymc>=5.16.1
# Dev dependencies
- black=24.3.0
- ipython>=5.8.0,!=8.7.0
- pre-commit>=2.19
- pylint=3.1.0
- pytest-cov>=2.6.1
- pytest>=4.4.0
- seaborn>=0.9.0
- pip
- watermark
- pip:
- quartodoc==0.6.1
- bayeux-ml==0.1.15 # Optional JAX dependency
Loading

0 comments on commit 27f8136

Please sign in to comment.