Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: split print_float into print_float and print_float_attr #3600

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from xdsl.dialects import test
from xdsl.dialects.arith import AddiOp, Arith, ConstantOp
from xdsl.dialects.builtin import (
AnyFloat,
AnyFloatAttr,
Builtin,
FloatAttr,
Expand All @@ -19,6 +20,7 @@
ModuleOp,
SymbolRefAttr,
UnitAttr,
f32,
i32,
)
from xdsl.dialects.func import Func
Expand Down Expand Up @@ -763,6 +765,49 @@ def test_densearray_attr():
assert_print_op(parsed, prog, None)


def test_float():
printer = Printer()

def _test_float_print(expected: str, value: float, type: AnyFloat):
io = StringIO()
printer.stream = io
printer.print_float(value, type)
assert io.getvalue() == expected

_test_float_print("3.000000e+00", 3, f32)
_test_float_print("-3.000000e+00", -3, f32)
_test_float_print("3.140000e+00", 3.14, f32)
_test_float_print("3.140000e+08", 3.14e8, f32)
_test_float_print("3.142857142857143", 22 / 7, f32)
_test_float_print("314285714.28571427", 22e8 / 7, f32)
_test_float_print("-3.142857142857143", -22 / 7, f32)


def test_float_attr():
printer = Printer()

def _test_float_attr(value: float, type: AnyFloat):
io_float = StringIO()
printer.stream = io_float
printer.print_float(value, type)

io_attr = StringIO()
printer.stream = io_attr
printer.print_float_attr(FloatAttr(value, type))

assert io_float.getvalue() == io_attr.getvalue()

for value in (
3,
3.14,
22 / 7,
float("nan"),
float("inf"),
float("-inf"),
):
_test_float_attr(value, f32)


def test_float_attr_specials():
printer = Printer()

Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def parse_with_type(
return FloatAttr(parser.parse_float(), type)

def print_without_type(self, printer: Printer):
return printer.print_float(self)
return printer.print_float_attr(self)


AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
Expand Down Expand Up @@ -2061,7 +2061,7 @@ def _print_one_elem(val: Attribute, printer: Printer):
if isinstance(val, IntegerAttr):
val.print_without_type(printer)
elif isinstance(val, FloatAttr):
printer.print_float(cast(AnyFloatAttr, val))
printer.print_float_attr(cast(AnyFloatAttr, val))
else:
raise Exception(
"unexpected attribute type "
Expand Down
29 changes: 16 additions & 13 deletions xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from xdsl.dialects.builtin import (
AffineMapAttr,
AffineSetAttr,
AnyFloat,
AnyFloatAttr,
AnyUnrankedMemrefType,
AnyUnrankedTensorType,
Expand Down Expand Up @@ -463,26 +464,28 @@ def print_bytes_literal(self, bytestring: bytes):
self.print_string(chr(byte))
self.print_string('"')

def print_float(self, attribute: AnyFloatAttr):
value = attribute.value
if math.isnan(value.data) or math.isinf(value.data):
if isinstance(attribute.type, Float16Type):
self.print_string(f"{hex(convert_f16_to_u16(value.data))}")
elif isinstance(attribute.type, Float32Type):
self.print_string(f"{hex(convert_f32_to_u32(value.data))}")
elif isinstance(attribute.type, Float64Type):
self.print_string(f"{hex(convert_f64_to_u64(value.data))}")
def print_float_attr(self, attribute: AnyFloatAttr):
self.print_float(attribute.value.data, attribute.type)

def print_float(self, value: float, type: AnyFloat):
if math.isnan(value) or math.isinf(value):
if isinstance(type, Float16Type):
self.print_string(f"{hex(convert_f16_to_u16(value))}")
elif isinstance(type, Float32Type):
self.print_string(f"{hex(convert_f32_to_u32(value))}")
elif isinstance(type, Float64Type):
self.print_string(f"{hex(convert_f64_to_u64(value))}")
else:
raise NotImplementedError(
f"Cannot print '{value.data}' value for float type {str(attribute.type)}"
f"Cannot print '{value}' value for float type {str(type)}"
)
else:
# to mirror mlir-opt, attempt to print scientific notation iff the value parses losslessly
float_str = f"{value.data:.6e}"
if float(float_str) == value.data:
float_str = f"{value:.6e}"
if float(float_str) == value:
self.print_string(float_str)
else:
self.print_string(f"{repr(value.data)}")
self.print_string(f"{repr(value)}")

def print_attribute(self, attribute: Attribute) -> None:
if isinstance(attribute, UnitAttr):
Expand Down
Loading