Skip to content

Commit

Permalink
core: add functional-type directive (#3517)
Browse files Browse the repository at this point in the history
Adds the functional-type directive, which takes two typeable directives
as input.

The key difficulty here was dealing with the results, which are usually
surrounded by parentheses but don't have to be if there is exactly one
result.
  • Loading branch information
alexarice authored Dec 2, 2024
1 parent 52431ee commit 1517d8d
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 0 deletions.
64 changes: 64 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,70 @@ def print(self, printer: Printer):
check_roundtrip("%0 = test.two_results : i32", ctx)


################################################################################
# Functional type #
################################################################################


@pytest.mark.parametrize(
"program",
[
"%0 = test.functional_type %1, %2 : (i32, i32) -> i32",
"test.functional_type %0, %1 : (i32, i32) -> ()",
"%0, %1 = test.functional_type %2, %3 : (i32, i32) -> (i32, i32)",
"%0 = test.functional_type %1 : (i32) -> i32",
"%0 = test.functional_type : () -> i32",
],
)
def test_functional_type(program: str):
"""Test the parsing of the functional-type directive"""

@irdl_op_definition
class FunctionalTypeOp(IRDLOperation):
name = "test.functional_type"

ops = var_operand_def()
res = var_result_def()

assembly_format = "$ops attr-dict `:` functional-type($ops, $res)"

ctx = MLContext()
ctx.load_op(FunctionalTypeOp)

check_roundtrip(program, ctx)


@pytest.mark.parametrize(
"program",
[
"%0 = test.functional_type %1, %2 : (i32, i32) -> i32",
"%0, %1 = test.functional_type %2, %3 : (i32, i32) -> (i32, i32)",
"%0 = test.functional_type %1 : (i32) -> i32",
],
)
def test_functional_type_with_operands_and_results(program: str):
"""
Test the parsing of the functional-type directive using the operands and
results directives
"""

@irdl_op_definition
class FunctionalTypeOp(IRDLOperation):
name = "test.functional_type"

op1 = operand_def()
ops2 = var_operand_def()
res1 = var_result_def()
res2 = result_def()

assembly_format = "operands attr-dict `:` functional-type(operands, results)"

ctx = MLContext()
ctx.load_op(FunctionalTypeOp)

check_roundtrip(program, ctx)


################################################################################
# Regions #
################################################################################
Expand Down
55 changes: 55 additions & 0 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,61 @@ def is_present(self, op: IRDLOperation) -> bool:
return bool(op.results)


@dataclass(frozen=True)
class FunctionalTypeDirective(OptionallyParsableDirective):
"""
A directive which parses a functional type, with format:
functional-type-directive ::= functional-type(typeable-directive, typeable-directive)
A functional type is either of the form
`(` type-list `)` `->` `(` type-list `)`
or
`(` type-list `)` `->` type
where type-list is a comma separated list of types (or the empty string to signify the empty list).
The second format is preferred for printing when possible.
"""

operand_typeable_directive: TypeableDirective
result_typeable_directive: TypeableDirective

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
if not parser.parse_optional_punctuation("("):
return False
if isinstance(self.operand_typeable_directive, VariadicTypeableDirective):
self.operand_typeable_directive.parse_many_types(parser, state)
else:
self.operand_typeable_directive.parse_single_type(parser, state)
parser.parse_punctuation(")")
parser.parse_punctuation("->")
if parser.parse_optional_punctuation("("):
if isinstance(self.result_typeable_directive, VariadicTypeableDirective):
self.result_typeable_directive.parse_many_types(parser, state)
else:
self.result_typeable_directive.parse_single_type(parser, state)
parser.parse_punctuation(")")
else:
self.result_typeable_directive.parse_single_type(parser, state)
return True

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print_string(" ")
state.should_emit_space = True
printer.print_string("(")
printer.print_list(
self.operand_typeable_directive.get_types(op), printer.print_attribute
)
printer.print_string(") -> ")
result_types = self.result_typeable_directive.get_types(op)
if len(result_types) == 1:
printer.print_attribute(result_types[0])
state.last_was_punctuation = False
else:
printer.print_string("(")
printer.print_list(result_types, printer.print_attribute)
printer.print_string(")")
state.last_was_punctuation = True


class RegionDirective(OptionallyParsableDirective, ABC):
"""
Baseclass to help keep typechecking simple.
Expand Down
16 changes: 16 additions & 0 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DefaultValuedAttributeVariable,
FormatDirective,
FormatProgram,
FunctionalTypeDirective,
KeywordDirective,
OperandOrResult,
OperandsDirective,
Expand Down Expand Up @@ -617,6 +618,19 @@ def parse_type_directive(self) -> FormatDirective:
return VariadicTypeDirective(inner)
return TypeDirective(inner)

def parse_functional_type_directive(self) -> FormatDirective:
"""
Parse a functional-type directive with the following format
functional-type-directive ::= `functional-type` `(` typeable-directive `,` typeable-directive `)`
`functional-type` is expected to have already been parsed
"""
self.parse_punctuation("(")
operands = self.parse_typeable_directive()
self.parse_punctuation(",")
results = self.parse_typeable_directive()
self.parse_punctuation(")")
return FunctionalTypeDirective(operands, results)

def parse_optional_group(self) -> FormatDirective:
"""
Parse an optional group, with the following format:
Expand Down Expand Up @@ -739,6 +753,8 @@ def parse_format_directive(self) -> FormatDirective:
return self.parse_type_directive()
if self.parse_optional_keyword("operands"):
return self.create_operands_directive(True)
if self.parse_optional_keyword("functional-type"):
return self.parse_functional_type_directive()
if self._current_token.text == "`":
return self.parse_keyword_or_punctuation()
if self.parse_optional_punctuation("("):
Expand Down

0 comments on commit 1517d8d

Please sign in to comment.