diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index c0e5a78b84..6e7eeeb55f 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -326,6 +326,47 @@ def TTIR_ReluOp : TTIR_ElementwiseUnaryOp<"relu"> { }]; } +def TTIR_RoundOp : TTIR_DPSOp<"round"> { + let summary = "Eltwise round."; + let description = [{ + Eltwise round operation. + }]; + let arguments = (ins + AnyRankedTensor:$input, + AnyRankedTensor:$output, + DefaultValuedAttr:$decimals + ); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let results = (outs AnyRankedTensor:$result); +} + +def TTIR_RoundNearestEvenOp : TTIR_DPSOp<"roundnearesteven"> { + let summary = "Eltwise round towards nearest even."; + let description = [{ + Rounds a number to the nearest value. If the number is exactly halfway between two values, the value is rounded to the nearest even value. + + Example: + // %operand = [-2.5, 0.4, 0.5, 0.6, 2.5] + %result = "stablehlo.round_nearest_even"(%operand) : (tensor<5xf64>) -> tensor<5xf64> + // %result: [-2.0, 0.0, 0.0, 1.0, 2.0] + }]; + let arguments = (ins + AnyRankedTensor:$input, + AnyRankedTensor:$output, + DefaultValuedAttr:$decimals + ); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let results = (outs AnyRankedTensor:$result); +} + def TTIR_RsqrtOp : TTIR_ElementwiseUnaryOp<"rsqrt"> { let summary = "Eltwise reciprocal square root."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 7567364d10..87f6ae5448 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -296,6 +296,20 @@ def TTNN_ReluOp : TTNN_ElementwiseUnaryOp<"relu", }]; } +def TTNN_RoundOp : TTNN_Op<"round"> { + let summary = "Eltwise round operation."; + let description = [{ + Eltwise round with decimal option to choose between Banker's rounding or normal rounding. + }]; + + let arguments = (ins + Variadic:$inputs, + I32Attr:$decimals + ); + + let results = (outs Variadic:$result); +} + def TTNN_SinOp : TTNN_ElementwiseUnaryOp<"sin"> { let summary = "Eltwise sine."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 4ba5443ad5..2495be1388 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -140,7 +140,8 @@ enum EltwiseOpType: uint32 { LeakyRelu, Scatter, Tan, - Tanh + Tanh, + Round } table ClampOpParams { @@ -148,6 +149,10 @@ table ClampOpParams { max: float; } +table RoundOpParams { + decimals: int32; +} + table EltwiseOpWithFloatParams { parameter: float; } @@ -155,6 +160,7 @@ table EltwiseOpWithFloatParams { union EltwiseOpParams { ClampOpParams, EltwiseOpWithFloatParams, + RoundOpParams, } table EltwiseOp { diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index b541d0a3e3..ab580293f1 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1755,6 +1755,41 @@ class StableHLOToTTIROpReverseOpConversionPattern } }; +template +class StableHLOToTTIRRoundOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SrcOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outputType = mlir::cast( + this->getTypeConverter()->convertType(srcOp.getResult().getType())); + + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + + if (isa(srcOp)) { + rewriter.replaceOpWithNewOp(srcOp, outputType, + adaptor.getOperand(), outputTensor, + rewriter.getI32IntegerAttr(1)); + } else if (isa(srcOp)) { + rewriter.replaceOpWithNewOp(srcOp, outputType, + adaptor.getOperand(), outputTensor, + rewriter.getI32IntegerAttr(0)); + } else { + return rewriter.notifyMatchFailure( + srcOp, "ttir::RoundOp only supports stablehlo:RoundOp or " + "stablehlo::RoundNearestEvenOp"); + } + + return success(); + } +}; + void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { @@ -1963,6 +1998,15 @@ void addReverseOpConversionPattern(MLIRContext *ctx, patterns.add(typeConverter, ctx); } +void addRoundOpConversionPattern(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>(typeConverter, ctx); + patterns.add>( + typeConverter, ctx); +} + } // namespace namespace mlir::tt { @@ -1992,6 +2036,7 @@ void populateStableHLOToTTIRPatterns(MLIRContext *ctx, addScatterOpConversionPatterns(ctx, patterns, typeConverter); addReturnOpConversionPatterns(ctx, patterns, typeConverter); addReverseOpConversionPattern(ctx, patterns, typeConverter); + addRoundOpConversionPattern(ctx, patterns, typeConverter); } } // namespace mlir::tt diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index a2b63a1bce..91261e35b6 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -965,6 +965,77 @@ class TypecastOpConversionPattern } }; +class BroadcastOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(ttir::BroadcastOp srcOp, ttir::BroadcastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Fold this operation into all consumer ops. It will only work with TTNN + // ops that support implicit broadcasting. We expect each Op's verify + // function to assert their arguments to verify that they can broadcast. + + if (srcOp->getUsers().empty()) { + // This broadcast chain has already been replaced. + rewriter.eraseOp(srcOp); + return success(); + } + + mlir::Value input = srcOp.getOperand(0); + + mlir::Operation *nextOp = srcOp; + while (isa(*nextOp->getUsers().begin())) { + assert(nextOp->hasOneUse() && + "Broadcast with multiple uses are not supported"); + nextOp = *nextOp->getUsers().begin(); + if (nextOp->getUsers().empty()) { + // This broadcast chain has already been replaced. + rewriter.eraseOp(srcOp); + return success(); + } + } + + rewriter.replaceAllOpUsesWith(nextOp, input); + rewriter.eraseOp(srcOp); + + return success(); + } +}; + +template +class RoundOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TTIROpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!isa(op) && !isa(op)) { + return rewriter.notifyMatchFailure( + op, "ttnn::RoundOp only supports ttir:RoundOp or " + "ttir::RoundNearestEvenOp"); + } + if (isa(op) && adaptor.getDecimals() == 0) { + return rewriter.notifyMatchFailure( + op, "ttir::RoundOp requires decimals != 0"); + } + if (isa(op) && adaptor.getDecimals() != 0) { + return rewriter.notifyMatchFailure( + op, "ttir::RoundNearestEvenOp requires decimals == 0"); + } + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), + adaptor.getInput(), adaptor.getDecimals()); + + return success(); + } +}; + class SubtractOpConversionPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1178,6 +1249,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ReductionOpConversionPattern, ReductionOpConversionPattern, ElementwiseUnaryWithFloatParameterOpConversionPattern, + RoundOpConversionPattern, + RoundOpConversionPattern, EmbeddingOpConversionPattern, EmbeddingBackwardOpConversionPattern, SoftmaxOpConversionPattern, diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 624ceddc33..d9e528cee1 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -678,6 +678,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, patterns.add, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, @@ -692,6 +693,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, + DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, DefaultOpConversionPattern, diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index b7ff1d7d57..03a160f519 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -530,6 +530,10 @@ createEltwiseOpParams(FlatbufferObjectCache &cache, EltwiseOp op) { return ::tt::target::ttnn::CreateEltwiseOpWithFloatParams(*cache.fbb, parameter); } + if constexpr (std::is_same_v) { + auto decimals = op.getDecimals(); + return ::tt::target::ttnn::CreateRoundOpParams(*cache.fbb, decimals); + } } ::flatbuffers::Offset<::tt::target::ttnn::UpdateCacheOp> @@ -569,6 +573,12 @@ createNonDPSEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { params = createEltwiseOpParams( cache, op) .Union(); + } else if (std::is_same_v) { + type = ::tt::target::ttnn::EltwiseOpType::Round; + paramsType = ::tt::target::ttnn::EltwiseOpParams::RoundOpParams; + params = createEltwiseOpParams( + cache, op) + .Union(); } else { llvm_unreachable("unhandled non-DPS EltwiseOp"); } @@ -1098,6 +1108,10 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createNonDPSEltwiseOp(cache, clampOp), debugString, locInfo); } + if (auto roundOp = dyn_cast(op); roundOp) { + return createOperation(cache, createNonDPSEltwiseOp(cache, roundOp), + debugString, locInfo); + } if (auto conv2dOp = dyn_cast(op); conv2dOp) { return createOperation(cache, createOp(cache, conv2dOp), debugString, locInfo); diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp index 632377c2f4..0e13658568 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.cpp @@ -43,6 +43,21 @@ static void runEltwiseUnaryCompositeClampOp( tensorPool.insert_or_assign(op->out()->global_id(), out); } +static void runEltwiseUnaryCompositeRoundOp( + const ::tt::target::ttnn::EltwiseOp *op, ProgramTensorPool &tensorPool, + const std::function<::ttnn::Tensor(const ::ttnn::Tensor &, int, + const ::tt::tt_metal::MemoryConfig &)> + &ttnnOp) { + ::ttnn::Tensor *in = nullptr; + getEltwiseUnaryOpInputTensor(op, tensorPool, &in); + + int32_t decimals = op->params_as_RoundOpParams()->decimals(); + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + ::tt::runtime::ttnn::utils::createMemoryConfig(op->out()); + ::ttnn::Tensor out = ttnnOp(*in, decimals, outputMemoryConfig); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} + void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { ProgramTensorPool &tensorPool = context.getTensorPool(); switch (op->type()) { @@ -58,6 +73,10 @@ void run(const ::tt::target::ttnn::EltwiseOp *op, ProgramContext &context) { runEltwiseUnaryCompositeOp(op, tensorPool, ::ttnn::log1p); break; } + case ::tt::target::ttnn::EltwiseOpType::Round: { + runEltwiseUnaryCompositeRoundOp(op, tensorPool, ::ttnn::round); + break; + } default: LOG_FATAL("Unsupported Eltwise Binary Composite operation"); } diff --git a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h index f0c0c834c9..bcc441d142 100644 --- a/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h +++ b/runtime/lib/ttnn/operations/eltwise/unary/unary_composite.h @@ -15,6 +15,7 @@ inline bool isUnaryCompositeOp(const ::tt::target::ttnn::EltwiseOp *op) { case ::tt::target::ttnn::EltwiseOpType::Cbrt: case ::tt::target::ttnn::EltwiseOpType::Clamp: case ::tt::target::ttnn::EltwiseOpType::Log1p: + case ::tt::target::ttnn::EltwiseOpType::Round: return true; default: return false; diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/round_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/round_op.mlir new file mode 100644 index 0000000000..d49d206222 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/round_op.mlir @@ -0,0 +1,16 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s +module @jit_eltwise_round attributes {} { + func.func public @test_round(%arg0: tensor<4xbf16>) -> tensor<4xbf16> { + %0 = stablehlo.round_nearest_afz %arg0 : tensor<4xbf16> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.round"[[C:.*]] + return %0 : tensor<4xbf16> + } + func.func public @test_roundnearesteven(%arg0: tensor<4xbf16>) -> tensor<4xbf16> { + %0 = stablehlo.round_nearest_even %arg0 : tensor<4xbf16> + // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] + // CHECK: %[[C:.*]] = "ttir.roundnearesteven"[[C:.*]] + return %0 : tensor<4xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_round.mlir b/test/ttmlir/Dialect/TTNN/simple_round.mlir new file mode 100644 index 0000000000..6d3de5c7e8 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_round.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s +module attributes {} { + func.func @roundnearesteven(%arg0: tensor<4xbf16>) -> tensor<4xbf16> { + %0 = tensor.empty() : tensor<4xbf16> + // CHECK: %[[C:.*]] = "ttnn.round"[[C:.*]] + %1 = "ttir.roundnearesteven"(%arg0, %0) <{decimals = 0 : i32}> : (tensor<4xbf16>, tensor<4xbf16>) -> tensor<4xbf16> + return %1 : tensor<4xbf16> + } + func.func @round(%arg0: tensor<4xbf16>) -> tensor<4xbf16> { + %0 = tensor.empty() : tensor<4xbf16> + // CHECK: %[[C:.*]] = "ttnn.round"[[C:.*]] + %1 = "ttir.round"(%arg0, %0) <{decimals = 1 : i32}> : (tensor<4xbf16>, tensor<4xbf16>) -> tensor<4xbf16> + return %1 : tensor<4xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index a0452f01f8..202f8419f4 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -271,6 +271,20 @@ func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tens // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} } +func.func @roundnearesteven(%arg0: tensor<4xbf16>) -> tensor<4xbf16> { + %0 = tensor.empty() : tensor<4xbf16> + // CHECK: %[[C:.*]] = "ttnn.round"[[C:.*]] + %1 = "ttir.roundnearesteven"(%arg0, %0) <{decimals = 0 : i32}> : (tensor<4xbf16>, tensor<4xbf16>) -> tensor<4xbf16> + return %1 : tensor<4xbf16> +} + +func.func @round(%arg0: tensor<4xbf16>) -> tensor<4xbf16> { + %0 = tensor.empty() : tensor<4xbf16> + // CHECK: %[[C:.*]] = "ttnn.round"[[C:.*]] + %1 = "ttir.round"(%arg0, %0) <{decimals = 1 : i32}> : (tensor<4xbf16>, tensor<4xbf16>) -> tensor<4xbf16> + return %1 : tensor<4xbf16> +} + func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> { %0 = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x3xf32>) -> tensor<1xi32> // CHECK: [[VAL:%[0-9]+]] = "ttnn.full"(%{{[0-9]+}}) <{fillValue = 2.100000e+01 : f32}> : (!tt.device<#device>) -> tensor<1xi32, {{.*}}>