From 5b901594959e11c416614239f3dc15f16275d107 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 8 Jul 2024 15:49:05 -0400 Subject: [PATCH 1/5] Fixed logic about resampling --- src/hssm/hssm.py | 43 +++++++++++++++---------------------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index e896bb8f..e35a7c38 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -537,18 +537,16 @@ def sample( # If sampler is finally `numpyro` make sure # the jitter argument is set to False if sampler == "nuts_numpyro": - if "jitter" not in kwargs.keys(): - kwargs["jitter"] = False - elif kwargs["jitter"]: + if kwargs.get("jitter", None): _logger.warning( "The jitter argument is set to True. " + "This argument is not supported " + "by the numpyro backend. " + "The jitter argument will be set to False." ) - kwargs["jitter"] = False - elif sampler != "nuts_numpyro": - if "jitter" in kwargs.keys(): + kwargs["jitter"] = False + else: + if "jitter" in kwargs: _logger.warning( "The jitter keyword argument is " + "supported only by the nuts_numpyro sampler. \n" @@ -560,27 +558,21 @@ def sample( # If not specified, include the mean prediction in # kwargs to be passed to the model.fit() method kwargs["include_mean"] = True - idata = self.model.fit(inference_method=sampler, init=init, **kwargs) - - if self._inference_obj is None: - self._inference_obj = idata - elif isinstance(self._inference_obj, az.InferenceData): - _logger.info( - "Inference data already exsits. \n" - "Data from this run will overwrite the idata file..." - ) - self._inference_obj.extend(idata, join="right") - else: - raise ValueError( - "The model has an attached inference object under" - + " self._inference_obj, but it is not an InferenceData object." + if self._inference_obj is not None: + _logger.warning( + "The model has already been sampled. Overwriting the previous " + + "inference object. Any previous reference to the inference object " + + "will still point to the old object." ) + self._inference_obj = self.model.fit( + inference_method=sampler, init=init, **kwargs + ) # The parent was previously not part of deterministics --> compute it via # posterior_predictive (works because it acts as the 'mu' parameter # in the GLM as far as bambi is concerned) if self._inference_obj is not None: - if self._parent not in self._inference_obj.posterior.data_vars.keys(): + if self._parent not in self._inference_obj.posterior.data_vars: # self.model.predict(self._inference_obj, kind="mean", inplace=True) # rename 'rt,response_mean' to 'v' so in the traces everything # looks the way it should @@ -595,10 +587,7 @@ def sample( # if parent already in posterior del self._inference_obj.posterior["rt,response_mean"] - # returning copy of traces here to detach from the actual _inference_obj - # attached to the class. Otherise resampling will - # overwrite the 'returned' object leading to unexpected consequences - return deepcopy(self.traces) + return self.traces def sample_posterior_predictive( self, @@ -1515,9 +1504,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: bounds=self.bounds, lapse=self.lapse, extra_fields=( - None - if not self.extra_fields - else [deepcopy(self.data[field].values) for field in self.extra_fields] + None if not self.extra_fields else deepcopy(self.extra_fields) ), ) From 3b9a74cf345c743ab10b582bc77daa472bf21225 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Mon, 8 Jul 2024 15:49:30 -0400 Subject: [PATCH 2/5] Added tests to ensure resampling behavior --- tests/test_hssm.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 88ca46e3..8e2dce1e 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -267,3 +267,14 @@ def test_override_default_link(caplog, data_ddm_reg): assert "t" in caplog.records[0].message assert "strange" in caplog.records[0].message + + +def test_resampling(data_ddm): + model = HSSM(data=data_ddm) + sample_1 = model.sample(draws=10, chains=1, tune=0) + assert sample_1 is model.traces + + sample_2 = model.sample(draws=10, chains=1, tune=0) + assert sample_2 is model.traces + + assert sample_1 is not sample_2 From d666a0f3e1728097fdf55507a65094380ce97402 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Fri, 12 Jul 2024 08:32:18 -0400 Subject: [PATCH 3/5] update docstring to clarify behavior --- src/hssm/hssm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index e35a7c38..7b2ecb53 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -460,9 +460,9 @@ def sample( Returns ------- az.InferenceData | pm.Approximation - An ArviZ `InferenceData` instance if inference_method is `"mcmc"` - (default), "nuts_numpyro", "nuts_blackjax" or "laplace". An `Approximation` - object if `"vi"`. + A reference to the `model.traces` object, which is an ArviZ `InferenceData` + instance if inference_method is `"mcmc"` (default), `"nuts_numpyro"`, + `"nuts_blackjax"` or "`laplace"`, or an `Approximation` object if `"vi"`. """ # If initvals are None (default) # we skip processing initvals here. From 3ed3418c1aace0bdb4e3d9758d6fbe3ea6d5a023 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Fri, 12 Jul 2024 08:41:20 -0400 Subject: [PATCH 4/5] update docstring to clarify behavior --- src/hssm/hssm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 7b2ecb53..a4dbff04 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -1139,7 +1139,7 @@ def traces(self) -> az.InferenceData | pm.Approximation: Returns ------- az.InferenceData | pm.Approximation - The trace of the model after sampling. + The trace of the model after the last call to `sample()`. """ if not self._inference_obj: raise ValueError("Please sample the model first.") From dcc208c55d4fc89b9bc8a89382ee967f6141e811 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Fri, 12 Jul 2024 08:45:53 -0400 Subject: [PATCH 5/5] update docstring to clarify behavior --- src/hssm/hssm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index a4dbff04..6c51756c 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -460,8 +460,9 @@ def sample( Returns ------- az.InferenceData | pm.Approximation - A reference to the `model.traces` object, which is an ArviZ `InferenceData` - instance if inference_method is `"mcmc"` (default), `"nuts_numpyro"`, + A reference to the `model.traces` object, which stores the traces of the + last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData` + instance if `sampler` is `"mcmc"` (default), `"nuts_numpyro"`, `"nuts_blackjax"` or "`laplace"`, or an `Approximation` object if `"vi"`. """ # If initvals are None (default)