From 6ee34539e3753dd8e70af8ad97e5be4de02cfebb Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Fri, 13 Dec 2024 16:21:55 +0100 Subject: [PATCH] fixes for insert_slice --- build_tools/llvm_version.txt | 2 +- .../Extensions/MeshShardingExtensions.cpp | 34 +++++++++++++------ lib/Dialect/NDArray/IR/InsertSliceOp.cpp | 19 ++++++++--- .../NDArray/Transforms/CoalesceShardOps.cpp | 9 +++-- 4 files changed, 45 insertions(+), 19 deletions(-) diff --git a/build_tools/llvm_version.txt b/build_tools/llvm_version.txt index 357f190a8..6f36fe915 100644 --- a/build_tools/llvm_version.txt +++ b/build_tools/llvm_version.txt @@ -1 +1 @@ -0eeb79d76a8284fae3e5e3b4ebbbe98d02249235 +d8bb4e6495793fc6bbc38a75dbef52091139c68a diff --git a/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp b/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp index 41dd3c1f8..5188e0d8a 100644 --- a/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp +++ b/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp @@ -193,6 +193,7 @@ static std::array getShardSliceOffAndSz( ValueRange myIdx, int64_t dim, ArrayRef meshShape, ArrayRef splitAxes, Value targetOffs, ArrayRef srcShape, const SmallVector &slcOffs, + const SmallVector &slcSizes, const SmallVector &slcStrides, const SmallVector &haloSizes, const EasyI64 &zero, const EasyI64 &one, OpBuilder &builder, Location loc) { @@ -214,8 +215,12 @@ static std::array getShardSliceOffAndSz( std::tie(myOff, mySize) = getOffsetAndSize(myID, zero, one, targetOffs, currPos, builder, loc); } else { - myOff = getBaseShardDimOff(myID, numShards, extend, zero).get(); - mySize = getBaseShardDimSize(myID, numShards, extend, one, zero).get(); + auto myOff_ = getBaseShardDimOff(myID, numShards, extend, zero); + auto mySize_ = getBaseShardDimSize(myID, numShards, extend, one, zero); + auto slcSz = easyI64(loc, builder, slcSizes[dim]); + mySize_ = zero.max(slcSz - myOff_).min(mySize_); + myOff = myOff_.get(); + mySize = mySize_.get(); } // the global offset of the local shard is slice offset plus the computed @@ -290,7 +295,7 @@ getLocalOffSzAndStrFromSlice(OP op, ArrayRef srcShape, } else { auto offAndSz = getShardSliceOffAndSz( myIdx, dim, mesh.getShape(), splitAxes, targetOffs, srcShape, slcOffs, - slcStrides, haloSizes, zero, one, builder, loc); + slcSizes, slcStrides, haloSizes, zero, one, builder, loc); lShardOffs.emplace_back(offAndSz[0]); lShardSizes.emplace_back(offAndSz[1]); } @@ -439,6 +444,7 @@ struct InsertSliceShardingInterface } auto dstSharding = mlir::mesh::MeshSharding::get(shardingOption.mesh, res); maybeInsertSourceShardingAnnotation(dstSharding, op->getOpOperand(0), b); + maybeInsertTargetShardingAnnotation(dstSharding, op->getResult(0), b); return success(); } @@ -449,7 +455,8 @@ struct InsertSliceShardingInterface IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) const { - if (resultShardings.size() != 0) { + if (resultShardings.size() != 1 || operandShardings.size() < 2 || + resultShardings[0] != operandShardings[0]) { return failure(); } @@ -493,22 +500,29 @@ struct InsertSliceShardingInterface } scf::IfOp ifOp = builder.create( - loc, hasSize.get(), [&](OpBuilder &b, Location loc) { - (void)b.create( + loc, hasSize.get(), + [&](OpBuilder &b, Location loc) { + auto res = b.create( loc, spmdizedOperands[0], spmdizedOperands[1], lShardOffs, lShardSizes, lShardStrides); - b.create(loc); + b.create(loc, res.getResult()); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, spmdizedOperands[0]); }); - spmdizationMap.map(op, ifOp.getOperation()); - builder.create( - loc, spmdizedOperands[0].getType(), spmdizedOperands[0], + auto res = builder.create( + loc, spmdizedOperands[0].getType(), ifOp.getResult(0), dstSharding.getMeshAttr(), mlir::mesh::MeshAxesArrayAttr::get(op->getContext(), dstSharding.getSplitAxes()), dstSharding.getDynamicHaloSizes(), DenseI64ArrayAttr::get(op->getContext(), dstSharding.getStaticHaloSizes())); + + spmdizationMap.map(op->getResult(0), res->getResult(0)); + spmdizationMap.map(op, res.getOperation()); + return success(); } }; diff --git a/lib/Dialect/NDArray/IR/InsertSliceOp.cpp b/lib/Dialect/NDArray/IR/InsertSliceOp.cpp index 9de5c1c86..e013cdffe 100644 --- a/lib/Dialect/NDArray/IR/InsertSliceOp.cpp +++ b/lib/Dialect/NDArray/IR/InsertSliceOp.cpp @@ -132,7 +132,8 @@ class InsertSliceOpConstantArgumentFolder final return mlir::failure(); auto sourceType = insertSliceOp.getSourceType(); - auto dstTnsrType = insertSliceOp.getDestinationType(); //.getTensorType(); + auto dstTnsrType = insertSliceOp.getDestinationType(); + // Create the new op in canonical form. auto sourceTnsrType = mlir::tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( @@ -140,14 +141,22 @@ class InsertSliceOpConstantArgumentFolder final mixedSizes, mixedStrides); auto newSourceType = sourceType.cloneWith(sourceTnsrType.getShape(), sourceTnsrType.getElementType()); + mlir::Value toInsert = insertSliceOp.getSource(); if (newSourceType != sourceType) { - if (newSourceType.getRank() != sourceType.getRank()) + if (sourceType.getRank() == 0) { + if (newSourceType.getRank() > 1) { + return mlir::failure(); + } + } else if (newSourceType.getRank() != sourceType.getRank()) { return mlir::failure(); - mlir::OpBuilder::InsertionGuard g(rewriter); - toInsert = rewriter.create(insertSliceOp.getLoc(), - newSourceType, toInsert); + } else { + mlir::OpBuilder::InsertionGuard g(rewriter); + toInsert = rewriter.create( + insertSliceOp.getLoc(), newSourceType, toInsert); + } } + rewriter.replaceOpWithNewOp( insertSliceOp, insertSliceOp.getDestination(), toInsert, mixedOffsets, mixedSizes, mixedStrides); diff --git a/lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp b/lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp index fc3c48af1..49a8f84bb 100644 --- a/lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp +++ b/lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp @@ -101,8 +101,8 @@ struct CoalesceShardOpsPass return defOp; } else if (auto op = ::mlir::dyn_cast<::mlir::DestinationStyleOpInterface>( defOp)) { - return op.getNumDpsInputs() == 1 ? op.getDpsInits()[0].getDefiningOp() - : defOp; + return op.getNumDpsInits() == 1 ? getBaseArray(op.getDpsInits()[0]) + : defOp; } else if (auto op = ::mlir::dyn_cast<::imex::ndarray::SubviewOp>(defOp)) { return getBaseArray(op.getSource()); } else if (auto op = @@ -479,7 +479,10 @@ struct CoalesceShardOpsPass // update shardOps of dependent Subview/InsertSliceOps for (auto svShardOp : shardOps) { - svShardOp.getSrcMutable().assign(newShardOp.getResult()); + assert(svShardOp->hasOneUse()); + if (mlir::isa<::imex::ndarray::SubviewOp>(*svShardOp->user_begin())) { + svShardOp.getSrcMutable().assign(newShardOp.getResult()); + } svShardOp.getShardingMutable().assign(newSharding); } // barriers/halo-updates get inserted when InsertSliceOps (or other write