Skip to content

Commit

Permalink
Add hoist mechanism for specific TTIR ops, with placeholder analysis …
Browse files Browse the repository at this point in the history
…pass (#1586)

### Goal: The end-to-end goal is to integrate a path to compile and
execute specific ops or sets of ops on the CPU.

### Context:
The entire task will be split into (tentatively) 7 PRs, as follows:
1. Hoist specific ops into isolated funcs in a separate module
2. Convert TTIR ops to linalg ops within the module of hoisted funcs
3. **Build a pipeline to lower linalg to llvm from existing conversion
passes**
4. Translate LLVM Dialect into a dynamic library for packing into
flatbuffer
5. Generate helper functions so that we can call all of our hoisted
funcs with a common signature
6. Insert TTNN instructions to move operands to host before executing
hoisted func, then back to device afterwards
7. Update ttir-to-ttnn and ttnn-to-flatbuffer pipelines to use new
passes, generate dylibs, and embed them into output flatbuffers, and
update update runtime to consume dylibs from flatbuffers

This PR represents the 1st point above. Here, we build hoisting
(placeholder) analysis + transform pass to mark specific ops to be
hoisted, and then actually pull them into separate functions in a new
"cpu" module + replace the original op with a call to the func.


## Example:
    input:
```
    module {
      func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
        %0 = tensor.empty() : tensor<32x32xbf16>
        %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> loc("add_op1")
        return %1 : tensor<32x32xbf16>
      }
    }
```

    output:
```
    module {
      func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
        %0 = tensor.empty() : tensor<32x32xbf16>
        %1 = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
        return %1 : tensor<32x32xbf16>
      }
      module @cpu_module attributes {ttir.cpu_module} {
        func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> attributes {arg_ranks = [2, 2, 2, 2]} {
          %0 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
          return %0 : tensor<32x32xbf16>
        }
      }
      func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
    }
```
  • Loading branch information
vwellsTT authored Dec 20, 2024
1 parent 2ed01e6 commit 432f8d8
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallString.h"

namespace mlir::tt::ttir {
#define GEN_PASS_DECL
Expand Down
39 changes: 39 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,43 @@ def TTIRBroadcastFold: Pass<"ttir-broadcast-fold", "::mlir::ModuleOp"> {
}];
}

def TTIRHoistTransform: Pass<"ttir-cpu-hoist-transform", "::mlir::ModuleOp">
{
let summary = "Transform to perform hoist mechanics on any ops marked to be hoisted for CPU lowering";
let description = [{
Transform pass which runs an analysis pass to find ops which should be hoisted, and then hoists those ops. Currently we only have a manual analysis which requires a commandline list of named locs to hoist--in the future, we will have an automatic analysis as well.

Example:
input:
module {
func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
%0 = tensor.empty() : tensor<32x32xbf16>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> loc("add_op1")
return %1 : tensor<32x32xbf16>
}
}

output:
module {
func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
%0 = tensor.empty() : tensor<32x32xbf16>
%1 = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %1 : tensor<32x32xbf16>
}
module @cpu_module attributes {ttir.cpu_module} {
func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>) -> tensor<32x32xbf16> attributes {arg_ranks = [2, 2, 2, 2]} {
%0 = "ttir.add"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %0 : tensor<32x32xbf16>
}
}
func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func(tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
}

}];
let options = [
ListOption<"hoistLocs", "hoist-locations", "std::string",
"comma-separated list of NameLoc's from input.">,
];
}

#endif
1 change: 1 addition & 0 deletions lib/Dialect/TTIR/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTTIRTransforms
Broadcast.cpp
Constant.cpp
Generic.cpp
HoistCPUOps.cpp
Layout.cpp
Transforms.cpp
Utility.cpp
Expand Down
226 changes: 226 additions & 0 deletions lib/Dialect/TTIR/Transforms/HoistCPUOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir::tt::ttir {
#define GEN_PASS_DEF_TTIRHOISTTRANSFORM
#include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc"

//===----------------------------------------------------------------------===//
// Hoist CPU ops to standalone funcs pass
//===----------------------------------------------------------------------===//

// Helper function to get ranks of an op's operands
// we use this to populate attrs which we need to tensor unpacking operations
// later.
static llvm::SmallVector<int64_t, 4>
getOperandTensorRanks(mlir::Operation *op) {
llvm::SmallVector<int64_t, 4> ranks;

// Iterate over operands (inputs)
for (auto operand : op->getOperands()) {
// Check if the operand is a tensor
if (auto tensorType = dyn_cast<mlir::RankedTensorType>(operand.getType())) {
// Add the rank of the tensor (number of dimensions)
ranks.push_back(tensorType.getRank());
}
}

return ranks;
}

// Generate unique name base on operation type + argument tensors dims & types.
static llvm::SmallString<16> generateHoistedFuncName(mlir::Operation *op) {
// Start building the unique function name
llvm::SmallString<16> uniqueName("hoisted_");
uniqueName.append(op->getName().getStringRef().begin(),
op->getName().getStringRef().end());

// Iterate over operands to extract tensor shapes and types
for (auto operand : op->getOperands()) {
auto rankedTensorType = dyn_cast<mlir::RankedTensorType>(operand.getType());
if (rankedTensorType) {
// Append the shape (dimensions) and the element type
llvm::SmallString<5> shapeStr("_");
for (auto dim : rankedTensorType.getShape()) {
shapeStr += std::to_string(dim) + "x";
}

// Append the element type (e.g., f32, i32) -- unforunately I don't think
// there's a better way to get string from mlir::Type
std::string elementTypeStr;
llvm::raw_string_ostream stream(elementTypeStr);
rankedTensorType.getElementType().print(stream);

uniqueName.append(shapeStr.begin(), shapeStr.end());
uniqueName.append(elementTypeStr.begin(), elementTypeStr.end());
}
}

// Add suffix to indicate it's a function
uniqueName += "_func";

return uniqueName;
}

// Helper function to hoist an arbitrary op into a new function in targetModule,
// generate a matching extern prototype in the sourceModule, and replace the
// original op with a callOp to the extern function.
static void hoistOperationToFunction(mlir::Operation *opToHoist,
mlir::ModuleOp sourceModule,
mlir::ModuleOp targetModule) {
const llvm::SmallVector<int64_t, 4> ranks = getOperandTensorRanks(opToHoist);

const llvm::SmallString<16> functionName = generateHoistedFuncName(opToHoist);

auto localFunc = llvm::dyn_cast_or_null<func::FuncOp>(
sourceModule.lookupSymbol(functionName));

// Create a new hoisted function only if an equivalent one does not exist.
if (localFunc == nullptr) {
mlir::MLIRContext *context = sourceModule.getContext();

// Gather operand and result types.
llvm::SmallVector<mlir::Type> operandTypes, resultTypes;
for (auto operand : opToHoist->getOperands()) {
operandTypes.push_back(operand.getType());
}
for (auto result : opToHoist->getResultTypes()) {
resultTypes.push_back(result);
}

// Create the function signature.
mlir::FunctionType funcType =
mlir::FunctionType::get(context, operandTypes, resultTypes);

// Create the function in the target module.
auto hoistedFunc =
func::FuncOp::create(opToHoist->getLoc(), functionName, funcType);
targetModule.push_back(hoistedFunc);

// Add a basic block to the function.
mlir::Block *block = hoistedFunc.addEntryBlock();
mlir::OpBuilder builder(block, block->end());

// Map operands to block arguments and clone the operation.
llvm::SmallVector<mlir::Value> newOperands;
for (auto operand : llvm::enumerate(opToHoist->getOperands())) {
newOperands.push_back(block->getArgument(operand.index()));
}

mlir::IRMapping mapping;
for (auto operand : llvm::zip(opToHoist->getOperands(), newOperands)) {
mapping.map(std::get<0>(operand), std::get<1>(operand));
}

mlir::Operation *clonedOp = builder.clone(*opToHoist, mapping);

// Add a return operation to the function.
builder.create<mlir::func::ReturnOp>(opToHoist->getLoc(),
clonedOp->getResults());

// Declare the function prototype in the source module.
localFunc =
func::FuncOp::create(opToHoist->getLoc(), functionName, funcType);
localFunc.setPrivate();
sourceModule.push_back(localFunc);

hoistedFunc->setAttr("arg_ranks", builder.getI64ArrayAttr(ranks));
}

// Replace the original operation with a call to the hoisted function.
mlir::OpBuilder opBuilder(opToHoist);
auto callOp = opBuilder.create<mlir::func::CallOp>(
opToHoist->getLoc(), localFunc, opToHoist->getOperands());

// Replace all results of the original operation with the call results.
opToHoist->replaceAllUsesWith(callOp);

// Erase the original operation.
opToHoist->erase();
}

// An analysis class which currently relies on manually tagging ops with a
// `should_hoist` attribute, but in the future will also tag fall-back ops, etc.
class TTIRHoistAnalyze {
public:
using HoistOpSet = llvm::SmallVector<llvm::SmallSet<mlir::Operation *, 4>>;

TTIRHoistAnalyze(mlir::ModuleOp moduleOp) {
moduleOp.walk([&](mlir::Operation *nestedOp) {
if (nestedOp->hasAttr("should_hoist")) {
llvm::SmallSet<mlir::Operation *, 4> opSet;
opSet.insert(nestedOp);
hoistedOps.push_back(opSet);
}
});
}

HoistOpSet getResults() { return hoistedOps; }

private:
HoistOpSet hoistedOps;
};

// Transform pass to hoist specific ops (based on configured analysis pass) into
// a cpu submodule for later independent lowering.
class TTIRHoistTransform
: public impl::TTIRHoistTransformBase<TTIRHoistTransform> {
public:
using impl::TTIRHoistTransformBase<
TTIRHoistTransform>::TTIRHoistTransformBase;

void runOnOperation() final {
mlir::ModuleOp moduleOp = getOperation();

IRRewriter rewriter(&getContext());

auto loc = moduleOp->getLoc();

TTIRHoistAnalyze analysisPass(moduleOp);
const TTIRHoistAnalyze::HoistOpSet &hoistOpSets = analysisPass.getResults();

// Check if a "cpu_module" already exists.
mlir::ModuleOp cpuModule;
for (auto &op : moduleOp.getBody()->getOperations()) {
if (auto module = llvm::dyn_cast<mlir::ModuleOp>(op)) {
if (module->hasAttr("ttir.cpu_module")) {
cpuModule = module;
break;
}
}
}

// If no CPU module exists, create one.
if (!cpuModule) {
rewriter.setInsertionPointToEnd(moduleOp.getBody());
cpuModule = rewriter.create<mlir::ModuleOp>(loc);
cpuModule->setAttr("ttir.cpu_module", rewriter.getUnitAttr());
cpuModule->setAttr(mlir::SymbolTable::getSymbolAttrName(),
rewriter.getStringAttr("cpu_module"));
// try to make cpu module global
mlir::SymbolTable::setSymbolVisibility(
cpuModule, mlir::SymbolTable::Visibility::Public);
}

for (const auto &opSet : hoistOpSets) {
assert(opSet.size() == 1 &&
"currently don't support hoisting multiple instructions at once!");
hoistOperationToFunction(*opSet.begin(), moduleOp, cpuModule);
}
}
};

} // namespace mlir::tt::ttir
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/Types/Types.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
Expand Down
37 changes: 37 additions & 0 deletions test/ttmlir/Dialect/TTIR/hoist/simple_add.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// RUN: ttmlir-opt --ttir-cpu-hoist-transform %s | FileCheck %s

func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
%0 = tensor.empty() : tensor<32x32xbf16>
// CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> {should_hoist} : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %1 : tensor<32x32xbf16>
}

func.func @add2(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = tensor.empty() : tensor<32x32xf32>
// CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> {should_hoist} : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
return %1 : tensor<32x32xf32>
}

func.func @add3(%arg0: tensor<32x3xf32>, %arg1: tensor<32x3xf32>) -> tensor<32x3xf32> {
%0 = tensor.empty() : tensor<32x3xf32>
// CHECK: %{{.*}} = call @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> {should_hoist} : (tensor<32x3xf32>, tensor<32x3xf32>, tensor<32x3xf32>) -> tensor<32x3xf32>
return %1 : tensor<32x3xf32>
}

func.func @add4(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
%0 = tensor.empty() : tensor<32x32xbf16>
// CHECK: %{{.*}} = call @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> {should_hoist} : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %1 : tensor<32x32xbf16>
}
// CHECK: module @cpu_module attributes {ttir.cpu_module}
// CHECK: func.func @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func
// CHECK: func.func @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func
// CHECK: func.func @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func

// CHECK: func.func private @hoisted_ttir.add_32x32xbf16_32x32xbf16_32x32xbf16_func
// CHECK: func.func private @hoisted_ttir.add_32x32xf32_32x32xf32_32x32xf32_func
// CHECK: func.func private @hoisted_ttir.add_32x3xf32_32x3xf32_32x3xf32_func

0 comments on commit 432f8d8

Please sign in to comment.