Skip to content

Commit

Permalink
[Region] Bufferize with one-shot-bufferize (#973)
Browse files Browse the repository at this point in the history
* Region dialect: bufferize via one-shot-bufferize
* Region dialect: remove region-bufferize pass
  • Loading branch information
tkarna authored Dec 18, 2024
1 parent 65fab90 commit 43a7d7c
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 181 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H
#define IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H

#include <mlir/IR/MLIRContext.h>

namespace imex {
namespace region {
void registerBufferizableOpInterfaceExternalModels(
::mlir::DialectRegistry &registry);
} // namespace region
} // namespace imex

#endif // IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H
7 changes: 0 additions & 7 deletions include/imex/Dialect/Region/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Region)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Region)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Region)
add_public_tablegen_target(IMEXRegionPassIncGen)

add_mlir_doc(Passes RegionPasses ./ -gen-pass-doc)
40 changes: 0 additions & 40 deletions include/imex/Dialect/Region/Transforms/Passes.h

This file was deleted.

35 changes: 0 additions & 35 deletions include/imex/Dialect/Region/Transforms/Passes.td

This file was deleted.

4 changes: 4 additions & 0 deletions include/imex/InitIMEXDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <imex/Dialect/GPUX/IR/GPUXOps.h>
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
#include <imex/Dialect/Region/IR/RegionOps.h>
#include <imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h>
#include <imex/Dialect/XeTile/IR/XeTileOps.h>

namespace imex {
Expand All @@ -37,6 +38,9 @@ inline void registerAllDialects(::mlir::DialectRegistry &registry) {
::imex::xetile::XeTileDialect,
::imex::gpux::GPUXDialect>();
// clang-format on

// Register all external models.
region::registerBufferizableOpInterfaceExternalModels(registry);
}

/// Append all the IMEX dialects to the registry contained in the given context.
Expand Down
2 changes: 0 additions & 2 deletions include/imex/InitIMEXPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <imex/Dialect/Dist/Transforms/Passes.h>
#include <imex/Dialect/DistRuntime/Transforms/Passes.h>
#include <imex/Dialect/NDArray/Transforms/Passes.h>
#include <imex/Dialect/Region/Transforms/Passes.h>
// #include <imex/Dialect/*/Transforms/Passes.h>
#include "imex/Transforms/Passes.h"
#include <imex/Dialect/XeTile/Transforms/Passes.h>
Expand All @@ -47,7 +46,6 @@ inline void registerAllPasses() {
registerNDArrayPasses();
registerDistPasses();
registerDistRuntimePasses();
registerRegionPasses();
registerXeTilePasses();
// register*Passes();

Expand Down
164 changes: 164 additions & 0 deletions lib/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h"
#include "imex/Dialect/Region/IR/RegionOps.h"

#include <mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h>

namespace imex {
namespace region {
namespace {

/// Convert values to buffers. If a value is a tensor, get a buffer for it.
::mlir::LogicalResult
convertToBuffers(::mlir::ValueRange values,
::mlir::SmallVector<::mlir::Value> &buffers,
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options) {
buffers.reserve(values.size());
for (auto val : values) {
if (::mlir::isa<::mlir::TensorType>(val.getType())) {
::mlir::FailureOr<::mlir::Value> maybeBuffer =
::mlir::bufferization::getBuffer(rewriter, val, options);
if (failed(maybeBuffer)) {
return ::mlir::failure();
}
buffers.push_back(*maybeBuffer);
} else {
buffers.push_back(val);
}
}
return ::mlir::success();
}

/// Bufferization of region.env_region op. Replaced with a new
/// op that takes and returns memrefs.
struct EnvironmentRegionOpInterface
: public ::mlir::bufferization::BufferizableOpInterface::ExternalModel<
EnvironmentRegionOpInterface, region::EnvironmentRegionOp> {
bool bufferizesToMemoryRead(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
assert(::mlir::isa<::mlir::RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
// Assume all operands are read.
return true;
}

bool bufferizesToMemoryWrite(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
assert(::mlir::isa<::mlir::RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
// Assume all operands are written to.
return true;
}

::mlir::bufferization::AliasingValueList
getAliasingValues(::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
// Assume no aliasing.
return {};
}

::mlir::LogicalResult
bufferize(::mlir::Operation *op, ::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options) const {
auto envOp = ::mlir::cast<region::EnvironmentRegionOp>(op);
// Convert op arguments to memrefs.
::mlir::SmallVector<::mlir::Value> newArguments;
if (failed(convertToBuffers(envOp.getArgs(), newArguments, rewriter,
options))) {
return ::mlir::failure();
}
// Infer result memref types by converting yield op operands to memrefs
::mlir::SmallVector<::mlir::Value> newResults;
if (failed(convertToBuffers(envOp.getBody()->getTerminator()->getOperands(),
newResults, rewriter, options))) {
return ::mlir::failure();
}
::mlir::TypeRange resTypes(newResults);
// Create new op via generic constructor, op will have an empty region.
rewriter.setInsertionPoint(op);
::mlir::OperationState state(op->getLoc(), op->getName(), newArguments,
resTypes, op->getAttrs());
state.addRegion();
::mlir::Operation *newOp = ::mlir::Operation::create(state);
// Move block from old op into the new op.
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
op->getRegion(0).getBlocks());
rewriter.insert(newOp);
::mlir::bufferization::replaceOpWithBufferizedValues(rewriter, op,
newOp->getResults());

return ::mlir::success();
}
};

/// Bufferization of region.env_region_yield. Replaced with a new yield that
/// operates on memrefs.
struct EnvironmentRegionYieldOpInterface
: public ::mlir::bufferization::BufferizableOpInterface::ExternalModel<
EnvironmentRegionYieldOpInterface, region::EnvironmentRegionYieldOp> {
bool bufferizesToMemoryRead(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
return true;
}

bool bufferizesToMemoryWrite(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
return false;
}

::mlir::bufferization::AliasingValueList
getAliasingValues(::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
::mlir::bufferization::BufferRelation::Equivalent}};
}

bool mustBufferizeInPlace(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
// Yield operands always bufferize inplace.
return true;
}

::mlir::LogicalResult
bufferize(::mlir::Operation *op, ::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options) const {
auto yieldOp = ::mlir::cast<region::EnvironmentRegionYieldOp>(op);

// Create a new terminator with bufferized operands.
::mlir::SmallVector<::mlir::Value> newOperands;
if (failed(convertToBuffers(yieldOp.getOperands(), newOperands, rewriter,
options))) {
return ::mlir::failure();
}
::mlir::bufferization::replaceOpWithNewBufferizedOp<
region::EnvironmentRegionYieldOp>(rewriter, op, newOperands);
return ::mlir::success();
}
};

} // namespace
} // namespace region
} // namespace imex

void imex::region::registerBufferizableOpInterfaceExternalModels(
::mlir::DialectRegistry &registry) {
registry.addExtension(+[](::mlir::MLIRContext *ctx,
region::RegionDialect *dialect) {
EnvironmentRegionOp::attachInterface<EnvironmentRegionOpInterface>(*ctx);
EnvironmentRegionYieldOp::attachInterface<
EnvironmentRegionYieldOpInterface>(*ctx);
});
}
5 changes: 1 addition & 4 deletions lib/Dialect/Region/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
add_imex_dialect_library(IMEXRegionTransforms
BufferizableOpInterfaceImpl.cpp
RegionConversions.cpp
RegionBufferize.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/imex/Dialect/Region

DEPENDS
IMEXRegionPassIncGen

LINK_LIBS PUBLIC
IMEXRegionDialect
MLIRPass
Expand Down
69 changes: 0 additions & 69 deletions lib/Dialect/Region/Transforms/RegionBufferize.cpp

This file was deleted.

Loading

0 comments on commit 43a7d7c

Please sign in to comment.