From 190c4f8b8f51924f2481d85f365e56a0f6cd7904 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Thu, 28 Nov 2024 22:52:30 +0000 Subject: [PATCH] core: remove unnecessary whitespace printing in assembly format (#3530) Variadic arguments were printing whitespace even if they had nothing to print after the whitespace, resulting in extra unnecessary whitespace appearing. This change moves the whitespace printing after the check to see if there's anything to print. Co-authored-by: Sasha Lopoukhine --- .../irdl/test_declarative_assembly_format.py | 18 ++--- xdsl/irdl/declarative_assembly_format.py | 71 +++++++++++-------- 2 files changed, 49 insertions(+), 40 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 4a22569f14..c389bf9581 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -805,7 +805,7 @@ class TwoOperandsOp(IRDLOperation): [ ( "$args type($args) attr-dict", - '%0 = "test.op"() : () -> i32\n' "test.variadic_operand ", + '%0 = "test.op"() : () -> i32\n' "test.variadic_operand", '%0 = "test.op"() : () -> i32\n' '"test.variadic_operand"() : () -> ()', ), ( @@ -853,7 +853,7 @@ class VariadicOperandOp(IRDLOperation): [ ( "$args type($args) attr-dict", - '%0 = "test.op"() : () -> i32\n' "test.optional_operand ", + '%0 = "test.op"() : () -> i32\n' "test.optional_operand", '%0 = "test.op"() : () -> i32\n' '"test.optional_operand"() : () -> ()', ), ( @@ -944,7 +944,7 @@ class VariadicOperandsOp(IRDLOperation): "program, generic_program", [ ( - "test.optional_operands(: ) [: ]", + "test.optional_operands(:) [:]", '"test.optional_operands"() {operandSegmentSizes = array} : () -> ()', ), ( @@ -1314,7 +1314,7 @@ class TwoResultOp(IRDLOperation): [ ( "`:` type($res) attr-dict", - "test.variadic_result : ", + "test.variadic_result :", '"test.variadic_result"() : () -> ()', ), ( @@ -1357,7 +1357,7 @@ class VariadicResultOp(IRDLOperation): [ ( "`:` type($res) attr-dict", - "test.optional_result : ", + "test.optional_result :", '"test.optional_result"() : () -> ()', ), ( @@ -1681,7 +1681,7 @@ class TwoRegionsOp(IRDLOperation): [ ( "attr-dict-with-keyword $region", - "test.variadic_region ", + "test.variadic_region", '"test.variadic_region"() : () -> ()', ), ( @@ -1724,7 +1724,7 @@ class VariadicRegionOp(IRDLOperation): [ ( "attr-dict-with-keyword $region", - "test.optional_region ", + "test.optional_region", '"test.optional_region"() : () -> ()', ), ( @@ -1871,7 +1871,7 @@ class TwoSuccessorsOp(IRDLOperation): "program, generic_program", [ ( - '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.var_successor \n}) : () -> ()', + '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.var_successor\n}) : () -> ()', textwrap.dedent( """\ "test.op"() ({ @@ -1941,7 +1941,7 @@ class VarSuccessorOp(IRDLOperation): "program, generic_program", [ ( - '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.opt_successor \n}) : () -> ()', + '"test.op"() ({\n "test.op"() [^0] : () -> ()\n^0:\n test.opt_successor\n}) : () -> ()', textwrap.dedent( """\ "test.op"() ({ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 4e4fb1c20f..e700e834e5 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -357,9 +357,12 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return self.inner.parse_many_types(parser, state) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + types = self.inner.get_types(op) + if not types: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - printer.print_list(self.inner.get_types(op), printer.print_attribute) + printer.print_list(types, printer.print_attribute) state.last_was_punctuation = False state.should_emit_space = True @@ -535,13 +538,14 @@ def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: return ret def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + operand = getattr(op, self.name) + if not operand: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - operand = getattr(op, self.name) - if operand: - printer.print_list(operand, printer.print_ssa_value) - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_list(operand, printer.print_ssa_value) + state.last_was_punctuation = False + state.should_emit_space = True def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: return getattr(op, self.name).types @@ -579,13 +583,14 @@ def parse_many_types(self, parser: Parser, state: ParsingState) -> bool: return ret def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + operand = getattr(op, self.name) + if not operand: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - operand = getattr(op, self.name) - if operand: - printer.print_ssa_value(operand) - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_ssa_value(operand) + state.last_was_punctuation = False + state.should_emit_space = True def get_types(self, op: IRDLOperation) -> Sequence[Attribute]: operand = getattr(op, self.name) @@ -866,13 +871,14 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(regions) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + region = getattr(op, self.name) + if not region: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - region = getattr(op, self.name) - if region: - printer.print_list(region, printer.print_region, delimiter=" ") - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_list(region, printer.print_region, delimiter=" ") + state.last_was_punctuation = False + state.should_emit_space = True def set_empty(self, state: ParsingState): state.regions[self.index] = () @@ -893,13 +899,14 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(region) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + region = getattr(op, self.name) + if not region: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - region = getattr(op, self.name) - if region: - printer.print_region(region) - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_region(region) + state.last_was_punctuation = False + state.should_emit_space = True def set_empty(self, state: ParsingState): state.regions[self.index] = () @@ -953,13 +960,14 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(successors) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + successor = getattr(op, self.name) + if not successor: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - successor = getattr(op, self.name) - if successor: - printer.print_list(successor, printer.print_block_name, delimiter=" ") - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_list(successor, printer.print_block_name, delimiter=" ") + state.last_was_punctuation = False + state.should_emit_space = True def set_empty(self, state: ParsingState): state.successors[self.index] = () @@ -980,13 +988,14 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: return bool(successor) def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + successor = getattr(op, self.name) + if not successor: + return if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - successor = getattr(op, self.name) - if successor: - printer.print_block_name(successor) - state.last_was_punctuation = False - state.should_emit_space = True + printer.print_block_name(successor) + state.last_was_punctuation = False + state.should_emit_space = True def set_empty(self, state: ParsingState): state.successors[self.index] = ()