Skip to content

Commit

Permalink
Add ModifySignaturesForDylib pass (#1595)
Browse files Browse the repository at this point in the history
  • Loading branch information
svuckovicTT authored Dec 16, 2024
1 parent fa8ea65 commit 593e0d8
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 5 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TT/IR/TTOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
Expand Down
25 changes: 25 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,30 @@
#define TTMLIR_TTMLIR_TTOPS_TD

include "ttmlir/Dialect/TT/IR/TTOpsTypes.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"

def TT_GetTupleElementOp: TT_Op<"get_tuple_element", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "GetTupleElement operation";
let description = [{
Extracts element at `index` position of the `operand` tuple and produces a `result`.

Example:
```mlir
%result = tt.get_tuple_element %operand[0] : (tuple<tensor<32x32xbf16>, tuple<tensor<1x32xf32>>>) -> tensor<32x32xbf16>
```
}];

let arguments = (ins TT_Tuple:$operand,
ConfinedAttr<I32Attr, [IntNonNegative]>:$index
);

let results = (outs TT_TupleReturnType:$result);

let assemblyFormat = [{
$operand `[` $index `]` attr-dict `:` functional-type(operands, results)
}];
}

#endif
8 changes: 8 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -488,4 +488,12 @@ def TT_Device : TT_Type<"Device", "device", []> {
let assemblyFormat = "`<` $desc `>`";
}

//===----------------------------------------------------------------------===//
// Auxiliary type definitions
//===----------------------------------------------------------------------===//

def TT_Tuple : NestedTupleOf<[AnyRankedTensor]>;

def TT_TupleReturnType : AnyTypeOf<[AnyRankedTensor]>;

#endif
32 changes: 32 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,36 @@ def TTNNCreateInputGenerators: Pass<"ttnn-create-input-gens", "::mlir::ModuleOp"
}];
}

def TTNNModifySignaturesForDylib: Pass<"ttnn-modify-signatures-for-dylib", "::mlir::ModuleOp"> {
let summary = "Modify signatures of the functions for dylib path.";
let description = [{
This pass is intended to be used only when the end result is a dylib!

It modifies signatures of forward functions so that they take a canonical
form. Essentially, input tensors are packed into a tuple and then accessed
in the function body. This allows for easier interfacing with the generated
dylib as the signatures are then uniform across all forward functions.

Given a forward function like this:

```mlir
func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
%0 = "ttnn.add"(%arg0, %arg1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %0 : tensor<32x32xbf16>
}
```

The pass will modify the signature and prepend unpacking ops like so:

```mlir
func.func @add(%arg0: tuple<tensor<32x32xbf16>, tensor<32x32xbf16>>) -> tensor<32x32xbf16> {
%0 = tt.get_tuple_element %arg0[0] : (tuple<tensor<32x32xbf16>, tensor<32x32xbf16>>) -> tensor<32x32xbf16>
%1 = tt.get_tuple_element %arg0[1] : (tuple<tensor<32x32xbf16>, tensor<32x32xbf16>>) -> tensor<32x32xbf16>
%2 = "ttnn.add"(%0, %1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %2 : tensor<32x32xbf16>
}
```
}];
}

#endif
49 changes: 45 additions & 4 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h"

#include "ttmlir/Conversion/TTNNToEmitC/Utils.h"
#include "ttmlir/Dialect/TT/IR/TTOps.h"
#include "ttmlir/Dialect/TT/IR/TTOpsDialect.h.inc"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
Expand Down Expand Up @@ -576,6 +577,42 @@ class ArithConstantOpConversionPattern
}
};

class GetTupleElementOpConversionPattern
: public OpConversionPattern<tt::GetTupleElementOp> {

public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(tt::GetTupleElementOp getTupleElementOp,
tt::GetTupleElementOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// SubscriptOp requires a Value object as index, which is created by
// invoking the emitc::LiteralOp
//
Value indexAsVal = rewriter.create<emitc::LiteralOp>(
getTupleElementOp->getLoc(), rewriter.getIndexType(),
std::to_string(adaptor.getIndex()));

// SubscriptOp also returns an emitc::LValueType, so we wrap the OpaqueType
// with LValueType
//
emitc::LValueType lvalueReturnType = emitc::LValueType::get(
emitc::OpaqueType::get(rewriter.getContext(), "ttnn::Tensor"));
Value subscript = rewriter.create<emitc::SubscriptOp>(
getTupleElementOp->getLoc(), lvalueReturnType, adaptor.getOperand(),
indexAsVal);

// As SubscriptOp returns an LValueType, we need to convert it to an
// OpaqueType - this is done by invoking the emitc::LoadOp
//
rewriter.replaceOpWithNewOp<emitc::LoadOp>(
getTupleElementOp, emitc::OpaqueType::get(getContext(), "ttnn::Tensor"),
subscript);
return success();
}
};

// Module Op conversion pattern
//
// This conversion pattern removes attributes from the ModuleOp. Previously,
Expand Down Expand Up @@ -724,10 +761,6 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
patterns.add<DefaultOpConversionPattern<ttnn::MeshShardOp>>(typeConverter,
ctx);

// Module op
//
patterns.add<ModuleOpConversionPattern>(typeConverter, ctx);

// KV Cache ops
//
patterns.add<DefaultOpConversionPattern<ttnn::UpdateCacheOp>>(typeConverter,
Expand All @@ -738,6 +771,14 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Arith ops
//
patterns.add<ArithConstantOpConversionPattern>(typeConverter, ctx);

// Module op
//
patterns.add<ModuleOpConversionPattern>(typeConverter, ctx);

// Tuple ops
//
patterns.add<GetTupleElementOpConversionPattern>(typeConverter, ctx);
}

} // namespace mlir::tt
4 changes: 4 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -40,6 +41,9 @@ class TTNNToEmitCTypeConverter : public TypeConverter {
addConversion([ctx](mlir::TensorType type) -> emitc::OpaqueType {
return emitc::OpaqueType::get(ctx, "ttnn::Tensor");
});
addConversion([ctx](mlir::TupleType type) -> emitc::OpaqueType {
return emitc::OpaqueType::get(ctx, "std::vector<ttnn::Tensor>");
});
}
};

Expand Down
25 changes: 25 additions & 0 deletions lib/Dialect/TT/IR/TTOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,28 @@

#define GET_OP_CLASSES
#include "ttmlir/Dialect/TT/IR/TTOps.cpp.inc"

namespace mlir::tt {

LogicalResult GetTupleElementOp::inferReturnTypes(
MLIRContext *, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {

GetTupleElementOp::Adaptor adaptor(operands, attributes, properties, regions);

auto operandType = dyn_cast<TupleType>(adaptor.getOperand().getType());
if (!operandType) {
return failure();
}
if (adaptor.getIndex() >= static_cast<int64_t>(operandType.size())) {
return emitOptionalError(location, "index ", adaptor.getIndex(),
" is out of bounds of operand with size ",
operandType.size());
}

inferredReturnTypes.push_back(operandType.getType(adaptor.getIndex()));
return success();
}

} // namespace mlir::tt
98 changes: 97 additions & 1 deletion lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "ttmlir/Dialect/TTNN/Transforms/Passes.h"

#include "ttmlir/Dialect/TT/IR/TTOps.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
Expand All @@ -12,6 +14,7 @@
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
Expand All @@ -24,6 +27,7 @@ namespace mlir::tt::ttnn {
#define GEN_PASS_DEF_TTNNDEALLOCATE
#define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS
#define GEN_PASS_DEF_TTNNCREATEINPUTGENERATORS
#define GEN_PASS_DEF_TTNNMODIFYSIGNATURESFORDYLIB
#include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc"

class TTNNDeallocate : public impl::TTNNDeallocateBase<TTNNDeallocate> {
Expand Down Expand Up @@ -906,7 +910,7 @@ class TTNNCreateInputGenerators
//
Block *firstBlock = module.getBody(0);

// Find all the func.func ops in the module
// Find all the func.func ops in the module that are "forward" functions
//
SmallVector<func::FuncOp, 1> forwardFuncOps;
for (mlir::Operation &op : firstBlock->getOperations()) {
Expand Down Expand Up @@ -1065,4 +1069,96 @@ class TTNNCreateInputGenerators
}
};

class TTNNModifySignaturesForDylib
: public impl::TTNNModifySignaturesForDylibBase<
TTNNModifySignaturesForDylib> {

public:
using impl::TTNNModifySignaturesForDylibBase<
TTNNModifySignaturesForDylib>::TTNNModifySignaturesForDylibBase;

void runOnOperation() final {
ModuleOp module = getOperation();
IRRewriter rewriter(&getContext());

// Ensure that the module has a single region and a single block within that
// region
assert(module->getRegions().size() == 1);
assert(module->getRegion(0).getBlocks().size() == 1);

// Get the first block of the region at index 0
//
Block *firstBlock = module.getBody(0);

// Find all the func.func ops in the module that are "forward" functions
//
SmallVector<func::FuncOp, 1> forwardFuncOps;
for (mlir::Operation &op : firstBlock->getOperations()) {
if (mlir::func::FuncOp funcOp = dyn_cast<func::FuncOp>(op)) {

// Skip functions that are called elsewhere in the IR
//
// This will skip utility functions that are used by other functions,
// only top-level "forward" functions should be considered
//
if (!funcOp->getUses().empty()) {
continue;
}

forwardFuncOps.push_back(funcOp);
}
}

// Iterate over all the func ops and modify the signatures
//
for (mlir::func::FuncOp forwardFuncOp : forwardFuncOps) {
// Replace the signature of the forward function so that all the tensor
// arguments are packed into a single tuple
//
mlir::FunctionType originalFuncType = forwardFuncOp.getFunctionType();
assert(
std::all_of(originalFuncType.getInputs().begin(),
originalFuncType.getInputs().end(),
[](Type t) { return mlir::isa<RankedTensorType>(t); }) &&
"Expected all inputs must be of type RankedTensorType");
mlir::TupleType inputTupleType =
mlir::TupleType::get(&getContext(), originalFuncType.getInputs());
FunctionType tuplifiedFuncType =
originalFuncType.clone(inputTupleType, originalFuncType.getResults());
rewriter.modifyOpInPlace(forwardFuncOp,
[&forwardFuncOp, &tuplifiedFuncType]() {
forwardFuncOp.setType(tuplifiedFuncType);
});

// First block of the function (often referred to as "entry block") needs
// its arguments updated as well - the args need to match the containing
// func's arguments; this is implemented here by first inserting the tuple
// as the first argument of the block, inserting GetTupleElementOp ops to
// start of the block in order to unpack tuple elements, and then
// replacing all uses of the original block arguments with the
// GetTupleElementOp results - after this it's finally safe to remove
// original block arguments as they have no live uses anymore
//
Block &entryBlock = forwardFuncOp.getBlocks().front();
entryBlock.insertArgument(/*index=*/0u,
tuplifiedFuncType.getInputs().front(),
forwardFuncOp.getLoc());

rewriter.setInsertionPointToStart(&entryBlock);
for (size_t idx = 0; idx < originalFuncType.getInputs().size(); idx++) {
::mlir::tt::GetTupleElementOp getTupleElementOp =
rewriter.create<mlir::tt::GetTupleElementOp>(
forwardFuncOp.getLoc(), forwardFuncOp.getArgument(0), idx);

rewriter.replaceAllUsesWith(entryBlock.getArgument(1 + idx),
getTupleElementOp);
}

// Erase original arguments
//
entryBlock.eraseArguments(1, originalFuncType.getInputs().size());
}
}
};

} // namespace mlir::tt::ttnn
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" --ttnn-modify-signatures-for-dylib %s | FileCheck %s

module attributes {} {
// CHECK: func.func @add(%arg0: tuple<[[TENSOR_A:.*>]], [[TENSOR_B:.*>]]>) -> tensor<32x32xbf16, #ttnn_layout> {
func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> {
// CHECK-NEXT: %0 = tt.get_tuple_element %arg0[0] : (tuple<[[TENSOR_A]], [[TENSOR_B]]>) -> [[TENSOR_A]]
// CHECK-NEXT: %1 = tt.get_tuple_element %arg0[1] : (tuple<[[TENSOR_A]], [[TENSOR_B]]>) -> [[TENSOR_B]]
%0 = tensor.empty() : tensor<32x32xbf16>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16>
return %1 : tensor<32x32xbf16>
}
}

0 comments on commit 593e0d8

Please sign in to comment.