Skip to content

Commit

Permalink
should have vectorized log_likelihood function for NumPyro, here
Browse files Browse the repository at this point in the history
  • Loading branch information
aporsch1 committed Oct 3, 2024
1 parent 9411317 commit 1955360
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,26 @@ def sample_stats_to_xarray(self):
@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from NumPyro posterior."""
"""Extract log likelihood from NumPyro posterior using vectorization."""
if not self.log_likelihood:
return None

data = {}
if self.observations is not None:
samples = self.posterior.get_samples(group_by_chain=False)
if hasattr(samples, "_asdict"):
samples = samples._asdict()
log_likelihood_dict = self.numpyro.infer.log_likelihood(
self.model, samples, *self._args, **self._kwargs
)

# Vectorized log likelihood calculation using jax.vmap
log_likelihood_dict = jax.vmap(lambda single_sample:
self.numpyro.infer.log_likelihood(self.model, single_sample, *self._args, **self._kwargs)
)(samples)

# Process the log likelihood results
for obs_name, log_like in log_likelihood_dict.items():
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(np.asarray(log_like), shape)

return dict_to_dataset(
data,
library=self.numpyro,
Expand Down

0 comments on commit 1955360

Please sign in to comment.