diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 14eccbe53..8ef3f73ca 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -903,11 +903,12 @@ def TTNN_EmptyOp : TTNN_Op<"empty"> { Tensor empty operation }]; - let arguments = (ins Optional:$device, - TTNN_ShapeAttr:$shape, - OptionalAttr:$dtype, - OptionalAttr:$layout, - OptionalAttr:$memory_config); + let arguments = (ins TTNN_ShapeAttr:$shape, + TT_DataTypeAttr:$dtype, + TTNN_LayoutAttr:$layout, + TT_Device:$device, + TTNN_MemoryConfigAttr:$memory_config); + let results = (outs AnyRankedTensor:$result); let hasVerifier = 1; diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 193c90047..81aaf868b 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -64,33 +64,26 @@ class TensorEmptyConversionPattern ttnn::LayoutAttr tensorLayoutAttr = ttnn::LayoutAttr::get(op.getContext(), ttnnLayoutEnum); - // If the tensor is not going to device, we can create the op without - // device-specific attributes + // Device // - ttnn::TensorMemoryLayoutAttr memLayout = layoutAttr.getMemLayout(); - if (!memLayout) { - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), nullptr, - shapeAttr, dTypeAttr, tensorLayoutAttr, nullptr); - - return success(); - } - - ttnn::BufferType bufferType = layoutAttr.getBufferType(); + auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); // Create MemoryConfigAttr // - auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op); - llvm::SmallVector shardShape = layoutAttr.getShardShape(); - ttnn::MemoryConfigAttr memoryConfigAttr = ttnn::MemoryConfigAttr::get( - op.getContext(), ttnn::BufferTypeAttr::get(op.getContext(), bufferType), - ttnn::ShardSpecAttr::get( - op.getContext(), ttnn::ShapeAttr::get(op.getContext(), shardShape)), - memLayout); + ttnn::BufferTypeAttr bufferTypeAttr = + ttnn::BufferTypeAttr::get(op.getContext(), layoutAttr.getBufferType()); + ttnn::ShardSpecAttr shardSpecAttr = ttnn::ShardSpecAttr::get( + op.getContext(), + ttnn::ShapeAttr::get(op.getContext(), layoutAttr.getShardShape())); + ttnn::MemoryConfigAttr memoryConfigAttr = + ttnn::MemoryConfigAttr::get(op.getContext(), bufferTypeAttr, + shardSpecAttr, layoutAttr.getMemLayout()); + // Replace op + // rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), device, - shapeAttr, dTypeAttr, tensorLayoutAttr, memoryConfigAttr); + op, this->getTypeConverter()->convertType(op.getType()), shapeAttr, + dTypeAttr, tensorLayoutAttr, device, memoryConfigAttr); return success(); } diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index fda64f925..dda73e8df 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -113,6 +113,43 @@ class DefaultOpConversionPattern } }; +// Eltwise Unary op conversion pattern +// +// Currently, it has to insert nullopts for some parameters that are not +// modelled in the dialect (memcfg) +// +template +class EltwiseUnaryOpConversionPattern + : public TTNNToEmitCBaseOpConversionPattern { + +public: + EltwiseUnaryOpConversionPattern(const TypeConverter &typeConverter, + MLIRContext *context, + PatternBenefit benefit = 1) + : TTNNToEmitCBaseOpConversionPattern(typeConverter, context, + benefit) {} + + LogicalResult + matchAndRewrite(SourceOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so + // an ArrayAttr object holding IndexTypes is created to denote this + // + llvm::SmallVector attrs; + attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 0)); + attrs.push_back(ttnn_to_emitc::utils::createStdNullopt(rewriter)); + attrs.push_back(mlir::IntegerAttr::get(rewriter.getIndexType(), 1)); + + ArrayAttr arrayAttrs = ArrayAttr::get(srcOp->getContext(), attrs); + + rewriter.replaceOpWithNewOp( + srcOp, this->getTypeConverter()->convertType(srcOp.getType(0)), + this->convertOpName(srcOp), arrayAttrs, nullptr, adaptor.getOperands()); + + return success(); + } +}; + // Eltwise Binary op conversion pattern // // Currently, it has to insert nullopts for some parameters that are not @@ -132,6 +169,7 @@ class EltwiseBinaryOpConversionPattern LogicalResult matchAndRewrite(SourceOp srcOp, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so // an ArrayAttr object holding IndexTypes is created to denote this // @@ -152,6 +190,50 @@ class EltwiseBinaryOpConversionPattern } }; +// Matmul op conversion pattern +// +class MatmulOpConversionPattern + : public TTNNToEmitCBaseOpConversionPattern { + +public: + MatmulOpConversionPattern(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : TTNNToEmitCBaseOpConversionPattern(typeConverter, + context, benefit) {} + + LogicalResult + matchAndRewrite(ttnn::MatmulOp matmulOp, ttnn::MatmulOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // emitc::CallOpaqueOp needs to know positions of operands vs attributes, so + // an ArrayAttr object holding IndexTypes is created to denote this + // + ArrayAttr arrayAttrs = rewriter.getArrayAttr({ + mlir::IntegerAttr::get(rewriter.getIndexType(), 0), + mlir::IntegerAttr::get(rewriter.getIndexType(), 1), + ttnn_to_emitc::utils::convertBoolAttr( + rewriter, BoolAttr::get(rewriter.getContext(), false)), + ttnn_to_emitc::utils::convertBoolAttr( + rewriter, BoolAttr::get(rewriter.getContext(), false)), + ttnn_to_emitc::utils::createStdNullopt(rewriter), + ttnn_to_emitc::utils::createStdNullopt(rewriter), + ttnn_to_emitc::utils::createStdNullopt(rewriter), + ttnn_to_emitc::utils::createStdNullopt(rewriter), + ttnn_to_emitc::utils::createStdNullopt(rewriter), + ttnn_to_emitc::utils::createStdNullopt(rewriter), + ttnn_to_emitc::utils::createStdNullopt(rewriter), + mlir::IntegerAttr::get(rewriter.getIndexType(), 2), + }); + + rewriter.replaceOpWithNewOp( + matmulOp, this->getTypeConverter()->convertType(matmulOp.getType()), + this->convertOpName(matmulOp), arrayAttrs, nullptr, + adaptor.getOperands()); + + return success(); + } +}; + // GetDeviceOp conversion pattern // class GetDeviceOpConversionPattern @@ -390,46 +472,42 @@ class EmptyOpConversionPattern tt::DataTypeAttr dataTypeAttr = srcOp.getDtypeAttr(); ttnn::LayoutAttr layoutAttr = srcOp.getLayoutAttr(); + // Find the GetDeviceOp + // + ttnn::GetDeviceOp getDeviceOp; + srcOp->getParentOp()->walk( + [&getDeviceOp](ttnn::GetDeviceOp currGetDeviceOp) { + getDeviceOp = currGetDeviceOp; + }); + // Create ttnn::Shape() call // emitc::ExpressionOp shapeExpressionOp = ttnn_to_emitc::utils::createShapeOp( rewriter, shapeAttr, srcOp->getBlock(), srcOp.getLoc()); - llvm::SmallVector operands{ - shapeExpressionOp->getResult(0), - }; - - // If there is a device operand, create tensor on device + // Create operands vector // - ArrayAttr arrayAttr; - if (adaptor.getDevice()) { - operands.append(1, adaptor.getDevice()); + llvm::SmallVector operands{shapeExpressionOp->getResult(0), + adaptor.getDevice()}; - // Create MemoryConfig object first, then pass it to the op - // - emitc::CallOpaqueOp memCfgOp = ttnn_to_emitc::utils::createMemoryConfigOp( - rewriter, srcOp.getMemoryConfig().value(), srcOp.getLoc()); + // Create MemoryConfig object first, then pass it to the op + // + emitc::CallOpaqueOp memCfgOp = ttnn_to_emitc::utils::createMemoryConfigOp( + rewriter, srcOp.getMemoryConfig(), srcOp.getLoc()); - // Concat operands and MemoryConfig object - // - operands.append(1, memCfgOp.getResult(0)); + // Concat operands and MemoryConfig object + // + operands.append(1, memCfgOp.getResult(0)); - // Create ArrayAttr object holding attributes and pointers to operands - // - arrayAttr = rewriter.getArrayAttr({ - rewriter.getIndexAttr(0), // ttnn::Shape - ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr), - ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr), - rewriter.getIndexAttr(1), // ttnn::Device - rewriter.getIndexAttr(2), // ttnn::MemoryConfig - }); - } else { - arrayAttr = rewriter.getArrayAttr({ - rewriter.getIndexAttr(0), // ttnn::Shape - ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr), - ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr), - }); - } + // Create ArrayAttr object holding attributes and pointers to operands + // + ArrayAttr arrayAttr = rewriter.getArrayAttr({ + rewriter.getIndexAttr(0), // ttnn::Shape + ttnn_to_emitc::utils::convertDType(rewriter, dataTypeAttr), + ttnn_to_emitc::utils::convertLayoutAttr(rewriter, layoutAttr), + rewriter.getIndexAttr(1), // ttnn::Device + rewriter.getIndexAttr(2), // ttnn::MemoryConfig + }); // Finally, convert ttir::EmptyOp to ttnn::EmptyOp // @@ -469,14 +547,14 @@ class OnesOpConversionPattern // Attrs (like shape) need to be instantiated into objects before being // passed to the op. Therefore: // - // We first create a ttnn::Shape object (SSA) by calling createShapeOp() and - // add it to the operands vector, but also add an IndexAttr in ArrayAttr to - // reference it (this is an EmitC mechanism that allows for combining Attrs - // and Values when calling an OpaqueOp). - // All the other input params are optional, so we create them on-the-fly - // into the ArrayAttr, whether they are an actual Attr, or a Value pointed - // to by IndexAttr. If they are present, we create the object and pass it to - // the op. If not, we pass std::nullopt. + // We first create a ttnn::Shape object (SSA) by calling createShapeOp() + // and add it to the operands vector, but also add an IndexAttr in + // ArrayAttr to reference it (this is an EmitC mechanism that allows for + // combining Attrs and Values when calling an OpaqueOp). All the other + // input params are optional, so we create them on-the-fly into the + // ArrayAttr, whether they are an actual Attr, or a Value pointed to by + // IndexAttr. If they are present, we create the object and pass it to the + // op. If not, we pass std::nullopt. // Create ttnn::Shape() call // @@ -489,8 +567,8 @@ class OnesOpConversionPattern // Create ArrayAttr object holding attributes and pointers to operands // - // Params that are Values are added to the operands vector on-the-fly, and a - // corresponding IndexAttr is added to the ArrayAttr to reference them. + // Params that are Values are added to the operands vector on-the-fly, and + // a corresponding IndexAttr is added to the ArrayAttr to reference them. // size_t operandIndex = 0; ArrayAttr arrayAttr = rewriter.getArrayAttr({ @@ -594,8 +672,8 @@ class GetTupleElementOpConversionPattern getTupleElementOp->getLoc(), rewriter.getIndexType(), std::to_string(adaptor.getIndex())); - // SubscriptOp also returns an emitc::LValueType, so we wrap the OpaqueType - // with LValueType + // SubscriptOp also returns an emitc::LValueType, so we wrap the + // OpaqueType with LValueType // emitc::LValueType lvalueReturnType = emitc::LValueType::get( emitc::OpaqueType::get(rewriter.getContext(), "ttnn::Tensor")); @@ -621,9 +699,9 @@ class TupleOpConversionPattern : public OpConversionPattern { LogicalResult matchAndRewrite(tt::TupleOp tupleOp, tt::TupleOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // EmitC doesn't offer a way to create a vector from a list of values, so we - // need to create a utility function that does this. This is achieved by - // using EmitC's VerbatimOp. + // EmitC doesn't offer a way to create a vector from a list of values, so + // we need to create a utility function that does this. This is achieved + // by using EmitC's VerbatimOp. // Try to find if utility vec creation function is already defined in the // module. If not, insert it. @@ -708,7 +786,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, - DefaultOpConversionPattern, + EltwiseUnaryOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, @@ -761,7 +839,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Matmul ops // patterns.add, - DefaultOpConversionPattern>(typeConverter, ctx); + MatmulOpConversionPattern>(typeConverter, ctx); // Reduction ops // diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 286393858..82511a44e 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -174,9 +174,6 @@ ::mlir::LogicalResult mlir::tt::ttnn::ArangeOp::verify() { // EmptyOp verification ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() { - // ============================== - // === CHECK ATTRIBUTES START === - // ============================== // Check that the attributes of the op match the attributes of the output // tensor type. // @@ -192,50 +189,17 @@ ::mlir::LogicalResult mlir::tt::ttnn::EmptyOp::verify() { // DataType and Layout // - if (getLayout().has_value()) { - ttnn::Layout ttnnLayoutEnum = layoutAttr.getLayout(); - assert(ttnnLayoutEnum == getLayoutAttr().getValue()); - } - if (getDtype().has_value()) { - tt::DataType dtype = layoutAttr.getDataType(); - assert(dtype == getDtype()); - } + assert(getLayout() == layoutAttr.getLayout()); + assert(getDtype() == layoutAttr.getDataType()); // MemoryConfig - // Check that op has MemoryConfigAttr set on itself, then compare internal - // attrs with output tensor attrs. - // - if (getMemoryConfig().has_value()) { - ttnn::BufferType bufferType = layoutAttr.getBufferType(); - ttnn::TensorMemoryLayoutAttr tensorMemoryLayoutAttr = - layoutAttr.getMemLayout(); - assert(bufferType == getMemoryConfig()->getBufferType().getValue()); - assert(tensorMemoryLayoutAttr == - getMemoryConfig()->getTensorMemoryLayout()); - } + // Compare internal attrs with output tensor attrs. // - // ============================== - // ==== CHECK ATTRIBUTES END ==== - // ============================== - - // ============================== - // === CHECK SIGNATURES START === - // ============================== - // Check that call-site uses the correct signature. We only allow 2 for now: - // 1. none, Shape, DataType, Layout, none - // 2. Device, Shape, DataType, Layout, MemoryConfig - // - assert( - // 1. - (!getDevice() && getDtype().has_value() && getLayout().has_value() && - !getMemoryConfig().has_value()) || - // 2. - (getDevice() && getDtype().has_value() && getLayout().has_value() && - getMemoryConfig().has_value())); - // - // ============================== - // ==== CHECK SIGNATURES END ==== - // ============================== + assert(getMemoryConfig().getBufferType().getValue() == + layoutAttr.getBufferType()); + assert(getMemoryConfig().getTensorMemoryLayout() == + layoutAttr.getMemLayout()); + return success(); } diff --git a/lib/Dialect/TTNN/Transforms/Passes.cpp b/lib/Dialect/TTNN/Transforms/Passes.cpp index 2f06efb82..73f2eadbb 100644 --- a/lib/Dialect/TTNN/Transforms/Passes.cpp +++ b/lib/Dialect/TTNN/Transforms/Passes.cpp @@ -1003,12 +1003,9 @@ class TTNNCreateInputGenerators // Create a new tensor // - // TODO(svuckovic): Move from ttnn::EmptyOp to ttnn::OnesOp once #1476 - // lands - // - mlir::Value tensorValue = rewriter.create( - forwardFuncOp->getLoc(), tensorType, nullptr, shapeAttr, dTypeAttr, - tensorLayoutAttr, nullptr); + mlir::Value tensorValue = rewriter.create( + forwardFuncOp->getLoc(), tensorType, shapeAttr, dTypeAttr, + tensorLayoutAttr, nullptr, nullptr); generatedTensors.push_back(tensorValue); } diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 1e8492317..4f9dde1a6 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -325,9 +325,9 @@ ::flatbuffers::Offset<::tt::target::ttnn::EmptyOp> createOp(FlatbufferObjectCache &cache, EmptyOp op) { ::llvm::ArrayRef shape = op.getShape().getShape(); ::tt::target::DataType dtype = - ::tt::mlir::ttnn::utils::toTargetDataType(op.getDtype().value()); + ::tt::mlir::ttnn::utils::toTargetDataType(op.getDtype()); ::tt::target::TensorLayout layout = - ::tt::mlir::ttnn::utils::toTargetTensorLayout(op.getLayout().value()); + ::tt::mlir::ttnn::utils::toTargetTensorLayout(op.getLayout()); uint32_t numShards = 1; auto strategy = createDistributionStrategy( @@ -335,19 +335,10 @@ createOp(FlatbufferObjectCache &cache, EmptyOp op) { numShards); auto output = getOperandThroughDPSOps(op.getResult()); - // If the device is not set, we create on host - if (!op.getDevice()) { - return ::tt::target::ttnn::CreateEmptyOp( - *cache.fbb, cache.fbb->CreateVector(shape), dtype, layout, - numShards, /* device */ 0, /* memcfg */ 0, strategy, - cache.getOrCreate(output, tensorValueToFlatbuffer, - kHostAllocatedAddress, kHostAllocatedSize)); - } - auto device = getOperandThroughDPSOps(op.getDevice()); auto memoryConfigDesc = - cache.getOrCreate(*op.getMemoryConfig(), memoryConfigToFlatbuffer); + cache.getOrCreate(op.getMemoryConfig(), memoryConfigToFlatbuffer); return ::tt::target::ttnn::CreateEmptyOp( *cache.fbb, cache.fbb->CreateVector(shape), dtype, layout, diff --git a/test/ttmlir/Silicon/TTNN/emitc/mnist.mlir b/test/ttmlir/Silicon/TTNN/emitc/mnist.mlir new file mode 100644 index 000000000..f3076360e --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/emitc/mnist.mlir @@ -0,0 +1,23 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" --ttnn-create-input-gens --convert-ttnn-to-emitc %s > %t.mlir + +module @MNISTLinear attributes {tt.system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>} { + func.func @forward(%arg0: tensor<1x784xf32> {ttir.name = "input_1"}, %arg1: tensor<784x512xf32> {ttir.name = "linear_relu_stack.0.weight"}, %arg2: tensor<512xf32> {ttir.name = "linear_relu_stack.0.bias"}, %arg3: tensor<512x512xf32> {ttir.name = "linear_relu_stack.2.weight"}, %arg4: tensor<512xf32> {ttir.name = "linear_relu_stack.2.bias"}, %arg5: tensor<512x10xf32> {ttir.name = "linear_relu_stack.4.weight"}, %arg6: tensor<10xf32> {ttir.name = "linear_relu_stack.4.bias"}) -> (tensor<1x10xf32> {ttir.name = "MNISTLinear_350.output_add_981"}) { + %0 = tensor.empty() : tensor<1x512xf32> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x784xf32>, tensor<784x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32> + %2 = tensor.empty() : tensor<1x512xf32> + %3 = "ttir.add"(%1, %arg2, %2) <{operandSegmentSizes = array}> : (tensor<1x512xf32>, tensor<512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32> + %4 = tensor.empty() : tensor<1x512xf32> + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32> + %6 = tensor.empty() : tensor<1x512xf32> + %7 = "ttir.matmul"(%5, %arg3, %6) : (tensor<1x512xf32>, tensor<512x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32> + %8 = tensor.empty() : tensor<1x512xf32> + %9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array}> : (tensor<1x512xf32>, tensor<512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32> + %10 = tensor.empty() : tensor<1x512xf32> + %11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array}> : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32> + %12 = tensor.empty() : tensor<1x10xf32> + %13 = "ttir.matmul"(%11, %arg5, %12) : (tensor<1x512xf32>, tensor<512x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + %14 = tensor.empty() : tensor<1x10xf32> + %15 = "ttir.add"(%13, %arg6, %14) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + return %15 : tensor<1x10xf32> + } +} diff --git a/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp b/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp index 10980cbda..a8ad3df44 100644 --- a/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp +++ b/test/unittests/Optimizer/TestGreedyL1InterleavedPolicy.cpp @@ -54,9 +54,9 @@ class GreedyL1InterleavedPolicyBase : public ::testing::Test { mlir::Value createEmptyTensor() { ShapeAttr shapeAttr = ShapeAttr::get(&context, getTensorShape()); - return builder.create(builder.getUnknownLoc(), - getTensorRankedType(), nullptr, shapeAttr, - nullptr, nullptr, nullptr); + return builder.create(builder.getUnknownLoc(), + getTensorRankedType(), shapeAttr, nullptr, + nullptr, nullptr, nullptr); } mlir::func::FuncOp createFuncOp() { diff --git a/test/unittests/Optimizer/TestShardSolver.cpp b/test/unittests/Optimizer/TestShardSolver.cpp index 9386cf24c..4693bc685 100644 --- a/test/unittests/Optimizer/TestShardSolver.cpp +++ b/test/unittests/Optimizer/TestShardSolver.cpp @@ -47,9 +47,9 @@ class ShardSolverBase : public ::testing::Test { mlir::Value createEmptyTensor() { ShapeAttr shapeAttr = ShapeAttr::get(&context, getTensorShape()); - return builder.create(builder.getUnknownLoc(), - getTensorRankedType(), nullptr, shapeAttr, - nullptr, nullptr, nullptr); + return builder.create(builder.getUnknownLoc(), + getTensorRankedType(), shapeAttr, nullptr, + nullptr, nullptr, nullptr); } mlir::func::FuncOp createFuncOp() { diff --git a/tools/ttnn-standalone/CMakeLists.txt b/tools/ttnn-standalone/CMakeLists.txt index 0be29d763..bf31fdc36 100644 --- a/tools/ttnn-standalone/CMakeLists.txt +++ b/tools/ttnn-standalone/CMakeLists.txt @@ -56,8 +56,9 @@ set(INCLUDE_DIRS # TODO: Remove these when ttmetal removes the dependencies from public facing headers $ENV{TT_METAL_HOME}/.cpmcache/reflect/e75434c4c5f669e4a74e4d84e0a30d7249c1e66f $ENV{TT_METAL_HOME}/.cpmcache/fmt/73b5ec45edbd92babfd91c3777a9e1ab9cac8238/include - $ENV{TT_METAL_HOME}/.cpmcache/magic_enum/1e1af177d4ab0ef660f105434fd1017c4d1f8c17/include/magic_enum + $ENV{TT_METAL_HOME}/.cpmcache/magic_enum/4d76fe0a5b27a0e62d6c15976d02b33c54207096/include $ENV{TT_METAL_HOME}/.cpmcache/boost_core/e679bef5c160cf29d0f37d549881dc5f5a58c332/include + $ENV{TT_METAL_HOME}/.cpmcache/json/230202b6f5267cbf0c8e5a2f17301964d95f83ff/include # Metalium $ENV{TT_METAL_HOME} @@ -70,6 +71,7 @@ set(INCLUDE_DIRS $ENV{TT_METAL_HOME}/tt_metal/hw/inc/${ARCH_EXTRA_DIR} $ENV{TT_METAL_HOME}/tt_metal/third_party/umd/src/firmware/riscv/${ARCH_NAME} $ENV{TT_METAL_HOME}/tt_metal/third_party/magic_enum + $ENV{TT_METAL_HOME}/tt_metal/third_party/tracy/public # TTNN $ENV{TT_METAL_HOME}/ttnn/cpp diff --git a/tools/ttnn-standalone/ttnn-precompiled.hpp b/tools/ttnn-standalone/ttnn-precompiled.hpp index 4b9894fd5..dfd74f3f3 100644 --- a/tools/ttnn-standalone/ttnn-precompiled.hpp +++ b/tools/ttnn-standalone/ttnn-precompiled.hpp @@ -13,6 +13,7 @@ #include "operations/eltwise/binary/binary.hpp" #include "operations/embedding/embedding.hpp" #include "operations/embedding_backward/embedding_backward.hpp" +#include "operations/matmul/matmul.hpp" #include "tensor/tensor.hpp" #include "tensor/types.hpp" #include "types.hpp"