Skip to content

Commit

Permalink
Refactor to create eltwise ternary op starting with whereOp.
Browse files Browse the repository at this point in the history
  • Loading branch information
uazizTT committed Nov 12, 2024
1 parent 8375558 commit d8fce32
Show file tree
Hide file tree
Showing 17 changed files with 143 additions and 151 deletions.
40 changes: 22 additions & 18 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,39 +167,43 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
let results = (outs Variadic<AnyRankedTensor>:$results);
}

class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise unary op.";
let summary = "Eltwise ternary op.";
let description = [{
Eltwise unary op.
Eltwise ternary op.
}];

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, in, out, operand_constraints);
build($_builder, $_state, {out.getType()}, {first, second, third}, out, operand_constraints);
}]>
];
}

def TTIR_WhereOp : TTIR_DPSOp<"where"> {
let summary = "Where operation.";
def TTIR_WhereOp: TTIR_ElementwiseTernaryOp<"where"> {
let summary = "Eltwise where op.";
let description = [{
Selects an element from on_true or on_false based on pred.
Eltwise where operation.
}];
}

let arguments = (ins AnyRankedTensor:$pred,
AnyRankedTensor:$on_true,
AnyRankedTensor:$on_false,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise unary op.";
let description = [{
Eltwise unary op.
}];

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, in, out, operand_constraints);
}]>
];
}

def TTIR_AbsOp: TTIR_ElementwiseUnaryOp<"abs"> {
Expand Down
25 changes: 17 additions & 8 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,27 @@ class TTNN_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
];
}

def TTNN_WhereOp : TTNN_NamedDPSOp<"where"> {
let summary = "Where op.";
class TTNN_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise binary op.";
let description = [{
Selects an element from on_true or on_false based on pred.
Eltwise binary op.
}];

let arguments = (ins AnyRankedTensor:$pred,
AnyRankedTensor:$on_true,
AnyRankedTensor:$on_false,
AnyRankedTensor:$outputs);
let builders =
[
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, {first, second, third}, out);
}]>
];
}

let results = (outs AnyRankedTensor:$result);
def TTNN_WhereOp : TTNN_ElementwiseTernaryOp<"where"> {
let summary = "Eltwise where.";
let description = [{
Eltwise where operation.
}];
}

def TTNN_AbsOp : TTNN_ElementwiseUnaryOp<"abs"> {
Expand Down
9 changes: 1 addition & 8 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ enum EltwiseOpType: uint32 {
Remainder = 32,
IsFinite = 33,
Floor = 34,
Where = 35,
}

union EltwiseOpParams {
Expand Down Expand Up @@ -120,13 +121,6 @@ table ReductionOp {
keep_dim: bool;
}

table WhereOp {
pred: tt.target.TensorRef;
on_true: tt.target.TensorRef;
on_false: tt.target.TensorRef;
output: tt.target.TensorRef;
}

table EmbeddingOp {
input: tt.target.TensorRef;
weight: tt.target.TensorRef;
Expand Down Expand Up @@ -238,7 +232,6 @@ union OpType {
EltwiseOp,
MatmulOp,
ReductionOp,
WhereOp,
EmbeddingOp,
SoftmaxOp,
TransposeOp,
Expand Down
36 changes: 2 additions & 34 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,33 +734,6 @@ class StableHLOToTTIRCompareOpConversionPattern
}
};

class StableHLOToTTIRSelectOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::SelectOp> {
using OpConversionPattern<mlir::stablehlo::SelectOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::SelectOp srcOp,
mlir::stablehlo::SelectOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

rewriter.replaceOpWithNewOp<::mlir::tt::ttir::WhereOp>(
srcOp, outputType, adaptor.getPred(), adaptor.getOnTrue(),
adaptor.getOnFalse(), Value(outputTensor),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));

return success();
}
};

class StableHLOToTTIRConcatOpConversionPattern
: public OpConversionPattern<mlir::stablehlo::ConcatenateOp> {

Expand Down Expand Up @@ -978,12 +951,8 @@ void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx,
ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::RemOp, mlir::tt::ttir::RemainderOp>>(typeConverter, ctx);
}

void addSelectOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {

patterns.add<StableHLOToTTIRSelectOpConversionPattern>(typeConverter, ctx);
patterns.add<StableHLOToTTIROpDefaultConversionPattern<
mlir::stablehlo::SelectOp, mlir::tt::ttir::WhereOp>>(typeConverter, ctx);
}

void addReduceOpsConversionPatterns(MLIRContext *ctx,
Expand Down Expand Up @@ -1090,7 +1059,6 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addReshapeOpConversionPattern(ctx, patterns, typeConverter);
addLogicalOpConversionPattern(ctx, patterns, typeConverter);
addSliceOpConversionPattern(ctx, patterns, typeConverter);
addSelectOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
19 changes: 1 addition & 18 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,23 +826,6 @@ class SubtractOpConversionPattern
}
};

template <typename TTIROpTy, typename TTNNOpTy,
typename OpAdaptor = typename TTIROpTy::Adaptor>
class SelectOpConversionPattern : public OpConversionPattern<TTIROpTy> {
public:
using OpConversionPattern<TTIROpTy>::OpConversionPattern;

LogicalResult
matchAndRewrite(TTIROpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TTNNOpTy>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getPred(), adaptor.getOnTrue(), adaptor.getOnFalse(),
adaptor.getOutput());
return success();
}
};

class AllGatherOpConversionPattern
: public OpConversionPattern<ttir::AllGatherOp> {
public:
Expand Down Expand Up @@ -906,10 +889,10 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::CosOp, ttnn::CosOp>,
ElementwiseOpConversionPattern<ttir::Expm1Op, ttnn::Expm1Op>,
ElementwiseOpConversionPattern<ttir::RemainderOp, ttnn::RemainderOp>,
ElementwiseOpConversionPattern<ttir::WhereOp, ttnn::WhereOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
SelectOpConversionPattern<ttir::WhereOp, ttnn::WhereOp>,
BroadcastOpConversionPattern,
EmbeddingOpConversionPattern,
SoftmaxOpConversionPattern,
Expand Down
20 changes: 3 additions & 17 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Expm1;
} else if constexpr (std::is_same_v<EltwiseOp, RemainderOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Remainder;
} else if constexpr (std::is_same_v<EltwiseOp, WhereOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Where;
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -419,22 +421,6 @@ createTransposeOp(FlatbufferObjectCache &cache, TransposeOp op) {
return ::tt::target::ttnn::CreateTransposeOp(*cache.fbb, in, out, dim0, dim1);
}

template <typename WhereOp>
::flatbuffers::Offset<::tt::target::ttnn::WhereOp>
createWhereOp(FlatbufferObjectCache &cache, WhereOp op) {
auto out = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
auto pred =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getPred()));
auto ontrue = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getOnTrue()));
auto onfalse = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getOnFalse()));

return ::tt::target::ttnn::CreateWhereOp(*cache.fbb, pred, ontrue, onfalse,
out);
}

template <typename ConcatOp>
::flatbuffers::Offset<::tt::target::ttnn::ConcatOp>
createConcatOp(FlatbufferObjectCache &cache, ConcatOp op) {
Expand Down Expand Up @@ -720,7 +706,7 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createEltwiseOp(cache, sinOp), debugString);
}
if (auto whereOp = dyn_cast<WhereOp>(op); whereOp) {
return createOperation(cache, createWhereOp(cache, whereOp), debugString);
return createOperation(cache, createEltwiseOp(cache, whereOp), debugString);
}

llvm_unreachable("unhandled op in emitTTNNOperation");
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/binary/binary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/unary/unary_composite.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/tertiary/where.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eltwise/ternary/where.cpp
${CMAKE_CURRENT_SOURCE_DIR}/embedding/embedding.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/to_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/layout/from_device.cpp
Expand Down
27 changes: 27 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/ternary/where.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "where.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/ternary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"

namespace tt::runtime::ttnn::operations::ternary {

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();

::ttnn::Tensor *first = nullptr;
::ttnn::Tensor *second = nullptr;
::ttnn::Tensor *third = nullptr;
getEltwiseTernaryOPInputTensors(op, tensorPool, &first, &second, &third);

::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());

::ttnn::Tensor out = ::ttnn::where(*first, *second, *third, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::ternary
21 changes: 21 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/ternary/where.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_WHERE_H
#define TTNN_RUNTIME_LIB_TTNN_OPERATIONS_ELTWISE_TERNARY_WHERE_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::ternary {

inline bool isTernaryOp(const ::tt::target::ttnn::EltwiseOp *op) {
return op->ins()->size() == 3;
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context);

} // namespace tt::runtime::ttnn::operations::ternary

#endif
25 changes: 0 additions & 25 deletions runtime/lib/ttnn/operations/eltwise/tertiary/where.cpp

This file was deleted.

15 changes: 0 additions & 15 deletions runtime/lib/ttnn/operations/eltwise/tertiary/where.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "utils.h"
#include "tt/runtime/detail/logger.h"
#include "tt/runtime/detail/workarounds.h"

namespace tt::runtime::ttnn::operations::binary {

void getEltwiseTernaryOPInputTensors(const ::tt::target::ttnn::EltwiseOp *op,
ProgramTensorPool &tensorPool,
::ttnn::Tensor **first,
::ttnn::Tensor **seccond,
::ttnn::Tensor **third) {
LOG_ASSERT(op->ins()->size() == 3, "Expected 3 inputs");
*first = &(tensorPool.at(op->ins()->Get(0)->global_id()));
*second = &(tensorPool.at(op->ins()->Get(1)->global_id()));
*third = &(tensorPool.at(op->ins()->Get(2)->global_id()));
DEBUG_ASSERT((*first)->is_allocated());
DEBUG_ASSERT((*second)->is_allocated());
DEBUG_ASSERT((*third)->is_allocated());
}

} // namespace tt::runtime::ttnn::operations::binary
Loading

0 comments on commit d8fce32

Please sign in to comment.