Skip to content

Commit

Permalink
fix: Fix random variables with missing values in pymc deterministics
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt authored and twiecki committed Jun 6, 2024
1 parent 150dcee commit 34d151a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
raise ValueError(f"Shared variables must have unique names: {val.name}")
shared_data[val.name] = val.get_value().copy()
shared_vars[val.name] = val
seen.add(val)

for val in shared_data.values():
val.flags.writeable = False
Expand Down
3 changes: 2 additions & 1 deletion tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def test_pymc_model_shared():
def test_missing():
with pm.Model(coords={"obs": range(4)}) as model:
mu = pm.Normal("mu")
pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs")
y = pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs")
pm.Deterministic("y2", 2 * y, dims="obs")

compiled = nutpie.compile_pymc_model(model)
tr = nutpie.sample(compiled, chains=1, seed=1)
Expand Down

0 comments on commit 34d151a

Please sign in to comment.