Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Support XLA_USE_BF16 #6817

Merged
merged 2 commits into from
Mar 26, 2024
Merged

[Pallas] Support XLA_USE_BF16 #6817

merged 2 commits into from
Mar 26, 2024

Conversation

alanwaketan
Copy link
Collaborator

Summary:
XLA_USE_BF16=1 will make all the internal xla tensors to use BF16 but torch.tensor wrappers will still return torch.float. To address this, we need to set the jax tracers correctly to produce the correct Mosaic.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_bf16

@alanwaketan alanwaketan requested a review from JackCaoG March 25, 2024 21:48
@alanwaketan alanwaketan self-assigned this Mar 25, 2024
@@ -74,9 +75,14 @@ def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
import jax._src.pallas.mosaic.pallas_call_registration

def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0") == "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should work in most cases, but the correct thing to do is to expose this static var in

static bool use_bf16 = ShouldUseBF16();
.

Otherwise there is a risk of XLA_USE_BF16 changed during the runtime...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me move this to the global scope to mimic the same effect.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I mean what we should do is to expose a pybind to call UseBF16. This way we won't have a case where C++ static variable is False and python static variable is True.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems overkilled given the source is the same env...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well I am fine with current approach, knowing it's unlikely for us to run into issue in real life.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Jack!

Summary:
XLA_USE_BF16=1 will make all the internal xla tensors to use BF16 but
torch.tensor wrappers will still return torch.float. To address this,
we need to set the jax tracers correctly to produce the correct Mosaic.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_bf16
@alanwaketan alanwaketan force-pushed the alanwaketan/pallas_bf16 branch from 860a296 to 26e87d7 Compare March 26, 2024 01:14
@alanwaketan alanwaketan merged commit 73c31db into master Mar 26, 2024
18 checks passed
@alanwaketan alanwaketan deleted the alanwaketan/pallas_bf16 branch March 26, 2024 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants