Skip to content

Commit

Permalink
feat: Use support_point as default init for pymc
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Nov 14, 2024
1 parent cad8e99 commit 51b0454
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
14 changes: 7 additions & 7 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,9 @@ def compile_pymc_model(
*,
backend: Literal["numba", "jax"] = "numba",
gradient_backend: Literal["pytensor", "jax"] = "pytensor",
overrides: dict[Union["Variable", str], np.ndarray | float | int] | None = None,
initial_points: dict[Union["Variable", str], np.ndarray | float | int] | None = None,
jitter_rvs: set["TensorVariable"] | None = None,
default_strategy: Literal["support_point", "prior"] = "prior",
default_initialization_strategy: Literal["support_point", "prior"] = "support_point",
**kwargs,
) -> CompiledModel:
"""Compile necessary functions for sampling a pymc model.
Expand All @@ -455,10 +455,10 @@ def compile_pymc_model(
The set (or list or tuple) of random variables for which a U(-1, +1)
jitter should be added to the initial value. Only available for
variables that have a transform or real-valued support.
default_strategy : str
default_initialization_strategy : str
Which of { "support_point", "prior" } to prefer if the initval setting
for an RV is None.
overrides : dict
initial_points : dict
Initial value (strategies) to use instead of what's specified in
`Model.initial_values`.
Returns
Expand All @@ -475,13 +475,13 @@ def compile_pymc_model(
"and restart your kernel in case you are in an interactive session."
)

if default_strategy == "support_point" and jitter_rvs is None:
if default_initialization_strategy == "support_point" and jitter_rvs is None:
jitter_rvs = set(model.free_RVs)

initial_point_fn = make_initial_point_fn(
model=model,
overrides=overrides,
default_strategy=default_strategy,
overrides=initial_points,
default_strategy=default_initialization_strategy,
jitter_rvs=jitter_rvs,
return_transformed=True,
)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def test_pymc_model_float32(backend, gradient_backend):
trace.posterior.a # noqa: B018


@parameterize_backends
def test_pymc_model_no_prior(backend, gradient_backend):
with pm.Model() as model:
a = pm.Flat("a")
pm.Normal("b", mu=a, observed=0.)

compiled = nutpie.compile_pymc_model(
model, backend=backend, gradient_backend=gradient_backend
)
trace = nutpie.sample(compiled, chains=1)
trace.posterior.a # noqa: B018


@parameterize_backends
def test_blocking(backend, gradient_backend):
with pm.Model() as model:
Expand Down

0 comments on commit 51b0454

Please sign in to comment.