Skip to content

Commit

Permalink
End-to-End Conversion for Round Op
Browse files Browse the repository at this point in the history
Stablehlo has 2 different round ops:
- round nearest even
- round nearest away from zero (or simply, round)

These two are lowered into ttnn::roundOp.

ttnn::roundOp is a unary composite op with integer parameter determining
which rounding mechanism to use.
  • Loading branch information
ddilbazTT committed Dec 20, 2024
1 parent c44a4bd commit 2250958
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 1 deletion.
41 changes: 41 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,47 @@ def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> {
}];
}

def TTIR_RoundOp : TTIR_DPSOp<"round"> {
let summary = "Eltwise round.";
let description = [{
Eltwise round operation.
}];
let arguments = (ins
AnyRankedTensor:$input,
AnyRankedTensor:$output,
DefaultValuedAttr<I32Attr, "1">:$decimals
);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let results = (outs AnyRankedTensor:$result);
}

def TTIR_RoundNearestEvenOp : TTIR_DPSOp<"roundnearesteven"> {
let summary = "Eltwise round towards nearest even.";
let description = [{
Rounds a number to the nearest value. If the number is exactly halfway between two values, the value is rounded to the nearest even value.

Example:
// %operand = [-2.5, 0.4, 0.5, 0.6, 2.5]
%result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64>
// %result: [-2.0, 0.0, 0.0, 1.0, 2.0]
}];
let arguments = (ins
AnyRankedTensor:$input,
AnyRankedTensor:$output,
DefaultValuedAttr<I32Attr, "0">:$decimals
);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let results = (outs AnyRankedTensor:$result);
}

def TTIR_RsqrtOp : TTIR_ElementwiseUnaryOp<"rsqrt"> {
let summary = "Eltwise reciprocal square root.";
let description = [{
Expand Down
14 changes: 14 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,20 @@ def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu",
}];
}

def TTNN_RoundOp : TTNN_Op<"round"> {
let summary = "Eltwise round operation.";
let description = [{
Eltwise round with decimal option to choose between Banker's rounding or normal rounding.
}];

let arguments = (ins
Variadic<AnyRankedTensor>:$inputs,
I32Attr:$decimals
);

let results = (outs Variadic<AnyRankedTensor>:$result);
}

def TTNN_SinOp : TTNN_ElementwiseUnaryOp<"sin"> {
let summary = "Eltwise sine.";
let description = [{
Expand Down
8 changes: 7 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -140,21 +140,27 @@ enum EltwiseOpType: uint32 {
LeakyRelu,
Scatter,
Tan,
Tanh
Tanh,
Round
}

table ClampOpParams {
min: float;
max: float;
}

table RoundOpParams {
decimals: int32;
}

table EltwiseOpWithFloatParams {
parameter: float;
}

union EltwiseOpParams {
ClampOpParams,
EltwiseOpWithFloatParams,
RoundOpParams,
}

table EltwiseOp {
Expand Down
45 changes: 45 additions & 0 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,41 @@ class StableHLOToTTIROpReverseOpConversionPattern
}
};

template <typename SrcOp, typename DestOp,
typename Adaptor = typename SrcOp::Adaptor>
class StableHLOToTTIRRoundOpConversionPattern
: public OpConversionPattern<SrcOp> {
public:
using OpConversionPattern<SrcOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));

tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

if (isa<stablehlo::RoundOp>(srcOp)) {
rewriter.replaceOpWithNewOp<DestOp>(srcOp, outputType,
adaptor.getOperand(), outputTensor,
rewriter.getI32IntegerAttr(1));
} else if (isa<stablehlo::RoundNearestEvenOp>(srcOp)) {
rewriter.replaceOpWithNewOp<DestOp>(srcOp, outputType,
adaptor.getOperand(), outputTensor,
rewriter.getI32IntegerAttr(0));
} else {
return rewriter.notifyMatchFailure(
srcOp, "ttir::RoundOp only supports stablehlo:RoundOp or "
"stablehlo::RoundNearestEvenOp");
}

return success();
}
};

void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand Down Expand Up @@ -1963,6 +1998,15 @@ void addReverseOpConversionPattern(MLIRContext *ctx,
patterns.add<StableHLOToTTIROpReverseOpConversionPattern>(typeConverter, ctx);
}

void addRoundOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<StableHLOToTTIRRoundOpConversionPattern<
mlir::stablehlo::RoundOp, mlir::tt::ttir::RoundOp>>(typeConverter, ctx);
patterns.add<StableHLOToTTIRRoundOpConversionPattern<
mlir::stablehlo::RoundNearestEvenOp, mlir::tt::ttir::RoundNearestEvenOp>>(
typeConverter, ctx);
}

} // namespace

namespace mlir::tt {
Expand Down Expand Up @@ -1992,6 +2036,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx,
addScatterOpConversionPatterns(ctx, patterns, typeConverter);
addReturnOpConversionPatterns(ctx, patterns, typeConverter);
addReverseOpConversionPattern(ctx, patterns, typeConverter);
addRoundOpConversionPattern(ctx, patterns, typeConverter);
}

} // namespace mlir::tt
73 changes: 73 additions & 0 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,77 @@ class TypecastOpConversionPattern
}
};

class BroadcastOpConversionPattern
: public OpConversionPattern<ttir::BroadcastOp> {
using OpConversionPattern<ttir::BroadcastOp>::OpConversionPattern;

public:
LogicalResult
matchAndRewrite(ttir::BroadcastOp srcOp, ttir::BroadcastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

// Fold this operation into all consumer ops. It will only work with TTNN
// ops that support implicit broadcasting. We expect each Op's verify
// function to assert their arguments to verify that they can broadcast.

if (srcOp->getUsers().empty()) {
// This broadcast chain has already been replaced.
rewriter.eraseOp(srcOp);
return success();
}

mlir::Value input = srcOp.getOperand(0);

mlir::Operation *nextOp = srcOp;
while (isa<ttir::BroadcastOp>(*nextOp->getUsers().begin())) {
assert(nextOp->hasOneUse() &&
"Broadcast with multiple uses are not supported");
nextOp = *nextOp->getUsers().begin();
if (nextOp->getUsers().empty()) {
// This broadcast chain has already been replaced.
rewriter.eraseOp(srcOp);
return success();
}
}

rewriter.replaceAllOpUsesWith(nextOp, input);
rewriter.eraseOp(srcOp);

return success();
}
};

template <typename TTIROpTy, typename TTNNOpTy,
typename OpAdaptor = typename TTIROpTy::Adaptor>
class RoundOpConversionPattern : public OpConversionPattern<TTIROpTy> {
public:
using OpConversionPattern<TTIROpTy>::OpConversionPattern;

LogicalResult
matchAndRewrite(TTIROpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (!isa<ttir::RoundOp>(op) && !isa<ttir::RoundNearestEvenOp>(op)) {
return rewriter.notifyMatchFailure(
op, "ttnn::RoundOp only supports ttir:RoundOp or "
"ttir::RoundNearestEvenOp");
}
if (isa<ttir::RoundOp>(op) && adaptor.getDecimals() == 0) {
return rewriter.notifyMatchFailure(
op, "ttir::RoundOp requires decimals != 0");
}
if (isa<ttir::RoundNearestEvenOp>(op) && adaptor.getDecimals() != 0) {
return rewriter.notifyMatchFailure(
op, "ttir::RoundNearestEvenOp requires decimals == 0");
}
rewriter.replaceOpWithNewOp<TTNNOpTy>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getDecimals());

return success();
}
};

class SubtractOpConversionPattern
: public OpConversionPattern<ttir::SubtractOp> {
using OpConversionPattern<ttir::SubtractOp>::OpConversionPattern;
Expand Down Expand Up @@ -1178,6 +1249,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
ReductionOpConversionPattern<ttir::MeanOp, ttnn::MeanOp>,
ReductionOpConversionPattern<ttir::MaxOp, ttnn::MaxOp>,
ElementwiseUnaryWithFloatParameterOpConversionPattern<ttir::LeakyReluOp, ttnn::LeakyReluOp>,
RoundOpConversionPattern<ttir::RoundOp, ttnn::RoundOp>,
RoundOpConversionPattern<ttir::RoundNearestEvenOp, ttnn::RoundOp>,
EmbeddingOpConversionPattern,
EmbeddingBackwardOpConversionPattern,
SoftmaxOpConversionPattern,
Expand Down
2 changes: 2 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
patterns.add<DefaultOpConversionPattern<ttnn::AbsOp>,
DefaultOpConversionPattern<ttnn::CbrtOp>,
DefaultOpConversionPattern<ttnn::ClampOp>,
DefaultOpConversionPattern<ttnn::RoundOp>,
DefaultOpConversionPattern<ttnn::FloorOp>,
DefaultOpConversionPattern<ttnn::IsFiniteOp>,
DefaultOpConversionPattern<ttnn::LogicalNotOp>,
Expand All @@ -692,6 +693,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
DefaultOpConversionPattern<ttnn::SigmoidOp>,
DefaultOpConversionPattern<ttnn::Log1pOp>,
DefaultOpConversionPattern<ttnn::ReciprocalOp>,
DefaultOpConversionPattern<ttnn::RoundOp>,
DefaultOpConversionPattern<ttnn::ExpOp>,
DefaultOpConversionPattern<ttnn::CeilOp>,
DefaultOpConversionPattern<ttnn::SinOp>,
Expand Down
14 changes: 14 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ createEltwiseOpParams(FlatbufferObjectCache &cache, EltwiseOp op) {
return ::tt::target::ttnn::CreateEltwiseOpWithFloatParams(*cache.fbb,
parameter);
}
if constexpr (std::is_same_v<EltwiseOp, RoundOp>) {
auto decimals = op.getDecimals();
return ::tt::target::ttnn::CreateRoundOpParams(*cache.fbb, decimals);
}
}

::flatbuffers::Offset<::tt::target::ttnn::UpdateCacheOp>
Expand Down Expand Up @@ -569,6 +573,12 @@ createNonDPSEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
params = createEltwiseOpParams<ClampOp, ::tt::target::ttnn::ClampOpParams>(
cache, op)
.Union();
} else if (std::is_same_v<EltwiseOp, RoundOp>) {
type = ::tt::target::ttnn::EltwiseOpType::Round;
paramsType = ::tt::target::ttnn::EltwiseOpParams::RoundOpParams;
params = createEltwiseOpParams<RoundOp, ::tt::target::ttnn::RoundOpParams>(
cache, op)
.Union();
} else {
llvm_unreachable("unhandled non-DPS EltwiseOp");
}
Expand Down Expand Up @@ -1098,6 +1108,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
return createOperation(cache, createNonDPSEltwiseOp(cache, clampOp),
debugString, locInfo);
}
if (auto roundOp = dyn_cast<RoundOp>(op); roundOp) {
return createOperation(cache, createNonDPSEltwiseOp(cache, roundOp),
debugString, locInfo);
}
if (auto conv2dOp = dyn_cast<Conv2dOp>(op); conv2dOp) {
return createOperation(cache, createOp(cache, conv2dOp), debugString,
locInfo);
Expand Down
19 changes: 19 additions & 0 deletions runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ static void runEltwiseUnaryCompositeClampOp(
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

static void runEltwiseUnaryCompositeRoundOp(
const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool,
const std::function<::ttnn::Tensor(const ::ttnn::Tensor &, int,
const ::tt::tt_metal::MemoryConfig &)>
&ttnnOp) {
::ttnn::Tensor *in = nullptr;
getEltwiseUnaryOpInputTensor(op, tensorPool, &in);

int32_t decimals = op->params_as_RoundOpParams()->decimals();
::tt::tt_metal::MemoryConfig outputMemoryConfig =
::tt::runtime::ttnn::utils::createMemoryConfig(op->out());
::ttnn::Tensor out = ttnnOp(*in, decimals, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}

void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
switch (op->type()) {
Expand All @@ -58,6 +73,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) {
runEltwiseUnaryCompositeOp(op, tensorPool, ::ttnn::log1p);
break;
}
case ::tt::target::ttnn::EltwiseOpType::Round: {
runEltwiseUnaryCompositeRoundOp(op, tensorPool, ::ttnn::round);
break;
}
default:
LOG_FATAL("Unsupported Eltwise Binary Composite operation");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) {
case ::tt::target::ttnn::EltwiseOpType::Cbrt:
case ::tt::target::ttnn::EltwiseOpType::Clamp:
case ::tt::target::ttnn::EltwiseOpType::Log1p:
case ::tt::target::ttnn::EltwiseOpType::Round:
return true;
default:
return false;
Expand Down
16 changes: 16 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/round_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_eltwise_round attributes {} {
func.func public @test_round(%arg0: tensor<4xbf16>) -> tensor<4xbf16> {
%0 = stablehlo.round_nearest_afz %arg0 : tensor<4xbf16>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.round"[[C:.*]]
return %0 : tensor<4xbf16>
}
func.func public @test_roundnearesteven(%arg0: tensor<4xbf16>) -> tensor<4xbf16> {
%0 = stablehlo.round_nearest_even %arg0 : tensor<4xbf16>
// CHECK: %[[C:.*]] = tensor.empty[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.roundnearesteven"[[C:.*]]
return %0 : tensor<4xbf16>
}
}
15 changes: 15 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_round.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s
module attributes {} {
func.func @roundnearesteven(%arg0: tensor<4xbf16>) -> tensor<4xbf16> {
%0 = tensor.empty() : tensor<4xbf16>
// CHECK: %[[C:.*]] = "ttnn.round"[[C:.*]]
%1 = "ttir.roundnearesteven"(%arg0, %0) <{decimals = 0 : i32}> : (tensor<4xbf16>, tensor<4xbf16>) -> tensor<4xbf16>
return %1 : tensor<4xbf16>
}
func.func @round(%arg0: tensor<4xbf16>) -> tensor<4xbf16> {
%0 = tensor.empty() : tensor<4xbf16>
// CHECK: %[[C:.*]] = "ttnn.round"[[C:.*]]
%1 = "ttir.round"(%arg0, %0) <{decimals = 1 : i32}> : (tensor<4xbf16>, tensor<4xbf16>) -> tensor<4xbf16>
return %1 : tensor<4xbf16>
}
}
14 changes: 14 additions & 0 deletions test/ttmlir/Silicon/TTNN/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,20 @@ func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tens
// CHECK: return {{.*}} : tensor<32x32xf32, {{.*}}
}

func.func @roundnearesteven(%arg0: tensor<4xbf16>) -> tensor<4xbf16> {
%0 = tensor.empty() : tensor<4xbf16>
// CHECK: %[[C:.*]] = "ttnn.round"[[C:.*]]
%1 = "ttir.roundnearesteven"(%arg0, %0) <{decimals = 0 : i32}> : (tensor<4xbf16>, tensor<4xbf16>) -> tensor<4xbf16>
return %1 : tensor<4xbf16>
}

func.func @round(%arg0: tensor<4xbf16>) -> tensor<4xbf16> {
%0 = tensor.empty() : tensor<4xbf16>
// CHECK: %[[C:.*]] = "ttnn.round"[[C:.*]]
%1 = "ttir.round"(%arg0, %0) <{decimals = 1 : i32}> : (tensor<4xbf16>, tensor<4xbf16>) -> tensor<4xbf16>
return %1 : tensor<4xbf16>
}

func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> {
%0 = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x3xf32>) -> tensor<1xi32>
// CHECK: [[VAL:%[0-9]+]] = "ttnn.full"(%{{[0-9]+}}) <{fillValue = 2.100000e+01 : f32}> : (!tt.device<#device>) -> tensor<1xi32, {{.*}}>
Expand Down

0 comments on commit 2250958

Please sign in to comment.