Skip to content

Commit

Permalink
Implementing leaky relu op in tt-mlir (#1290)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT authored Nov 18, 2024
1 parent 6c64fb7 commit 4ed0821
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 11 deletions.
44 changes: 44 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,50 @@ def TTIR_Expm1Op: TTIR_ElementwiseUnaryOp<"expm1"> {
}];
}

class TTIR_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseUnaryOp<mnemonic, traits> {
let summary = "Eltwise unary op with the float parameter.";
let description = [{
Eltwise unary op with the float parameter.
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
F32Attr:$parameter,
TT_OperandConstraintArrayAttr:$operand_constraints);

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter, "ArrayAttr": $operand_constraints),
[{
build($_builder, $_state, {out.getType()}, {in}, {out}, parameter, operand_constraints);
}]>
];
}

def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
let summary = "Eltwise leaky relu operation.";
let description = [{
The Leaky ReLU (Rectified Linear Unit) operation computes an element-wise
activation function over its input tensor. It is defined as:

y = x if x > 0
y = parameter * x if x <= 0

where `parameter` is a small, user-defined constant that determines the slope for
negative inputs.

Attributes:
- `parameter` (float): The slope for negative values.

Inputs:
- `input` (Tensor): The input tensor to be activated.

Outputs:
- `output` (Tensor): The tensor after applying the Leaky ReLU activation.
}];
}

class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
TTIR_ElementwiseOp<mnemonic, traits> {
let summary = "Eltwise binary op.";
Expand Down
43 changes: 43 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,49 @@ def TTNN_Expm1Op: TTNN_ElementwiseUnaryOp<"expm1"> {
}];
}

class TTIR_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> traits = []> :
TTNN_ElementwiseUnaryOp<mnemonic, traits> {
let summary = "Eltwise unary op with the float parameter.";
let description = [{
Eltwise unary op with the float parameter.
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
F32Attr:$parameter);

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter),
[{
build($_builder, $_state, {out.getType()}, {in}, {out}, parameter);
}]>
];
}

def TTIR_LeakyReluOp : TTIR_ElementwiseUnaryWithFloatParameterOp<"leaky_relu"> {
let summary = "Eltwise leaky relu operation.";
let description = [{
The Leaky ReLU (Rectified Linear Unit) operation computes an element-wise
activation function over its input tensor. It is defined as:

y = x if x > 0
y = parameter * x if x <= 0

where `parameter` is a small, user-defined constant that determines the slope for
negative inputs.

Attributes:
- `parameter` (float): The slope for negative values.

Inputs:
- `input` (Tensor): The input tensor to be activated.

Outputs:
- `output` (Tensor): The tensor after applying the Leaky ReLU activation.
}];
}

def TTNN_AddOp : TTNN_ElementwiseBinaryOp<"add"> {
let summary = "Eltwise add.";
let description = [{
Expand Down
16 changes: 11 additions & 5 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ include "Common/debug_info.fbs";

namespace tt.target.ttnn;

table ClampOpParams {
min: float;
max: float;
}

table GetDeviceOp {
mesh: Dim2d;
chip_ids: [uint32];
Expand Down Expand Up @@ -106,10 +101,21 @@ enum EltwiseOpType: uint32 {
Gelu = 36,
LogicalXor = 37,
Clamp = 38,
LeakyRelu = 39,
}

table ClampOpParams {
min: float;
max: float;
}

table EltwiseOpWithFloatParams {
parameter: float;
}

union EltwiseOpParams {
ClampOpParams,
EltwiseOpWithFloatParams,
}

table EltwiseOp {
Expand Down
18 changes: 18 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,23 @@ class ClampOpConversionPattern : public OpConversionPattern<ttir::ClampOp> {
}
};

template <typename TTIROpTy, typename TTNNOpTy,
typename OpAdaptor = typename TTIROpTy::Adaptor>
class ElementwiseUnaryWithFloatParameterOpConversionPattern
: public OpConversionPattern<TTIROpTy> {
public:
using OpConversionPattern<TTIROpTy>::OpConversionPattern;

LogicalResult
matchAndRewrite(TTIROpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TTNNOpTy>(
op, this->getTypeConverter()->convertType(op.getType(0)),
adaptor.getInputs(), adaptor.getOutputs(), adaptor.getParameter());
return success();
}
};

class ConcatOpConversionPattern : public OpConversionPattern<ttir::ConcatOp> {
public:
using OpConversionPattern<ttir::ConcatOp>::OpConversionPattern;
Expand Down Expand Up @@ -936,6 +953,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ElementwiseOpConversionPattern<ttir::Expm1Op, ttnn::Expm1Op>,
ElementwiseOpConversionPattern<ttir::RemainderOp, ttnn::RemainderOp>,
ElementwiseOpConversionPattern<ttir::WhereOp, ttnn::WhereOp>,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
ReductionOpConversionPattern<ttir::SumOp, ttnn::SumOp>,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
DefaultOpConversionPattern<ttnn::NegOp>,
DefaultOpConversionPattern<ttnn::ReluOp>,
DefaultOpConversionPattern<ttnn::LeakyReluOp>,
DefaultOpConversionPattern<ttnn::GeluOp>,
DefaultOpConversionPattern<ttnn::SqrtOp>,
DefaultOpConversionPattern<ttnn::RsqrtOp>,
Expand Down
35 changes: 29 additions & 6 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
#include "ttmlir/Dialect/TTKernel/IR/TTKernel.h"
#include "ttmlir/Dialect/TTKernel/IR/TTKernelOps.h"
#include "ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.h"
Expand Down Expand Up @@ -379,11 +380,19 @@ createOp(FlatbufferObjectCache &cache, AllGatherOp op) {
op.getDim(), op.getNumLinks());
}

::flatbuffers::Offset<::tt::target::ttnn::ClampOpParams>
createEltwiseOpParams(FlatbufferObjectCache &cache, ClampOp op) {
auto min = op.getMin().convertToFloat();
auto max = op.getMax().convertToFloat();
return ::tt::target::ttnn::CreateClampOpParams(*cache.fbb, min, max);
template <typename EltwiseOp, typename EltwiseOpParams>
::flatbuffers::Offset<EltwiseOpParams>
createEltwiseOpParams(FlatbufferObjectCache &cache, EltwiseOp op) {
if constexpr (std::is_same_v<EltwiseOp, ClampOp>) {
auto min = op.getMin().convertToFloat();
auto max = op.getMax().convertToFloat();
return ::tt::target::ttnn::CreateClampOpParams(*cache.fbb, min, max);
}
if constexpr (std::is_same_v<EltwiseOp, LeakyReluOp>) {
auto parameter = op.getParameter().convertToFloat();
return ::tt::target::ttnn::CreateEltwiseOpWithFloatParams(*cache.fbb,
parameter);
}
}

template <typename EltwiseOp>
Expand All @@ -396,7 +405,9 @@ createNonDPSEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
if constexpr (std::is_same_v<EltwiseOp, ClampOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Clamp;
paramsType = ::tt::target::ttnn::EltwiseOpParams::ClampOpParams;
params = createEltwiseOpParams(cache, op).Union();
params = createEltwiseOpParams<ClampOp, ::tt::target::ttnn::ClampOpParams>(
cache, op)
.Union();
} else {
llvm_unreachable("unhandled non-DPS EltwiseOp");
}
Expand Down Expand Up @@ -494,6 +505,14 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
type = ::tt::target::ttnn::EltwiseOpType::Where;
} else if constexpr (std::is_same_v<EltwiseOp, GeluOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Gelu;
} else if constexpr (std::is_same_v<EltwiseOp, LeakyReluOp>) {
type = ::tt::target::ttnn::EltwiseOpType::LeakyRelu;
paramsType = ::tt::target::ttnn::EltwiseOpParams::EltwiseOpWithFloatParams;
params =
createEltwiseOpParams<LeakyReluOp,
::tt::target::ttnn::EltwiseOpWithFloatParams>(
cache, op)
.Union();
} else {
llvm_unreachable("unhandled EltwiseOp");
}
Expand Down Expand Up @@ -778,6 +797,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createEltwiseOp(cache, remainderOp),
debugString);
}
if (auto leakyReluOp = dyn_cast<LeakyReluOp>(op); leakyReluOp) {
return createOperation(cache, createEltwiseOp(cache, leakyReluOp),
debugString);
}
if (auto matmulOp = dyn_cast<MatmulOp>(op); matmulOp) {
return createOperation(cache, createOp(cache, matmulOp), debugString);
}
Expand Down
20 changes: 20 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/eltwise/unary/utils.h"
#include "tt/runtime/ttnn/operations/utils.h"
#include "ttmlir/Target/TTNN/program_generated.h"
#include "ttnn/operations/copy.hpp"

namespace tt::runtime::ttnn::operations::unary {
Expand Down Expand Up @@ -45,6 +46,21 @@ static void runEltwiseUnaryWithFastAndApproximateModeOp(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryWithFloatParameterOp(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
const std::function<::ttnn::Tensor(const ::ttnn::Tensor &, float,
const ::tt::tt_metal::MemoryConfig &)>
&ttnnOp) {
::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOpInputTensor(op, tensorPool, &in);

float parameter = op->params_as_EltwiseOpWithFloatParams()->parameter();
::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());
::ttnn::Tensor out = ttnnOp(*in, parameter, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
Expand Down Expand Up @@ -122,6 +138,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryOp(op, tensorPool, ::ttnn::expm1);
break;
}
case ::tt::target::ttnn::EltwiseOpType::LeakyRelu: {
runEltwiseUnaryWithFloatParameterOp(op, tensorPool, ::ttnn::leaky_relu);
break;
}
default:
LOG_FATAL("Unsupported unary operation");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @leaky_relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"
// CHECK-SAME: [[TENSOR:tensor<64x128xf32,]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.leaky_relu"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: [[TENSOR]]
// CHECK-SAME: -> [[TENSOR]]
%1 = "ttir.leaky_relu"(%arg0, %0) <{parameter = 0.01 : f32, operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}
8 changes: 8 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ func.func @relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
return %1 : tensor<64x128xf32>
}

func.func @leaky_relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttnn.leaky_relu"
%1 = "ttir.leaky_relu"(%arg0, %0) <{parameter = 0.01 : f32, operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}

func.func @reshape(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> {
%0 = tensor.empty() : tensor<2x4x32x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]]
Expand Down

0 comments on commit 4ed0821

Please sign in to comment.