From 95a7d1fca69573d144f806bc3f6dee8c6487f907 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Mon, 6 Nov 2023 12:01:53 +0000 Subject: [PATCH] core: Implement dialect attribute opaque syntax. (#1726) Implement the dialect attribute opaque syntax, i.e., `#arith>` Adds an OpaqueSyntaxAttribute simple class to inherit from on Attributes meant to be printed using that syntax, just like TypeAttribute for using `!` instead of `#` More MLIR compliancy! --------- Co-authored-by: Mathieu Fehr --- docs/irdl.ipynb | 8 +- docs/tutorial.ipynb | 4 +- tests/dialects/test_builtin.py | 2 +- tests/dialects/test_gpu.py | 18 +-- .../filecheck/dialects/arith/arith_attrs.mlir | 3 +- tests/filecheck/dialects/builtin/attrs.mlir | 12 +- .../builtin/dense_array_invalid_element.mlir | 2 +- tests/filecheck/dialects/gpu/invalid.mlir | 4 +- .../filecheck/dialects/riscv_snitch/ops.mlir | 4 +- .../filecheck/dialects/snitch/snitch_ops.mlir | 20 +-- .../snitch/snitch_to_riscv_lowering.mlir | 18 +-- .../convert_snitch_stream_to_snitch.mlir | 22 ++-- .../filecheck/dialects/snitch_stream/ops.mlir | 24 ++-- .../dialects/stencil/stencil_ops.mlir | 4 +- .../parser-printer/builtin_attrs.mlir | 6 - .../parser-printer/unregistered_dialect.mlir | 14 +- .../add_snitch_stream.mlir | 8 +- .../relu_snitch_stream.mlir | 6 +- .../snitch_register_allocation.mlir | 12 +- tests/interpreters/test_wgsl_printer.py | 8 +- tests/test_attribute_definition.py | 106 +++++++++------ tests/test_operation_definition.py | 2 +- tests/test_printer.py | 6 +- xdsl/dialects/builtin.py | 10 +- xdsl/dialects/gpu.py | 106 +++++---------- .../interpreters/experimental/wgsl_printer.py | 10 +- xdsl/ir/core.py | 37 +++--- xdsl/irdl/irdl.py | 2 +- xdsl/parser/attribute_parser.py | 121 +++++++++++++----- xdsl/printer.py | 25 +++- 30 files changed, 346 insertions(+), 278 deletions(-) diff --git a/docs/irdl.ipynb b/docs/irdl.ipynb index 0a629fb1d7..3b7296c184 100644 --- a/docs/irdl.ipynb +++ b/docs/irdl.ipynb @@ -233,7 +233,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "#int<3> should be of base attribute string\n" + "#builtin.int<3> should be of base attribute string\n" ] } ], @@ -588,7 +588,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\"ga\" should be of base attribute int\n" + "\"ga\" should be of base attribute builtin.int\n" ] } ], @@ -694,7 +694,7 @@ "output_type": "stream", "text": [ "In integer_type attribute verifier: 1 parameters expected, got 2\n", - "\"ga\" should be of base attribute int\n" + "\"ga\" should be of base attribute builtin.int\n" ] } ], @@ -1381,7 +1381,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "\"string_attr_op\"() {\"value\" = \"ga\", \"other_attr\" = #int<42>} : () -> ()" + "\"string_attr_op\"() {\"value\" = \"ga\", \"other_attr\" = #builtin.int<42>} : () -> ()" ] } ], diff --git a/docs/tutorial.ipynb b/docs/tutorial.ipynb index 51963820b6..9e8f443b9f 100644 --- a/docs/tutorial.ipynb +++ b/docs/tutorial.ipynb @@ -156,7 +156,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "#int<42>" + "#builtin.int<42>" ] } ], @@ -246,7 +246,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "#int<64>" + "#builtin.int<64>" ] } ], diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py index c49ecaf48f..a8c1b46820 100644 --- a/tests/dialects/test_builtin.py +++ b/tests/dialects/test_builtin.py @@ -309,7 +309,7 @@ def verify( context = getattr(e, "__context__") assert "fail" in context - with pytest.raises(VerifyException, match="wrapped #int<1>") as e: + with pytest.raises(VerifyException, match="wrapped #builtin.int<1>") as e: outer = CustomErrorMessageAttrConstraint(inner, lambda k: f"wrapped {k}") outer.verify(one, {}) assert hasattr(e, "__context__") diff --git a/tests/dialects/test_gpu.py b/tests/dialects/test_gpu.py index e93b8c2551..9978efe38e 100644 --- a/tests/dialects/test_gpu.py +++ b/tests/dialects/test_gpu.py @@ -3,7 +3,7 @@ from xdsl.dialects.gpu import ( AllocOp, AllReduceOp, - AllReduceOperationAttr, + AllReduceOpAttr, AsyncTokenType, BarrierOp, BlockDimOp, @@ -34,7 +34,7 @@ def test_dimension(): - dim = DimensionAttr.from_dimension("x") + dim = DimensionAttr("x") assert dim.data == "x" @@ -79,13 +79,13 @@ def test_alloc(): def test_all_reduce_operation(): - op = AllReduceOperationAttr.from_op("add") + op = AllReduceOpAttr("add") assert op.data == "add" def test_all_reduce(): - op = AllReduceOperationAttr.from_op("add") + op = AllReduceOpAttr("add") init = arith.Constant.from_int_and_width(0, builtin.IndexType()) @@ -120,7 +120,7 @@ def test_barrier(): def test_block_dim(): - dim = DimensionAttr.from_dimension("x") + dim = DimensionAttr("x") block_dim = BlockDimOp(dim) @@ -129,7 +129,7 @@ def test_block_dim(): def test_block_id(): - dim = DimensionAttr.from_dimension("x") + dim = DimensionAttr("x") block_id = BlockIdOp(dim) @@ -181,7 +181,7 @@ def test_gpu_module_end(): def test_global_id(): - dim = DimensionAttr.from_dimension("x") + dim = DimensionAttr("x") global_id = GlobalIdOp(dim) @@ -190,7 +190,7 @@ def test_global_id(): def test_grid_dim(): - dim = DimensionAttr.from_dimension("x") + dim = DimensionAttr("x") grid_dim = GridDimOp(dim) @@ -395,7 +395,7 @@ def test_subgroup_size(): def test_thread_id(): - dim = DimensionAttr.from_dimension("x") + dim = DimensionAttr("x") thread_id = ThreadIdOp(dim) diff --git a/tests/filecheck/dialects/arith/arith_attrs.mlir b/tests/filecheck/dialects/arith/arith_attrs.mlir index 8ad2b31593..e5800aa6d6 100644 --- a/tests/filecheck/dialects/arith/arith_attrs.mlir +++ b/tests/filecheck/dialects/arith/arith_attrs.mlir @@ -4,7 +4,8 @@ "test.op"() {attrs = [ #arith.fastmath, - // CHECK: #arith.fastmath + #arith>, + // CHECK: #arith.fastmath, #arith.fastmath #arith.fastmath, // CHECK-SAME: #arith.fastmath #arith.fastmath, diff --git a/tests/filecheck/dialects/builtin/attrs.mlir b/tests/filecheck/dialects/builtin/attrs.mlir index bae1bb3e54..423c9eda8a 100644 --- a/tests/filecheck/dialects/builtin/attrs.mlir +++ b/tests/filecheck/dialects/builtin/attrs.mlir @@ -19,12 +19,12 @@ %x7 = "arith.constant"() {"value" = 0 : i64, "test" = array} : () -> i64 // CHECK: "test" = array %x8 = "arith.constant"() {"value" = 0 : i64, "test" = array} : () -> i64 - // CHECK: "test" = #signedness - %x9 = "arith.constant"() {"value" = 0 : i64, "test" = #signedness} : () -> i64 - // CHECK: "test" = #signedness - %x10 = "arith.constant"() {"value" = 0 : i64, "test" = #signedness} : () -> i64 - // CHECK: "test" = #signedness - %x11 = "arith.constant"() {"value" = 0 : i64, "test" = #signedness} : () -> i64 + // CHECK: "test" = #builtin.signedness + %x9 = "arith.constant"() {"value" = 0 : i64, "test" = #builtin.signedness} : () -> i64 + // CHECK: "test" = #builtin.signedness + %x10 = "arith.constant"() {"value" = 0 : i64, "test" = #builtin.signedness} : () -> i64 + // CHECK: "test" = #builtin.signedness + %x11 = "arith.constant"() {"value" = 0 : i64, "test" = #builtin.signedness} : () -> i64 // CHECK: "test" = @foo %x12 = "arith.constant"() {"value" = 0 : i64, "test" = @foo} : () -> i64 // CHECK: "test" = @foo::@bar diff --git a/tests/filecheck/dialects/builtin/dense_array_invalid_element.mlir b/tests/filecheck/dialects/builtin/dense_array_invalid_element.mlir index e6363437ad..72fb02ee2e 100644 --- a/tests/filecheck/dialects/builtin/dense_array_invalid_element.mlir +++ b/tests/filecheck/dialects/builtin/dense_array_invalid_element.mlir @@ -1,6 +1,6 @@ // RUN: xdsl-opt %s --parsing-diagnostics | filecheck %s -"builtin.module" () {"test" = array: 2, 5, 2>} ({ +"builtin.module" () {"test" = array<()->(): 2, 5, 2>} ({ }) // CHECK: dense array element type must be an integer or floating point type diff --git a/tests/filecheck/dialects/gpu/invalid.mlir b/tests/filecheck/dialects/gpu/invalid.mlir index b24b93c329..d26e656208 100644 --- a/tests/filecheck/dialects/gpu/invalid.mlir +++ b/tests/filecheck/dialects/gpu/invalid.mlir @@ -19,7 +19,7 @@ "builtin.module"() ({ }) {"wrong_all_reduce_operation" = #gpu}: () -> () -// CHECK: Unexpected op magic. A gpu all_reduce_op can only be add, and, max, min, mul, or, or xor +// CHECK: Expected add, and, max, min, mul, or, or xor. // ----- @@ -77,7 +77,7 @@ "builtin.module"() ({ }) {"wrong_dim" = #gpu}: () -> () -// CHECK: Unexpected dim w. A gpu dim can only be x, y, or z +// CHECK: Expected x, y or z. // ----- diff --git a/tests/filecheck/dialects/riscv_snitch/ops.mlir b/tests/filecheck/dialects/riscv_snitch/ops.mlir index 5de7a099f7..bf8ed7395c 100644 --- a/tests/filecheck/dialects/riscv_snitch/ops.mlir +++ b/tests/filecheck/dialects/riscv_snitch/ops.mlir @@ -37,10 +37,10 @@ riscv_func.func @main() { // CHECK-GENERIC-NEXT: %scfgwi_zero = "riscv_snitch.scfgwi"(%0) {"immediate" = 42 : si12} : (!riscv.reg<>) -> !riscv.reg // CHECK-GENERIC-NEXT: "riscv_snitch.frep_outer"(%{{.*}}) ({ // CHECK-GENERIC-NEXT: %{{.*}} = "riscv.add"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> -// CHECK-GENERIC-NEXT: }) {"stagger_mask" = #int<0>, "stagger_count" = #int<0>} : (!riscv.reg<>) -> () +// CHECK-GENERIC-NEXT: }) {"stagger_mask" = #builtin.int<0>, "stagger_count" = #builtin.int<0>} : (!riscv.reg<>) -> () // CHECK-GENERIC-NEXT: "riscv_snitch.frep_inner"(%{{.*}}) ({ // CHECK-GENERIC-NEXT: %{{.*}} = "riscv.add"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> -// CHECK-GENERIC-NEXT: }) {"stagger_mask" = #int<0>, "stagger_count" = #int<0>} : (!riscv.reg<>) -> () +// CHECK-GENERIC-NEXT: }) {"stagger_mask" = #builtin.int<0>, "stagger_count" = #builtin.int<0>} : (!riscv.reg<>) -> () // CHECK-GENERIC-NEXT: "riscv_func.return"() : () -> () // CHECK-GENERIC-NEXT: }) {"sym_name" = "main", "function_type" = () -> ()} : () -> () // CHECK-GENERIC-NEXT: }) : () -> () diff --git a/tests/filecheck/dialects/snitch/snitch_ops.mlir b/tests/filecheck/dialects/snitch/snitch_ops.mlir index 9a78db234c..21deca2338 100644 --- a/tests/filecheck/dialects/snitch/snitch_ops.mlir +++ b/tests/filecheck/dialects/snitch/snitch_ops.mlir @@ -5,16 +5,16 @@ %stride = "test.op"() : () -> !riscv.reg<> %rep = "test.op"() : () -> !riscv.reg<> // Usual SSR setup sequence: - "snitch.ssr_set_dimension_bound"(%bound) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () - // CHECK: "snitch.ssr_set_dimension_bound"(%bound) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () - "snitch.ssr_set_dimension_stride"(%stride) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () - // CHECK: "snitch.ssr_set_dimension_stride"(%stride) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () - "snitch.ssr_set_dimension_source"(%addr) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () - // CHECK: "snitch.ssr_set_dimension_source"(%addr) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () - "snitch.ssr_set_dimension_destination"(%addr) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () - // CHECK: "snitch.ssr_set_dimension_destination"(%addr) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () - "snitch.ssr_set_stream_repetition"(%rep) {"dm" = #int<0>} : (!riscv.reg<>) -> () - // CHECK: "snitch.ssr_set_stream_repetition"(%rep) {"dm" = #int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_bound"(%bound) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () + // CHECK: "snitch.ssr_set_dimension_bound"(%bound) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_stride"(%stride) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () + // CHECK: "snitch.ssr_set_dimension_stride"(%stride) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_source"(%addr) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () + // CHECK: "snitch.ssr_set_dimension_source"(%addr) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_destination"(%addr) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () + // CHECK: "snitch.ssr_set_dimension_destination"(%addr) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_stream_repetition"(%rep) {"dm" = #builtin.int<0>} : (!riscv.reg<>) -> () + // CHECK: "snitch.ssr_set_stream_repetition"(%rep) {"dm" = #builtin.int<0>} : (!riscv.reg<>) -> () "snitch.ssr_enable"() : () -> () // CHECK-NEXT: "snitch.ssr_enable"() : () -> () "snitch.ssr_disable"() : () -> () diff --git a/tests/filecheck/dialects/snitch/snitch_to_riscv_lowering.mlir b/tests/filecheck/dialects/snitch/snitch_to_riscv_lowering.mlir index f51cd8143f..29087f87c2 100644 --- a/tests/filecheck/dialects/snitch/snitch_to_riscv_lowering.mlir +++ b/tests/filecheck/dialects/snitch/snitch_to_riscv_lowering.mlir @@ -5,32 +5,32 @@ builtin.module { %stride = riscv.li 4 : () -> !riscv.reg<> %rep = riscv.li 0 : () -> !riscv.reg<> // SSR setup sequence for dimension 0 - "snitch.ssr_set_dimension_bound"(%bound) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_bound"(%bound) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 64 : () -> !riscv.reg<> // CHECK-NEXT: %{{.*}} = riscv_snitch.scfgw %bound, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg - "snitch.ssr_set_dimension_stride"(%stride) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_stride"(%stride) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 192 : () -> !riscv.reg<> // CHECK-NEXT: %{{.*}} riscv_snitch.scfgw %stride, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg - "snitch.ssr_set_dimension_source"(%addr) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_source"(%addr) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 768 : () -> !riscv.reg<> // %{{.*}} = riscv_snitch.scfgw %addr, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg - "snitch.ssr_set_dimension_destination"(%addr) {"dm" = #int<0>, "dimension" = #int<0>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_destination"(%addr) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 896 : () -> !riscv.reg<> // CHECK-NEXT: %{{.*}} = riscv_snitch.scfgw %addr, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg // SSR setup sequence for dimension 3 - "snitch.ssr_set_dimension_bound"(%bound) {"dm" = #int<0>, "dimension" = #int<3>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_bound"(%bound) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<3>} : (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 160 : () -> !riscv.reg<> // CHECK-NEXT: %{{.*}} = riscv_snitch.scfgw %bound, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg - "snitch.ssr_set_dimension_stride"(%stride) {"dm" = #int<0>, "dimension" = #int<3>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_stride"(%stride) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<3>} : (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 288 : () -> !riscv.reg<> // CHECK-NEXT: %{{.*}} = riscv_snitch.scfgw %stride, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg - "snitch.ssr_set_dimension_source"(%addr) {"dm" = #int<0>, "dimension" = #int<3>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_source"(%addr) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<3>} : (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 864 : () -> !riscv.reg<> // riscv_snitch.scfgw %addr, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg - "snitch.ssr_set_dimension_destination"(%addr) {"dm" = #int<0>, "dimension" = #int<3>} : (!riscv.reg<>) -> () + "snitch.ssr_set_dimension_destination"(%addr) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<3>} : (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 992 : () -> !riscv.reg<> // CHECK-NEXT: %{{.*}} = riscv_snitch.scfgw %addr, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg - "snitch.ssr_set_stream_repetition"(%rep) {"dm" = #int<0>}: (!riscv.reg<>) -> () + "snitch.ssr_set_stream_repetition"(%rep) {"dm" = #builtin.int<0>}: (!riscv.reg<>) -> () // CHECK: %{{.*}} = riscv.li 32 : () -> !riscv.reg<> // CHECK-NEXT: %{{.*}} = riscv_snitch.scfgw %rep, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg // On/Off switching sequence diff --git a/tests/filecheck/dialects/snitch_stream/convert_snitch_stream_to_snitch.mlir b/tests/filecheck/dialects/snitch_stream/convert_snitch_stream_to_snitch.mlir index f966e6c7f5..db2a84fc53 100644 --- a/tests/filecheck/dialects/snitch_stream/convert_snitch_stream_to_snitch.mlir +++ b/tests/filecheck/dialects/snitch_stream/convert_snitch_stream_to_snitch.mlir @@ -5,31 +5,31 @@ %A, %B, %C = "test.op"() : () -> (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>) // CHECK-NEXT: %A, %B, %C = "test.op"() : () -> (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>) -%0 = "snitch_stream.stride_pattern"() {"ub" = [#int<2>, #int<3>], "strides" = [#int<24>, #int<8>], "dm" = #int<31>} : () -> !snitch_stream.stride_pattern_type +%0 = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<2>, #builtin.int<3>], "strides" = [#builtin.int<24>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type // CHECK-NEXT: %0 = riscv.li 2 : () -> !riscv.reg<> // CHECK-NEXT: %1 = riscv.li 3 : () -> !riscv.reg<> // CHECK-NEXT: %2 = riscv.li 24 : () -> !riscv.reg<> // CHECK-NEXT: %3 = riscv.li 8 : () -> !riscv.reg<> // CHECK-NEXT: %4 = riscv.addi %0, -1 : (!riscv.reg<>) -> !riscv.reg<> // CHECK-NEXT: %5 = riscv.addi %1, -1 : (!riscv.reg<>) -> !riscv.reg<> -// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%4) {"dm" = #int<31>, "dimension" = #int<0>} : (!riscv.reg<>) -> () -// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%5) {"dm" = #int<31>, "dimension" = #int<1>} : (!riscv.reg<>) -> () -// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%2) {"dm" = #int<31>, "dimension" = #int<0>} : (!riscv.reg<>) -> () +// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%4) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () +// CHECK-NEXT: "snitch.ssr_set_dimension_bound"(%5) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> () +// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%2) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<0>} : (!riscv.reg<>) -> () // CHECK-NEXT: %6 = riscv.mul %4, %2 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> // CHECK-NEXT: %7 = riscv.sub %3, %6 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<> -// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%7) {"dm" = #int<31>, "dimension" = #int<1>} : (!riscv.reg<>) -> () +// CHECK-NEXT: "snitch.ssr_set_dimension_stride"(%7) {"dm" = #builtin.int<31>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> () -%1 = "snitch_stream.strided_read"(%A, %0) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -// CHECK-NEXT: "snitch.ssr_set_dimension_source"(%A) {"dm" = #int<0>, "dimension" = #int<1>} : (!riscv.reg<>) -> () +%1 = "snitch_stream.strided_read"(%A, %0) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-NEXT: "snitch.ssr_set_dimension_source"(%A) {"dm" = #builtin.int<0>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> () // CHECK-NEXT: %a = riscv.get_float_register : () -> !riscv.freg -%2 = "snitch_stream.strided_read"(%B, %0) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -// CHECK-NEXT: "snitch.ssr_set_dimension_source"(%B) {"dm" = #int<1>, "dimension" = #int<1>} : (!riscv.reg<>) -> () +%2 = "snitch_stream.strided_read"(%B, %0) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-NEXT: "snitch.ssr_set_dimension_source"(%B) {"dm" = #builtin.int<1>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> () // CHECK-NEXT: %b = riscv.get_float_register : () -> !riscv.freg -%3 = "snitch_stream.strided_write"(%C, %0) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> -// CHECK-NEXT: "snitch.ssr_set_dimension_destination"(%C) {"dm" = #int<2>, "dimension" = #int<1>} : (!riscv.reg<>) -> () +%3 = "snitch_stream.strided_write"(%C, %0) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> +// CHECK-NEXT: "snitch.ssr_set_dimension_destination"(%C) {"dm" = #builtin.int<2>, "dimension" = #builtin.int<1>} : (!riscv.reg<>) -> () %4 = riscv.li 6 : () -> !riscv.reg<> diff --git a/tests/filecheck/dialects/snitch_stream/ops.mlir b/tests/filecheck/dialects/snitch_stream/ops.mlir index 9ea4716800..7778eeb103 100644 --- a/tests/filecheck/dialects/snitch_stream/ops.mlir +++ b/tests/filecheck/dialects/snitch_stream/ops.mlir @@ -3,10 +3,10 @@ %X, %Y, %Z, %n = "test.op"() : () -> (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) -%pattern = "snitch_stream.stride_pattern"() {"ub" = [#int<8>, #int<16>], "strides" = [#int<128>, #int<8>], "dm" = #int<31>} : () -> !snitch_stream.stride_pattern_type -%X_str = "snitch_stream.strided_read"(%X, %pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -%Y_str = "snitch_stream.strided_read"(%Y, %pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -%Z_str = "snitch_stream.strided_write"(%Z, %pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> +%pattern = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<8>, #builtin.int<16>], "strides" = [#builtin.int<128>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type +%X_str = "snitch_stream.strided_read"(%X, %pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +%Y_str = "snitch_stream.strided_read"(%Y, %pattern) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +%Z_str = "snitch_stream.strided_write"(%Z, %pattern) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> "snitch_stream.generic"(%n, %X_str, %Y_str, %Z_str) <{"operandSegmentSizes" = array}> ({ ^0(%x : !riscv.freg<>, %y : !riscv.freg<>): %z = riscv.fadd.d %x, %y : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<> @@ -14,10 +14,10 @@ }) : (!riscv.reg<>, !stream.readable>, !stream.readable>, !stream.writable>) -> () // CHECK: %X, %Y, %Z, %n = "test.op"() : () -> (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) -// CHECK-NEXT: %pattern = "snitch_stream.stride_pattern"() {"ub" = [#int<8>, #int<16>], "strides" = [#int<128>, #int<8>], "dm" = #int<31>} : () -> !snitch_stream.stride_pattern_type -// CHECK-NEXT: %X_str = "snitch_stream.strided_read"(%X, %pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -// CHECK-NEXT: %Y_str = "snitch_stream.strided_read"(%Y, %pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -// CHECK-NEXT: %Z_str = "snitch_stream.strided_write"(%Z, %pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> +// CHECK-NEXT: %pattern = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<8>, #builtin.int<16>], "strides" = [#builtin.int<128>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type +// CHECK-NEXT: %X_str = "snitch_stream.strided_read"(%X, %pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-NEXT: %Y_str = "snitch_stream.strided_read"(%Y, %pattern) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-NEXT: %Z_str = "snitch_stream.strided_write"(%Z, %pattern) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> // CHECK-NEXT: "snitch_stream.generic"(%n, %X_str, %Y_str, %Z_str) <{"operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%x : !riscv.freg<>, %y : !riscv.freg<>): // CHECK-NEXT: %z = riscv.fadd.d %x, %y : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<> @@ -25,10 +25,10 @@ // CHECK-NEXT: }) : (!riscv.reg<>, !stream.readable>, !stream.readable>, !stream.writable>) -> () // CHECK-GENERIC: %X, %Y, %Z, %n = "test.op"() : () -> (!riscv.reg<>, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) -// CHECK-GENERIC-NEXT: %pattern = "snitch_stream.stride_pattern"() {"ub" = [#int<8>, #int<16>], "strides" = [#int<128>, #int<8>], "dm" = #int<31>} : () -> !snitch_stream.stride_pattern_type -// CHECK-GENERIC-NEXT: %X_str = "snitch_stream.strided_read"(%X, %pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -// CHECK-GENERIC-NEXT: %Y_str = "snitch_stream.strided_read"(%Y, %pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -// CHECK-GENERIC-NEXT: %Z_str = "snitch_stream.strided_write"(%Z, %pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> +// CHECK-GENERIC-NEXT: %pattern = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<8>, #builtin.int<16>], "strides" = [#builtin.int<128>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type +// CHECK-GENERIC-NEXT: %X_str = "snitch_stream.strided_read"(%X, %pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-GENERIC-NEXT: %Y_str = "snitch_stream.strided_read"(%Y, %pattern) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-GENERIC-NEXT: %Z_str = "snitch_stream.strided_write"(%Z, %pattern) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> // CHECK-GENERIC-NEXT: "snitch_stream.generic"(%n, %X_str, %Y_str, %Z_str) <{"operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^0(%x : !riscv.freg<>, %y : !riscv.freg<>): // CHECK-GENERIC-NEXT: %z = "riscv.fadd.d"(%x, %y) : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<> diff --git a/tests/filecheck/dialects/stencil/stencil_ops.mlir b/tests/filecheck/dialects/stencil/stencil_ops.mlir index e2b69e5532..979dd7cb44 100644 --- a/tests/filecheck/dialects/stencil/stencil_ops.mlir +++ b/tests/filecheck/dialects/stencil/stencil_ops.mlir @@ -180,7 +180,7 @@ builtin.module { %4 = "stencil.load"(%2) : (!stencil.field<[-4,68]x[-4,68]xf64>) -> !stencil.temp<[-1,65]x[-1,65]xf64> %5 = "stencil.apply"(%4) ({ ^0(%6 : !stencil.temp<[-1,65]x[-1,65]xf64>): - %7 = "stencil.access"(%6) {"offset" = #stencil.index<-1, 0>, "offset_mapping" = [#int<1>, #int<0>]} : (!stencil.temp<[-1,65]x[-1,65]xf64>) -> f64 + %7 = "stencil.access"(%6) {"offset" = #stencil.index<-1, 0>, "offset_mapping" = [#builtin.int<1>, #builtin.int<0>]} : (!stencil.temp<[-1,65]x[-1,65]xf64>) -> f64 "stencil.return"(%7) : (f64) -> () }) : (!stencil.temp<[-1,65]x[-1,65]xf64>) -> !stencil.temp<[0,64]x[0,64]xf64> "stencil.store"(%5, %3) {"lb" = #stencil.index<0, 0>, "ub" = #stencil.index<64, 64>} : (!stencil.temp<[0,64]x[0,64]xf64>, !stencil.field<[-4,68]x[-4,68]xf64>) -> () @@ -195,7 +195,7 @@ builtin.module { // CHECK-NEXT: %4 = "stencil.load"(%2) : (!stencil.field<[-4,68]x[-4,68]xf64>) -> !stencil.temp<[-1,65]x[-1,65]xf64> // CHECK-NEXT: %5 = "stencil.apply"(%4) ({ // CHECK-NEXT: ^0(%6 : !stencil.temp<[-1,65]x[-1,65]xf64>): -// CHECK-NEXT: %7 = "stencil.access"(%6) {"offset" = #stencil.index<-1, 0>, "offset_mapping" = [#int<1>, #int<0>]} : (!stencil.temp<[-1,65]x[-1,65]xf64>) -> f64 +// CHECK-NEXT: %7 = "stencil.access"(%6) {"offset" = #stencil.index<-1, 0>, "offset_mapping" = [#builtin.int<1>, #builtin.int<0>]} : (!stencil.temp<[-1,65]x[-1,65]xf64>) -> f64 // CHECK-NEXT: "stencil.return"(%7) : (f64) -> () // CHECK-NEXT: }) : (!stencil.temp<[-1,65]x[-1,65]xf64>) -> !stencil.temp<[0,64]x[0,64]xf64> // CHECK-NEXT: "stencil.store"(%5, %3) {"lb" = #stencil.index<0, 0>, "ub" = #stencil.index<64, 64>} : (!stencil.temp<[0,64]x[0,64]xf64>, !stencil.field<[-4,68]x[-4,68]xf64>) -> () diff --git a/tests/filecheck/parser-printer/builtin_attrs.mlir b/tests/filecheck/parser-printer/builtin_attrs.mlir index 83767d679f..5d74bffda5 100644 --- a/tests/filecheck/parser-printer/builtin_attrs.mlir +++ b/tests/filecheck/parser-printer/builtin_attrs.mlir @@ -238,12 +238,6 @@ // CHECK: "type_attr" = index - "func.func"() ({}) {function_type = () -> (), - type_attr = !index, - sym_name = "index_type_prefix"} : () -> () - - // CHECK: "type_attr" = index - "func.func"() ({}) {function_type = () -> (), strided = strided<[1, 0x23, -23, -0x21, ?], offset: -3>, sym_name = "strided"} : () -> () diff --git a/tests/filecheck/parser-printer/unregistered_dialect.mlir b/tests/filecheck/parser-printer/unregistered_dialect.mlir index 6ce97aa415..28ce28f675 100644 --- a/tests/filecheck/parser-printer/unregistered_dialect.mlir +++ b/tests/filecheck/parser-printer/unregistered_dialect.mlir @@ -3,15 +3,17 @@ "builtin.module"() ({ %0 = "region_op"() ({ - %y = "op_with_res"() {otherattr = #unknown_attr>>} : () -> (i32) - %z = "op_with_operands"(%y, %y) : (i32, i32) -> !unknown_type<{[<()>]}> - "op"() {ab = !unknown_singleton_type} : () -> () + %x = "op_with_res"() {otherattr = #unknowndialect.unknown_attr>>} : () -> (i32) + %y = "op_with_res"() {otherattr = #unknowndialect>>} : () -> (i32) + %z = "op_with_operands"(%y, %y) : (i32, i32) -> !unknowndialect.unknown_type<{[<()>]}> + "op"() {ab = !unknowndialect.unknown_singleton_type} : () -> () }) {testattr = "foo"} : () -> i32 // CHECK: %{{.*}} = "region_op"() ({ - // CHECK-NEXT: %{{.*}} = "op_with_res"() {"otherattr" = #unknown_attr>>} : () -> i32 - // CHECK-NEXT: %{{.*}} = "op_with_operands"(%{{.*}}, %{{.*}}) : (i32, i32) -> !unknown_type<{[<()>]}> - // CHECK-NEXT: "op"() {"ab" = !unknown_singleton_type} : () -> () + // CHECK-NEXT: %{{.*}} = "op_with_res"() {"otherattr" = #unknowndialect.unknown_attr>>} : () -> i32 + // CHECK-NEXT: %{{.*}} = "op_with_res"() {"otherattr" = #unknowndialect>>} : () -> i32 + // CHECK-NEXT: %{{.*}} = "op_with_operands"(%{{.*}}, %{{.*}}) : (i32, i32) -> !unknowndialect.unknown_type<{[<()>]}> + // CHECK-NEXT: "op"() {"ab" = !unknowndialect.unknown_singleton_type} : () -> () // CHECK-NEXT: }) {"testattr" = "foo"} : () -> i32 }) : () -> () diff --git a/tests/filecheck/projects/riscv-backend-paper/add_snitch_stream.mlir b/tests/filecheck/projects/riscv-backend-paper/add_snitch_stream.mlir index 5f591ddf69..65e3524e12 100644 --- a/tests/filecheck/projects/riscv-backend-paper/add_snitch_stream.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/add_snitch_stream.mlir @@ -16,10 +16,10 @@ builtin.module { %A = riscv.li "a" : () -> !riscv.reg<> %B = riscv.li "b" : () -> !riscv.reg<> %C = riscv.li "c" : () -> !riscv.reg<> - %0 = "snitch_stream.stride_pattern"() {"ub" = [#int<2>, #int<3>], "strides" = [#int<24>, #int<8>], "dm" = #int<31>} : () -> !snitch_stream.stride_pattern_type - %1 = "snitch_stream.strided_read"(%A, %0) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> - %2 = "snitch_stream.strided_read"(%B, %0) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> - %3 = "snitch_stream.strided_write"(%C, %0) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> + %0 = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<2>, #builtin.int<3>], "strides" = [#builtin.int<24>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type + %1 = "snitch_stream.strided_read"(%A, %0) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> + %2 = "snitch_stream.strided_read"(%B, %0) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> + %3 = "snitch_stream.strided_write"(%C, %0) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> %4 = riscv.li 6 : () -> !riscv.reg<> "snitch_stream.generic"(%4, %1, %2, %3) <{"operandSegmentSizes" = array}> ({ ^0(%a : !riscv.freg, %b : !riscv.freg): diff --git a/tests/filecheck/projects/riscv-backend-paper/relu_snitch_stream.mlir b/tests/filecheck/projects/riscv-backend-paper/relu_snitch_stream.mlir index 67bbb57236..28d3fedbc3 100644 --- a/tests/filecheck/projects/riscv-backend-paper/relu_snitch_stream.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/relu_snitch_stream.mlir @@ -21,9 +21,9 @@ builtin.module { %zero_2 = riscv.li 0 : () -> !riscv.reg<> riscv.sw %zero, %zero_2, -8 : (!riscv.reg, !riscv.reg<>) -> () %zero_3 = riscv.fld %zero, -8 : (!riscv.reg) -> !riscv.freg<> - %stride_pattern = "snitch_stream.stride_pattern"() {"ub" = [#int<2>, #int<3>], "strides" = [#int<24>, #int<8>], "dm" = #int<31>} : () -> !snitch_stream.stride_pattern_type - %a_stream = "snitch_stream.strided_read"(%A, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> - %b_stream = "snitch_stream.strided_write"(%B, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> + %stride_pattern = "snitch_stream.stride_pattern"() {"ub" = [#builtin.int<2>, #builtin.int<3>], "strides" = [#builtin.int<24>, #builtin.int<8>], "dm" = #builtin.int<31>} : () -> !snitch_stream.stride_pattern_type + %a_stream = "snitch_stream.strided_read"(%A, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> + %b_stream = "snitch_stream.strided_write"(%B, %stride_pattern) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> %c6 = riscv.li 6 : () -> !riscv.reg<> "snitch_stream.generic"(%c6, %a_stream, %b_stream) <{"operandSegmentSizes" = array}> ({ ^0(%a : !riscv.freg): diff --git a/tests/filecheck/transforms/snitch_register_allocation.mlir b/tests/filecheck/transforms/snitch_register_allocation.mlir index d96ba1a146..ef33d0e84f 100644 --- a/tests/filecheck/transforms/snitch_register_allocation.mlir +++ b/tests/filecheck/transforms/snitch_register_allocation.mlir @@ -1,9 +1,9 @@ // RUN: xdsl-opt -p snitch-allocate-registers %s | filecheck %s %stride_pattern, %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) -%s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -%s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -%s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> +%s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +%s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +%s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> %c128 = riscv.li 128 : () -> !riscv.reg<> "snitch_stream.generic"(%c128, %s0, %s1, %s2) <{"operandSegmentSizes" = array}> ({ @@ -21,9 +21,9 @@ // CHECK: builtin.module { // CHECK-NEXT: %stride_pattern, %ptr0, %ptr1, %ptr2 = "test.op"() : () -> (!snitch_stream.stride_pattern_type, !riscv.reg<>, !riscv.reg<>, !riscv.reg<>) -// CHECK-NEXT: %s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #int<0>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -// CHECK-NEXT: %s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #int<1>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> -// CHECK-NEXT: %s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #int<2>, "rank" = #int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> +// CHECK-NEXT: %s0 = "snitch_stream.strided_read"(%ptr0, %stride_pattern) {"dm" = #builtin.int<0>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-NEXT: %s1 = "snitch_stream.strided_read"(%ptr1, %stride_pattern) {"dm" = #builtin.int<1>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.readable> +// CHECK-NEXT: %s2 = "snitch_stream.strided_write"(%ptr2, %stride_pattern) {"dm" = #builtin.int<2>, "rank" = #builtin.int<2>} : (!riscv.reg<>, !snitch_stream.stride_pattern_type) -> !stream.writable> // CHECK-NEXT: %c128 = riscv.li 128 : () -> !riscv.reg<> // CHECK-NEXT: "snitch_stream.generic"(%c128, %s0, %s1, %s2) <{"operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%{{.*}} : !riscv.freg, %{{.*}} : !riscv.freg): diff --git a/tests/interpreters/test_wgsl_printer.py b/tests/interpreters/test_wgsl_printer.py index 493c7361f5..37cbbdff6c 100644 --- a/tests/interpreters/test_wgsl_printer.py +++ b/tests/interpreters/test_wgsl_printer.py @@ -14,7 +14,7 @@ def test_gpu_global_id(): file = StringIO("") - global_id_x = gpu.GlobalIdOp(gpu.DimensionAttr.from_dimension("x")) + global_id_x = gpu.GlobalIdOp(gpu.DimensionAttr("x")) printer = WGSLPrinter() printer.print(global_id_x, file) @@ -25,7 +25,7 @@ def test_gpu_global_id(): def test_gpu_thread_id(): file = StringIO("") - thread_id_x = gpu.ThreadIdOp(gpu.DimensionAttr.from_dimension("x")) + thread_id_x = gpu.ThreadIdOp(gpu.DimensionAttr("x")) printer = WGSLPrinter() printer.print(thread_id_x, file) @@ -36,7 +36,7 @@ def test_gpu_thread_id(): def test_gpu_block_id(): file = StringIO("") - block_id_x = gpu.BlockIdOp(gpu.DimensionAttr.from_dimension("x")) + block_id_x = gpu.BlockIdOp(gpu.DimensionAttr("x")) printer = WGSLPrinter() printer.print(block_id_x, file) @@ -47,7 +47,7 @@ def test_gpu_block_id(): def test_gpu_grid_dim(): file = StringIO("") - num_workgroups = gpu.GridDimOp(gpu.DimensionAttr.from_dimension("x")) + num_workgroups = gpu.GridDimOp(gpu.DimensionAttr("x")) printer = WGSLPrinter() printer.print(num_workgroups, file) diff --git a/tests/test_attribute_definition.py b/tests/test_attribute_definition.py index a36b020c15..d13b6cd996 100644 --- a/tests/test_attribute_definition.py +++ b/tests/test_attribute_definition.py @@ -27,6 +27,19 @@ from xdsl.printer import Printer from xdsl.utils.exceptions import PyRDLAttrDefinitionError, VerifyException + +def test_wrong_attribute_type(): + with pytest.raises( + TypeError, + match="Class AbstractAttribute should either be a subclass of 'Data' or 'ParametrizedAttribute'", + ): + + @irdl_attr_definition + class AbstractAttribute(Attribute): # pyright: ignore[reportUnusedClass] + name = "test.wrong" + pass + + ################################################################################ # Data attributes ################################################################################ @@ -36,7 +49,7 @@ class BoolData(Data[bool]): """An attribute holding a boolean value.""" - name = "bool" + name = "test.bool" @classmethod def parse_parameter(cls, parser: AttrParser) -> bool: @@ -56,7 +69,7 @@ def print_parameter(self, printer: Printer): class IntData(Data[int]): """An attribute holding an integer value.""" - name = "int" + name = "test.int" @classmethod def parse_parameter(cls, parser: AttrParser) -> int: @@ -72,7 +85,7 @@ def print_parameter(self, printer: Printer): class StringData(Data[str]): """An attribute holding a string value.""" - name = "str" + name = "test.str" @classmethod def parse_parameter(cls, parser: AttrParser) -> str: @@ -90,7 +103,7 @@ def test_simple_data(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(b) - assert stream.getvalue() == "#bool" + assert stream.getvalue() == "#test.bool" @irdl_attr_definition @@ -99,7 +112,7 @@ class IntListData(Data[list[int]]): An attribute holding a list of integers. """ - name = "int_list" + name = "test.int_list" @classmethod def parse_parameter(cls, parser: AttrParser) -> list[int]: @@ -118,7 +131,7 @@ def test_non_class_data(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#int_list<[0, 1, 42]>" + assert stream.getvalue() == "#test.int_list<[0, 1, 42]>" ################################################################################ @@ -168,7 +181,7 @@ def test_signless_integer_attr(): @irdl_attr_definition class BoolWrapperAttr(ParametrizedAttribute): - name = "bool_wrapper" + name = "test.bool_wrapper" param: ParameterDef[BoolData] @@ -179,14 +192,14 @@ def test_bose_constraint(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#bool_wrapper<#bool>" + assert stream.getvalue() == "#test.bool_wrapper<#test.bool>" def test_base_constraint_fail(): """Test the verifier of a union constraint.""" with pytest.raises(Exception) as e: BoolWrapperAttr([StringData("foo")]) - assert e.value.args[0] == "#str should be of base attribute bool" + assert e.value.args[0] == "#test.str should be of base attribute test.bool" ################################################################################ @@ -196,7 +209,7 @@ def test_base_constraint_fail(): @irdl_attr_definition class BoolOrIntParamAttr(ParametrizedAttribute): - name = "bool_or_int" + name = "test.bool_or_int" param: ParameterDef[BoolData | IntData] @@ -207,7 +220,7 @@ def test_union_constraint_left(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#bool_or_int<#bool>" + assert stream.getvalue() == "#test.bool_or_int<#test.bool>" def test_union_constraint_right(): @@ -216,14 +229,14 @@ def test_union_constraint_right(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#bool_or_int<#int<42>>" + assert stream.getvalue() == "#test.bool_or_int<#test.int<42>>" def test_union_constraint_fail(): """Test the verifier of a union constraint.""" with pytest.raises(Exception) as e: BoolOrIntParamAttr([StringData("foo")]) - assert e.value.args[0] == "Unexpected attribute #str" + assert e.value.args[0] == "Unexpected attribute #test.str" ################################################################################ @@ -243,7 +256,7 @@ def verify(self, attr: Attribute, constraint_vars: dict[str, Attribute]) -> None @irdl_attr_definition class PositiveIntAttr(ParametrizedAttribute): - name = "positive_int" + name = "test.positive_int" param: ParameterDef[Annotated[IntData, PositiveIntConstr()]] @@ -254,7 +267,7 @@ def test_annotated_constraint(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#positive_int<#int<42>>" + assert stream.getvalue() == "#test.positive_int<#test.int<42>>" def test_annotated_constraint_fail(): @@ -273,7 +286,7 @@ def test_annotated_constraint_fail(): @irdl_attr_definition class ParamWrapperAttr(Generic[_T], ParametrizedAttribute): - name = "int_or_bool_generic" + name = "test.int_or_bool_generic" param: ParameterDef[_T] @@ -284,7 +297,7 @@ def test_typevar_attribute_int(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#int_or_bool_generic<#int<42>>" + assert stream.getvalue() == "#test.int_or_bool_generic<#test.int<42>>" def test_typevar_attribute_bool(): @@ -293,19 +306,19 @@ def test_typevar_attribute_bool(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#int_or_bool_generic<#bool>" + assert stream.getvalue() == "#test.int_or_bool_generic<#test.bool>" def test_typevar_attribute_fail(): """Test that the verifier of an generic attribute can fail.""" with pytest.raises(Exception) as e: ParamWrapperAttr([StringData("foo")]) - assert e.value.args[0] == "Unexpected attribute #str" + assert e.value.args[0] == "Unexpected attribute #test.str" @irdl_attr_definition class ParamConstrAttr(ParametrizedAttribute): - name = "param_constr" + name = "test.param_constr" param: ParameterDef[ParamWrapperAttr[IntData]] @@ -316,7 +329,10 @@ def test_param_attr_constraint(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#param_constr<#int_or_bool_generic<#int<42>>>" + assert ( + stream.getvalue() + == "#test.param_constr<#test.int_or_bool_generic<#test.int<42>>>" + ) def test_param_attr_constraint_fail(): @@ -326,7 +342,7 @@ def test_param_attr_constraint_fail(): """ with pytest.raises(Exception) as e: ParamConstrAttr([ParamWrapperAttr([BoolData(True)])]) - assert e.value.args[0] == "#bool should be of base attribute int" + assert e.value.args[0] == "#test.bool should be of base attribute test.int" _U = TypeVar("_U", bound=IntData) @@ -334,7 +350,7 @@ def test_param_attr_constraint_fail(): @irdl_attr_definition class NestedParamWrapperAttr(Generic[_U], ParametrizedAttribute): - name = "nested_param_wrapper" + name = "test.nested_param_wrapper" param: ParameterDef[ParamWrapperAttr[_U]] @@ -348,7 +364,10 @@ def test_nested_generic_constraint(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#nested_param_wrapper<#int_or_bool_generic<#int<42>>>" + assert ( + stream.getvalue() + == "#test.nested_param_wrapper<#test.int_or_bool_generic<#test.int<42>>>" + ) def test_nested_generic_constraint_fail(): @@ -358,12 +377,12 @@ def test_nested_generic_constraint_fail(): """ with pytest.raises(Exception) as e: NestedParamWrapperAttr([ParamWrapperAttr([BoolData(True)])]) - assert e.value.args[0] == "#bool should be of base attribute int" + assert e.value.args[0] == "#test.bool should be of base attribute test.int" @irdl_attr_definition class NestedParamConstrAttr(ParametrizedAttribute): - name = "nested_param_constr" + name = "test.nested_param_constr" param: ParameterDef[NestedParamWrapperAttr[Annotated[IntData, PositiveIntConstr()]]] @@ -380,7 +399,7 @@ def test_nested_param_attr_constraint(): p.print_attribute(attr) assert ( stream.getvalue() - == "#nested_param_constr<#nested_param_wrapper<#int_or_bool_generic<#int<42>>>>" + == "#test.nested_param_constr<#test.nested_param_wrapper<#test.int_or_bool_generic<#test.int<42>>>>" ) @@ -404,7 +423,7 @@ def test_nested_param_attr_constraint_fail(): @irdl_attr_definition class MissingGenericDataData(Data[_MissingGenericDataData]): - name = "missing_genericdata" + name = "test.missing_genericdata" @classmethod def parse_parameter(cls, parser: AttrParser) -> _MissingGenericDataData: @@ -418,7 +437,7 @@ def verify(self) -> None: class MissingGenericDataDataWrapper(ParametrizedAttribute): - name = "missing_genericdata_wrapper" + name = "test.missing_genericdata_wrapper" param: ParameterDef[MissingGenericDataData[int]] @@ -431,7 +450,7 @@ def test_data_with_generic_missing_generic_data_failure(): with pytest.raises(Exception) as e: irdl_attr_definition(MissingGenericDataDataWrapper) assert e.value.args[0] == ( - "Generic `Data` type 'missing_genericdata' cannot be converted to " + "Generic `Data` type 'test.missing_genericdata' cannot be converted to " "an attribute constraint. Consider making it inherit from " "`GenericData` instead of `Data`." ) @@ -457,7 +476,7 @@ def verify(self, attr: Attribute, constraint_vars: dict[str, Attribute]) -> None @irdl_attr_definition class ListData(GenericData[list[A]]): - name = "list" + name = "test.list" @classmethod def parse_parameter(cls, parser: AttrParser) -> list[A]: @@ -499,12 +518,15 @@ def test_generic_data_verifier(self): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#list<[#bool, #list<[#bool]>]>" + assert ( + stream.getvalue() + == "#test.list<[#test.bool, #test.list<[#test.bool]>]>" + ) @irdl_attr_definition class ListDataWrapper(ParametrizedAttribute): - name = "list_wrapper" + name = "test.list_wrapper" val: ParameterDef[ListData[BoolData]] @@ -517,7 +539,10 @@ def test_generic_data_wrapper_verifier(): stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) - assert stream.getvalue() == "#list_wrapper<#list<[#bool, #bool]>>" + assert ( + stream.getvalue() + == "#test.list_wrapper<#test.list<[#test.bool, #test.bool]>>" + ) def test_generic_data_wrapper_verifier_failure(): @@ -527,12 +552,19 @@ def test_generic_data_wrapper_verifier_failure(): """ with pytest.raises(VerifyException) as e: ListDataWrapper([ListData([BoolData(True), ListData([BoolData(False)])])]) - assert e.value.args[0] == "#list<[#bool]> should be of base attribute bool" + assert ( + e.value.args[0] + == "#test.list<[#test.bool]> should be of base attribute test.bool" + ) + assert ( + e.value.args[0] + == "#test.list<[#test.bool]> should be of base attribute test.bool" + ) @irdl_attr_definition class ListDataNoGenericsWrapper(ParametrizedAttribute): - name = "list_no_generics_wrapper" + name = "test.list_no_generics_wrapper" val: ParameterDef[AnyListData] @@ -549,7 +581,7 @@ def test_generic_data_no_generics_wrapper_verifier(): p.print_attribute(attr) assert ( stream.getvalue() - == "#list_no_generics_wrapper<#list<[#bool, #list<[#bool]>]>>" + == "#test.list_no_generics_wrapper<#test.list<[#test.bool, #test.list<[#test.bool]>]>>" ) diff --git a/tests/test_operation_definition.py b/tests/test_operation_definition.py index 32cd31dbc9..9c8717b985 100644 --- a/tests/test_operation_definition.py +++ b/tests/test_operation_definition.py @@ -155,7 +155,7 @@ class AttrOp(IRDLOperation): def test_attr_verify(): op = AttrOp.create(attributes={"attr": IntAttr(1)}) with pytest.raises( - VerifyException, match="#int<1> should be of base attribute string" + VerifyException, match="#builtin.int<1> should be of base attribute string" ): op.verify() diff --git a/tests/test_printer.py b/tests/test_printer.py index 1662a3ff4e..aa9315b00b 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -637,7 +637,7 @@ def test_missing_custom_format(): @irdl_attr_definition class CustomFormatAttr(ParametrizedAttribute): - name = "custom" + name = "test.custom" attr: ParameterDef[IntAttr] @@ -668,13 +668,13 @@ def test_custom_format_attr(): """ prog = """\ "builtin.module"() ({ - "any"() {"attr" = #custom} : () -> () + "any"() {"attr" = #test.custom} : () -> () }) : () -> () """ expected = """\ "builtin.module"() ({ - "any"() {"attr" = #custom} : () -> () + "any"() {"attr" = #test.custom} : () -> () }) : () -> ()""" ctx = MLContext() diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 3f017c111c..8d221aff1e 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -266,7 +266,7 @@ def verify(self, attr: Attribute, constraint_vars: dict[str, Attribute]) -> None @irdl_attr_definition class IntAttr(Data[int]): - name = "int" + name = "builtin.int" @classmethod def parse_parameter(cls, parser: AttrParser) -> int: @@ -291,7 +291,7 @@ class Signedness(Enum): @irdl_attr_definition class SignednessAttr(Data[Signedness]): - name = "signedness" + name = "builtin.signedness" @classmethod def parse_parameter(cls, parser: AttrParser) -> Signedness: @@ -1315,6 +1315,7 @@ class UnregisteredAttr(ParametrizedAttribute, ABC): attr_name: ParameterDef[StringAttr] is_type: ParameterDef[IntAttr] + is_opaque: ParameterDef[IntAttr] value: ParameterDef[StringAttr] """ This parameter is non-null is the attribute is a type, and null otherwise. @@ -1324,15 +1325,18 @@ def __init__( self, attr_name: str | StringAttr, is_type: bool | IntAttr, + is_opaque: bool | IntAttr, value: str | StringAttr, ): if isinstance(attr_name, str): attr_name = StringAttr(attr_name) if isinstance(is_type, bool): is_type = IntAttr(int(is_type)) + if isinstance(is_opaque, bool): + is_opaque = IntAttr(int(is_opaque)) if isinstance(value, str): value = StringAttr(value) - super().__init__([attr_name, is_type, value]) + super().__init__([attr_name, is_type, is_opaque, value]) @classmethod def with_name_and_type(cls, name: str, is_type: bool) -> type[UnregisteredAttr]: diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index 493c0c034a..9420f07467 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Generic, TypeVar +from typing import Literal, TypeVar from xdsl.dialects import memref from xdsl.dialects.builtin import ( @@ -18,7 +18,9 @@ from xdsl.ir import ( Attribute, Block, + Data, Dialect, + OpaqueSyntaxAttribute, Operation, OpResult, ParametrizedAttribute, @@ -66,85 +68,39 @@ class AsyncTokenType(ParametrizedAttribute, TypeAttribute): name = "gpu.async.token" -@irdl_attr_definition -class _AllReduceOperationAttr(ParametrizedAttribute): - name = "all_reduce_op" +class AllReduceOpAttr( + Data[Literal["add", "and", "max", "min", "mul", "or", "xor"]], OpaqueSyntaxAttribute +): + name = "gpu.all_reduce_op" param: ParameterDef[StringAttr] - def print_parameters(self, printer: Printer) -> None: - printer.print(f"all_reduce_op {self.param.data}") - - -@irdl_attr_definition -class _DimensionAttr(ParametrizedAttribute): - name = "dim" - - param: ParameterDef[StringAttr] - - def print_parameters(self, printer: Printer) -> None: - printer.print(f"dim {self.param.data}") - - -T = TypeVar("T", bound=_AllReduceOperationAttr | _DimensionAttr, covariant=True) - - -@irdl_attr_definition -class _GPUAttr(ParametrizedAttribute, Generic[T]): - name = "gpu" - - value: ParameterDef[T] + def print_parameter(self, printer: Printer) -> None: + printer.print(f" {self.data}") @classmethod - def parse_parameters(cls, parser: AttrParser) -> list[Attribute]: - parser.parse_characters( - "<", - ": gpu attributes currently have the #gpu syntax.", - ) - if parser.parse_optional_keyword("dim"): - attrtype = _DimensionAttr - vtok = parser.parse_optional_identifier() - if vtok not in ["x", "y", "z"]: - parser.raise_error( - f"Unexpected dim {vtok}. A gpu dim can only be x, y, or z", - ) + def parse_parameter( + cls, parser: AttrParser + ) -> Literal["add", "and", "max", "min", "mul", "or", "xor"]: + val = parser.parse_identifier() + if val in ("add", "and", "max", "min", "mul", "or", "xor"): + return val + parser.raise_error("Expected add, and, max, min, mul, or, or xor.") - elif parser.parse_optional_keyword("all_reduce_op"): - attrtype = _AllReduceOperationAttr - vtok = parser.parse_optional_identifier() - if vtok not in ["add", "and", "max", "min", "mul", "or", "xor"]: - parser.raise_error( - f"Unexpected op {vtok}. A gpu all_reduce_op can only be add, " - "and, max, min, mul, or, or xor ", - ) - else: - parser.raise_error("'dim' or 'all_reduce_op' expected") - parser.parse_characters( - ">", - ". gpu attributes currently have the #gpu syntax.", - ) - return [attrtype([StringAttr(vtok)])] - @staticmethod - def from_op(value: str) -> AllReduceOperationAttr: - return AllReduceOperationAttr([_AllReduceOperationAttr([StringAttr(value)])]) - - @property - def data(self) -> str: - return self.value.param.data +class DimensionAttr(Data[Literal["x", "y", "z"]], OpaqueSyntaxAttribute): + name = "gpu.dim" - @staticmethod - def from_dimension(value: str) -> DimensionAttr: - return DimensionAttr([_DimensionAttr([StringAttr(value)])]) - - def print_parameters(self, printer: Printer) -> None: - printer.print_string("<") - self.value.print_parameters(printer) - printer.print_string(">") + def print_parameter(self, printer: Printer) -> None: + printer.print(f" {self.data}") + @classmethod + def parse_parameter(cls, parser: AttrParser) -> Literal["x", "y", "z"]: + val = parser.parse_identifier() + if val in ("x", "y", "z"): + return val + parser.raise_error("Expected x, y or z.") -DimensionAttr = _GPUAttr[_DimensionAttr] -AllReduceOperationAttr = _GPUAttr[_AllReduceOperationAttr] _Element = TypeVar("_Element", bound=Attribute, covariant=True) @@ -201,7 +157,7 @@ def __init__( @irdl_op_definition class AllReduceOp(IRDLOperation): name = "gpu.all_reduce" - op: AllReduceOperationAttr | None = opt_prop_def(AllReduceOperationAttr) + op: AllReduceOpAttr | None = opt_prop_def(AllReduceOpAttr) uniform: UnitAttr | None = opt_prop_def(UnitAttr) operand: Operand = operand_def(Attribute) result: OpResult = result_def(Attribute) @@ -211,7 +167,7 @@ class AllReduceOp(IRDLOperation): @staticmethod def from_op( - op: AllReduceOperationAttr, + op: AllReduceOpAttr, operand: SSAValue | Operation, uniform: UnitAttr | None = None, ): @@ -730,10 +686,6 @@ def verify_(self) -> None: ) -# _GPUAttr has to be registered instead of DimensionAttr and AllReduceOperationAttr here. -# This is a hack to fit MLIR's syntax in xDSL's way of parsing attributes, without making GPU builtin. -# Hopefully MLIR will parse it in a more xDSL-friendly way soon, so all that can be factored in proper xDSL -# atrributes. GPU = Dialect( "gpu", [ @@ -763,5 +715,5 @@ def verify_(self) -> None: ThreadIdOp, YieldOp, ], - [_GPUAttr], + [AllReduceOpAttr, DimensionAttr], ) diff --git a/xdsl/interpreters/experimental/wgsl_printer.py b/xdsl/interpreters/experimental/wgsl_printer.py index 00e994f719..fec0c84c74 100644 --- a/xdsl/interpreters/experimental/wgsl_printer.py +++ b/xdsl/interpreters/experimental/wgsl_printer.py @@ -85,7 +85,7 @@ def _(self, op: gpu.ReturnOp, out_stream: IO[str]): @print.register def _(self, op: gpu.BlockIdOp, out_stream: IO[str]): - dim = str(op.dimension.value.param).strip('"') + dim = str(op.dimension.data).strip('"') name_hint = self.wgsl_name(op.result) out_stream.write( f""" @@ -94,7 +94,7 @@ def _(self, op: gpu.BlockIdOp, out_stream: IO[str]): @print.register def _(self, op: gpu.ThreadIdOp, out_stream: IO[str]): - dim = str(op.dimension.value.param).strip('"') + dim = str(op.dimension.data).strip('"') name_hint = self.wgsl_name(op.result) out_stream.write( f""" @@ -103,7 +103,7 @@ def _(self, op: gpu.ThreadIdOp, out_stream: IO[str]): @print.register def _(self, op: gpu.GridDimOp, out_stream: IO[str]): - dim = str(op.dimension.value.param).strip('"') + dim = str(op.dimension.data).strip('"') name_hint = self.wgsl_name(op.result) out_stream.write( f""" @@ -112,7 +112,7 @@ def _(self, op: gpu.GridDimOp, out_stream: IO[str]): @print.register def _(self, op: gpu.BlockDimOp, out_stream: IO[str]): - dim = str(op.dimension.value.param).strip('"') + dim = str(op.dimension.data).strip('"') name_hint = self.wgsl_name(op.result) out_stream.write( f""" @@ -121,7 +121,7 @@ def _(self, op: gpu.BlockDimOp, out_stream: IO[str]): @print.register def _(self, op: gpu.GlobalIdOp, out_stream: IO[str]): - dim = str(op.dimension.value.param).strip('"') + dim = str(op.dimension.data).strip('"') name_hint = self.wgsl_name(op.result) out_stream.write( f""" diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index 50d57402ab..62329d9ec5 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -356,21 +356,6 @@ def __hash__(self) -> int: return hash(id(self)) -@dataclass -class TypeAttribute: - """ - This class should only be inherited by classes inheriting Attribute. - This class is only used for printing attributes in the MLIR format, - inheriting this class prefix the attribute by `!` instead of `#`. - """ - - def __post_init__(self): - if not isinstance(self, Attribute): - raise TypeError( - "TypeAttribute should only be inherited by classes inheriting Attribute" - ) - - A = TypeVar("A", bound="Attribute") @@ -387,6 +372,8 @@ class Attribute(ABC): def __post_init__(self): self._verify() + if not isinstance(self, Data | ParametrizedAttribute): + raise TypeError("Attributes should only be Data or ParameterizedAttribute") def _verify(self): self.verify() @@ -407,6 +394,26 @@ def __str__(self) -> str: return res.getvalue() +class TypeAttribute(Attribute): + """ + This class should only be inherited by classes inheriting Attribute. + This class is only used for printing attributes in the MLIR format, + inheriting this class prefix the attribute by `!` instead of `#`. + """ + + pass + + +class OpaqueSyntaxAttribute(Attribute): + """ + This class should only be inherited by classes inheriting Attribute. + This class is only used for printing attributes in the opaque form, + as described at https://mlir.llvm.org/docs/LangRef/#dialect-attribute-values. + """ + + pass + + DataElement = TypeVar("DataElement", covariant=True) AttributeCovT = TypeVar("AttributeCovT", bound=Attribute, covariant=True) diff --git a/xdsl/irdl/irdl.py b/xdsl/irdl/irdl.py index df75113554..752b213fbc 100644 --- a/xdsl/irdl/irdl.py +++ b/xdsl/irdl/irdl.py @@ -2262,7 +2262,7 @@ def irdl_attr_definition(cls: TypeAttributeInvT) -> TypeAttributeInvT: dict(cls.__dict__), ) ) - raise Exception( + raise TypeError( f"Class {cls.__name__} should either be a subclass of 'Data' or " "'ParametrizedAttribute'" ) diff --git a/xdsl/parser/attribute_parser.py b/xdsl/parser/attribute_parser.py index b87a85037c..82cf52bf11 100644 --- a/xdsl/parser/attribute_parser.py +++ b/xdsl/parser/attribute_parser.py @@ -92,7 +92,7 @@ def parse_optional_type(self) -> Attribute | None: if ( token := self._parse_optional_token(Token.Kind.EXCLAMATION_IDENT) ) is not None: - return self._parse_dialect_type_or_attribute_inner(token.text[1:], True) + return self._parse_dialect_type_or_attribute(token.text[1:], True) return self._parse_optional_builtin_type() def parse_type(self) -> Attribute: @@ -124,7 +124,7 @@ def parse_optional_attribute(self) -> Attribute | None: | [^[]<>(){}\0]+ """ if (token := self._parse_optional_token(Token.Kind.HASH_IDENT)) is not None: - return self._parse_dialect_type_or_attribute_inner(token.text[1:], False) + return self._parse_dialect_type_or_attribute(token.text[1:], False) return self._parse_optional_builtin_attr() def parse_attribute(self) -> Attribute: @@ -172,18 +172,22 @@ def parse_optional_dictionary_attr_dict(self) -> dict[str, Attribute]: return dict() return dict(attrs) - def _parse_dialect_type_or_attribute_inner( - self, attr_name: str, is_type: bool = True - ) -> Attribute: + def _parse_dialect_type_or_attribute_body( + self, + attr_name: str, + is_type: bool, + is_opaque: bool, + starting_opaque_pos: Position | None, + ): """ - Parse the contents of a dialect type or attribute, with format: - dialect-attr-contents ::= `<` dialect-attribute-contents+ `>` - | `(` dialect-attribute-contents+ `)` - | `[` dialect-attribute-contents+ `]` - | `{` dialect-attribute-contents+ `}` + Parse the contents of an attribute or type, with syntax: + dialect-attr-contents ::= `<` dialect-attr-contents+ `>` + | `(` dialect-attr-contents+ `)` + | `[` dialect-attr-contents+ `]` + | `{` dialect-attr-contents+ `}` | [^[]<>(){}\0]+ - The contents will be parsed by a user-defined parser, or by a generic parser - if the dialect attribute/type is not registered. + In the case where the attribute or type is using the opaque syntax, + the attribute or type mnemonic should have already been parsed. """ attr_def = self.ctx.get_optional_attr( attr_name, @@ -191,32 +195,75 @@ def _parse_dialect_type_or_attribute_inner( ) if attr_def is None: self.raise_error(f"'{attr_name}' is not registered") - - # Pass the task of parsing parameters on to the attribute/type definition if issubclass(attr_def, UnregisteredAttr): - body = self._parse_unregistered_attr_body() - return attr_def(attr_name, is_type, body) - if issubclass(attr_def, ParametrizedAttribute): + if not is_opaque: + if self.parse_optional_punctuation("<") is None: + return attr_def(attr_name, is_type, is_opaque, "") + body = self._parse_unregistered_attr_body(starting_opaque_pos) + attr = attr_def(attr_name, is_type, is_opaque, body) + if not is_opaque: + self.parse_punctuation(">") + return attr + + elif issubclass(attr_def, ParametrizedAttribute): param_list = attr_def.parse_parameters(self) return attr_def.new(param_list) - if issubclass(attr_def, Data): + elif issubclass(attr_def, Data): param: Any = attr_def.parse_parameter(self) return cast(Data[Any], attr_def(param)) - assert False, "Attributes are either ParametrizedAttribute or Data." + else: + raise TypeError("Attributes are either ParametrizedAttribute or Data.") + + def _parse_dialect_type_or_attribute( + self, attr_or_dialect_name: str, is_type: bool = True + ) -> Attribute: + """ + Parse the contents of a dialect type or attribute, with format: + dialect-attr-contents ::= `<` dialect-attr-contents+ `>` + | `(` dialect-attr-contents+ `)` + | `[` dialect-attr-contents+ `]` + | `{` dialect-attr-contents+ `}` + | [^[]<>(){}\0]+ + The contents will be parsed by a user-defined parser, or by a generic parser + if the dialect attribute/type is not registered. - def _parse_unregistered_attr_body(self) -> str: + In the case that the type or attribute is using the opaque syntax (where the + identifier parsed is the dialect name), this function will parse the opaque + attribute with the following format: + opaque-attr-contents ::= `<` bare-ident dialect-attr-contents+ `>` + """ + is_opaque = "." not in attr_or_dialect_name + starting_opaque_pos = None + if is_opaque: + self.parse_punctuation("<") + attr_name_token = self._parse_token( + Token.Kind.BARE_IDENT, "Expected attribute name." + ) + starting_opaque_pos = attr_name_token.span.end + + attr_or_dialect_name += "." + attr_name_token.text + + attr = self._parse_dialect_type_or_attribute_body( + attr_or_dialect_name, is_type, is_opaque, starting_opaque_pos + ) + + if is_opaque: + self.parse_punctuation(">") + + return attr + + def _parse_unregistered_attr_body(self, start_pos: Position | None) -> str: """ Parse the body of an unregistered attribute, which is a balanced string for `<`, `(`, `[`, `{`, and may contain string literals. + The body ends when no parentheses are opened, and an `>` is encountered. """ - start_token = self._parse_optional_token(Token.Kind.LESS) - if start_token is None: - return "" - start_pos = start_token.span.start + if start_pos is None: + start_pos = self.pos end_pos: Position = start_pos - symbols_stack = [Token.Kind.LESS] + symbols_stack: list[Token.Kind] = [] parentheses = { Token.Kind.GREATER: Token.Kind.LESS, Token.Kind.R_PAREN: Token.Kind.L_PAREN, @@ -238,17 +285,31 @@ def _parse_unregistered_attr_body(self) -> str: continue # Closing a parenthesis - if (token := self._parse_optional_token_in(parentheses.keys())) is not None: + if (token := self._current_token).kind in parentheses.keys(): closing = parentheses[token.kind] + + # If we don't have any open parenthesis, either we end the parsing if + # the parenthesis is a `>`, or we raise an error. + if len(symbols_stack) == 0: + if token.kind == Token.Kind.GREATER: + end_pos = self.pos + break + self.raise_error( + "Unexpected closing parenthesis " + f"{parentheses_names[token.kind]} in attribute body!", + self._current_token.span, + ) + + # If we have an open parenthesis, check that we are closing it + # with the right parenthesis kind. if symbols_stack[-1] != closing: self.raise_error( - f"Mismatched {parentheses_names[token.kind]} in attribute body!", + "Unexpected closing parenthesis " + f"{parentheses_names[token.kind]} in attribute body! {symbols_stack}", self._current_token.span, ) symbols_stack.pop() - if len(symbols_stack) == 0: - end_pos = token.span.end - break + self._consume_token() continue # Checking for unexpected EOF diff --git a/xdsl/printer.py b/xdsl/printer.py index 62956e48d3..2ebcceeee5 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -52,6 +52,7 @@ Block, BlockArgument, Data, + OpaqueSyntaxAttribute, Operation, ParametrizedAttribute, Region, @@ -640,20 +641,34 @@ def print_int_or_question(value: IntAttr | NoneAttr) -> None: if isinstance(attribute, UnregisteredAttr): # Do not print `!` or `#` for unregistered builtin attributes self.print("!" if attribute.is_type.data else "#") - self.print(attribute.attr_name.data, attribute.value.data) + if attribute.is_opaque.data: + self.print(attribute.attr_name.data.replace(".", "<", 1)) + self.print(attribute.value.data) + self.print(">") + else: + self.print(attribute.attr_name.data) + if attribute.value.data: + self.print("<") + self.print(attribute.value.data) + self.print(">") return # Print dialect attributes self.print("!" if isinstance(attribute, TypeAttribute) else "#") - self.print(attribute.name) + + if isinstance(attribute, OpaqueSyntaxAttribute): + self.print(attribute.name.replace(".", "<", 1)) + else: + self.print(attribute.name) if isinstance(attribute, Data): attribute.print_parameter(self) - return - assert isinstance(attribute, ParametrizedAttribute) + elif isinstance(attribute, ParametrizedAttribute): + attribute.print_parameters(self) - attribute.print_parameters(self) + if isinstance(attribute, OpaqueSyntaxAttribute): + self.print(">") return def print_successors(self, successors: list[Block]):