diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 4a22569f14..c389bf9581 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -805,7 +805,7 @@ class TwoOperandsOp(IRDLOperation): [ ( "$args type($args) attr-dict", - '%0 = "test.op"() : () -> i32\n' "test.variadic_operand ", + '%0 = "test.op"() : () -> i32\n' "test.variadic_operand", '%0 = "test.op"() : () -> i32\n' '"test.variadic_operand"() : () -> ()', ), ( @@ -853,7 +853,7 @@ class VariadicOperandOp(IRDLOperation): [ ( "$args type($args) attr-dict", - '%0 = "test.op"() : () -> i32\n' "test.optional_operand ", + '%0 = "test.op"() : () -> i32\n' "test.optional_operand", '%0 = "test.op"() : () -> i32\n' '"test.optional_operand"() : () -> ()', ), ( @@ -944,7 +944,7 @@ class VariadicOperandsOp(IRDLOperation): "program, generic_program", [ ( - "test.optional_operands(: ) [: ]", + "test.optional_operands(:) [:]", '"test.optional_operands"() {operandSegmentSizes = array} : () -> ()', ), ( @@ -1314,7 +1314,7 @@ class TwoResultOp(IRDLOperation): [ ( "`:` type($res) attr-dict", - "test.variadic_result : ", + "test.variadic_result :", '"test.variadic_result"() : () -> ()', ), ( @@ -1357,7 +1357,7 @@ class VariadicResultOp(IRDLOperation): [ ( "`:` type($res) attr-dict", - "test.optional_result : ", + "test.optional_result :", '"test.optional_result"() : () -> ()', ), ( @@ -1681,7 +1681,7 @@ class TwoRegionsOp(IRDLOperation): [ ( "attr-dict-with-keyword $region", - "test.variadic_region ", + "test.variadic_region", '"test.variadic_region"() : () -> ()', ), ( @@ -1724,7 +1724,7 @@ class VariadicRegionOp(IRDLOperation): [ ( "attr-dict-with-keyword $region", - "test.optional_region ", + "test.optional_region", '"test.optional_region"() : () -> ()', ), ( @@ -1871,7 +1871,7 @@ class TwoSuccessorsOp(IRDLOperation): "program, generic_program", [ ( - '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.var_successor \n}) : () -> ()', + '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.var_successor\n}) : () -> ()', textwrap.dedent( """\ "test.op"() ({ @@ -1941,7 +1941,7 @@ class VarSuccessorOp(IRDLOperation): "program, generic_program", [ ( - '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.opt_successor \n}) : () -> ()', + '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.opt_successor\n}) : () -> ()', textwrap.dedent( """\ "test.op"() ({ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 4e4fb1c20f..e700e834e5 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -357,9 +357,12 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return self.inner.parse_many_types(parser, state) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + types = self.inner.get_types(op) + if not types: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - printer.print_list(self.inner.get_types(op), printer.print_attribute) + printer.print_list(types, printer.print_attribute) state.last_was_punctuation = False state.should_emit_space = True @@ -535,13 +538,14 @@ def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: return ret def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + operand = getattr(op, self.name) + if not operand: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - operand = getattr(op, self.name) - if operand: - printer.print_list(operand, printer.print_ssa_value) - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_list(operand, printer.print_ssa_value) + state.last_was_punctuation = False + state.should_emit_space = True def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: return getattr(op, self.name).types @@ -579,13 +583,14 @@ def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: return ret def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + operand = getattr(op, self.name) + if not operand: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - operand = getattr(op, self.name) - if operand: - printer.print_ssa_value(operand) - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_ssa_value(operand) + state.last_was_punctuation = False + state.should_emit_space = True def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: operand = getattr(op, self.name) @@ -866,13 +871,14 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(regions) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + region = getattr(op, self.name) + if not region: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - region = getattr(op, self.name) - if region: - printer.print_list(region, printer.print_region, delimiter=" ") - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_list(region, printer.print_region, delimiter=" ") + state.last_was_punctuation = False + state.should_emit_space = True def set_empty(self, state: ParsingState): state.regions[self.index] = () @@ -893,13 +899,14 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(region) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + region = getattr(op, self.name) + if not region: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - region = getattr(op, self.name) - if region: - printer.print_region(region) - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_region(region) + state.last_was_punctuation = False + state.should_emit_space = True def set_empty(self, state: ParsingState): state.regions[self.index] = () @@ -953,13 +960,14 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(successors) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + successor = getattr(op, self.name) + if not successor: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - successor = getattr(op, self.name) - if successor: - printer.print_list(successor, printer.print_block_name, delimiter=" ") - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_list(successor, printer.print_block_name, delimiter=" ") + state.last_was_punctuation = False + state.should_emit_space = True def set_empty(self, state: ParsingState): state.successors[self.index] = () @@ -980,13 +988,14 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(successor) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + successor = getattr(op, self.name) + if not successor: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - successor = getattr(op, self.name) - if successor: - printer.print_block_name(successor) - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_block_name(successor) + state.last_was_punctuation = False + state.should_emit_space = True def set_empty(self, state: ParsingState): state.successors[self.index] = ()