Skip to content

Commit

Permalink
Fixes to remove temporaries pass (#951)
Browse files Browse the repository at this point in the history
* rm-temp: fixes
  • Loading branch information
tkarna authored Dec 3, 2024
1 parent 7dbe11c commit b5483b9
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions lib/Transforms/RemoveTemporaries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,9 @@ bool checkReadWriteConflict(mlir::Operation *op, mlir::Operation *srcAllocOp,
/// propagating it through subview ops so we cannot just do a replaceAllUse
/// but need to propagate the type change and erase old subview ops. Ported
/// from mlir memref MultiBuffer.cpp
static void replaceUsesAndPropagateType(mlir::RewriterBase &rewriter,
mlir::Operation *oldOp,
mlir::Value val) {
mlir::SmallVector<mlir::Operation *> opsToDelete;
static void replaceUsesAndPropagateType(
mlir::RewriterBase &rewriter, mlir::Operation *oldOp, mlir::Value val,
::mlir::SmallVector<mlir::Operation *> &opsToDelete) {
mlir::SmallVector<mlir::OpOperand *> operandsToReplace;

// Save the operand to replace / delete later (avoid iterator invalidation).
Expand Down Expand Up @@ -390,7 +389,7 @@ static void replaceUsesAndPropagateType(mlir::RewriterBase &rewriter,
subviewUse.getMixedStrides());

// Ouch recursion ... is this really necessary?
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview);
replaceUsesAndPropagateType(rewriter, subviewUse, newSubview, opsToDelete);

opsToDelete.push_back(use.getOwner());
}
Expand All @@ -402,10 +401,6 @@ static void replaceUsesAndPropagateType(mlir::RewriterBase &rewriter,
operand->set(val);
rewriter.finalizeOpModification(op);
}

// Perform late op erasure.
for (mlir::Operation *op : opsToDelete)
rewriter.eraseOp(op);
}

// Moves op after markerOp if possible. Returns true if successful.
Expand Down Expand Up @@ -489,19 +484,21 @@ struct RemoveTemporaries
auto dst = opi.getTarget();
auto src = opi.getSource();
mlir::IRRewriter rewriter(op->getContext());
DEBUG_MSG("RemoveTemporaries", "------------------------------------------")
DEBUG_OP("RemoveTemporaries", "inspecting", op)

auto srcAllocOp = findAllocOp(src);
auto srcDeallocOp = findDeallocOp(src);
auto dstDeallocOp = findDeallocOp(dst);
auto dstDefOp = dst.getDefiningOp();
if (!srcAllocOp) {
// src is not associated with a temp array allocation
DEBUG_MSG("RemoveTemporaries",
"src is not associated with an alloc, skipping")
return;
}
auto allocOpParentReg = srcAllocOp->getParentRegion();
auto copyOpParentReg = op->getParentRegion();
DEBUG_MSG("RemoveTemporaries", "------------------------------------------")
DEBUG_OP("RemoveTemporaries", "inspecting", op)
DEBUG_OP("RemoveTemporaries", " src alloc op", srcAllocOp)

bool srcIsReturned = findReturn(srcAllocOp->getResult(0));
Expand Down Expand Up @@ -533,15 +530,17 @@ struct RemoveTemporaries
return;
}
// Move copy target right after src allocation
// unless target is defined earlier
auto &dom = getAnalysis<::mlir::DominanceInfo>();
if (!moveAfterIfPossible(dstDefOp, srcAllocOp, op, dom)) {
if (!dom.dominates(dstDefOp, srcAllocOp) &&
!moveAfterIfPossible(dstDefOp, srcAllocOp, op, dom)) {
DEBUG_MSG("RemoveTemporaries", "cannot move dst defining op, skipping")
return;
}
// Replace src alloc uses by dst defining op
DEBUG_OP("RemoveTemporaries", " replacing src alloc", srcAllocOp)
DEBUG_OP("RemoveTemporaries", " with", dstDefOp)
replaceUsesAndPropagateType(rewriter, srcAllocOp, dstDefOp->getResult(0));
replaceUsesAndPropagateType(rewriter, srcAllocOp, dst, opsToRemove);
} else {
if (srcIsReturned) {
// no defining op, dst is function argument, after removing scr allow
Expand All @@ -553,7 +552,7 @@ struct RemoveTemporaries
// no defining op, replace src with dst mlir::Value
DEBUG_OP("RemoveTemporaries", " replacing src alloc", srcAllocOp)
DEBUG_MSG("RemoveTemporaries", " with copy op dst value")
replaceUsesAndPropagateType(rewriter, srcAllocOp, dst);
replaceUsesAndPropagateType(rewriter, srcAllocOp, dst, opsToRemove);
}
DEBUG_OP("RemoveTemporaries", " removing op", op)
opsToRemove.push_back(op);
Expand Down

0 comments on commit b5483b9

Please sign in to comment.