From 52431ee0e9d2ba1a750534d14bbc9e77b13a41a8 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Mon, 2 Dec 2024 12:20:36 +0000 Subject: [PATCH] core: Fix 'parse_single_type' for operands/results directives (#3553) Reworks the way 'parse_single_type' works for operands/results directives, reusing the infrastructure introduced for 'parse_many_types'. This now allows 'parse_single_type' to work when there are variadic operands, as demonstrated by the new tests. The new tests manually make a `FormatProgram`, as the format program parser will never generate such a program. This currently makes no difference, as these functions are never called, but it will be important for the functional-type directive (#3517), as it allows a type like: ```mlir (i32, i32) -> i32 ``` to be parsed, even if the results of the operation is a variadic (the functional type directive calls 'parse_single_type' when the results are not wrapped in parentheses). --- .../irdl/test_declarative_assembly_format.py | 148 ++++++++++++++++++ xdsl/irdl/declarative_assembly_format.py | 16 +- 2 files changed, 158 insertions(+), 6 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index b9585c36e4..d02f333c14 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -65,6 +65,14 @@ var_result_def, var_successor_def, ) +from xdsl.irdl.declarative_assembly_format import ( + AttrDictDirective, + FormatProgram, + OperandsDirective, + PunctuationDirective, + ResultsDirective, + TypeDirective, +) from xdsl.parser import Parser from xdsl.printer import Printer from xdsl.utils.exceptions import ParseError, PyRDLOpDefinitionError, VerifyException @@ -1254,6 +1262,77 @@ class ThreeOperandsOp(IRDLOperation): parser.parse_operation() +def test_operands_directive_with_non_variadic_type_directive(): + """Tests the 'parse_single_type' function of the operands directive.""" + + # The parser will never generate a non-variadic TypeDirective containing + # an OperandsDirective, but we can manually make one. + format_program = FormatProgram( + ( + OperandsDirective(None), + AttrDictDirective(False, set(), False), + PunctuationDirective(":"), + TypeDirective(OperandsDirective(None)), + ), + {}, + ) + + @irdl_op_definition + class OneOperandOp(IRDLOperation): + name = "test.one_operand" + + op1 = operand_def() + + @classmethod + def parse(cls, parser: Parser) -> OneOperandOp: + return format_program.parse(parser, cls) + + def print(self, printer: Printer): + format_program.print(printer, self) + + ctx = MLContext() + ctx.load_op(OneOperandOp) + + check_roundtrip("test.one_operand %0 : i32", ctx) + + +def test_operands_directive_with_variadic_type_directive(): + """ + Tests the 'parse_single_type' function of the operands directive + when the operation has a variadic. + """ + # The parser will never generate a non-variadic TypeDirective containing + # an OperandsDirective, but we can manually make one. + format_program = FormatProgram( + ( + OperandsDirective((False, 1)), + AttrDictDirective(False, set(), False), + PunctuationDirective(":"), + TypeDirective(OperandsDirective((False, 1))), + ), + {}, + ) + + @irdl_op_definition + class TwoOperandOp(IRDLOperation): + name = "test.two_operand" + + op1 = operand_def() + op2 = var_operand_def() + + @classmethod + def parse(cls, parser: Parser) -> TwoOperandOp: + return format_program.parse(parser, cls) + + def print(self, printer: Printer): + format_program.print(printer, self) + + ctx = MLContext() + ctx.load_op(TwoOperandOp) + + check_roundtrip("test.two_operand %0 : i32", ctx) + + ################################################################################ # Results # ################################################################################ @@ -1610,6 +1689,75 @@ class ThreeResultsOp(IRDLOperation): parser.parse_operation() +def test_results_directive_with_non_variadic_type_directive(): + """Tests the 'parse_single_type' function of the results directive.""" + + # The parser will never generate a non-variadic TypeDirective containing + # a ResultsDirective, but we can manually make one. + format_program = FormatProgram( + ( + AttrDictDirective(False, set(), False), + PunctuationDirective(":"), + TypeDirective(ResultsDirective(None)), + ), + {}, + ) + + @irdl_op_definition + class OneResultOp(IRDLOperation): + name = "test.one_result" + + res = result_def() + + @classmethod + def parse(cls, parser: Parser) -> OneResultOp: + return format_program.parse(parser, cls) + + def print(self, printer: Printer): + format_program.print(printer, self) + + ctx = MLContext() + ctx.load_op(OneResultOp) + + check_roundtrip("%0 = test.one_result : i32", ctx) + + +def test_results_directive_with_variadic_type_directive(): + """ + Tests the 'parse_single_type' function of the results directive + when the operation has a variadic. + """ + # The parser will never generate a non-variadic TypeDirective containing + # a ResultsDirective, but we can manually make one. + format_program = FormatProgram( + ( + AttrDictDirective(False, set(), False), + PunctuationDirective(":"), + TypeDirective(ResultsDirective((False, 1))), + ), + {}, + ) + + @irdl_op_definition + class TwoResultsOp(IRDLOperation): + name = "test.two_results" + + res1 = result_def() + res2 = var_result_def() + + @classmethod + def parse(cls, parser: Parser) -> TwoResultsOp: + return format_program.parse(parser, cls) + + def print(self, printer: Printer): + format_program.print(printer, self) + + ctx = MLContext() + ctx.load_op(TwoResultsOp) + + check_roundtrip("%0 = test.two_results : i32", ctx) + + ################################################################################ # Regions # ################################################################################ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index c0db9d9c7e..ecc9318807 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -665,9 +665,11 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(operands) def parse_single_type(self, parser: Parser, state: ParsingState) -> None: - if len(state.operand_types) > 1: - parser.raise_error("Expected multiple types but received one.") - state.operand_types[0] = parser.parse_type() + pos_start = parser.pos + if s := self._set_using_variadic_index( + state.operand_types, "operand types", (parser.parse_type(),) + ): + parser.raise_error(s, at_position=pos_start, end_position=parser.pos) def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: pos_start = parser.pos @@ -787,9 +789,11 @@ class ResultsDirective(OperandsOrResultDirective): """ def parse_single_type(self, parser: Parser, state: ParsingState) -> None: - if len(state.result_types) > 1: - parser.raise_error("Expected multiple types but received one.") - state.result_types[0] = parser.parse_type() + pos_start = parser.pos + if s := self._set_using_variadic_index( + state.result_types, "result types", (parser.parse_type(),) + ): + parser.raise_error(s, at_position=pos_start, end_position=parser.pos) def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: pos_start = parser.pos