Skip to content

Commit

Permalink
API: split out MLIR-independent printing into BasePrinter and use in …
Browse files Browse the repository at this point in the history
…STIM [NFC] (#3613)

I'd also like to use the indentation logic in ASL, WGSL, and the
assembly dialects, PRs incoming.
  • Loading branch information
superlopuh authored Dec 11, 2024
1 parent 056f76e commit 414bcb0
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 197 deletions.
40 changes: 0 additions & 40 deletions tests/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,43 +952,3 @@ def test_get_printed_name():
printed = StringIO()
picked_name = Printer(printed).print_ssa_value(val)
assert f"%{picked_name}" == printed.getvalue()


def test_indented():
output = StringIO()
printer = Printer(stream=output)
printer.print("\n{")
with printer.indented():
printer.print("\nhello\nhow are you?")
printer.print("\n(")
with printer.indented():
printer.print("\nfoo,")
printer.print("\nbar,")
printer.print("\n")
printer.print_string("test\nraw print!", indent=0)
printer.print_string("\ndifferent indent level", indent=4)
printer.print("\n)")
printer.print("\n}")
printer.print("\n[")
with printer.indented(amount=3):
printer.print("\nbaz")
printer.print("\n]\n")

EXPECTED = """
{
hello
how are you?
(
foo,
bar,
test
raw print!
different indent level
)
}
[
baz
]
"""

assert output.getvalue() == EXPECTED
43 changes: 43 additions & 0 deletions tests/utils/test_base_printer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from io import StringIO

from xdsl.utils.base_printer import BasePrinter


def test_indented():
output = StringIO()
printer = BasePrinter(stream=output)
printer.print_string("\n{")
with printer.indented():
printer.print_string("\nhello\nhow are you?")
printer.print_string("\n(")
with printer.indented():
printer.print_string("\nfoo,")
printer.print_string("\nbar,")
printer.print_string("\n")
printer.print_string("test\nraw print!", indent=0)
printer.print_string("\ndifferent indent level", indent=4)
printer.print_string("\n)")
printer.print_string("\n}")
printer.print_string("\n[")
with printer.indented(amount=3):
printer.print_string("\nbaz")
printer.print_string("\n]\n")

EXPECTED = """
{
hello
how are you?
(
foo,
bar,
test
raw print!
different indent level
)
}
[
baz
]
"""

assert output.getvalue() == EXPECTED
33 changes: 9 additions & 24 deletions xdsl/dialects/stim/stim_printer_parser.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
import abc
from collections.abc import Callable, Iterable
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, TypeVar, cast

from xdsl.dialects.builtin import ArrayAttr, FloatData, IntAttr
from xdsl.ir import Attribute, Operation

"""
This file implements a printer that prints to the .stim file format.
Full documentation can be found here: https://github.com/quantumlib/Stim/blob/main/doc/file_format_stim_circuit.md
"""

import abc
from contextlib import contextmanager
from dataclasses import dataclass
from typing import cast

@dataclass(eq=False, repr=False)
class StimPrinter:
stream: Any | None = field(default=None)
from xdsl.dialects.builtin import ArrayAttr, FloatData, IntAttr
from xdsl.ir import Attribute, Operation
from xdsl.utils.base_printer import BasePrinter

def print_string(self, text: str) -> None:
print(text, end="", file=self.stream)

@dataclass(eq=False, repr=False)
class StimPrinter(BasePrinter):
@contextmanager
def in_braces(self):
self.print_string("{")
Expand All @@ -32,16 +27,6 @@ def in_parens(self):
yield
self.print_string(") ")

T = TypeVar("T")

def print_list(
self, elems: Iterable[T], print_fn: Callable[[T], Any], delimiter: str = ", "
) -> None:
for i, elem in enumerate(elems):
if i:
self.print_string(delimiter)
print_fn(elem)

def print_attribute(self, attribute: Attribute) -> None:
if isinstance(attribute, ArrayAttr):
attribute = cast(ArrayAttr[Attribute], attribute)
Expand Down
135 changes: 2 additions & 133 deletions xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from itertools import chain
from typing import Any, TypeVar, cast

from typing_extensions import deprecated

from xdsl.dialects.builtin import (
AffineMapAttr,
AffineSetAttr,
Expand Down Expand Up @@ -68,6 +66,7 @@
TypedAttribute,
)
from xdsl.traits import IsolatedFromAbove, IsTerminator
from xdsl.utils.base_printer import BasePrinter
from xdsl.utils.bitwise_casts import (
convert_f16_to_u16,
convert_f32_to_u32,
Expand All @@ -76,18 +75,14 @@
from xdsl.utils.diagnostic import Diagnostic
from xdsl.utils.mlir_lexer import MLIRLexer

indentNumSpaces = 2


@dataclass(eq=False, repr=False)
class Printer:
stream: Any | None = field(default=None)
class Printer(BasePrinter):
print_generic_format: bool = field(default=False)
print_properties_as_attributes: bool = field(default=False)
print_debuginfo: bool = field(default=False)
diagnostic: Diagnostic = field(default_factory=Diagnostic)

_indent: int = field(default=0, init=False)
_ssa_values: dict[SSAValue, str] = field(default_factory=dict, init=False)
"""
maps SSA Values to their "allocated" names
Expand All @@ -100,11 +95,6 @@ class Printer:
)
_next_valid_name_id: list[int] = field(default_factory=lambda: [0], init=False)
_next_valid_block_id: list[int] = field(default_factory=lambda: [0], init=False)
_current_line: int = field(default=0, init=False)
_current_column: int = field(default=0, init=False)
_next_line_callback: list[Callable[[], None]] = field(
default_factory=list, init=False
)

@property
def ssa_names(self):
Expand Down Expand Up @@ -148,114 +138,9 @@ def print(self, *argv: Any) -> None:
text = str(arg)
self.print_string(text)

@deprecated("Please use `printer.print_strint(text, indent=0)`")
def print_string_raw(self, text: str) -> None:
"""
Prints a string to the printer's output, without taking
indentation into account.
"""
self.print_string(text, indent=0)

def print_string(self, text: str, *, indent: int | None = None) -> None:
"""
Prints a string to the printer's output.
This function takes into account indentation level when
printing new lines.
If the indentation level is specified as 0, the string is printed as-is, if `None`
then the `Printer` instance's indentation level is used.
"""

num_newlines = text.count("\n")

if not num_newlines:
self._current_column += len(text)
print(text, end="", file=self.stream)
return

indent = self._indent if indent is None else indent
lines = text.split("\n")

if indent == 0 and not self._next_line_callback:
# No indent and no callback to print after the next newline, the text
# can be printed directly.
self._current_line += num_newlines
self._current_column = len(lines[-1])
print(text, end="", file=self.stream)
return

# Line and column information is not computed ahead of time
# as indent-aware newline printing may use it as part of
# callbacks.
print(lines[0], end="", file=self.stream)
self._current_column += len(lines[0])
for line in lines[1:]:
self._print_new_line(indent=indent)
print(line, end="", file=self.stream)
self._current_column += len(line)

@contextmanager
def indented(self, amount: int = 1):
"""
Increases the indentation level by the provided amount
for the duration of the context.
Only affects new lines printed within the context.
"""

self._indent += amount
try:
yield
finally:
self._indent -= amount

def _add_message_on_next_line(self, message: str, begin_pos: int, end_pos: int):
"""Add a message that will be displayed on the next line."""

def callback(indent: int = self._indent):
self._print_message(message, begin_pos, end_pos, indent)

self._next_line_callback.append(callback)

def _print_message(
self, message: str, begin_pos: int, end_pos: int, indent: int | None = None
):
"""
Print a message.
This is expected to be called at the beginning of a new line and to create a new
line at the end.
The span of the message to be underlined is represented as [begin_pos, end_pos).
"""
indent = self._indent if indent is None else indent
indent_size = indent * indentNumSpaces
self.print_string(" " * indent_size)
message_end_pos = max(map(len, message.split("\n"))) + indent_size + 2
first_line = (
(begin_pos - indent_size) * "-"
+ (end_pos - begin_pos) * "^"
+ (max(message_end_pos, end_pos) - end_pos) * "-"
)
self.print_string(first_line)
self._print_new_line(indent=indent, print_message=False)
for message_line in message.split("\n"):
self.print_string("| ")
self.print_string(message_line)
self._print_new_line(indent=indent, print_message=False)
self.print_string("-" * (max(message_end_pos, end_pos) - indent_size))
self._print_new_line(indent=0, print_message=False)

T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")

def print_list(
self, elems: Iterable[T], print_fn: Callable[[T], Any], delimiter: str = ", "
) -> None:
for i, elem in enumerate(elems):
if i:
self.print_string(delimiter)
print_fn(elem)

def print_dictionary(
self,
elems: dict[K, V],
Expand All @@ -270,22 +155,6 @@ def print_dictionary(
self.print_string("=")
print_value(value)

def _print_new_line(
self, indent: int | None = None, print_message: bool = True
) -> None:
indent = self._indent if indent is None else indent
# Prints a newline, bypassing the `print_string` method
print(file=self.stream)
self._current_line += 1
if print_message:
for callback in self._next_line_callback:
callback()
self._next_line_callback = []
num_spaces = indent * indentNumSpaces
# Prints indentation, bypassing the `print_string` method
print(" " * num_spaces, end="", file=self.stream)
self._current_column = num_spaces

def _get_new_valid_name_id(self) -> str:
self._next_valid_name_id[-1] += 1
return str(self._next_valid_name_id[-1] - 1)
Expand Down
Loading

0 comments on commit 414bcb0

Please sign in to comment.