Skip to content

Commit

Permalink
dialects: (scf) Allow signless integers as induction variable types f…
Browse files Browse the repository at this point in the history
…or `scf.for` (#1727)

This PR adds:

- Support for signless integers as induction variable types for
`scf.for` (also supported in MLIR)
- Custom parsing and printing for the above
- Tests (including MLIR interoperability tests)

Resolves #1304
  • Loading branch information
compor authored Nov 15, 2023
1 parent 42a6733 commit c5f3ce9
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 6 deletions.
48 changes: 47 additions & 1 deletion tests/filecheck/dialects/scf/for_args_types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"builtin.module"() ({
%lbi = "test.op"() : () -> !test.type<"int">
%x:2 = "test.op"() : () -> (index, index) // ub, step
// CHECK: !test.type<"int"> should be of base attribute index
// CHECK: operand at position 0 does not verify
"scf.for"(%lbi, %x#0, %x#1) ({
^0(%iv : index):
"scf.yield"() : () -> ()
Expand All @@ -21,3 +21,49 @@
}) : () -> ()

// CHECK: Expected induction var to be same type as bounds and step

// -----

"builtin.module"() ({
%lbi = "test.op"() : () -> si32
%x:2 = "test.op"() : () -> (index, index) // ub, step
// CHECK: operand at position 0 does not verify
"scf.for"(%lbi, %x#0, %x#1) ({
^0(%iv : index):
"scf.yield"() : () -> ()
}) : (si32, index, index) -> ()
}) : () -> ()

// -----

"builtin.module"() ({
%x:3 = "test.op"() : () -> (index, index, index) // lb, ub, step
"scf.for"(%x#0, %x#1, %x#2) ({
^0(%iv : i32):
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
}) : () -> ()

// CHECK: Expected induction var to be same type as bounds and step

// -----

"builtin.module"() ({
%x:3 = "test.op"() : () -> (si32, si32, si32) // lb, ub, step
// CHECK: operand at position 0 does not verify
"scf.for"(%x#0, %x#1, %x#2) ({
^0(%iv : si32):
"scf.yield"() : () -> ()
}) : (si32, si32, si32) -> ()
}) : () -> ()

// -----

"builtin.module"() ({
%x:3 = "test.op"() : () -> (i32, i32, i32) // lb, ub, step
// CHECK: Expected induction var to be same type as bounds and step
"scf.for"(%x#0, %x#1, %x#2) ({
^0(%iv : index):
"scf.yield"() : () -> ()
}) : (i32, i32, i32) -> ()
}) : () -> ()
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: xdsl-opt %s | xdsl-opt | mlir-opt | filecheck %s

%lb = arith.constant 0 : i32
%ub = arith.constant 42 : i32
%step = arith.constant 7 : i32
%sum_init = arith.constant 36 : i32
%sum = scf.for %iv = %lb to %ub step %step iter_args(%sum_iter = %sum_init) -> (i32) : i32 {
%sum_new = arith.addi %sum_iter, %iv : i32
scf.yield %sum_new : i32
}

scf.for %iv = %lb to %ub step %step : i32 {
}

// CHECK: module {
// CHECK-NEXT: %{{.*}} = arith.constant 0 : i32
// CHECK-NEXT: %{{.*}} = arith.constant 42 : i32
// CHECK-NEXT: %{{.*}} = arith.constant 7 : i32
// CHECK-NEXT: %{{.*}} = arith.constant 36 : i32
// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (i32) : i32 {
// CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i32
// CHECK-NEXT: scf.yield %{{.*}} : i32
// CHECK-NEXT: }
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} : i32 {
// CHECK-NEXT: }
// CHECK-NEXT: }
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: xdsl-opt %s | mlir-opt --mlir-print-op-generic | xdsl-opt --print-op-generic | filecheck %s

"builtin.module"() ({
%lb = "arith.constant"() {"value" = 0 : i32} : () -> i32
%ub = "arith.constant"() {"value" = 42 : i32} : () -> i32
%step = "arith.constant"() {"value" = 7 : i32} : () -> i32
%sum_init = "arith.constant"() {"value" = 36 : i32} : () -> i32
%sum = "scf.for"(%lb, %ub, %step, %sum_init) ({
^0(%iv : i32, %sum_iter : i32):
%sum_new = "arith.addi"(%sum_iter, %iv) : (i32, i32) -> i32
"scf.yield"(%sum_new) : (i32) -> ()
}) : (i32, i32, i32, i32) -> i32
"scf.for"(%lb, %ub, %step) ({
^bb0(%iv: i32):
"scf.yield"() : () -> ()
}) : (i32, i32, i32) -> ()
}) : () -> ()

// CHECK: "builtin.module"() ({
// CHECK-NEXT: %{{.*}} = "arith.constant"() <{"value" = 0 : i32}> : () -> i32
// CHECK-NEXT: %{{.*}} = "arith.constant"() <{"value" = 42 : i32}> : () -> i32
// CHECK-NEXT: %{{.*}} = "arith.constant"() <{"value" = 7 : i32}> : () -> i32
// CHECK-NEXT: %{{.*}} = "arith.constant"() <{"value" = 36 : i32}> : () -> i32
// CHECK-NEXT: %{{.*}} = "scf.for"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ({
// CHECK-NEXT: ^0(%{{.*}} : i32, %{{.*}} : i32):
// CHECK-NEXT: %{{.*}} = "arith.addi"(%{{.*}}, %{{.*}}) : (i32, i32) -> i32
// CHECK-NEXT: "scf.yield"(%{{.*}}) : (i32) -> ()
// CHECK-NEXT: }) : (i32, i32, i32, i32) -> i32
// CHECK-NEXT: "scf.for"(%{{.*}}, %{{.*}}, %{{.*}}) ({
// CHECK-NEXT: ^1(%{{.*}}: i32):
// CHECK-NEXT: "scf.yield"() : () -> ()
// CHECK-NEXT: }) : (i32, i32, i32) -> ()
// CHECK-NEXT: }) : () -> ()
15 changes: 15 additions & 0 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from xdsl.irdl import (
AllOf,
AnyAttr,
AnyOf,
AttrConstraint,
GenericData,
IRDLOperation,
Expand Down Expand Up @@ -340,6 +341,15 @@ def __init__(
i1 = IntegerType(1)


SignlessIntegerConstraint = ParamAttrConstraint(
IntegerType, [IntAttr, SignednessAttr(Signedness.SIGNLESS)]
)
"""Type constraint for signless IntegerType."""

AnySignlessIntegerType: TypeAlias = Annotated[IntegerType, SignlessIntegerConstraint]
"""Type alias constrained to signless IntegerType."""


@irdl_attr_definition
class UnitAttr(ParametrizedAttribute):
name = "unit"
Expand All @@ -364,6 +374,11 @@ class IndexType(ParametrizedAttribute):
"_IntegerAttrType", bound=IntegerType | IndexType, covariant=True
)

AnySignlessIntegerOrIndexType: TypeAlias = Annotated[
Attribute, AnyOf([IndexType, SignlessIntegerConstraint])
]
"""Type alias constrained to IndexType or signless IntegerType."""


@irdl_attr_definition
class IntegerAttr(Generic[_IntegerAttrType], ParametrizedAttribute):
Expand Down
26 changes: 21 additions & 5 deletions xdsl/dialects/scf.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Annotated

from typing_extensions import Self

from xdsl.dialects.builtin import IndexType, IntegerType
from xdsl.dialects.builtin import (
AnySignlessIntegerOrIndexType,
IndexType,
IntegerType,
)
from xdsl.dialects.utils import (
AbstractYieldOperation,
parse_assignment,
Expand All @@ -14,6 +19,7 @@
from xdsl.irdl import (
AnyAttr,
AttrSizedOperandSegments,
ConstraintVar,
IRDLOperation,
Operand,
VarOperand,
Expand Down Expand Up @@ -128,9 +134,11 @@ def get(
class For(IRDLOperation):
name = "scf.for"

lb: Operand = operand_def(IndexType)
ub: Operand = operand_def(IndexType)
step: Operand = operand_def(IndexType)
T = Annotated[AnySignlessIntegerOrIndexType, ConstraintVar("T")]

lb: Operand = operand_def(T)
ub: Operand = operand_def(T)
step: Operand = operand_def(T)

iter_args: VarOperand = var_operand_def(AnyAttr())

Expand Down Expand Up @@ -233,6 +241,10 @@ def print(self, printer: Printer):
printer.print_string(") -> (")
printer.print_list((a.type for a in iter_args), printer.print_attribute)
printer.print_string(") ")
if not isinstance(indvar.type, IndexType):
printer.print_string(": ")
printer.print_attribute(indvar.type)
printer.print_string(" ")
printer.print_region(
self.body,
print_entry_block_args=False,
Expand Down Expand Up @@ -271,8 +283,12 @@ def parse(cls, parser: Parser) -> Self:
iter_arg_unresolved_operands, iter_arg_types, pos
)

# Set block argument types
# Set induction variable type
indvar.type = lb.type
if parser.parse_optional_characters(":"):
indvar.type = parser.parse_type()

# Set block argument types
for iter_arg, iter_arg_type in zip(iter_args, iter_arg_types):
iter_arg.type = iter_arg_type

Expand Down

0 comments on commit c5f3ce9

Please sign in to comment.