From 435c7872ca4cbba46e8dce7bc01d4e3daf521345 Mon Sep 17 00:00:00 2001 From: Milan Topalovic Date: Fri, 27 Dec 2024 19:17:10 +0000 Subject: [PATCH] Fixing reduction ops --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 27 ++++++++++++++++--- include/ttmlir/Dialect/TTNN/IR/TTNNOps.td | 24 ++++++++++++++++- .../Decomposition/ReduceOpsRewritePattern.h | 15 ++++------- include/ttmlir/Target/TTNN/program.fbs | 2 +- lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 2 +- lib/Dialect/TTIR/IR/TTIROps.cpp | 17 +++++------- lib/Dialect/TTNN/IR/TTNNOps.cpp | 18 ++++++------- .../Decomposition/ReduceOpsRewritePattern.cpp | 17 +----------- lib/Target/TTNN/TTNNToFlatbuffer.cpp | 16 ++++++----- .../ttnn/operations/reduction/reduction.cpp | 8 +++--- .../reduce_ops/negative_invalid_dim_high.mlir | 2 +- .../reduce_ops/negative_invalid_dim_low.mlir | 2 +- .../reduce_ops/negative_repeating_dims.mlir | 2 +- .../TTNN/reduction/max_op_negative.mlir | 2 +- .../TTNN/reduction/mean_op_negative.mlir | 2 +- .../TTNN/reduction/sum_op_negative.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_max.mlir | 2 +- .../ttmlir/Silicon/TTMetal/simple_reduce.mlir | 8 +++--- .../Silicon/TTMetal/simple_reduce_1x1.mlir | 6 ++--- test/ttmlir/Silicon/TTNN/simple_mean.mlir | 22 +++++++++++++++ 20 files changed, 119 insertions(+), 77 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index b571287c8..0aecb8d50 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -615,7 +615,7 @@ class TTIR_ReductionOp traits = []> : let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, BoolAttr:$keep_dim, - OptionalAttr:$dim_arg); + OptionalAttr>:$dim); let results = (outs AnyRankedTensor:$result); @@ -624,6 +624,26 @@ class TTIR_ReductionOp traits = []> : void buildGenericRegion(::mlir::OpBuilder &opBuilder, ::mlir::Block* block); + SmallVector getReduceDims() { + mlir::Attribute reduceDimsAttr = getDim().value_or(mlir::Attribute{}); + SmallVector reduceDimsVec; + if (!reduceDimsAttr) { + return reduceDimsVec; + } + + if (auto intAttr = mlir::dyn_cast(reduceDimsAttr)) { + reduceDimsVec.push_back(intAttr.getSInt()); + } else { + auto arrayAttr = mlir::cast(reduceDimsAttr); + for (auto dimAttr : arrayAttr) { + int64_t dim = mlir::cast(dimAttr).getInt(); + reduceDimsVec.push_back(dim); + } + } + + return reduceDimsVec; + } + // Returns the indexing maps and iterator types for the reduction op. // Indexing maps are identity maps with dropped dimensions corresponding to the // reduction dimensions. Iterator types are parallel for non-reduction dimensions @@ -635,10 +655,9 @@ class TTIR_ReductionOp traits = []> : SmallVector iteratorTypes( rank, builder.getAttr(IteratorType::Parallel)); - auto reduceDims = getDimArgAttr(); + SmallVector reduceDims = getReduceDims(); auto resultIndexingMap = indexingMaps.back(); - for (auto reduceDim : reduceDims) { - int64_t reduceDimInt = mlir::cast(reduceDim).getInt(); + for (auto reduceDimInt : reduceDims) { if (reduceDimInt < 0) { reduceDimInt += rank; } diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 8ef3f73ca..a8c8a13d7 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -578,10 +578,32 @@ class TTNN_ReductionOp traits = []> : TTNN_Op:$dim_arg); + OptionalAttr>:$dim); let results = (outs AnyRankedTensor:$result); + let extraClassDeclaration = [{ + SmallVector getReduceDims() { + mlir::Attribute reduceDimsAttr = getDim().value_or(mlir::Attribute{}); + SmallVector reduceDimsVec; + if (!reduceDimsAttr) { + return reduceDimsVec; + } + + if (auto intAttr = mlir::dyn_cast(reduceDimsAttr)) { + reduceDimsVec.push_back(intAttr.getSInt()); + } else { + auto arrayAttr = mlir::cast(reduceDimsAttr); + for (auto dimAttr : arrayAttr) { + int64_t dim = mlir::cast(dimAttr).getInt(); + reduceDimsVec.push_back(dim); + } + } + + return reduceDimsVec; + } + }]; + let hasVerifier = 1; } diff --git a/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h b/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h index 741fbfc06..746b5f80e 100644 --- a/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h +++ b/include/ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h @@ -16,16 +16,11 @@ namespace mlir::tt::ttnn::workarounds::decomposition { -// Extracts reduce dimensions' values from the dimArg attribute. In case when -// dimArg is not specified, returns empty vector. -llvm::SmallVector -getReduceDims(const std::optional &dimArg); - // Calculates the shape of the new Reduce op created in the workaround, based // on the input shape and reducing dimensions. llvm::SmallVector calculateNewReduceShape(RankedTensorType inputType, - const std::optional &dimArg); + const llvm::SmallVector &reduceDims); // This workaround addresses the next Metal issue: // https://github.com/tenstorrent/tt-metal/issues/13361 @@ -70,7 +65,7 @@ class ReduceOpsKeepDimRewritePattern : public OpRewritePattern { RankedTensorType inputType, RankedTensorType outputType) const { llvm::SmallVector outputShapeVec = - calculateNewReduceShape(inputType, srcOp.getDimArg()); + calculateNewReduceShape(inputType, srcOp.getReduceDims()); TTNNLayoutAttr newOutputLayoutAttr = mlir::cast(outputType.getEncoding()) @@ -81,7 +76,7 @@ class ReduceOpsKeepDimRewritePattern : public OpRewritePattern { return rewriter.create(srcOp.getLoc(), newOutputType, srcOp.getInput(), true /*keep_dim*/, - srcOp.getDimArg().value_or(nullptr)); + srcOp.getDim().value_or(nullptr)); } void replaceOpWithReshapeOp(ReduceOp srcOp, ReduceOp newReduceOp, @@ -108,11 +103,11 @@ class ReduceOpsAllDimsRewritePattern : public OpRewritePattern { LogicalResult matchAndRewrite(ReduceOp srcOp, PatternRewriter &rewriter) const override { - if (!srcOp.getDimArg() || srcOp.getDimArg()->empty()) { + llvm::SmallVector reduceDims = srcOp.getReduceDims(); + if (reduceDims.empty()) { return failure(); } - llvm::SmallVector reduceDims = getReduceDims(srcOp.getDimArg()); llvm::SmallSet uniqueReduceDims(reduceDims.begin(), reduceDims.end()); diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index c8c0f1ad0..2a67fd80c 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -174,7 +174,7 @@ table ReductionOp { type: ReductionOpType; in: tt.target.TensorRef; out: tt.target.TensorRef; - dim_arg: [int32]; + dim: [int32]; keep_dim: bool; } diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index 81aaf868b..449fe99af 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -332,7 +332,7 @@ class ReductionOpConversionPattern : public OpConversionPattern { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), adaptor.getInput(), adaptor.getKeepDim(), - adaptor.getDimArg().value_or(nullptr)); + adaptor.getDim().value_or(nullptr)); return success(); } }; diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 83bb98baa..59d286436 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1729,23 +1729,20 @@ static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block, // Common verifier for all Reduce ops. static mlir::LogicalResult verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, - const std::optional &reduceDims) { - if (!reduceDims) { + const llvm::SmallVector &reduceDims) { + if (reduceDims.empty()) { return mlir::success(); } - int64_t inputTensorRank = inputType.getRank(); - llvm::SmallSet uniqueReduceDims; - for (mlir::Attribute reduceDim : *reduceDims) { - int64_t reduceDimInt = mlir::cast(reduceDim).getInt(); + for (int64_t reduceDimInt : reduceDims) { if (reduceDimInt < -inputTensorRank || reduceDimInt >= inputTensorRank) { return reduceOp->emitOpError("Reduce dimensions are out of range"); } uniqueReduceDims.insert(reduceDimInt); } - if (uniqueReduceDims.size() != reduceDims->size()) { + if (uniqueReduceDims.size() != reduceDims.size()) { return reduceOp->emitOpError("Reduce dimensions are not unique"); } @@ -1770,7 +1767,7 @@ void mlir::tt::ttir::MaxOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MaxOp verification. ::mlir::LogicalResult mlir::tt::ttir::MaxOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims()); } //===----------------------------------------------------------------------===// @@ -1786,7 +1783,7 @@ void mlir::tt::ttir::MeanOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // MeanOp verification. ::mlir::LogicalResult mlir::tt::ttir::MeanOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims()); } //===----------------------------------------------------------------------===// @@ -1802,5 +1799,5 @@ void mlir::tt::ttir::SumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, // SumOp verification. ::mlir::LogicalResult mlir::tt::ttir::SumOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims()); } diff --git a/lib/Dialect/TTNN/IR/TTNNOps.cpp b/lib/Dialect/TTNN/IR/TTNNOps.cpp index 9f2b2f6e3..fb5bec628 100644 --- a/lib/Dialect/TTNN/IR/TTNNOps.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOps.cpp @@ -152,8 +152,8 @@ ::mlir::LogicalResult mlir::tt::ttnn::ArangeOp::verify() { << getStart() << ", end=" << getEnd() << ", step=" << getStep(); } - std::vector expectedShape = {1, 1, 1, numValues}; - if (getType().getShape().vec() != expectedShape) { + llvm::SmallVector expectedShape = {1, 1, 1, numValues}; + if (getType().getShape() != ArrayRef(expectedShape)) { return emitOpError() << "Output tensor shape must be " << expectedShape << ", but got " << getType().getShape(); } @@ -1274,13 +1274,13 @@ ::mlir::LogicalResult mlir::tt::ttnn::PermuteOp::verify() { // Common verifier for all Reduction ops. static mlir::LogicalResult verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, - const std::optional &reduceDims) { - int64_t inputTensorRank = inputType.getRank(); + const llvm::SmallVector &reduceDims) { + size_t inputTensorRank = inputType.getRank(); // TODO(mrakita): Only last two dimensions can be reduced, check for that // too. - if (reduceDims && reduceDims->size() > 2 && - static_cast(reduceDims->size()) != inputTensorRank) { + if (!reduceDims.empty() && reduceDims.size() > 2 && + reduceDims.size() != inputTensorRank) { return reduceOp->emitOpError("Reduce on more than two dimensions is not " "currently supported by TTNN"); } @@ -1294,7 +1294,7 @@ verifyReduceOp(mlir::Operation *reduceOp, mlir::RankedTensorType inputType, // MaxOp verification. ::mlir::LogicalResult MaxOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims()); } //===----------------------------------------------------------------------===// @@ -1303,7 +1303,7 @@ ::mlir::LogicalResult MaxOp::verify() { // MeanOp verification. ::mlir::LogicalResult MeanOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims()); } //===----------------------------------------------------------------------===// @@ -1312,7 +1312,7 @@ ::mlir::LogicalResult MeanOp::verify() { // SumOp verification. ::mlir::LogicalResult SumOp::verify() { - return verifyReduceOp(getOperation(), getInput().getType(), getDimArg()); + return verifyReduceOp(getOperation(), getInput().getType(), getReduceDims()); } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.cpp b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.cpp index 99b61ef0b..de02e39b0 100644 --- a/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.cpp +++ b/lib/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.cpp @@ -8,25 +8,10 @@ namespace mlir::tt::ttnn::workarounds::decomposition { -llvm::SmallVector -getReduceDims(const std::optional &dimArg) { - llvm::SmallVector reduceDims; - if (!dimArg) { - return reduceDims; - } - - for (const mlir::Attribute &reduceDim : *dimArg) { - reduceDims.push_back(mlir::cast(reduceDim).getInt()); - } - - return reduceDims; -} - llvm::SmallVector calculateNewReduceShape(RankedTensorType inputType, - const std::optional &dimArg) { + const llvm::SmallVector &reduceDims) { llvm::SmallVector outputShapeVec(inputType.getShape()); - llvm::SmallVector reduceDims = getReduceDims(dimArg); if (reduceDims.empty()) { // When reduce dimensions are not specified that means we are reducing over diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index 4f9dde1a6..75b36b663 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -307,8 +307,8 @@ createDistributionStrategy(FlatbufferObjectCache &cache, // tensor is sliced at the fastest dimension. if (meshShape[0] == 1 || meshShape[1] == 1) { assert(type.getShape().size() > 0 && "expected non-zero tensor shape"); - uint32_t target_dim = type.getShape().size() - 1; - auto strategy = ::tt::target::CreateShardTensor(*cache.fbb, target_dim); + uint32_t targetDim = type.getShape().size() - 1; + auto strategy = ::tt::target::CreateShardTensor(*cache.fbb, targetDim); return ::tt::target::CreateDistributionStrategy( *cache.fbb, ::tt::target::DistributedTensorConfig::ShardTensor, strategy.Union()); @@ -730,11 +730,13 @@ createReductionOp(FlatbufferObjectCache &cache, ReductionOp op) { cache.at<::tt::target::TensorRef>(getOperandThroughDPSOps(op.getInput())); auto output = cache.getOrCreate(op.getResult(), tensorValueToFlatbuffer, kHostAllocatedAddress, kHostAllocatedSize); - auto dim_arg = - arrayAttrToFlatbuffer(cache, op.getDimArg()); + SmallVector dims = op.getReduceDims(); + SmallVector dims32(dims.begin(), dims.end()); + auto dimArg = + op.getReduceDims().empty() ? 0 : toFlatbuffer(cache, dims32); return ::tt::target::ttnn::CreateReductionOp(*cache.fbb, type, in, output, - dim_arg, op.getKeepDim()); + dimArg, op.getKeepDim()); } ::flatbuffers::Offset<::tt::target::ttnn::TransposeOp> @@ -1134,8 +1136,8 @@ emitTTNNOperation(FlatbufferObjectCache &cache, Operation *op, return createOperation(cache, createSliceOp(cache, sliceOp), debugString, locInfo); } - if (auto max_pool2dOp = dyn_cast(op); max_pool2dOp) { - return createOperation(cache, createMaxPool2dOp(cache, max_pool2dOp), + if (auto maxPool2dOp = dyn_cast(op); maxPool2dOp) { + return createOperation(cache, createMaxPool2dOp(cache, maxPool2dOp), debugString, locInfo); } if (auto deallocateOp = dyn_cast(op); deallocateOp) { diff --git a/runtime/lib/ttnn/operations/reduction/reduction.cpp b/runtime/lib/ttnn/operations/reduction/reduction.cpp index 631df3f51..6412c5d34 100644 --- a/runtime/lib/ttnn/operations/reduction/reduction.cpp +++ b/runtime/lib/ttnn/operations/reduction/reduction.cpp @@ -22,11 +22,11 @@ static void runReductionOp( const ::ttnn::Tensor &in = tensorPool.at(op->in()->global_id()); DEBUG_ASSERT(in.is_allocated()); - const auto *fbDimArg = op->dim_arg(); + const auto *fbDim = op->dim(); std::optional<::ttnn::SmallVector> dimArg = - fbDimArg ? std::make_optional(::ttnn::SmallVector(fbDimArg->begin(), - fbDimArg->end())) - : std::nullopt; + fbDim ? std::make_optional( + ::ttnn::SmallVector(fbDim->begin(), fbDim->end())) + : std::nullopt; ::ttnn::Tensor out = ttnnOp( in, dimArg, op->keep_dim(), outputMemoryConfig /* memory_config_arg */, diff --git a/test/ttmlir/Dialect/TTIR/reduce_ops/negative_invalid_dim_high.mlir b/test/ttmlir/Dialect/TTIR/reduce_ops/negative_invalid_dim_high.mlir index 565745d05..56982ae4b 100644 --- a/test/ttmlir/Dialect/TTIR/reduce_ops/negative_invalid_dim_high.mlir +++ b/test/ttmlir/Dialect/TTIR/reduce_ops/negative_invalid_dim_high.mlir @@ -4,6 +4,6 @@ // CHECK: error: 'ttir.sum' op Reduce dimensions are out of range func.func public @test_reduce_add_invalid_dim_high(%arg0: tensor<128x10xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> { %0 = tensor.empty() : tensor<128xf32> - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32> + %1 = "ttir.sum"(%arg0, %0) <{dim = [2 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32> return %1 : tensor<128xf32> } diff --git a/test/ttmlir/Dialect/TTIR/reduce_ops/negative_invalid_dim_low.mlir b/test/ttmlir/Dialect/TTIR/reduce_ops/negative_invalid_dim_low.mlir index bd4a237d4..13b1a4800 100644 --- a/test/ttmlir/Dialect/TTIR/reduce_ops/negative_invalid_dim_low.mlir +++ b/test/ttmlir/Dialect/TTIR/reduce_ops/negative_invalid_dim_low.mlir @@ -4,6 +4,6 @@ // CHECK: error: 'ttir.sum' op Reduce dimensions are out of range func.func public @test_reduce_add_invalid_dim_low(%arg0: tensor<128x10xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> { %0 = tensor.empty() : tensor<128xf32> - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-3 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32> + %1 = "ttir.sum"(%arg0, %0) <{dim = [-3 : i32], keep_dim = false}> : (tensor<128x10xf32>, tensor<128xf32>) -> tensor<128xf32> return %1 : tensor<128xf32> } diff --git a/test/ttmlir/Dialect/TTIR/reduce_ops/negative_repeating_dims.mlir b/test/ttmlir/Dialect/TTIR/reduce_ops/negative_repeating_dims.mlir index 13649e1e6..7a9bbdb13 100644 --- a/test/ttmlir/Dialect/TTIR/reduce_ops/negative_repeating_dims.mlir +++ b/test/ttmlir/Dialect/TTIR/reduce_ops/negative_repeating_dims.mlir @@ -4,6 +4,6 @@ // CHECK: error: 'ttir.sum' op Reduce dimensions are not unique func.func public @test_reduce_add_repeating_dims(%arg0: tensor<128x10x32x4xf32>, %arg1: tensor<1xf32>) -> tensor<128xf32> { %0 = tensor.empty() : tensor<128xf32> - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [1 : i32, 2 : i32, 3 : i32, 2 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128xf32>) -> tensor<128xf32> + %1 = "ttir.sum"(%arg0, %0) <{dim = [1 : i32, 2 : i32, 3 : i32, 2 : i32], keep_dim = false}> : (tensor<128x10x32x4xf32>, tensor<128xf32>) -> tensor<128xf32> return %1 : tensor<128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/reduction/max_op_negative.mlir b/test/ttmlir/Dialect/TTNN/reduction/max_op_negative.mlir index ac587303e..492414b7c 100644 --- a/test/ttmlir/Dialect/TTNN/reduction/max_op_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/reduction/max_op_negative.mlir @@ -4,7 +4,7 @@ module { func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> { %0 = tensor.empty() : tensor<128x1x1x1xbf16> // CHECK: error: 'ttnn.max' op Reduce on more than two dimensions is not currently supported by TTNN - %1 = "ttir.max"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16> + %1 = "ttir.max"(%arg0, %0) <{dim = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16> return %1 : tensor<128x1x1x1xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/reduction/mean_op_negative.mlir b/test/ttmlir/Dialect/TTNN/reduction/mean_op_negative.mlir index 768b220bb..0e085cd12 100644 --- a/test/ttmlir/Dialect/TTNN/reduction/mean_op_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/reduction/mean_op_negative.mlir @@ -4,7 +4,7 @@ module { func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> { %0 = tensor.empty() : tensor<128x1x1x1xbf16> // CHECK: error: 'ttnn.mean' op Reduce on more than two dimensions is not currently supported by TTNN - %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16> + %1 = "ttir.mean"(%arg0, %0) <{dim = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16> return %1 : tensor<128x1x1x1xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/reduction/sum_op_negative.mlir b/test/ttmlir/Dialect/TTNN/reduction/sum_op_negative.mlir index c0c634f05..135191ffc 100644 --- a/test/ttmlir/Dialect/TTNN/reduction/sum_op_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/reduction/sum_op_negative.mlir @@ -4,7 +4,7 @@ module { func.func @forward(%arg0: tensor<128x32x10x4xbf16>) -> tensor<128x1x1x1xbf16> { %0 = tensor.empty() : tensor<128x1x1x1xbf16> // CHECK: error: 'ttnn.sum' op Reduce on more than two dimensions is not currently supported by TTNN - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16> + %1 = "ttir.sum"(%arg0, %0) <{dim = [1: i32, 2: i32, 3: i32], keep_dim = true}> : (tensor<128x32x10x4xbf16>, tensor<128x1x1x1xbf16>) -> tensor<128x1x1x1xbf16> return %1 : tensor<128x1x1x1xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_max.mlir b/test/ttmlir/Dialect/TTNN/simple_max.mlir index 34a0120b2..ea0c8593b 100644 --- a/test/ttmlir/Dialect/TTNN/simple_max.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_max.mlir @@ -3,7 +3,7 @@ module attributes {} { func.func @forward(%arg0: tensor<512x32xbf16>) -> tensor<512xbf16> { %0 = tensor.empty() : tensor<512xbf16> // CHECK: %[[C:.*]] = "ttnn.max"[[C:.*]] - %1 = "ttir.max"(%arg0, %0) <{dim_arg = [1: i32], keep_dim = false}> : (tensor<512x32xbf16>, tensor<512xbf16>) -> tensor<512xbf16> + %1 = "ttir.max"(%arg0, %0) <{dim = [1: i32], keep_dim = false}> : (tensor<512x32xbf16>, tensor<512xbf16>) -> tensor<512xbf16> return %1 : tensor<512xbf16> } } diff --git a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir index a6ab52acb..fa8f5aca6 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir @@ -7,7 +7,7 @@ func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, # %0 = tensor.empty() : tensor<256x32xf32, #layout2> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, - dim_arg = [-1: i32], + dim = [-1: i32], keep_dim = true}> : (tensor<256x384xf32, #layout1>, tensor<256x32xf32, #layout2>) -> tensor<256x32xf32, #layout2> return %1 : tensor<256x32xf32, #layout2> @@ -18,7 +18,7 @@ func.func @reduceH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x384xf32, # %0 = tensor.empty() : tensor<32x384xf32, #layout3> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, - dim_arg = [-2: i32], + dim = [-2: i32], keep_dim = true}> : (tensor<256x384xf32, #layout1>, tensor<32x384xf32, #layout3>) -> tensor<32x384xf32, #layout3> return %1 : tensor<32x384xf32, #layout3> @@ -29,7 +29,7 @@ func.func @reduceWH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x32xf32, # %0 = tensor.empty() : tensor<32x32xf32, #layout4> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, - dim_arg = [-1: i32, -2: i32], + dim = [-1: i32, -2: i32], keep_dim = true}> : (tensor<256x384xf32, #layout1>, tensor<32x32xf32, #layout4>) -> tensor<32x32xf32, #layout4> return %1 : tensor<32x32xf32, #layout4> @@ -39,7 +39,7 @@ func.func @maxReduceWH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x32xf32 %0 = tensor.empty() : tensor<32x32xf32, #layout4> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.max" (%arg0, %0) <{operandSegmentSizes = array, - dim_arg = [-1: i32, -2: i32], + dim = [-1: i32, -2: i32], keep_dim = true}> : (tensor<256x384xf32, #layout1>, tensor<32x32xf32, #layout4>) -> tensor<32x32xf32, #layout4> return %1 : tensor<32x32xf32, #layout4> diff --git a/test/ttmlir/Silicon/TTMetal/simple_reduce_1x1.mlir b/test/ttmlir/Silicon/TTMetal/simple_reduce_1x1.mlir index 2038cfa08..806cfc04a 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_reduce_1x1.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_reduce_1x1.mlir @@ -5,7 +5,7 @@ func.func @reduceW(%arg0: tensor<64x256xf32>) -> tensor<64x32xf32> { %0 = tensor.empty() : tensor<64x32xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, - dim_arg = [-1: i32], + dim = [-1: i32], keep_dim = true}> : (tensor<64x256xf32>, tensor<64x32xf32>) -> tensor<64x32xf32> return %1 : tensor<64x32xf32> @@ -15,7 +15,7 @@ func.func @reduceH(%arg0: tensor<256x64xf32>) -> tensor<32x64xf32> { %0 = tensor.empty() : tensor<32x64xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, - dim_arg = [-2: i32], + dim = [-2: i32], keep_dim = true}> : (tensor<256x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> return %1 : tensor<32x64xf32> @@ -25,7 +25,7 @@ func.func @reduceWH(%arg0: tensor<256x64xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, - dim_arg = [-1: i32, -2: i32], + dim = [-1: i32, -2: i32], keep_dim = true}> : (tensor<256x64xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> diff --git a/test/ttmlir/Silicon/TTNN/simple_mean.mlir b/test/ttmlir/Silicon/TTNN/simple_mean.mlir index 476dcd9ab..22ed1d635 100644 --- a/test/ttmlir/Silicon/TTNN/simple_mean.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_mean.mlir @@ -36,4 +36,26 @@ module { %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = true}> : (tensor<128x10xf32>, tensor<128x1xf32>) -> tensor<128x1xf32> return %1 : tensor<128x1xf32> } + + func.func public @mean_into_reshape_dim_array(%arg0: tensor<1x1x49x2048xf32>) -> (tensor<1x2048x1x1xf32> {ttir.name = "AvgPool2d.output_avg_pool2d_0"}) { + // CHECK: "ttnn.mean" + // CHECK-SAME: {dim = [-2 : i32], keep_dim = true} + // CHECK: "ttnn.reshape" + %1 = tensor.empty() : tensor<1x1x1x2048xf32> + %2 = "ttir.mean"(%arg0, %1) <{dim = [-2 : i32], keep_dim = true}> : (tensor<1x1x49x2048xf32>, tensor<1x1x1x2048xf32>) -> tensor<1x1x1x2048xf32> + %3 = tensor.empty() : tensor<1x2048x1x1xf32> + %4 = "ttir.reshape"(%2, %3) <{shape = [1 : i32, 2048 : i32, 1 : i32, 1 : i32]}> : (tensor<1x1x1x2048xf32>, tensor<1x2048x1x1xf32>) -> tensor<1x2048x1x1xf32> + return %4 : tensor<1x2048x1x1xf32> + } + + func.func public @mean_into_reshape_dim_scalar(%arg0: tensor<1x1x49x2048xf32>) -> (tensor<1x2048x1x1xf32> {ttir.name = "AvgPool2d.output_avg_pool2d_0"}) { + // CHECK: "ttnn.mean" + // CHECK-SAME: {dim = -2 : si32, keep_dim = true} + // CHECK: "ttnn.reshape" + %1 = tensor.empty() : tensor<1x1x1x2048xf32> + %2 = "ttir.mean"(%arg0, %1) <{dim = -2 : si32, keep_dim = true}> : (tensor<1x1x49x2048xf32>, tensor<1x1x1x2048xf32>) -> tensor<1x1x1x2048xf32> + %3 = tensor.empty() : tensor<1x2048x1x1xf32> + %4 = "ttir.reshape"(%2, %3) <{shape = [1 : i32, 2048 : i32, 1 : i32, 1 : i32]}> : (tensor<1x1x1x2048xf32>, tensor<1x2048x1x1xf32>) -> tensor<1x2048x1x1xf32> + return %4 : tensor<1x2048x1x1xf32> + } }