diff --git a/tests/test_printer.py b/tests/test_printer.py index 843e2d72e7..6645c8c489 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -952,43 +952,3 @@ def test_get_printed_name(): printed = StringIO() picked_name = Printer(printed).print_ssa_value(val) assert f"%{picked_name}" == printed.getvalue() - - -def test_indented(): - output = StringIO() - printer = Printer(stream=output) - printer.print("\n{") - with printer.indented(): - printer.print("\nhello\nhow are you?") - printer.print("\n(") - with printer.indented(): - printer.print("\nfoo,") - printer.print("\nbar,") - printer.print("\n") - printer.print_string("test\nraw print!", indent=0) - printer.print_string("\ndifferent indent level", indent=4) - printer.print("\n)") - printer.print("\n}") - printer.print("\n[") - with printer.indented(amount=3): - printer.print("\nbaz") - printer.print("\n]\n") - - EXPECTED = """ -{ - hello - how are you? - ( - foo, - bar, - test -raw print! - different indent level - ) -} -[ - baz -] -""" - - assert output.getvalue() == EXPECTED diff --git a/tests/utils/test_base_printer.py b/tests/utils/test_base_printer.py new file mode 100644 index 0000000000..5d5882f440 --- /dev/null +++ b/tests/utils/test_base_printer.py @@ -0,0 +1,43 @@ +from io import StringIO + +from xdsl.utils.base_printer import BasePrinter + + +def test_indented(): + output = StringIO() + printer = BasePrinter(stream=output) + printer.print_string("\n{") + with printer.indented(): + printer.print_string("\nhello\nhow are you?") + printer.print_string("\n(") + with printer.indented(): + printer.print_string("\nfoo,") + printer.print_string("\nbar,") + printer.print_string("\n") + printer.print_string("test\nraw print!", indent=0) + printer.print_string("\ndifferent indent level", indent=4) + printer.print_string("\n)") + printer.print_string("\n}") + printer.print_string("\n[") + with printer.indented(amount=3): + printer.print_string("\nbaz") + printer.print_string("\n]\n") + + EXPECTED = """ +{ + hello + how are you? + ( + foo, + bar, + test +raw print! + different indent level + ) +} +[ + baz +] +""" + + assert output.getvalue() == EXPECTED diff --git a/xdsl/dialects/stim/stim_printer_parser.py b/xdsl/dialects/stim/stim_printer_parser.py index b48e7da4a7..c63cc36067 100644 --- a/xdsl/dialects/stim/stim_printer_parser.py +++ b/xdsl/dialects/stim/stim_printer_parser.py @@ -1,25 +1,20 @@ -import abc -from collections.abc import Callable, Iterable -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Any, TypeVar, cast - -from xdsl.dialects.builtin import ArrayAttr, FloatData, IntAttr -from xdsl.ir import Attribute, Operation - """ This file implements a printer that prints to the .stim file format. Full documentation can be found here: https://github.com/quantumlib/Stim/blob/main/doc/file_format_stim_circuit.md """ +import abc +from contextlib import contextmanager +from dataclasses import dataclass +from typing import cast -@dataclass(eq=False, repr=False) -class StimPrinter: - stream: Any | None = field(default=None) +from xdsl.dialects.builtin import ArrayAttr, FloatData, IntAttr +from xdsl.ir import Attribute, Operation +from xdsl.utils.base_printer import BasePrinter - def print_string(self, text: str) -> None: - print(text, end="", file=self.stream) +@dataclass(eq=False, repr=False) +class StimPrinter(BasePrinter): @contextmanager def in_braces(self): self.print_string("{") @@ -32,16 +27,6 @@ def in_parens(self): yield self.print_string(") ") - T = TypeVar("T") - - def print_list( - self, elems: Iterable[T], print_fn: Callable[[T], Any], delimiter: str = ", " - ) -> None: - for i, elem in enumerate(elems): - if i: - self.print_string(delimiter) - print_fn(elem) - def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, ArrayAttr): attribute = cast(ArrayAttr[Attribute], attribute) diff --git a/xdsl/printer.py b/xdsl/printer.py index 958c5ccf21..79914321f8 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -8,8 +8,6 @@ from itertools import chain from typing import Any, TypeVar, cast -from typing_extensions import deprecated - from xdsl.dialects.builtin import ( AffineMapAttr, AffineSetAttr, @@ -68,6 +66,7 @@ TypedAttribute, ) from xdsl.traits import IsolatedFromAbove, IsTerminator +from xdsl.utils.base_printer import BasePrinter from xdsl.utils.bitwise_casts import ( convert_f16_to_u16, convert_f32_to_u32, @@ -76,18 +75,14 @@ from xdsl.utils.diagnostic import Diagnostic from xdsl.utils.mlir_lexer import MLIRLexer -indentNumSpaces = 2 - @dataclass(eq=False, repr=False) -class Printer: - stream: Any | None = field(default=None) +class Printer(BasePrinter): print_generic_format: bool = field(default=False) print_properties_as_attributes: bool = field(default=False) print_debuginfo: bool = field(default=False) diagnostic: Diagnostic = field(default_factory=Diagnostic) - _indent: int = field(default=0, init=False) _ssa_values: dict[SSAValue, str] = field(default_factory=dict, init=False) """ maps SSA Values to their "allocated" names @@ -100,11 +95,6 @@ class Printer: ) _next_valid_name_id: list[int] = field(default_factory=lambda: [0], init=False) _next_valid_block_id: list[int] = field(default_factory=lambda: [0], init=False) - _current_line: int = field(default=0, init=False) - _current_column: int = field(default=0, init=False) - _next_line_callback: list[Callable[[], None]] = field( - default_factory=list, init=False - ) @property def ssa_names(self): @@ -148,114 +138,9 @@ def print(self, *argv: Any) -> None: text = str(arg) self.print_string(text) - @deprecated("Please use `printer.print_strint(text, indent=0)`") - def print_string_raw(self, text: str) -> None: - """ - Prints a string to the printer's output, without taking - indentation into account. - """ - self.print_string(text, indent=0) - - def print_string(self, text: str, *, indent: int | None = None) -> None: - """ - Prints a string to the printer's output. - - This function takes into account indentation level when - printing new lines. - If the indentation level is specified as 0, the string is printed as-is, if `None` - then the `Printer` instance's indentation level is used. - """ - - num_newlines = text.count("\n") - - if not num_newlines: - self._current_column += len(text) - print(text, end="", file=self.stream) - return - - indent = self._indent if indent is None else indent - lines = text.split("\n") - - if indent == 0 and not self._next_line_callback: - # No indent and no callback to print after the next newline, the text - # can be printed directly. - self._current_line += num_newlines - self._current_column = len(lines[-1]) - print(text, end="", file=self.stream) - return - - # Line and column information is not computed ahead of time - # as indent-aware newline printing may use it as part of - # callbacks. - print(lines[0], end="", file=self.stream) - self._current_column += len(lines[0]) - for line in lines[1:]: - self._print_new_line(indent=indent) - print(line, end="", file=self.stream) - self._current_column += len(line) - - @contextmanager - def indented(self, amount: int = 1): - """ - Increases the indentation level by the provided amount - for the duration of the context. - - Only affects new lines printed within the context. - """ - - self._indent += amount - try: - yield - finally: - self._indent -= amount - - def _add_message_on_next_line(self, message: str, begin_pos: int, end_pos: int): - """Add a message that will be displayed on the next line.""" - - def callback(indent: int = self._indent): - self._print_message(message, begin_pos, end_pos, indent) - - self._next_line_callback.append(callback) - - def _print_message( - self, message: str, begin_pos: int, end_pos: int, indent: int | None = None - ): - """ - Print a message. - This is expected to be called at the beginning of a new line and to create a new - line at the end. - The span of the message to be underlined is represented as [begin_pos, end_pos). - """ - indent = self._indent if indent is None else indent - indent_size = indent * indentNumSpaces - self.print_string(" " * indent_size) - message_end_pos = max(map(len, message.split("\n"))) + indent_size + 2 - first_line = ( - (begin_pos - indent_size) * "-" - + (end_pos - begin_pos) * "^" - + (max(message_end_pos, end_pos) - end_pos) * "-" - ) - self.print_string(first_line) - self._print_new_line(indent=indent, print_message=False) - for message_line in message.split("\n"): - self.print_string("| ") - self.print_string(message_line) - self._print_new_line(indent=indent, print_message=False) - self.print_string("-" * (max(message_end_pos, end_pos) - indent_size)) - self._print_new_line(indent=0, print_message=False) - - T = TypeVar("T") K = TypeVar("K") V = TypeVar("V") - def print_list( - self, elems: Iterable[T], print_fn: Callable[[T], Any], delimiter: str = ", " - ) -> None: - for i, elem in enumerate(elems): - if i: - self.print_string(delimiter) - print_fn(elem) - def print_dictionary( self, elems: dict[K, V], @@ -270,22 +155,6 @@ def print_dictionary( self.print_string("=") print_value(value) - def _print_new_line( - self, indent: int | None = None, print_message: bool = True - ) -> None: - indent = self._indent if indent is None else indent - # Prints a newline, bypassing the `print_string` method - print(file=self.stream) - self._current_line += 1 - if print_message: - for callback in self._next_line_callback: - callback() - self._next_line_callback = [] - num_spaces = indent * indentNumSpaces - # Prints indentation, bypassing the `print_string` method - print(" " * num_spaces, end="", file=self.stream) - self._current_column = num_spaces - def _get_new_valid_name_id(self) -> str: self._next_valid_name_id[-1] += 1 return str(self._next_valid_name_id[-1] - 1) diff --git a/xdsl/utils/base_printer.py b/xdsl/utils/base_printer.py new file mode 100644 index 0000000000..42658b5091 --- /dev/null +++ b/xdsl/utils/base_printer.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from collections.abc import Callable, Iterable +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import IO, Any, TypeVar + + +@dataclass(eq=False, repr=False) +class BasePrinter: + stream: IO[str] | None = field(default=None) + indent_num_spaces: int = field(default=2, kw_only=True) + _indent: int = field(default=0, init=False) + _current_line: int = field(default=0, init=False) + _current_column: int = field(default=0, init=False) + + _next_line_callback: list[Callable[[], None]] = field( + default_factory=list, init=False + ) + + def print_string(self, text: str, *, indent: int | None = None) -> None: + """ + Prints a string to the printer's output. + + This function takes into account indentation level when + printing new lines. + If the indentation level is specified as 0, the string is printed as-is, if `None` + then the `Printer` instance's indentation level is used. + """ + + num_newlines = text.count("\n") + + if not num_newlines: + self._current_column += len(text) + print(text, end="", file=self.stream) + return + + indent = self._indent if indent is None else indent + lines = text.split("\n") + + if indent == 0 and not self._next_line_callback: + # No indent and no callback to print after the next newline, the text + # can be printed directly. + self._current_line += num_newlines + self._current_column = len(lines[-1]) + print(text, end="", file=self.stream) + return + + # Line and column information is not computed ahead of time + # as indent-aware newline printing may use it as part of + # callbacks. + print(lines[0], end="", file=self.stream) + self._current_column += len(lines[0]) + for line in lines[1:]: + self._print_new_line(indent=indent) + print(line, end="", file=self.stream) + self._current_column += len(line) + + T = TypeVar("T") + + def print_list( + self, elems: Iterable[T], print_fn: Callable[[T], Any], delimiter: str = ", " + ) -> None: + for i, elem in enumerate(elems): + if i: + self.print_string(delimiter) + print_fn(elem) + + def _print_new_line( + self, indent: int | None = None, print_message: bool = True + ) -> None: + indent = self._indent if indent is None else indent + # Prints a newline, bypassing the `print_string` method + print(file=self.stream) + self._current_line += 1 + if print_message: + for callback in self._next_line_callback: + callback() + self._next_line_callback = [] + num_spaces = indent * self.indent_num_spaces + # Prints indentation, bypassing the `print_string` method + print(" " * num_spaces, end="", file=self.stream) + self._current_column = num_spaces + + @contextmanager + def indented(self, amount: int = 1): + """ + Increases the indentation level by the provided amount + for the duration of the context. + + Only affects new lines printed within the context. + """ + + self._indent += amount + try: + yield + finally: + self._indent -= amount + + def _add_message_on_next_line(self, message: str, begin_pos: int, end_pos: int): + """Add a message that will be displayed on the next line.""" + + def callback(indent: int = self._indent): + self._print_message(message, begin_pos, end_pos, indent) + + self._next_line_callback.append(callback) + + def _print_message( + self, message: str, begin_pos: int, end_pos: int, indent: int | None = None + ): + """ + Print a message. + This is expected to be called at the beginning of a new line and to create a new + line at the end. + The span of the message to be underlined is represented as [begin_pos, end_pos). + """ + indent = self._indent if indent is None else indent + indent_size = indent * self.indent_num_spaces + self.print_string(" " * indent_size) + message_end_pos = max(map(len, message.split("\n"))) + indent_size + 2 + first_line = ( + (begin_pos - indent_size) * "-" + + (end_pos - begin_pos) * "^" + + (max(message_end_pos, end_pos) - end_pos) * "-" + ) + self.print_string(first_line) + self._print_new_line(indent=indent, print_message=False) + for message_line in message.split("\n"): + self.print_string("| ") + self.print_string(message_line) + self._print_new_line(indent=indent, print_message=False) + self.print_string("-" * (max(message_end_pos, end_pos) - indent_size)) + self._print_new_line(indent=0, print_message=False)