diff --git a/tests/test_printer.py b/tests/test_printer.py index dc00cda82c..843e2d72e7 100644 --- a/tests/test_printer.py +++ b/tests/test_printer.py @@ -10,6 +10,7 @@ from xdsl.dialects import test from xdsl.dialects.arith import AddiOp, Arith, ConstantOp from xdsl.dialects.builtin import ( + AnyFloat, AnyFloatAttr, Builtin, FloatAttr, @@ -19,6 +20,7 @@ ModuleOp, SymbolRefAttr, UnitAttr, + f32, i32, ) from xdsl.dialects.func import Func @@ -763,6 +765,49 @@ def test_densearray_attr(): assert_print_op(parsed, prog, None) +def test_float(): + printer = Printer() + + def _test_float_print(expected: str, value: float, type: AnyFloat): + io = StringIO() + printer.stream = io + printer.print_float(value, type) + assert io.getvalue() == expected + + _test_float_print("3.000000e+00", 3, f32) + _test_float_print("-3.000000e+00", -3, f32) + _test_float_print("3.140000e+00", 3.14, f32) + _test_float_print("3.140000e+08", 3.14e8, f32) + _test_float_print("3.142857142857143", 22 / 7, f32) + _test_float_print("314285714.28571427", 22e8 / 7, f32) + _test_float_print("-3.142857142857143", -22 / 7, f32) + + +def test_float_attr(): + printer = Printer() + + def _test_float_attr(value: float, type: AnyFloat): + io_float = StringIO() + printer.stream = io_float + printer.print_float(value, type) + + io_attr = StringIO() + printer.stream = io_attr + printer.print_float_attr(FloatAttr(value, type)) + + assert io_float.getvalue() == io_attr.getvalue() + + for value in ( + 3, + 3.14, + 22 / 7, + float("nan"), + float("inf"), + float("-inf"), + ): + _test_float_attr(value, f32) + + def test_float_attr_specials(): printer = Printer() diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 15fe476ec0..15fc5915c4 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -836,7 +836,7 @@ def parse_with_type( return FloatAttr(parser.parse_float(), type) def print_without_type(self, printer: Printer): - return printer.print_float(self) + return printer.print_float_attr(self) AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat] @@ -2061,7 +2061,7 @@ def _print_one_elem(val: Attribute, printer: Printer): if isinstance(val, IntegerAttr): val.print_without_type(printer) elif isinstance(val, FloatAttr): - printer.print_float(cast(AnyFloatAttr, val)) + printer.print_float_attr(cast(AnyFloatAttr, val)) else: raise Exception( "unexpected attribute type " diff --git a/xdsl/printer.py b/xdsl/printer.py index 1ca2298e04..e6bb93fc7f 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -13,6 +13,7 @@ from xdsl.dialects.builtin import ( AffineMapAttr, AffineSetAttr, + AnyFloat, AnyFloatAttr, AnyUnrankedMemrefType, AnyUnrankedTensorType, @@ -463,26 +464,28 @@ def print_bytes_literal(self, bytestring: bytes): self.print_string(chr(byte)) self.print_string('"') - def print_float(self, attribute: AnyFloatAttr): - value = attribute.value - if math.isnan(value.data) or math.isinf(value.data): - if isinstance(attribute.type, Float16Type): - self.print_string(f"{hex(convert_f16_to_u16(value.data))}") - elif isinstance(attribute.type, Float32Type): - self.print_string(f"{hex(convert_f32_to_u32(value.data))}") - elif isinstance(attribute.type, Float64Type): - self.print_string(f"{hex(convert_f64_to_u64(value.data))}") + def print_float_attr(self, attribute: AnyFloatAttr): + self.print_float(attribute.value.data, attribute.type) + + def print_float(self, value: float, type: AnyFloat): + if math.isnan(value) or math.isinf(value): + if isinstance(type, Float16Type): + self.print_string(f"{hex(convert_f16_to_u16(value))}") + elif isinstance(type, Float32Type): + self.print_string(f"{hex(convert_f32_to_u32(value))}") + elif isinstance(type, Float64Type): + self.print_string(f"{hex(convert_f64_to_u64(value))}") else: raise NotImplementedError( - f"Cannot print '{value.data}' value for float type {str(attribute.type)}" + f"Cannot print '{value}' value for float type {str(type)}" ) else: # to mirror mlir-opt, attempt to print scientific notation iff the value parses losslessly - float_str = f"{value.data:.6e}" - if float(float_str) == value.data: + float_str = f"{value:.6e}" + if float(float_str) == value: self.print_string(float_str) else: - self.print_string(f"{repr(value.data)}") + self.print_string(f"{repr(value)}") def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, UnitAttr):