Skip to content

Commit

Permalink
[Pallas] Set the major-minor layout for inputs and outputs (pytorch#6826
Browse files Browse the repository at this point in the history
)

Summary:
Mosaic only accepts major-minor layout for both its inputs and outputs. So we need to enforce those layouts by setting the expected input&output shapes in xla::CustomCallWithLayout. After this change, XLA_TPU_LAYOUT is no longer needed.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py
  • Loading branch information
alanwaketan authored Mar 27, 2024
1 parent d6fb539 commit d99ceb6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 18 deletions.
5 changes: 0 additions & 5 deletions configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,6 @@ variables:
dot.
type: string
default_value: ""
XLA_TPU_LAYOUT:
description:
- Determine to use TPU layout or not, where it will use sorted layout for TPU.
type: bool
default_value: true
PT_XLA_DEBUG_FILE:
description:
- If set, filepath used for printing out reports.
Expand Down
8 changes: 0 additions & 8 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def test_tpu_custom_call_pallas_raise(self):
output.cpu()

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# Mosiac is not compatible with our sorted layout that boosts performance for dim > 2 tensor input applications, like resnet.
# For LLM, it should be fine since all inputs are 2D.
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_tpu_custom_call_pallas_flash_attention(self):
# This payload is generated by the following Pallas code:
# https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
Expand Down Expand Up @@ -178,7 +175,6 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_tpu_custom_call_pallas_wrap_flash_attention(self):
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
Expand All @@ -204,7 +200,6 @@ def attention(q, k, v):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_flash_attention_wrapper(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention
Expand All @@ -226,7 +221,6 @@ def attention(q, k, v):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_flash_attention_wrapper_causal(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention
Expand All @@ -250,7 +244,6 @@ def attention(q, k, v):

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
@unittest.mock.patch.dict(os.environ, {"XLA_USE_BF16": "1"})
def test_flash_attention_wrapper_bf16(self):
from torch_xla.experimental.custom_kernel import flash_attention
Expand All @@ -265,7 +258,6 @@ def test_flash_attention_wrapper_bf16(self):

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
# TODO: do we want to set the following flags?
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
Expand Down
4 changes: 1 addition & 3 deletions torch_xla/csrc/layout_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,7 @@ xla::Shape MakeArrayShapeFromDimensions(
return MakeShapeWithLayout(type, dimensions, dynamic_dimensions,
*layout_ptr);
}

bool tpu_layout_env = runtime::sys_util::GetEnvBool("XLA_TPU_LAYOUT", true);
if (tpu_layout_env && dimensions.size() > 1 && CheckTpuDevice(hw_type)) {
if (dimensions.size() > 1 && CheckTpuDevice(hw_type)) {
return MakeTpuShape(dimensions, dynamic_dimensions, type);
}
return MakeTorchTensorLayout(dimensions, dynamic_dimensions, type);
Expand Down
13 changes: 11 additions & 2 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "torch_xla/csrc/data_ops.h"
#include "torch_xla/csrc/elementwise.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/layout_manager.h"
#include "torch_xla/csrc/random.h"
#include "torch_xla/csrc/reduction.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
Expand Down Expand Up @@ -1252,16 +1253,24 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) {
xla::XlaOp BuildTpuCustomCall(const std::vector<xla::XlaOp>& inputs,
const xla::Shape& output_shape,
const std::string& payload) {
// We need to enforce the default C-order (major-to-minor) layouts for inputs
// to Mosaic and outputs from Mosaic.
std::vector<xla::Shape> input_shapes;
input_shapes.reserve(inputs.size());
for (const auto& input : inputs) {
input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input));
xla::Shape shape = ShapeHelper::ShapeOfXlaOp(input);
input_shapes.push_back(MakeTorchTensorLayout(
shape.dimensions(), shape.dynamic_dimensions(), shape.element_type()));
}
xla::Shape output_shape_impl = MakeTorchTensorLayout(
output_shape.dimensions(), output_shape.dynamic_dimensions(),
output_shape.element_type());

XLA_CHECK(inputs.size() > 0) << "inputs are empty";
return xla::CustomCallWithLayout(inputs[0].builder(),
/*call_target_name=*/"tpu_custom_call",
inputs, output_shape, input_shapes, payload);
inputs, output_shape_impl, input_shapes,
payload);
}

} // namespace torch_xla

0 comments on commit d99ceb6

Please sign in to comment.