Skip to content

Commit

Permalink
Adding bf16 data type workaround for max_pool2d op.
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Dec 26, 2024
1 parent 436a9b8 commit 5ba32ac
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 50 deletions.
11 changes: 2 additions & 9 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,6 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::llvm::ArrayRef<int64_t> inputShape = getInput().getType().getShape();

if (!inputType.getElementType().isBF16()) {
return emitOpError()
<< "ttnn.max_pool2d currently only supports an input type of "
"bfloat16. Recieved "
<< inputType.getElementType() << ".";
}

if (getKernelHeight() > getInputHeight()) {
return emitOpError() << "Kernel height " << getKernelHeight()
<< " is greater than input height " << getInputHeight()
Expand All @@ -111,13 +104,13 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() {

if (inputType.getRank() != 4) {
return emitOpError()
<< "Input tensor rank must be 4. Recieved input with rank "
<< "Input tensor rank must be 4. Received input with rank "
<< inputType.getRank() << ". Shape: (" << inputShape << ").";
}

if (inputShape[0] != 1 || inputShape[1] != 1) {
return emitOpError() << "Maxpool input must be in the form (1, 1, N*H*W, "
"C). Recieved shape ("
"C). Received shape ("
<< inputShape << ").";
}

Expand Down
20 changes: 12 additions & 8 deletions lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,17 @@ TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(Operation *op) {
// tile layout, but the output of the operation is strictly in row-major layout.
// In order to keep the output consistent with the input, the row-major
// workaround is applied to both the input and output operands.
// The input and output operands are expected to use the bf16 data type, so the
// bf16 workaround is applied to both the input and output operands.
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() {
wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround =
wa::TTNNOperandWorkarounds(Layout::RowMajor);
wa::TTNNOperandWorkarounds rowMajorLayoutBF16Workaround;
rowMajorLayoutBF16Workaround.tensorLayoutWorkaround = Layout::RowMajor;
rowMajorLayoutBF16Workaround.tensorDataTypeWorkaround = DataType::BFloat16;
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addInputOperandWorkaround(rowMajorLayoutWorkaround)
.addInputOperandWorkaround(rowMajorLayoutWorkaround)
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
.addInputOperandWorkaround(rowMajorLayoutBF16Workaround)
.addInputOperandWorkaround(rowMajorLayoutBF16Workaround)
.addOutputOperandWorkaround(rowMajorLayoutBF16Workaround);
}

// Factory method to create a set of workarounds for embedding operation
Expand All @@ -107,12 +110,13 @@ TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() {
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds() {
// Create input and weight workarounds.
TTNNOperandWorkarounds inputWorkaround =
TTNNOperandWorkarounds(Layout::RowMajor);
TTNNOperandWorkarounds inputRowMajorInt32Workaround;
inputRowMajorInt32Workaround.tensorLayoutWorkaround = Layout::RowMajor;
inputRowMajorInt32Workaround.tensorDataTypeWorkaround = DataType::UInt32;
TTNNOperandWorkarounds weightWorkaround =
TTNNOperandWorkarounds(DataType::BFloat16);
return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(0, 0)
.addInputOperandWorkaround(inputWorkaround)
.addInputOperandWorkaround(inputRowMajorInt32Workaround)
.addInputOperandWorkaround(weightWorkaround)
.addInputOperandWorkaround(weightWorkaround)
.addOutputOperandWorkaround(weightWorkaround);
Expand Down
17 changes: 13 additions & 4 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ struct Env {
constexpr static Env
#endif
get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true,
bool readUpdateIndexFromDeviceForKVCache = true)
bool readUpdateIndexFromDeviceForKVCache = true,
bool toDtypeOnHost = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true);
return Env(true, true, true, true);
}
#endif
// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d
Expand All @@ -39,13 +40,19 @@ struct Env {
// to be able to pluck this update index from a runtime tensor.
bool readUpdateIndexFromDeviceForKVCache;

// TODO(bug #1658): We're currently use ttnn::to_dtype operation to cast the
// data type of a tensor on host. Once we have improved the typecast operation
// to handle this, we should remove this workaround.
bool toDtypeOnHost;

private:
constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache)
bool readUpdateIndexFromDeviceForKVCache, bool toDtypeOnHost)
: maxpool2dPreshard(maxpool2dPreshard),
swapBinaryOperands(swapBinaryOperands),
readUpdateIndexFromDeviceForKVCache(
readUpdateIndexFromDeviceForKVCache) {}
readUpdateIndexFromDeviceForKVCache),
toDtypeOnHost(toDtypeOnHost) {}
};

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
Expand All @@ -57,6 +64,8 @@ inline std::ostream &operator<<(std::ostream &os, const Env &env) {
os << "\t"
<< "readUpdateIndexFromDeviceForKVCache: "
<< env.readUpdateIndexFromDeviceForKVCache << "\n";
os << "\t"
<< "toDtypeOnHost: " << env.toDtypeOnHost << "\n";
os << "}";
return os;
}
Expand Down
5 changes: 3 additions & 2 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache) {
bool readUpdateIndexFromDeviceForKVCache,
bool toDtypeOnHost) {
static const Env config(maxpool2dPreshard, swapBinaryOperands,
readUpdateIndexFromDeviceForKVCache);
readUpdateIndexFromDeviceForKVCache, toDtypeOnHost);
return config;
}
#endif
Expand Down
11 changes: 10 additions & 1 deletion runtime/lib/ttnn/operations/layout/typecast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

#include "operations/layout/typecast.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/detail/workarounds.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"
#include "ttnn/operations/core/core.hpp"

namespace tt::runtime::ttnn::operations::layout {

Expand All @@ -16,7 +18,14 @@ void run(const ::tt::target::ttnn::TypecastOp *op, ProgramContext &context) {
::ttnn::DataType targetDataType =
::tt::runtime::ttnn::utils::toTTNNDataType(op->dtype());

::ttnn::Tensor out = ::ttnn::typecast(inputTensor, targetDataType);
::ttnn::Tensor out;
if (workaround::Env::get().toDtypeOnHost &&
::tt::runtime::ttnn::utils::isOnHost(inputTensor.storage_type())) {
out = ::ttnn::to_dtype(inputTensor, targetDataType);
} else {
out = ::ttnn::typecast(inputTensor, targetDataType);
}

tensorPool.insert_or_assign(op->out()->global_id(), out);
}

Expand Down
8 changes: 8 additions & 0 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def initialize_api():
choices=[True, False],
help="disable read update index for kv cache workaround",
)
Run.register_arg(
name="--disable-to-dtype-on-host",
type=bool,
default=False,
choices=[True, False],
help="disable to_dtype on host workaround",
)
Run.register_arg(
name="--result-file",
type=str,
Expand Down Expand Up @@ -379,6 +386,7 @@ def _execute(binaries):
not self["--disable-maxpool2d-preshard"],
not self["--disable-swap-binary-operands"],
not self["--disable-read-update-index-for-kv-cache"],
not self["--disable-to-dtype-on-host"],
)
self.logging.debug(f"setting tt runtime workaround env={workaround_env}")
self.logging.debug(f"setting torch manual seed={self['--seed']}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} {
// CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]])
// Check that the input operand is transformed into the row major layout.
// CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout"
// CHECK-SAME: dtype = #tt.supportedDataTypes<u32>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<32x32>>, <interleaved>>
// CHECK-SAME: -> tensor<32x32xf32
// CHECK-SAME: -> tensor<32x32xi32
// Check that the data type of the weight operand is transformed in bf16.
// CHECK-NEXT: %[[TO_LAYOUT_WEIGHTS:.*]] = "ttnn.to_layout"
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,41 @@
#dram = #ttnn.buffer_type<dram>
#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#system_memory = #ttnn.buffer_type<system_memory>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<16384x32xbf16, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<4096x32xbf16, #system_memory>>
#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<512x1x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<128x1x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<16384x32xf32, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<4096x32xf32, #system_memory>>
#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<512x1x!tt.tile<32x32, f32>, #dram>, <interleaved>>
#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<128x1x!tt.tile<32x32, f32>, #dram>, <interleaved>>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func @forward(%arg0: tensor<1x128x128x32xbf16, #ttnn_layout>) -> tensor<1x64x64x32xbf16, #ttnn_layout1> {
func.func @forward(%arg0: tensor<1x128x128x32xf32, #ttnn_layout>) -> tensor<1x64x64x32xf32, #ttnn_layout1> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
// CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device"[[C:.*]]
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<512x1>>, <interleaved>>}> : (tensor<1x128x128x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<1x128x128x32xbf16, #ttnn_layout2>
%2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 1 : i32, 16384 : i32, 32 : i32]}> : (tensor<1x128x128x32xbf16, #ttnn_layout2>) -> tensor<1x1x16384x32xbf16, #ttnn_layout2>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<128x1>>, <interleaved>>, shape = #ttnn.shape<1x1x4096x32>}> : (!tt.device<#device>) -> tensor<1x1x4096x32xbf16, #ttnn_layout3>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<512x1>>, <interleaved>>}> : (tensor<1x128x128x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x128x128x32xf32, #ttnn_layout2>
%2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 1 : i32, 16384 : i32, 32 : i32]}> : (tensor<1x128x128x32xf32, #ttnn_layout2>) -> tensor<1x1x16384x32xf32, #ttnn_layout2>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<128x1>>, <interleaved>>, shape = #ttnn.shape<1x1x4096x32>}> : (!tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3>
// CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]])
// Check that the input operand is transformed into the row major layout.
// CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout"
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16384x32>>, <interleaved>>
// CHECK-SAME: -> tensor<1x1x16384x32xbf16,
// Check that the output operand is transformed into the row major layout.
// CHECK-NEXT: %[[TO_LAYOUT_OUTPUT_DPS:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]])
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<4096x32>>, <interleaved>>
// CHECK-SAME: -> tensor<1x1x4096x32xbf16,
%4 = "ttnn.max_pool2d"(%2, %3, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xbf16, #ttnn_layout2>, tensor<1x1x4096x32xbf16, #ttnn_layout3>, !tt.device<#device>) -> tensor<1x1x4096x32xbf16, #ttnn_layout3>
%4 = "ttnn.max_pool2d"(%2, %3, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, tensor<1x1x4096x32xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3>
// CHECK-NEXT: %[[MAX_POOL_2D_OP:.*]] = "ttnn.max_pool2d"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_OUTPUT_DPS]], %[[DEVICE_OP]])
// Check that the output operand is transformed back into the tile and f32 data type.
// CHECK-NEXT: %[[TO_LAYOUT_OUTPUT:.*]] = "ttnn.to_layout"(%[[MAX_POOL_2D_OP]], %[[DEVICE_OP]])
// CHECK-SAME: dtype = #tt.supportedDataTypes<f32>
// CHECK-SAME: layout = #ttnn.layout<tile>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<128x1>>, <interleaved>>
// CHECK-SAME: -> tensor<1x1x4096x32xbf16
%5 = "ttnn.reshape"(%4) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xbf16, #ttnn_layout3>) -> tensor<1x64x64x32xbf16, #ttnn_layout3>
// CHECK-SAME: -> tensor<1x1x4096x32xf32
%5 = "ttnn.reshape"(%4) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout3>
// CHECK-NEXT: ttnn.reshape
%6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xbf16, #ttnn_layout3>) -> tensor<1x64x64x32xbf16, #ttnn_layout1>
return %6 : tensor<1x64x64x32xbf16, #ttnn_layout1>
%6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout1>
return %6 : tensor<1x64x64x32xf32, #ttnn_layout1>
}
}
8 changes: 4 additions & 4 deletions test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
module attributes {} {
func.func @forward(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> {
%0 = tensor.empty() : tensor<1x64x64x32xbf16>
func.func @forward(%arg0: tensor<1x128x128x32xf32>) -> tensor<1x64x64x32xf32> {
%0 = tensor.empty() : tensor<1x64x64x32xf32>
// CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]]
%1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16>
return %1 : tensor<1x64x64x32xbf16>
%1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xf32>, tensor<1x64x64x32xf32>) -> tensor<1x64x64x32xf32>
return %1 : tensor<1x64x64x32xf32>
}
}
8 changes: 4 additions & 4 deletions test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x32x128xbf16>
%0 = tensor.empty() : tensor<32x32x128xf32>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16>
return %1 : tensor<32x32x128xbf16>
%1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32>
return %1 : tensor<32x32x128xf32>
}
}
8 changes: 4 additions & 4 deletions test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> {
%0 = tensor.empty() : tensor<1x32x64x64xbf16>
func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x64x64xf32> {
%0 = tensor.empty() : tensor<1x32x64x64xf32>
// CHECK: "ttnn.permute"
// CHECK-SAME: permutation = array<i64: 0, 2, 3, 1>
// CHECK: "ttnn.max_pool2d"
Expand All @@ -16,7 +16,7 @@ module attributes {} {
window_strides = array<i64: 1, 1, 2, 2>,
base_dilations = array<i64: 1, 1, 1, 1>,
window_dilations = array<i64: 1, 1, 1, 1>,
padding = array<i64: 0, 0, 0, 0, 0, 0, 0, 0>}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16>
return %1 : tensor<1x32x64x64xbf16>
padding = array<i64: 0, 0, 0, 0, 0, 0, 0, 0>}> : (tensor<1x32x128x128xf32>, tensor<1x32x64x64xf32>) -> tensor<1x32x64x64xf32>
return %1 : tensor<1x32x64x64xf32>
}
}

0 comments on commit 5ba32ac

Please sign in to comment.