From 1ddbd0eec3faff38744ec2c5e7112a6c5e86d29b Mon Sep 17 00:00:00 2001 From: Darko Golubovic <158151710+dgolubovicTT@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:37:43 +0200 Subject: [PATCH] Implement reduction avg op end to end (#271) --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 7 +++++++ include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 7 +++++++ include/ttmlir/Target/TTNN/program.fbs | 1 + lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 1 + lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp | 1 + lib/Dialect/TTIR/Transforms/Passes.cpp | 4 +++- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 5 +++++ runtime/lib/ttnn/program.cpp | 14 ++++++++++++++ test/ttmlir/Dialect/TTNN/simple_avg.mlir | 15 +++++++++++++++ 9 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 test/ttmlir/Dialect/TTNN/simple_avg.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index f17bbd018..a0091c97e 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -217,6 +217,13 @@ def TTIR_SumOp : TTIR_ReductionOp<"sum"> { }]; } +def TTIR_AvgOp : TTIR_ReductionOp<"avg"> { + let summary = "Average reduction op."; + let description = [{ + Average reduction op. + }]; +} + def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> { let summary = "Softmax operation."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 4f060c67d..a9c5f0b69 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -124,6 +124,13 @@ def TTNN_SumOp : TTNN_ReductionOp<"sum"> { }]; } +def TTNN_AvgOp : TTNN_ReductionOp<"avg"> { + let summary = "Average reduction op."; + let description = [{ + Average reduction op. + }]; +} + def TTNN_ReluOp : TTNN_ElementwiseOp<"relu"> { let summary = "Eltwise ReLU."; let description = [{ diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index a2251c1cb..f13e8ab7d 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -40,6 +40,7 @@ table EltwiseOp { enum ReductionOpType: uint32 { Sum = 0, + Avg = 1, } table ReductionOp { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 9501ca7f2..7d4441c7d 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -144,6 +144,7 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, ElementwiseBinaryOpConversionPattern, ElementwiseBinaryOpConversionPattern, ReductionOpConversionPattern, + ReductionOpConversionPattern, SoftmaxOpConversionPattern, MatmulOpConversionPattern >(typeConverter, ctx); diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index d0fe6c9d1..b4103726f 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -119,6 +119,7 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, // Reduction ops // patterns.add>(typeConverter, ctx); + patterns.add>(typeConverter, ctx); } } // namespace mlir::tt diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 836108ca4..5d372adea 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -534,7 +534,9 @@ class TTIRLayout : public impl::TTIRLayoutBase { TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, - TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, + TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutOperandsRewriter, TTIRLayoutFuncReturnRewriter>( &getContext()); diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 5660d8607..d9dc2e378 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -142,6 +142,8 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { ::tt::target::ttnn::ReductionOpType type; if constexpr (std::is_same_v) { type = ::tt::target::ttnn::ReductionOpType::Sum; + } else if constexpr (std::is_same_v) { + type = ::tt::target::ttnn::ReductionOpType::Avg; } else { llvm_unreachable("unhandled ReductionOp"); } @@ -209,6 +211,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto sumOp = dyn_cast(op); sumOp) { return createOperation(cache, createReductionOp(cache, sumOp), debugString); } + if (auto avgOp = dyn_cast(op); avgOp) { + return createOperation(cache, createReductionOp(cache, avgOp), debugString); + } if (auto softmaxOp = dyn_cast(op); softmaxOp) { return createOperation(cache, createSoftmaxOp(cache, softmaxOp), debugString); diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 558eccb22..6e122a56c 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -138,6 +138,20 @@ run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device, liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); break; } + case ::tt::target::ttnn::ReductionOpType::Avg: { + auto &in = *liveTensors.at(op->in()->global_id()); + + const auto *dim_arg_fb_ptr = op->dim_arg(); + std::optional> dim_arg = + dim_arg_fb_ptr ? std::make_optional(std::vector( + dim_arg_fb_ptr->begin(), dim_arg_fb_ptr->end())) + : std::nullopt; + + tensorPool.push_back(::ttnn::avg(in, dim_arg, op->keep_dim())); + + liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); + break; + } } } diff --git a/test/ttmlir/Dialect/TTNN/simple_avg.mlir b/test/ttmlir/Dialect/TTNN/simple_avg.mlir new file mode 100644 index 000000000..6d367c96b --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/simple_avg.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {tt.system_desc = #tt.system_desc<[{arch = , grid = 8x8, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [], [<0, 0, 0, 0>]>} { + func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> { + // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]] + %0 = tensor.empty() : tensor<512x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] + // CHECK: %[[C:.*]] = "ttnn.avg"[[C:.*]] + %1 = "ttir.avg"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> + // CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]] + // CHECK: "ttnn.close_device"[[C:.*]] + return %1 : tensor<512x32xbf16> + } +} \ No newline at end of file