diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index aee213b5a..9eae608b7 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -477,8 +477,7 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> { let arguments = (ins Variadic:$inputs, AnyRankedTensor:$output, SI32Attr:$dim, - - TT_OperandConstraintArrayAttr:$operand_constraints); + TT_OperandConstraintArrayAttr:$operand_constraints); let results = (outs AnyRankedTensor:$result); @@ -507,6 +506,27 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> { }]; } +// CCL ops +def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> { + let summary = "All gather operation."; + let description = [{ + All gather op. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + SI32Attr:$dim, + TT_OperandConstraintArrayAttr:$operand_constraints); + + let results = (outs AnyRankedTensor:$result); + + let extraClassDeclaration = [{ + MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } + }]; + + let hasVerifier = 1; +} + def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { let summary = "Conv2d operation."; let description = [{ diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 46d0fcf52..58c078232 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -547,4 +547,34 @@ def TTNN_DeallocOp : TTNN_Op<"dealloc"> { let arguments = (ins AnyRankedTensor:$input); } +def TTNN_AllGatherOp: TTNN_Op<"all_gather"> { + let summary = "All gather op."; + let description = [{ + Tensor All Gather operation + }]; + + let arguments = (ins AnyRankedTensor:$input, + SI32Attr:$dim, + DefaultValuedAttr:$num_links); + + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; +} + +def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> { + let summary = "Reduce scatter op."; + let description = [{ + Tensor Reduce Scatter operation + }]; + + let arguments = (ins AnyRankedTensor:$input, + SI32Attr:$scatter_split_dim, + TTNN_ReduceType:$math_op, + DefaultValuedAttr:$num_links); + let results = (outs AnyRankedTensor:$result); + + let hasVerifier = 1; +} + #endif diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td index c90d204a8..d40d05113 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsEnums.td @@ -57,4 +57,24 @@ def TTNN_BufferType : I32EnumAttr<"BufferType", "TTNN Buffer Type", let cppNamespace = "::mlir::tt::ttnn"; } +def TTNN_ReduceType_Sum : I32EnumAttrCase<"Sum", 0, "sum">; +def TTNN_ReduceType_Mean : I32EnumAttrCase<"Mean", 1, "mean">; +def TTNN_ReduceType_Max : I32EnumAttrCase<"Max", 2, "max">; +def TTNN_ReduceType_Min : I32EnumAttrCase<"Min", 3, "min">; +def TTNN_ReduceType_Std : I32EnumAttrCase<"Std", 4, "std">; +def TTNN_ReduceType_Var : I32EnumAttrCase<"Var", 5, "var">; + +def TTNN_ReduceType: I32EnumAttr<"ReduceType", "TTNN Reduce Operation Type", + [ + TTNN_ReduceType_Sum, + TTNN_ReduceType_Mean, + TTNN_ReduceType_Max, + TTNN_ReduceType_Min, + TTNN_ReduceType_Std, + TTNN_ReduceType_Var, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt::ttnn::operations::reduction"; +} + #endif diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index bc98664db..bafd202e5 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -187,6 +187,13 @@ table DeallocOp { in: tt.target.TensorRef; } +table AllGatherOp { + in: tt.target.TensorRef; + out: tt.target.TensorRef; + dim: uint32; + num_links: uint32; +} + union OpType { GetDeviceOp, ToMemoryConfigOp, @@ -206,7 +213,8 @@ union OpType { ReshapeOp, SliceOp, MaxPool2dOp, - DeallocOp + DeallocOp, + AllGatherOp, } table Operation { diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index c9052c866..f786a4a5d 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -718,6 +718,27 @@ class SubtractOpConversionPattern } }; +class AllGatherOpConversionPattern + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType type = + mlir::cast(adaptor.getInput().getType()); + Value device = getOrInsertDevice(rewriter, op); + tensor::EmptyOp emptyOp = rewriter.create( + op.getLoc(), this->getTypeConverter()->convertType(type), device); + + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), emptyOp, + adaptor.getDim()); + return success(); + } +}; + namespace mlir::tt { void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, @@ -765,7 +786,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns, MatmulOpConversionPattern, Conv2dOpConversionPattern, MaxPool2dOpConversionPattern, - SubtractOpConversionPattern + SubtractOpConversionPattern, + AllGatherOpConversionPattern >(typeConverter, ctx); // ANCHOR_END: op_rewriter_pattern_set // clang-format on diff --git a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp index 946862c23..f46eb1ab8 100644 --- a/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp +++ b/lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp @@ -581,6 +581,11 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx, patterns.add, DefaultOpConversionPattern>(typeConverter, ctx); + + // CCL ops + // + patterns.add>(typeConverter, + ctx); } } // namespace mlir::tt diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 6f1426054..5e88470cf 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -752,6 +752,22 @@ ::mlir::LogicalResult mlir::tt::ttir::SoftmaxOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AllGatherOp +//===----------------------------------------------------------------------===// + +// AllGatherOp verification +::mlir::LogicalResult mlir::tt::ttir::AllGatherOp::verify() { + ::mlir::RankedTensorType inputType = getInput().getType(); + int32_t dim = getDim(); + + if (dim >= inputType.getRank() || dim < -inputType.getRank()) { + return emitOpError("Invalid dimension for all gather op."); + } + + return success(); +} + //===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 64e01af75..feed3a7b3 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -754,4 +754,20 @@ ::mlir::LogicalResult mlir::tt::ttnn::SoftmaxOp::verify() { return success(); } +::mlir::LogicalResult AllGatherOp::verify() { + ::mlir::RankedTensorType inputType = getInput().getType(); + int32_t dim = getDim(); + + if (dim >= inputType.getRank() || dim < -inputType.getRank()) { + return emitOpError("Invalid dimension for all gather op."); + } + + return success(); +} + +::mlir::LogicalResult ReduceScatterOp::verify() { + // TODO + return success(); +} + } // namespace mlir::tt::ttnn diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 40ed701fd..d6093f448 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -250,6 +250,16 @@ createOp(FlatbufferObjectCache &cache, Conv2dOp op) { op.getGroups()); } +::flatbuffers::Offset<::tt::target::ttnn::AllGatherOp> +createOp(FlatbufferObjectCache &cache, AllGatherOp op) { + auto input = + cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); + auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, + kHostAllocatedAddress, kHostAllocatedSize); + return ::tt::target::ttnn::CreateAllGatherOp(*cache.fbb, input, output, + op.getDim(), op.getNumLinks()); +} + template ::flatbuffers::Offset<::tt::target::ttnn::EltwiseOp> createEltwiseOp(FlatbufferObjectCache &cache, EltwiseOp op) { @@ -578,6 +588,9 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, if (auto conv2dOp = dyn_cast(op); conv2dOp) { return createOperation(cache, createOp(cache, conv2dOp), debugString); } + if (auto allGatherOp = dyn_cast(op); allGatherOp) { + return createOperation(cache, createOp(cache, allGatherOp), debugString); + } if (auto concatOp = dyn_cast(op); concatOp) { return createOperation(cache, createConcatOp(cache, concatOp), debugString); } diff --git a/runtime/include/tt/runtime/detail/ttnn.h b/runtime/include/tt/runtime/detail/ttnn.h index e03e5f7bd..f6cf78e93 100644 --- a/runtime/include/tt/runtime/detail/ttnn.h +++ b/runtime/include/tt/runtime/detail/ttnn.h @@ -46,6 +46,7 @@ #include "hostdevcommon/common_values.hpp" #include "impl/device/mesh_device.hpp" #include "ttnn/device.hpp" +#include "ttnn/operations/ccl/all_gather/all_gather.hpp" #include "ttnn/operations/conv/conv2d/conv2d.hpp" #include "ttnn/operations/copy.hpp" #include "ttnn/operations/core/core.hpp" diff --git a/runtime/lib/ttnn/operations/CMakeLists.txt b/runtime/lib/ttnn/operations/CMakeLists.txt index 9894e1e4c..3bca03ee2 100644 --- a/runtime/lib/ttnn/operations/CMakeLists.txt +++ b/runtime/lib/ttnn/operations/CMakeLists.txt @@ -1,5 +1,6 @@ set(TTNN_OPS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/include/tt/runtime/ttnn/operations/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/all_gather.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv/conv2d.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/empty.cpp ${CMAKE_CURRENT_SOURCE_DIR}/creation/full.cpp diff --git a/runtime/lib/ttnn/operations/ccl/all_gather.cpp b/runtime/lib/ttnn/operations/ccl/all_gather.cpp new file mode 100644 index 000000000..37bf7427b --- /dev/null +++ b/runtime/lib/ttnn/operations/ccl/all_gather.cpp @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "all_gather.h" +#include "tt/runtime/detail/ttnn.h" +#include "tt/runtime/ttnn/operations/utils.h" + +namespace tt::runtime::ttnn::operations::ccl { +void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context) { + ProgramTensorPool &tensorPool = context.getTensorPool(); + const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id()); + int32_t dim = op->dim(); + int32_t num_links = op->num_links(); + ::tt::tt_metal::MemoryConfig outputMemoryConfig = + utils::createMemoryConfig(op->out()); + ::ttnn::Tensor out = + ::ttnn::all_gather(input, dim, num_links, outputMemoryConfig); + tensorPool.insert_or_assign(op->out()->global_id(), out); +} +} // namespace tt::runtime::ttnn::operations::ccl diff --git a/runtime/lib/ttnn/operations/ccl/all_gather.h b/runtime/lib/ttnn/operations/ccl/all_gather.h new file mode 100644 index 000000000..f9a7e5624 --- /dev/null +++ b/runtime/lib/ttnn/operations/ccl/all_gather.h @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTNN_RUNTIME_ALL_GATHER_H +#define TTNN_RUNTIME_ALL_GATHER_H + +#include "tt/runtime/ttnn/types.h" +#include "ttmlir/Target/TTNN/program_generated.h" + +namespace tt::runtime::ttnn::operations::ccl { +void run(const ::tt::target::ttnn::AllGatherOp *op, ProgramContext &context); +} // namespace tt::runtime::ttnn::operations::ccl + +#endif diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 8839b868f..079e5ec7e 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -1,6 +1,7 @@ // SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC // // SPDX-License-Identifier: Apache-2.0 +#include "operations/ccl/all_gather.h" #include "operations/context/get_device.h" #include "operations/conv/conv2d.h" #include "operations/creation/empty.h" @@ -113,6 +114,9 @@ void ProgramExecutor::runOperation(const ::tt::target::ttnn::Operation *op) { case ::tt::target::ttnn::OpType::MaxPool2dOp: { return operations::pool::run(op->type_as_MaxPool2dOp(), context); } + case ::tt::target::ttnn::OpType::AllGatherOp: { + return operations::ccl::run(op->type_as_AllGatherOp(), context); + } default: { throw std::runtime_error("Unsupported operation type"); } diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir new file mode 100644 index 000000000..f1f5a5965 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir @@ -0,0 +1,11 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<1x1x32x128xbf16> + // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] + %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %1 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_gather_negative.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_gather_negative.mlir new file mode 100644 index 000000000..d3f6ac3da --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/ccl/all_gather_negative.mlir @@ -0,0 +1,10 @@ +// RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s +// CHECK: error: 'ttir.all_gather' op Invalid dimension for all gather op +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> { + %0 = tensor.empty() : tensor<1x1x32x128xbf16> + %1 = "ttir.all_gather"(%arg0, %0) <{dim = 4 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + return %1 : tensor<1x1x32x128xbf16> + } +} diff --git a/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir b/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir new file mode 100644 index 000000000..edf0a4eaf --- /dev/null +++ b/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% mesh-shape=4,1,1" %s > %t.mlir +// RUN: FileCheck %s --input-file=%t.mlir +// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn +// UNSUPPORTED: true +// REQUIRES: multi-chip +#any_device = #tt.operand_constraint +#any_device_tile = #tt.operand_constraint + +func.func @forward(%arg0: tensor<1x1x32x32xf32>) -> tensor<1x1x32x128xf32> { + // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] + %0 = tensor.empty() : tensor<1x1x32x128xf32> + // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] + %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xf32>, tensor<1x1x32x128xf32>) -> tensor<1x1x32x128xf32> + return %1 : tensor<1x1x32x128xf32> +}