diff --git a/include/ttmlir/Dialect/TT/IR/TTOps.h b/include/ttmlir/Dialect/TT/IR/TTOps.h index 047fc2a3c..16219c04e 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOps.h +++ b/include/ttmlir/Dialect/TT/IR/TTOps.h @@ -9,6 +9,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/include/ttmlir/Dialect/TT/IR/TTOps.td b/include/ttmlir/Dialect/TT/IR/TTOps.td index d3b34fff8..9078028ab 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOps.td +++ b/include/ttmlir/Dialect/TT/IR/TTOps.td @@ -6,5 +6,30 @@ #define TTMLIR_TTMLIR_TTOPS_TD include "ttmlir/Dialect/TT/IR/TTOpsTypes.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/CommonTypeConstraints.td" + +def TT_GetTupleElementOp: TT_Op<"get_tuple_element", [Pure, DeclareOpInterfaceMethods]> { + let summary = "GetTupleElement operation"; + let description = [{ + Extracts element at `index` position of the `operand` tuple and produces a `result`. + + Example: + ```mlir + %result = tt.get_tuple_element %operand[0] : (tuple, tuple>>) -> tensor<32x32xbf16> + ``` + }]; + + let arguments = (ins TT_Tuple:$operand, + ConfinedAttr:$index + ); + + let results = (outs TT_TupleReturnType:$result); + + let assemblyFormat = [{ + $operand `[` $index `]` attr-dict `:` functional-type(operands, results) + }]; +} #endif diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index c690b8bca..7472c298b 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -488,4 +488,12 @@ def TT_Device : TT_Type<"Device", "device", []> { let assemblyFormat = "`<` $desc `>`"; } +//===----------------------------------------------------------------------===// +// Auxiliary type definitions +//===----------------------------------------------------------------------===// + +def TT_Tuple : NestedTupleOf<[AnyRankedTensor]>; + +def TT_TupleReturnType : AnyTypeOf<[AnyRankedTensor]>; + #endif diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td index 99a9bed24..4597db87e 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTNN/Transforms/Passes.td @@ -86,4 +86,36 @@ def TTNNCreateInputGenerators: Pass<"ttnn-create-input-gens", "::mlir::ModuleOp" }]; } +def TTNNModifySignaturesForDylib: Pass<"ttnn-modify-signatures-for-dylib", "::mlir::ModuleOp"> { + let summary = "Modify signatures of the functions for dylib path."; + let description = [{ + This pass is intended to be used only when the end result is a dylib! + + It modifies signatures of forward functions so that they take a canonical + form. Essentially, input tensors are packed into a tuple and then accessed + in the function body. This allows for easier interfacing with the generated + dylib as the signatures are then uniform across all forward functions. + + Given a forward function like this: + + ```mlir + func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + %0 = "ttnn.add"(%arg0, %arg1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %0 : tensor<32x32xbf16> + } + ``` + + The pass will modify the signature and prepend unpacking ops like so: + + ```mlir + func.func @add(%arg0: tuple, tensor<32x32xbf16>>) -> tensor<32x32xbf16> { + %0 = tt.get_tuple_element %arg0[0] : (tuple, tensor<32x32xbf16>>) -> tensor<32x32xbf16> + %1 = tt.get_tuple_element %arg0[1] : (tuple, tensor<32x32xbf16>>) -> tensor<32x32xbf16> + %2 = "ttnn.add"(%0, %1) : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %2 : tensor<32x32xbf16> + } + ``` + }]; +} + #endif diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index b1d438b90..c1a07b5fc 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -5,6 +5,7 @@ #include "ttmlir/Conversion/TTNNToEmitC/TTNNToEmitC.h" #include "ttmlir/Conversion/TTNNToEmitC/Utils.h" +#include "ttmlir/Dialect/TT/IR/TTOps.h" #include "ttmlir/Dialect/TT/IR/TTOpsDialect.h.inc" #include "ttmlir/Dialect/TTNN/IR/TTNN.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" @@ -576,6 +577,42 @@ class ArithConstantOpConversionPattern } }; +class GetTupleElementOpConversionPattern + : public OpConversionPattern { + +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tt::GetTupleElementOp getTupleElementOp, + tt::GetTupleElementOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // SubscriptOp requires a Value object as index, which is created by + // invoking the emitc::LiteralOp + // + Value indexAsVal = rewriter.create( + getTupleElementOp->getLoc(), rewriter.getIndexType(), + std::to_string(adaptor.getIndex())); + + // SubscriptOp also returns an emitc::LValueType, so we wrap the OpaqueType + // with LValueType + // + emitc::LValueType lvalueReturnType = emitc::LValueType::get( + emitc::OpaqueType::get(rewriter.getContext(), "ttnn::Tensor")); + Value subscript = rewriter.create( + getTupleElementOp->getLoc(), lvalueReturnType, adaptor.getOperand(), + indexAsVal); + + // As SubscriptOp returns an LValueType, we need to convert it to an + // OpaqueType - this is done by invoking the emitc::LoadOp + // + rewriter.replaceOpWithNewOp( + getTupleElementOp, emitc::OpaqueType::get(getContext(), "ttnn::Tensor"), + subscript); + return success(); + } +}; + // Module Op conversion pattern // // This conversion pattern removes attributes from the ModuleOp. Previously, @@ -724,10 +761,6 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, patterns.add>(typeConverter, ctx); - // Module op - // - patterns.add(typeConverter, ctx); - // KV Cache ops // patterns.add>(typeConverter, @@ -738,6 +771,14 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Arith ops // patterns.add(typeConverter, ctx); + + // Module op + // + patterns.add(typeConverter, ctx); + + // Tuple ops + // + patterns.add(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp index bd0c9044f..95e722d84 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitCPass.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -40,6 +41,9 @@ class TTNNToEmitCTypeConverter : public TypeConverter { addConversion([ctx](mlir::TensorType type) -> emitc::OpaqueType { return emitc::OpaqueType::get(ctx, "ttnn::Tensor"); }); + addConversion([ctx](mlir::TupleType type) -> emitc::OpaqueType { + return emitc::OpaqueType::get(ctx, "std::vector"); + }); } }; diff --git a/lib/Dialect/TT/IR/TTOps.cpp b/lib/Dialect/TT/IR/TTOps.cpp index 6f15f813e..b4f3b951d 100644 --- a/lib/Dialect/TT/IR/TTOps.cpp +++ b/lib/Dialect/TT/IR/TTOps.cpp @@ -7,3 +7,28 @@ #define GET_OP_CLASSES #include "ttmlir/Dialect/TT/IR/TTOps.cpp.inc" + +namespace mlir::tt { + +LogicalResult GetTupleElementOp::inferReturnTypes( + MLIRContext *, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + + GetTupleElementOp::Adaptor adaptor(operands, attributes, properties, regions); + + auto operandType = dyn_cast(adaptor.getOperand().getType()); + if (!operandType) { + return failure(); + } + if (adaptor.getIndex() >= static_cast(operandType.size())) { + return emitOptionalError(location, "index ", adaptor.getIndex(), + " is out of bounds of operand with size ", + operandType.size()); + } + + inferredReturnTypes.push_back(operandType.getType(adaptor.getIndex())); + return success(); +} + +} // namespace mlir::tt diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 20172f4fd..f35768d63 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -4,6 +4,8 @@ #include "ttmlir/Dialect/TTNN/Transforms/Passes.h" +#include "ttmlir/Dialect/TT/IR/TTOps.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" @@ -12,6 +14,7 @@ #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" @@ -24,6 +27,7 @@ namespace mlir::tt::ttnn { #define GEN_PASS_DEF_TTNNDEALLOCATE #define GEN_PASS_DEF_TTNNDECOMPOSELAYOUTS #define GEN_PASS_DEF_TTNNCREATEINPUTGENERATORS +#define GEN_PASS_DEF_TTNNMODIFYSIGNATURESFORDYLIB #include "ttmlir/Dialect/TTNN/Transforms/Passes.h.inc" class TTNNDeallocate : public impl::TTNNDeallocateBase { @@ -906,7 +910,7 @@ class TTNNCreateInputGenerators // Block *firstBlock = module.getBody(0); - // Find all the func.func ops in the module + // Find all the func.func ops in the module that are "forward" functions // SmallVector forwardFuncOps; for (mlir::Operation &op : firstBlock->getOperations()) { @@ -1065,4 +1069,96 @@ class TTNNCreateInputGenerators } }; +class TTNNModifySignaturesForDylib + : public impl::TTNNModifySignaturesForDylibBase< + TTNNModifySignaturesForDylib> { + +public: + using impl::TTNNModifySignaturesForDylibBase< + TTNNModifySignaturesForDylib>::TTNNModifySignaturesForDylibBase; + + void runOnOperation() final { + ModuleOp module = getOperation(); + IRRewriter rewriter(&getContext()); + + // Ensure that the module has a single region and a single block within that + // region + assert(module->getRegions().size() == 1); + assert(module->getRegion(0).getBlocks().size() == 1); + + // Get the first block of the region at index 0 + // + Block *firstBlock = module.getBody(0); + + // Find all the func.func ops in the module that are "forward" functions + // + SmallVector forwardFuncOps; + for (mlir::Operation &op : firstBlock->getOperations()) { + if (mlir::func::FuncOp funcOp = dyn_cast(op)) { + + // Skip functions that are called elsewhere in the IR + // + // This will skip utility functions that are used by other functions, + // only top-level "forward" functions should be considered + // + if (!funcOp->getUses().empty()) { + continue; + } + + forwardFuncOps.push_back(funcOp); + } + } + + // Iterate over all the func ops and modify the signatures + // + for (mlir::func::FuncOp forwardFuncOp : forwardFuncOps) { + // Replace the signature of the forward function so that all the tensor + // arguments are packed into a single tuple + // + mlir::FunctionType originalFuncType = forwardFuncOp.getFunctionType(); + assert( + std::all_of(originalFuncType.getInputs().begin(), + originalFuncType.getInputs().end(), + [](Type t) { return mlir::isa(t); }) && + "Expected all inputs must be of type RankedTensorType"); + mlir::TupleType inputTupleType = + mlir::TupleType::get(&getContext(), originalFuncType.getInputs()); + FunctionType tuplifiedFuncType = + originalFuncType.clone(inputTupleType, originalFuncType.getResults()); + rewriter.modifyOpInPlace(forwardFuncOp, + [&forwardFuncOp, &tuplifiedFuncType]() { + forwardFuncOp.setType(tuplifiedFuncType); + }); + + // First block of the function (often referred to as "entry block") needs + // its arguments updated as well - the args need to match the containing + // func's arguments; this is implemented here by first inserting the tuple + // as the first argument of the block, inserting GetTupleElementOp ops to + // start of the block in order to unpack tuple elements, and then + // replacing all uses of the original block arguments with the + // GetTupleElementOp results - after this it's finally safe to remove + // original block arguments as they have no live uses anymore + // + Block &entryBlock = forwardFuncOp.getBlocks().front(); + entryBlock.insertArgument(/*index=*/0u, + tuplifiedFuncType.getInputs().front(), + forwardFuncOp.getLoc()); + + rewriter.setInsertionPointToStart(&entryBlock); + for (size_t idx = 0; idx < originalFuncType.getInputs().size(); idx++) { + ::mlir::tt::GetTupleElementOp getTupleElementOp = + rewriter.create( + forwardFuncOp.getLoc(), forwardFuncOp.getArgument(0), idx); + + rewriter.replaceAllUsesWith(entryBlock.getArgument(1 + idx), + getTupleElementOp); + } + + // Erase original arguments + // + entryBlock.eraseArguments(1, originalFuncType.getInputs().size()); + } + } +}; + } // namespace mlir::tt::ttnn diff --git a/test/ttmlir/Dialect/TTNN/Transforms/ttnn_modify_signatures_for_dylib_0.mlir b/test/ttmlir/Dialect/TTNN/Transforms/ttnn_modify_signatures_for_dylib_0.mlir new file mode 100644 index 000000000..f7cab8590 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/Transforms/ttnn_modify_signatures_for_dylib_0.mlir @@ -0,0 +1,12 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" --ttnn-modify-signatures-for-dylib %s | FileCheck %s + +module attributes {} { + // CHECK: func.func @add(%arg0: tuple<[[TENSOR_A:.*>]], [[TENSOR_B:.*>]]>) -> tensor<32x32xbf16, #ttnn_layout> { + func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { + // CHECK-NEXT: %0 = tt.get_tuple_element %arg0[0] : (tuple<[[TENSOR_A]], [[TENSOR_B]]>) -> [[TENSOR_A]] + // CHECK-NEXT: %1 = tt.get_tuple_element %arg0[1] : (tuple<[[TENSOR_A]], [[TENSOR_B]]>) -> [[TENSOR_B]] + %0 = tensor.empty() : tensor<32x32xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + return %1 : tensor<32x32xbf16> + } +}