Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax warnings #106

Open
MickaelRigault opened this issue Jan 3, 2023 · 2 comments
Open

jax warnings #106

MickaelRigault opened this issue Jan 3, 2023 · 2 comments

Comments

@MickaelRigault
Copy link

Hello guys,

Starting to use jax_cosmo and I have warnings concerning the current (pip) implementations

Users/rigault/miniforge3/lib/python3.9/site-packages/jax_cosmo/scipy/interpolate.py:35: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)
/Users/rigault/miniforge3/lib/python3.9/site-packages/jax_cosmo/scipy/interpolate.py:36: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
a = (fp[ind + np.copysign(1, s).astype(np.int64)] - fp[ind]) / (
/Users/rigault/miniforge3/lib/python3.9/site-packages/jax_cosmo/scipy/interpolate.py:37: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
xp[ind + np.copysign(1, s).astype(np.int64)] - xp[ind]

It is likely connected to the jax version.

I am running with jax v0.4.1

@EiffL
Copy link
Member

EiffL commented Jan 3, 2023

Hi Mickael, awesome if it could be useful to you :-)
Yes, these warnings are due to some interactions with int64 and int32 numbers (jax by default only do int32 and float32, unless you specify these ENABLE_X64 flags).

These messages are completely harmless, but definitely something we should clean up. I'm going to look into this for the upcoming 0.1.0 release.

@MickaelRigault
Copy link
Author

Hi François,
Yeah, I know they are harmless (so far), just to let you know for future releases.
Cheers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants