diff --git a/test/stablehlo/test_mark_pattern.py b/test/stablehlo/test_mark_pattern.py index bd1bcd4467e..a045b76e20f 100644 --- a/test/stablehlo/test_mark_pattern.py +++ b/test/stablehlo/test_mark_pattern.py @@ -356,6 +356,34 @@ def test_update_kv_cache(self): self.assertEqual( shlo_text.count("stablehlo.composite \"test.update_kv_cache\""), 1) + def test_composite_builder_list_attr_value(self): + + class M(torch.nn.Module): + + def forward(self, x, y): + builder = StableHLOCompositeBuilder( + "test.add", { + "int_arr": [1, 2, 3], + "float_arr": [1.0, 1.1, 1.2], + "bool_arr": [True, False] + }) + x, y = builder.mark_inputs(x, y) + z = x + y + z = builder.mark_outputs(z) + return z + + input_args = (torch.randn((5, 5)), torch.randn((5, 5))) + stablehlo = self.run_func_get_stablehlo(M(), input_args) + self.assertEqual(stablehlo.count("stablehlo.composite \"test.add\""), 1) + self.assertTrue( + stablehlo.count("bool_arr = dense<[true, false]> : tensor<2xi1>"), 1) + self.assertTrue( + stablehlo.count( + "float_arr = dense<[1.000000e+00, 1.100000e+00, 1.200000e+00]> : tensor<3xf32>" + ), 1) + self.assertTrue( + stablehlo.count("int_arr = dense<[1, 2, 3]> : tensor<3xi64>"), 1) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc index c2681dbca0a..3bc9e9cc309 100644 --- a/torch_xla/csrc/runtime/stablehlo_composite_helper.cc +++ b/torch_xla/csrc/runtime/stablehlo_composite_helper.cc @@ -190,35 +190,78 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { return metadata; } + mlir::FailureOr BuildAttrFromJson(mlir::OpBuilder& builder, + mlir::Operation* op, + const json& json_value) { + switch (json_value.type()) { + case json::value_t::number_integer: + case json::value_t::number_unsigned: + return builder.getI64IntegerAttr(json_value.template get()); + case json::value_t::number_float: + return builder.getF32FloatAttr(json_value.template get()); + case json::value_t::boolean: + return builder.getBoolAttr(json_value.template get()); + case json::value_t::string: + return builder.getStringAttr(json_value.template get()); + case json::value_t::array: { + if (json_value.empty()) { + return builder.getArrayAttr({}); + } + auto get_json_type = [](const json& j) { + auto ty = j.type(); + if (ty == json::value_t::number_unsigned) { + return json::value_t::number_integer; + } + return ty; + }; + + auto head_type = get_json_type(json_value[0]); + bool is_homogeneous = llvm::all_of(json_value, [&](auto& el) { + return get_json_type(el) == head_type; + }); + if (!is_homogeneous) { + return op->emitError() + << "invalid JSON to MLIR, arrays must be homogeneous"; + } + + switch (head_type) { + case json::value_t::number_integer: + return builder.getI64TensorAttr( + json_value.template get>()); + case json::value_t::number_float: + return mlir::DenseFPElementsAttr::get( + mlir::RankedTensorType::get(json_value.size(), + builder.getF32Type()), + json_value.template get>()); + case json::value_t::boolean: + return mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get(json_value.size(), + builder.getI1Type()), + json_value.template get>()); + default: + return op->emitError() + << "invalid JSON to MLIR: invalid array type. arrays must " + "be " + "1-D homogeneous arrays of supported primitive types"; + } + } + default: + return op->emitError() + << "invalid JSON to MLIR: unsupported json value type"; + } + } + mlir::FailureOr BuildDictionaryAttrFromJsonMap( - mlir::OpBuilder& builder, + mlir::OpBuilder& builder, mlir::Operation* op, const std::unordered_map& json_map) { llvm::SmallVector named_attrs; for (auto& [key, j] : json_map) { - switch (j.type()) { - case json::value_t::number_integer: - case json::value_t::number_unsigned: - named_attrs.push_back( - {builder.getStringAttr(key), - builder.getI64IntegerAttr(j.template get())}); - break; - case json::value_t::number_float: - named_attrs.push_back( - {builder.getStringAttr(key), - builder.getF32FloatAttr(j.template get())}); - break; - case json::value_t::boolean: - named_attrs.push_back({builder.getStringAttr(key), - builder.getBoolAttr(j.template get())}); - break; - case json::value_t::string: - named_attrs.push_back( - {builder.getStringAttr(key), - builder.getStringAttr(j.template get())}); - break; - default: - return mlir::failure(); + mlir::FailureOr attribute_or = + BuildAttrFromJson(builder, op, j); + if (mlir::failed(attribute_or)) { + return mlir::failure(); } + named_attrs.push_back({builder.getStringAttr(key), *attribute_or}); } return builder.getDictionaryAttr(named_attrs); } @@ -430,7 +473,8 @@ class BuildStableHLOCompositePass : public mlir::OperationPass { mlir::OpBuilder builder(context); mlir::FailureOr attributes_or = - BuildDictionaryAttrFromJsonMap(builder, metadata.attrs); + BuildDictionaryAttrFromJsonMap(builder, boundary_output_op, + metadata.attrs); if (mlir::failed(attributes_or)) { return boundary_output_op->emitError() << "failed to transform boundary attr " diff --git a/torch_xla/experimental/xla_marker.py b/torch_xla/experimental/xla_marker.py index e15552f9ed6..adc77f8f4cd 100644 --- a/torch_xla/experimental/xla_marker.py +++ b/torch_xla/experimental/xla_marker.py @@ -40,10 +40,16 @@ def _assert_valid_composite_attr(attr): for k, v in attr.items(): if not isinstance(k, str): raise ValueError("Composite attr name must be a Python str.") - if type(k) not in (str, float, int, bool): - raise ValueError( - "Composite attr value must be either Python str, float, int, or bool." - ) + + invalid_attr_value_error = ValueError( + "Composite attr value must be either Python str, float, int, bool, list[int], list[float], list[bool]." + ) + if isinstance(v, (list, tuple)): + eltys = {type(el) for el in v} + if len(eltys) > 1 or next(iter(eltys)) not in (int, float, bool): + raise invalid_attr_value_error + elif type(v) not in (str, float, int, bool): + raise invalid_attr_value_error @torchdynamo.assume_constant_result