Skip to content

Commit

Permalink
Merge pull request #478 from lnccbrown/476-inference-object-samples-d…
Browse files Browse the repository at this point in the history
…ont-get-updated-if-new-samples-are-generated-based-on-an-existing-model

fix bug where resampling doesnt update the traces argument
  • Loading branch information
AlexanderFengler authored Jul 1, 2024
2 parents 39df93a + a0b28d9 commit 6a11c47
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ keywords = ["HSSM", "sequential sampling models", "bayesian", "bayes", "mcmc"]

[tool.poetry.dependencies]
python = ">=3.10,<3.12"
pymc = "^5.14.0"
pymc = ">=5.14.0,<5.15.0"
arviz = "^0.18.0"
onnx = "^1.16.0"
ssm-simulators = "^0.7.2"
Expand Down
12 changes: 10 additions & 2 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,11 @@ def sample(
if self._inference_obj is None:
self._inference_obj = idata
elif isinstance(self._inference_obj, az.InferenceData):
self._inference_obj.extend(idata)
_logger.info(
"Inference data already exsits. \n"
"Data from this run will overwrite the idata file..."
)
self._inference_obj.extend(idata, join="right")
else:
raise ValueError(
"The model has an attached inference object under"
Expand All @@ -590,7 +594,11 @@ def sample(
# drop redundant 'rt,response_mean' variable,
# if parent already in posterior
del self._inference_obj.posterior["rt,response_mean"]
return self.traces

# returning copy of traces here to detach from the actual _inference_obj
# attached to the class. Otherise resampling will
# overwrite the 'returned' object leading to unexpected consequences
return deepcopy(self.traces)

def sample_posterior_predictive(
self,
Expand Down

0 comments on commit 6a11c47

Please sign in to comment.