Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Region] Bufferize with one-shot-bufferize #973

Merged
merged 5 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H
#define IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H

#include <mlir/IR/MLIRContext.h>

namespace imex {
namespace region {
void registerBufferizableOpInterfaceExternalModels(
::mlir::DialectRegistry &registry);
} // namespace region
} // namespace imex

#endif // IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H
7 changes: 0 additions & 7 deletions include/imex/Dialect/Region/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Region)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Region)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Region)
add_public_tablegen_target(IMEXRegionPassIncGen)

add_mlir_doc(Passes RegionPasses ./ -gen-pass-doc)
40 changes: 0 additions & 40 deletions include/imex/Dialect/Region/Transforms/Passes.h

This file was deleted.

35 changes: 0 additions & 35 deletions include/imex/Dialect/Region/Transforms/Passes.td

This file was deleted.

4 changes: 4 additions & 0 deletions include/imex/InitIMEXDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <imex/Dialect/GPUX/IR/GPUXOps.h>
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
#include <imex/Dialect/Region/IR/RegionOps.h>
#include <imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h>
#include <imex/Dialect/XeTile/IR/XeTileOps.h>

namespace imex {
Expand All @@ -37,6 +38,9 @@ inline void registerAllDialects(::mlir::DialectRegistry &registry) {
::imex::xetile::XeTileDialect,
::imex::gpux::GPUXDialect>();
// clang-format on

// Register all external models.
region::registerBufferizableOpInterfaceExternalModels(registry);
}

/// Append all the IMEX dialects to the registry contained in the given context.
Expand Down
2 changes: 0 additions & 2 deletions include/imex/InitIMEXPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <imex/Dialect/Dist/Transforms/Passes.h>
#include <imex/Dialect/DistRuntime/Transforms/Passes.h>
#include <imex/Dialect/NDArray/Transforms/Passes.h>
#include <imex/Dialect/Region/Transforms/Passes.h>
// #include <imex/Dialect/*/Transforms/Passes.h>
#include "imex/Transforms/Passes.h"
#include <imex/Dialect/XeTile/Transforms/Passes.h>
Expand All @@ -47,7 +46,6 @@ inline void registerAllPasses() {
registerNDArrayPasses();
registerDistPasses();
registerDistRuntimePasses();
registerRegionPasses();
registerXeTilePasses();
// register*Passes();

Expand Down
164 changes: 164 additions & 0 deletions lib/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h"
#include "imex/Dialect/Region/IR/RegionOps.h"

#include <mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h>

namespace imex {
namespace region {
namespace {

/// Convert values to buffers. If a value is a tensor, get a buffer for it.
::mlir::LogicalResult
convertToBuffers(::mlir::ValueRange values,
::mlir::SmallVector<::mlir::Value> &buffers,
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options) {
buffers.reserve(values.size());
for (auto val : values) {
if (::mlir::isa<::mlir::TensorType>(val.getType())) {
::mlir::FailureOr<::mlir::Value> maybeBuffer =
::mlir::bufferization::getBuffer(rewriter, val, options);
if (failed(maybeBuffer)) {
return ::mlir::failure();
}
buffers.push_back(*maybeBuffer);
} else {
buffers.push_back(val);
}
}
return ::mlir::success();
}

/// Bufferization of region.env_region op. Replaced with a new
/// op that takes and returns memrefs.
struct EnvironmentRegionOpInterface
: public ::mlir::bufferization::BufferizableOpInterface::ExternalModel<
EnvironmentRegionOpInterface, region::EnvironmentRegionOp> {
bool bufferizesToMemoryRead(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
assert(::mlir::isa<::mlir::RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
// Assume all operands are read.
return true;
}

bool bufferizesToMemoryWrite(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
assert(::mlir::isa<::mlir::RankedTensorType>(opOperand.get().getType()) &&
"only tensor types expected");
// Assume all operands are written to.
return true;
}

::mlir::bufferization::AliasingValueList
getAliasingValues(::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
// Assume no aliasing.
return {};
}

::mlir::LogicalResult
bufferize(::mlir::Operation *op, ::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options) const {
auto envOp = ::mlir::cast<region::EnvironmentRegionOp>(op);
// Convert op arguments to memrefs.
::mlir::SmallVector<::mlir::Value> newArguments;
if (failed(convertToBuffers(envOp.getArgs(), newArguments, rewriter,
options))) {
return ::mlir::failure();
}
// Infer result memref types by converting yield op operands to memrefs
::mlir::SmallVector<::mlir::Value> newResults;
if (failed(convertToBuffers(envOp.getBody()->getTerminator()->getOperands(),
newResults, rewriter, options))) {
return ::mlir::failure();
}
::mlir::TypeRange resTypes(newResults);
// Create new op via generic constructor, op will have an empty region.
rewriter.setInsertionPoint(op);
::mlir::OperationState state(op->getLoc(), op->getName(), newArguments,
resTypes, op->getAttrs());
state.addRegion();
::mlir::Operation *newOp = ::mlir::Operation::create(state);
// Move block from old op into the new op.
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
op->getRegion(0).getBlocks());
rewriter.insert(newOp);
::mlir::bufferization::replaceOpWithBufferizedValues(rewriter, op,
newOp->getResults());

return ::mlir::success();
}
};

/// Bufferization of region.env_region_yield. Replaced with a new yield that
/// operates on memrefs.
struct EnvironmentRegionYieldOpInterface
: public ::mlir::bufferization::BufferizableOpInterface::ExternalModel<
EnvironmentRegionYieldOpInterface, region::EnvironmentRegionYieldOp> {
bool bufferizesToMemoryRead(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is it reading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation follows scf.yield op bufferization: read is true, write is false etc.

}

bool bufferizesToMemoryWrite(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
return false;
}

::mlir::bufferization::AliasingValueList
getAliasingValues(::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
return {{op->getParentOp()->getResult(opOperand.getOperandNumber()),
::mlir::bufferization::BufferRelation::Equivalent}};
}

bool mustBufferizeInPlace(
::mlir::Operation *op, ::mlir::OpOperand &opOperand,
const ::mlir::bufferization::AnalysisState &state) const {
// Yield operands always bufferize inplace.
return true;
}

::mlir::LogicalResult
bufferize(::mlir::Operation *op, ::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options) const {
auto yieldOp = ::mlir::cast<region::EnvironmentRegionYieldOp>(op);

// Create a new terminator with bufferized operands.
::mlir::SmallVector<::mlir::Value> newOperands;
if (failed(convertToBuffers(yieldOp.getOperands(), newOperands, rewriter,
options))) {
return ::mlir::failure();
}
::mlir::bufferization::replaceOpWithNewBufferizedOp<
region::EnvironmentRegionYieldOp>(rewriter, op, newOperands);
return ::mlir::success();
}
};

} // namespace
} // namespace region
} // namespace imex

void imex::region::registerBufferizableOpInterfaceExternalModels(
::mlir::DialectRegistry &registry) {
registry.addExtension(+[](::mlir::MLIRContext *ctx,
region::RegionDialect *dialect) {
EnvironmentRegionOp::attachInterface<EnvironmentRegionOpInterface>(*ctx);
EnvironmentRegionYieldOp::attachInterface<
EnvironmentRegionYieldOpInterface>(*ctx);
});
}
5 changes: 1 addition & 4 deletions lib/Dialect/Region/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
add_imex_dialect_library(IMEXRegionTransforms
BufferizableOpInterfaceImpl.cpp
RegionConversions.cpp
RegionBufferize.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/imex/Dialect/Region

DEPENDS
IMEXRegionPassIncGen

LINK_LIBS PUBLIC
IMEXRegionDialect
MLIRPass
Expand Down
69 changes: 0 additions & 69 deletions lib/Dialect/Region/Transforms/RegionBufferize.cpp

This file was deleted.

Loading
Loading