Skip to content

Commit

Permalink
Merge pull request #485 from lnccbrown/481-reconciliate-re-sampling-b…
Browse files Browse the repository at this point in the history
…ehavior

Reconciliate re-sampling behavior
  • Loading branch information
digicosmos86 authored Jul 12, 2024
2 parents 6a11c47 + dcc208c commit 3558223
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 32 deletions.
52 changes: 20 additions & 32 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,10 @@ def sample(
Returns
-------
az.InferenceData | pm.Approximation
An ArviZ `InferenceData` instance if inference_method is `"mcmc"`
(default), "nuts_numpyro", "nuts_blackjax" or "laplace". An `Approximation`
object if `"vi"`.
A reference to the `model.traces` object, which stores the traces of the
last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData`
instance if `sampler` is `"mcmc"` (default), `"nuts_numpyro"`,
`"nuts_blackjax"` or "`laplace"`, or an `Approximation` object if `"vi"`.
"""
# If initvals are None (default)
# we skip processing initvals here.
Expand Down Expand Up @@ -537,18 +538,16 @@ def sample(
# If sampler is finally `numpyro` make sure
# the jitter argument is set to False
if sampler == "nuts_numpyro":
if "jitter" not in kwargs.keys():
kwargs["jitter"] = False
elif kwargs["jitter"]:
if kwargs.get("jitter", None):
_logger.warning(
"The jitter argument is set to True. "
+ "This argument is not supported "
+ "by the numpyro backend. "
+ "The jitter argument will be set to False."
)
kwargs["jitter"] = False
elif sampler != "nuts_numpyro":
if "jitter" in kwargs.keys():
kwargs["jitter"] = False
else:
if "jitter" in kwargs:
_logger.warning(
"The jitter keyword argument is "
+ "supported only by the nuts_numpyro sampler. \n"
Expand All @@ -560,27 +559,21 @@ def sample(
# If not specified, include the mean prediction in
# kwargs to be passed to the model.fit() method
kwargs["include_mean"] = True
idata = self.model.fit(inference_method=sampler, init=init, **kwargs)

if self._inference_obj is None:
self._inference_obj = idata
elif isinstance(self._inference_obj, az.InferenceData):
_logger.info(
"Inference data already exsits. \n"
"Data from this run will overwrite the idata file..."
)
self._inference_obj.extend(idata, join="right")
else:
raise ValueError(
"The model has an attached inference object under"
+ " self._inference_obj, but it is not an InferenceData object."
if self._inference_obj is not None:
_logger.warning(
"The model has already been sampled. Overwriting the previous "
+ "inference object. Any previous reference to the inference object "
+ "will still point to the old object."
)
self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
)

# The parent was previously not part of deterministics --> compute it via
# posterior_predictive (works because it acts as the 'mu' parameter
# in the GLM as far as bambi is concerned)
if self._inference_obj is not None:
if self._parent not in self._inference_obj.posterior.data_vars.keys():
if self._parent not in self._inference_obj.posterior.data_vars:
# self.model.predict(self._inference_obj, kind="mean", inplace=True)
# rename 'rt,response_mean' to 'v' so in the traces everything
# looks the way it should
Expand All @@ -595,10 +588,7 @@ def sample(
# if parent already in posterior
del self._inference_obj.posterior["rt,response_mean"]

# returning copy of traces here to detach from the actual _inference_obj
# attached to the class. Otherise resampling will
# overwrite the 'returned' object leading to unexpected consequences
return deepcopy(self.traces)
return self.traces

def sample_posterior_predictive(
self,
Expand Down Expand Up @@ -1150,7 +1140,7 @@ def traces(self) -> az.InferenceData | pm.Approximation:
Returns
-------
az.InferenceData | pm.Approximation
The trace of the model after sampling.
The trace of the model after the last call to `sample()`.
"""
if not self._inference_obj:
raise ValueError("Please sample the model first.")
Expand Down Expand Up @@ -1515,9 +1505,7 @@ def _make_model_distribution(self) -> type[pm.Distribution]:
bounds=self.bounds,
lapse=self.lapse,
extra_fields=(
None
if not self.extra_fields
else [deepcopy(self.data[field].values) for field in self.extra_fields]
None if not self.extra_fields else deepcopy(self.extra_fields)
),
)

Expand Down
11 changes: 11 additions & 0 deletions tests/test_hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,14 @@ def test_override_default_link(caplog, data_ddm_reg):

assert "t" in caplog.records[0].message
assert "strange" in caplog.records[0].message


def test_resampling(data_ddm):
model = HSSM(data=data_ddm)
sample_1 = model.sample(draws=10, chains=1, tune=0)
assert sample_1 is model.traces

sample_2 = model.sample(draws=10, chains=1, tune=0)
assert sample_2 is model.traces

assert sample_1 is not sample_2

0 comments on commit 3558223

Please sign in to comment.