diff --git a/include/imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h b/include/imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 000000000..3db902c6d --- /dev/null +++ b/include/imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,21 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H +#define IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H + +#include + +namespace imex { +namespace region { +void registerBufferizableOpInterfaceExternalModels( + ::mlir::DialectRegistry ®istry); +} // namespace region +} // namespace imex + +#endif // IMEX_DIALECT_REGION_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/include/imex/Dialect/Region/Transforms/CMakeLists.txt b/include/imex/Dialect/Region/Transforms/CMakeLists.txt index 76de27820..e69de29bb 100644 --- a/include/imex/Dialect/Region/Transforms/CMakeLists.txt +++ b/include/imex/Dialect/Region/Transforms/CMakeLists.txt @@ -1,7 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name Region) -mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Region) -mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Region) -add_public_tablegen_target(IMEXRegionPassIncGen) - -add_mlir_doc(Passes RegionPasses ./ -gen-pass-doc) diff --git a/include/imex/Dialect/Region/Transforms/Passes.h b/include/imex/Dialect/Region/Transforms/Passes.h deleted file mode 100644 index 22921e32a..000000000 --- a/include/imex/Dialect/Region/Transforms/Passes.h +++ /dev/null @@ -1,40 +0,0 @@ -//===-- Passes.h - Dist pass declaration file -------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This header file defines prototypes that expose pass constructors for the -/// Region dialect. -/// -//===----------------------------------------------------------------------===// - -#ifndef _Region_PASSES_H_INCLUDED_ -#define _Region_PASSES_H_INCLUDED_ - -#include - -namespace imex { - -//===----------------------------------------------------------------------===// -/// Dist passes. -//===----------------------------------------------------------------------===// - -/// Create a RegionBufferize pass -std::unique_ptr<::mlir::Pass> createRegionBufferizePass(); - -//===----------------------------------------------------------------------===// -// Registration -//===----------------------------------------------------------------------===// - -/// Generate the code for registering passes. -#define GEN_PASS_REGISTRATION -#include - -} // namespace imex - -#endif // _Region_PASSES_H_INCLUDED_ diff --git a/include/imex/Dialect/Region/Transforms/Passes.td b/include/imex/Dialect/Region/Transforms/Passes.td deleted file mode 100644 index 876dd6f55..000000000 --- a/include/imex/Dialect/Region/Transforms/Passes.td +++ /dev/null @@ -1,35 +0,0 @@ -//===-- Passes.td - Region pass definition file -----------*- tablegen -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines passes/transformations of the Region dialect. -/// -//===----------------------------------------------------------------------===// - -#ifndef _Region_PASSES_TD_INCLUDED_ -#define _Region_PASSES_TD_INCLUDED_ - -include "mlir/Pass/PassBase.td" - -//===----------------------------------------------------------------------===// -// DistCoalesce -//===----------------------------------------------------------------------===// - -def RegionBufferize : Pass<"region-bufferize"> { - let summary = "Bufferization of region ops"; - let description = [{ - Bufferize EnvironmentRegionOp and EnvironmentRegionYieldOp. - }]; - let constructor = "imex::createRegionBufferizePass()"; - let dependentDialects = ["::mlir::bufferization::BufferizationDialect", - "::mlir::memref::MemRefDialect"]; - let options = []; -} - -#endif // _Region_PASSES_TD_INCLUDED_ diff --git a/include/imex/InitIMEXDialects.h b/include/imex/InitIMEXDialects.h index ec0f78f20..8512813d2 100644 --- a/include/imex/InitIMEXDialects.h +++ b/include/imex/InitIMEXDialects.h @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace imex { @@ -37,6 +38,9 @@ inline void registerAllDialects(::mlir::DialectRegistry ®istry) { ::imex::xetile::XeTileDialect, ::imex::gpux::GPUXDialect>(); // clang-format on + + // Register all external models. + region::registerBufferizableOpInterfaceExternalModels(registry); } /// Append all the IMEX dialects to the registry contained in the given context. diff --git a/include/imex/InitIMEXPasses.h b/include/imex/InitIMEXPasses.h index 3d3afb87f..a6b254c08 100644 --- a/include/imex/InitIMEXPasses.h +++ b/include/imex/InitIMEXPasses.h @@ -20,7 +20,6 @@ #include #include #include -#include // #include #include "imex/Transforms/Passes.h" #include @@ -47,7 +46,6 @@ inline void registerAllPasses() { registerNDArrayPasses(); registerDistPasses(); registerDistRuntimePasses(); - registerRegionPasses(); registerXeTilePasses(); // register*Passes(); diff --git a/lib/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.cpp b/lib/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..1687a8c98 --- /dev/null +++ b/lib/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,164 @@ +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "imex/Dialect/Region/Transforms/BufferizableOpInterfaceImpl.h" +#include "imex/Dialect/Region/IR/RegionOps.h" + +#include + +namespace imex { +namespace region { +namespace { + +/// Convert values to buffers. If a value is a tensor, get a buffer for it. +::mlir::LogicalResult +convertToBuffers(::mlir::ValueRange values, + ::mlir::SmallVector<::mlir::Value> &buffers, + ::mlir::RewriterBase &rewriter, + const ::mlir::bufferization::BufferizationOptions &options) { + buffers.reserve(values.size()); + for (auto val : values) { + if (::mlir::isa<::mlir::TensorType>(val.getType())) { + ::mlir::FailureOr<::mlir::Value> maybeBuffer = + ::mlir::bufferization::getBuffer(rewriter, val, options); + if (failed(maybeBuffer)) { + return ::mlir::failure(); + } + buffers.push_back(*maybeBuffer); + } else { + buffers.push_back(val); + } + } + return ::mlir::success(); +} + +/// Bufferization of region.env_region op. Replaced with a new +/// op that takes and returns memrefs. +struct EnvironmentRegionOpInterface + : public ::mlir::bufferization::BufferizableOpInterface::ExternalModel< + EnvironmentRegionOpInterface, region::EnvironmentRegionOp> { + bool bufferizesToMemoryRead( + ::mlir::Operation *op, ::mlir::OpOperand &opOperand, + const ::mlir::bufferization::AnalysisState &state) const { + assert(::mlir::isa<::mlir::RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + // Assume all operands are read. + return true; + } + + bool bufferizesToMemoryWrite( + ::mlir::Operation *op, ::mlir::OpOperand &opOperand, + const ::mlir::bufferization::AnalysisState &state) const { + assert(::mlir::isa<::mlir::RankedTensorType>(opOperand.get().getType()) && + "only tensor types expected"); + // Assume all operands are written to. + return true; + } + + ::mlir::bufferization::AliasingValueList + getAliasingValues(::mlir::Operation *op, ::mlir::OpOperand &opOperand, + const ::mlir::bufferization::AnalysisState &state) const { + // Assume no aliasing. + return {}; + } + + ::mlir::LogicalResult + bufferize(::mlir::Operation *op, ::mlir::RewriterBase &rewriter, + const ::mlir::bufferization::BufferizationOptions &options) const { + auto envOp = ::mlir::cast(op); + // Convert op arguments to memrefs. + ::mlir::SmallVector<::mlir::Value> newArguments; + if (failed(convertToBuffers(envOp.getArgs(), newArguments, rewriter, + options))) { + return ::mlir::failure(); + } + // Infer result memref types by converting yield op operands to memrefs + ::mlir::SmallVector<::mlir::Value> newResults; + if (failed(convertToBuffers(envOp.getBody()->getTerminator()->getOperands(), + newResults, rewriter, options))) { + return ::mlir::failure(); + } + ::mlir::TypeRange resTypes(newResults); + // Create new op via generic constructor, op will have an empty region. + rewriter.setInsertionPoint(op); + ::mlir::OperationState state(op->getLoc(), op->getName(), newArguments, + resTypes, op->getAttrs()); + state.addRegion(); + ::mlir::Operation *newOp = ::mlir::Operation::create(state); + // Move block from old op into the new op. + newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(), + op->getRegion(0).getBlocks()); + rewriter.insert(newOp); + ::mlir::bufferization::replaceOpWithBufferizedValues(rewriter, op, + newOp->getResults()); + + return ::mlir::success(); + } +}; + +/// Bufferization of region.env_region_yield. Replaced with a new yield that +/// operates on memrefs. +struct EnvironmentRegionYieldOpInterface + : public ::mlir::bufferization::BufferizableOpInterface::ExternalModel< + EnvironmentRegionYieldOpInterface, region::EnvironmentRegionYieldOp> { + bool bufferizesToMemoryRead( + ::mlir::Operation *op, ::mlir::OpOperand &opOperand, + const ::mlir::bufferization::AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite( + ::mlir::Operation *op, ::mlir::OpOperand &opOperand, + const ::mlir::bufferization::AnalysisState &state) const { + return false; + } + + ::mlir::bufferization::AliasingValueList + getAliasingValues(::mlir::Operation *op, ::mlir::OpOperand &opOperand, + const ::mlir::bufferization::AnalysisState &state) const { + return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), + ::mlir::bufferization::BufferRelation::Equivalent}}; + } + + bool mustBufferizeInPlace( + ::mlir::Operation *op, ::mlir::OpOperand &opOperand, + const ::mlir::bufferization::AnalysisState &state) const { + // Yield operands always bufferize inplace. + return true; + } + + ::mlir::LogicalResult + bufferize(::mlir::Operation *op, ::mlir::RewriterBase &rewriter, + const ::mlir::bufferization::BufferizationOptions &options) const { + auto yieldOp = ::mlir::cast(op); + + // Create a new terminator with bufferized operands. + ::mlir::SmallVector<::mlir::Value> newOperands; + if (failed(convertToBuffers(yieldOp.getOperands(), newOperands, rewriter, + options))) { + return ::mlir::failure(); + } + ::mlir::bufferization::replaceOpWithNewBufferizedOp< + region::EnvironmentRegionYieldOp>(rewriter, op, newOperands); + return ::mlir::success(); + } +}; + +} // namespace +} // namespace region +} // namespace imex + +void imex::region::registerBufferizableOpInterfaceExternalModels( + ::mlir::DialectRegistry ®istry) { + registry.addExtension(+[](::mlir::MLIRContext *ctx, + region::RegionDialect *dialect) { + EnvironmentRegionOp::attachInterface(*ctx); + EnvironmentRegionYieldOp::attachInterface< + EnvironmentRegionYieldOpInterface>(*ctx); + }); +} diff --git a/lib/Dialect/Region/Transforms/CMakeLists.txt b/lib/Dialect/Region/Transforms/CMakeLists.txt index 13b81d409..d51698260 100644 --- a/lib/Dialect/Region/Transforms/CMakeLists.txt +++ b/lib/Dialect/Region/Transforms/CMakeLists.txt @@ -1,13 +1,10 @@ add_imex_dialect_library(IMEXRegionTransforms + BufferizableOpInterfaceImpl.cpp RegionConversions.cpp - RegionBufferize.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/imex/Dialect/Region - DEPENDS - IMEXRegionPassIncGen - LINK_LIBS PUBLIC IMEXRegionDialect MLIRPass diff --git a/lib/Dialect/Region/Transforms/RegionBufferize.cpp b/lib/Dialect/Region/Transforms/RegionBufferize.cpp deleted file mode 100644 index 6d9cc3b9c..000000000 --- a/lib/Dialect/Region/Transforms/RegionBufferize.cpp +++ /dev/null @@ -1,69 +0,0 @@ -//===- RegionBufferize.cpp - Bufferization for region ops -----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements bufferization of region ops. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/Func/Transforms/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "imex/Dialect/Region/IR/RegionOps.h" -#include "imex/Dialect/Region/Transforms/Passes.h" -#include "imex/Dialect/Region/Transforms/RegionConversions.h" - -namespace imex { -#define GEN_PASS_DEF_REGIONBUFFERIZE -#include "imex/Dialect/Region/Transforms/Passes.h.inc" -} // namespace imex - -static ::mlir::Value materializeToTensorRestrict(::mlir::OpBuilder &builder, - ::mlir::TensorType type, - ::mlir::ValueRange inputs, - ::mlir::Location loc) { - assert(inputs.size() == 1); - assert(::mlir::isa<::mlir::BaseMemRefType>(inputs[0].getType())); - return builder.create<::mlir::bufferization::ToTensorOp>(loc, type, inputs[0], - /*restrict=*/true); -} - -namespace { -struct RegionBufferizePass - : public ::imex::impl::RegionBufferizeBase { - using ::imex::impl::RegionBufferizeBase< - RegionBufferizePass>::RegionBufferizeBase; - - void runOnOperation() override { - auto module = getOperation(); - auto *context = &getContext(); - - ::mlir::bufferization::BufferizeTypeConverter typeConverter; - ::mlir::RewritePatternSet patterns(context); - ::mlir::ConversionTarget target(*context); - - typeConverter.addArgumentMaterialization(materializeToTensorRestrict); - typeConverter.addSourceMaterialization(materializeToTensorRestrict); - ::imex::populateRegionTypeConversionPatterns(patterns, typeConverter); - - target.addDynamicallyLegalOp<::imex::region::EnvironmentRegionOp, - ::imex::region::EnvironmentRegionYieldOp>( - [&](mlir::Operation *op) { return typeConverter.isLegal(op); }); - - if (::mlir::failed(::mlir::applyPartialConversion(module, target, - std::move(patterns)))) - signalPassFailure(); - } -}; -} // namespace - -std::unique_ptr<::mlir::Pass> imex::createRegionBufferizePass() { - return std::make_unique(); -} diff --git a/test/Dialect/Region/Transforms/RegionBufferize.mlir b/test/Dialect/Region/Transforms/RegionBufferize.mlir index 0678cb82a..c653f2cc8 100644 --- a/test/Dialect/Region/Transforms/RegionBufferize.mlir +++ b/test/Dialect/Region/Transforms/RegionBufferize.mlir @@ -1,31 +1,21 @@ -// RUN: imex-opt %s -region-bufferize --split-input-file | FileCheck %s +// RUN: imex-opt %s -one-shot-bufferize --split-input-file | FileCheck %s -#map = affine_map<(d0) -> (d0)> module { - func.func @sharpy_jit() -> memref<16xi64, strided<[?], offset: ?>> attributes {llvm.emit_c_interface} { - %cst = arith.constant 0.000000e+00 : f64 + func.func @test_bufferize() -> memref<16xi64, strided<[?], offset: ?>> { + %c1_i64 = arith.constant 1 : i64 %0 = region.env_region #region.gpu_env -> tensor<16xi64> { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<16xi64> - linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%alloc : memref<16xi64>) { - ^bb0(%out: i64): - %3 = linalg.index 0 : index - %4 = arith.index_cast %3 : index to i64 - %5 = arith.sitofp %4 : i64 to f64 - %6 = arith.addf %5, %cst : f64 - %7 = arith.fptosi %6 : f64 to i64 - linalg.yield %7 : i64 - } - %2 = bufferization.to_tensor %alloc : memref<16xi64> - region.env_region_yield %2 : tensor<16xi64> + %2 = bufferization.alloc_tensor() : tensor<16xi64> + %3 = linalg.fill ins(%c1_i64 : i64) outs(%2 : tensor<16xi64>) -> tensor<16xi64> + region.env_region_yield %3 : tensor<16xi64> } %1 = bufferization.to_memref %0 : memref<16xi64, strided<[?], offset: ?>> return %1 : memref<16xi64, strided<[?], offset: ?>> } } -// CHECK-LABEL: func.func @sharpy_jit() -> memref<16xi64, strided<[?], offset: ?>> attributes {llvm.emit_c_interface} { -// CHECK: region.env_region #region.gpu_env -> memref<16xi64> { -// CHECK: memref.alloc() {alignment = 64 : i64} : memref<16xi64> -// CHECK: linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%alloc : memref<16xi64>) { -// CHECK: region.env_region_yield %4 : memref<16xi64> -// CHECK: return -// CHECK-SAME: : memref<16xi64, strided<[?], offset: ?>> +// CHECK-LABEL: func.func @test_bufferize() -> memref<16xi64, strided<[?], offset: ?>> { +// CHECK: [[R1:%.*]] = region.env_region #region.gpu_env -> memref<16xi64> { +// CHECK-NEXT: [[V1:%.*]] = memref.alloc() {alignment = 64 : i64} : memref<16xi64> +// CHECK-NEXT: linalg.fill +// CHECK-NEXT: region.env_region_yield [[V1]] : memref<16xi64> +// CHECK: [[V2:%.*]] = memref.cast [[R1]] : memref<16xi64> to memref<16xi64, strided<[?], offset: ?>> +// CHECK-NEXT: return [[V2]] : memref<16xi64, strided<[?], offset: ?>> diff --git a/test/imex-runner/fullgpu.pp b/test/imex-runner/fullgpu.pp index 075bdece0..77ace9b97 100644 --- a/test/imex-runner/fullgpu.pp +++ b/test/imex-runner/fullgpu.pp @@ -19,7 +19,6 @@ memref-expand, func.func(empty-tensor-to-alloc-tensor), one-shot-bufferize{unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries} - region-bufferize, canonicalize, imex-remove-temporaries, func.func(convert-linalg-to-parallel-loops),