From 44617af8c767110605664595b9ba7dbe521ec87d Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 18 Jul 2024 15:41:29 -0700 Subject: [PATCH] try defining Halve as a native constraint --- lib/Transform/Arith/MulToAdd.pdll | 6 +----- lib/Transform/Arith/MulToAddPdll.cpp | 12 ++++++++++++ lib/Transform/Arith/MulToAddPdll.h | 1 + 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/lib/Transform/Arith/MulToAdd.pdll b/lib/Transform/Arith/MulToAdd.pdll index 54c2e4f..274693a 100644 --- a/lib/Transform/Arith/MulToAdd.pdll +++ b/lib/Transform/Arith/MulToAdd.pdll @@ -5,11 +5,7 @@ Constraint IsPowerOfTwo(attr: Attr) [{ return success((value & (value - 1)) == 0); }]; -Constraint Halve(attr: Attr) -> Attr [{ - IntegerAttr cAttr = ::llvm::cast<::mlir::IntegerAttr>(attr); - int64_t value = cAttr.getValue().getSExtValue(); - return rewriter.getIntegerAttr(cAttr.getType(), value / 2); -}]; +Constraint Halve(attr: Attr) -> Attr; Pattern PowerOfTwoExpand with benefit(2) { let root = op(op {value = const: Attr}, rhs: Value); diff --git a/lib/Transform/Arith/MulToAddPdll.cpp b/lib/Transform/Arith/MulToAddPdll.cpp index b88958f..36cc48a 100644 --- a/lib/Transform/Arith/MulToAddPdll.cpp +++ b/lib/Transform/Arith/MulToAddPdll.cpp @@ -10,12 +10,24 @@ namespace tutorial { #define GEN_PASS_DEF_MULTOADDPDLL #include "lib/Transform/Arith/Passes.h.inc" +Attribute halveImpl(PatternRewriter &rewriter, Attribute attr) { + IntegerAttr cAttr = ::llvm::cast<::mlir::IntegerAttr>(attr); + int64_t value = cAttr.getValue().getSExtValue(); + return rewriter.getIntegerAttr(cAttr.getType(), value / 2); +} + +void registerNativeConstraints(RewritePatternSet &patterns) { + patterns.getPDLPatterns().registerConstraintFunction( + "Halve", halveImpl); +} + struct MulToAddPdll : impl::MulToAddPdllBase { using MulToAddPdllBase::MulToAddPdllBase; void runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); populateGeneratedPDLLPatterns(patterns); + registerNativeConstraints(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; diff --git a/lib/Transform/Arith/MulToAddPdll.h b/lib/Transform/Arith/MulToAddPdll.h index 9ebdc3c..36ee995 100644 --- a/lib/Transform/Arith/MulToAddPdll.h +++ b/lib/Transform/Arith/MulToAddPdll.h @@ -4,6 +4,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Parser/Parser.h" namespace mlir { namespace tutorial {