From caf5ab574905991a89463fc1bfc49b1ebe020960 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 6 Jun 2024 19:03:58 +0200 Subject: [PATCH] fix: Fix random variables with missing values in pymc deterministics --- python/nutpie/compile_pymc.py | 1 + tests/test_pymc.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 5d9a048..b9c28b2 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -202,6 +202,7 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel: raise ValueError(f"Shared variables must have unique names: {val.name}") shared_data[val.name] = val.get_value().copy() shared_vars[val.name] = val + seen.add(val) for val in shared_data.values(): val.flags.writeable = False diff --git a/tests/test_pymc.py b/tests/test_pymc.py index d9ebaa2..d39b06e 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -99,7 +99,8 @@ def test_pymc_model_shared(): def test_missing(): with pm.Model(coords={"obs": range(4)}) as model: mu = pm.Normal("mu") - pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs") + y = pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs") + pm.Deterministic("y2", 2 * y, dims="obs") compiled = nutpie.compile_pymc_model(model) tr = nutpie.sample(compiled, chains=1, seed=1)