-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add hoist mechanism for specific TTIR ops, with placeholder analysis …
…pass (#1586) ### Goal: The end-to-end goal is to integrate a path to compile and execute specific ops or sets of ops on the CPU. ### Context: The entire task will be split into (tentatively) 7 PRs, as follows: 1. Hoist specific ops into isolated funcs in a separate module 2. Convert TTIR ops to linalg ops within the module of hoisted funcs 3. **Build a pipeline to lower linalg to llvm from existing conversion passes** 4. Translate LLVM Dialect into a dynamic library for packing into flatbuffer 5. Generate helper functions so that we can call all of our hoisted funcs with a common signature 6. Insert TTNN instructions to move operands to host before executing hoisted func, then back to device afterwards 7. Update ttir-to-ttnn and ttnn-to-flatbuffer pipelines to use new passes, generate dylibs, and embed them into output flatbuffers, and update update runtime to consume dylibs from flatbuffers This PR represents the 1st point above. Here, we build hoisting (placeholder) analysis + transform pass to mark specific ops to be hoisted, and then actually pull them into separate functions in a new "cpu" module + replace the original op with a call to the func. ## Example: input: ``` module { func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> loc("add_op1") return %1 : tensor<32x32xbf16> } } ``` output: ``` module { func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> %1 = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } module @cpu_module attributes {ttir.cpu_module} { func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> attributes {arg_ranks = [2, 2, 2, 2]} { %0 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %0 : tensor<32x32xbf16> } } func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> } ```
- Loading branch information
Showing
6 changed files
with
305 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttmlir/Dialect/TT/IR/TT.h" | ||
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h" | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "llvm/ADT/SmallSet.h" | ||
#include "llvm/ADT/SmallString.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
|
||
namespace mlir::tt::ttir { | ||
#define GEN_PASS_DEF_TTIRHOISTTRANSFORM | ||
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Hoist CPU ops to standalone funcs pass | ||
//===----------------------------------------------------------------------===// | ||
|
||
// Helper function to get ranks of an op's operands | ||
// we use this to populate attrs which we need to tensor unpacking operations | ||
// later. | ||
static llvm::SmallVector<int64_t, 4> | ||
getOperandTensorRanks(mlir::Operation *op) { | ||
llvm::SmallVector<int64_t, 4> ranks; | ||
|
||
// Iterate over operands (inputs) | ||
for (auto operand : op->getOperands()) { | ||
// Check if the operand is a tensor | ||
if (auto tensorType = dyn_cast<mlir::RankedTensorType>(operand.getType())) { | ||
// Add the rank of the tensor (number of dimensions) | ||
ranks.push_back(tensorType.getRank()); | ||
} | ||
} | ||
|
||
return ranks; | ||
} | ||
|
||
// Generate unique name base on operation type + argument tensors dims & types. | ||
static llvm::SmallString<16> generateHoistedFuncName(mlir::Operation *op) { | ||
// Start building the unique function name | ||
llvm::SmallString<16> uniqueName("hoisted_"); | ||
uniqueName.append(op->getName().getStringRef().begin(), | ||
op->getName().getStringRef().end()); | ||
|
||
// Iterate over operands to extract tensor shapes and types | ||
for (auto operand : op->getOperands()) { | ||
auto rankedTensorType = dyn_cast<mlir::RankedTensorType>(operand.getType()); | ||
if (rankedTensorType) { | ||
// Append the shape (dimensions) and the element type | ||
llvm::SmallString<5> shapeStr("_"); | ||
for (auto dim : rankedTensorType.getShape()) { | ||
shapeStr += std::to_string(dim) + "x"; | ||
} | ||
|
||
// Append the element type (e.g., f32, i32) -- unforunately I don't think | ||
// there's a better way to get string from mlir::Type | ||
std::string elementTypeStr; | ||
llvm::raw_string_ostream stream(elementTypeStr); | ||
rankedTensorType.getElementType().print(stream); | ||
|
||
uniqueName.append(shapeStr.begin(), shapeStr.end()); | ||
uniqueName.append(elementTypeStr.begin(), elementTypeStr.end()); | ||
} | ||
} | ||
|
||
// Add suffix to indicate it's a function | ||
uniqueName += "_func"; | ||
|
||
return uniqueName; | ||
} | ||
|
||
// Helper function to hoist an arbitrary op into a new function in targetModule, | ||
// generate a matching extern prototype in the sourceModule, and replace the | ||
// original op with a callOp to the extern function. | ||
static void hoistOperationToFunction(mlir::Operation *opToHoist, | ||
mlir::ModuleOp sourceModule, | ||
mlir::ModuleOp targetModule) { | ||
const llvm::SmallVector<int64_t, 4> ranks = getOperandTensorRanks(opToHoist); | ||
|
||
const llvm::SmallString<16> functionName = generateHoistedFuncName(opToHoist); | ||
|
||
auto localFunc = llvm::dyn_cast_or_null<func::FuncOp>( | ||
sourceModule.lookupSymbol(functionName)); | ||
|
||
// Create a new hoisted function only if an equivalent one does not exist. | ||
if (localFunc == nullptr) { | ||
mlir::MLIRContext *context = sourceModule.getContext(); | ||
|
||
// Gather operand and result types. | ||
llvm::SmallVector<mlir::Type> operandTypes, resultTypes; | ||
for (auto operand : opToHoist->getOperands()) { | ||
operandTypes.push_back(operand.getType()); | ||
} | ||
for (auto result : opToHoist->getResultTypes()) { | ||
resultTypes.push_back(result); | ||
} | ||
|
||
// Create the function signature. | ||
mlir::FunctionType funcType = | ||
mlir::FunctionType::get(context, operandTypes, resultTypes); | ||
|
||
// Create the function in the target module. | ||
auto hoistedFunc = | ||
func::FuncOp::create(opToHoist->getLoc(), functionName, funcType); | ||
targetModule.push_back(hoistedFunc); | ||
|
||
// Add a basic block to the function. | ||
mlir::Block *block = hoistedFunc.addEntryBlock(); | ||
mlir::OpBuilder builder(block, block->end()); | ||
|
||
// Map operands to block arguments and clone the operation. | ||
llvm::SmallVector<mlir::Value> newOperands; | ||
for (auto operand : llvm::enumerate(opToHoist->getOperands())) { | ||
newOperands.push_back(block->getArgument(operand.index())); | ||
} | ||
|
||
mlir::IRMapping mapping; | ||
for (auto operand : llvm::zip(opToHoist->getOperands(), newOperands)) { | ||
mapping.map(std::get<0>(operand), std::get<1>(operand)); | ||
} | ||
|
||
mlir::Operation *clonedOp = builder.clone(*opToHoist, mapping); | ||
|
||
// Add a return operation to the function. | ||
builder.create<mlir::func::ReturnOp>(opToHoist->getLoc(), | ||
clonedOp->getResults()); | ||
|
||
// Declare the function prototype in the source module. | ||
localFunc = | ||
func::FuncOp::create(opToHoist->getLoc(), functionName, funcType); | ||
localFunc.setPrivate(); | ||
sourceModule.push_back(localFunc); | ||
|
||
hoistedFunc->setAttr("arg_ranks", builder.getI64ArrayAttr(ranks)); | ||
} | ||
|
||
// Replace the original operation with a call to the hoisted function. | ||
mlir::OpBuilder opBuilder(opToHoist); | ||
auto callOp = opBuilder.create<mlir::func::CallOp>( | ||
opToHoist->getLoc(), localFunc, opToHoist->getOperands()); | ||
|
||
// Replace all results of the original operation with the call results. | ||
opToHoist->replaceAllUsesWith(callOp); | ||
|
||
// Erase the original operation. | ||
opToHoist->erase(); | ||
} | ||
|
||
// An analysis class which currently relies on manually tagging ops with a | ||
// `should_hoist` attribute, but in the future will also tag fall-back ops, etc. | ||
class TTIRHoistAnalyze { | ||
public: | ||
using HoistOpSet = llvm::SmallVector<llvm::SmallSet<mlir::Operation *, 4>>; | ||
|
||
TTIRHoistAnalyze(mlir::ModuleOp moduleOp) { | ||
moduleOp.walk([&](mlir::Operation *nestedOp) { | ||
if (nestedOp->hasAttr("should_hoist")) { | ||
llvm::SmallSet<mlir::Operation *, 4> opSet; | ||
opSet.insert(nestedOp); | ||
hoistedOps.push_back(opSet); | ||
} | ||
}); | ||
} | ||
|
||
HoistOpSet getResults() { return hoistedOps; } | ||
|
||
private: | ||
HoistOpSet hoistedOps; | ||
}; | ||
|
||
// Transform pass to hoist specific ops (based on configured analysis pass) into | ||
// a cpu submodule for later independent lowering. | ||
class TTIRHoistTransform | ||
: public impl::TTIRHoistTransformBase<TTIRHoistTransform> { | ||
public: | ||
using impl::TTIRHoistTransformBase< | ||
TTIRHoistTransform>::TTIRHoistTransformBase; | ||
|
||
void runOnOperation() final { | ||
mlir::ModuleOp moduleOp = getOperation(); | ||
|
||
IRRewriter rewriter(&getContext()); | ||
|
||
auto loc = moduleOp->getLoc(); | ||
|
||
TTIRHoistAnalyze analysisPass(moduleOp); | ||
const TTIRHoistAnalyze::HoistOpSet &hoistOpSets = analysisPass.getResults(); | ||
|
||
// Check if a "cpu_module" already exists. | ||
mlir::ModuleOp cpuModule; | ||
for (auto &op : moduleOp.getBody()->getOperations()) { | ||
if (auto module = llvm::dyn_cast<mlir::ModuleOp>(op)) { | ||
if (module->hasAttr("ttir.cpu_module")) { | ||
cpuModule = module; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
// If no CPU module exists, create one. | ||
if (!cpuModule) { | ||
rewriter.setInsertionPointToEnd(moduleOp.getBody()); | ||
cpuModule = rewriter.create<mlir::ModuleOp>(loc); | ||
cpuModule->setAttr("ttir.cpu_module", rewriter.getUnitAttr()); | ||
cpuModule->setAttr(mlir::SymbolTable::getSymbolAttrName(), | ||
rewriter.getStringAttr("cpu_module")); | ||
// try to make cpu module global | ||
mlir::SymbolTable::setSymbolVisibility( | ||
cpuModule, mlir::SymbolTable::Visibility::Public); | ||
} | ||
|
||
for (const auto &opSet : hoistOpSets) { | ||
assert(opSet.size() == 1 && | ||
"currently don't support hoisting multiple instructions at once!"); | ||
hoistOperationToFunction(*opSet.begin(), moduleOp, cpuModule); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace mlir::tt::ttir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// RUN: ttmlir-opt --ttir-cpu-hoist-transform %s | FileCheck %s | ||
|
||
func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { | ||
%0 = tensor.empty() : tensor<32x32xbf16> | ||
// CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func | ||
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> {should_hoist} : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> | ||
return %1 : tensor<32x32xbf16> | ||
} | ||
|
||
func.func @add2(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { | ||
%0 = tensor.empty() : tensor<32x32xf32> | ||
// CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func | ||
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> {should_hoist} : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> | ||
return %1 : tensor<32x32xf32> | ||
} | ||
|
||
func.func @add3(%arg0: tensor<32x3xf32>, %arg1: tensor<32x3xf32>) -> tensor<32x3xf32> { | ||
%0 = tensor.empty() : tensor<32x3xf32> | ||
// CHECK: %{{.*}} = call @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func | ||
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> {should_hoist} : (tensor<32x3xf32>, tensor<32x3xf32>, tensor<32x3xf32>) -> tensor<32x3xf32> | ||
return %1 : tensor<32x3xf32> | ||
} | ||
|
||
func.func @add4(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { | ||
%0 = tensor.empty() : tensor<32x32xbf16> | ||
// CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func | ||
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> {should_hoist} : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> | ||
return %1 : tensor<32x32xbf16> | ||
} | ||
// CHECK: module @cpu_module attributes {ttir.cpu_module} | ||
// CHECK: func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func | ||
// CHECK: func.func @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func | ||
// CHECK: func.func @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func | ||
|
||
// CHECK: func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func | ||
// CHECK: func.func private @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func | ||
// CHECK: func.func private @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func |