-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement conversion for stablehlo.select and add Where Op (#852)
* Added conversion for SelectOp to TTIR WhereOp along with end-to-end support. --------- Co-authored-by: Stefan Djordjevic <[email protected]> Co-authored-by: Milan Topalovic <[email protected]> Co-authored-by: Nikola Obradovic <[email protected]> Co-authored-by: Filip Bajraktari <[email protected]>
- Loading branch information
1 parent
14cd5d0
commit 430b036
Showing
18 changed files
with
218 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ternary.h" | ||
#include "tt/runtime/detail/logger.h" | ||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/operations/eltwise/ternary/utils.h" | ||
#include "tt/runtime/ttnn/operations/utils.h" | ||
|
||
namespace tt::runtime::ttnn::operations::ternary { | ||
|
||
void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { | ||
if (op->type() != ::tt::target::ttnn::EltwiseOpType::Where) { | ||
throw std::invalid_argument("Unsupported Eltwise Ternary operation"); | ||
} | ||
|
||
ProgramTensorPool &tensorPool = context.getTensorPool(); | ||
|
||
::ttnn::Tensor *first = nullptr; | ||
::ttnn::Tensor *second = nullptr; | ||
::ttnn::Tensor *third = nullptr; | ||
getEltwiseTernaryOPInputTensors(op, tensorPool, &first, &second, &third); | ||
|
||
::tt::tt_metal::MemoryConfig outputMemoryConfig = | ||
utils::createMemoryConfig(op->out()); | ||
|
||
::ttnn::Tensor out = | ||
::ttnn::where(*first, *second, *third, outputMemoryConfig); | ||
tensorPool.insert_or_assign(op->out()->global_id(), out); | ||
} | ||
} // namespace tt::runtime::ttnn::operations::ternary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H | ||
#define TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H | ||
|
||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::ternary { | ||
|
||
inline bool isTernaryOp(const ::tt::target::ttnn::EltwiseOp *op) { | ||
return op->ins()->size() == 3; | ||
} | ||
|
||
void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context); | ||
|
||
} // namespace tt::runtime::ttnn::operations::ternary | ||
|
||
#endif |
24 changes: 24 additions & 0 deletions
24
runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#include "utils.h" | ||
#include "tt/runtime/detail/logger.h" | ||
#include "tt/runtime/detail/workarounds.h" | ||
|
||
namespace tt::runtime::ttnn::operations::ternary { | ||
|
||
void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, | ||
ProgramTensorPool &tensorPool, | ||
::ttnn::Tensor **first, | ||
::ttnn::Tensor **second, | ||
::ttnn::Tensor **third) { | ||
LOG_ASSERT(op->ins()->size() == 3, "Expected 3 inputs"); | ||
*first = &(tensorPool.at(op->ins()->Get(0)->global_id())); | ||
*second = &(tensorPool.at(op->ins()->Get(1)->global_id())); | ||
*third = &(tensorPool.at(op->ins()->Get(2)->global_id())); | ||
DEBUG_ASSERT((*first)->is_allocated()); | ||
DEBUG_ASSERT((*second)->is_allocated()); | ||
DEBUG_ASSERT((*third)->is_allocated()); | ||
} | ||
|
||
} // namespace tt::runtime::ttnn::operations::ternary |
21 changes: 21 additions & 0 deletions
21
runtime/lib/ttnn/operations/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTNN_RUNTIME_ELTWISE_TERNARY_UTILS_H | ||
#define TTNN_RUNTIME_ELTWISE_TERNARY_UTILS_H | ||
|
||
#include "tt/runtime/detail/ttnn.h" | ||
#include "tt/runtime/ttnn/types.h" | ||
#include "ttmlir/Target/TTNN/program_generated.h" | ||
|
||
namespace tt::runtime::ttnn::operations::ternary { | ||
void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op, | ||
ProgramTensorPool &tensorPool, | ||
::ttnn::Tensor **first, | ||
::ttnn::Tensor **second, | ||
::ttnn::Tensor **third); | ||
|
||
} // namespace tt::runtime::ttnn::operations::ternary | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// REQUIRES: stablehlo | ||
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s | ||
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile> | ||
module @jit_eltwise_select attributes {} { | ||
func.func public @test_select(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { | ||
%0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xi1> | ||
%1 = stablehlo.select %0, %arg0, %arg1 : (tensor<13x37xi1>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> | ||
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() | ||
// CHECK: %[[VAL1:[0-9]+]] = "ttir.eq" | ||
// CHECK: %[[SELECT:[0-9]+]] = "ttir.where"(%[[VAL1:[0-9]+]], %arg0, %arg1, %[[EMPTY:[0-9]+]]) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> | ||
return %1 : tensor<13x37xf32> | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s | ||
#any_device_tile = #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile> | ||
module @jit_eltwise_where { | ||
func.func public @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { | ||
%0 = tensor.empty() : tensor<13x37xf32> | ||
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> | ||
%2 = tensor.empty() : tensor<13x37xf32> | ||
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> | ||
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} | ||
// CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) | ||
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) | ||
return %3 : tensor<13x37xf32> | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir | ||
// RUN: FileCheck %s --input-file=%t.mlir | ||
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn | ||
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile> | ||
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile> | ||
|
||
func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { | ||
%0 = tensor.empty() : tensor<13x37xbf16> | ||
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16> | ||
%2 = tensor.empty() : tensor<13x37xf32> | ||
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> | ||
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} | ||
// CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) | ||
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) | ||
return %3 : tensor<13x37xf32> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters