Skip to content

Commit

Permalink
Perform initialization in pass initialize method (openxla#1966)
Browse files Browse the repository at this point in the history
This is unlikely to give us performance gains since most of our passes
run on modules anyway (so the initialization probably already occurs
only once), but it is cleaner to separate the initialization of a pass
from the actual running of the pass.

The code in `initialize` will run when the pass runs regardless of
whether there is at least one instance of the target operation. However,
modules are pretty much always present so this is unlikely to change
anything.
  • Loading branch information
mlevesquedion authored Feb 1, 2024
1 parent 68da793 commit 550d946
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2573,27 +2573,33 @@ struct StablehloLegalizeToLinalgPass
: impl::StablehloLegalizeToLinalgPassBase<StablehloLegalizeToLinalgPass> {
using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase;

void runOnOperation() override {
auto *context = &getContext();
auto target = ConversionTarget{*context};
auto patterns = RewritePatternSet{context};
auto typeConverter = std::make_unique<LinalgTypeConverter>();

target.addLegalDialect<
LogicalResult initialize(MLIRContext *context) override {
target = std::make_shared<ConversionTarget>(*context);
target->addLegalDialect<
bufferization::BufferizationDialect, arith::ArithDialect,
complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect,
tensor::TensorDialect, sparse_tensor::SparseTensorDialect,
scf::SCFDialect, shape::ShapeDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
target->addLegalOp<UnrealizedConversionCastOp>();

populateConversionPatterns(context, *typeConverter, &patterns,
RewritePatternSet patterns_(context);
populateConversionPatterns(context, converter, &patterns_,
enablePrimitiveOps);
patterns = std::move(patterns_);

return success();
}

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
void runOnOperation() override {
if (failed(applyPartialConversion(getOperation(), *target, patterns))) {
return signalPassFailure();
}
}

private:
std::shared_ptr<ConversionTarget> target;
FrozenRewritePatternSet patterns;
LinalgTypeConverter converter;
};
} // namespace
} // namespace mlir::stablehlo
14 changes: 11 additions & 3 deletions stablehlo/tests/TestUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,22 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
#include "stablehlo/tests/TestUtils.h.inc"

struct HloTestInferPass : public impl::HloTestInferPassBase<HloTestInferPass> {
LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet patterns_(context);
patterns_.add<InferReturnTypesPattern>(context);
patterns_.add<ReifyReturnTypeShapesPattern>(context);
patterns = std::move(patterns_);
return success();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<InferReturnTypesPattern>(&getContext());
patterns.add<ReifyReturnTypeShapesPattern>(&getContext());
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}

private:
FrozenRewritePatternSet patterns;
};

#define GEN_PASS_REGISTRATION
Expand Down
21 changes: 14 additions & 7 deletions stablehlo/transforms/StablehloCanonicalizeDynamism.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,24 +298,31 @@ struct StablehloCanonicalizeDynamismPass
using StablehloCanonicalizeDynamismPassBase::
StablehloCanonicalizeDynamismPassBase;

void runOnOperation() override {
GreedyRewriteConfig config;
LogicalResult initialize(MLIRContext* context) override {
config.useTopDownTraversal = true;
config.enableRegionSimplification = true;
config.maxIterations = 2;
config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
config.strictMode = GreedyRewriteStrictness::AnyOp;

RewritePatternSet patterns(&getContext());
populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext());
RewritePatternSet patterns_(context);
populateStablehloCanonicalizeDynamismPatterns(&patterns_, context);
patterns = std::move(patterns_);

return success();
}

void runOnOperation() override {
auto func = getOperation();
if (failed(
applyPatternsAndFoldGreedily(func, std::move(patterns), config))) {
if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
func.emitError("Failed to converge StablehloCanonicalizeDynamism in ")
<< config.maxIterations << " iterations";
return signalPassFailure();
}
}

private:
FrozenRewritePatternSet patterns;
GreedyRewriteConfig config;
};

} // namespace
Expand Down
29 changes: 18 additions & 11 deletions stablehlo/transforms/StablehloLegalizeToVhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,24 +862,31 @@ void populateStablehloToVhloPatterns(RewritePatternSet* patterns,
struct StablehloLegalizeToVhloPass
: public impl::StablehloLegalizeToVhloPassBase<
StablehloLegalizeToVhloPass> {
void runOnOperation() override {
ConversionTarget target(getContext());
target.addIllegalDialect<stablehlo::StablehloDialect>();
target.addIllegalDialect<func::FuncDialect>();
target.addLegalDialect<vhlo::VhloDialect>();
LogicalResult initialize(MLIRContext* context) override {
target = std::make_shared<ConversionTarget>(*context);
target->addIllegalDialect<stablehlo::StablehloDialect>();
target->addIllegalDialect<func::FuncDialect>();
target->addLegalDialect<vhlo::VhloDialect>();

RewritePatternSet patterns_(context);
stablehlo::populateStablehloToVhloPatterns(&patterns_, &converter, context);
patterns = std::move(patterns_);

StablehloToVhloTypeConverter converter;
RewritePatternSet patterns(&getContext());
stablehlo::populateStablehloToVhloPatterns(&patterns, &converter,
&getContext());
return success();
}

void runOnOperation() override {
// StableHLO should always be convertible to VHLO.
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
if (failed(applyPartialConversion(getOperation(), *target, patterns))) {
LLVM_DEBUG(llvm::dbgs() << "Failed partial conversion\n");
return signalPassFailure();
}
}

private:
StablehloToVhloTypeConverter converter;
FrozenRewritePatternSet patterns;
std::shared_ptr<ConversionTarget> target;
};

void populateStablehloToVhloPatterns(RewritePatternSet* patterns,
Expand Down
26 changes: 16 additions & 10 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1064,33 +1064,39 @@ struct StablehloRefineShapesPass
: public impl::StablehloRefineShapesPassBase<StablehloRefineShapesPass> {
using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase;

void runOnOperation() override {
auto func = getStablehloRefineShapesTarget(getOperation());
if (!func) return signalPassFailure();

LogicalResult initialize(MLIRContext* context) override {
// The algorithm behind this pass consists of a single traversal of the
// function. This is sufficient because we only support one function per
// program at the moment.
// TODO(#1048): Find out why .maxIterations = 1 no longer works.
// There have been recent refactors to applyPatternsAndFoldGreedily
// upstream, and that might be the reason.
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification = true;
config.maxIterations = 2;
config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
config.strictMode = GreedyRewriteStrictness::AnyOp;

RewritePatternSet patterns(&getContext());
RewritePatternSet patterns_(context);
populateStablehloRefineShapesPatterns(&patterns_, context);
patterns = std::move(patterns_);

return success();
}

void runOnOperation() override {
auto func = getStablehloRefineShapesTarget(getOperation());
if (!func) return signalPassFailure();

populateStablehloRefineShapesPatterns(&patterns, &getContext());
if (failed(
applyPatternsAndFoldGreedily(func, std::move(patterns), config))) {
if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) {
func.emitError("Failed to converge StablehloRefineShapes in ")
<< config.maxIterations << " iterations";
return signalPassFailure();
}
}

private:
FrozenRewritePatternSet patterns;
GreedyRewriteConfig config;
};

} // namespace
Expand Down
29 changes: 18 additions & 11 deletions stablehlo/transforms/VhloLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,25 +872,32 @@ void populateVhloToStablehloPatterns(RewritePatternSet* patterns,
struct VhloLegalizeToStablehloPass
: public impl::VhloLegalizeToStablehloPassBase<
VhloLegalizeToStablehloPass> {
void runOnOperation() override {
ConversionTarget target(getContext());
target.addIllegalDialect<vhlo::VhloDialect>();
target.addLegalDialect<stablehlo::StablehloDialect>();
target.addLegalDialect<func::FuncDialect>();
LogicalResult initialize(MLIRContext* context) override {
target = std::make_shared<ConversionTarget>(*context);
target->addIllegalDialect<vhlo::VhloDialect>();
target->addLegalDialect<stablehlo::StablehloDialect>();
target->addLegalDialect<func::FuncDialect>();

RewritePatternSet patterns_(context);
stablehlo::populateVhloToStablehloPatterns(&patterns_, &converter, context);
patterns = std::move(patterns_);

VhloToStablehloTypeConverter converter;
RewritePatternSet patterns(&getContext());
stablehlo::populateVhloToStablehloPatterns(&patterns, &converter,
&getContext());
return success();
}

void runOnOperation() override {
// Upgraded VHLO should always be convertible to StableHLO.
// Arbitrary VHLO might not be convertible if it uses deprecated features
// which are no longer available in StableHLO.
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
if (failed(applyPartialConversion(getOperation(), *target, patterns))) {
return signalPassFailure();
}
}

private:
VhloToStablehloTypeConverter converter;
FrozenRewritePatternSet patterns;
std::shared_ptr<ConversionTarget> target;
};

void populateVhloToStablehloPatterns(RewritePatternSet* patterns,
Expand Down
20 changes: 13 additions & 7 deletions stablehlo/transforms/VhloToVersion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ struct VhloToVersionPass : public VhloToVersionPassBase<VhloToVersionPass> {
VhloToVersionPass(const VhloToVersionPassOptions& opts)
: VhloToVersionPassBase<VhloToVersionPass>(opts) {}

LogicalResult initialize(MLIRContext* context) override {
RewritePatternSet patterns_(context);
stablehlo::populateVhloToVersionPatterns(&patterns_, &converter, context);
patterns = std::move(patterns_);

return success();
}

void runOnOperation() override {
ConversionTarget target(getContext());

Expand Down Expand Up @@ -248,16 +256,14 @@ struct VhloToVersionPass : public VhloToVersionPassBase<VhloToVersionPass> {
return isLegalOperation(op, targetVersion);
});

vhlo::VhloToVersionConverter converter;
RewritePatternSet patterns(&getContext());
stablehlo::populateVhloToVersionPatterns(&patterns, &converter,
&getContext());

// Conversions within VHLO may fail if new features or ops are used.
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
if (failed(applyPartialConversion(getOperation(), target, patterns)))
return signalPassFailure();
}

private:
vhlo::VhloToVersionConverter converter;
FrozenRewritePatternSet patterns;
};

////////////////////////////////////////////
Expand Down

0 comments on commit 550d946

Please sign in to comment.