From 3c7daa246f21130cc34b8d413f127e7c11ff6b7a Mon Sep 17 00:00:00 2001 From: Bhavya Bahl Date: Fri, 13 Sep 2024 16:16:38 -0700 Subject: [PATCH] Set xla_tpu_enable_flash_attention=false to enable libtpu pin update (#8008) --- test/test_pallas.py | 6 +++--- torch_xla/__init__.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 4e7fb90cbf7..28088566118 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -117,8 +117,7 @@ def test_tpu_custom_call_pallas_flash_attention(self): # This payload is generated by the following Pallas code: # https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139 # To be noted, set `jax.config.update('jax_default_matmul_precision', 'highest')`` before generating the payload. - payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTkuMC4wZ2l0AAFBDQEDBQcJCwEDDQMFAw8FAxEHBxMVFwkLGRsdHyELAyMDrgI+AhsB8wcTCwsPEwsPDxMLCwsLkwsTCw8TDwsLCwsPCwsLDw8LCw8LDw8PDxcTE0MLGwvFC5MLCwsLGxsLGwsbCxsLGxsbGw8PDw8XDwsXDw8LFw8PCxcPDwsXDwsTCw8PFxMfCw8PFyMPEx8LDxcbDw8LDxcLDwsTHwsPFxsFCY15kWEHA1kJBV1JAR8PCxMTFxMTFxcfCxMXIwsBGw8HKx8bBxcjDwsbLy8CYg0fAwMNhwUlBScVj5UdOgJTBSkdI4kdI7UdIxYCBSsFLQUvBTEjEQlBAQAAAAAAAAABAAAAAAAAAIAAAAAAAAAABAAAAAAAAAANGQMDDYUFMxETAAMD4fsREQEFNQU3BTkFOx2/wQU9BT8FQR3PPRXRCQVDBUUBA9cFRx3bSRXdCR3rTRXtCR0GAgoCHSoCUxUuAgkDD1dZFVtfYWMpZSkXZ2lrBUkBCfPz8/cNF2FmZmluZV9tYXA8KGQwLCBkMSwgZDIsIGQzKSAtPiAoZDAsIGQxLCBkMiwgZDMpPgAFSyMRCUEDAAAAAAAAAAIAAAAAAAAAAQAAAAAAAAABAAAAAAAAAAVNBU8FUQVTAQltcXV5AwUZbxsdCSsDBRlzGx0JLQMFGXcbHQkvAwUZexsdCTEDBRUfFysDBRUfFy0DBRUfFy8DBRUfFzERAQERAwEViwkdB40XBRoIAR2RkwVVFwVKBQEVl50dmZsFVxcFqgsBFZ+lHaGjBVkXBWIDARWnrR2pqwVbFwUaAwEdr7EFXRezZQEFXxW3CR0HuRcFHggBAwMNvSUHCQAAAAAFYRXDCR0HxRcFIggBAwc19TclOckREwEDAw3NJQ0JAACA/wVjHQfTFwW2CAEDBT/9QUMREQUdRT0FZR0H3xcFuggBBWcd5UkFaQMDDeklDQkAAAAABWsdB+8XBb4IAQMFP/9BQyN0cHUuZGltZW5zaW9uX3NlbWFudGljczxwYXJhbGxlbD4AI3RwdS5jb250cmFjdF9wcmVjaXNpb248ZnAzMj4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI3ZlY3Rvci5raW5kPGFkZD4AHUVNBW0VDgIJHQcSAhcFwggBFRoCCR0HHgIXBd4IAQMDDSYCJQkJAAAAAAVvHQcyAhcF4ggBAwc19TclOSUFcQECAgMX+QkFBQIEEQtdJwUCBAIECycFAgQRCwsnAwIECycJBQUCBBELAQIEAQknBQIEBQsFEQEBAQEFBQUFAQUJAQEBAQkBAQEBBEIHBQEQAQcDARUDEQFVBwNhqxEBAQEBAQEBAQUBBQEFAQUBCQMPAwMDCQMPAwMDCQMPAwMDCQMPAwMDEQYPAw8LCRETFRcPBg8DCQMZCQMRAwMDCQMRAwMDCQMRAwMDCQMRAwMDEQYRAw8LCx0fISMPBhEDCQMlCQMzuwMHBwczxwMHBxsnKQkDO8sDDRMHO9UDDQUrLQ8G2QMVAy8VBkcDBwMxCwdHJwMHBSszGQfjJwMHAzUJA0vnAw0TB0vxAw0FNzkPBgICAxUDOxUGTwMHAz0NB08nAwcFNz8JAxMDAwMJAxMDAwMJAxMDAwMJAxMDAwMRBhMDDwsNQ0VHSQ8GEwMJA0sJA1EiAgMJBwdRNgIDCQdBTU8JAwsDAwMJAwsDAwMJAwsDAwMJAwsDAwMRBgsDDwsPU1VXWQ8GCwMJA1sPBgsDDwNRFwQLDV8PU1VXWQUAAQMRAX0HAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkDEQF/BwMLCwkBAQEBAQEBAQkDASEDAQUEAQkBAwcJAxEBgQcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMHCQMRAYMHAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkGAwEFAQDuFnOGAk4CCy8LEwsvTgJTEyEjLTEdCyMhIyl5HwsdHRUZGRkZggIdJRMdDWPHCQ0VIQsXCwsTDw8PCw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbWF0aABtb2R1bGUAcmV0dXJuAG1hdG11bABjb25zdGFudABzdWJmAGRpdmYAc2hhcGVfY2FzdABsb2FkAG11bHRpX3JlZHVjdGlvbgBicm9hZGNhc3QAc3RvcmUAZXhwAC9ob21lL2p3dGFuLy5sb2NhbC9saWIvcHl0aG9uMy4xMC9zaXRlLXBhY2thZ2VzL2pheC9leHBlcmltZW50YWwvcGFsbGFzL29wcy90cHUvZmxhc2hfYXR0ZW50aW9uLnB5AF9mbGFzaF9hdHRlbnRpb25fa2VybmVsX3NpbmdsZV9iYXRjaF9zaW5nbGVfc3RlcAB2YWx1ZQBmdW5jdGlvbl90eXBlAHN5bV9uYW1lAHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL2dldFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHRyYW5zZm9ybV8zAHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAbWFpbgB3aW5kb3dfcGFyYW1zAF9mbGFzaF9hdHRlbnRpb25fa2VybmVsAF9mbGFzaF9hdHRlbnRpb25faW1wbABfZmxhc2hfYXR0ZW50aW9uAGZsYXNoX2F0dGVudGlvbgA8bW9kdWxlPgAvbW50L2Rpc2tzL3NzZC93b3JrL3BhbGxhcy9wYWxsYXNfYWRkLnB5AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPSg8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+LCA8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+KSBwcmVmZXJyZWRfZWxlbWVudF90eXBlPWZsb2F0MzJdAC9yZWR1Y2VfbWF4W2F4ZXM9KDEsKV0AL3N1YgBmYXN0bWF0aAAvZXhwAC9yZWR1Y2Vfc3VtW2F4ZXM9KDEsKV0AL2RpdgAvZG90X2dlbmVyYWxbZGltZW5zaW9uX251bWJlcnM9KCgoMSwpLCAoMCwpKSwgKCgpLCAoKSkpIHByZWNpc2lvbj0oPFByZWNpc2lvbi5ISUdIRVNUOiAyPiwgPFByZWNpc2lvbi5ISUdIRVNUOiAyPikgcHJlZmVycmVkX2VsZW1lbnRfdHlwZT1mbG9hdDMyXQAvc3dhcFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"needs_layout_passes\": true}}" - + payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMjAuMC4wZ2l0AAEvCwEDBQcJAQMLAxkNDxETFRcZGx0fISMD0gJaAhsB9QcTCwsPExMLDw8TCwsLC5MLCw8TDwsLCwsLDwsLCw8PCwsPCw8PExMXExMTCw9DCxsLxQuTCwsLCxsbCxsLGwsbCxsbGxsPDw8PFw8LFw8PCxcPDwsXDw8LFw8PCxcPCxcPDxcTHwsPDxcjDxMfCw8XGw8PCw8XCw8LBQmNeZFhBwNdCQNZASsXHwsTFx8PCxMTFxMTFxcfCxMXIwsHA0kBGw8HKx8bBxcPIwsbLy8C0g0fAwMPjwUlBScVl50DAw+NHVICVQUpHSORHSPDHSMuAgUrBS0FLwUxIw8JQQEAAAAAAAAAAQAAAAAAAACAAAAAAAAAAAQAAAAAAAAADRkFMxETAAMD7/8RDwEFNQU3BTkFOwU9Hc3PBT8FQQVDHd0/Fd8JBUUFRwED5QVJHelLFesJHQoCTxUOAgkdHgIiAh1CAlUVRgIJAwNZWwVLEQ8JAw9fYRdjZ2lrKW0pGW9xcwVNAQn19fX5DRdhZmZpbmVfbWFwPChkMCwgZDEsIGQyLCBkMykgLT4gKGQwLCBkMSwgZDIsIGQzKT4ABU8jDwlBAwAAAAAAAAACAAAAAAAAAAEAAAAAAAAAAQAAAAAAAAAFUQVTBVUFVwEJdXl9gQMFG3cdHwkrAwUbex0fCS0DBRt/HR8JLwMFG4MdHwkxAwUXIRkrAwUXIRktAwUXIRkvAwUXIRkxEQEBEQMBFZMJHQeVFwUGCAEdmZsFWRcFSgUBFZ+lHaGjBVsXBYYLARWnrR2pqwVdFwViAwEVr7UdsbMFXxcFGgMBFbe9Hbm7BWEXM14DAR2/wQVjFzM2EAEVxQkdB8cXBQoIAQMDD8slBwkAAAAABWUV0QkdB9MXBQ4IAQMHN/c5JTvXERMBAwMP2yUNCQAAgP8FZx0H4RcFoggBAwVB/UNFEQ8FHUc/BWkdB+0XBaYIAQVrHfNLBW0jdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8cGFyYWxsZWw+ACN0cHUuY29udHJhY3RfcHJlY2lzaW9uPGZwMzI+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+AAMDDwYCJQ0JAAAAAAVvHQcSAhcFqggBAwVBVgJDRR1HTwVxFSYCCR0HKgIXBa4IARUyAgkdBzYCFwXKCAEDAw8+AiUJCQAAAAAFcx0HSgIXBc4IAQMHN/c5JTslBXUjdmVjdG9yLmtpbmQ8YWRkPgABAgIDF/sJBQUCBBELZScFAgQCBAsnBQIEEQsLJwMCBAsBAgQnCQUFAgQRCwEJJwUCBAULBREBAQEBBQUFBQEFCQEBAQEJAQEBAQSuBwUBEQFXBwMBFQcRAV0HA2GrEQEBAQEBAQEBBQEFAQUBBQEDAxEDAwMDAxEDAwMDAxEDAwMDAxEDAwMLBhEDEQsJERMVFwUGEQMJAxkDAxMDAwMDAxMDAwMDAxMDAwMDAxMDAwMLBhMDEQsLHR8hIwUGEwMJAyUDAzXJAwcNBzXVAwcHGycpAwM92QMNDwc94wMNBSstBQbnAxUDLxEGSQMHAzETB0knAwcFKzMVB/EnAwcDNQMDTQICAw0PB00WAgMNBTc5BQYaAgMVAzsRBlEDBwM9FwdRJwMHBTc/AwMVAwMDAwMVAwMDAwMVAwMDAwMVAwMDCwYVAxELDUNFR0kFBhUDCQNLAwNTOgIDCQ0HU04CAwkHQU1PAwMNAwMDAwMNAwMDAwMNAwMDAwMNAwMDCwYNAxELD1NVV1kFBg0DCQNbBQYNAxEDURkEDQ1fD1NVV1kJAAEHEQGFBwMNDwkBAQEBAQEBAQMDAQsDAQMDAQsDAQkEAQkBAwUJBxEBhwcDDQ8JAQEBAQEBAQEDAwELAwEDAwELAwEJBAEJAQMHCQcRAYkHAw0PCQEBAQEBAQEBAwMBCwMBAwMBCwMBCQQBCQEDBwkHEQGLBwMNDwkBAQEBAQEBAQMDAQsDAQMDAQsDAQkEAQkBAwUJBgMBBQEAMhp37gImAgsvCxMLLyYCE2MhIy0xHQsjISMpLXkfCx0dFVkZGRkZ6gIdJRMdDWPvGxcTFyMvFxkZFSUfDw0PCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQB2ZWN0b3IAYXJpdGgAbW9kdWxlAGFyaXRoLmNvbnN0YW50AHZlY3Rvci5zaGFwZV9jYXN0AGZ1bmMuZnVuYwBmdW5jLnJldHVybgB2ZWN0b3IubG9hZAB0cHUubWF0bXVsAHZlY3Rvci5tdWx0aV9yZWR1Y3Rpb24AdmVjdG9yLmJyb2FkY2FzdABhcml0aC5zdWJmAG1hdGguZXhwAGFyaXRoLmRpdmYAdmVjdG9yLnN0b3JlAC9ob21lL2JiYWhsL21pbmljb25kYTMvZW52cy90b3JjaHNlcDEwL2xpYi9weXRob24zLjEwL3NpdGUtcGFja2FnZXMvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9mbGFzaF9hdHRlbnRpb24ucHkAX2ZsYXNoX2F0dGVudGlvbl9rZXJuZWxfc2luZ2xlX2JhdGNoX3NpbmdsZV9zdGVwAHZhbHVlAGZ1bmN0aW9uX3R5cGUAc3ltX25hbWUAdHJhbnNmb3JtX2luZGljZXMAd2luZG93X2JvdW5kcwAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKCgqLCAqLCBDdXN0b21Ob2RlKFNsaWNlWygwLCAxMjgsIDEpXSwgW05vbmUsIE5vbmVdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCwgMSldLCBbTm9uZSwgTm9uZV0pKSksICgxLCAxLCAxMjgsIDQpLCAoKSldLCBbKiwgKl0pLCkpXQB0cmFuc2Zvcm1fMAB0cmFuc2Zvcm1fMQB0cmFuc2Zvcm1fMgB0cmFuc2Zvcm1fMwAvaG9tZS9iYmFobC9weXRvcmNoL3hsYS90ZXN0L3Rlc3RfcGFsbGFzLnB5AHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBzdGFibGVfbW9zYWljLnZlcnNpb24AZGltZW5zaW9uX3NlbWFudGljcwBpdGVyYXRpb25fYm91bmRzAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAG1haW4Ad2luZG93X3BhcmFtcwBfZmxhc2hfYXR0ZW50aW9uX2tlcm5lbABfZmxhc2hfYXR0ZW50aW9uX2ltcGwAX2ZsYXNoX2F0dGVudGlvbgBmbGFzaF9hdHRlbnRpb24AdGVzdF90cHVfY3VzdG9tX2NhbGxfcGFsbGFzX3dyYXBfZmxhc2hfYXR0ZW50aW9uADxtb2R1bGU+AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPShQcmVjaXNpb24uSElHSEVTVCwgUHJlY2lzaW9uLkhJR0hFU1QpIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3JlZHVjZV9tYXhbYXhlcz0oMSwpXQAvc3ViAGZhc3RtYXRoAC9leHAAL3JlZHVjZV9zdW1bYXhlcz0oMSwpXQAvZGl2AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgwLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPShQcmVjaXNpb24uSElHSEVTVCwgUHJlY2lzaW9uLkhJR0hFU1QpIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKCosICosIEN1c3RvbU5vZGUoU2xpY2VbKDAsIDEyOCwgMSldLCBbTm9uZSwgTm9uZV0pLCBDdXN0b21Ob2RlKFNsaWNlWygwLCA0LCAxKV0sIFtOb25lLCBOb25lXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"serialization_format\": 1, \"needs_layout_passes\": true}, \"implicit_sharding\": {\"type\": \"MANUAL\"}}" # The division is to cause potential precision issue on TPU. q_mini = torch.arange(128 * 4, dtype=torch.float32).reshape(128, 4) / 13 k_mini = torch.arange( @@ -209,7 +208,8 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self): o = flash_attention_kernel(q, k, v) expected_o = self._attention(q, k, v) - self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) + torch.testing.assert_close(o.cpu(), expected_o.cpu()) + # self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, "This test only works on TPUv3+.") diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 3a75796b155..3cefe486417 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -66,6 +66,14 @@ def _setup_libtpu_flags(): # improves device memory usage. flags = _set_missing_flags( flags, (('xla_tpu_prefer_async_allgather_to_allreduce', 'true'),)) + + # This flag enables FlashAttention HLO pass that pattern matches attention + # and rewrites it as flash attention. This pattern matching is causing + # issues for our standard dot product attention. Turning it off till + # we fix the issue with pattern matching. + flags = _set_missing_flags( + flags, (('xla_tpu_enable_flash_attention', 'false'),) + ) if tpu.version() == 5: default_v5_flags = {