diff --git a/configuration.yaml b/configuration.yaml index fe70139ce33..d78492a9ec2 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -319,6 +319,11 @@ variables: dot. type: string default_value: "" + XLA_TPU_LAYOUT: + description: + - Determine to use TPU layout or not, where it will use sorted layout for TPU. + type: bool + default_value: true PT_XLA_DEBUG_FILE: description: - If set, filepath used for printing out reports. @@ -409,9 +414,9 @@ variables: default_value: false XLA_DISABLE_FUNCTIONALIZATION: description: - - Setting this to true will disable functionalization, which is a dispatcher + - Setting this to true will disable functionalization, which is a dispatcher pass in PyTorch that remove views and mutations to produce functional graphs. - This flag's main purpose is to A/B test the impact of functionalization in + This flag's main purpose is to A/B test the impact of functionalization in your code. type: bool default_value: false diff --git a/test/test_operations.py b/test/test_operations.py index 0328482ce5e..4f3cc1bed22 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -1967,6 +1967,36 @@ def test_tpu_custom_call_pallas_raise(self): torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload) output.cpu() + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + # Mosiac is not compatible with our sorted layout that boosts performance for dim > 2 tensor input applications, like resnet. + # For LLM, it should be fine since all inputs are 2D. + @unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"}) + def test_tpu_custom_call_pallas_flash_attention(self): + # This payload is generated by the following Pallas code: + # https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139 + # To be noted, set `jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)`` before generating the payload. + payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTkuMC4wZ2l0AAFBDQEDBQcJCwEDDQMFAw8FAxEHBxMVFwkLGRsdHyELAyMDrgI+AhsB8wcTCwsPEwsPDxMLCwsLkwsTCw8TDwsLCwsPCwsLDw8LCw8LDw8PDxcTE0MLGwvFC5MLCwsLGxsLGwsbCxsLGxsbGw8PDw8XDwsXDw8LFw8PCxcPDwsXDwsTCw8PFxMfCw8PFyMPEx8LDxcbDw8LDxcLDwsTHwsPFxsFCY15kWEHA1kJBV1JAR8PCxMTFxMTFxcfCxMXIwsBGw8HKx8bBxcjDwsbLy8CYg0fAwMNhwUlBScVj5UdOgJTBSkdI4kdI7UdIxYCBSsFLQUvBTEjEQlBAQAAAAAAAAABAAAAAAAAAIAAAAAAAAAABAAAAAAAAAANGQMDDYUFMxETAAMD4fsREQEFNQU3BTkFOx2/wQU9BT8FQR3PPRXRCQVDBUUBA9cFRx3bSRXdCR3rTRXtCR0GAgoCHSoCUxUuAgkDD1dZFVtfYWMpZSkXZ2lrBUkBCfPz8/cNF2FmZmluZV9tYXA8KGQwLCBkMSwgZDIsIGQzKSAtPiAoZDAsIGQxLCBkMiwgZDMpPgAFSyMRCUEDAAAAAAAAAAIAAAAAAAAAAQAAAAAAAAABAAAAAAAAAAVNBU8FUQVTAQltcXV5AwUZbxsdCSsDBRlzGx0JLQMFGXcbHQkvAwUZexsdCTEDBRUfFysDBRUfFy0DBRUfFy8DBRUfFzERAQERAwEViwkdB40XBRoIAR2RkwVVFwVKBQEVl50dmZsFVxcFqgsBFZ+lHaGjBVkXBWIDARWnrR2pqwVbFwUaAwEdr7EFXRezZQEFXxW3CR0HuRcFHggBAwMNvSUHCQAAAAAFYRXDCR0HxRcFIggBAwc19TclOckREwEDAw3NJQ0JAACA/wVjHQfTFwW2CAEDBT/9QUMREQUdRT0FZR0H3xcFuggBBWcd5UkFaQMDDeklDQkAAAAABWsdB+8XBb4IAQMFP/9BQyN0cHUuZGltZW5zaW9uX3NlbWFudGljczxwYXJhbGxlbD4AI3RwdS5jb250cmFjdF9wcmVjaXNpb248ZnAzMj4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI3ZlY3Rvci5raW5kPGFkZD4AHUVNBW0VDgIJHQcSAhcFwggBFRoCCR0HHgIXBd4IAQMDDSYCJQkJAAAAAAVvHQcyAhcF4ggBAwc19TclOSUFcQECAgMX+QkFBQIEEQtdJwUCBAIECycFAgQRCwsnAwIECycJBQUCBBELAQIEAQknBQIEBQsFEQEBAQEFBQUFAQUJAQEBAQkBAQEBBEIHBQEQAQcDARUDEQFVBwNhqxEBAQEBAQEBAQUBBQEFAQUBCQMPAwMDCQMPAwMDCQMPAwMDCQMPAwMDEQYPAw8LCRETFRcPBg8DCQMZCQMRAwMDCQMRAwMDCQMRAwMDCQMRAwMDEQYRAw8LCx0fISMPBhEDCQMlCQMzuwMHBwczxwMHBxsnKQkDO8sDDRMHO9UDDQUrLQ8G2QMVAy8VBkcDBwMxCwdHJwMHBSszGQfjJwMHAzUJA0vnAw0TB0vxAw0FNzkPBgICAxUDOxUGTwMHAz0NB08nAwcFNz8JAxMDAwMJAxMDAwMJAxMDAwMJAxMDAwMRBhMDDwsNQ0VHSQ8GEwMJA0sJA1EiAgMJBwdRNgIDCQdBTU8JAwsDAwMJAwsDAwMJAwsDAwMJAwsDAwMRBgsDDwsPU1VXWQ8GCwMJA1sPBgsDDwNRFwQLDV8PU1VXWQUAAQMRAX0HAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkDEQF/BwMLCwkBAQEBAQEBAQkDASEDAQUEAQkBAwcJAxEBgQcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMHCQMRAYMHAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkGAwEFAQDuFnOGAk4CCy8LEwsvTgJTEyEjLTEdCyMhIyl5HwsdHRUZGRkZggIdJRMdDWPHCQ0VIQsXCwsTDw8PCw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbWF0aABtb2R1bGUAcmV0dXJuAG1hdG11bABjb25zdGFudABzdWJmAGRpdmYAc2hhcGVfY2FzdABsb2FkAG11bHRpX3JlZHVjdGlvbgBicm9hZGNhc3QAc3RvcmUAZXhwAC9ob21lL2p3dGFuLy5sb2NhbC9saWIvcHl0aG9uMy4xMC9zaXRlLXBhY2thZ2VzL2pheC9leHBlcmltZW50YWwvcGFsbGFzL29wcy90cHUvZmxhc2hfYXR0ZW50aW9uLnB5AF9mbGFzaF9hdHRlbnRpb25fa2VybmVsX3NpbmdsZV9iYXRjaF9zaW5nbGVfc3RlcAB2YWx1ZQBmdW5jdGlvbl90eXBlAHN5bV9uYW1lAHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL2dldFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHRyYW5zZm9ybV8zAHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAbWFpbgB3aW5kb3dfcGFyYW1zAF9mbGFzaF9hdHRlbnRpb25fa2VybmVsAF9mbGFzaF9hdHRlbnRpb25faW1wbABfZmxhc2hfYXR0ZW50aW9uAGZsYXNoX2F0dGVudGlvbgA8bW9kdWxlPgAvbW50L2Rpc2tzL3NzZC93b3JrL3BhbGxhcy9wYWxsYXNfYWRkLnB5AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPSg8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+LCA8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+KSBwcmVmZXJyZWRfZWxlbWVudF90eXBlPWZsb2F0MzJdAC9yZWR1Y2VfbWF4W2F4ZXM9KDEsKV0AL3N1YgBmYXN0bWF0aAAvZXhwAC9yZWR1Y2Vfc3VtW2F4ZXM9KDEsKV0AL2RpdgAvZG90X2dlbmVyYWxbZGltZW5zaW9uX251bWJlcnM9KCgoMSwpLCAoMCwpKSwgKCgpLCAoKSkpIHByZWNpc2lvbj0oPFByZWNpc2lvbi5ISUdIRVNUOiAyPiwgPFByZWNpc2lvbi5ISUdIRVNUOiAyPikgcHJlZmVycmVkX2VsZW1lbnRfdHlwZT1mbG9hdDMyXQAvc3dhcFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"needs_layout_passes\": true}}" + + # The division is to cause potential precision issue on TPU. + q_mini = torch.arange(128 * 4, dtype=torch.float32).reshape(128, 4) / 13 + k_mini = torch.arange( + 1000, 1000 + 128 * 4, dtype=torch.float32).reshape(128, 4) / 13 + q = q_mini.broadcast_to(3, 2, 128, 4).to("xla") + k = k_mini.broadcast_to(3, 2, 128, 4).to("xla") + v = torch.ones(3, 2, 128, 4).to("xla") + o = torch.zeros(3, 2, 128, 4).to("xla") + + def attention(q, k, v): + attn_weight = q @ k.transpose(-2, -1) + attn_weight = nn.functional.softmax(attn_weight, dim=-1) + attn_output = attn_weight @ v + return attn_output + + expected_o = attention(q, k, v) + + torch_xla._XLAC._xla_tpu_custom_call_(o, [q, k, v], payload) + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") # TODO: Make the tpu_custom_call_ as functional. @unittest.mock.patch.dict(os.environ, {"XLA_DISABLE_FUNCTIONALIZATION": "1"}) diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index b488acbaefc..33deba58f24 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -184,7 +184,10 @@ xla::Shape MakeArrayShapeFromDimensions( return MakeShapeWithLayout(type, dimensions, dynamic_dimensions, *layout_ptr); } - if (dimensions.size() > 1 && hw_type == XlaDeviceType::TPU) { + + bool tpu_layout_env = runtime::sys_util::GetEnvBool("XLA_TPU_LAYOUT", true); + if (tpu_layout_env && dimensions.size() > 1 && + hw_type == XlaDeviceType::TPU) { return MakeTpuShape(dimensions, dynamic_dimensions, type); } return MakeTorchTensorLayout(dimensions, dynamic_dimensions, type);