From e092cf1766448aef7ed61199e5a6bf64da789b08 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 22 Jul 2024 13:01:00 -0700 Subject: [PATCH] add PeelFromMul --- lib/Transform/Arith/MulToAdd.pdll | 34 +++++++++++++++++++++++++++- lib/Transform/Arith/MulToAddPdll.cpp | 11 ++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/lib/Transform/Arith/MulToAdd.pdll b/lib/Transform/Arith/MulToAdd.pdll index 7cfcc8d..8bbddcb 100644 --- a/lib/Transform/Arith/MulToAdd.pdll +++ b/lib/Transform/Arith/MulToAdd.pdll @@ -6,8 +6,11 @@ Constraint IsPowerOfTwo(attr: Attr) [{ }]; // Currently, constraints that return values must be defined in C++ -Constraint Halve(attr: Attr) -> Attr; +Constraint Halve(atttr: Attr) -> Attr; +Constraint MinusOne(attr: Attr) -> Attr; +// Replace y = C*x with y = C/2*x + C/2*x, when C is a power of 2, otherwise do +// nothing. Pattern PowerOfTwoExpandRhs with benefit(2) { let root = op(op {value = const: Attr}, rhs: Value); IsPowerOfTwo(const); @@ -33,3 +36,32 @@ Pattern PowerOfTwoExpandLhs with benefit(2) { replace root with newAdd; }; } + +// Replace y = 9*x with y = 8*x + x +Pattern PeelFromMulRhs with benefit(1) { + let root = op(lhs: Value, op {value = const: Attr}); + + // We are guaranteed `value` is not a power of two, because the greedy + // rewrite engine ensures the PowerOfTwoExpand pattern is run first, since + // it has higher benefit. + let minusOne: Attr = MinusOne(const); + + rewrite root with { + let newConst = op {value = minusOne}; + let newMul = op(lhs, newConst); + let newAdd = op(newMul, lhs); + replace root with newAdd; + }; +} + +Pattern PeelFromMulLhs with benefit(1) { + let root = op(op {value = const: Attr}, rhs: Value); + let minusOne: Attr = MinusOne(const); + + rewrite root with { + let newConst = op {value = minusOne}; + let newMul = op(newConst, rhs); + let newAdd = op(newMul, rhs); + replace root with newAdd; + }; +} diff --git a/lib/Transform/Arith/MulToAddPdll.cpp b/lib/Transform/Arith/MulToAddPdll.cpp index f76c834..0566342 100644 --- a/lib/Transform/Arith/MulToAddPdll.cpp +++ b/lib/Transform/Arith/MulToAddPdll.cpp @@ -11,7 +11,6 @@ namespace tutorial { #define GEN_PASS_DEF_MULTOADDPDLL #include "lib/Transform/Arith/Passes.h.inc" - LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results, ArrayRef args) { Attribute attr = args[0].cast(); @@ -21,8 +20,18 @@ LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results, return success(); } +LogicalResult minusOneImpl(PatternRewriter &rewriter, PDLResultList &results, + ArrayRef args) { + Attribute attr = args[0].cast(); + IntegerAttr cAttr = cast(attr); + int64_t value = cAttr.getValue().getSExtValue(); + results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value - 1)); + return success(); +} + void registerNativeConstraints(RewritePatternSet &patterns) { patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl); + patterns.getPDLPatterns().registerConstraintFunction("MinusOne", minusOneImpl); } struct MulToAddPdll : impl::MulToAddPdllBase {