From 0bf05d9e119da7b9ec2058197a01494a7f3ce7cb Mon Sep 17 00:00:00 2001 From: Vincent Wells Date: Mon, 23 Dec 2024 13:51:25 -0600 Subject: [PATCH] add some silly debugs --- lib/Dialect/TTNN/Transforms/Optimizer.cpp | 3 + lib/Dialect/TTNN/Transforms/Passes.cpp | 12 ++++ lib/Dialect/TTNN/Transforms/TTNNLayout.cpp | 68 ++++++++++++++----- .../TTNN/Transforms/TTNNWorkarounds.cpp | 2 + 4 files changed, 67 insertions(+), 18 deletions(-) diff --git a/lib/Dialect/TTNN/Transforms/Optimizer.cpp b/lib/Dialect/TTNN/Transforms/Optimizer.cpp index 9ada2dbb5..719d773c7 100644 --- a/lib/Dialect/TTNN/Transforms/Optimizer.cpp +++ b/lib/Dialect/TTNN/Transforms/Optimizer.cpp @@ -146,6 +146,8 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { public: using impl::TTNNOptimizerBase::TTNNOptimizerBase; void runOnOperation() final { + llvm::outs() << "s TTNNOptimizer::rOO\n"; + // Generate legal OP configuration candidates. // Perform memory layout analysis. // Perform final configuration analysis. @@ -338,6 +340,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase { func.getContext(), funcType.getInputs(), funcResultTypes); func.setType(newFuncType); }); + llvm::outs() << "e TTNNOptimizer::rOO\n"; } private: diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index e7aa57d94..3a02c1e8d 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -64,6 +64,7 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase { } void runOnOperation() final { + llvm::outs() << "s TTNNDeallocate::rOO\n"; ModuleOp module = getOperation(); IRRewriter rewriter(&getContext()); @@ -115,6 +116,7 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase { } }); }); + llvm::outs() << "e TTNNDeallocate::rOO\n"; } }; @@ -126,6 +128,8 @@ class TTNNDecomposeLayouts TTNNDecomposeLayouts>::TTNNDecomposeLayoutsBase; void runOnOperation() final { + llvm::outs() << "s TTNNDecomposeLayouts::rOO\n"; + ModuleOp module = getOperation(); IRRewriter rewriter(&getContext()); llvm::SmallVector opsToReplace; @@ -147,6 +151,7 @@ class TTNNDecomposeLayouts rewriter); rewriter.eraseOp(op); } + llvm::outs() << "e TTNNDecomposeLayouts::rOO\n"; } private: @@ -906,6 +911,8 @@ class TTNNCreateInputGenerators TTNNCreateInputGenerators>::TTNNCreateInputGeneratorsBase; void runOnOperation() final { + llvm::outs() << "s TTNNCreateInputGenerators::rOO\n"; + ModuleOp module = getOperation(); IRRewriter rewriter(&getContext()); @@ -1074,6 +1081,8 @@ class TTNNCreateInputGenerators rewriter.getI32IntegerAttr(0)); rewriter.create(mainFuncOp->getLoc(), constantZero); } + + llvm::outs() << "e TTNNCreateInputGenerators::rOO\n"; } }; @@ -1086,6 +1095,8 @@ class TTNNModifySignaturesForDylib TTNNModifySignaturesForDylib>::TTNNModifySignaturesForDylibBase; void runOnOperation() final { + llvm::outs() << "s TTNNModifySignaturesForDylib::rOO\n"; + ModuleOp module = getOperation(); IRRewriter rewriter(&getContext()); @@ -1166,6 +1177,7 @@ class TTNNModifySignaturesForDylib // entryBlock.eraseArguments(1, originalFuncType.getInputs().size()); } + llvm::outs() << "e TTNNModifySignaturesForDylib::rOO\n"; } }; diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index c32efa2d3..df958b2c4 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -279,8 +279,12 @@ class TTNNLayoutDPSOperandsRewriter LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const final { + llvm::outs() << "s TTNNLayoutDPSOperandsRewriter::mAR\n"; + op.dump(); + // To layout op is a special case, we don't want to rewrite it if (mlir::isa(op.getOperation())) { + llvm::outs() << "e -- f TTNNLayoutDPSOperandsRewriter::mAR\n"; return failure(); } @@ -330,6 +334,8 @@ class TTNNLayoutDPSOperandsRewriter }); } } + llvm::outs() << "e -- " << (modified ? "s" : "f") + << " TTNNLayoutDPSOperandsRewriter::mAR\n"; return modified ? success() : failure(); } @@ -341,61 +347,73 @@ class TTNNLayoutHoistedFuncCallRewriter TTNNLayoutHoistedFuncCallRewriter(MLIRContext *ctx) : OpRewritePattern(ctx) {} - // Match and rewrite the CallOp + // Match and rewrite the CallOp. LogicalResult matchAndRewrite(func::CallOp callOp, PatternRewriter &rewriter) const override { + llvm::outs() << "s TTNNLayoutHoistedFuncCallRewriter::mAR\n"; + auto device = utils::getOrInsertDevice(rewriter, callOp); llvm::outs() << "0\n"; llvm::outs() << callOp->getName() << "\n"; if (!callOp->hasAttr("hoisted_call")) { + llvm::outs() << "e -- f TTNNLayoutHoistedFuncCallRewriter::mAR\n"; return failure(); } llvm::outs() << "1\n"; - // Create a FromDevice operation for each operand + // Create a FromDevice operation for each operand. SmallVector fromDeviceOperands; + size_t locIdx = 0; for (auto operand : callOp.getOperands()) { - // Insert FromDevice op before the operand and collect the new operands + // Insert FromDevice op before the operand and collect the new operands. auto fromDeviceOp = rewriter.create( callOp.getLoc(), operand.getType(), operand); fromDeviceOp.dump(); - fromDeviceOperands.push_back(fromDeviceOp.getResult()); + Location newLoc = appendInputSuffix(callOp.getLoc(), locIdx++); + std::optional maybeLayoutOp = createToLayoutOp( + rewriter, newLoc, operand, BufferType::SystemMemory, + nullptr /* tensorMemoryLayoutAttr */, false /* tiled */); + Value hostOpValue = maybeLayoutOp.has_value() ? maybeLayoutOp.value() + : fromDeviceOp.getResult(); + fromDeviceOperands.push_back(hostOpValue); } llvm::outs() << "2\n"; - // Create the original CallOp with the new operands (FromDevice'd) - rewriter.create(callOp.getLoc(), callOp.getCallee(), - callOp.getResultTypes(), fromDeviceOperands); + // Create the original CallOp with the new operands (FromDevice'd). + auto newCallOp = rewriter.create( + callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(), + fromDeviceOperands); llvm::outs() << "3\n"; // Now, insert ToDevice ops for the results of the CallOp SmallVector toDeviceResults; - for (auto result : callOp.getResults()) { + for (auto result : newCallOp.getResults()) { // Insert ToDevice op after the result auto toDeviceOp = rewriter.create( callOp.getLoc(), result.getType(), result, device, ttnn::MemoryConfigAttr{}); - toDeviceResults.push_back(toDeviceOp.getResult()); + Location newLoc = + appendInputSuffix(callOp.getLoc(), result.getResultNumber() + locIdx); + std::optional maybeLayoutOp = createToLayoutOp( + rewriter, newLoc, result, BufferType::SystemMemory, + nullptr /* tensorMemoryLayoutAttr */, true /* tiled */); + Value deviceResultValue = maybeLayoutOp.has_value() + ? maybeLayoutOp.value() + : toDeviceOp.getResult(); + toDeviceResults.push_back(deviceResultValue); } llvm::outs() << "4\n"; - // Replace the original call with a dummy op (since it has been rewritten) + // Replace the original call with the new ToDevice results. rewriter.replaceOp(callOp, toDeviceResults); llvm::outs() << "5\n"; - // Replace the original operands with the FromDevice'd operands - for (size_t i = 0; i < callOp.getOperands().size(); ++i) { - // Replace the uses of the original operand with the new FromDevice - // operand - callOp.getOperand(i).replaceAllUsesWith(fromDeviceOperands[i]); - } - - llvm::outs() << "6\n"; + llvm::outs() << "e -- s TTNNLayoutHoistedFuncCallRewriter::mAR\n"; return success(); } @@ -411,6 +429,8 @@ class TTNNLayoutFuncReturnRewriter LogicalResult matchAndRewrite(mlir::func::ReturnOp op, PatternRewriter &rewriter) const final { + llvm::outs() << "s TTNNLayoutFuncReturnRewriter::mAR\n"; + bool modified = false; for (OpOperand &operand : op->getOpOperands()) { Location newLoc = @@ -424,6 +444,15 @@ class TTNNLayoutFuncReturnRewriter modified = true; } } + llvm::outs() << "e -- "; + if (modified) { + llvm::outs() << "s"; + op.dump(); + } else { + llvm::outs() << "f"; + } + llvm::outs() << " TTNNLayoutFuncReturnRewriter::mAR\n"; + return modified ? success() : failure(); } @@ -472,6 +501,9 @@ class TTNNLayout : public impl::TTNNLayoutBase { signalPassFailure(); return; } + llvm::outs() << "IR dump:\n"; + getOperation()->dump(); + llvm::outs() << "e TTNNLayout::rOO\n"; } } diff --git a/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp index 2c0c48dbc..1e6c4dd68 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp @@ -398,6 +398,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { using impl::TTNNWorkaroundsBase::TTNNWorkaroundsBase; void runOnOperation() final { + llvm::outs() << "s TTNNWorkarounds::rOO\n"; if (decompositionWorkaroundsEnabled) { // Placeholder for workaround decomposition patterns. RewritePatternSet patterns(&getContext()); @@ -438,6 +439,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase { return; } } + llvm::outs() << "e TTNNWorkarounds::rOO\n"; } }; } // namespace mlir::tt::ttnn