Skip to content

Commit

Permalink
Implement conversion for stablehlo.select and add Where Op (#852)
Browse files Browse the repository at this point in the history
* Added conversion for SelectOp to TTIR WhereOp along with end-to-end support.

---------

Co-authored-by: Stefan Djordjevic <[email protected]>
Co-authored-by: Milan Topalovic <[email protected]>
Co-authored-by: Nikola Obradovic <[email protected]>
Co-authored-by: Filip Bajraktari <[email protected]>
  • Loading branch information
5 people authored Nov 12, 2024
1 parent 14cd5d0 commit 430b036
Show file tree
Hide file tree
Showing 18 changed files with 218 additions and 2 deletions.
23 changes: 23 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,29 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
let results = (outs Variadic<AnyRankedTensor>:$results);
}

class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
}];

let builders =
[
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, {first, second, third}, out, operand_constraints);
}]>
];
}

def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> {
let summary = "Eltwise where op.";
let description = [{
Eltwise where operation.
}];
}

class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise unary op.";
Expand Down
23 changes: 23 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,29 @@ class TTNN_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

class TTNN_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise ternary op.";
let description = [{
Eltwise ternary op.
}];

let builders =
[
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, {first, second, third}, out);
}]>
];
}

def TTNN_WhereOp : TTNN_ElementwiseTernaryOp<"where"> {
let summary = "Eltwise where.";
let description = [{
Eltwise where operation.
}];
}

def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> {
let summary = "Eltwise absolute.";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ enum EltwiseOpType: uint32 {
Remainder = 32,
IsFinite = 33,
Floor = 34,
Where = 35,
}

union EltwiseOpParams {
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,8 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::RemOp, mlir::tt::ttir::RemainderOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SelectOp, mlir::tt::ttir::WhereOp>>(typeConverter, ctx);
}

void addReduceOpsConversionPatterns(MLIRContext *ctx,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::CosOp, ttnn::CosOp>,
ElementwiseOpConversionPattern<ttir::Expm1Op, ttnn::Expm1Op>,
ElementwiseOpConversionPattern<ttir::RemainderOp, ttnn::RemainderOp>,
ElementwiseOpConversionPattern<ttir::WhereOp, ttnn::WhereOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,8 +679,8 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
// Other ops
//
patterns.add<DefaultOpConversionPattern<ttnn::SoftmaxOp>,
DefaultOpConversionPattern<ttnn::EmbeddingOp>>(typeConverter,
ctx);
DefaultOpConversionPattern<ttnn::EmbeddingOp>,
DefaultOpConversionPattern<ttnn::WhereOp>>(typeConverter, ctx);

// CCL ops
//
Expand Down
5 changes: 5 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Expm1;
} else if constexpr (std::is_same_v<EltwiseOp, RemainderOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Remainder;
} else if constexpr (std::is_same_v<EltwiseOp, WhereOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Where;
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -719,6 +721,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto sinOp = dyn_cast<SinOp>(op); sinOp) {
return createOperation(cache, createEltwiseOp(cache, sinOp), debugString);
}
if (auto whereOp = dyn_cast<WhereOp>(op); whereOp) {
return createOperation(cache, createEltwiseOp(cache, whereOp), debugString);
}

llvm_unreachable("unhandled op in emitTTNNOperation");
}
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "ttnn/operations/data_movement/permute/permute.hpp"
#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"
#include "ttnn/operations/eltwise/ternary/where.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/embedding/embedding.hpp"
#include "ttnn/operations/matmul/matmul.hpp"
Expand Down
2 changes: 2 additions & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/binary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/unary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/eltwise/ternary/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
Expand All @@ -15,6 +16,7 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/ternary/ternary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp
Expand Down
32 changes: 32 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/ternary/ternary.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ternary.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/ternary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"

namespace tt::runtime::ttnn::operations::ternary {

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
if (op->type() != ::tt::target::ttnn::EltwiseOpType::Where) {
throw std::invalid_argument("Unsupported Eltwise Ternary operation");
}

ProgramTensorPool &tensorPool = context.getTensorPool();

::ttnn::Tensor *first = nullptr;
::ttnn::Tensor *second = nullptr;
::ttnn::Tensor *third = nullptr;
getEltwiseTernaryOPInputTensors(op, tensorPool, &first, &second, &third);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out =
::ttnn::where(*first, *second, *third, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::ternary
21 changes: 21 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/ternary/ternary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H
#define TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_TERNARY_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::ternary {

inline bool isTernaryOp(const ::tt::target::ttnn::EltwiseOp *op) {
return op->ins()->size() == 3;
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::ternary

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "utils.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/workarounds.h"

namespace tt::runtime::ttnn::operations::ternary {

void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **first,
::ttnn::Tensor **second,
::ttnn::Tensor **third) {
LOG_ASSERT(op->ins()->size() == 3, "Expected 3 inputs");
*first = &(tensorPool.at(op->ins()->Get(0)->global_id()));
*second = &(tensorPool.at(op->ins()->Get(1)->global_id()));
*third = &(tensorPool.at(op->ins()->Get(2)->global_id()));
DEBUG_ASSERT((*first)->is_allocated());
DEBUG_ASSERT((*second)->is_allocated());
DEBUG_ASSERT((*third)->is_allocated());
}

} // namespace tt::runtime::ttnn::operations::ternary
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ELTWISE_TERNARY_UTILS_H
#define TTNN_RUNTIME_ELTWISE_TERNARY_UTILS_H

#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::ternary {
void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **first,
::ttnn::Tensor **second,
::ttnn::Tensor **third);

} // namespace tt::runtime::ttnn::operations::ternary

#endif
6 changes: 6 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "operations/deletion/dealloc.h"
#include "operations/eltwise/binary/binary.h"
#include "operations/eltwise/binary/binary_composite.h"
#include "operations/eltwise/ternary/ternary.h"
#include "operations/eltwise/unary/unary.h"
#include "operations/eltwise/unary/unary_composite.h"
#include "operations/embedding/embedding.h"
Expand Down Expand Up @@ -73,13 +74,18 @@ void ProgramExecutor::runEltwiseOperation(
return operations::binary::run(op, context);
};

auto runTernaryOp = [&]() { return operations::ternary::run(op, context); };

if (operations::unary::isUnaryOp(op)) {
return runUnaryOp();
}

if (operations::binary::isBinaryOp(op)) {
return runBinaryOp();
}
if (operations::ternary::isTernaryOp(op)) {
return runTernaryOp();
}

throw std::invalid_argument("Unsupported Eltwise operation");
}
Expand Down
13 changes: 13 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module @jit_eltwise_select attributes {} {
func.func public @test_select(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xi1>
%1 = stablehlo.select %0, %arg0, %arg1 : (tensor<13x37xi1>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty()
// CHECK: %[[VAL1:[0-9]+]] = "ttir.eq"
// CHECK: %[[SELECT:[0-9]+]] = "ttir.where"(%[[VAL1:[0-9]+]], %arg0, %arg1, %[[EMPTY:[0-9]+]]) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
return %1 : tensor<13x37xf32>
}
}
14 changes: 14 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_where.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|none|interleaved|single_bank|height_sharded|width_sharded|block_sharded|any_layout|any_device_tile>
module @jit_eltwise_where {
func.func public @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = tensor.empty() : tensor<13x37xf32>
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
%2 = tensor.empty() : tensor<13x37xf32>
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}}
// CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]])
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
return %3 : tensor<13x37xf32>
}
}
16 changes: 16 additions & 0 deletions test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>

func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = tensor.empty() : tensor<13x37xbf16>
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16>
%2 = tensor.empty() : tensor<13x37xf32>
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}}
// CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]])
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
return %3 : tensor<13x37xf32>
}
11 changes: 11 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,14 @@ func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> {
return %0 : tensor<1xi32>
// CHECK: return [[VAL]] : tensor<1xi32, {{.*}}>
}

func.func @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> {
%0 = tensor.empty() : tensor<13x37xbf16>
%1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xbf16>) -> tensor<13x37xbf16>
%2 = tensor.empty() : tensor<13x37xf32>
%3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array<i32: 3, 1>, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32>
// CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}}
// CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]])
// CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}})
return %3 : tensor<13x37xf32>
}

0 comments on commit 430b036

Please sign in to comment.