Skip to content

Commit

Permalink
add some silly debugs
Browse files Browse the repository at this point in the history
  • Loading branch information
vwellsTT committed Dec 27, 2024
1 parent 6e0fdaa commit 0bf05d9
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
3 changes: 3 additions & 0 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
public:
using impl::TTNNOptimizerBase<TTNNOptimizer>::TTNNOptimizerBase;
void runOnOperation() final {
llvm::outs() << "s TTNNOptimizer::rOO\n";

// Generate legal OP configuration candidates.
// Perform memory layout analysis.
// Perform final configuration analysis.
Expand Down Expand Up @@ -338,6 +340,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
func.getContext(), funcType.getInputs(), funcResultTypes);
func.setType(newFuncType);
});
llvm::outs() << "e TTNNOptimizer::rOO\n";
}

private:
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
}

void runOnOperation() final {
llvm::outs() << "s TTNNDeallocate::rOO\n";
ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());

Expand Down Expand Up @@ -115,6 +116,7 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
}
});
});
llvm::outs() << "e TTNNDeallocate::rOO\n";
}
};

Expand All @@ -126,6 +128,8 @@ class TTNNDecomposeLayouts
TTNNDecomposeLayouts>::TTNNDecomposeLayoutsBase;

void runOnOperation() final {
llvm::outs() << "s TTNNDecomposeLayouts::rOO\n";

ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());
llvm::SmallVector<Operation *> opsToReplace;
Expand All @@ -147,6 +151,7 @@ class TTNNDecomposeLayouts
rewriter);
rewriter.eraseOp(op);
}
llvm::outs() << "e TTNNDecomposeLayouts::rOO\n";
}

private:
Expand Down Expand Up @@ -906,6 +911,8 @@ class TTNNCreateInputGenerators
TTNNCreateInputGenerators>::TTNNCreateInputGeneratorsBase;

void runOnOperation() final {
llvm::outs() << "s TTNNCreateInputGenerators::rOO\n";

ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());

Expand Down Expand Up @@ -1074,6 +1081,8 @@ class TTNNCreateInputGenerators
rewriter.getI32IntegerAttr(0));
rewriter.create<func::ReturnOp>(mainFuncOp->getLoc(), constantZero);
}

llvm::outs() << "e TTNNCreateInputGenerators::rOO\n";
}
};

Expand All @@ -1086,6 +1095,8 @@ class TTNNModifySignaturesForDylib
TTNNModifySignaturesForDylib>::TTNNModifySignaturesForDylibBase;

void runOnOperation() final {
llvm::outs() << "s TTNNModifySignaturesForDylib::rOO\n";

ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());

Expand Down Expand Up @@ -1166,6 +1177,7 @@ class TTNNModifySignaturesForDylib
//
entryBlock.eraseArguments(1, originalFuncType.getInputs().size());
}
llvm::outs() << "e TTNNModifySignaturesForDylib::rOO\n";
}
};

Expand Down
68 changes: 50 additions & 18 deletions lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,12 @@ class TTNNLayoutDPSOperandsRewriter

LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const final {
llvm::outs() << "s TTNNLayoutDPSOperandsRewriter::mAR\n";
op.dump();

// To layout op is a special case, we don't want to rewrite it
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation())) {
llvm::outs() << "e -- f TTNNLayoutDPSOperandsRewriter::mAR\n";
return failure();
}

Expand Down Expand Up @@ -330,6 +334,8 @@ class TTNNLayoutDPSOperandsRewriter
});
}
}
llvm::outs() << "e -- " << (modified ? "s" : "f")
<< " TTNNLayoutDPSOperandsRewriter::mAR\n";

return modified ? success() : failure();
}
Expand All @@ -341,61 +347,73 @@ class TTNNLayoutHoistedFuncCallRewriter
TTNNLayoutHoistedFuncCallRewriter(MLIRContext *ctx)
: OpRewritePattern<func::CallOp>(ctx) {}

// Match and rewrite the CallOp
// Match and rewrite the CallOp.
LogicalResult matchAndRewrite(func::CallOp callOp,
PatternRewriter &rewriter) const override {
llvm::outs() << "s TTNNLayoutHoistedFuncCallRewriter::mAR\n";

auto device = utils::getOrInsertDevice(rewriter, callOp);
llvm::outs() << "0\n";

llvm::outs() << callOp->getName() << "\n";
if (!callOp->hasAttr("hoisted_call")) {
llvm::outs() << "e -- f TTNNLayoutHoistedFuncCallRewriter::mAR\n";
return failure();
}

llvm::outs() << "1\n";

// Create a FromDevice operation for each operand
// Create a FromDevice operation for each operand.
SmallVector<Value, 4> fromDeviceOperands;
size_t locIdx = 0;
for (auto operand : callOp.getOperands()) {
// Insert FromDevice op before the operand and collect the new operands
// Insert FromDevice op before the operand and collect the new operands.
auto fromDeviceOp = rewriter.create<ttnn::FromDeviceOp>(
callOp.getLoc(), operand.getType(), operand);
fromDeviceOp.dump();
fromDeviceOperands.push_back(fromDeviceOp.getResult());
Location newLoc = appendInputSuffix(callOp.getLoc(), locIdx++);
std::optional<Value> maybeLayoutOp = createToLayoutOp(
rewriter, newLoc, operand, BufferType::SystemMemory,
nullptr /* tensorMemoryLayoutAttr */, false /* tiled */);
Value hostOpValue = maybeLayoutOp.has_value() ? maybeLayoutOp.value()
: fromDeviceOp.getResult();
fromDeviceOperands.push_back(hostOpValue);
}

llvm::outs() << "2\n";

// Create the original CallOp with the new operands (FromDevice'd)
rewriter.create<func::CallOp>(callOp.getLoc(), callOp.getCallee(),
callOp.getResultTypes(), fromDeviceOperands);
// Create the original CallOp with the new operands (FromDevice'd).
auto newCallOp = rewriter.create<func::CallOp>(
callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
fromDeviceOperands);

llvm::outs() << "3\n";

// Now, insert ToDevice ops for the results of the CallOp
SmallVector<Value, 4> toDeviceResults;
for (auto result : callOp.getResults()) {
for (auto result : newCallOp.getResults()) {
// Insert ToDevice op after the result
auto toDeviceOp = rewriter.create<ttnn::ToDeviceOp>(
callOp.getLoc(), result.getType(), result, device,
ttnn::MemoryConfigAttr{});
toDeviceResults.push_back(toDeviceOp.getResult());
Location newLoc =
appendInputSuffix(callOp.getLoc(), result.getResultNumber() + locIdx);
std::optional<Value> maybeLayoutOp = createToLayoutOp(
rewriter, newLoc, result, BufferType::SystemMemory,
nullptr /* tensorMemoryLayoutAttr */, true /* tiled */);
Value deviceResultValue = maybeLayoutOp.has_value()
? maybeLayoutOp.value()
: toDeviceOp.getResult();
toDeviceResults.push_back(deviceResultValue);
}
llvm::outs() << "4\n";

// Replace the original call with a dummy op (since it has been rewritten)
// Replace the original call with the new ToDevice results.
rewriter.replaceOp(callOp, toDeviceResults);

llvm::outs() << "5\n";

// Replace the original operands with the FromDevice'd operands
for (size_t i = 0; i < callOp.getOperands().size(); ++i) {
// Replace the uses of the original operand with the new FromDevice
// operand
callOp.getOperand(i).replaceAllUsesWith(fromDeviceOperands[i]);
}

llvm::outs() << "6\n";
llvm::outs() << "e -- s TTNNLayoutHoistedFuncCallRewriter::mAR\n";

return success();
}
Expand All @@ -411,6 +429,8 @@ class TTNNLayoutFuncReturnRewriter

LogicalResult matchAndRewrite(mlir::func::ReturnOp op,
PatternRewriter &rewriter) const final {
llvm::outs() << "s TTNNLayoutFuncReturnRewriter::mAR\n";

bool modified = false;
for (OpOperand &operand : op->getOpOperands()) {
Location newLoc =
Expand All @@ -424,6 +444,15 @@ class TTNNLayoutFuncReturnRewriter
modified = true;
}
}
llvm::outs() << "e -- ";
if (modified) {
llvm::outs() << "s";
op.dump();
} else {
llvm::outs() << "f";
}
llvm::outs() << " TTNNLayoutFuncReturnRewriter::mAR\n";

return modified ? success() : failure();
}

Expand Down Expand Up @@ -472,6 +501,9 @@ class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
signalPassFailure();
return;
}
llvm::outs() << "IR dump:\n";
getOperation()->dump();
llvm::outs() << "e TTNNLayout::rOO\n";
}
}

Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
using impl::TTNNWorkaroundsBase<TTNNWorkarounds>::TTNNWorkaroundsBase;

void runOnOperation() final {
llvm::outs() << "s TTNNWorkarounds::rOO\n";
if (decompositionWorkaroundsEnabled) {
// Placeholder for workaround decomposition patterns.
RewritePatternSet patterns(&getContext());
Expand Down Expand Up @@ -438,6 +439,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
return;
}
}
llvm::outs() << "e TTNNWorkarounds::rOO\n";
}
};
} // namespace mlir::tt::ttnn

0 comments on commit 0bf05d9

Please sign in to comment.