Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Uable to convert negative values from float16/float32 to int8/int32 in pallas #25047

Open
shangz-ai opened this issue Nov 21, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@shangz-ai
Copy link

Description

Hello,
I'm encountering an issue that when converting negative float16 values to int8 in pallas kernel, the conversion will fall incorrectly to zero value instead of a negative int8 value.
Similarly I observed from float32 to int32.
Is that a known issue for pallas? or I'm missing out something?
Attaching the small reproducer.

I'm seeing that

===a_int===
[-4 -3 -2 -1  0  1  2  3]
===convert_from_int===
[-4 -3 -2 -1  0  1  2  3]
==================
===a_half===
[-4. -3. -2. -1.  0.  1.  2.  3.]
===convert_from_half===
[0 0 0 0 0 1 2 3]
==================

which doesn't make sense to me.

Thanks!

System info (python version, jaxlib version, accelerator, etc.)

from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

def read_convert_kernel(x_ref, o_ref):
  x = x_ref[...]
  o_ref[...] = x.astype(o_ref.dtype)

@partial(jax.jit,static_argnames='out_type')
def read_convert(x: jax.Array, out_type) -> jax.Array:
  return pl.pallas_call(read_convert_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, out_type)
                        )(x)

a_half = jnp.arange(8, dtype=jnp.float16)-4
a_int = jnp.arange(8, dtype=jnp.int8)-4
out_from_int = read_convert(a_int, out_type=jnp.int8)
out_from_half = read_convert(a_half, out_type=jnp.int8)

print("===a_int===")
print(a_int)
print("===convert_from_int===")
print(out_from_int)
print("==================")

print("===a_half===")
print(a_half)
print("===convert_from_half===")
print(out_from_half)
print("==================")
@shangz-ai shangz-ai added the bug Something isn't working label Nov 21, 2024
@justinjfu
Copy link
Collaborator

Hello,

Are you on TPU or GPU? This looks like a bug in Pallas.

@justinjfu justinjfu self-assigned this Nov 25, 2024
@shangz-ai
Copy link
Author

Thanks for replying. I'm testing on H100.

@justinjfu
Copy link
Collaborator

#25134 should fix this issue.

It seems like it was a known failure that was being skipped in the test cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants