Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TTNN transform to ensure hoisted ops have operands moved to/from CPU properly #1649

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion env/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ ExternalProject_Add(
-DMLIR_LINK_MLIR_DYLIB=OFF
-DMLIR_BUILD_MLIR_C_DYLIB=ON
# ======================
-DCMAKE_BUILD_TYPE=MinSizeRel
-DCMAKE_BUILD_TYPE=Debug
-DLLVM_ENABLE_ASSERTIONS=ON
-DMLIR_ENABLE_BINDINGS_PYTHON=ON
-DCMAKE_C_FLAGS=-D_LIBCPP_HAS_NO_LIBRARY_ALIGNED_ALLOCATION
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"
#include "ttmlir/Target/Common/Target.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
public:
using impl::TTNNOptimizerBase<TTNNOptimizer>::TTNNOptimizerBase;
void runOnOperation() final {
llvm::outs() << "s TTNNOptimizer::rOO\n";

// Generate legal OP configuration candidates.
// Perform memory layout analysis.
// Perform final configuration analysis.
Expand Down Expand Up @@ -338,6 +340,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
func.getContext(), funcType.getInputs(), funcResultTypes);
func.setType(newFuncType);
});
llvm::outs() << "e TTNNOptimizer::rOO\n";
}

private:
Expand Down
24 changes: 22 additions & 2 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
}

void runOnOperation() final {
llvm::outs() << "s TTNNDeallocate::rOO\n";
ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());

module->walk([&](func::FuncOp func) {
assert(func.getBody().hasOneBlock());
if (func.isDeclaration()) {
return;
}
assert(func.getBody().hasOneBlock() &&
"found func that didn't have one block!");
Liveness liveness(func.getOperation());
const LivenessBlockInfo *livenessInfo =
liveness.getLiveness(&func.getBody().front());
Expand Down Expand Up @@ -111,6 +116,7 @@ class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
}
});
});
llvm::outs() << "e TTNNDeallocate::rOO\n";
}
};

Expand All @@ -122,11 +128,17 @@ class TTNNDecomposeLayouts
TTNNDecomposeLayouts>::TTNNDecomposeLayoutsBase;

void runOnOperation() final {
llvm::outs() << "s TTNNDecomposeLayouts::rOO\n";

ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());
llvm::SmallVector<Operation *> opsToReplace;
module->walk([&](func::FuncOp func) {
assert(func.getBody().hasOneBlock());
if (func.isDeclaration()) {
return;
}
assert(func.getBody().hasOneBlock() &&
"found func that didn't have one block!");
func->walk([&](Operation *op) {
if (!isa<ttnn::ToLayoutOp>(op)) {
return;
Expand All @@ -139,6 +151,7 @@ class TTNNDecomposeLayouts
rewriter);
rewriter.eraseOp(op);
}
llvm::outs() << "e TTNNDecomposeLayouts::rOO\n";
}

private:
Expand Down Expand Up @@ -898,6 +911,8 @@ class TTNNCreateInputGenerators
TTNNCreateInputGenerators>::TTNNCreateInputGeneratorsBase;

void runOnOperation() final {
llvm::outs() << "s TTNNCreateInputGenerators::rOO\n";

ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());

Expand Down Expand Up @@ -1066,6 +1081,8 @@ class TTNNCreateInputGenerators
rewriter.getI32IntegerAttr(0));
rewriter.create<func::ReturnOp>(mainFuncOp->getLoc(), constantZero);
}

llvm::outs() << "e TTNNCreateInputGenerators::rOO\n";
}
};

Expand All @@ -1078,6 +1095,8 @@ class TTNNModifySignaturesForDylib
TTNNModifySignaturesForDylib>::TTNNModifySignaturesForDylibBase;

void runOnOperation() final {
llvm::outs() << "s TTNNModifySignaturesForDylib::rOO\n";

ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());

Expand Down Expand Up @@ -1158,6 +1177,7 @@ class TTNNModifySignaturesForDylib
//
entryBlock.eraseArguments(1, originalFuncType.getInputs().size());
}
llvm::outs() << "e TTNNModifySignaturesForDylib::rOO\n";
}
};

Expand Down
114 changes: 112 additions & 2 deletions lib/Dialect/TTNN/Transforms/TTNNLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

namespace mlir::tt::ttnn {
#define GEN_PASS_DEF_TTNNLAYOUT
Expand Down Expand Up @@ -51,7 +51,7 @@ inline Location appendInputSuffix(Location loc, int64_t operandIndex) {
// To layout pass
//===----------------------------------------------------------------------===//

// Converts tensor types to have a ttnn layout attribute with default values
// Converts tensor types to have a ttnn layout attribute with deault values
//
// Example: tensor<15x10x32xf32> -> tensor<15x10x32xf32, ttnn_layout<...>>
// where ttnn_layout<...> is constructed with default values
Expand Down Expand Up @@ -127,7 +127,12 @@ class TTNNLayoutTensorTypeRewriter : public RewritePattern {
}
funcOp.setFunctionType(newType);

if (funcOp.isDeclaration()) {
return true;
}

Block &entryBlock = funcOp.getBody().front();

for (unsigned i = 0; i < entryBlock.getNumArguments(); ++i) {
entryBlock.getArgument(i).setType(inputTypes[i]);
}
Expand All @@ -143,6 +148,10 @@ class TTNNLayoutTensorTypeRewriter : public RewritePattern {
updated |= convertTypes(op->getOperands(), operands);
updated |= convertTypes(op->getResults(), results);
updated |= convertFuncType(op, rewriter);
if (updated) {
llvm::outs() << "TTNNLayoutTensorTypeRewriter::mAR succeeded on op:\n";
op->dump();
}
return updated ? success() : failure();
}

Expand Down Expand Up @@ -274,8 +283,12 @@ class TTNNLayoutDPSOperandsRewriter

LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const final {
llvm::outs() << "s TTNNLayoutDPSOperandsRewriter::mAR\n";
op.dump();

// To layout op is a special case, we don't want to rewrite it
if (mlir::isa<ttir::ToLayoutOp>(op.getOperation())) {
llvm::outs() << "e -- f TTNNLayoutDPSOperandsRewriter::mAR\n";
return failure();
}

Expand Down Expand Up @@ -325,11 +338,91 @@ class TTNNLayoutDPSOperandsRewriter
});
}
}
llvm::outs() << "e -- " << (modified ? "s" : "f")
<< " TTNNLayoutDPSOperandsRewriter::mAR\n";

return modified ? success() : failure();
}
};

class TTNNLayoutHoistedFuncCallRewriter
: public OpRewritePattern<func::CallOp> {
public:
TTNNLayoutHoistedFuncCallRewriter(MLIRContext *ctx)
: OpRewritePattern<func::CallOp>(ctx) {}

// Match and rewrite the CallOp.
LogicalResult matchAndRewrite(func::CallOp callOp,
PatternRewriter &rewriter) const override {
llvm::outs() << "s TTNNLayoutHoistedFuncCallRewriter::mAR\n";

auto device = utils::getOrInsertDevice(rewriter, callOp);
llvm::outs() << "0\n";

llvm::outs() << callOp->getName() << "\n";
if (!callOp->hasAttr("hoisted_call")) {
llvm::outs() << "e -- f TTNNLayoutHoistedFuncCallRewriter::mAR\n";
return failure();
}

llvm::outs() << "1\n";

// Create a FromDevice operation for each operand.
SmallVector<Value, 4> fromDeviceOperands;
size_t locIdx = 0;
for (auto operand : callOp.getOperands()) {
// Insert FromDevice op before the operand and collect the new operands.
auto fromDeviceOp = rewriter.create<ttnn::FromDeviceOp>(
callOp.getLoc(), operand.getType(), operand);
fromDeviceOp.dump();
Location newLoc = appendInputSuffix(callOp.getLoc(), locIdx++);
std::optional<Value> maybeLayoutOp = createToLayoutOp(
rewriter, newLoc, operand, BufferType::SystemMemory,
nullptr /* tensorMemoryLayoutAttr */, false /* tiled */);
Value hostOpValue = maybeLayoutOp.has_value() ? maybeLayoutOp.value()
: fromDeviceOp.getResult();
fromDeviceOperands.push_back(hostOpValue);
}

llvm::outs() << "2\n";

// Create the original CallOp with the new operands (FromDevice'd).
auto newCallOp = rewriter.create<func::CallOp>(
callOp.getLoc(), callOp.getCallee(), callOp.getResultTypes(),
fromDeviceOperands);

llvm::outs() << "3\n";

// Now, insert ToDevice ops for the results of the CallOp
SmallVector<Value, 4> toDeviceResults;
for (auto result : newCallOp.getResults()) {
// Insert ToDevice op after the result
auto toDeviceOp = rewriter.create<ttnn::ToDeviceOp>(
callOp.getLoc(), result.getType(), result, device,
ttnn::MemoryConfigAttr{});
Location newLoc =
appendInputSuffix(callOp.getLoc(), result.getResultNumber() + locIdx);
std::optional<Value> maybeLayoutOp = createToLayoutOp(
rewriter, newLoc, result, BufferType::SystemMemory,
nullptr /* tensorMemoryLayoutAttr */, true /* tiled */);
Value deviceResultValue = maybeLayoutOp.has_value()
? maybeLayoutOp.value()
: toDeviceOp.getResult();
toDeviceResults.push_back(deviceResultValue);
}
llvm::outs() << "4\n";

// Replace the original call with the new ToDevice results.
rewriter.replaceOp(callOp, toDeviceResults);

llvm::outs() << "5\n";

llvm::outs() << "e -- s TTNNLayoutHoistedFuncCallRewriter::mAR\n";

return success();
}
};

// Updates the layout of the operands of a func::ReturnOp.
// The intent is to move the result to host.
class TTNNLayoutFuncReturnRewriter
Expand All @@ -340,6 +433,8 @@ class TTNNLayoutFuncReturnRewriter

LogicalResult matchAndRewrite(mlir::func::ReturnOp op,
PatternRewriter &rewriter) const final {
llvm::outs() << "s TTNNLayoutFuncReturnRewriter::mAR\n";

bool modified = false;
for (OpOperand &operand : op->getOpOperands()) {
Location newLoc =
Expand All @@ -353,6 +448,15 @@ class TTNNLayoutFuncReturnRewriter
modified = true;
}
}
llvm::outs() << "e -- ";
if (modified) {
llvm::outs() << "s";
op.dump();
} else {
llvm::outs() << "f";
}
llvm::outs() << " TTNNLayoutFuncReturnRewriter::mAR\n";

return modified ? success() : failure();
}

Expand Down Expand Up @@ -390,6 +494,9 @@ class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
// Takes func::Return op and sets layout which will
// move it's operands to host
patterns.add<TTNNLayoutFuncReturnRewriter>(&getContext());
// Move operands + results of hoisted funcs to and from device
// appropriately
patterns.add<TTNNLayoutHoistedFuncCallRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
GreedyRewriteConfig config = GreedyRewriteConfig();
config.useTopDownTraversal = true;
Expand All @@ -398,6 +505,9 @@ class TTNNLayout : public impl::TTNNLayoutBase<TTNNLayout> {
signalPassFailure();
return;
}
llvm::outs() << "IR dump:\n";
getOperation()->dump();
llvm::outs() << "e TTNNLayout::rOO\n";
}
}

Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TTNN/Transforms/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
using impl::TTNNWorkaroundsBase<TTNNWorkarounds>::TTNNWorkaroundsBase;

void runOnOperation() final {
llvm::outs() << "s TTNNWorkarounds::rOO\n";
if (decompositionWorkaroundsEnabled) {
// Placeholder for workaround decomposition patterns.
RewritePatternSet patterns(&getContext());
Expand Down Expand Up @@ -438,6 +439,7 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
return;
}
}
llvm::outs() << "e TTNNWorkarounds::rOO\n";
}
};
} // namespace mlir::tt::ttnn