Skip to content

Commit

Permalink
moving bufferization of insertslice and subview from ndarraytolinalg …
Browse files Browse the repository at this point in the history
…to bufferizableinterface; add shardinginterface to reshape (incomplete)
  • Loading branch information
fschlimb committed Dec 11, 2024
1 parent bf4ba36 commit 2728722
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 13 deletions.
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ecf1694333c05fc7180a2ad8fa80bbd709f35006
0eeb79d76a8284fae3e5e3b4ebbbe98d02249235
11 changes: 9 additions & 2 deletions include/imex/Dialect/NDArray/IR/NDArrayOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
// include "mlir/Interfaces/ShapedOpInterfaces.td"
// include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/IR/OpAsmInterface.td"
Expand Down Expand Up @@ -293,6 +294,10 @@ def SubviewOp : NDArray_OpWithOffsetSizesAndStrides<"subview", [
def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
DestinationStyleOpInterface,
Pure,
TypesMatchWith<"expected result type to match dest type",
"destination", "result", "$_self">
]> {
let summary = "Copy values from a array into a slice of another.";
let description = [{
Expand All @@ -312,6 +317,7 @@ def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [
DenseI64ArrayAttr:$static_sizes,
DenseI64ArrayAttr:$static_strides
);
let results = (outs AnyRankedTensor:$result);

let assemblyFormat = [{
$source `into` $destination ``
Expand Down Expand Up @@ -377,6 +383,7 @@ def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [
/// and `strides` operands.
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }

mlir::MutableOperandRange getDpsInitsMutable() { return getDestinationMutable(); }
}];

let hasCanonicalizer = 1;
Expand Down Expand Up @@ -408,8 +415,8 @@ def ReshapeOp : NDArray_Op<"reshape", []> {
See Array API.
}];

let arguments = (ins AnyType:$source, Variadic<Index>:$shape, OptionalAttr<I1Attr>:$copy);
let results = (outs AnyType);
let arguments = (ins AnyRankedTensor:$source, Variadic<Index>:$shape, OptionalAttr<I1Attr>:$copy);
let results = (outs AnyRankedTensor);

let assemblyFormat = [{
$source $shape attr-dict `:` qualified(type($source)) `->` qualified(type(results))
Expand Down
18 changes: 11 additions & 7 deletions lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,20 +430,24 @@ struct ConvertNDArrayToLinalgPass
});

::mlir::ConversionTarget target(ctxt);
// We convert all NDArray stuff...
target.addIllegalDialect<::imex::ndarray::NDArrayDialect>();
// ...into Linalg, Affine, Tensor, Arith
target.addLegalDialect<
::mlir::linalg::LinalgDialect, ::mlir::arith::ArithDialect,
::mlir::memref::MemRefDialect, ::mlir::tensor::TensorDialect,
::mlir::bufferization::BufferizationDialect, ::mlir::func::FuncDialect,
::imex::region::RegionDialect>();
target.addLegalOp<mlir::UnrealizedConversionCastOp>();

target.addLegalOp<imex::ndarray::SubviewOp, imex::ndarray::InsertSliceOp,
mlir::UnrealizedConversionCastOp>();

// We convert almost all NDArray stuff...
target.addDynamicallyLegalDialect<::imex::ndarray::NDArrayDialect>(
[&](mlir::Operation *op) {
return mlir::isa<imex::ndarray::SubviewOp,
imex::ndarray::InsertSliceOp>(op);
});
::mlir::RewritePatternSet patterns(&ctxt);
patterns.insert<SubviewLowering, InsertSliceLowering, LinSpaceLowering,
ReshapeLowering, CopyLowering, DeleteLowering,
CastElemTypeLowering>(&ctxt);
patterns.insert<LinSpaceLowering, ReshapeLowering, CopyLowering,
DeleteLowering, CastElemTypeLowering>(&ctxt);

if (::mlir::failed(::mlir::applyPartialConversion(getOperation(), target,
::std::move(patterns)))) {
Expand Down
6 changes: 4 additions & 2 deletions lib/Dialect/NDArray/Extensions/AllExtensions.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- AllExtensions.cpp - All NDArray Dialect Extensions ------------------===//
//===- AllExtensions.cpp - All NDArray Dialect Extensions -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -7,10 +7,12 @@
//===----------------------------------------------------------------------===//

#include "imex/Dialect/NDArray/Extensions/AllExtensions.h"
#include "imex/Dialect/NDArray/Extensions/BufferizableOpInterfaceImpl.h"
#include "imex/Dialect/NDArray/Extensions/MeshShardingExtensions.h"

using namespace mlir;

void imex::ndarray::registerAllExtensions(DialectRegistry &registry) {
registerShardingInterfaceExternalModels(registry);
}
registerBufferizableOpInterfaceExternalModels(registry);
}
4 changes: 3 additions & 1 deletion lib/Dialect/NDArray/Extensions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
set(LLVM_OPTIONAL_SOURCES
AllExtensions.cpp
BufferizableOpInterfaceImpl.cpp
MeshShardingExtensions.cpp
)

add_imex_extension_library(IMEXNDArrayMeshShardingExtensions
MeshShardingExtensions.cpp
BufferizableOpInterfaceImpl.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/mlir/Dialect/NDArray/Extensions
Expand All @@ -23,4 +25,4 @@ add_imex_extension_library(IMEXNDArrayAllExtensions

LINK_LIBS PUBLIC
IMEXNDArrayMeshShardingExtensions
)
)
16 changes: 16 additions & 0 deletions lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,21 @@ struct LinspaceShardingInterface
}
};

//===----------------------------------------------------------------------===//
// ReshapeShardingInterface
//===----------------------------------------------------------------------===//

struct ReshapeShardingInterface
: public BaseShardingInterface<ReshapeShardingInterface, ReshapeOp> {

SmallVector<mlir::utils::IteratorType>
getLoopIteratorTypes(::mlir::Operation *op) const {
auto rsop = cast<ReshapeOp>(op);
size_t rank = std::max(rsop.getSource().getType().getRank(),
rsop.getResult().getType().getRank());
return {rank, utils::IteratorType::parallel};
}
};
} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -672,6 +687,7 @@ void registerShardingInterfaceExternalModels(mlir::DialectRegistry &registry) {
SubviewOp::attachInterface<SubviewShardingInterface>(*ctx);
InsertSliceOp::attachInterface<InsertSliceShardingInterface>(*ctx);
LinSpaceOp::attachInterface<LinspaceShardingInterface>(*ctx);
ReshapeOp::attachInterface<ReshapeShardingInterface>(*ctx);
registerTrivial<CopyOp, DeleteOp, CastElemTypeOp>(ctx);
});
}
Expand Down

0 comments on commit 2728722

Please sign in to comment.