Skip to content

Commit

Permalink
Use .replaceAllUsesWith() in MLIR pass
Browse files Browse the repository at this point in the history
  • Loading branch information
chunnienc authored and Siyuan Liu committed Nov 29, 2023
1 parent d78bf9e commit 661b595
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions torch_xla/csrc/runtime/stablehlo_composite_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,7 @@ class BuildStableHLOCompositePass : public mlir::OperationPass<mlir::ModuleOp> {
// func call.
for (size_t i = 0; i < op->getNumResults(); ++i) {
mlir::OpResult result = op->getResult(i);
mlir::OpResult new_result = composite_call_op->getResult(i);
for (mlir::OpOperand& use : result.getUses()) {
use.getOwner()->setOperand(use.getOperandNumber(), new_result);
}
result.replaceAllUsesWith(composite_call_op->getResult(i));
}

// The unused scope_ops will be eliminated with canonicalizer.
Expand All @@ -357,26 +354,18 @@ class RemoveXlaMarkTensorOpsPass
mlir::func::FuncOp func_op = getOperation();
llvm::SmallVector<mlir::Operation*> ops_to_erase;

for (auto mark_tensor_op :
func_op.getOps<mlir::stablehlo::CustomCallOp>()) {
if (!IsXlaMarkTensorOp(mark_tensor_op.getOperation())) {
for (auto op : func_op.getOps<mlir::stablehlo::CustomCallOp>()) {
if (!IsXlaMarkTensorOp(op.getOperation())) {
continue;
}
mlir::Value original_value = mark_tensor_op.getOperand(0);

llvm::SmallVector<std::tuple<mlir::Operation*, size_t>> uses;
for (mlir::OpOperand& use : mark_tensor_op.getResult(0).getUses()) {
uses.push_back({use.getOwner(), use.getOperandNumber()});
}
mlir::Value original_value = op.getOperand(0);

for (auto [use_op, operand_number] : uses) {
use_op->setOperand(operand_number, original_value);
for (mlir::Value result : op.getResults()) {
result.replaceAllUsesWith(original_value);
}
ops_to_erase.push_back(mark_tensor_op.getOperation());
}
for (auto* op : ops_to_erase) {
op->erase();
}

// The unused custom_call ops will be eliminated with canonicalizer.
}

mlir::StringRef getName() const override {
Expand Down

0 comments on commit 661b595

Please sign in to comment.