Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2.5 #8220

Closed
wants to merge 11 commits into from
6 changes: 3 additions & 3 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
@@ -117,8 +117,7 @@ def test_tpu_custom_call_pallas_flash_attention(self):
# This payload is generated by the following Pallas code:
# https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# To be noted, set `jax.config.update('jax_default_matmul_precision', 'highest')`` before generating the payload.
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTkuMC4wZ2l0AAFBDQEDBQcJCwEDDQMFAw8FAxEHBxMVFwkLGRsdHyELAyMDrgI+AhsB8wcTCwsPEwsPDxMLCwsLkwsTCw8TDwsLCwsPCwsLDw8LCw8LDw8PDxcTE0MLGwvFC5MLCwsLGxsLGwsbCxsLGxsbGw8PDw8XDwsXDw8LFw8PCxcPDwsXDwsTCw8PFxMfCw8PFyMPEx8LDxcbDw8LDxcLDwsTHwsPFxsFCY15kWEHA1kJBV1JAR8PCxMTFxMTFxcfCxMXIwsBGw8HKx8bBxcjDwsbLy8CYg0fAwMNhwUlBScVj5UdOgJTBSkdI4kdI7UdIxYCBSsFLQUvBTEjEQlBAQAAAAAAAAABAAAAAAAAAIAAAAAAAAAABAAAAAAAAAANGQMDDYUFMxETAAMD4fsREQEFNQU3BTkFOx2/wQU9BT8FQR3PPRXRCQVDBUUBA9cFRx3bSRXdCR3rTRXtCR0GAgoCHSoCUxUuAgkDD1dZFVtfYWMpZSkXZ2lrBUkBCfPz8/cNF2FmZmluZV9tYXA8KGQwLCBkMSwgZDIsIGQzKSAtPiAoZDAsIGQxLCBkMiwgZDMpPgAFSyMRCUEDAAAAAAAAAAIAAAAAAAAAAQAAAAAAAAABAAAAAAAAAAVNBU8FUQVTAQltcXV5AwUZbxsdCSsDBRlzGx0JLQMFGXcbHQkvAwUZexsdCTEDBRUfFysDBRUfFy0DBRUfFy8DBRUfFzERAQERAwEViwkdB40XBRoIAR2RkwVVFwVKBQEVl50dmZsFVxcFqgsBFZ+lHaGjBVkXBWIDARWnrR2pqwVbFwUaAwEdr7EFXRezZQEFXxW3CR0HuRcFHggBAwMNvSUHCQAAAAAFYRXDCR0HxRcFIggBAwc19TclOckREwEDAw3NJQ0JAACA/wVjHQfTFwW2CAEDBT/9QUMREQUdRT0FZR0H3xcFuggBBWcd5UkFaQMDDeklDQkAAAAABWsdB+8XBb4IAQMFP/9BQyN0cHUuZGltZW5zaW9uX3NlbWFudGljczxwYXJhbGxlbD4AI3RwdS5jb250cmFjdF9wcmVjaXNpb248ZnAzMj4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI3ZlY3Rvci5raW5kPGFkZD4AHUVNBW0VDgIJHQcSAhcFwggBFRoCCR0HHgIXBd4IAQMDDSYCJQkJAAAAAAVvHQcyAhcF4ggBAwc19TclOSUFcQECAgMX+QkFBQIEEQtdJwUCBAIECycFAgQRCwsnAwIECycJBQUCBBELAQIEAQknBQIEBQsFEQEBAQEFBQUFAQUJAQEBAQkBAQEBBEIHBQEQAQcDARUDEQFVBwNhqxEBAQEBAQEBAQUBBQEFAQUBCQMPAwMDCQMPAwMDCQMPAwMDCQMPAwMDEQYPAw8LCRETFRcPBg8DCQMZCQMRAwMDCQMRAwMDCQMRAwMDCQMRAwMDEQYRAw8LCx0fISMPBhEDCQMlCQMzuwMHBwczxwMHBxsnKQkDO8sDDRMHO9UDDQUrLQ8G2QMVAy8VBkcDBwMxCwdHJwMHBSszGQfjJwMHAzUJA0vnAw0TB0vxAw0FNzkPBgICAxUDOxUGTwMHAz0NB08nAwcFNz8JAxMDAwMJAxMDAwMJAxMDAwMJAxMDAwMRBhMDDwsNQ0VHSQ8GEwMJA0sJA1EiAgMJBwdRNgIDCQdBTU8JAwsDAwMJAwsDAwMJAwsDAwMJAwsDAwMRBgsDDwsPU1VXWQ8GCwMJA1sPBgsDDwNRFwQLDV8PU1VXWQUAAQMRAX0HAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkDEQF/BwMLCwkBAQEBAQEBAQkDASEDAQUEAQkBAwcJAxEBgQcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMHCQMRAYMHAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkGAwEFAQDuFnOGAk4CCy8LEwsvTgJTEyEjLTEdCyMhIyl5HwsdHRUZGRkZggIdJRMdDWPHCQ0VIQsXCwsTDw8PCw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbWF0aABtb2R1bGUAcmV0dXJuAG1hdG11bABjb25zdGFudABzdWJmAGRpdmYAc2hhcGVfY2FzdABsb2FkAG11bHRpX3JlZHVjdGlvbgBicm9hZGNhc3QAc3RvcmUAZXhwAC9ob21lL2p3dGFuLy5sb2NhbC9saWIvcHl0aG9uMy4xMC9zaXRlLXBhY2thZ2VzL2pheC9leHBlcmltZW50YWwvcGFsbGFzL29wcy90cHUvZmxhc2hfYXR0ZW50aW9uLnB5AF9mbGFzaF9hdHRlbnRpb25fa2VybmVsX3NpbmdsZV9iYXRjaF9zaW5nbGVfc3RlcAB2YWx1ZQBmdW5jdGlvbl90eXBlAHN5bV9uYW1lAHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL2dldFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHRyYW5zZm9ybV8zAHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAbWFpbgB3aW5kb3dfcGFyYW1zAF9mbGFzaF9hdHRlbnRpb25fa2VybmVsAF9mbGFzaF9hdHRlbnRpb25faW1wbABfZmxhc2hfYXR0ZW50aW9uAGZsYXNoX2F0dGVudGlvbgA8bW9kdWxlPgAvbW50L2Rpc2tzL3NzZC93b3JrL3BhbGxhcy9wYWxsYXNfYWRkLnB5AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPSg8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+LCA8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+KSBwcmVmZXJyZWRfZWxlbWVudF90eXBlPWZsb2F0MzJdAC9yZWR1Y2VfbWF4W2F4ZXM9KDEsKV0AL3N1YgBmYXN0bWF0aAAvZXhwAC9yZWR1Y2Vfc3VtW2F4ZXM9KDEsKV0AL2RpdgAvZG90X2dlbmVyYWxbZGltZW5zaW9uX251bWJlcnM9KCgoMSwpLCAoMCwpKSwgKCgpLCAoKSkpIHByZWNpc2lvbj0oPFByZWNpc2lvbi5ISUdIRVNUOiAyPiwgPFByZWNpc2lvbi5ISUdIRVNUOiAyPikgcHJlZmVycmVkX2VsZW1lbnRfdHlwZT1mbG9hdDMyXQAvc3dhcFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"needs_layout_passes\": true}}"

payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMjAuMC4wZ2l0AAEvCwEDBQcJAQMLAxkNDxETFRcZGx0fISMD0gJaAhsB9QcTCwsPExMLDw8TCwsLC5MLCw8TDwsLCwsLDwsLCw8PCwsPCw8PExMXExMTCw9DCxsLxQuTCwsLCxsbCxsLGwsbCxsbGxsPDw8PFw8LFw8PCxcPDwsXDw8LFw8PCxcPCxcPDxcTHwsPDxcjDxMfCw8XGw8PCw8XCw8LBQmNeZFhBwNdCQNZASsXHwsTFx8PCxMTFxMTFxcfCxMXIwsHA0kBGw8HKx8bBxcPIwsbLy8C0g0fAwMPjwUlBScVl50DAw+NHVICVQUpHSORHSPDHSMuAgUrBS0FLwUxIw8JQQEAAAAAAAAAAQAAAAAAAACAAAAAAAAAAAQAAAAAAAAADRkFMxETAAMD7/8RDwEFNQU3BTkFOwU9Hc3PBT8FQQVDHd0/Fd8JBUUFRwED5QVJHelLFesJHQoCTxUOAgkdHgIiAh1CAlUVRgIJAwNZWwVLEQ8JAw9fYRdjZ2lrKW0pGW9xcwVNAQn19fX5DRdhZmZpbmVfbWFwPChkMCwgZDEsIGQyLCBkMykgLT4gKGQwLCBkMSwgZDIsIGQzKT4ABU8jDwlBAwAAAAAAAAACAAAAAAAAAAEAAAAAAAAAAQAAAAAAAAAFUQVTBVUFVwEJdXl9gQMFG3cdHwkrAwUbex0fCS0DBRt/HR8JLwMFG4MdHwkxAwUXIRkrAwUXIRktAwUXIRkvAwUXIRkxEQEBEQMBFZMJHQeVFwUGCAEdmZsFWRcFSgUBFZ+lHaGjBVsXBYYLARWnrR2pqwVdFwViAwEVr7UdsbMFXxcFGgMBFbe9Hbm7BWEXM14DAR2/wQVjFzM2EAEVxQkdB8cXBQoIAQMDD8slBwkAAAAABWUV0QkdB9MXBQ4IAQMHN/c5JTvXERMBAwMP2yUNCQAAgP8FZx0H4RcFoggBAwVB/UNFEQ8FHUc/BWkdB+0XBaYIAQVrHfNLBW0jdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8cGFyYWxsZWw+ACN0cHUuY29udHJhY3RfcHJlY2lzaW9uPGZwMzI+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+AAMDDwYCJQ0JAAAAAAVvHQcSAhcFqggBAwVBVgJDRR1HTwVxFSYCCR0HKgIXBa4IARUyAgkdBzYCFwXKCAEDAw8+AiUJCQAAAAAFcx0HSgIXBc4IAQMHN/c5JTslBXUjdmVjdG9yLmtpbmQ8YWRkPgABAgIDF/sJBQUCBBELZScFAgQCBAsnBQIEEQsLJwMCBAsBAgQnCQUFAgQRCwEJJwUCBAULBREBAQEBBQUFBQEFCQEBAQEJAQEBAQSuBwUBEQFXBwMBFQcRAV0HA2GrEQEBAQEBAQEBBQEFAQUBBQEDAxEDAwMDAxEDAwMDAxEDAwMDAxEDAwMLBhEDEQsJERMVFwUGEQMJAxkDAxMDAwMDAxMDAwMDAxMDAwMDAxMDAwMLBhMDEQsLHR8hIwUGEwMJAyUDAzXJAwcNBzXVAwcHGycpAwM92QMNDwc94wMNBSstBQbnAxUDLxEGSQMHAzETB0knAwcFKzMVB/EnAwcDNQMDTQICAw0PB00WAgMNBTc5BQYaAgMVAzsRBlEDBwM9FwdRJwMHBTc/AwMVAwMDAwMVAwMDAwMVAwMDAwMVAwMDCwYVAxELDUNFR0kFBhUDCQNLAwNTOgIDCQ0HU04CAwkHQU1PAwMNAwMDAwMNAwMDAwMNAwMDAwMNAwMDCwYNAxELD1NVV1kFBg0DCQNbBQYNAxEDURkEDQ1fD1NVV1kJAAEHEQGFBwMNDwkBAQEBAQEBAQMDAQsDAQMDAQsDAQkEAQkBAwUJBxEBhwcDDQ8JAQEBAQEBAQEDAwELAwEDAwELAwEJBAEJAQMHCQcRAYkHAw0PCQEBAQEBAQEBAwMBCwMBAwMBCwMBCQQBCQEDBwkHEQGLBwMNDwkBAQEBAQEBAQMDAQsDAQMDAQsDAQkEAQkBAwUJBgMBBQEAMhp37gImAgsvCxMLLyYCE2MhIy0xHQsjISMpLXkfCx0dFVkZGRkZ6gIdJRMdDWPvGxcTFyMvFxkZFSUfDw0PCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQB2ZWN0b3IAYXJpdGgAbW9kdWxlAGFyaXRoLmNvbnN0YW50AHZlY3Rvci5zaGFwZV9jYXN0AGZ1bmMuZnVuYwBmdW5jLnJldHVybgB2ZWN0b3IubG9hZAB0cHUubWF0bXVsAHZlY3Rvci5tdWx0aV9yZWR1Y3Rpb24AdmVjdG9yLmJyb2FkY2FzdABhcml0aC5zdWJmAG1hdGguZXhwAGFyaXRoLmRpdmYAdmVjdG9yLnN0b3JlAC9ob21lL2JiYWhsL21pbmljb25kYTMvZW52cy90b3JjaHNlcDEwL2xpYi9weXRob24zLjEwL3NpdGUtcGFja2FnZXMvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9mbGFzaF9hdHRlbnRpb24ucHkAX2ZsYXNoX2F0dGVudGlvbl9rZXJuZWxfc2luZ2xlX2JhdGNoX3NpbmdsZV9zdGVwAHZhbHVlAGZ1bmN0aW9uX3R5cGUAc3ltX25hbWUAdHJhbnNmb3JtX2luZGljZXMAd2luZG93X2JvdW5kcwAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKCgqLCAqLCBDdXN0b21Ob2RlKFNsaWNlWygwLCAxMjgsIDEpXSwgW05vbmUsIE5vbmVdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCwgMSldLCBbTm9uZSwgTm9uZV0pKSksICgxLCAxLCAxMjgsIDQpLCAoKSldLCBbKiwgKl0pLCkpXQB0cmFuc2Zvcm1fMAB0cmFuc2Zvcm1fMQB0cmFuc2Zvcm1fMgB0cmFuc2Zvcm1fMwAvaG9tZS9iYmFobC9weXRvcmNoL3hsYS90ZXN0L3Rlc3RfcGFsbGFzLnB5AHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBzdGFibGVfbW9zYWljLnZlcnNpb24AZGltZW5zaW9uX3NlbWFudGljcwBpdGVyYXRpb25fYm91bmRzAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAG1haW4Ad2luZG93X3BhcmFtcwBfZmxhc2hfYXR0ZW50aW9uX2tlcm5lbABfZmxhc2hfYXR0ZW50aW9uX2ltcGwAX2ZsYXNoX2F0dGVudGlvbgBmbGFzaF9hdHRlbnRpb24AdGVzdF90cHVfY3VzdG9tX2NhbGxfcGFsbGFzX3dyYXBfZmxhc2hfYXR0ZW50aW9uADxtb2R1bGU+AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPShQcmVjaXNpb24uSElHSEVTVCwgUHJlY2lzaW9uLkhJR0hFU1QpIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3JlZHVjZV9tYXhbYXhlcz0oMSwpXQAvc3ViAGZhc3RtYXRoAC9leHAAL3JlZHVjZV9zdW1bYXhlcz0oMSwpXQAvZGl2AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgwLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPShQcmVjaXNpb24uSElHSEVTVCwgUHJlY2lzaW9uLkhJR0hFU1QpIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKCosICosIEN1c3RvbU5vZGUoU2xpY2VbKDAsIDEyOCwgMSldLCBbTm9uZSwgTm9uZV0pLCBDdXN0b21Ob2RlKFNsaWNlWygwLCA0LCAxKV0sIFtOb25lLCBOb25lXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"serialization_format\": 1, \"needs_layout_passes\": true}, \"implicit_sharding\": {\"type\": \"MANUAL\"}}"
# The division is to cause potential precision issue on TPU.
q_mini = torch.arange(128 * 4, dtype=torch.float32).reshape(128, 4) / 13
k_mini = torch.arange(
@@ -209,7 +208,8 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self):

o = flash_attention_kernel(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
torch.testing.assert_close(o.cpu(), expected_o.cpu())
# self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
8 changes: 8 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
@@ -66,6 +66,14 @@ def _setup_libtpu_flags():
# improves device memory usage.
flags = _set_missing_flags(
flags, (('xla_tpu_prefer_async_allgather_to_allreduce', 'true'),))

# This flag enables FlashAttention HLO pass that pattern matches attention
# and rewrites it as flash attention. This pattern matching is causing
# issues for our standard dot product attention. Turning it off till
# we fix the issue with pattern matching.
flags = _set_missing_flags(
flags, (('xla_tpu_enable_flash_attention', 'false'),)
)

if tpu.version() == 5:
default_v5_flags = {