diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 2331e54..7bbce4e 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -435,9 +435,9 @@ def compile_pymc_model( *, backend: Literal["numba", "jax"] = "numba", gradient_backend: Literal["pytensor", "jax"] = "pytensor", - overrides: dict[Union["Variable", str], np.ndarray | float | int] | None = None, + initial_points: dict[Union["Variable", str], np.ndarray | float | int] | None = None, jitter_rvs: set["TensorVariable"] | None = None, - default_strategy: Literal["support_point", "prior"] = "prior", + default_initialization_strategy: Literal["support_point", "prior"] = "support_point", **kwargs, ) -> CompiledModel: """Compile necessary functions for sampling a pymc model. @@ -455,10 +455,10 @@ def compile_pymc_model( The set (or list or tuple) of random variables for which a U(-1, +1) jitter should be added to the initial value. Only available for variables that have a transform or real-valued support. - default_strategy : str + default_initialization_strategy : str Which of { "support_point", "prior" } to prefer if the initval setting for an RV is None. - overrides : dict + initial_points : dict Initial value (strategies) to use instead of what's specified in `Model.initial_values`. Returns @@ -475,13 +475,13 @@ def compile_pymc_model( "and restart your kernel in case you are in an interactive session." ) - if default_strategy == "support_point" and jitter_rvs is None: + if default_initialization_strategy == "support_point" and jitter_rvs is None: jitter_rvs = set(model.free_RVs) initial_point_fn = make_initial_point_fn( model=model, - overrides=overrides, - default_strategy=default_strategy, + overrides=initial_points, + default_strategy=default_initialization_strategy, jitter_rvs=jitter_rvs, return_transformed=True, ) diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 8f70699..7c7f597 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -38,6 +38,19 @@ def test_pymc_model_float32(backend, gradient_backend): trace.posterior.a # noqa: B018 +@parameterize_backends +def test_pymc_model_no_prior(backend, gradient_backend): + with pm.Model() as model: + a = pm.Flat("a") + pm.Normal("b", mu=a, observed=0.) + + compiled = nutpie.compile_pymc_model( + model, backend=backend, gradient_backend=gradient_backend + ) + trace = nutpie.sample(compiled, chains=1) + trace.posterior.a # noqa: B018 + + @parameterize_backends def test_blocking(backend, gradient_backend): with pm.Model() as model: