Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP?) vectorized log_likelihood function for NumPyro #2390

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

aporsch1
Copy link

@aporsch1 aporsch1 commented Oct 3, 2024

Description

Checklist

  • Follows official PR format
  • New features are properly documented
  • Code style correct (follows pylint and black guidelines)

📚 Documentation preview 📚: https://arviz--2390.org.readthedocs.build/en/2390/

@aporsch1
Copy link
Author

aporsch1 commented Oct 3, 2024

Hey, I looked at the checks that failed, and they are failing because they can't even find test cases. I don't think that is related to the updated code at all? Let me know if I am missing something, though.

@OriolAbril
Copy link
Member

@virajpandya could you try it out and see how timing compares to the ~80 mins from the latest release and setting log_likelihood=False?

You can install the arviz version of this PR with:

pip install "arviz @ git+https://github.com/aporsch1/arviz"

@OriolAbril OriolAbril changed the title (WIP?) vectorized log_likelihood function for NumPyro (https://github.com/arviz-devs/arviz/issues/2373) (WIP?) vectorized log_likelihood function for NumPyro Oct 7, 2024
@OriolAbril OriolAbril linked an issue Oct 7, 2024 that may be closed by this pull request
@OriolAbril
Copy link
Member

Hey, I looked at the checks that failed, and they are failing because they can't even find test cases. I don't think that is related to the updated code at all? Let me know if I am missing something, though.

The pylint checks are failing. These are the specific errors:

************* Module arviz.data.io_numpyro
arviz/data/io_numpyro.py:195:64: C0303: Trailing whitespace (trailing-whitespace)
arviz/data/io_numpyro.py:196:0: C0301: Line too long (105/100) (line-too-long)
arviz/data/io_numpyro.py:195:34: E0602: Undefined variable 'jax' (undefined-variable)

For the jax import, note that it is not a dependency of ArviZ (nor it should be) so it needs to be imported at runtime from inside the method itself. This is already done in the __init__ method for example: https://github.com/arviz-devs/arviz/blob/main/arviz/data/io_numpyro.py#L67

@lucifer4073
Copy link

Hey, I looked at the checks that failed, and they are failing because they can't even find test cases. I don't think that is related to the updated code at all? Let me know if I am missing something, though.

The pylint checks are failing. These are the specific errors:

************* Module arviz.data.io_numpyro
arviz/data/io_numpyro.py:195:64: C0303: Trailing whitespace (trailing-whitespace)
arviz/data/io_numpyro.py:196:0: C0301: Line too long (105/100) (line-too-long)
arviz/data/io_numpyro.py:195:34: E0602: Undefined variable 'jax' (undefined-variable)

For the jax import, note that it is not a dependency of ArviZ (nor it should be) so it needs to be imported at runtime from inside the method itself. This is already done in the __init__ method for example: https://github.com/arviz-devs/arviz/blob/main/arviz/data/io_numpyro.py#L67

You might try this in your terminal.

black arviz/ examples/ asv_benchmarks/

This will format the code according to the benchmark. Once done re-add the changes (git add. -u) and commit them.
Let me know.

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried testing this locally on a variation of the model in https://python.arviz.org/en/stable/getting_started/CreatingInferenceData.html#from-numpyro but with random y and sigma with 30k elements plus generating 2k posterior samples.

The version with vmap (after the fixes mentioned in the review) and the current version took basically the same time. The log_likelihood function itself in numpyro calls a soft_vmap so there might not even be any difference between using vmap directly on our side or calling numpyro directly.

I did still crash my computer multiple times with both versions when I attempted running things in a loop to get some average timings which makes me suspect there are memory leaks somewhere in the process which might even be the reason of the slowness.

I am sorry but I don't think it makes sense to merge this before we can get reproducible models that take extremely long with the current version yet run fast with this vmap version

@@ -181,20 +181,26 @@ def sample_stats_to_xarray(self):
@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from NumPyro posterior."""
"""Extract log likelihood from NumPyro posterior using vectorization."""
if not self.log_likelihood:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return None
return None
import jax


# Vectorized log likelihood calculation using jax.vmap
log_likelihood_dict = jax.vmap(lambda single_sample:
self.numpyro.infer.log_likelihood(self.model, single_sample, *self._args, **self._kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.numpyro.infer.log_likelihood(self.model, single_sample, *self._args, **self._kwargs)
self.numpyro.infer.log_likelihood(self.model, single_sample, *self._args, batch_ndims=0, **self._kwargs)

It doesn't work without this because batching is not taken care of directly in vmap but this function expects a batch dimension too and fails when it is not there (or seemingly changes with the different variables)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Log likelihood computation in numpyro can be extremely slow
3 participants