-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Region] Bufferize with one-shot-bufferize (#973)
* Region dialect: bufferize via one-shot-bufferize * Region dialect: remove region-bufferize pass
- Loading branch information
Showing
11 changed files
with
203 additions
and
181 deletions.
There are no files selected for viewing
21 changes: 21 additions & 0 deletions
21
include/imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H | ||
#define IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H | ||
|
||
#include <mlir/IR/MLIRContext.h> | ||
|
||
namespace imex { | ||
namespace region { | ||
void registerBufferizableOpInterfaceExternalModels( | ||
::mlir::DialectRegistry ®istry); | ||
} // namespace region | ||
} // namespace imex | ||
|
||
#endif // IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +0,0 @@ | ||
set(LLVM_TARGET_DEFINITIONS Passes.td) | ||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Region) | ||
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Region) | ||
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Region) | ||
add_public_tablegen_target(IMEXRegionPassIncGen) | ||
|
||
add_mlir_doc(Passes RegionPasses ./ -gen-pass-doc) | ||
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
164 changes: 164 additions & 0 deletions
164
lib/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h" | ||
#include "imex/Dialect/Region/IR/RegionOps.h" | ||
|
||
#include <mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h> | ||
|
||
namespace imex { | ||
namespace region { | ||
namespace { | ||
|
||
/// Convert values to buffers. If a value is a tensor, get a buffer for it. | ||
::mlir::LogicalResult | ||
convertToBuffers(::mlir::ValueRange values, | ||
::mlir::SmallVector<::mlir::Value> &buffers, | ||
::mlir::RewriterBase &rewriter, | ||
const ::mlir::bufferization::BufferizationOptions &options) { | ||
buffers.reserve(values.size()); | ||
for (auto val : values) { | ||
if (::mlir::isa<::mlir::TensorType>(val.getType())) { | ||
::mlir::FailureOr<::mlir::Value> maybeBuffer = | ||
::mlir::bufferization::getBuffer(rewriter, val, options); | ||
if (failed(maybeBuffer)) { | ||
return ::mlir::failure(); | ||
} | ||
buffers.push_back(*maybeBuffer); | ||
} else { | ||
buffers.push_back(val); | ||
} | ||
} | ||
return ::mlir::success(); | ||
} | ||
|
||
/// Bufferization of region.env_region op. Replaced with a new | ||
/// op that takes and returns memrefs. | ||
struct EnvironmentRegionOpInterface | ||
: public ::mlir::bufferization::BufferizableOpInterface::ExternalModel< | ||
EnvironmentRegionOpInterface, region::EnvironmentRegionOp> { | ||
bool bufferizesToMemoryRead( | ||
::mlir::Operation *op, ::mlir::OpOperand &opOperand, | ||
const ::mlir::bufferization::AnalysisState &state) const { | ||
assert(::mlir::isa<::mlir::RankedTensorType>(opOperand.get().getType()) && | ||
"only tensor types expected"); | ||
// Assume all operands are read. | ||
return true; | ||
} | ||
|
||
bool bufferizesToMemoryWrite( | ||
::mlir::Operation *op, ::mlir::OpOperand &opOperand, | ||
const ::mlir::bufferization::AnalysisState &state) const { | ||
assert(::mlir::isa<::mlir::RankedTensorType>(opOperand.get().getType()) && | ||
"only tensor types expected"); | ||
// Assume all operands are written to. | ||
return true; | ||
} | ||
|
||
::mlir::bufferization::AliasingValueList | ||
getAliasingValues(::mlir::Operation *op, ::mlir::OpOperand &opOperand, | ||
const ::mlir::bufferization::AnalysisState &state) const { | ||
// Assume no aliasing. | ||
return {}; | ||
} | ||
|
||
::mlir::LogicalResult | ||
bufferize(::mlir::Operation *op, ::mlir::RewriterBase &rewriter, | ||
const ::mlir::bufferization::BufferizationOptions &options) const { | ||
auto envOp = ::mlir::cast<region::EnvironmentRegionOp>(op); | ||
// Convert op arguments to memrefs. | ||
::mlir::SmallVector<::mlir::Value> newArguments; | ||
if (failed(convertToBuffers(envOp.getArgs(), newArguments, rewriter, | ||
options))) { | ||
return ::mlir::failure(); | ||
} | ||
// Infer result memref types by converting yield op operands to memrefs | ||
::mlir::SmallVector<::mlir::Value> newResults; | ||
if (failed(convertToBuffers(envOp.getBody()->getTerminator()->getOperands(), | ||
newResults, rewriter, options))) { | ||
return ::mlir::failure(); | ||
} | ||
::mlir::TypeRange resTypes(newResults); | ||
// Create new op via generic constructor, op will have an empty region. | ||
rewriter.setInsertionPoint(op); | ||
::mlir::OperationState state(op->getLoc(), op->getName(), newArguments, | ||
resTypes, op->getAttrs()); | ||
state.addRegion(); | ||
::mlir::Operation *newOp = ::mlir::Operation::create(state); | ||
// Move block from old op into the new op. | ||
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(), | ||
op->getRegion(0).getBlocks()); | ||
rewriter.insert(newOp); | ||
::mlir::bufferization::replaceOpWithBufferizedValues(rewriter, op, | ||
newOp->getResults()); | ||
|
||
return ::mlir::success(); | ||
} | ||
}; | ||
|
||
/// Bufferization of region.env_region_yield. Replaced with a new yield that | ||
/// operates on memrefs. | ||
struct EnvironmentRegionYieldOpInterface | ||
: public ::mlir::bufferization::BufferizableOpInterface::ExternalModel< | ||
EnvironmentRegionYieldOpInterface, region::EnvironmentRegionYieldOp> { | ||
bool bufferizesToMemoryRead( | ||
::mlir::Operation *op, ::mlir::OpOperand &opOperand, | ||
const ::mlir::bufferization::AnalysisState &state) const { | ||
return true; | ||
} | ||
|
||
bool bufferizesToMemoryWrite( | ||
::mlir::Operation *op, ::mlir::OpOperand &opOperand, | ||
const ::mlir::bufferization::AnalysisState &state) const { | ||
return false; | ||
} | ||
|
||
::mlir::bufferization::AliasingValueList | ||
getAliasingValues(::mlir::Operation *op, ::mlir::OpOperand &opOperand, | ||
const ::mlir::bufferization::AnalysisState &state) const { | ||
return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), | ||
::mlir::bufferization::BufferRelation::Equivalent}}; | ||
} | ||
|
||
bool mustBufferizeInPlace( | ||
::mlir::Operation *op, ::mlir::OpOperand &opOperand, | ||
const ::mlir::bufferization::AnalysisState &state) const { | ||
// Yield operands always bufferize inplace. | ||
return true; | ||
} | ||
|
||
::mlir::LogicalResult | ||
bufferize(::mlir::Operation *op, ::mlir::RewriterBase &rewriter, | ||
const ::mlir::bufferization::BufferizationOptions &options) const { | ||
auto yieldOp = ::mlir::cast<region::EnvironmentRegionYieldOp>(op); | ||
|
||
// Create a new terminator with bufferized operands. | ||
::mlir::SmallVector<::mlir::Value> newOperands; | ||
if (failed(convertToBuffers(yieldOp.getOperands(), newOperands, rewriter, | ||
options))) { | ||
return ::mlir::failure(); | ||
} | ||
::mlir::bufferization::replaceOpWithNewBufferizedOp< | ||
region::EnvironmentRegionYieldOp>(rewriter, op, newOperands); | ||
return ::mlir::success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
} // namespace region | ||
} // namespace imex | ||
|
||
void imex::region::registerBufferizableOpInterfaceExternalModels( | ||
::mlir::DialectRegistry ®istry) { | ||
registry.addExtension(+[](::mlir::MLIRContext *ctx, | ||
region::RegionDialect *dialect) { | ||
EnvironmentRegionOp::attachInterface<EnvironmentRegionOpInterface>(*ctx); | ||
EnvironmentRegionYieldOp::attachInterface< | ||
EnvironmentRegionYieldOpInterface>(*ctx); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.