Skip to content

Commit

Permalink
start trying to build the pdll-generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Jul 18, 2024
1 parent d6d73c1 commit 8615b61
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 3 deletions.
16 changes: 16 additions & 0 deletions lib/Transform/Arith/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ cc_library(
hdrs = ["Passes.h"],
deps = [
":MulToAdd",
":MulToAddPdll",
":pass_inc_gen",
],
)
Expand All @@ -67,3 +68,18 @@ gentbl_cc_library(
"@llvm-project//mlir:ArithOpsTdFiles",
],
)

cc_library(
name = "MulToAddPdll",
srcs = ["MulToAddPdll.cpp"],
hdrs = ["MulToAddPdll.h"],
deps = [
":pass_inc_gen",
":MulToAddPdllIncGen",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
)

7 changes: 4 additions & 3 deletions lib/Transform/Arith/MulToAdd.pdll
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include "mlir/Dialect/Arith/IR/ArithOps.td"

Constraint IsPowerOfTwo(attr: Attr) [{
int64_t value = cast<::mlir::IntegerAttr>(attr).getValue();
int64_t value = cast<::mlir::IntegerAttr>(attr).getValue().getSExtValue();
return success((value & (value - 1)) == 0);
}];

Constraint Halve(attr: Attr) -> Attr [{
int64_t value = cast<::mlir::IntegerAttr>(attr).getValue();
return rewriter.getIntegerAttr(attr.getType(), value / 2);
IntegerAttr cAttr = ::llvm::cast<::mlir::IntegerAttr>(attr);
int64_t value = cAttr.getValue().getSExtValue();
return rewriter.getIntegerAttr(cAttr.getType(), value / 2);
}];

Pattern PowerOfTwoExpand with benefit(2) {
Expand Down
24 changes: 24 additions & 0 deletions lib/Transform/Arith/MulToAddPdll.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "lib/Transform/Arith/MulToAddPdll.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/include/mlir/Pass/Pass.h"

namespace mlir {
namespace tutorial {

#define GEN_PASS_DEF_MULTOADDPDLL
#include "lib/Transform/Arith/Passes.h.inc"

struct MulToAddPdll : impl::MulToAddPdllBase<MulToAddPdll> {
using MulToAddPdllBase::MulToAddPdllBase;

void runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
populateGeneratedPDLLPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace tutorial
} // namespace mlir
19 changes: 19 additions & 0 deletions lib/Transform/Arith/MulToAddPdll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_
#define LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_

#include "mlir/Pass/Pass.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Dialect/Arith/IR/Arith.h"

namespace mlir {
namespace tutorial {

#define GEN_PASS_DECL_MULTOADDPDLL
#include "lib/Transform/Arith/Passes.h.inc"

#include "lib/Transform/Arith/MulToAddPdll.h.inc"

} // namespace tutorial
} // namespace mlir

#endif // LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_
7 changes: 7 additions & 0 deletions lib/Transform/Arith/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,11 @@ def MulToAdd : Pass<"mul-to-add"> {
}];
}

def MulToAddPdll : Pass<"mul-to-add-pdll"> {
let summary = "Convert multiplications to repeated additions using pdll";
let description = [{
Convert multiplications to repeated additions (using pdll).
}];
}

#endif // LIB_TRANSFORM_ARITH_PASSES_TD_

0 comments on commit 8615b61

Please sign in to comment.