diff --git a/arviz/data/io_numpyro.py b/arviz/data/io_numpyro.py index f8fc098a68..71745e0467 100644 --- a/arviz/data/io_numpyro.py +++ b/arviz/data/io_numpyro.py @@ -181,20 +181,26 @@ def sample_stats_to_xarray(self): @requires("posterior") @requires("model") def log_likelihood_to_xarray(self): - """Extract log likelihood from NumPyro posterior.""" + """Extract log likelihood from NumPyro posterior using vectorization.""" if not self.log_likelihood: return None + data = {} if self.observations is not None: samples = self.posterior.get_samples(group_by_chain=False) if hasattr(samples, "_asdict"): samples = samples._asdict() - log_likelihood_dict = self.numpyro.infer.log_likelihood( - self.model, samples, *self._args, **self._kwargs - ) + + # Vectorized log likelihood calculation using jax.vmap + log_likelihood_dict = jax.vmap(lambda single_sample: + self.numpyro.infer.log_likelihood(self.model, single_sample, *self._args, **self._kwargs) + )(samples) + + # Process the log likelihood results for obs_name, log_like in log_likelihood_dict.items(): shape = (self.nchains, self.ndraws) + log_like.shape[1:] data[obs_name] = np.reshape(np.asarray(log_like), shape) + return dict_to_dataset( data, library=self.numpyro,