-
Notifications
You must be signed in to change notification settings - Fork 12.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect conversion: extra signature conversion check #117471
base: main
Are you sure you want to change the base?
[mlir][Transforms] Dialect conversion: extra signature conversion check #117471
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit adds an extra assertion to To simplify the check, This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust. Full diff: https://github.com/llvm/llvm-project/pull/117471.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5acd095da8e386..710c976281dc3d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -434,23 +434,25 @@ class MoveBlockRewrite : public BlockRewrite {
class BlockTypeConversionRewrite : public BlockRewrite {
public:
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, Block *origBlock)
- : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
- origBlock(origBlock) {}
+ Block *origBlock, Block *newBlock)
+ : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock),
+ newBlock(newBlock) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
}
- Block *getOrigBlock() const { return origBlock; }
+ Block *getOrigBlock() const { return block; }
+
+ Block *getNewBlock() const { return newBlock; }
void commit(RewriterBase &rewriter) override;
void rollback() override;
private:
- /// The original block that was requested to have its signature converted.
- Block *origBlock;
+ /// The new block that was created as part of this signature conversion.
+ Block *newBlock;
};
/// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -721,6 +723,18 @@ static bool hasRewrite(R &&rewrites, Operation *op) {
});
}
+#ifndef NDEBUG
+/// Return "true" if there is a block rewrite that matches the specified
+/// rewrite type and block among the given rewrites.
+template <typename RewriteTy, typename R>
+static bool hasRewrite(R &&rewrites, Block *block) {
+ return any_of(std::forward<R>(rewrites), [&](auto &rewrite) {
+ auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+ return rewriteTy && rewriteTy->getBlock() == block;
+ });
+}
+#endif // NDEBUG
+
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -966,12 +980,12 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
// block.
if (auto *listener =
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
- for (Operation *op : block->getUsers())
+ for (Operation *op : getNewBlock()->getUsers())
listener->notifyOperationModified(op);
}
void BlockTypeConversionRewrite::rollback() {
- block->replaceAllUsesWith(origBlock);
+ getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
@@ -1223,6 +1237,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
+ // A block cannot be converted multiple times.
+ assert(!hasRewrite<BlockTypeConversionRewrite>(rewrites, block) &&
+ "block was already converted");
OpBuilder::InsertionGuard g(rewriter);
// If no arguments are being changed or added, there is nothing to do.
@@ -1308,7 +1325,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
+ appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, very useful!
This commit adds an extra assertion to
applySignatureConversion
to prevent incorrect API usage: The same block cannot be converted multiple times. That would mess with the underlying conversion value mapping. (Mappings would be overwritten.) This is similar to op replacements: The same op cannot be replaced multiple times.To simplify the check,
BlockTypeConversionRewrite::block
now stores the original block. The new block is stored in an extra field. (It used to be the other way around.)This commit is in preparation of adding 1:N support to the conversion value mapping. Before making any further changes to the mapping infrastructure, I'd like to make sure that the code base around it (that uses the mapping) is robust.