Skip to content

Commit

Permalink
Support array attribute in stablehlo composite (#6840)
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc authored Mar 29, 2024
1 parent c1db875 commit 7bbe9d7
Showing 3 changed files with 107 additions and 29 deletions.
28 changes: 28 additions & 0 deletions test/stablehlo/test_mark_pattern.py
Original file line number Diff line number Diff line change
@@ -356,6 +356,34 @@ def test_update_kv_cache(self):
self.assertEqual(
shlo_text.count("stablehlo.composite \"test.update_kv_cache\""), 1)

def test_composite_builder_list_attr_value(self):

class M(torch.nn.Module):

def forward(self, x, y):
builder = StableHLOCompositeBuilder(
"test.add", {
"int_arr": [1, 2, 3],
"float_arr": [1.0, 1.1, 1.2],
"bool_arr": [True, False]
})
x, y = builder.mark_inputs(x, y)
z = x + y
z = builder.mark_outputs(z)
return z

input_args = (torch.randn((5, 5)), torch.randn((5, 5)))
stablehlo = self.run_func_get_stablehlo(M(), input_args)
self.assertEqual(stablehlo.count("stablehlo.composite \"test.add\""), 1)
self.assertTrue(
stablehlo.count("bool_arr = dense<[true, false]> : tensor<2xi1>"), 1)
self.assertTrue(
stablehlo.count(
"float_arr = dense<[1.000000e+00, 1.100000e+00, 1.200000e+00]> : tensor<3xf32>"
), 1)
self.assertTrue(
stablehlo.count("int_arr = dense<[1, 2, 3]> : tensor<3xi64>"), 1)


if __name__ == '__main__':
test = unittest.main()
94 changes: 69 additions & 25 deletions torch_xla/csrc/runtime/stablehlo_composite_helper.cc
Original file line number Diff line number Diff line change
@@ -190,35 +190,78 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
return metadata;
}

mlir::FailureOr<mlir::Attribute> BuildAttrFromJson(mlir::OpBuilder& builder,
mlir::Operation* op,
const json& json_value) {
switch (json_value.type()) {
case json::value_t::number_integer:
case json::value_t::number_unsigned:
return builder.getI64IntegerAttr(json_value.template get<int64_t>());
case json::value_t::number_float:
return builder.getF32FloatAttr(json_value.template get<float>());
case json::value_t::boolean:
return builder.getBoolAttr(json_value.template get<bool>());
case json::value_t::string:
return builder.getStringAttr(json_value.template get<std::string>());
case json::value_t::array: {
if (json_value.empty()) {
return builder.getArrayAttr({});
}
auto get_json_type = [](const json& j) {
auto ty = j.type();
if (ty == json::value_t::number_unsigned) {
return json::value_t::number_integer;
}
return ty;
};

auto head_type = get_json_type(json_value[0]);
bool is_homogeneous = llvm::all_of(json_value, [&](auto& el) {
return get_json_type(el) == head_type;
});
if (!is_homogeneous) {
return op->emitError()
<< "invalid JSON to MLIR, arrays must be homogeneous";
}

switch (head_type) {
case json::value_t::number_integer:
return builder.getI64TensorAttr(
json_value.template get<llvm::SmallVector<int64_t>>());
case json::value_t::number_float:
return mlir::DenseFPElementsAttr::get(
mlir::RankedTensorType::get(json_value.size(),
builder.getF32Type()),
json_value.template get<llvm::SmallVector<float>>());
case json::value_t::boolean:
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(json_value.size(),
builder.getI1Type()),
json_value.template get<llvm::SmallVector<bool>>());
default:
return op->emitError()
<< "invalid JSON to MLIR: invalid array type. arrays must "
"be "
"1-D homogeneous arrays of supported primitive types";
}
}
default:
return op->emitError()
<< "invalid JSON to MLIR: unsupported json value type";
}
}

mlir::FailureOr<mlir::DictionaryAttr> BuildDictionaryAttrFromJsonMap(
mlir::OpBuilder& builder,
mlir::OpBuilder& builder, mlir::Operation* op,
const std::unordered_map<std::string, json>& json_map) {
llvm::SmallVector<mlir::NamedAttribute> named_attrs;
for (auto& [key, j] : json_map) {
switch (j.type()) {
case json::value_t::number_integer:
case json::value_t::number_unsigned:
named_attrs.push_back(
{builder.getStringAttr(key),
builder.getI64IntegerAttr(j.template get<int64_t>())});
break;
case json::value_t::number_float:
named_attrs.push_back(
{builder.getStringAttr(key),
builder.getF32FloatAttr(j.template get<float>())});
break;
case json::value_t::boolean:
named_attrs.push_back({builder.getStringAttr(key),
builder.getBoolAttr(j.template get<bool>())});
break;
case json::value_t::string:
named_attrs.push_back(
{builder.getStringAttr(key),
builder.getStringAttr(j.template get<std::string>())});
break;
default:
return mlir::failure();
mlir::FailureOr<mlir::Attribute> attribute_or =
BuildAttrFromJson(builder, op, j);
if (mlir::failed(attribute_or)) {
return mlir::failure();
}
named_attrs.push_back({builder.getStringAttr(key), *attribute_or});
}
return builder.getDictionaryAttr(named_attrs);
}
@@ -430,7 +473,8 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
mlir::OpBuilder builder(context);

mlir::FailureOr<mlir::DictionaryAttr> attributes_or =
BuildDictionaryAttrFromJsonMap(builder, metadata.attrs);
BuildDictionaryAttrFromJsonMap(builder, boundary_output_op,
metadata.attrs);
if (mlir::failed(attributes_or)) {
return boundary_output_op->emitError()
<< "failed to transform boundary attr "
14 changes: 10 additions & 4 deletions torch_xla/experimental/xla_marker.py
Original file line number Diff line number Diff line change
@@ -40,10 +40,16 @@ def _assert_valid_composite_attr(attr):
for k, v in attr.items():
if not isinstance(k, str):
raise ValueError("Composite attr name must be a Python str.")
if type(k) not in (str, float, int, bool):
raise ValueError(
"Composite attr value must be either Python str, float, int, or bool."
)

invalid_attr_value_error = ValueError(
"Composite attr value must be either Python str, float, int, bool, list[int], list[float], list[bool]."
)
if isinstance(v, (list, tuple)):
eltys = {type(el) for el in v}
if len(eltys) > 1 or next(iter(eltys)) not in (int, float, bool):
raise invalid_attr_value_error
elif type(v) not in (str, float, int, bool):
raise invalid_attr_value_error


@torchdynamo.assume_constant_result

0 comments on commit 7bbe9d7

Please sign in to comment.