Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
alanwaketan committed Mar 1, 2024
1 parent 169633f commit e3f2424
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
18 changes: 17 additions & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import torch_xla.core.xla_model as xm
import torch_xla.core.functions as xf
import torch_xla.debug.profiler as xp
import torchvision
# import torchvision
import unittest
import test_utils

Expand Down Expand Up @@ -1896,6 +1896,22 @@ def test_tpu_custom_call_pallas_raise(self):
torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload)
output.cpu()

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_flash_attention(self):
# This payload is generated by the following Pallas code:
# https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTkuMC4wZ2l0AAFBDQEDBQcJCwEDDQMFAw8FAxEHBxMVFwkLGRsdHyELAyMDpgI2AhsB9QcTCwsPEwsPDxMLCwsLkwsTCw8TDwsLCwsPCwsPDwsLDwsPDw8PExMTQwsbC8ULkwsLCwsbGwsbCxsLGwsbGxsbDw8PDxcPCxcPDwsXDw8LFw8PCxcPCxMLDw8XEx8LDw8XGw8THwsPFxsPDwsPFwsPCxMfCw8XGw8LBQeNkWEHA1kJBV1JARsTExcTExcXHwsTFxsLARsPBysfGwcXIw8LGy8vAs4MHwMDDYUFJQUnFY2THTICUQUpHSOHHSOzHSMOAgUrBS0FLwUxIxEJQQEAAAAAAAAAAQAAAAAAAACAAAAAAAAAAAQAAAAAAAAADRkDAw2DBTMREwADA9/7EREBBTUFNwU5BTsdvb8FPQU/Hc07Fc8JBUEFQwED1QVFHdlHFdsJHelLFesJHfMCAh0iAlEVJgIJAw9VVxVZXV9hKWMpF2VnaQVHAQn19fX3DRdhZmZpbmVfbWFwPChkMCwgZDEsIGQyLCBkMykgLT4gKGQwLCBkMSwgZDIsIGQzKT4ABUkjEQlBAwAAAAAAAAACAAAAAAAAAAEAAAAAAAAAAQAAAAAAAAAFSwVNBU8FUQEJa29zdwMFGW0bHQkrAwUZcRsdCS0DBRl1Gx0JLwMFGXkbHQkxAwUVHxcrAwUVHxctAwUVHxcvAwUVHxcxEQEBEQMBFYkJHQeLFwUaCAEdj5EFUxcFSgUBFZWbHZeZBVUXBaoLARWdox2foQVXFwViAwEVpasdp6kFWRcFGgMBHa2vBVsXsVEBBV0VtQkdB7cXBR4IAQMDDbslBwkAAAAABV8VwQkdB8MXBSIIAQMFNSU3xxETAQMDDcslDQkAAID/BWEdB9EXBbYIAQMFPf0/QRERBR1DOwVjHQfdFwW6CAEFZR3jRwVnAwMN5yUNCQAAAAAFaR0H7RcFvggBAwU9/z9BHUNLBWsjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8cGFyYWxsZWw+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACNhcml0aC5mYXN0bWF0aDxub25lPgAjdmVjdG9yLmtpbmQ8bWF4aW11bWY+ACN2ZWN0b3Iua2luZDxhZGQ+ABUGAgkdBwoCFwXCCAEVEgIJHQcWAhcF3ggBAwMNHgIlCQkAAAAABW0dByoCFwXiCAEDBTUlNyUFbwECAgMX+QkFBQIEEQtbJwUCBAIECycFAgQRCwsnAwIECycJBQUCBBELAQIEAQknBQIEBQsFEQEBAQEFBQUFAQUJAQEBAQkBAQEBBD4HBQEQAQcDARUDEQFTBwNhqxEBAQEBAQEBAQUBBQEFAQUBCQMPAwMDCQMPAwMDCQMPAwMDCQMPAwMDEQYPAw8LCRETFRcPBg8DCQMZCQMRAwMDCQMRAwMDCQMRAwMDCQMRAwMDEQYRAw8LCx0fISMPBhEDCQMlCQMzuQMHBwczxQMHBxsnKQkDOckDDRMHOdMDDQUrLQ8G1wMVAy8VBkUDBwMxCwdFJwMHBSszGQfhJwMHAzUJA0nlAw0TB0nvAw0FNzkPBvEDFQM7FQZNAwcDPQ0HTScDBwU3PwkDEwMDAwkDEwMDAwkDEwMDAwkDEwMDAxEGEwMPCw1DRUdJDwYTAwkDSwkDTxoCAwkHB08uAgMJB0FNTwkDCwMDAwkDCwMDAwkDCwMDAwkDCwMDAxEGCwMPCw9TVVdZDwYLAwkDWw8GCwMPA1EXBAsNXw9TVVdZBQABAxEBewcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMFCQMRAX0HAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBwkDEQF/BwMLCwkBAQEBAQEBAQkDASEDAQUEAQkBAwcJAxEBgQcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMFCQYDAQUBAFoVcYYCzwsvCxMLL89TEyEjLTEdCyMhIyl5HwsdHRkZGRmCAh0lEx0NY8cJDRUhCxcLCxMPDw8LDw0JCxFidWlsdGluAGZ1bmMAdHB1AGFyaXRoAHZlY3RvcgBtYXRoAG1vZHVsZQByZXR1cm4AbWF0bXVsAGNvbnN0YW50AHN1YmYAZGl2ZgBzaGFwZV9jYXN0AGxvYWQAbXVsdGlfcmVkdWN0aW9uAGJyb2FkY2FzdABzdG9yZQBleHAAL2hvbWUvand0YW4vLmxvY2FsL2xpYi9weXRob24zLjEwL3NpdGUtcGFja2FnZXMvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9mbGFzaF9hdHRlbnRpb24ucHkAX2ZsYXNoX2F0dGVudGlvbl9rZXJuZWxfc2luZ2xlX2JhdGNoX3NpbmdsZV9zdGVwAHZhbHVlAGZ1bmN0aW9uX3R5cGUAc3ltX25hbWUAdHJhbnNmb3JtX2luZGljZXMAd2luZG93X2JvdW5kcwAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKCgqLCAqLCBDdXN0b21Ob2RlKFNsaWNlWygwLCAxMjgpXSwgW10pLCBDdXN0b21Ob2RlKFNsaWNlWygwLCA0KV0sIFtdKSkpLCAoMSwgMSwgMTI4LCA0KSwgKCkpXSwgWyosICpdKSwpKV0AdHJhbnNmb3JtXzAAdHJhbnNmb3JtXzEAdHJhbnNmb3JtXzIAdHJhbnNmb3JtXzMAdHJhbnNwb3NlX2xocwB0cmFuc3Bvc2VfcmhzAGtpbmQAcmVkdWN0aW9uX2RpbXMAL2Jyb2FkY2FzdF9pbl9kaW1bc2hhcGU9KDEyOCwgMSkgYnJvYWRjYXN0X2RpbWVuc2lvbnM9KDAsKV0AZGltZW5zaW9uX3NlbWFudGljcwBpdGVyYXRpb25fYm91bmRzAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAG1haW4Ad2luZG93X3BhcmFtcwBfZmxhc2hfYXR0ZW50aW9uX2tlcm5lbABfZmxhc2hfYXR0ZW50aW9uX2ltcGwAX2ZsYXNoX2F0dGVudGlvbgBmbGFzaF9hdHRlbnRpb24APG1vZHVsZT4AL21udC9kaXNrcy9zc2Qvd29yay9wYWxsYXMvcGFsbGFzX2FkZC5weQAvZG90X2dlbmVyYWxbZGltZW5zaW9uX251bWJlcnM9KCgoMSwpLCAoMSwpKSwgKCgpLCAoKSkpIHByZWNpc2lvbj1Ob25lIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3JlZHVjZV9tYXhbYXhlcz0oMSwpXQAvc3ViAGZhc3RtYXRoAC9leHAAL3JlZHVjZV9zdW1bYXhlcz0oMSwpXQAvZGl2AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgwLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPU5vbmUgcHJlZmVycmVkX2VsZW1lbnRfdHlwZT1mbG9hdDMyXQAvc3dhcFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"needs_layout_passes\": true}}"

q = torch.ones(3, 2, 128, 4).to("xla")
k = torch.ones(3, 2, 128, 4).to("xla")
v = torch.ones(3, 2, 128, 4).to("xla")
o = torch.zeros(3, 2, 128, 4).to("xla")

torch_xla._XLAC._xla_tpu_custom_call_(o, [q, k, v], payload)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([o])
print(hlo)
print(o)


class MNISTComparator(nn.Module):

Expand Down
19 changes: 12 additions & 7 deletions torch_xla/csrc/layout_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,16 @@ xla::Shape MakeTpuShape(absl::Span<const int64_t> dimensions,
static double max_padding_factor =
runtime::sys_util::GetEnvDouble("XLA_MAX_PADDING_FACTOR", 1.25);
xla::Shape shape;
if (PaddingFactor(dimensions[dimensions.size() - 1], 128) *
PaddingFactor(dimensions[dimensions.size() - 2], 8) <
max_padding_factor) {
shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(type, dimensions);
} else {
shape = MakeShapeWithSortedLayout(dimensions, type);
}
// if (PaddingFactor(dimensions[dimensions.size() - 1], 128) *
// PaddingFactor(dimensions[dimensions.size() - 2], 8) <
// max_padding_factor) {
// std::cerr << "ff7: layout A" << std::endl;
// shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(type, dimensions);
// } else {
// std::cerr << "ff7: layout B" << std::endl;
// shape = MakeShapeWithSortedLayout(dimensions, type);
// }
shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(type, dimensions);
SetDynamicDimensions(&shape, dynamic_dimensions);
return shape;
}
Expand Down Expand Up @@ -181,10 +184,12 @@ xla::Shape MakeArrayShapeFromDimensions(
XlaDeviceType hw_type) {
auto layout_ptr = LayoutManager::Get()->GetLayout(dimensions);
if (layout_ptr != nullptr) {
std::cerr << "ff7: layout 1" << std::endl;
return MakeShapeWithLayout(type, dimensions, dynamic_dimensions,
*layout_ptr);
}
if (dimensions.size() > 1 && hw_type == XlaDeviceType::TPU) {
std::cerr << "ff7: layout 2" << std::endl;
return MakeTpuShape(dimensions, dynamic_dimensions, type);
}
return MakeTorchTensorLayout(dimensions, dynamic_dimensions, type);
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,11 @@ xla::XlaOp BuildTpuCustomCall(const std::vector<xla::XlaOp>& inputs,
std::vector<xla::Shape> input_shapes;
input_shapes.reserve(inputs.size());
for (const auto& input : inputs) {
input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input));
auto shape = ShapeHelper::ShapeOfXlaOp(input);
// shape.mutable_layout()->clear_minor_to_major();
std::cerr << "ff7:" << shape << std::endl;

input_shapes.push_back(std::move(shape));
}

XLA_CHECK(inputs.size() > 0) << "inputs are empty";
Expand Down

0 comments on commit e3f2424

Please sign in to comment.