Skip to content

Commit

Permalink
Add CCL ops to TTNN dialect (#700) (#937)
Browse files Browse the repository at this point in the history
These ops are addead:
- all_gather
- reduce_scatter

Also add all_gather as a TTIR op
  • Loading branch information
gfengTT authored Oct 25, 2024
1 parent 6eaed3f commit ebb5398
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 4 deletions.
24 changes: 22 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,7 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {
let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
AnyRankedTensor:$output,
SI32Attr:$dim,

TT_OperandConstraintArrayAttr:$operand_constraints);
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -507,6 +506,27 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
}];
}

// CCL ops
def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> {
let summary = "All gather operation.";
let description = [{
All gather op.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let summary = "Conv2d operation.";
let description = [{
Expand Down
30 changes: 30 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,34 @@ def TTNN_DeallocOp : TTNN_Op<"dealloc"> {
let arguments = (ins AnyRankedTensor:$input);
}

def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
let summary = "All gather op.";
let description = [{
Tensor All Gather operation
}];

let arguments = (ins AnyRankedTensor:$input,
SI32Attr:$dim,
DefaultValuedAttr<SI32Attr, "1">:$num_links);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> {
let summary = "Reduce scatter op.";
let description = [{
Tensor Reduce Scatter operation
}];

let arguments = (ins AnyRankedTensor:$input,
SI32Attr:$scatter_split_dim,
TTNN_ReduceType:$math_op,
DefaultValuedAttr<SI32Attr, "1">:$num_links);
let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

#endif
20 changes: 20 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,24 @@ def TTNN_BufferType : I32EnumAttr<"BufferType", "TTNN Buffer Type",
let cppNamespace = "::mlir::tt::ttnn";
}

def TTNN_ReduceType_Sum : I32EnumAttrCase<"Sum", 0, "sum">;
def TTNN_ReduceType_Mean : I32EnumAttrCase<"Mean", 1, "mean">;
def TTNN_ReduceType_Max : I32EnumAttrCase<"Max", 2, "max">;
def TTNN_ReduceType_Min : I32EnumAttrCase<"Min", 3, "min">;
def TTNN_ReduceType_Std : I32EnumAttrCase<"Std", 4, "std">;
def TTNN_ReduceType_Var : I32EnumAttrCase<"Var", 5, "var">;

def TTNN_ReduceType: I32EnumAttr<"ReduceType", "TTNN Reduce Operation Type",
[
TTNN_ReduceType_Sum,
TTNN_ReduceType_Mean,
TTNN_ReduceType_Max,
TTNN_ReduceType_Min,
TTNN_ReduceType_Std,
TTNN_ReduceType_Var,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttnn::operations::reduction";
}

#endif
10 changes: 9 additions & 1 deletion include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ table DeallocOp {
in: tt.target.TensorRef;
}

table AllGatherOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
dim: uint32;
num_links: uint32;
}

union OpType {
GetDeviceOp,
ToMemoryConfigOp,
Expand All @@ -206,7 +213,8 @@ union OpType {
ReshapeOp,
SliceOp,
MaxPool2dOp,
DeallocOp
DeallocOp,
AllGatherOp,
}

table Operation {
Expand Down
24 changes: 23 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,27 @@ class SubtractOpConversionPattern
}
};

class AllGatherOpConversionPattern
: public OpConversionPattern<ttir::AllGatherOp> {
public:
using OpConversionPattern<ttir::AllGatherOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType type =
mlir::cast<RankedTensorType>(adaptor.getInput().getType());
Value device = getOrInsertDevice(rewriter, op);
tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
op.getLoc(), this->getTypeConverter()->convertType(type), device);

rewriter.replaceOpWithNewOp<ttnn::AllGatherOp>(
op, this->getTypeConverter()->convertType(op.getType()), emptyOp,
adaptor.getDim());
return success();
}
};

namespace mlir::tt {

void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Expand Down Expand Up @@ -765,7 +786,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
MatmulOpConversionPattern,
Conv2dOpConversionPattern,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern
SubtractOpConversionPattern,
AllGatherOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
// clang-format on
Expand Down
5 changes: 5 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,11 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
patterns.add<DefaultOpConversionPattern<ttnn::SoftmaxOp>,
DefaultOpConversionPattern<ttnn::EmbeddingOp>>(typeConverter,
ctx);

// CCL ops
//
patterns.add<DefaultOpConversionPattern<ttnn::AllGatherOp>>(typeConverter,
ctx);
}

} // namespace mlir::tt
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,22 @@ ::mlir::LogicalResult mlir::tt::ttir::SoftmaxOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AllGatherOp
//===----------------------------------------------------------------------===//

// AllGatherOp verification
::mlir::LogicalResult mlir::tt::ttir::AllGatherOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
int32_t dim = getDim();

if (dim >= inputType.getRank() || dim < -inputType.getRank()) {
return emitOpError("Invalid dimension for all gather op.");
}

return success();
}

//===----------------------------------------------------------------------===//
// GenericOp
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,4 +754,20 @@ ::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() {
return success();
}

::mlir::LogicalResult AllGatherOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
int32_t dim = getDim();

if (dim >= inputType.getRank() || dim < -inputType.getRank()) {
return emitOpError("Invalid dimension for all gather op.");
}

return success();
}

::mlir::LogicalResult ReduceScatterOp::verify() {
// TODO
return success();
}

} // namespace mlir::tt::ttnn
13 changes: 13 additions & 0 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,16 @@ createOp(FlatbufferObjectCache &cache, Conv2dOp op) {
op.getGroups());
}

::flatbuffers::Offset<::tt::target::ttnn::AllGatherOp>
createOp(FlatbufferObjectCache &cache, AllGatherOp op) {
auto input =
cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput()));
auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize);
return ::tt::target::ttnn::CreateAllGatherOp(*cache.fbb, input, output,
op.getDim(), op.getNumLinks());
}

template <typename EltwiseOp>
::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp>
createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) {
Expand Down Expand Up @@ -578,6 +588,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op,
if (auto conv2dOp = dyn_cast<Conv2dOp>(op); conv2dOp) {
return createOperation(cache, createOp(cache, conv2dOp), debugString);
}
if (auto allGatherOp = dyn_cast<AllGatherOp>(op); allGatherOp) {
return createOperation(cache, createOp(cache, allGatherOp), debugString);
}
if (auto concatOp = dyn_cast<ConcatOp>(op); concatOp) {
return createOperation(cache, createConcatOp(cache, concatOp), debugString);
}
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "hostdevcommon/common_values.hpp"
#include "impl/device/mesh_device.hpp"
#include "ttnn/device.hpp"
#include "ttnn/operations/ccl/all_gather/all_gather.hpp"
#include "ttnn/operations/conv/conv2d/conv2d.hpp"
#include "ttnn/operations/copy.hpp"
#include "ttnn/operations/core/core.hpp"
Expand Down
1 change: 1 addition & 0 deletions runtime/lib/ttnn/operations/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(TTNN_OPS_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp
${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp
${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp
Expand Down
21 changes: 21 additions & 0 deletions runtime/lib/ttnn/operations/ccl/all_gather.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "all_gather.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/ttnn/operations/utils.h"

namespace tt::runtime::ttnn::operations::ccl {
void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) {
ProgramTensorPool &tensorPool = context.getTensorPool();
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
int32_t dim = op->dim();
int32_t num_links = op->num_links();
::tt::tt_metal::MemoryConfig outputMemoryConfig =
utils::createMemoryConfig(op->out());
::ttnn::Tensor out =
::ttnn::all_gather(input, dim, num_links, outputMemoryConfig);
tensorPool.insert_or_assign(op->out()->global_id(), out);
}
} // namespace tt::runtime::ttnn::operations::ccl
15 changes: 15 additions & 0 deletions runtime/lib/ttnn/operations/ccl/all_gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTNN_RUNTIME_ALL_GATHER_H
#define TTNN_RUNTIME_ALL_GATHER_H

#include "tt/runtime/ttnn/types.h"
#include "ttmlir/Target/TTNN/program_generated.h"

namespace tt::runtime::ttnn::operations::ccl {
void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context);
} // namespace tt::runtime::ttnn::operations::ccl

#endif
4 changes: 4 additions & 0 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0
#include "operations/ccl/all_gather.h"
#include "operations/context/get_device.h"
#include "operations/conv/conv2d.h"
#include "operations/creation/empty.h"
Expand Down Expand Up @@ -113,6 +114,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) {
case ::tt::target::ttnn::OpType::MaxPool2dOp: {
return operations::pool::run(op->type_as_MaxPool2dOp(), context);
}
case ::tt::target::ttnn::OpType::AllGatherOp: {
return operations::ccl::run(op->type_as_AllGatherOp(), context);
}
default: {
throw std::runtime_error("Unsupported operation type");
}
Expand Down
11 changes: 11 additions & 0 deletions test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x1x32x128xbf16>
// CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]]
%1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16>
return %1 : tensor<1x1x32x128xbf16>
}
}
10 changes: 10 additions & 0 deletions test/ttmlir/Dialect/TTNN/ccl/all_gather_negative.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s
// CHECK: error: 'ttir.all_gather' op Invalid dimension for all gather op
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> {
%0 = tensor.empty() : tensor<1x1x32x128xbf16>
%1 = "ttir.all_gather"(%arg0, %0) <{dim = 4 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16>
return %1 : tensor<1x1x32x128xbf16>
}
}
15 changes: 15 additions & 0 deletions test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=4,1,1" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// UNSUPPORTED: true
// REQUIRES: multi-chip
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>

func.func @forward(%arg0: tensor<1x1x32x32xf32>) -> tensor<1x1x32x128xf32> {
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x1x32x128xf32>
// CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]]
%1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xf32>, tensor<1x1x32x128xf32>) -> tensor<1x1x32x128xf32>
return %1 : tensor<1x1x32x128xf32>
}

0 comments on commit ebb5398

Please sign in to comment.