Skip to content

Commit

Permalink
Add e2e support for reduce sum op (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
derdeljanTT authored Jul 9, 2024
1 parent 6cb6cb6 commit c0b43fc
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 8 deletions.
26 changes: 26 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,32 @@ def TTIR_MultiplyOp : TTIR_ElementwiseBinaryOp<"multiply"> {
}];
}

class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> : TTIR_DPSOp<mnemonic, traits> {
let summary = "Reduction op.";
let description = [{
Reduction op.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
BoolAttr:$keep_dim,
OptionalAttr<I32ArrayAttr>:$dim_arg,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
}

def TTIR_SumOp : TTIR_ReductionOp<"sum"> {
let summary = "Sum reduction op.";
let description = [{
Sum reduction op.
}];
}

// ANCHOR: adding_an_op_matmul_ttir
def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {
let summary = "Matrix multiply operation.";
Expand Down
25 changes: 25 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,31 @@ def TTNN_MultiplyOp : TTNN_ElementwiseBinaryOp<"multiply"> {
}];
}

class TTNN_ReductionOp<string mnemonic, list<Trait> traits = []> : TTNN_NamedDPSOp<mnemonic, traits> {
let summary = "Reduction op.";
let description = [{
Reduction op.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
BoolAttr:$keep_dim,
OptionalAttr<I32ArrayAttr>:$dim_arg);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
}

def TTNN_SumOp : TTNN_ReductionOp<"sum"> {
let summary = "Sum reduction op.";
let description = [{
Sum reduction op.
}];
}

// ANCHOR: adding_an_op_matmul_ttnn
def TTNN_MatmulOp : TTNN_NamedDPSOp<"matmul"> {
let arguments = (ins AnyRankedTensor:$a,
Expand Down
13 changes: 13 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ table EltwiseOp {
out: tt.target.TensorRef;
}

enum ReductionOpType: uint32 {
Sum = 0,
}

table ReductionOp {
type: ReductionOpType;
in: tt.target.TensorRef;
out: tt.target.TensorRef;
dim_arg: [int32];
keep_dim: bool;
}

// ANCHOR: adding_an_op_matmul_fbs
table MatmulOp {
in0: tt.target.TensorRef;
Expand All @@ -51,6 +63,7 @@ union OpType {
FullOp,
EltwiseOp,
MatmulOp,
ReductionOp
}

table Operation {
Expand Down
37 changes: 37 additions & 0 deletions include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,43 @@ inline DataType elementTypeToDataType(Type elementType) {
return dtype;
}

template <typename AttrType, typename ValueType>
struct ArrayAttrToFlatbufferSerializer {
static flatbuffers::Offset<flatbuffers::Vector<ValueType>>
impl(FlatbufferObjectCache &cache, const ArrayAttr &arrayAttr) {
assert(false && "unsupported array attr to value type serializer");
}
};

template <typename ValueType>
struct ArrayAttrToFlatbufferSerializer<IntegerAttr, ValueType> {
static flatbuffers::Offset<flatbuffers::Vector<ValueType>>
impl(FlatbufferObjectCache &cache, const ::mlir::ArrayAttr &arrayAttr) {
return cache.fbb->CreateVector<ValueType>(
arrayAttr.size(), [&arrayAttr](size_t i) {
return static_cast<ValueType>(
mlir::cast<IntegerAttr>(arrayAttr[i]).getInt());
});
}
};

template <typename AttrType, typename ValueType>
inline flatbuffers::Offset<flatbuffers::Vector<ValueType>>
arrayAttrToFlatbuffer(FlatbufferObjectCache &cache,
const ::mlir::ArrayAttr &arrayAttr) {
return ArrayAttrToFlatbufferSerializer<AttrType, ValueType>::impl(cache,
arrayAttr);
}

template <typename AttrType, typename ValueType>
inline flatbuffers::Offset<flatbuffers::Vector<ValueType>>
arrayAttrToFlatbuffer(FlatbufferObjectCache &cache,
const std::optional<::mlir::ArrayAttr> &arrayAttrOpt) {
return arrayAttrOpt.has_value() ? arrayAttrToFlatbuffer<AttrType, ValueType>(
cache, arrayAttrOpt.value())
: 0;
}

inline flatbuffers::Offset<::tt::target::MemoryDesc>
memrefAttrToFlatbuffer(FlatbufferObjectCache &cache, MemRefType memref) {
auto shapeInt64 = memref.getShape();
Expand Down
14 changes: 8 additions & 6 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,14 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
}
{
RewritePatternSet patterns(&getContext());
patterns.add<TTIRLayoutOperandsRewriter<GenericOp>,
TTIRLayoutOperandsRewriter<AddOp>,
TTIRLayoutOperandsRewriter<MultiplyOp>,
TTIRLayoutOperandsRewriter<SubtractOp>,
TTIRLayoutOperandsRewriter<MatmulOp>,
TTIRLayoutFuncReturnRewriter>(&getContext());
patterns
.add<TTIRLayoutOperandsRewriter<GenericOp>,
TTIRLayoutOperandsRewriter<AddOp>,
TTIRLayoutOperandsRewriter<MultiplyOp>,
TTIRLayoutOperandsRewriter<SubtractOp>,
TTIRLayoutOperandsRewriter<MatmulOp>,
TTIRLayoutOperandsRewriter<SumOp>, TTIRLayoutFuncReturnRewriter>(
&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/TTNN/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ class TTIRToTTNNOpRewriter : public OpRewritePattern<TTIROp> {
}
};

template <typename TTIROp, typename TTNNOp>
class TTIRToTTNNReductionOpRewriter : public OpRewritePattern<TTIROp> {
using OpRewritePattern<TTIROp>::OpRewritePattern;

LogicalResult matchAndRewrite(TTIROp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<TTNNOp>(
op, op.getResult().getType(), op.getInput(), op.getOutput(),
op.getKeepDim(), op.getDimArg().value_or(nullptr));
return success();
}
};

// ANCHOR: adding_an_op_matmul_op_rewriter
template <typename TTIROp, typename TTNNOp>
class TTIRToTTNNBinaryOpRewriter : public OpRewritePattern<TTIROp> {
Expand Down Expand Up @@ -135,6 +148,7 @@ class ConvertTTIRToTTNN
TTIRToTTNNOpRewriter<ttir::MultiplyOp, MultiplyOp>,
TTIRToTTNNOpRewriter<ttir::SubtractOp, SubtractOp>,
TTIRToTTNNBinaryOpRewriter<ttir::MatmulOp, MatmulOp>,
TTIRToTTNNReductionOpRewriter<ttir::SumOp, SumOp>,
TensorEmptyToFullRewriter>(&getContext());
// ANCHOR_END: adding_an_op_matmul_rewrite_pattern_set
FrozenRewritePatternSet patternSet(std::move(patterns));
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/TTNN/Transforms/SerializeToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,27 @@ createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
getOperandThroughDPSOps(op.getOutputs().front())));
}

template <typename ReductionOp>
::flatbuffers::Offset<::tt::target::ttnn::ReductionOp>
createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) {
::tt::target::ttnn::ReductionOpType type;
if constexpr (std::is_same_v<ReductionOp, SumOp>) {
type = ::tt::target::ttnn::ReductionOpType::Sum;
} else {
llvm_unreachable("unhandled ReductionOp");
}

auto in =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto output = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));
auto dim_arg =
arrayAttrToFlatbuffer<mlir::IntegerAttr, int>(cache, op.getDimArg());

return ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output,
dim_arg, op.getKeepDim());
}

::flatbuffers::Offset<::tt::target::ttnn::Operation>
emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
std::string const &debugString) {
Expand Down Expand Up @@ -158,6 +179,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto matmulOp = dyn_cast<MatmulOp>(op); matmulOp) {
return createOperation(cache, createOp(cache, matmulOp), debugString);
}
if (auto sumOp = dyn_cast<SumOp>(op); sumOp) {
return createOperation(cache, createReductionOp(cache, sumOp), debugString);
}

llvm_unreachable("unhandled op in emitTTNNOperation");
}
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TTNN/Transforms/TTNNToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class ConvertTTNNToEmitC
TTNNToEmitCOpaqueRewriter<MultiplyOp>,
TTNNToEmitCOpaqueRewriter<SubtractOp>,
TTNNToEmitCOpaqueRewriter<MatmulOp>,
TTNNToEmitCOpaqueRewriter<SumOp>,
TTNNToEmitCOpaqueRewriter<CloseDeviceOp>>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
Expand All @@ -164,6 +165,10 @@ class ConvertTTNNToEmitC
module.getBody()->push_front(builder.create<emitc::IncludeOp>(
module.getLoc(), "ttnn/operations/creation.hpp",
/*isStandard=*/false));
module.getBody()->push_front(builder.create<emitc::IncludeOp>(
module.getLoc(),
"ttnn/operations/reduction/generic/generic_reductions.hpp",
/*isStandard=*/false));
}
}

Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#pragma clang diagnostic ignored "-Wunused-local-typedef"
#pragma clang diagnostic ignored "-Wunused-function"
#define FMT_HEADER_ONLY
#include "ttnn//operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/device.hpp"
#include "ttnn/operations/binary.hpp"
#include "ttnn/operations/core.hpp"
Expand Down
5 changes: 3 additions & 2 deletions runtime/lib/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ static std::string asJson(void const *fbb, uint8_t const *binarySchema,
}

std::string text;
if (::flatbuffers::GenerateText(parser, fbb, &text)) {
throw std::runtime_error("Failed to generate JSON");
const char *err = ::flatbuffers::GenerateText(parser, fbb, &text);
if (err) {
throw std::runtime_error("Failed to generate JSON: " + std::string(err));
}

return text;
Expand Down
26 changes: 26 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include <list>
#include <optional>
#include <unordered_map>

#include "tt/runtime/detail/ttnn.h"
Expand Down Expand Up @@ -88,6 +89,28 @@ run(::tt::target::ttnn::EltwiseOp const *op, ::ttnn::Device &device,
}
}

static void
run(::tt::target::ttnn::ReductionOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
switch (op->type()) {
case ::tt::target::ttnn::ReductionOpType::Sum: {
auto &in = *liveTensors.at(op->in()->global_id());

const auto *dim_arg_fb_ptr = op->dim_arg();
std::optional<vector<int>> dim_arg =
dim_arg_fb_ptr ? std::make_optional(std::vector<int>(
dim_arg_fb_ptr->begin(), dim_arg_fb_ptr->end()))
: std::nullopt;

tensorPool.push_back(::ttnn::sum(in, dim_arg, op->keep_dim()));

liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
break;
}
}
}

// ANCHOR: adding_an_op_matmul_runtime
static void
run(::tt::target::ttnn::MatmulOp const *op, ::ttnn::Device &device,
Expand Down Expand Up @@ -127,6 +150,9 @@ run(::tt::target::ttnn::Operation const *op, ::ttnn::Device &device,
case ::tt::target::ttnn::OpType::MatmulOp: {
return run(op->type_as_MatmulOp(), device, liveTensors, tensorPool);
}
case ::tt::target::ttnn::OpType::ReductionOp: {
return run(op->type_as_ReductionOp(), device, liveTensors, tensorPool);
}
default:
throw std::runtime_error("Unsupported operation type");
}
Expand Down
15 changes: 15 additions & 0 deletions test/ttmlir/Dialect/TTNN/simple_sum.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ttmlir-opt --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|tile|any_device|any_device_tile>
module attributes {torch.debug_module_name = "_lambda", tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = <8x8>, l1_size = 1048576, num_dram_channels = 12, dram_channel_size = 1048576}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.full"[[C:.*]]
%0 = tensor.empty() : tensor<512x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]]
%1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.to_memory_config"[[C:.*]]
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<512x32xbf16>
}
}

0 comments on commit c0b43fc

Please sign in to comment.