Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding bf16 data type workaround for max_pool2d op and fixing embedding workaround tests #1657

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,6 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::llvm::ArrayRef<int64_t> inputShape = getInput().getType().getShape();

if (!inputType.getElementType().isBF16()) {
return emitOpError()
<< "ttnn.max_pool2d currently only supports an input type of "
"bfloat16. Recieved "
<< inputType.getElementType() << ".";
}

if (getKernelHeight() > getInputHeight()) {
return emitOpError() << "Kernel height " << getKernelHeight()
<< " is greater than input height " << getInputHeight()
Expand All @@ -111,13 +104,13 @@ ::mlir::LogicalResult mlir::tt::ttnn::MaxPool2dOp::verify() {

if (inputType.getRank() != 4) {
return emitOpError()
<< "Input tensor rank must be 4. Recieved input with rank "
<< "Input tensor rank must be 4. Received input with rank "
<< inputType.getRank() << ". Shape: (" << inputShape << ").";
}

if (inputShape[0] != 1 || inputShape[1] != 1) {
return emitOpError() << "Maxpool input must be in the form (1, 1, N*H*W, "
"C). Recieved shape ("
"C). Received shape ("
<< inputShape << ").";
}

Expand Down
20 changes: 12 additions & 8 deletions lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,17 @@ TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(Operation *op) {
// tile layout, but the output of the operation is strictly in row-major layout.
// In order to keep the output consistent with the input, the row-major
// workaround is applied to both the input and output operands.
// The input and output operands are expected to use the bf16 data type, so the
// bf16 workaround is applied to both the input and output operands.
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() {
wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround =
wa::TTNNOperandWorkarounds(Layout::RowMajor);
wa::TTNNOperandWorkarounds rowMajorLayoutBF16Workaround;
rowMajorLayoutBF16Workaround.tensorLayoutWorkaround = Layout::RowMajor;
rowMajorLayoutBF16Workaround.tensorDataTypeWorkaround = DataType::BFloat16;
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addInputOperandWorkaround(rowMajorLayoutWorkaround)
.addInputOperandWorkaround(rowMajorLayoutWorkaround)
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
.addInputOperandWorkaround(rowMajorLayoutBF16Workaround)
.addInputOperandWorkaround(rowMajorLayoutBF16Workaround)
.addOutputOperandWorkaround(rowMajorLayoutBF16Workaround);
}

// Factory method to create a set of workarounds for embedding operation
Expand All @@ -107,12 +110,13 @@ TTNNOperandsWorkaroundsFactory::createMaxPool2DOpOperandsWorkarounds() {
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createEmbeddingOpOperandsWorkarounds() {
// Create input and weight workarounds.
TTNNOperandWorkarounds inputWorkaround =
TTNNOperandWorkarounds(Layout::RowMajor);
TTNNOperandWorkarounds inputRowMajorInt32Workaround;
inputRowMajorInt32Workaround.tensorLayoutWorkaround = Layout::RowMajor;
inputRowMajorInt32Workaround.tensorDataTypeWorkaround = DataType::UInt32;
TTNNOperandWorkarounds weightWorkaround =
TTNNOperandWorkarounds(DataType::BFloat16);
return TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds(0, 0)
.addInputOperandWorkaround(inputWorkaround)
.addInputOperandWorkaround(inputRowMajorInt32Workaround)
.addInputOperandWorkaround(weightWorkaround)
.addInputOperandWorkaround(weightWorkaround)
.addOutputOperandWorkaround(weightWorkaround);
Expand Down
17 changes: 13 additions & 4 deletions runtime/include/tt/runtime/detail/workarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ struct Env {
constexpr static Env
#endif
get(bool maxpool2dPreshard = true, bool swapBinaryOperands = true,
bool readUpdateIndexFromDeviceForKVCache = true)
bool readUpdateIndexFromDeviceForKVCache = true,
bool toDtypeOnHost = true)
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
;
#else
{
return Env(true, true, true);
return Env(true, true, true, true);
}
#endif
// TODO(bug #855): Ideally we should have an op that preshards for maxpool2d
Expand All @@ -39,13 +40,19 @@ struct Env {
// to be able to pluck this update index from a runtime tensor.
bool readUpdateIndexFromDeviceForKVCache;

// TODO(bug #1658): We're currently use ttnn::to_dtype operation to cast the
// data type of a tensor on host. Once we have improved the typecast operation
// to handle this, we should remove this workaround.
bool toDtypeOnHost;

private:
constexpr Env(bool maxpool2dPreshard, bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache)
bool readUpdateIndexFromDeviceForKVCache, bool toDtypeOnHost)
: maxpool2dPreshard(maxpool2dPreshard),
swapBinaryOperands(swapBinaryOperands),
readUpdateIndexFromDeviceForKVCache(
readUpdateIndexFromDeviceForKVCache) {}
readUpdateIndexFromDeviceForKVCache),
toDtypeOnHost(toDtypeOnHost) {}
};

inline std::ostream &operator<<(std::ostream &os, const Env &env) {
Expand All @@ -57,6 +64,8 @@ inline std::ostream &operator<<(std::ostream &os, const Env &env) {
os << "\t"
<< "readUpdateIndexFromDeviceForKVCache: "
<< env.readUpdateIndexFromDeviceForKVCache << "\n";
os << "\t"
<< "toDtypeOnHost: " << env.toDtypeOnHost << "\n";
os << "}";
return os;
}
Expand Down
5 changes: 3 additions & 2 deletions runtime/lib/common/workarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
namespace tt::runtime::workaround {
#if defined(TT_RUNTIME_WORKAROUNDS) && TT_RUNTIME_WORKAROUNDS == 1
const Env &Env::get(bool maxpool2dPreshard, bool swapBinaryOperands,
bool readUpdateIndexFromDeviceForKVCache) {
bool readUpdateIndexFromDeviceForKVCache,
bool toDtypeOnHost) {
static const Env config(maxpool2dPreshard, swapBinaryOperands,
readUpdateIndexFromDeviceForKVCache);
readUpdateIndexFromDeviceForKVCache, toDtypeOnHost);
return config;
}
#endif
Expand Down
11 changes: 10 additions & 1 deletion runtime/lib/ttnn/operations/layout/typecast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

#include "operations/layout/typecast.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/detail/workarounds.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "tt/runtime/ttnn/utils.h"
#include "ttnn/operations/core/core.hpp"

namespace tt::runtime::ttnn::operations::layout {

Expand All @@ -16,7 +18,14 @@ void run(const ::tt::target::ttnn::TypecastOp *op, ProgramContext &context) {
::ttnn::DataType targetDataType =
::tt::runtime::ttnn::utils::toTTNNDataType(op->dtype());

::ttnn::Tensor out = ::ttnn::typecast(inputTensor, targetDataType);
::ttnn::Tensor out;
if (workaround::Env::get().toDtypeOnHost &&
::tt::runtime::ttnn::utils::isOnHost(inputTensor.storage_type())) {
out = ::ttnn::to_dtype(inputTensor, targetDataType);
} else {
out = ::ttnn::typecast(inputTensor, targetDataType);
}

tensorPool.insert_or_assign(op->out()->global_id(), out);
}

Expand Down
8 changes: 8 additions & 0 deletions runtime/tools/python/ttrt/common/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def initialize_api():
choices=[True, False],
help="disable read update index for kv cache workaround",
)
Run.register_arg(
name="--disable-to-dtype-on-host",
type=bool,
default=False,
choices=[True, False],
help="disable to_dtype on host workaround",
)
Run.register_arg(
name="--result-file",
type=str,
Expand Down Expand Up @@ -379,6 +386,7 @@ def _execute(binaries):
not self["--disable-maxpool2d-preshard"],
not self["--disable-swap-binary-operands"],
not self["--disable-read-update-index-for-kv-cache"],
not self["--disable-to-dtype-on-host"],
)
self.logging.debug(f"setting tt runtime workaround env={workaround_env}")
self.logging.debug(f"setting torch manual seed={self['--seed']}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} {
// CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]])
// Check that the input operand is transformed into the row major layout.
// CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout"
// CHECK-SAME: dtype = #tt.supportedDataTypes<u32>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<32x32>>, <interleaved>>
// CHECK-SAME: -> tensor<32x32xf32
// CHECK-SAME: -> tensor<32x32xi32
// Check that the data type of the weight operand is transformed in bf16.
// CHECK-NEXT: %[[TO_LAYOUT_WEIGHTS:.*]] = "ttnn.to_layout"
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,41 @@
#dram = #ttnn.buffer_type<dram>
#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#system_memory = #ttnn.buffer_type<system_memory>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<16384x32xbf16, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<4096x32xbf16, #system_memory>>
#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<512x1x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<128x1x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<16384x32xf32, #system_memory>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<4096x32xf32, #system_memory>>
#ttnn_layout2 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 16384 + d1 * 128 + d2, d3), <1x1>, memref<512x1x!tt.tile<32x32, f32>, #dram>, <interleaved>>
#ttnn_layout3 = #ttnn.ttnn_layout<(d0, d1, d2, d3) -> (d0 * 4096 + d1 * 64 + d2, d3), <1x1>, memref<128x1x!tt.tile<32x32, f32>, #dram>, <interleaved>>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func @forward(%arg0: tensor<1x128x128x32xbf16, #ttnn_layout>) -> tensor<1x64x64x32xbf16, #ttnn_layout1> {
func.func @forward(%arg0: tensor<1x128x128x32xf32, #ttnn_layout>) -> tensor<1x64x64x32xf32, #ttnn_layout1> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
// CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device"[[C:.*]]
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<512x1>>, <interleaved>>}> : (tensor<1x128x128x32xbf16, #ttnn_layout>, !tt.device<#device>) -> tensor<1x128x128x32xbf16, #ttnn_layout2>
%2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 1 : i32, 16384 : i32, 32 : i32]}> : (tensor<1x128x128x32xbf16, #ttnn_layout2>) -> tensor<1x1x16384x32xbf16, #ttnn_layout2>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<128x1>>, <interleaved>>, shape = #ttnn.shape<1x1x4096x32>}> : (!tt.device<#device>) -> tensor<1x1x4096x32xbf16, #ttnn_layout3>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<512x1>>, <interleaved>>}> : (tensor<1x128x128x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1x128x128x32xf32, #ttnn_layout2>
%2 = "ttnn.reshape"(%1) <{shape = [1 : i32, 1 : i32, 16384 : i32, 32 : i32]}> : (tensor<1x128x128x32xf32, #ttnn_layout2>) -> tensor<1x1x16384x32xf32, #ttnn_layout2>
%3 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<128x1>>, <interleaved>>, shape = #ttnn.shape<1x1x4096x32>}> : (!tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3>
// CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]])
// Check that the input operand is transformed into the row major layout.
// CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout"
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<16384x32>>, <interleaved>>
// CHECK-SAME: -> tensor<1x1x16384x32xbf16,
// Check that the output operand is transformed into the row major layout.
// CHECK-NEXT: %[[TO_LAYOUT_OUTPUT_DPS:.*]] = "ttnn.to_layout"(%[[EMPTY_OP]], %[[DEVICE_OP]])
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<4096x32>>, <interleaved>>
// CHECK-SAME: -> tensor<1x1x4096x32xbf16,
%4 = "ttnn.max_pool2d"(%2, %3, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xbf16, #ttnn_layout2>, tensor<1x1x4096x32xbf16, #ttnn_layout3>, !tt.device<#device>) -> tensor<1x1x4096x32xbf16, #ttnn_layout3>
%4 = "ttnn.max_pool2d"(%2, %3, %0) <{batch_size = 1 : si32, ceil_mode = false, channels = 32 : si32, dilation_height = 1 : si32, dilation_width = 1 : si32, input_height = 128 : si32, input_width = 128 : si32, kernel_height = 2 : si32, kernel_width = 2 : si32, padding_height = 0 : si32, padding_width = 0 : si32, stride_height = 2 : si32, stride_width = 2 : si32}> : (tensor<1x1x16384x32xf32, #ttnn_layout2>, tensor<1x1x4096x32xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<1x1x4096x32xf32, #ttnn_layout3>
// CHECK-NEXT: %[[MAX_POOL_2D_OP:.*]] = "ttnn.max_pool2d"(%[[TO_LAYOUT_INPUT]], %[[TO_LAYOUT_OUTPUT_DPS]], %[[DEVICE_OP]])
// Check that the output operand is transformed back into the tile and f32 data type.
// CHECK-NEXT: %[[TO_LAYOUT_OUTPUT:.*]] = "ttnn.to_layout"(%[[MAX_POOL_2D_OP]], %[[DEVICE_OP]])
// CHECK-SAME: dtype = #tt.supportedDataTypes<f32>
// CHECK-SAME: layout = #ttnn.layout<tile>
// CHECK-SAME: memory_config = #ttnn.memory_config<#dram, <<128x1>>, <interleaved>>
// CHECK-SAME: -> tensor<1x1x4096x32xbf16
%5 = "ttnn.reshape"(%4) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xbf16, #ttnn_layout3>) -> tensor<1x64x64x32xbf16, #ttnn_layout3>
// CHECK-SAME: -> tensor<1x1x4096x32xf32
%5 = "ttnn.reshape"(%4) <{shape = [1 : i32, 64 : i32, 64 : i32, 32 : i32]}> : (tensor<1x1x4096x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout3>
// CHECK-NEXT: ttnn.reshape
%6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xbf16, #ttnn_layout3>) -> tensor<1x64x64x32xbf16, #ttnn_layout1>
return %6 : tensor<1x64x64x32xbf16, #ttnn_layout1>
%6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<4096x32>>>}> : (tensor<1x64x64x32xf32, #ttnn_layout3>) -> tensor<1x64x64x32xf32, #ttnn_layout1>
return %6 : tensor<1x64x64x32xf32, #ttnn_layout1>
}
}
8 changes: 4 additions & 4 deletions test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
module attributes {} {
func.func @forward(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> {
%0 = tensor.empty() : tensor<1x64x64x32xbf16>
func.func @forward(%arg0: tensor<1x128x128x32xf32>) -> tensor<1x64x64x32xf32> {
%0 = tensor.empty() : tensor<1x64x64x32xf32>
// CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]]
%1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16>
return %1 : tensor<1x64x64x32xbf16>
%1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xf32>, tensor<1x64x64x32xf32>) -> tensor<1x64x64x32xf32>
return %1 : tensor<1x64x64x32xf32>
}
}
8 changes: 4 additions & 4 deletions test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> {
func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<32x32x128xbf16>
%0 = tensor.empty() : tensor<32x32x128xf32>
// CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]]
%1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16>
return %1 : tensor<32x32x128xbf16>
%1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xf32>, tensor<512x128xf32>, tensor<32x32x128xf32>) -> tensor<32x32x128xf32>
return %1 : tensor<32x32x128xf32>
}
}
8 changes: 4 additions & 4 deletions test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
module attributes {} {
func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> {
%0 = tensor.empty() : tensor<1x32x64x64xbf16>
func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x64x64xf32> {
%0 = tensor.empty() : tensor<1x32x64x64xf32>
// CHECK: "ttnn.permute"
// CHECK-SAME: permutation = array<i64: 0, 2, 3, 1>
// CHECK: "ttnn.max_pool2d"
Expand All @@ -16,7 +16,7 @@ module attributes {} {
window_strides = array<i64: 1, 1, 2, 2>,
base_dilations = array<i64: 1, 1, 1, 1>,
window_dilations = array<i64: 1, 1, 1, 1>,
padding = array<i64: 0, 0, 0, 0, 0, 0, 0, 0>}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16>
return %1 : tensor<1x32x64x64xbf16>
padding = array<i64: 0, 0, 0, 0, 0, 0, 0, 0>}> : (tensor<1x32x128x128xf32>, tensor<1x32x64x64xf32>) -> tensor<1x32x64x64xf32>
return %1 : tensor<1x32x64x64xf32>
}
}
Loading