Skip to content

Commit

Permalink
Add Pattern Boundary Marking API (#5930)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <[email protected]>
Co-authored-by: Chunnien Chan <[email protected]>
  • Loading branch information
3 people authored Dec 12, 2023
1 parent e3012bf commit 8243b50
Show file tree
Hide file tree
Showing 16 changed files with 962 additions and 1 deletion.
8 changes: 8 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ http_archive(
urls = ["https://github.com/pybind/pybind11/archive/442261da585536521ff459b1457b2904895f23b4.tar.gz"],
)

http_archive(
name = "com_nlohmann_json",
build_file = "//bazel:nlohmann_json.BUILD",
sha256 = "d69f9deb6a75e2580465c6c4c5111b89c4dc2fa94e3a85fcd2ffcd9a143d9273",
strip_prefix = "json-3.11.2",
url = "https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz",
)

load("@pybind11_bazel//:python_configure.bzl", "python_configure")

# This is required for setting up the linkopts for -lpython.q
Expand Down
9 changes: 9 additions & 0 deletions bazel/nlohmann_json.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cc_library(
name = "json",
hdrs = [
"single_include/nlohmann/json.hpp",
"single_include/nlohmann/json_fwd.hpp",
],
includes = ["single_include"],
visibility = ["//visibility:public"],
)
183 changes: 183 additions & 0 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import sys
import unittest

import torch
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_marker
from torch.utils import _pytree as pytree
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder


class XlaMarkPatternTest(unittest.TestCase):

def run_func_get_stablehlo(self, f, input_args):

device = xm.xla_device()
input_args = pytree.tree_map_only(torch.Tensor,
lambda x: x.to(device=device), input_args)
out = f(*input_args)
if isinstance(out, tuple):
out = list(out)
else:
out = [out]
stablehlo = xm.get_stablehlo(out)
return stablehlo

def test_basic(self):

def f(x):
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
x = x + 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", False)
return x

input_args = (torch.randn(5),)
stablehlo = self.run_func_get_stablehlo(f, input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertTrue('{attributes = {}, name = "p"}' in stablehlo)

def test_sdpa_pattern(self):
import torch.nn.functional as F

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
q, k, v = x.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q, "sdpa", pos=0, id="0", is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k, "sdpa", pos=1, id="0", is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v, "sdpa", pos=2, id="0", is_input=True)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = torch.ops.xla_pattern_marking.mark_tensor(
attn_out,
"sdpa",
pos=0,
id="0",
is_input=False,
attr={"scale": 0.25})
q, k, v = y.split(128, dim=-2)
q = torch.ops.xla_pattern_marking.mark_tensor(
q, "sdpa", pos=0, id="1", is_input=True)
k = torch.ops.xla_pattern_marking.mark_tensor(
k, "sdpa", pos=1, id="1", is_input=True)
v = torch.ops.xla_pattern_marking.mark_tensor(
v, "sdpa", pos=2, id="1", is_input=True)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = torch.ops.xla_pattern_marking.mark_tensor(
attn_out2, "sdpa", pos=0, id="1", is_input=False, attr={"scale": 2})
return attn_out, attn_out2

input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 2)
self.assertTrue(
'{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in
stablehlo)
self.assertTrue(
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)

def test_composite_builder_sdpa_pattern(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
b = StableHLOCompositeBuilder("sdpa", {"scale": 0.25})
q, k, v = x.split(128, dim=-2)
q, k, v = b.mark_inputs(q, k, v)
attn_out = F.scaled_dot_product_attention(q, k, v, scale=0.25)
attn_out = b.mark_outputs(attn_out)

b2 = StableHLOCompositeBuilder("sdpa", {"scale": 2})
q, k, v = y.split(128, dim=-2)
q, k, v = b2.mark_inputs(q, k, v)
attn_out2 = F.scaled_dot_product_attention(q, k, v, scale=4)
attn_out2 = b2.mark_outputs(attn_out2)
return attn_out, attn_out2

input_args = (torch.randn((32, 8, 384, 64)), torch.randn((32, 8, 384, 64)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 2)
self.assertTrue(
'{attributes = {scale = 2.500000e-01 : f32}, name = "sdpa"}}' in
stablehlo)
self.assertTrue(
'{attributes = {scale = 2 : i64}, name = "sdpa"}}' in stablehlo)

def test_multiple_input(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
out = x + y
out = out * x * y
out = torch.ops.xla_pattern_marking.mark_tensor(out, "p", 0, "0", False)
return out

input_args = (torch.ones(5), torch.ones(5))
stablehlo = self.run_func_get_stablehlo(f, input_args)
self.assertEqual(stablehlo.count("@stablehlo.composite"), 1)
self.assertTrue('{attributes = {}, name = "p"}' in stablehlo)

@unittest.skip("Multiple outputs patterns are not supported now.")
def test_multiple_output(self):

def f(x, y):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p", 0, "0", True)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p", 1, "0", True)
out1 = x + y
out2 = x * y
out1 = torch.ops.xla_pattern_marking.mark_tensor(out1, "p", 0, "0", False)
out2 = torch.ops.xla_pattern_marking.mark_tensor(out2, "p", 1, "0", False)
return out1, out2

input_args = (torch.ones(5), torch.ones(5))
stablehlo = self.run_func_get_stablehlo(f, input_args)

@unittest.skip("Nested pattern is not supported now.")
def test_nested_pattern(self):

def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
x = x * 2
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0",
False)
return x

input_args = (torch.ones(5),)
stablehlo = self.run_func_get_stablehlo(f, input_args)

@unittest.skip("Nested pattern is not supported now.")
def test_tangent_output(self):
# Special case of nested pattern, outputs don't have dependencies.
def f(x):
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_outter", 0, "0", True)
x = x + 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", True)
x = x + 1
y = x - 1
x = torch.ops.xla_pattern_marking.mark_tensor(x, "p_inner", 0, "0", False)
y = torch.ops.xla_pattern_marking.mark_tensor(y, "p_outter", 0, "0",
False)
return x, y

input_args = (torch.ones(5),)
stablehlo = self.run_func_get_stablehlo(f, input_args)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
17 changes: 16 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,12 @@ void MapXlaEnvVarsToLazy() {
runtime::sys_util::GetEnvInt("XLA_TRIM_GRAPH_SIZE", 100000);
}

at::Tensor MarkTensor(const at::Tensor& input, const std::string& info) {
XLATensorPtr result =
tensor_methods::mark_tensor(bridge::GetXlaTensor(input), info);
return bridge::AtenFromXlaTensor(std::move(result));
}

std::string GetPyTypeString(py::handle obj) {
std::string type = obj.attr("__class__").attr("__name__").cast<std::string>();
return type;
Expand Down Expand Up @@ -2172,7 +2178,16 @@ void InitXlaModuleBindings(py::module m) {
}
return handles;
});

m.def("_xla_mark_tensor",
[](const at::Tensor& input, const std::string& info) {
TORCH_LAZY_COUNTER("XlaMarkTensor", 1);
at::Tensor result;
{
NoGilSection nogil;
result = MarkTensor(input, info);
}
return result;
});
m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) {
TORCH_LAZY_COUNTER("XlaMarkDynamic", 1);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
Expand Down
35 changes: 35 additions & 0 deletions torch_xla/csrc/ops/mark_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "torch_xla/csrc/ops/mark_tensor.h"

#include <torch/csrc/lazy/core/tensor_util.h>

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/shape_helper.h"

namespace torch_xla {

MarkTensor::MarkTensor(const torch::lazy::Value& input, const std::string& info)
: XlaNode(xla_mark_tensor, {input}, GetXlaShape(input),
/*num_outputs=*/1, torch::lazy::MHash(info)),
info_(info) {}

torch::lazy::NodePtr MarkTensor::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<MarkTensor>(operands.at(0), info_);
}

XlaOpVector MarkTensor::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::Shape input_shape = ShapeHelper::ShapeOfXlaOp(input);
static const std::string opname = "xla_mark_tensor";
xla::XlaOp output =
xla::CustomCall(input.builder(), opname, {input}, input_shape, info_);
return ReturnOp(output, loctx);
}

std::string MarkTensor::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", info=" << info_;
return ss.str();
}

} // namespace torch_xla
24 changes: 24 additions & 0 deletions torch_xla/csrc/ops/mark_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_
#define XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class MarkTensor : public XlaNode {
public:
MarkTensor(const torch::lazy::Value& input, const std::string& info);

std::string ToString() const override;

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

private:
std::string info_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_MARK_TENSOR_H_
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const OpKindWrapper xla_diagonal_view_update("xla::diagonal_view_update");
const OpKindWrapper xla_einsum_backward("xla::einsum_backward");
const OpKindWrapper xla_generic_slice("xla::generic_slice");
const OpKindWrapper xla_get_dimensions_size("xla::xla_get_dimensions_size");
const OpKindWrapper xla_mark_tensor("xla::mark_tensor");
const OpKindWrapper xla_moving_average("xla::moving_average");
const OpKindWrapper xla_nms("xla::nms");
const OpKindWrapper xla_not_supported("xla::not_supported");
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ extern const OpKindWrapper xla_diagonal_view_update;
extern const OpKindWrapper xla_einsum_backward;
extern const OpKindWrapper xla_generic_slice;
extern const OpKindWrapper xla_get_dimensions_size;
extern const OpKindWrapper xla_mark_tensor;
extern const OpKindWrapper xla_moving_average;
extern const OpKindWrapper xla_nms;
extern const OpKindWrapper xla_not_supported;
Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,26 @@ cc_library(
],
)

cc_library(
name = "stablehlo_composite_helper",
srcs = ["stablehlo_composite_helper.cc"],
hdrs = ["stablehlo_composite_helper.h"],
deps = [
":types",
":xla_util",
"@com_nlohmann_json//:json",
"@xla//xla/mlir_hlo:all_passes",
],
)

cc_library(
name = "stablehlo_helper",
srcs = ["stablehlo_helper.cc"],
hdrs = ["stablehlo_helper.h"],
deps = [
":types",
":xla_util",
":stablehlo_composite_helper",
"@stablehlo//:stablehlo_portable_api",
"@stablehlo//:stablehlo_serialization",
"@xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
Expand Down
Loading

0 comments on commit 8243b50

Please sign in to comment.