Skip to content

Commit

Permalink
dialects: fix support for fastmath in some arith ops. (#1751)
Browse files Browse the repository at this point in the history
fastmath was ignored by custom syntax and not exposed through `__init__`
on common arith ops.
  • Loading branch information
PapyChacal authored Nov 6, 2023
1 parent bdcc260 commit 532cec8
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
12 changes: 10 additions & 2 deletions tests/dialects/test_arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
ExtFOp,
ExtSIOp,
ExtUIOp,
FastMathFlagsAttr,
FloatingPointLikeBinaryOp,
FloorDivSI,
FPToSIOp,
IndexCastOp,
Expand Down Expand Up @@ -111,10 +113,16 @@ class Test_float_arith_construction:
"func",
[Addf, Subf, Mulf, Divf, Maxf, Minf],
)
def test_arith_ops(self, func: type[BinaryOperation[_BinOpArgT]]):
op = func(self.a, self.b)
@pytest.mark.parametrize(
"flags", [FastMathFlagsAttr("none"), FastMathFlagsAttr("fast"), None]
)
def test_arith_ops(
self, func: type[FloatingPointLikeBinaryOp], flags: FastMathFlagsAttr | None
):
op = func(self.a, self.b, flags)
assert op.operands[0].owner is self.a
assert op.operands[1].owner is self.b
assert op.fastmath == flags


def test_select_op():
Expand Down
6 changes: 6 additions & 0 deletions tests/filecheck/dialects/arith/arith_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@
// CHECK-NEXT: %divf = arith.divf %lhsf32, %rhsf32 : f32
// CHECK-NEXT: %divf_vector = arith.divf %lhsvec, %rhsvec : vector<4xf32>

%faddf = arith.addf %lhsf32, %rhsf32 fastmath<fast> : f32
%faddf_vector = arith.addf %lhsvec, %rhsvec fastmath<fast> : vector<4xf32>

// CHECK-NEXT: %faddf = arith.addf %lhsf32, %rhsf32 fastmath<fast> : f32
// CHECK-NEXT: %faddf_vector = arith.addf %lhsvec, %rhsvec fastmath<fast> : vector<4xf32>

%negf = "arith.negf"(%lhsf32) : (f32) -> f32

// CHECK-NEXT: %negf = arith.negf %lhsf32 : f32
Expand Down
34 changes: 34 additions & 0 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,40 @@ def __hash__(self) -> int:
class BinaryOperationWithFastMath(Generic[_T], BinaryOperation[_T]):
fastmath = opt_prop_def(FastMathFlagsAttr)

def __init__(
self,
operand1: Operation | SSAValue,
operand2: Operation | SSAValue,
flags: FastMathFlagsAttr | None = None,
result_type: Attribute | None = None,
):
super().__init__(operand1, operand2, result_type)
self.fastmath = flags

@classmethod
def parse(cls, parser: Parser):
lhs = parser.parse_unresolved_operand()
parser.parse_punctuation(",")
rhs = parser.parse_unresolved_operand()
flags = FastMathFlagsAttr("none")
if parser.parse_optional_keyword("fastmath") is not None:
flags = FastMathFlagsAttr(FastMathFlagsAttr.parse_parameter(parser))
parser.parse_punctuation(":")
result_type = parser.parse_type()
(lhs, rhs) = parser.resolve_operands([lhs, rhs], 2 * [result_type], parser.pos)
return cls(lhs, rhs, flags, result_type)

def print(self, printer: Printer):
printer.print(" ")
printer.print_ssa_value(self.lhs)
printer.print(", ")
printer.print_ssa_value(self.rhs)
if self.fastmath is not None and self.fastmath != FastMathFlagsAttr("none"):
printer.print(" fastmath")
self.fastmath.print_parameter(printer)
printer.print(" : ")
printer.print_attribute(self.result.type)


FloatingPointLikeBinaryOp = BinaryOperationWithFastMath[
Annotated[Attribute, floatingPointLike]
Expand Down

0 comments on commit 532cec8

Please sign in to comment.