diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 286393858..ce4edfb8b 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -90,13 +90,6 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() { ::mlir::RankedTensorType inputType = getInput().getType(); ::llvm::ArrayRef inputShape = getInput().getType().getShape(); - if (!inputType.getElementType().isBF16()) { - return emitOpError() - << "ttnn.max_pool2d currently only supports an input type of " - "bfloat16. Recieved " - << inputType.getElementType() << "."; - } - if (getKernelHeight() > getInputHeight()) { return emitOpError() << "Kernel height " << getKernelHeight() << " is greater than input height " << getInputHeight() @@ -111,13 +104,13 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() { if (inputType.getRank() != 4) { return emitOpError() - << "Input tensor rank must be 4. Recieved input with rank " + << "Input tensor rank must be 4. Received input with rank " << inputType.getRank() << ". Shape: (" << inputShape << ")."; } if (inputShape[0] != 1 || inputShape[1] != 1) { return emitOpError() << "Maxpool input must be in the form (1, 1, N*H*W, " - "C). Recieved shape (" + "C). Received shape (" << inputShape << ")."; } diff --git a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp index 00b705e14..a4adec064 100644 --- a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp @@ -7,6 +7,7 @@ #include "ttmlir/Utils.h" #include "llvm/ADT/SmallVector.h" +#include namespace mlir::tt::ttnn::wa { @@ -83,14 +84,17 @@ TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(Operation *op) { // tile layout, but the output of the operation is strictly in row-major layout. // In order to keep the output consistent with the input, the row-major // workaround is applied to both the input and output operands. +// The input and output operands are expected to use the bf16 data type, so the +// bf16 workaround is applied to both the input and output operands. TTNNOperandsWorkarounds TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() { - wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround = - wa::TTNNOperandWorkarounds(Layout::RowMajor); + wa::TTNNOperandWorkarounds rowMajorLayoutBF16Workaround; + rowMajorLayoutBF16Workaround.tensorLayoutWorkaround = Layout::RowMajor; + rowMajorLayoutBF16Workaround.tensorDataTypeWorkaround = DataType::BFloat16; return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() - .addInputOperandWorkaround(rowMajorLayoutWorkaround) - .addInputOperandWorkaround(rowMajorLayoutWorkaround) - .addOutputOperandWorkaround(rowMajorLayoutWorkaround); + .addInputOperandWorkaround(rowMajorLayoutBF16Workaround) + .addInputOperandWorkaround(rowMajorLayoutBF16Workaround) + .addOutputOperandWorkaround(rowMajorLayoutBF16Workaround); } // Factory method to create a set of workarounds for embedding operation @@ -107,12 +111,13 @@ TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() { TTNNOperandsWorkarounds TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds() { // Create input and weight workarounds. - TTNNOperandWorkarounds inputWorkaround = - TTNNOperandWorkarounds(Layout::RowMajor); + TTNNOperandWorkarounds inputRowMajorInt32Workaround; + inputRowMajorInt32Workaround.tensorLayoutWorkaround = Layout::RowMajor; + inputRowMajorInt32Workaround.tensorDataTypeWorkaround = DataType::UInt32; TTNNOperandWorkarounds weightWorkaround = TTNNOperandWorkarounds(DataType::BFloat16); return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(0, 0) - .addInputOperandWorkaround(inputWorkaround) + .addInputOperandWorkaround(inputRowMajorInt32Workaround) .addInputOperandWorkaround(weightWorkaround) .addInputOperandWorkaround(weightWorkaround) .addOutputOperandWorkaround(weightWorkaround); diff --git a/runtime/include/tt/runtime/detail/workarounds.h b/runtime/include/tt/runtime/detail/workarounds.h index a58675752..12607bde6 100644 --- a/runtime/include/tt/runtime/detail/workarounds.h +++ b/runtime/include/tt/runtime/detail/workarounds.h @@ -16,12 +16,13 @@ struct Env { constexpr static Env #endif get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true, - bool readUpdateIndexFromDeviceForKVCache = true) + bool readUpdateIndexFromDeviceForKVCache = true, + bool typecastOnHost = true) #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 ; #else { - return Env(true, true, true); + return Env(true, true, true, true); } #endif // TODO(bug #855): Ideally we should have an op that preshards for maxpool2d @@ -39,13 +40,19 @@ struct Env { // to be able to pluck this update index from a runtime tensor. bool readUpdateIndexFromDeviceForKVCache; + // TODO(bug #1658): We're currently use ttnn::to_dtype operation to cast the + // data type of a tensor on host. Once we have improved the typecast operation + // to handle this, we should remove this workaround. + bool typecastOnHost; + private: constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands, - bool readUpdateIndexFromDeviceForKVCache) + bool readUpdateIndexFromDeviceForKVCache, bool typecastOnHost) : maxpool2dPreshard(maxpool2dPreshard), swapBinaryOperands(swapBinaryOperands), readUpdateIndexFromDeviceForKVCache( - readUpdateIndexFromDeviceForKVCache) {} + readUpdateIndexFromDeviceForKVCache), + typecastOnHost(typecastOnHost) {} }; inline std::ostream &operator<<(std::ostream &os, const Env &env) { @@ -57,6 +64,8 @@ inline std::ostream &operator<<(std::ostream &os, const Env &env) { os << "\t" << "readUpdateIndexFromDeviceForKVCache: " << env.readUpdateIndexFromDeviceForKVCache << "\n"; + os << "\t" + << "typecastOnHost: " << env.typecastOnHost << "\n"; os << "}"; return os; } diff --git a/runtime/lib/common/workarounds.cpp b/runtime/lib/common/workarounds.cpp index a9dbf7564..63adfc8d4 100644 --- a/runtime/lib/common/workarounds.cpp +++ b/runtime/lib/common/workarounds.cpp @@ -7,9 +7,10 @@ namespace tt::runtime::workaround { #if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1 const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands, - bool readUpdateIndexFromDeviceForKVCache) { + bool readUpdateIndexFromDeviceForKVCache, + bool typecastOnHost) { static const Env config(maxpool2dPreshard, swapBinaryOperands, - readUpdateIndexFromDeviceForKVCache); + readUpdateIndexFromDeviceForKVCache, typecastOnHost); return config; } #endif diff --git a/runtime/lib/ttnn/operations/layout/typecast.cpp b/runtime/lib/ttnn/operations/layout/typecast.cpp index 63a9ba63d..9d8b868f4 100644 --- a/runtime/lib/ttnn/operations/layout/typecast.cpp +++ b/runtime/lib/ttnn/operations/layout/typecast.cpp @@ -4,8 +4,10 @@ #include "operations/layout/typecast.h" #include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/detail/workarounds.h" #include "tt/runtime/ttnn/operations/utils.h" #include "tt/runtime/ttnn/utils.h" +#include "ttnn/operations/core/core.hpp" namespace tt::runtime::ttnn::operations::layout { @@ -16,7 +18,14 @@ void run(const ::tt::target::ttnn::TypecastOp *op, ProgramContext &context) { ::ttnn::DataType targetDataType = ::tt::runtime::ttnn::utils::toTTNNDataType(op->dtype()); - ::ttnn::Tensor out = ::ttnn::typecast(inputTensor, targetDataType); + ::ttnn::Tensor out; + if (workaround::Env::get().typecastOnHost && + ::tt::runtime::ttnn::utils::isOnHost(inputTensor.storage_type())) { + out = ::ttnn::to_dtype(inputTensor, targetDataType); + } else { + out = ::ttnn::typecast(inputTensor, targetDataType); + } + tensorPool.insert_or_assign(op->out()->global_id(), out); } diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir index 1aa9f55ac..66128ff62 100644 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/embedding_workaround.mlir @@ -19,9 +19,10 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} { // CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) // Check that the input operand is transformed into the row major layout. // CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout" + // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<32x32>>, > - // CHECK-SAME: -> tensor<32x32xf32 + // CHECK-SAME: -> tensor<32x32xi32 // Check that the data type of the weight operand is transformed in bf16. // CHECK-NEXT: %[[TO_LAYOUT_WEIGHTS:.*]] = "ttnn.to_layout" // CHECK-SAME: dtype = #tt.supportedDataTypes diff --git a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir index f24ff1184..f579c947f 100644 --- a/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir +++ b/test/ttmlir/Dialect/TTNN/Transforms/Workarounds/max_pool2d_workaround.mlir @@ -3,37 +3,41 @@ #dram = #ttnn.buffer_type #system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> #system_memory = #ttnn.buffer_type -#ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<16384x32xbf16, #system_memory>> -#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<4096x32xbf16, #system_memory>> -#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<512x1x!tt.tile<32x32, bf16>, #dram>, > -#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<128x1x!tt.tile<32x32, bf16>, #dram>, > +#ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<16384x32xf32, #system_memory>> +#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<4096x32xf32, #system_memory>> +#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<512x1x!tt.tile<32x32, f32>, #dram>, > +#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<128x1x!tt.tile<32x32, f32>, #dram>, > module attributes {tt.device = #device, tt.system_desc = #system_desc} { - func.func @forward(%arg0: tensor<1x128x128x32xbf16, #ttnn_layout>) -> tensor<1x64x64x32xbf16, #ttnn_layout1> { + func.func @forward(%arg0: tensor<1x128x128x32xf32, #ttnn_layout>) -> tensor<1x64x64x32xf32, #ttnn_layout1> { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn}> : () -> !tt.device<#device> // CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device"[[C:.*]] - %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<512x1>>, >}> : (tensor<1x128x128x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<1x128x128x32xbf16, #ttnn_layout2> - %2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 1 : i32, 16384 : i32, 32 : i32]}> : (tensor<1x128x128x32xbf16, #ttnn_layout2>) -> tensor<1x1x16384x32xbf16, #ttnn_layout2> - %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<128x1>>, >, shape = #ttnn.shape<1x1x4096x32>}> : (!tt.device<#device>) -> tensor<1x1x4096x32xbf16, #ttnn_layout3> + %1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<512x1>>, >}> : (tensor<1x128x128x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x128x128x32xf32, #ttnn_layout2> + %2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 1 : i32, 16384 : i32, 32 : i32]}> : (tensor<1x128x128x32xf32, #ttnn_layout2>) -> tensor<1x1x16384x32xf32, #ttnn_layout2> + %3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#dram, <<128x1>>, >, shape = #ttnn.shape<1x1x4096x32>}> : (!tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3> // CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]]) // Check that the input operand is transformed into the row major layout. // CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout" + // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16384x32>>, > // CHECK-SAME: -> tensor<1x1x16384x32xbf16, // Check that the output operand is transformed into the row major layout. // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT_DPS:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<4096x32>>, > // CHECK-SAME: -> tensor<1x1x4096x32xbf16, - %4 = "ttnn.max_pool2d"(%2, %3, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xbf16, #ttnn_layout2>, tensor<1x1x4096x32xbf16, #ttnn_layout3>, !tt.device<#device>) -> tensor<1x1x4096x32xbf16, #ttnn_layout3> + %4 = "ttnn.max_pool2d"(%2, %3, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, tensor<1x1x4096x32xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3> // CHECK-NEXT: %[[MAX_POOL_2D_OP:.*]] = "ttnn.max_pool2d"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_OUTPUT_DPS]], %[[DEVICE_OP]]) + // Check that the output operand is transformed back into the tile and f32 data type. // CHECK-NEXT: %[[TO_LAYOUT_OUTPUT:.*]] = "ttnn.to_layout"(%[[MAX_POOL_2D_OP]], %[[DEVICE_OP]]) + // CHECK-SAME: dtype = #tt.supportedDataTypes // CHECK-SAME: layout = #ttnn.layout // CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<128x1>>, > - // CHECK-SAME: -> tensor<1x1x4096x32xbf16 - %5 = "ttnn.reshape"(%4) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xbf16, #ttnn_layout3>) -> tensor<1x64x64x32xbf16, #ttnn_layout3> + // CHECK-SAME: -> tensor<1x1x4096x32xf32 + %5 = "ttnn.reshape"(%4) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout3> // CHECK-NEXT: ttnn.reshape - %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xbf16, #ttnn_layout3>) -> tensor<1x64x64x32xbf16, #ttnn_layout1> - return %6 : tensor<1x64x64x32xbf16, #ttnn_layout1> + %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes, layout = #ttnn.layout, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout1> + return %6 : tensor<1x64x64x32xf32, #ttnn_layout1> } } diff --git a/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir b/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir index dc48662d7..801b240f8 100644 --- a/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir +++ b/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir @@ -1,9 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s module attributes {} { - func.func @forward(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> { - %0 = tensor.empty() : tensor<1x64x64x32xbf16> + func.func @forward(%arg0: tensor<1x128x128x32xf32>) -> tensor<1x64x64x32xf32> { + %0 = tensor.empty() : tensor<1x64x64x32xf32> // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] - %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> - return %1 : tensor<1x64x64x32xbf16> + %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xf32>, tensor<1x64x64x32xf32>) -> tensor<1x64x64x32xf32> + return %1 : tensor<1x64x64x32xf32> } } diff --git a/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir b/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir index 583aa82e0..65936b18b 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir @@ -2,11 +2,11 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module attributes {} { - func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> { + func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] - %0 = tensor.empty() : tensor<32x32x128xbf16> + %0 = tensor.empty() : tensor<32x32x128xf32> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> - return %1 : tensor<32x32x128xbf16> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32> + return %1 : tensor<32x32x128xf32> } } diff --git a/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir b/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir index 8b68ad5ec..fe873c2fd 100644 --- a/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir +++ b/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir @@ -2,8 +2,8 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn module attributes {} { - func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { - %0 = tensor.empty() : tensor<1x32x64x64xbf16> + func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x64x64xf32> { + %0 = tensor.empty() : tensor<1x32x64x64xf32> // CHECK: "ttnn.permute" // CHECK-SAME: permutation = array // CHECK: "ttnn.max_pool2d" @@ -16,7 +16,7 @@ module attributes {} { window_strides = array, base_dilations = array, window_dilations = array, - padding = array}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> - return %1 : tensor<1x32x64x64xbf16> + padding = array}> : (tensor<1x32x128x128xf32>, tensor<1x32x64x64xf32>) -> tensor<1x32x64x64xf32> + return %1 : tensor<1x32x64x64xf32> } }