diff --git a/tests/filecheck/dialects/scf/for_args_types.mlir b/tests/filecheck/dialects/scf/for_args_types.mlir index b8f24671d5..3cdf4b688d 100644 --- a/tests/filecheck/dialects/scf/for_args_types.mlir +++ b/tests/filecheck/dialects/scf/for_args_types.mlir @@ -3,7 +3,7 @@ "builtin.module"() ({ %lbi = "test.op"() : () -> !test.type<"int"> %x:2 = "test.op"() : () -> (index, index) // ub, step -// CHECK: !test.type<"int"> should be of base attribute index +// CHECK: operand at position 0 does not verify "scf.for"(%lbi, %x#0, %x#1) ({ ^0(%iv : index): "scf.yield"() : () -> () @@ -21,3 +21,49 @@ }) : () -> () // CHECK: Expected induction var to be same type as bounds and step + +// ----- + +"builtin.module"() ({ + %lbi = "test.op"() : () -> si32 + %x:2 = "test.op"() : () -> (index, index) // ub, step +// CHECK: operand at position 0 does not verify + "scf.for"(%lbi, %x#0, %x#1) ({ + ^0(%iv : index): + "scf.yield"() : () -> () + }) : (si32, index, index) -> () +}) : () -> () + +// ----- + +"builtin.module"() ({ + %x:3 = "test.op"() : () -> (index, index, index) // lb, ub, step + "scf.for"(%x#0, %x#1, %x#2) ({ + ^0(%iv : i32): + "scf.yield"() : () -> () + }) : (index, index, index) -> () +}) : () -> () + +// CHECK: Expected induction var to be same type as bounds and step + +// ----- + +"builtin.module"() ({ + %x:3 = "test.op"() : () -> (si32, si32, si32) // lb, ub, step +// CHECK: operand at position 0 does not verify + "scf.for"(%x#0, %x#1, %x#2) ({ + ^0(%iv : si32): + "scf.yield"() : () -> () + }) : (si32, si32, si32) -> () +}) : () -> () + +// ----- + +"builtin.module"() ({ + %x:3 = "test.op"() : () -> (i32, i32, i32) // lb, ub, step +// CHECK: Expected induction var to be same type as bounds and step + "scf.for"(%x#0, %x#1, %x#2) ({ + ^0(%iv : index): + "scf.yield"() : () -> () + }) : (i32, i32, i32) -> () +}) : () -> () diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/scf/for_custom_non_index_iv.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/scf/for_custom_non_index_iv.mlir new file mode 100644 index 0000000000..a111310b34 --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/scf/for_custom_non_index_iv.mlir @@ -0,0 +1,26 @@ +// RUN: xdsl-opt %s | xdsl-opt | mlir-opt | filecheck %s + +%lb = arith.constant 0 : i32 +%ub = arith.constant 42 : i32 +%step = arith.constant 7 : i32 +%sum_init = arith.constant 36 : i32 +%sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_init) -> (i32) : i32 { + %sum_new = arith.addi %sum_iter, %iv : i32 + scf.yield %sum_new : i32 +} + +scf.for %iv = %lb to %ub step %step : i32 { +} + +// CHECK: module { +// CHECK-NEXT: %{{.*}} = arith.constant 0 : i32 +// CHECK-NEXT: %{{.*}} = arith.constant 42 : i32 +// CHECK-NEXT: %{{.*}} = arith.constant 7 : i32 +// CHECK-NEXT: %{{.*}} = arith.constant 36 : i32 +// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (i32) : i32 { +// CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i32 +// CHECK-NEXT: scf.yield %{{.*}} : i32 +// CHECK-NEXT: } +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 { +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/scf/for_generic_non_index_iv.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/scf/for_generic_non_index_iv.mlir new file mode 100644 index 0000000000..ac82d0bfc7 --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/scf/for_generic_non_index_iv.mlir @@ -0,0 +1,33 @@ +// RUN: xdsl-opt %s | mlir-opt --mlir-print-op-generic | xdsl-opt --print-op-generic | filecheck %s + +"builtin.module"() ({ + %lb = "arith.constant"() {"value" = 0 : i32} : () -> i32 + %ub = "arith.constant"() {"value" = 42 : i32} : () -> i32 + %step = "arith.constant"() {"value" = 7 : i32} : () -> i32 + %sum_init = "arith.constant"() {"value" = 36 : i32} : () -> i32 + %sum = "scf.for"(%lb, %ub, %step, %sum_init) ({ + ^0(%iv : i32, %sum_iter : i32): + %sum_new = "arith.addi"(%sum_iter, %iv) : (i32, i32) -> i32 + "scf.yield"(%sum_new) : (i32) -> () + }) : (i32, i32, i32, i32) -> i32 + "scf.for"(%lb, %ub, %step) ({ + ^bb0(%iv: i32): + "scf.yield"() : () -> () + }) : (i32, i32, i32) -> () +}) : () -> () + +// CHECK: "builtin.module"() ({ +// CHECK-NEXT: %{{.*}} = "arith.constant"() <{"value" = 0 : i32}> : () -> i32 +// CHECK-NEXT: %{{.*}} = "arith.constant"() <{"value" = 42 : i32}> : () -> i32 +// CHECK-NEXT: %{{.*}} = "arith.constant"() <{"value" = 7 : i32}> : () -> i32 +// CHECK-NEXT: %{{.*}} = "arith.constant"() <{"value" = 36 : i32}> : () -> i32 +// CHECK-NEXT: %{{.*}} = "scf.for"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ({ +// CHECK-NEXT: ^0(%{{.*}} : i32, %{{.*}} : i32): +// CHECK-NEXT: %{{.*}} = "arith.addi"(%{{.*}}, %{{.*}}) : (i32, i32) -> i32 +// CHECK-NEXT: "scf.yield"(%{{.*}}) : (i32) -> () +// CHECK-NEXT: }) : (i32, i32, i32, i32) -> i32 +// CHECK-NEXT: "scf.for"(%{{.*}}, %{{.*}}, %{{.*}}) ({ +// CHECK-NEXT: ^1(%{{.*}}: i32): +// CHECK-NEXT: "scf.yield"() : () -> () +// CHECK-NEXT: }) : (i32, i32, i32) -> () +// CHECK-NEXT: }) : () -> () diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index a3ca2c47d7..b22b366380 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -36,6 +36,7 @@ from xdsl.irdl import ( AllOf, AnyAttr, + AnyOf, AttrConstraint, GenericData, IRDLOperation, @@ -340,6 +341,15 @@ def __init__( i1 = IntegerType(1) +SignlessIntegerConstraint = ParamAttrConstraint( + IntegerType, [IntAttr, SignednessAttr(Signedness.SIGNLESS)] +) +"""Type constraint for signless IntegerType.""" + +AnySignlessIntegerType: TypeAlias = Annotated[IntegerType, SignlessIntegerConstraint] +"""Type alias constrained to signless IntegerType.""" + + @irdl_attr_definition class UnitAttr(ParametrizedAttribute): name = "unit" @@ -364,6 +374,11 @@ class IndexType(ParametrizedAttribute): "_IntegerAttrType", bound=IntegerType | IndexType, covariant=True ) +AnySignlessIntegerOrIndexType: TypeAlias = Annotated[ + Attribute, AnyOf([IndexType, SignlessIntegerConstraint]) +] +"""Type alias constrained to IndexType or signless IntegerType.""" + @irdl_attr_definition class IntegerAttr(Generic[_IntegerAttrType], ParametrizedAttribute): diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index 63941f1cb0..81d7f6dcd1 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -1,10 +1,15 @@ from __future__ import annotations from collections.abc import Sequence +from typing import Annotated from typing_extensions import Self -from xdsl.dialects.builtin import IndexType, IntegerType +from xdsl.dialects.builtin import ( + AnySignlessIntegerOrIndexType, + IndexType, + IntegerType, +) from xdsl.dialects.utils import ( AbstractYieldOperation, parse_assignment, @@ -14,6 +19,7 @@ from xdsl.irdl import ( AnyAttr, AttrSizedOperandSegments, + ConstraintVar, IRDLOperation, Operand, VarOperand, @@ -128,9 +134,11 @@ def get( class For(IRDLOperation): name = "scf.for" - lb: Operand = operand_def(IndexType) - ub: Operand = operand_def(IndexType) - step: Operand = operand_def(IndexType) + T = Annotated[AnySignlessIntegerOrIndexType, ConstraintVar("T")] + + lb: Operand = operand_def(T) + ub: Operand = operand_def(T) + step: Operand = operand_def(T) iter_args: VarOperand = var_operand_def(AnyAttr()) @@ -233,6 +241,10 @@ def print(self, printer: Printer): printer.print_string(") -> (") printer.print_list((a.type for a in iter_args), printer.print_attribute) printer.print_string(") ") + if not isinstance(indvar.type, IndexType): + printer.print_string(": ") + printer.print_attribute(indvar.type) + printer.print_string(" ") printer.print_region( self.body, print_entry_block_args=False, @@ -271,8 +283,12 @@ def parse(cls, parser: Parser) -> Self: iter_arg_unresolved_operands, iter_arg_types, pos ) - # Set block argument types + # Set induction variable type indvar.type = lb.type + if parser.parse_optional_characters(":"): + indvar.type = parser.parse_type() + + # Set block argument types for iter_arg, iter_arg_type in zip(iter_args, iter_arg_types): iter_arg.type = iter_arg_type