Skip to content

Commit

Permalink
add noise verification analysis check
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Nov 11, 2023
1 parent 37ae94f commit 14b0912
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lib/Transform/Noisy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ cc_library(
":pass_inc_gen",
"//lib/Analysis/ReduceNoiseAnalysis",
"//lib/Dialect/Noisy",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
],
Expand Down
48 changes: 47 additions & 1 deletion lib/Transform/Noisy/ReduceNoiseOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include "lib/Analysis/ReduceNoiseAnalysis/ReduceNoiseAnalysis.h"
#include "lib/Dialect/Noisy/NoisyOps.h"
#include "lib/Dialect/Noisy/NoisyTypes.h"
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/include/mlir/Analysis/DataFlowFramework.h"
#include "mlir/include/mlir/IR/Visitors.h"
#include "mlir/include/mlir/Pass/Pass.h"

namespace mlir {
Expand All @@ -21,14 +25,56 @@ struct ReduceNoiseOptimizer
ReduceNoiseAnalysis analysis(getOperation());
OpBuilder b(&getContext());

getOperation()->walk([&](Operation *op) {
Operation *module = getOperation();

module->walk([&](Operation *op) {
if (!analysis.shouldInsertReduceNoise(op))
return;

b.setInsertionPointAfter(op);
auto reduceOp = b.create<ReduceNoiseOp>(op->getLoc(), op->getResult(0));
op->getResult(0).replaceAllUsesExcept(reduceOp.getResult(), {reduceOp});
});

// Afterwards, use the int range analysis to confirm the noise is always
// below the maximum.
DataFlowSolver solver;
// The IntegerRangeAnalysis depends on DeadCodeAnalysis, but this
// dependence is not automatic and fails silently.
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(module)))
signalPassFailure();

auto result = module->walk([&](Operation *op) {
if (!llvm::isa<noisy::AddOp, noisy::SubOp, noisy::MulOp,
noisy::ReduceNoiseOp>(*op)) {
return WalkResult::advance();
}
const dataflow::IntegerValueRangeLattice *opRange =
solver.lookupState<dataflow::IntegerValueRangeLattice>(
op->getResult(0));
if (!opRange || opRange->getValue().isUninitialized()) {
op->emitOpError()
<< "Found op without a set integer range; did the analysis fail?";
return WalkResult::interrupt();
}

ConstantIntRanges range = opRange->getValue().getValue();
if (range.umax().getZExtValue() > MAX_NOISE) {
op->emitOpError() << "Found op after which the noise exceeds the "
"allowable maximum of "
<< MAX_NOISE
<< "; it was: " << range.umax().getZExtValue()
<< "\n";
return WalkResult::interrupt();
}

return WalkResult::advance();
});

if (result.wasInterrupted())
signalPassFailure();
}
};

Expand Down

0 comments on commit 14b0912

Please sign in to comment.