-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pallas] Support XLA_USE_BF16 #6817
Conversation
@@ -74,9 +75,14 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable): | |||
import jax._src.pallas.mosaic.pallas_call_registration | |||
|
|||
def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: | |||
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should work in most cases, but the correct thing to do is to expose this static var in
Line 54 in 22fe1dc
static bool use_bf16 = ShouldUseBF16(); |
Otherwise there is a risk of XLA_USE_BF16
changed during the runtime...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me move this to the global scope to mimic the same effect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I mean what we should do is to expose a pybind to call UseBF16
. This way we won't have a case where C++ static variable is False and python static variable is True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems overkilled given the source is the same env...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well I am fine with current approach, knowing it's unlikely for us to run into issue in real life.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Jack!
Summary: XLA_USE_BF16=1 will make all the internal xla tensors to use BF16 but torch.tensor wrappers will still return torch.float. To address this, we need to set the jax tracers correctly to produce the correct Mosaic. Test Plan: PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_bf16
860a296
to
26e87d7
Compare
Summary:
XLA_USE_BF16=1 will make all the internal xla tensors to use BF16 but torch.tensor wrappers will still return torch.float. To address this, we need to set the jax tracers correctly to produce the correct Mosaic.
Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_bf16