Skip to content

Commit

Permalink
add PeelFromMul
Browse files Browse the repository at this point in the history
j2kun committed Jul 22, 2024
1 parent 83053ed commit 139a63e
Showing 2 changed files with 46 additions and 1 deletion.
36 changes: 36 additions & 0 deletions lib/Transform/Arith/MulToAdd.pdll
Original file line number Diff line number Diff line change
@@ -7,7 +7,10 @@ Constraint IsPowerOfTwo(attr: Attr) [{

// Currently, constraints that return values must be defined in C++
Constraint Halve(attr: Attr) -> Attr;
Constraint MinusOne(attr: Attr) -> Attr;

// Replace y = C*x with y = C/2*x + C/2*x, when C is a power of 2, otherwise do
// nothing.
Pattern PowerOfTwoExpandRhs with benefit(2) {
let root = op<arith.muli>(op<arith.constant> {value = const: Attr}, rhs: Value);
IsPowerOfTwo(const);
@@ -33,3 +36,36 @@ Pattern PowerOfTwoExpandLhs with benefit(2) {
replace root with newAdd;
};
}

// Replace y = 9*x with y = 8*x + x
Pattern PeelFromMulRhs with benefit(1) {
let root = op<arith.muli>(lhs: Value, op<arith.constant> {value = const: Attr});

// We are guaranteed `value` is not a power of two, because the greedy
// rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
// it has higher benefit.
let minusOne: Attr = MinusOne(const);

rewrite root with {
let newConst = op<arith.constant> {value = minusOne};
let newMul = op<arith.muli>(lhs, newConst);
let newAdd = op<arith.addi>(newMul, lhs);
replace root with newAdd;
};
}

Pattern PeelFromMulLhs with benefit(1) {
let root = op<arith.muli>(op<arith.constant> {value = const: Attr}, rhs: Value);

// We are guaranteed `value` is not a power of two, because the greedy
// rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
// it has higher benefit.
let minusOne: Attr = MinusOne(const);

rewrite root with {
let newConst = op<arith.constant> {value = minusOne};
let newMul = op<arith.muli>(newConst, rhs);
let newAdd = op<arith.addi>(newMul, rhs);
replace root with newAdd;
};
}
11 changes: 10 additions & 1 deletion lib/Transform/Arith/MulToAddPdll.cpp
Original file line number Diff line number Diff line change
@@ -11,7 +11,6 @@ namespace tutorial {
#define GEN_PASS_DEF_MULTOADDPDLL
#include "lib/Transform/Arith/Passes.h.inc"


LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
Attribute attr = args[0].cast<Attribute>();
@@ -21,8 +20,18 @@ LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results,
return success();
}

LogicalResult minusOneImpl(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
Attribute attr = args[0].cast<Attribute>();
IntegerAttr cAttr = cast<IntegerAttr>(attr);
int64_t value = cAttr.getValue().getSExtValue();
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value - 1));
return success();
}

void registerNativeConstraints(RewritePatternSet &patterns) {
patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl);
patterns.getPDLPatterns().registerConstraintFunction("MinusOne", minusOneImpl);
}

struct MulToAddPdll : impl::MulToAddPdllBase<MulToAddPdll> {

0 comments on commit 139a63e

Please sign in to comment.