diff --git a/test/test_operations.py b/test/test_operations.py index 9e274218c4c6..8cd7a06dc8a0 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -44,7 +44,7 @@ import torch_xla.core.xla_model as xm import torch_xla.core.functions as xf import torch_xla.debug.profiler as xp -import torchvision +# import torchvision import unittest import test_utils @@ -1896,6 +1896,22 @@ def test_tpu_custom_call_pallas_raise(self): torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload) output.cpu() + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_tpu_custom_call_pallas_flash_attention(self): + # This payload is generated by the following Pallas code: + # https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139 + payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTkuMC4wZ2l0AAFBDQEDBQcJCwEDDQMFAw8FAxEHBxMVFwkLGRsdHyELAyMDpgI2AhsB9QcTCwsPEwsPDxMLCwsLkwsTCw8TDwsLCwsPCwsPDwsLDwsPDw8PExMTQwsbC8ULkwsLCwsbGwsbCxsLGwsbGxsbDw8PDxcPCxcPDwsXDw8LFw8PCxcPCxMLDw8XEx8LDw8XGw8THwsPFxsPDwsPFwsPCxMfCw8XGw8LBQeNkWEHA1kJBV1JARsTExcTExcXHwsTFxsLARsPBysfGwcXIw8LGy8vAs4MHwMDDYUFJQUnFY2THTICUQUpHSOHHSOzHSMOAgUrBS0FLwUxIxEJQQEAAAAAAAAAAQAAAAAAAACAAAAAAAAAAAQAAAAAAAAADRkDAw2DBTMREwADA9/7EREBBTUFNwU5BTsdvb8FPQU/Hc07Fc8JBUEFQwED1QVFHdlHFdsJHelLFesJHfMCAh0iAlEVJgIJAw9VVxVZXV9hKWMpF2VnaQVHAQn19fX3DRdhZmZpbmVfbWFwPChkMCwgZDEsIGQyLCBkMykgLT4gKGQwLCBkMSwgZDIsIGQzKT4ABUkjEQlBAwAAAAAAAAACAAAAAAAAAAEAAAAAAAAAAQAAAAAAAAAFSwVNBU8FUQEJa29zdwMFGW0bHQkrAwUZcRsdCS0DBRl1Gx0JLwMFGXkbHQkxAwUVHxcrAwUVHxctAwUVHxcvAwUVHxcxEQEBEQMBFYkJHQeLFwUaCAEdj5EFUxcFSgUBFZWbHZeZBVUXBaoLARWdox2foQVXFwViAwEVpasdp6kFWRcFGgMBHa2vBVsXsVEBBV0VtQkdB7cXBR4IAQMDDbslBwkAAAAABV8VwQkdB8MXBSIIAQMFNSU3xxETAQMDDcslDQkAAID/BWEdB9EXBbYIAQMFPf0/QRERBR1DOwVjHQfdFwW6CAEFZR3jRwVnAwMN5yUNCQAAAAAFaR0H7RcFvggBAwU9/z9BHUNLBWsjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8cGFyYWxsZWw+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACNhcml0aC5mYXN0bWF0aDxub25lPgAjdmVjdG9yLmtpbmQ8bWF4aW11bWY+ACN2ZWN0b3Iua2luZDxhZGQ+ABUGAgkdBwoCFwXCCAEVEgIJHQcWAhcF3ggBAwMNHgIlCQkAAAAABW0dByoCFwXiCAEDBTUlNyUFbwECAgMX+QkFBQIEEQtbJwUCBAIECycFAgQRCwsnAwIECycJBQUCBBELAQIEAQknBQIEBQsFEQEBAQEFBQUFAQUJAQEBAQkBAQEBBD4HBQEQAQcDARUDEQFTBwNhqxEBAQEBAQEBAQUBBQEFAQUBCQMPAwMDCQMPAwMDCQMPAwMDCQMPAwMDEQYPAw8LCRETFRcPBg8DCQMZCQMRAwMDCQMRAwMDCQMRAwMDCQMRAwMDEQYRAw8LCx0fISMPBhEDCQMlCQMzuQMHBwczxQMHBxsnKQkDOckDDRMHOdMDDQUrLQ8G1wMVAy8VBkUDBwMxCwdFJwMHBSszGQfhJwMHAzUJA0nlAw0TB0nvAw0FNzkPBvEDFQM7FQZNAwcDPQ0HTScDBwU3PwkDEwMDAwkDEwMDAwkDEwMDAwkDEwMDAxEGEwMPCw1DRUdJDwYTAwkDSwkDTxoCAwkHB08uAgMJB0FNTwkDCwMDAwkDCwMDAwkDCwMDAwkDCwMDAxEGCwMPCw9TVVdZDwYLAwkDWw8GCwMPA1EXBAsNXw9TVVdZBQABAxEBewcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMFCQMRAX0HAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBwkDEQF/BwMLCwkBAQEBAQEBAQkDASEDAQUEAQkBAwcJAxEBgQcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMFCQYDAQUBAFoVcYYCzwsvCxMLL89TEyEjLTEdCyMhIyl5HwsdHRkZGRmCAh0lEx0NY8cJDRUhCxcLCxMPDw8LDw0JCxFidWlsdGluAGZ1bmMAdHB1AGFyaXRoAHZlY3RvcgBtYXRoAG1vZHVsZQByZXR1cm4AbWF0bXVsAGNvbnN0YW50AHN1YmYAZGl2ZgBzaGFwZV9jYXN0AGxvYWQAbXVsdGlfcmVkdWN0aW9uAGJyb2FkY2FzdABzdG9yZQBleHAAL2hvbWUvand0YW4vLmxvY2FsL2xpYi9weXRob24zLjEwL3NpdGUtcGFja2FnZXMvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9mbGFzaF9hdHRlbnRpb24ucHkAX2ZsYXNoX2F0dGVudGlvbl9rZXJuZWxfc2luZ2xlX2JhdGNoX3NpbmdsZV9zdGVwAHZhbHVlAGZ1bmN0aW9uX3R5cGUAc3ltX25hbWUAdHJhbnNmb3JtX2luZGljZXMAd2luZG93X2JvdW5kcwAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKCgqLCAqLCBDdXN0b21Ob2RlKFNsaWNlWygwLCAxMjgpXSwgW10pLCBDdXN0b21Ob2RlKFNsaWNlWygwLCA0KV0sIFtdKSkpLCAoMSwgMSwgMTI4LCA0KSwgKCkpXSwgWyosICpdKSwpKV0AdHJhbnNmb3JtXzAAdHJhbnNmb3JtXzEAdHJhbnNmb3JtXzIAdHJhbnNmb3JtXzMAdHJhbnNwb3NlX2xocwB0cmFuc3Bvc2VfcmhzAGtpbmQAcmVkdWN0aW9uX2RpbXMAL2Jyb2FkY2FzdF9pbl9kaW1bc2hhcGU9KDEyOCwgMSkgYnJvYWRjYXN0X2RpbWVuc2lvbnM9KDAsKV0AZGltZW5zaW9uX3NlbWFudGljcwBpdGVyYXRpb25fYm91bmRzAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAG1haW4Ad2luZG93X3BhcmFtcwBfZmxhc2hfYXR0ZW50aW9uX2tlcm5lbABfZmxhc2hfYXR0ZW50aW9uX2ltcGwAX2ZsYXNoX2F0dGVudGlvbgBmbGFzaF9hdHRlbnRpb24APG1vZHVsZT4AL21udC9kaXNrcy9zc2Qvd29yay9wYWxsYXMvcGFsbGFzX2FkZC5weQAvZG90X2dlbmVyYWxbZGltZW5zaW9uX251bWJlcnM9KCgoMSwpLCAoMSwpKSwgKCgpLCAoKSkpIHByZWNpc2lvbj1Ob25lIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3JlZHVjZV9tYXhbYXhlcz0oMSwpXQAvc3ViAGZhc3RtYXRoAC9leHAAL3JlZHVjZV9zdW1bYXhlcz0oMSwpXQAvZGl2AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgwLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPU5vbmUgcHJlZmVycmVkX2VsZW1lbnRfdHlwZT1mbG9hdDMyXQAvc3dhcFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"needs_layout_passes\": true}}" + + q = torch.ones(3, 2, 128, 4).to("xla") + k = torch.ones(3, 2, 128, 4).to("xla") + v = torch.ones(3, 2, 128, 4).to("xla") + o = torch.zeros(3, 2, 128, 4).to("xla") + + torch_xla._XLAC._xla_tpu_custom_call_(o, [q, k, v], payload) + hlo = torch_xla._XLAC._get_xla_tensors_hlo([o]) + print(hlo) + print(o) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/layout_manager.cpp b/torch_xla/csrc/layout_manager.cpp index b488acbaefc1..e42ea769c13a 100644 --- a/torch_xla/csrc/layout_manager.cpp +++ b/torch_xla/csrc/layout_manager.cpp @@ -143,13 +143,16 @@ xla::Shape MakeTpuShape(absl::Span dimensions, static double max_padding_factor = runtime::sys_util::GetEnvDouble("XLA_MAX_PADDING_FACTOR", 1.25); xla::Shape shape; - if (PaddingFactor(dimensions[dimensions.size() - 1], 128) * - PaddingFactor(dimensions[dimensions.size() - 2], 8) < - max_padding_factor) { - shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(type, dimensions); - } else { - shape = MakeShapeWithSortedLayout(dimensions, type); - } + // if (PaddingFactor(dimensions[dimensions.size() - 1], 128) * + // PaddingFactor(dimensions[dimensions.size() - 2], 8) < + // max_padding_factor) { + // std::cerr << "ff7: layout A" << std::endl; + // shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(type, dimensions); + // } else { + // std::cerr << "ff7: layout B" << std::endl; + // shape = MakeShapeWithSortedLayout(dimensions, type); + // } + shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(type, dimensions); SetDynamicDimensions(&shape, dynamic_dimensions); return shape; } @@ -181,10 +184,12 @@ xla::Shape MakeArrayShapeFromDimensions( XlaDeviceType hw_type) { auto layout_ptr = LayoutManager::Get()->GetLayout(dimensions); if (layout_ptr != nullptr) { + std::cerr << "ff7: layout 1" << std::endl; return MakeShapeWithLayout(type, dimensions, dynamic_dimensions, *layout_ptr); } if (dimensions.size() > 1 && hw_type == XlaDeviceType::TPU) { + std::cerr << "ff7: layout 2" << std::endl; return MakeTpuShape(dimensions, dynamic_dimensions, type); } return MakeTorchTensorLayout(dimensions, dynamic_dimensions, type); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index e5af6a37dd55..c99fb02da57d 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1247,7 +1247,11 @@ xla::XlaOp BuildTpuCustomCall(const std::vector& inputs, std::vector input_shapes; input_shapes.reserve(inputs.size()); for (const auto& input : inputs) { - input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input)); + auto shape = ShapeHelper::ShapeOfXlaOp(input); + // shape.mutable_layout()->clear_minor_to_major(); + std::cerr << "ff7:" << shape << std::endl; + + input_shapes.push_back(std::move(shape)); } XLA_CHECK(inputs.size() > 0) << "inputs are empty";