Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Support Flash Attention #6658

Merged
merged 9 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ variables:
dot.
type: string
default_value: ""
XLA_TPU_LAYOUT:
description:
- Determine to use TPU layout or not, where it will use sorted layout for TPU.
type: bool
default_value: true
PT_XLA_DEBUG_FILE:
description:
- If set, filepath used for printing out reports.
Expand Down Expand Up @@ -409,9 +414,9 @@ variables:
default_value: false
XLA_DISABLE_FUNCTIONALIZATION:
description:
- Setting this to true will disable functionalization, which is a dispatcher
- Setting this to true will disable functionalization, which is a dispatcher
pass in PyTorch that remove views and mutations to produce functional graphs.
This flag's main purpose is to A/B test the impact of functionalization in
This flag's main purpose is to A/B test the impact of functionalization in
your code.
type: bool
default_value: false
30 changes: 30 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,6 +1967,36 @@ def test_tpu_custom_call_pallas_raise(self):
torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload)
output.cpu()

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# Mosiac is not compatible with our sorted layout that boosts performance for dim > 2 tensor input applications, like resnet.
# For LLM, it should be fine since all inputs are 2D.
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_tpu_custom_call_pallas_flash_attention(self):
# This payload is generated by the following Pallas code:
# https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# To be noted, set `jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)`` before generating the payload.
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTkuMC4wZ2l0AAFBDQEDBQcJCwEDDQMFAw8FAxEHBxMVFwkLGRsdHyELAyMDrgI+AhsB8wcTCwsPEwsPDxMLCwsLkwsTCw8TDwsLCwsPCwsLDw8LCw8LDw8PDxcTE0MLGwvFC5MLCwsLGxsLGwsbCxsLGxsbGw8PDw8XDwsXDw8LFw8PCxcPDwsXDwsTCw8PFxMfCw8PFyMPEx8LDxcbDw8LDxcLDwsTHwsPFxsFCY15kWEHA1kJBV1JAR8PCxMTFxMTFxcfCxMXIwsBGw8HKx8bBxcjDwsbLy8CYg0fAwMNhwUlBScVj5UdOgJTBSkdI4kdI7UdIxYCBSsFLQUvBTEjEQlBAQAAAAAAAAABAAAAAAAAAIAAAAAAAAAABAAAAAAAAAANGQMDDYUFMxETAAMD4fsREQEFNQU3BTkFOx2/wQU9BT8FQR3PPRXRCQVDBUUBA9cFRx3bSRXdCR3rTRXtCR0GAgoCHSoCUxUuAgkDD1dZFVtfYWMpZSkXZ2lrBUkBCfPz8/cNF2FmZmluZV9tYXA8KGQwLCBkMSwgZDIsIGQzKSAtPiAoZDAsIGQxLCBkMiwgZDMpPgAFSyMRCUEDAAAAAAAAAAIAAAAAAAAAAQAAAAAAAAABAAAAAAAAAAVNBU8FUQVTAQltcXV5AwUZbxsdCSsDBRlzGx0JLQMFGXcbHQkvAwUZexsdCTEDBRUfFysDBRUfFy0DBRUfFy8DBRUfFzERAQERAwEViwkdB40XBRoIAR2RkwVVFwVKBQEVl50dmZsFVxcFqgsBFZ+lHaGjBVkXBWIDARWnrR2pqwVbFwUaAwEdr7EFXRezZQEFXxW3CR0HuRcFHggBAwMNvSUHCQAAAAAFYRXDCR0HxRcFIggBAwc19TclOckREwEDAw3NJQ0JAACA/wVjHQfTFwW2CAEDBT/9QUMREQUdRT0FZR0H3xcFuggBBWcd5UkFaQMDDeklDQkAAAAABWsdB+8XBb4IAQMFP/9BQyN0cHUuZGltZW5zaW9uX3NlbWFudGljczxwYXJhbGxlbD4AI3RwdS5jb250cmFjdF9wcmVjaXNpb248ZnAzMj4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI3ZlY3Rvci5raW5kPGFkZD4AHUVNBW0VDgIJHQcSAhcFwggBFRoCCR0HHgIXBd4IAQMDDSYCJQkJAAAAAAVvHQcyAhcF4ggBAwc19TclOSUFcQECAgMX+QkFBQIEEQtdJwUCBAIECycFAgQRCwsnAwIECycJBQUCBBELAQIEAQknBQIEBQsFEQEBAQEFBQUFAQUJAQEBAQkBAQEBBEIHBQEQAQcDARUDEQFVBwNhqxEBAQEBAQEBAQUBBQEFAQUBCQMPAwMDCQMPAwMDCQMPAwMDCQMPAwMDEQYPAw8LCRETFRcPBg8DCQMZCQMRAwMDCQMRAwMDCQMRAwMDCQMRAwMDEQYRAw8LCx0fISMPBhEDCQMlCQMzuwMHBwczxwMHBxsnKQkDO8sDDRMHO9UDDQUrLQ8G2QMVAy8VBkcDBwMxCwdHJwMHBSszGQfjJwMHAzUJA0vnAw0TB0vxAw0FNzkPBgICAxUDOxUGTwMHAz0NB08nAwcFNz8JAxMDAwMJAxMDAwMJAxMDAwMJAxMDAwMRBhMDDwsNQ0VHSQ8GEwMJA0sJA1EiAgMJBwdRNgIDCQdBTU8JAwsDAwMJAwsDAwMJAwsDAwMJAwsDAwMRBgsDDwsPU1VXWQ8GCwMJA1sPBgsDDwNRFwQLDV8PU1VXWQUAAQMRAX0HAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkDEQF/BwMLCwkBAQEBAQEBAQkDASEDAQUEAQkBAwcJAxEBgQcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMHCQMRAYMHAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkGAwEFAQDuFnOGAk4CCy8LEwsvTgJTEyEjLTEdCyMhIyl5HwsdHRUZGRkZggIdJRMdDWPHCQ0VIQsXCwsTDw8PCw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbWF0aABtb2R1bGUAcmV0dXJuAG1hdG11bABjb25zdGFudABzdWJmAGRpdmYAc2hhcGVfY2FzdABsb2FkAG11bHRpX3JlZHVjdGlvbgBicm9hZGNhc3QAc3RvcmUAZXhwAC9ob21lL2p3dGFuLy5sb2NhbC9saWIvcHl0aG9uMy4xMC9zaXRlLXBhY2thZ2VzL2pheC9leHBlcmltZW50YWwvcGFsbGFzL29wcy90cHUvZmxhc2hfYXR0ZW50aW9uLnB5AF9mbGFzaF9hdHRlbnRpb25fa2VybmVsX3NpbmdsZV9iYXRjaF9zaW5nbGVfc3RlcAB2YWx1ZQBmdW5jdGlvbl90eXBlAHN5bV9uYW1lAHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL2dldFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHRyYW5zZm9ybV8zAHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAbWFpbgB3aW5kb3dfcGFyYW1zAF9mbGFzaF9hdHRlbnRpb25fa2VybmVsAF9mbGFzaF9hdHRlbnRpb25faW1wbABfZmxhc2hfYXR0ZW50aW9uAGZsYXNoX2F0dGVudGlvbgA8bW9kdWxlPgAvbW50L2Rpc2tzL3NzZC93b3JrL3BhbGxhcy9wYWxsYXNfYWRkLnB5AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPSg8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+LCA8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+KSBwcmVmZXJyZWRfZWxlbWVudF90eXBlPWZsb2F0MzJdAC9yZWR1Y2VfbWF4W2F4ZXM9KDEsKV0AL3N1YgBmYXN0bWF0aAAvZXhwAC9yZWR1Y2Vfc3VtW2F4ZXM9KDEsKV0AL2RpdgAvZG90X2dlbmVyYWxbZGltZW5zaW9uX251bWJlcnM9KCgoMSwpLCAoMCwpKSwgKCgpLCAoKSkpIHByZWNpc2lvbj0oPFByZWNpc2lvbi5ISUdIRVNUOiAyPiwgPFByZWNpc2lvbi5ISUdIRVNUOiAyPikgcHJlZmVycmVkX2VsZW1lbnRfdHlwZT1mbG9hdDMyXQAvc3dhcFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"needs_layout_passes\": true}}"

# The division is to cause potential precision issue on TPU.
q_mini = torch.arange(128 * 4, dtype=torch.float32).reshape(128, 4) / 13
k_mini = torch.arange(
1000, 1000 + 128 * 4, dtype=torch.float32).reshape(128, 4) / 13
q = q_mini.broadcast_to(3, 2, 128, 4).to("xla")
k = k_mini.broadcast_to(3, 2, 128, 4).to("xla")
v = torch.ones(3, 2, 128, 4).to("xla")
o = torch.zeros(3, 2, 128, 4).to("xla")

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

expected_o = attention(q, k, v)

torch_xla._XLAC._xla_tpu_custom_call_(o, [q, k, v], payload)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# TODO: Make the tpu_custom_call_ as functional.
@unittest.mock.patch.dict(os.environ, {"XLA_DISABLE_FUNCTIONALIZATION": "1"})
Expand Down
5 changes: 4 additions & 1 deletion torch_xla/csrc/layout_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,10 @@ xla::Shape MakeArrayShapeFromDimensions(
return MakeShapeWithLayout(type, dimensions, dynamic_dimensions,
*layout_ptr);
}
if (dimensions.size() > 1 && hw_type == XlaDeviceType::TPU) {

bool tpu_layout_env = runtime::sys_util::GetEnvBool("XLA_TPU_LAYOUT", true);
if (tpu_layout_env && dimensions.size() > 1 &&
hw_type == XlaDeviceType::TPU) {
return MakeTpuShape(dimensions, dynamic_dimensions, type);
}
return MakeTorchTensorLayout(dimensions, dynamic_dimensions, type);
Expand Down
Loading