diff --git a/lib/Transform/Arith/BUILD b/lib/Transform/Arith/BUILD index efd56e3..a1e6e0a 100644 --- a/lib/Transform/Arith/BUILD +++ b/lib/Transform/Arith/BUILD @@ -47,6 +47,7 @@ cc_library( hdrs = ["Passes.h"], deps = [ ":MulToAdd", + ":MulToAddPdll", ":pass_inc_gen", ], ) @@ -67,3 +68,18 @@ gentbl_cc_library( "@llvm-project//mlir:ArithOpsTdFiles", ], ) + +cc_library( + name = "MulToAddPdll", + srcs = ["MulToAddPdll.cpp"], + hdrs = ["MulToAddPdll.h"], + deps = [ + ":pass_inc_gen", + ":MulToAddPdllIncGen", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) + diff --git a/lib/Transform/Arith/MulToAdd.pdll b/lib/Transform/Arith/MulToAdd.pdll index 57cb4f4..54c2e4f 100644 --- a/lib/Transform/Arith/MulToAdd.pdll +++ b/lib/Transform/Arith/MulToAdd.pdll @@ -1,13 +1,14 @@ #include "mlir/Dialect/Arith/IR/ArithOps.td" Constraint IsPowerOfTwo(attr: Attr) [{ - int64_t value = cast<::mlir::IntegerAttr>(attr).getValue(); + int64_t value = cast<::mlir::IntegerAttr>(attr).getValue().getSExtValue(); return success((value & (value - 1)) == 0); }]; Constraint Halve(attr: Attr) -> Attr [{ - int64_t value = cast<::mlir::IntegerAttr>(attr).getValue(); - return rewriter.getIntegerAttr(attr.getType(), value / 2); + IntegerAttr cAttr = ::llvm::cast<::mlir::IntegerAttr>(attr); + int64_t value = cAttr.getValue().getSExtValue(); + return rewriter.getIntegerAttr(cAttr.getType(), value / 2); }]; Pattern PowerOfTwoExpand with benefit(2) { diff --git a/lib/Transform/Arith/MulToAddPdll.cpp b/lib/Transform/Arith/MulToAddPdll.cpp new file mode 100644 index 0000000..b88958f --- /dev/null +++ b/lib/Transform/Arith/MulToAddPdll.cpp @@ -0,0 +1,24 @@ +#include "lib/Transform/Arith/MulToAddPdll.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/include/mlir/Pass/Pass.h" + +namespace mlir { +namespace tutorial { + +#define GEN_PASS_DEF_MULTOADDPDLL +#include "lib/Transform/Arith/Passes.h.inc" + +struct MulToAddPdll : impl::MulToAddPdllBase { + using MulToAddPdllBase::MulToAddPdllBase; + + void runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + populateGeneratedPDLLPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace tutorial +} // namespace mlir diff --git a/lib/Transform/Arith/MulToAddPdll.h b/lib/Transform/Arith/MulToAddPdll.h new file mode 100644 index 0000000..9ebdc3c --- /dev/null +++ b/lib/Transform/Arith/MulToAddPdll.h @@ -0,0 +1,19 @@ +#ifndef LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ +#define LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ + +#include "mlir/Pass/Pass.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +namespace mlir { +namespace tutorial { + +#define GEN_PASS_DECL_MULTOADDPDLL +#include "lib/Transform/Arith/Passes.h.inc" + +#include "lib/Transform/Arith/MulToAddPdll.h.inc" + +} // namespace tutorial +} // namespace mlir + +#endif // LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_ diff --git a/lib/Transform/Arith/Passes.td b/lib/Transform/Arith/Passes.td index 0c60eb4..d5dd6c4 100644 --- a/lib/Transform/Arith/Passes.td +++ b/lib/Transform/Arith/Passes.td @@ -10,4 +10,11 @@ def MulToAdd : Pass<"mul-to-add"> { }]; } +def MulToAddPdll : Pass<"mul-to-add-pdll"> { + let summary = "Convert multiplications to repeated additions using pdll"; + let description = [{ + Convert multiplications to repeated additions (using pdll). + }]; +} + #endif // LIB_TRANSFORM_ARITH_PASSES_TD_