diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index e896bb8f..6c51756c 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -460,9 +460,10 @@ def sample( Returns ------- az.InferenceData | pm.Approximation - An ArviZ `InferenceData` instance if inference_method is `"mcmc"` - (default), "nuts_numpyro", "nuts_blackjax" or "laplace". An `Approximation` - object if `"vi"`. + A reference to the `model.traces` object, which stores the traces of the + last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData` + instance if `sampler` is `"mcmc"` (default), `"nuts_numpyro"`, + `"nuts_blackjax"` or "`laplace"`, or an `Approximation` object if `"vi"`. """ # If initvals are None (default) # we skip processing initvals here. @@ -537,18 +538,16 @@ def sample( # If sampler is finally `numpyro` make sure # the jitter argument is set to False if sampler == "nuts_numpyro": - if "jitter" not in kwargs.keys(): - kwargs["jitter"] = False - elif kwargs["jitter"]: + if kwargs.get("jitter", None): _logger.warning( "The jitter argument is set to True. " + "This argument is not supported " + "by the numpyro backend. " + "The jitter argument will be set to False." ) - kwargs["jitter"] = False - elif sampler != "nuts_numpyro": - if "jitter" in kwargs.keys(): + kwargs["jitter"] = False + else: + if "jitter" in kwargs: _logger.warning( "The jitter keyword argument is " + "supported only by the nuts_numpyro sampler. \n" @@ -560,27 +559,21 @@ def sample( # If not specified, include the mean prediction in # kwargs to be passed to the model.fit() method kwargs["include_mean"] = True - idata = self.model.fit(inference_method=sampler, init=init, **kwargs) - - if self._inference_obj is None: - self._inference_obj = idata - elif isinstance(self._inference_obj, az.InferenceData): - _logger.info( - "Inference data already exsits. \n" - "Data from this run will overwrite the idata file..." - ) - self._inference_obj.extend(idata, join="right") - else: - raise ValueError( - "The model has an attached inference object under" - + " self._inference_obj, but it is not an InferenceData object." + if self._inference_obj is not None: + _logger.warning( + "The model has already been sampled. Overwriting the previous " + + "inference object. Any previous reference to the inference object " + + "will still point to the old object." ) + self._inference_obj = self.model.fit( + inference_method=sampler, init=init, **kwargs + ) # The parent was previously not part of deterministics --> compute it via # posterior_predictive (works because it acts as the 'mu' parameter # in the GLM as far as bambi is concerned) if self._inference_obj is not None: - if self._parent not in self._inference_obj.posterior.data_vars.keys(): + if self._parent not in self._inference_obj.posterior.data_vars: # self.model.predict(self._inference_obj, kind="mean", inplace=True) # rename 'rt,response_mean' to 'v' so in the traces everything # looks the way it should @@ -595,10 +588,7 @@ def sample( # if parent already in posterior del self._inference_obj.posterior["rt,response_mean"] - # returning copy of traces here to detach from the actual _inference_obj - # attached to the class. Otherise resampling will - # overwrite the 'returned' object leading to unexpected consequences - return deepcopy(self.traces) + return self.traces def sample_posterior_predictive( self, @@ -1150,7 +1140,7 @@ def traces(self) -> az.InferenceData | pm.Approximation: Returns ------- az.InferenceData | pm.Approximation - The trace of the model after sampling. + The trace of the model after the last call to `sample()`. """ if not self._inference_obj: raise ValueError("Please sample the model first.") @@ -1515,9 +1505,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]: bounds=self.bounds, lapse=self.lapse, extra_fields=( - None - if not self.extra_fields - else [deepcopy(self.data[field].values) for field in self.extra_fields] + None if not self.extra_fields else deepcopy(self.extra_fields) ), ) diff --git a/tests/test_hssm.py b/tests/test_hssm.py index 88ca46e3..8e2dce1e 100644 --- a/tests/test_hssm.py +++ b/tests/test_hssm.py @@ -267,3 +267,14 @@ def test_override_default_link(caplog, data_ddm_reg): assert "t" in caplog.records[0].message assert "strange" in caplog.records[0].message + + +def test_resampling(data_ddm): + model = HSSM(data=data_ddm) + sample_1 = model.sample(draws=10, chains=1, tune=0) + assert sample_1 is model.traces + + sample_2 = model.sample(draws=10, chains=1, tune=0) + assert sample_2 is model.traces + + assert sample_1 is not sample_2