From 9bc6dc11c454c54cae5fe85306574d3ef4b97cf4 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Wed, 27 Nov 2024 11:00:57 +0000 Subject: [PATCH] core: Add operands directive (#3507) Adds the operands directive to the assembly format, which works as a format directive and in type directives. --- .../irdl/test_declarative_assembly_format.py | 229 ++++++++++++++++++ xdsl/irdl/declarative_assembly_format.py | 96 +++++++- .../declarative_assembly_format_parser.py | 29 +++ 3 files changed, 353 insertions(+), 1 deletion(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 5b6d4e1f89..7057586ca9 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -1004,6 +1004,235 @@ class TwoOperandsOp(IRDLOperation): check_roundtrip(program, ctx) +@pytest.mark.parametrize( + "program", + [ + "test.operands_directive %0 : i32", + "test.operands_directive %0, %1 : i32, i32", + "test.operands_directive %0, %1, %2 : i32, i32, i32", + ], +) +def test_operands_directive(program: str): + """Test the operands directive""" + + @irdl_op_definition + class OperandsDirectiveOp(IRDLOperation): + name = "test.operands_directive" + + op1 = operand_def() + op2 = var_operand_def() + + assembly_format = "operands `:` type(operands) attr-dict" + + ctx = MLContext() + ctx.load_op(OperandsDirectiveOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + + +@pytest.mark.parametrize( + "program", + [ + "test.operands_directive %0 : i32", + "test.operands_directive %0, %1 : i32, i32", + ], +) +def test_operands_directive_with_optional(program: str): + """Test the operands directive""" + + @irdl_op_definition + class OperandsDirectiveOp(IRDLOperation): + name = "test.operands_directive" + + op1 = opt_operand_def() + op2 = operand_def() + + assembly_format = "operands `:` type(operands) attr-dict" + + ctx = MLContext() + ctx.load_op(OperandsDirectiveOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + + +def test_operands_directive_fails_with_two_var(): + """Test operands directive cannot be used with two variadic operands""" + + with pytest.raises( + PyRDLOpDefinitionError, + match="'operands' is ambiguous with multiple variadic operands", + ): + + @irdl_op_definition + class TwoVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.two_var_op" + + op1 = var_operand_def() + op2 = var_operand_def() + + irdl_options = [AttrSizedOperandSegments()] + + assembly_format = "operands attr-dict `:` type(operands)" + + +def test_operands_directive_fails_with_no_operands(): + """Test operands directive cannot be used with no operands""" + + with pytest.raises( + PyRDLOpDefinitionError, + match="'operands' should not be used when there are no operands", + ): + + @irdl_op_definition + class NoOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.no_operands_op" + + assembly_format = "operands attr-dict `:` type(operands)" + + +def test_operands_directive_fails_with_other_directive(): + """Test operands directive cannot be used with no operands""" + + with pytest.raises( + PyRDLOpDefinitionError, + match="'operands' cannot be used with other operand directives", + ): + + @irdl_op_definition + class TwoOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.two_operands_op" + + op1 = operand_def() + op2 = operand_def() + + assembly_format = "$op1 `,` operands attr-dict `:` type(operands)" + + +def test_operands_directive_fails_with_other_type_directive(): + """Test operands directive cannot be used with no operands""" + + with pytest.raises( + PyRDLOpDefinitionError, + match="'operands' cannot be used in a type directive with other operand type directives", + ): + + @irdl_op_definition + class TwoOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.two_operands_op" + + op1 = operand_def() + op2 = operand_def() + + assembly_format = "operands attr-dict `:` type($op1) `,` type(operands)" + + +@pytest.mark.parametrize( + "program, error", + [ + ("test.two_operands %0 : i32, i32", "Expected 2 operands but found 1"), + ( + "test.two_operands %0, %1, %2 : i32, i32", + "Expected 2 operands but found 3", + ), + ("test.two_operands %0, %1 : i32", "Expected 2 operand types but found 1"), + ( + "test.two_operands %0, %1 : i32, i32, i32", + "Expected 2 operand types but found 3", + ), + ], +) +def test_operands_directive_bounds(program: str, error: str): + @irdl_op_definition + class TwoOperandsOp(IRDLOperation): + name = "test.two_operands" + + op1 = operand_def() + op2 = operand_def() + + assembly_format = "operands attr-dict `:` type(operands)" + + ctx = MLContext() + ctx.load_op(TwoOperandsOp) + + with pytest.raises(ParseError, match=error): + parser = Parser(ctx, program) + parser.parse_operation() + + +@pytest.mark.parametrize( + "program, error", + [ + ( + "test.three_operands %0 : i32, i32", + "Expected at least 2 operands but found 1", + ), + ( + "test.three_operands %0, %1, %2, %3 : i32, i32, i32", + "Expected at most 3 operands but found 4", + ), + ( + "test.three_operands %0, %1 : i32", + "Expected at least 2 operand types but found 1", + ), + ( + "test.three_operands %0, %1, %3 : i32, i32, i32, i32", + "Expected at most 3 operand types but found 4", + ), + ], +) +def test_operands_directive_bounds_with_opt(program: str, error: str): + @irdl_op_definition + class ThreeOperandsOp(IRDLOperation): + name = "test.three_operands" + + op1 = operand_def() + op2 = opt_operand_def() + op3 = operand_def() + + assembly_format = "operands attr-dict `:` type(operands)" + + ctx = MLContext() + ctx.load_op(ThreeOperandsOp) + + with pytest.raises(ParseError, match=error): + parser = Parser(ctx, program) + parser.parse_operation() + + +@pytest.mark.parametrize( + "program, error", + [ + ( + "test.three_operands %0 : i32, i32", + "Expected at least 2 operands but found 1", + ), + ( + "test.three_operands %0, %1 : i32", + "Expected at least 2 operand types but found 1", + ), + ], +) +def test_operands_directive_bound_with_var(program: str, error: str): + @irdl_op_definition + class ThreeOperandsOp(IRDLOperation): + name = "test.three_operands" + + op1 = operand_def() + op2 = var_operand_def() + op3 = operand_def() + + assembly_format = "operands attr-dict `:` type(operands)" + + ctx = MLContext() + ctx.load_op(ThreeOperandsOp) + + with pytest.raises(ParseError, match=error): + parser = Parser(ctx, program) + parser.parse_operation() + + ################################################################################ # Results # ################################################################################ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 4bd6b1cbd4..1c5d3a6500 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Literal +from typing import Literal, TypeVar from xdsl.dialects.builtin import UnitAttr from xdsl.ir import ( @@ -600,6 +600,100 @@ def set_types_empty(self, state: ParsingState) -> None: state.operand_types[self.index] = () +_T = TypeVar("_T") + + +@dataclass(frozen=True) +class OperandsDirective(VariadicOperandDirective, VariadicTypeableDirective): + """ + An operands directive, with the following format: + operands-directive ::= operands + Prints each operand of the operation, inserting a comma between each. + """ + + variadic_index: tuple[bool, int] | None + """ + Represents the position of a (single) variadic variable, with the boolean + representing whether it is optional + """ + + def _set_using_variadic_index( + self, + field: list[_T | None | Sequence[_T]], + field_name: str, + set_to: Sequence[_T], + ) -> str | None: + if self.variadic_index is None: + if len(set_to) != len(field): + return f"Expected {len(field)} {field_name} but found {len(set_to)}" + field = [o for o in set_to] # Copy needed as list is not covariant + return + + is_optional, var_position = self.variadic_index + var_length = len(set_to) - len(field) + 1 + if var_length < 0: + return f"Expected at least {len(field) - 1} {field_name} but found {len(set_to)}" + if var_length > 1 and is_optional: + return f"Expected at most {len(field)} {field_name} but found {len(set_to)}" + field[:var_position] = set_to[:var_position] + field[var_position] = set_to[var_position : var_position + var_length] + field[var_position + 1 :] = set_to[var_position + var_length :] + + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + pos_start = parser.pos + operands = ( + parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_unresolved_operand, + parser.parse_unresolved_operand, + ) + or [] + ) + + if s := self._set_using_variadic_index(state.operands, "operands", operands): + parser.raise_error(s, at_position=pos_start, end_position=parser.pos) + return bool(operands) + + def parse_single_type(self, parser: Parser, state: ParsingState) -> None: + if len(state.operand_types) > 1: + parser.raise_error("Expected multiple types but received one.") + state.operand_types[0] = parser.parse_type() + + def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: + pos_start = parser.pos + types = ( + parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_type, parser.parse_type + ) + or [] + ) + + if s := self._set_using_variadic_index( + state.operand_types, "operand types", types + ): + parser.raise_error(s, at_position=pos_start, end_position=parser.pos) + return bool(types) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if op.operands: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + printer.print_list(op.operands, printer.print_ssa_value) + state.last_was_punctuation = False + state.should_emit_space = True + + def set_types_empty(self, state: ParsingState) -> None: + state.operand_types = [() for _ in state.operand_types] + + def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: + return op.operand_types + + def set_empty(self, state: ParsingState): + state.operands = [() for _ in state.operands] + + def is_present(self, op: IRDLOperation) -> bool: + return bool(op.operands) + + @dataclass(frozen=True) class ResultVariable(VariableDirective, TypeableDirective): """ diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 850cd967e9..92f0d1e17e 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -46,6 +46,7 @@ FormatProgram, KeywordDirective, OperandOrResult, + OperandsDirective, OperandVariable, OptionalAttributeVariable, OptionalGroupDirective, @@ -699,6 +700,8 @@ def parse_typeable_directive(self) -> TypeableDirective: Parse a typeable directive, with the following format: directive ::= variable """ + if self.parse_optional_keyword("operands"): + return self.create_operands_directive(False) if variable := self.parse_optional_typeable_variable(): return variable self.raise_error(f"unexpected token '{self._current_token.text}'") @@ -718,6 +721,8 @@ def parse_format_directive(self) -> FormatDirective: return self.create_attr_dict_directive(True) if self.parse_optional_keyword("type"): return self.parse_type_directive() + if self.parse_optional_keyword("operands"): + return self.create_operands_directive(True) if self._current_token.text == "`": return self.parse_keyword_or_punctuation() if self.parse_optional_punctuation("("): @@ -744,3 +749,27 @@ def create_attr_dict_directive(self, with_keyword: bool) -> AttrDictDirective: reserved_attr_names=set(), print_properties=print_properties, ) + + def create_operands_directive(self, top_level: bool) -> OperandsDirective: + if not self.op_def.operands: + self.raise_error("'operands' should not be used when there are no operands") + if top_level and any(self.seen_operands): + self.raise_error("'operands' cannot be used with other operand directives") + if not top_level and any(self.seen_operand_types): + self.raise_error( + "'operands' cannot be used in a type directive with other operand type directives" + ) + variadics = tuple( + (isinstance(o, OptionalDef), i) + for i, (_, o) in enumerate(self.op_def.operands) + if isinstance(o, VariadicDef) + ) + if len(variadics) > 1: + self.raise_error("'operands' is ambiguous with multiple variadic operands") + if top_level: + self.seen_operands = [True] * len(self.seen_operands) + else: + self.seen_operand_types = [True] * len(self.seen_operand_types) + if not variadics: + return OperandsDirective(None) + return OperandsDirective(variadics[0])