From bb05d0c45091f352f1d5ed9647bda1b993a19951 Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Fri, 1 Nov 2024 21:02:27 +0000 Subject: [PATCH] Create new general pooling op and decomposition pattern that converts to maxpool2d --- include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt | 6 + include/ttmlir/Dialect/TTIR/IR/TTIROps.h | 2 + include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 27 ++ .../ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td | 4 + .../ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td | 21 ++ .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 7 - .../StableHLOToTTIRPatterns.cpp | 210 ++++++----- .../TTIRToTTIRDecomposition/CMakeLists.txt | 1 + .../TTIRToTTIRDecomposition.cpp | 336 +++++++++++++++--- .../TTIRToTTIRDecompositionPass.cpp | 1 + lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp | 125 +++++-- lib/Dialect/TTIR/IR/TTIRDialect.cpp | 2 + lib/Dialect/TTIR/IR/TTIROps.cpp | 38 ++ lib/Dialect/TTIR/Transforms/Transforms.cpp | 137 ------- lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 1 - .../Dialect/TTNN/pooling/complex_pooling.mlir | 19 + .../TTNN/{ => pooling}/simple_maxpool2d.mlir | 0 .../Dialect/TTNN/pooling/simple_pooling.mlir | 15 + 18 files changed, 616 insertions(+), 336 deletions(-) create mode 100644 include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td create mode 100644 test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir rename test/ttmlir/Dialect/TTNN/{ => pooling}/simple_maxpool2d.mlir (100%) create mode 100644 test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir diff --git a/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt b/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt index e04bf3f3e3..15bf10223e 100644 --- a/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt +++ b/include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt @@ -8,6 +8,12 @@ mlir_tablegen(TTIROpsAttrs.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(TTIROpsAttrsIncGen) add_dependencies(mlir-headers TTIROpsAttrsIncGen) +set(LLVM_TARGET_DEFINITIONS TTIROpsEnums.td) +mlir_tablegen(TTIROpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(TTIROpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRTTIROpsEnumsIncGen) +add_dependencies(mlir-headers MLIRTTIROpsEnumsIncGen) + set(LLVM_TARGET_DEFINITIONS TTIROpsInterfaces.td) mlir_tablegen(TTIROpsInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(TTIROpsInterfaces.cpp.inc -gen-op-interface-defs) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h index cd102dbb96..f23fd6d881 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.h +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.h @@ -16,6 +16,8 @@ #include "TTIROpsInterfaces.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc" + #define GET_ATTRDEF_CLASSES #include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.h.inc" diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 3e0d27598c..b50e863522 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -694,6 +694,33 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> { }]; } +def TTIR_PoolingOp : TTIR_DPSOp<"pooling", [AttrSizedOperandSegments]> { + let summary = "General pooling op"; + let description = [{ + General pooling op + }]; + + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + TTIR_PoolingMethodAttr:$pooling_method, + DenseI64ArrayAttr:$window_dimensions, + + // Default stride of 1 over every dimension + DefaultValuedOptionalAttr(getWindowDimensions().size(), 1)">:$window_strides, + // Default dilation of 1 over every dimension + DefaultValuedOptionalAttr(getWindowDimensions().size(), 1)">:$base_dilations, + // Default dilation of 1 over every dimension + DefaultValuedOptionalAttr(getWindowDimensions().size(), 1)">:$window_dilations, + // Default padding of 0 over every dimension + DefaultValuedOptionalAttr(getWindowDimensions().size() * 2, 0)">:$padding, + TT_OperandConstraintArrayAttr:$operand_constraints + ); + + let results = (outs Variadic); + + let hasVerifier = 1; +} def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> { let summary = "Applies a 2D max pooling over an input signal composed of several input planes."; diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td index 60943af269..66cc6bc32f 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td @@ -7,6 +7,10 @@ include "mlir/IR/AttrTypeBase.td" include "ttmlir/Dialect/TTIR/IR/TTIRBase.td" +include "mlir/IR/EnumAttr.td" +include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td" + +def TTIR_PoolingMethodAttr : EnumAttr; def TTIR_ConvolutionLayoutAttr : AttrDef { let mnemonic = "convolution_layout"; diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td new file mode 100644 index 0000000000..473675d3e0 --- /dev/null +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td @@ -0,0 +1,21 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_TTIR_ENUMS_TD +#define TTMLIR_TTIR_ENUMS_TD + +include "mlir/IR/EnumAttr.td" + +def TTIR_AveragePoolingMethod : I32EnumAttrCase<"Average", 0>; +def TTIR_MaxPoolingMethod : I32EnumAttrCase<"Max", 1>; + +def TTIR_PoolingMethod : I32EnumAttr<"PoolingMethod", "TTIR PoolingMethod", [ + TTIR_AveragePoolingMethod, + TTIR_MaxPoolingMethod + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tt::ttir"; +} + +#endif diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index 1cee4cbb5c..63ccb0d28a 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -70,13 +70,6 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> { ]; } -def TTIRSlidingWindow2dFixShapes: Pass<"ttir-sliding-window-2d-fix-shapes", "::mlir::ModuleOp"> { - let summary = "Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)"; - let description = [{ - Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C) - }]; -} - def TTIRSplitCompoundLayout: Pass<"ttir-split-compound-layout", "::mlir::ModuleOp"> { let summary = "Split compound layouts."; let description = [{ diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 0d750aed80..716cf8ab29 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include #include #include "mlir/Dialect/Traits.h" @@ -473,11 +474,74 @@ class StableHLOToTTIRReduceWindowOpConversionPattern rewriter.eraseOp(op); } - bool isMaxPool2d(mlir::stablehlo::ReduceWindowOp &srcOp) const { + bool isMaxPool(mlir::stablehlo::ReduceWindowOp &srcOp) const { if (srcOp.getBody().getBlocks().size() != 1) { return false; } + // Find constant input(s) + Operation *init_value; + for (uint64_t i = 0; i < srcOp.getInitValues().size(); i++) { + init_value = srcOp.getInitValues()[i].getDefiningOp(); + auto name = init_value->getName().getStringRef().str(); + (void)name; + while (init_value->getOpOperands().size() == 1) { + init_value = init_value->getOpOperand(0).get().getDefiningOp(); + } + if (!isa(init_value)) { + return false; + } + + stablehlo::ConstantOp init_value_op = + mlir::cast(init_value); + + if (init_value_op.getValueAttr().size() != 1) { + return false; + } + + // Constant operand must be -inf if this is to be a max pool + // since bfloat16 is not a type we acually have I must compare the raw + // bits + if (init_value_op.getResult().getType().getElementType().isBF16()) { + // Collect the values into a vector + std::vector values; + for (int64_t i = 0; i < init_value_op.getValueAttr().size(); ++i) { + values.push_back( + init_value_op.getValueAttr().getValues()[i]); + } + + auto denseValues = ::mlir::DenseElementsAttr::get( + init_value_op.getValueAttr().getShapedType(), values); + uint16_t bfloat_bits = + static_cast(*denseValues.getRawData().data()); + if (bfloat_bits != 0xff80) { // This is -inf in bfloat16 + return false; + } + } else if (init_value_op.getValue().getType().isF32()) { + if (*init_value_op.getValue().value_begin() != + -std::numeric_limits::infinity()) { + return false; + } + } else if (init_value_op.getValue().getType().isF64()) { + if (*init_value_op.getValue().value_begin() != + -std::numeric_limits::infinity()) { + return false; + } + } else if (init_value_op.getValue().getType().isInteger(32)) { + if (*init_value_op.getValue().value_begin() != + std::numeric_limits::min()) { + return false; + } + } else if (init_value_op.getValue().getType().isInteger(64)) { + if (*init_value_op.getValue().value_begin() != + std::numeric_limits::min()) { + return false; + } + } else { + return false; + } + } + Block &block = *srcOp.getBody().getBlocks().begin(); uint32_t op_idx = 0; for (Operation &op : block) { @@ -501,105 +565,57 @@ class StableHLOToTTIRReduceWindowOpConversionPattern mlir::stablehlo::ReduceWindowOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (isMaxPool2d(srcOp)) { + RankedTensorType outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult(0).getType())); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - RankedTensorType outputType = mlir::cast( - getTypeConverter()->convertType(srcOp.getResult(0).getType())); + ValueRange inputs = adaptor.getInputs()[0]; + ValueRange outputs = {outputTensor}; + + auto window_dimensions = adaptor.getWindowDimensionsAttr(); + auto window_strides = adaptor.getWindowStridesAttr(); + auto base_dilations = adaptor.getBaseDilationsAttr(); + auto window_dilations = adaptor.getWindowDilationsAttr(); + auto padding_ = adaptor.getPaddingAttr(); + + // Generate defaults if they dont exist + window_strides = window_strides + ? window_strides + : rewriter.getDenseI64ArrayAttr(SmallVector( + window_dimensions.size(), 1)); + base_dilations = base_dilations + ? base_dilations + : rewriter.getDenseI64ArrayAttr(SmallVector( + window_dimensions.size(), 1)); + window_dilations = + window_dilations ? window_dilations + : rewriter.getDenseI64ArrayAttr(SmallVector( + window_dimensions.size(), 1)); + auto padding = + padding_ ? rewriter.getDenseI64ArrayAttr( + SmallVector(padding_.getValues())) + : rewriter.getDenseI64ArrayAttr( + SmallVector(window_dimensions.size() * 2, 1)); + + auto operand_constraints = rewriter.getArrayAttr(SmallVector( + adaptor.getOperands().size(), rewriter.getAttr( + OperandConstraint::AnyDeviceTile))); + + mlir::tt::ttir::PoolingMethod pooling_method; + if (isMaxPool(srcOp)) { + pooling_method = mlir::tt::ttir::PoolingMethod::Max; + } else { + return failure(); + } - tensor::EmptyOp outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + rewriter.replaceOpWithNewOp( + srcOp, outputType, inputs, outputs, + pooling_method, window_dimensions, window_strides, + base_dilations, window_dilations, padding, operand_constraints); - // The generalized ReduceWindow allows for kernel_size, strides, dilation, - // and padding to act on all 4 input dimensions. Since we only support - // channel-last pooling, we select the middle two values for H and W. - // And fail if the others are not 1 (or 0 in the case of padding). - std::vector window_dimensions = adaptor.getWindowDimensions(); - if (window_dimensions[0] != 1 || window_dimensions[3] != 1) { - return failure(); - } - IntegerAttr kernel_height_attr = rewriter.getSI32IntegerAttr( - static_cast(window_dimensions[1])); - IntegerAttr kernel_width_attr = rewriter.getSI32IntegerAttr( - static_cast(window_dimensions[2])); - - std::vector strides = - adaptor.getWindowStrides() - .value_or(ArrayRef({1, 1, 1, 1})) - .vec(); - - if (strides[0] != 1 || strides[3] != 1) { - return failure(); - } - IntegerAttr stride_height_attr = - rewriter.getSI32IntegerAttr(static_cast(strides[1])); - IntegerAttr stride_width_attr = - rewriter.getSI32IntegerAttr(static_cast(strides[2])); - - std::vector dilation = - adaptor.getBaseDilations() - .value_or(ArrayRef({1, 1, 1, 1})) - .vec(); - - if (dilation[0] != 1 || dilation[3] != 1) { - return failure(); - } - IntegerAttr dilation_height_attr = - rewriter.getSI32IntegerAttr(static_cast(dilation[1])); - IntegerAttr dilation_width_attr = - rewriter.getSI32IntegerAttr(static_cast(dilation[2])); - - // Padding here is in the form ((., .), (top, bottom), (left, right), (., - // .)) one for each of (N, H, W, C). Since we only support maxpool2d, the - // first and last padding tuples must be zero to be valid. This list is - // flattened so we can use a single iterator to get the values. - std::vector padding = {0, 0, 0, 0}; - if (adaptor.getPadding().has_value()) { - uint32_t pad_idx = 0; - for (auto iter = adaptor.getPadding()->value_begin(); - iter < adaptor.getPadding()->value_end(); iter++) { - - // TTIR requires left, right, top, bottom - if (pad_idx == 2) { - padding[2] = *iter; - } else if (pad_idx == 3) { - padding[3] = *iter; - } else if (pad_idx == 4) { - padding[0] = *iter; - } else if (pad_idx == 5) { - padding[1] = *iter; - } else if (*iter != 0) { - // Padding on the channel or batch is > 1. TTIR/TTNN does not - // support this. - return failure(); - } - pad_idx++; - } - } - ::llvm::ArrayRef input_shape = - mlir::cast(adaptor.getInputs()[0].getType()) - .getShape(); - - // Dead ttir.constant sticks around and fails verification. Removing it - // like so since its behind another op - recursiveErase(rewriter, adaptor.getInitValues()[0].getDefiningOp()); - rewriter.replaceOpWithNewOp( - srcOp, outputType, srcOp.getInputs()[0], outputTensor, - kernel_height_attr, kernel_width_attr, stride_height_attr, - stride_width_attr, dilation_height_attr, dilation_width_attr, - rewriter.getBoolAttr(false), rewriter.getSI32IntegerAttr(padding[0]), - rewriter.getSI32IntegerAttr(padding[1]), - rewriter.getSI32IntegerAttr(padding[2]), - rewriter.getSI32IntegerAttr(padding[3]), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile))), - rewriter.getSI32IntegerAttr(input_shape[1]), - rewriter.getSI32IntegerAttr(input_shape[2])); - - return success(); - } - return failure(); + return success(); + } }; diff --git a/lib/Conversion/TTIRToTTIRDecomposition/CMakeLists.txt b/lib/Conversion/TTIRToTTIRDecomposition/CMakeLists.txt index d727b3d1c9..5d166565fa 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/CMakeLists.txt +++ b/lib/Conversion/TTIRToTTIRDecomposition/CMakeLists.txt @@ -1,3 +1,4 @@ +set(CMAKE_BUILD_TYPE Debug) add_mlir_library(TTMLIRTTIRToTTIRDecomposition TTIRToTTIRDecomposition.cpp TTIRToTTIRDecompositionPass.cpp diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index d361fce1f3..c3b8a1a9c4 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -23,6 +23,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LogicalResult.h" #include +#include #include using namespace mlir; @@ -142,46 +143,46 @@ generateTranspose(Value input, int64_t dim0, int64_t dim1, dim1_attr, operandConstraints); } -static std::vector generateKernelTransposeIndices( - ttir::ConvolutionOp op, - const std::vector ttnn_convolution_kernel_layout) { +static std::vector +generateTransposeIndicesLeftToRight(std::vector current_layout, + const std::vector desired_layout) { std::vector transpose_indices; - - std::vector kernel_layout( - ttnn_convolution_kernel_layout.size(), - ConvolutionKernelDimension::INVALID_KERNEL_DIM); - kernel_layout[op.getConvolutionLayout().getKernelOutputFeatureDimension()] = - ConvolutionKernelDimension::OUTPUT_FEATURES; - kernel_layout[op.getConvolutionLayout().getKernelInputFeatureDimension()] = - ConvolutionKernelDimension::INPUT_FEATURES; - - int64_t spatial_count = 0; - for (int64_t spatial_dim : - op.getConvolutionLayout().getKernelSpatialDimensions()) { - kernel_layout[spatial_dim] = spatial_count; - spatial_count++; + for (int64_t i = 0; i < static_cast(current_layout.size()); i++) { + if (current_layout[i] != desired_layout[i]) { + int64_t dim0 = i; + int64_t dim1 = std::find(current_layout.begin(), current_layout.end(), + desired_layout[i]) - + current_layout.begin(); + transpose_indices.push_back(std::make_tuple(dim0, dim1)); + std::swap(current_layout[dim0], current_layout[dim1]); + } } - const std::vector desired_kernel_layout = - ttnn_convolution_kernel_layout; - for (int64_t i = 0; i < static_cast(kernel_layout.size()); i++) { - if (kernel_layout[i] != desired_kernel_layout[i]) { + return transpose_indices; +} + +static std::vector +generateTransposeIndicesRightToLeft(std::vector current_layout, + const std::vector desired_layout) { + std::vector transpose_indices; + for (int64_t i = static_cast(current_layout.size()) - 1; i >= 0; + i--) { + if (current_layout[i] != desired_layout[i]) { int64_t dim0 = i; - int64_t dim1 = std::find(kernel_layout.begin(), kernel_layout.end(), - desired_kernel_layout[i]) - - kernel_layout.begin(); + int64_t dim1 = std::find(current_layout.begin(), current_layout.end(), + desired_layout[i]) - + current_layout.begin(); transpose_indices.push_back(std::make_tuple(dim0, dim1)); - std::swap(kernel_layout[dim0], kernel_layout[dim1]); + std::swap(current_layout[dim0], current_layout[dim1]); } } return transpose_indices; } -static std::vector generateInputTransposeIndices( +static std::vector generateConvInputTransposeIndices( ttir::ConvolutionOp op, const std::vector ttnn_convolution_layout) { - std::vector transpose_indices; std::vector input_layout(ttnn_convolution_layout.size(), ConvolutionDimension::INVALID_DIM); @@ -197,19 +198,32 @@ static std::vector generateInputTransposeIndices( spatial_count++; } - const std::vector desired_input_layout = ttnn_convolution_layout; - for (int64_t i = 0; i < static_cast(input_layout.size()); i++) { - if (input_layout[i] != desired_input_layout[i]) { - int64_t dim0 = i; - int64_t dim1 = std::find(input_layout.begin(), input_layout.end(), - desired_input_layout[i]) - - input_layout.begin(); - transpose_indices.push_back(std::make_tuple(dim0, dim1)); - std::swap(input_layout[dim0], input_layout[dim1]); - } + return generateTransposeIndicesLeftToRight(input_layout, + ttnn_convolution_layout); +} + +static std::vector generateConvKernelTransposeIndices( + ttir::ConvolutionOp op, + const std::vector ttnn_convolution_kernel_layout) { + std::vector transpose_indices; + + std::vector kernel_layout( + ttnn_convolution_kernel_layout.size(), + ConvolutionKernelDimension::INVALID_KERNEL_DIM); + kernel_layout[op.getConvolutionLayout().getKernelOutputFeatureDimension()] = + ConvolutionKernelDimension::OUTPUT_FEATURES; + kernel_layout[op.getConvolutionLayout().getKernelInputFeatureDimension()] = + ConvolutionKernelDimension::INPUT_FEATURES; + + int64_t spatial_count = 0; + for (int64_t spatial_dim : + op.getConvolutionLayout().getKernelSpatialDimensions()) { + kernel_layout[spatial_dim] = spatial_count; + spatial_count++; } - return transpose_indices; + return generateTransposeIndicesLeftToRight(kernel_layout, + ttnn_convolution_kernel_layout); } /** @@ -219,7 +233,7 @@ static std::vector generateInputTransposeIndices( * that were applied to the input but in reverse order. This makes optimizing * away the inserted transposes easier. */ -static std::vector generateOutputTransposeIndices( +static std::vector generateConvOutputTransposeIndices( ttir::ConvolutionOp op, const std::vector ttnn_convolution_layout) { std::vector transpose_indices; @@ -238,21 +252,8 @@ static std::vector generateOutputTransposeIndices( spatial_count++; } - std::vector output_layout = ttnn_convolution_layout; - - for (int64_t i = static_cast(desired_output_layout.size()) - 1; - i >= 0; i--) { - if (desired_output_layout[i] != output_layout[i]) { - int64_t dim0 = i; - int64_t dim1 = std::find(output_layout.begin(), output_layout.end(), - desired_output_layout[i]) - - output_layout.begin(); - transpose_indices.push_back(std::make_tuple(dim0, dim1)); - std::swap(output_layout[dim0], output_layout[dim1]); - } - } - - return transpose_indices; + return generateTransposeIndicesRightToLeft(desired_output_layout, + ttnn_convolution_layout); } static Value @@ -373,13 +374,13 @@ struct ConvolutionToConv2dPattern outputType.getElementType()); auto input_transpose_indices = - generateInputTransposeIndices(op, conv2d_layout); + generateConvInputTransposeIndices(op, conv2d_layout); Value input = generateTransposeSequence(adaptor.getInput(), rewriter, input_transpose_indices, adaptor.getOperandConstraints()); auto kernel_transpose_indices = - generateKernelTransposeIndices(op, conv2d_kernel_layout); + generateConvKernelTransposeIndices(op, conv2d_kernel_layout); Value weight = generateTransposeSequence(adaptor.getWeight(), rewriter, kernel_transpose_indices, adaptor.getOperandConstraints()); @@ -391,7 +392,7 @@ struct ConvolutionToConv2dPattern padding_bottom_attr, adaptor.getOperandConstraints()); auto output_transpose_indices = - generateOutputTransposeIndices(op, conv2d_layout); + generateConvOutputTransposeIndices(op, conv2d_layout); Value output = generateTransposeSequence(new_conv.getResult(), rewriter, output_transpose_indices, adaptor.getOperandConstraints()); @@ -402,9 +403,230 @@ struct ConvolutionToConv2dPattern } }; +struct PoolingToPool2dPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + std::vector getSpatialDims(ttir::PoolingOp op) const { + std::vector spatial_dims; + for (int64_t i = 0; + i < static_cast(op.getWindowDimensions().size()); i++) { + if (op.getWindowDimensions()[i] > 1) { + spatial_dims.push_back(i); + } + } + return spatial_dims; + } + + LogicalResult is2d(ttir::PoolingOp op) const { + + // Window dimensions must be 4 in length + if (op.getWindowDimensions().size() != 4) { + return failure(); + } + + // Window strides must be 4 in length + if (op.getWindowStrides().size() != 4) { + return failure(); + } + + // Operand rank(s) must be 4 + for (Value operand : op.getInputs()) { + auto operand_type = mlir::cast(operand.getType()); + if (operand_type.getRank() != 4) { + return failure(); + } + } + + // Exactly two of the window dimensions must be greater than 1 + std::vector true_window_dimensions_indices = getSpatialDims(op); + + if (true_window_dimensions_indices.size() != 2) { + return failure(); + } + + // Exactly two of the window strides must be greater than 1 + std::vector true_window_stride_indices; + for (int64_t i = 0; i < static_cast(op.getWindowStrides().size()); + i++) { + if (op.getWindowStrides()[i] > 1) { + true_window_stride_indices.push_back(i); + } + } + + if (true_window_stride_indices.size() != 2) { + return failure(); + } + + // The indices of the true window dimensions and strides must be the same + if ((true_window_dimensions_indices[0] != true_window_stride_indices[0] || + true_window_dimensions_indices[1] != true_window_stride_indices[1]) && + (true_window_dimensions_indices[0] != true_window_stride_indices[1] || + true_window_dimensions_indices[1] != true_window_stride_indices[0])) { + return failure(); + } + + // NOTE: To model later + // + // There are many scenarios in which this is a valid maxpool2d + // - There are two window dimensions > 1 and two window strides > 1 which + // are at the exact same indices + // - There is one window dimension > 1 and one window stride > 1 which are + // at the exact same index (other spatial dimension is thus arbitrary so we + // can just pick an adjacent one) + // - There is two window dimensions > 1 and one window stride > 1 where the + // single window stride is at one of the two indices where the window + // dimensions lie, the other window stride (1) will be at the index of the + // other spatial dimension + // - There is one window dimension > 1 and two window strides > 1 where the + // single window dimension is at one of the two indices where the window + // strides lie, the other window dimension (1) will be at the index of the + // other stride dimension + // - If it is 1x1 with 1x1 strides this is a nop, which is true for all + // pooling ops + + // Padding must be 8 in length + if (op.getPadding().size() != 8) { + return failure(); + } + + return success(); + } + + template + void rewritePool2d(ttir::PoolingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + const int64_t SPATIAL_H = -3; + const int64_t SPATIAL_W = -2; + const int64_t NON_SPATIAL = -1; + + auto inputType = + mlir::cast(adaptor.getInputs()[0].getType()); + assert(inputType.getRank() == 4 && "Input must be 4D tensor"); + std::vector desired_layout(inputType.getRank(), NON_SPATIAL); + desired_layout[inputType.getRank() - 3] = SPATIAL_H; + desired_layout[inputType.getRank() - 2] = SPATIAL_W; + + int64_t non_spatial_count = 0; + for (int64_t i = 0; i < static_cast(desired_layout.size()); i++) { + if (desired_layout[i] == NON_SPATIAL) { + desired_layout[i] = non_spatial_count; + non_spatial_count++; + } + } + + std::vector spatial_dims = getSpatialDims(op); + + std::vector current_layout(inputType.getRank(), NON_SPATIAL); + current_layout[spatial_dims[0]] = SPATIAL_H; + current_layout[spatial_dims[1]] = SPATIAL_W; + + non_spatial_count = 0; + for (int64_t i = 0; i < static_cast(current_layout.size()); i++) { + if (current_layout[i] == NON_SPATIAL) { + current_layout[i] = non_spatial_count; + non_spatial_count++; + } + } + + auto input_transpose_indices = + generateTransposeIndicesLeftToRight(current_layout, desired_layout); + auto output_transpose_indices = + generateTransposeIndicesRightToLeft(desired_layout, current_layout); + + auto kernel_height_attr = rewriter.getSI32IntegerAttr( + static_cast(op.getWindowDimensions()[spatial_dims[0]])); + auto kernel_width_attr = rewriter.getSI32IntegerAttr( + static_cast(op.getWindowDimensions()[spatial_dims[1]])); + + auto stride_height_attr = rewriter.getSI32IntegerAttr( + static_cast(op.getWindowStrides()[spatial_dims[0]])); + + auto stride_width_attr = rewriter.getSI32IntegerAttr( + static_cast(op.getWindowStrides()[spatial_dims[1]])); + + auto dilation_height_attr = rewriter.getSI32IntegerAttr( + adaptor.getWindowDilations()[spatial_dims[0]]); + auto dilation_width_attr = rewriter.getSI32IntegerAttr( + adaptor.getWindowDilations()[spatial_dims[1]]); + auto ceil_mode_attr = rewriter.getBoolAttr(false); + + auto padding_top_attr = + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatial_dims[0]]); + auto padding_bottom_attr = + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatial_dims[0] + 1]); + auto padding_left_attr = + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatial_dims[1]]); + auto padding_right_attr = + rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatial_dims[1] + 1]); + auto operand_constraints = adaptor.getOperandConstraints(); + + std::vector outputs; + for (Value input : adaptor.getInputs()) { + input = generateTransposeSequence( + input, rewriter, input_transpose_indices, operand_constraints); + + auto outputType = mlir::cast(op.getResult(0).getType()); + auto newOutputShape = outputType.getShape().vec(); + for (TransposeDims dims : input_transpose_indices) { + std::swap(newOutputShape[std::get<0>(dims)], + newOutputShape[std::get<1>(dims)]); + } + auto newOutputType = + outputType.cloneWith(newOutputShape, outputType.getElementType()); + auto outputTensor = rewriter.create( + op.getLoc(), newOutputType.getShape(), + newOutputType.getElementType()); + + auto new_pool = rewriter.create( + op.getLoc(), newOutputType, input, outputTensor, kernel_height_attr, + kernel_width_attr, stride_height_attr, stride_width_attr, + dilation_height_attr, dilation_width_attr, ceil_mode_attr, + padding_top_attr, padding_bottom_attr, padding_left_attr, + padding_right_attr, operand_constraints, + rewriter.getSI32IntegerAttr(inputType.getShape()[spatial_dims[0]]), + rewriter.getSI32IntegerAttr(inputType.getShape()[spatial_dims[1]])); + + Value output = generateTransposeSequence(new_pool.getResult(), rewriter, + output_transpose_indices, + operand_constraints); + outputs.push_back(output); + } + + for (int64_t i = 0; i < static_cast(outputs.size()); i++) { + rewriter.replaceAllUsesWith(op.getResult(i), outputs[i]); + } + + rewriter.eraseOp(op); + } + + LogicalResult + matchAndRewrite(ttir::PoolingOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(is2d(op))) { + return failure(); + } + + switch (op.getPoolingMethod()) { + case ttir::PoolingMethod::Max: { + rewritePool2d(op, adaptor, rewriter); + return success(); + } + + default: + return failure(); + } + + return failure(); + } +}; + void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx, RewritePatternSet &patterns, TypeConverter &typeConverter) { + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); } diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp index e621e6b285..12738df431 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecompositionPass.cpp @@ -48,6 +48,7 @@ struct TTIRToTTIRDecompositionPass // These are the ops we intend to remove entirely with this pass target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); TypeConverter typeConverter; // All types map 1:1. diff --git a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp index c0864c8d86..9ed56a1ea3 100644 --- a/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp +++ b/lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp @@ -23,6 +23,9 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LogicalResult.h" +#include +#include +#include using namespace mlir; using namespace mlir::tt; @@ -501,11 +504,35 @@ class ConstantOpConversionPattern return legalityResult; } - if (valueAttr.isSplat()) { + auto shape = op.getType().getShape(); + auto elementType = op.getType().getElementType(); + auto numElements = valueAttr.size(); + + // Collect the values into a vector + std::vector values; + for (int64_t i = 0; i < numElements; ++i) { + values.push_back(valueAttr.getValues()[i]); + } + + auto denseValues = ::mlir::DenseElementsAttr::get( + ::mlir::RankedTensorType::get(shape, elementType), values); + + if (denseValues.isSplat()) { Value device = getOrInsertDevice(rewriter, op); - float fillValue = valueAttr.getElementType().isInteger() - ? static_cast(valueAttr.getSplatValue()) - : valueAttr.getSplatValue(); + float fillValue = 0; + if (op.getResult().getType().getElementType().isBF16()) { + uint32_t raw_shifted_bfloat = + static_cast(*denseValues.getRawData().data()) << 16; + fillValue = *reinterpret_cast(&raw_shifted_bfloat); + } else if (op.getResult().getType().getElementType().isF32()) { + fillValue = static_cast(*denseValues.getRawData().data()); + } else if (op.getResult().getType().getElementType().isInteger()) { + fillValue = static_cast(denseValues.getSplatValue()); + } else { + return rewriter.notifyMatchFailure( + op, "TTNN doesn't currently support tensor creation from values " + "which are not integer or floating point numbers"); + } if (fillValue == 0) { rewriter.replaceOpWithNewOp( op, this->getTypeConverter()->convertType(op.getType()), device); @@ -556,33 +583,31 @@ class MatmulOpConversionPattern : public OpConversionPattern { }; // ANCHOR_END: adding_an_op_matmul_op_rewriter -class Conv2dOpConversionPattern : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; +static ttnn::ReshapeOp generateReshape(Value input, ArrayRef newShape, + PatternRewriter &rewriter) { + auto inputType = mlir::cast(input.getType()); + auto outputType = inputType.cloneWith(newShape, inputType.getElementType()); - ttnn::ReshapeOp generateReshape(ttir::Conv2dOp op, Value input, - ArrayRef newShape, - PatternRewriter &rewriter) const { - auto inputType = mlir::cast(input.getType()); - auto outputType = inputType.cloneWith(newShape, inputType.getElementType()); + std::vector newShapeI32(newShape.begin(), newShape.end()); + return rewriter.create( + input.getLoc(), outputType, input, rewriter.getI32ArrayAttr(newShapeI32)); +} - std::vector newShapeI32(newShape.begin(), newShape.end()); - return rewriter.create( - input.getLoc(), outputType, input, - rewriter.getI32ArrayAttr(newShapeI32)); - } +static ttnn::ReshapeOp generateNHWFlatten(Value input, + PatternRewriter &rewriter) { + std::vector shape = + mlir::cast(input.getType()).getShape().vec(); - ttnn::ReshapeOp generateNHWFlatten(ttir::Conv2dOp op, Value input, - PatternRewriter &rewriter) const { - std::vector shape = - mlir::cast(input.getType()).getShape().vec(); + assert(shape.size() == 4 && "Must have 4-dim tensor as conv2d input"); - assert(shape.size() == 4 && "Must have 4-dim tensor as conv2d input"); + std::vector newShape = {1, 1, shape[0] * shape[1] * shape[2], + shape[3]}; + return generateReshape(input, newShape, rewriter); +} - std::vector newShape = {1, 1, shape[0] * shape[1] * shape[2], - shape[3]}; - return generateReshape(op, input, newShape, rewriter); - } +class Conv2dOpConversionPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ttir::Conv2dOp op, OpAdaptor adaptor, @@ -638,7 +663,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern { std::vector flattenedInputShape = { 1, 1, input_shape[0] * input_shape[1] * input_shape[2], input_shape[3]}; - Value flattenedInput = generateNHWFlatten(op, adaptor.getInput(), rewriter); + Value flattenedInput = generateNHWFlatten(adaptor.getInput(), rewriter); std::vector flattenedOutputShape = { 1, 1, output_shape[0] * output_shape[1] * output_shape[2], @@ -663,7 +688,7 @@ class Conv2dOpConversionPattern : public OpConversionPattern { stride_height, stride_width, padding_height, padding_width, dilation_height, dilation_width, groups); - Value output = generateReshape(op, new_conv, output_shape, rewriter); + Value output = generateReshape(new_conv, output_shape, rewriter); rewriter.replaceOp(op, output); return success(); @@ -702,15 +727,41 @@ class MaxPool2dOpConversionPattern "ttir::MaxPool2dOp must have original_width set before translating " "to TTNN dialect."); - rewriter.replaceOpWithNewOp( - op, this->getTypeConverter()->convertType(op.getType()), - adaptor.getInput(), adaptor.getOutput(), device, batch_size, - adaptor.getOriginalHeightAttr(), adaptor.getOriginalWidthAttr(), - channels, adaptor.getKernelHeightAttr(), adaptor.getKernelWidthAttr(), - adaptor.getStrideHeightAttr(), adaptor.getStrideWidthAttr(), - adaptor.getDilationHeightAttr(), adaptor.getDilationWidthAttr(), - adaptor.getCeilModeAttr(), adaptor.getPaddingTopAttr(), - adaptor.getPaddingRightAttr()); + Value flattenedInput = generateNHWFlatten(adaptor.getInput(), rewriter); + + auto output_ty = + mlir::cast(adaptor.getOutput().getType()); + llvm::ArrayRef output_shape = output_ty.getShape(); + + std::vector flattenedOutputShape = { + 1, 1, output_shape[0] * output_shape[1] * output_shape[2], + output_shape[3]}; + + output_ty = mlir::cast(getTypeConverter()->convertType( + output_ty.cloneWith(flattenedOutputShape, output_ty.getElementType()))); + + // Using a tensor::EmptyOp so that the rewriter for EmptyOp can handle the + // attribute determination + auto poolDPSOutput = rewriter.replaceOpWithNewOp( + adaptor.getOutput().getDefiningOp(), flattenedOutputShape, + output_ty.getElementType()); + + // Must set the type to the output type to maintain the layout attributes + poolDPSOutput.getResult().setType(output_ty); + + auto new_pool = rewriter.create( + op.getLoc(), output_ty, flattenedInput, poolDPSOutput, device, + batch_size, adaptor.getOriginalHeightAttr(), + adaptor.getOriginalWidthAttr(), channels, adaptor.getKernelHeightAttr(), + adaptor.getKernelWidthAttr(), adaptor.getStrideHeightAttr(), + adaptor.getStrideWidthAttr(), adaptor.getDilationHeightAttr(), + adaptor.getDilationWidthAttr(), adaptor.getCeilModeAttr(), + adaptor.getPaddingTopAttr(), adaptor.getPaddingRightAttr()); + + Value output = generateReshape(new_pool, output_shape, rewriter); + + rewriter.replaceOp(op, output); + return success(); } }; diff --git a/lib/Dialect/TTIR/IR/TTIRDialect.cpp b/lib/Dialect/TTIR/IR/TTIRDialect.cpp index 73d259ea3c..b935b613bb 100644 --- a/lib/Dialect/TTIR/IR/TTIRDialect.cpp +++ b/lib/Dialect/TTIR/IR/TTIRDialect.cpp @@ -11,6 +11,8 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.cpp.inc" + #define GET_ATTRDEF_CLASSES #include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.cpp.inc" diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index 3ae2a7badd..864f236dae 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -110,6 +110,44 @@ ::mlir::LogicalResult mlir::tt::ttir::ConvolutionOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PoolingOp +//===----------------------------------------------------------------------===// + +::mlir::LogicalResult mlir::tt::ttir::PoolingOp::verify() { + + uint32_t inputRank = mlir::cast(getInputs()[0].getType()).getRank(); + + for (auto input : getInputs()) { + auto inputType = mlir::cast(input.getType()); + if (inputType.getRank() != inputRank) { + return emitOpError("All input tensors must have the same rank"); + } + } + + if (getWindowStrides().size() != inputRank) { + return emitOpError("Window strides must have the same number of elements " + "as the rank of the input tensor"); + } + + if (getWindowDilations().size() != inputRank) { + return emitOpError("Window dilations must have the same number of elements " + "as the rank of the input tensor"); + } + + if (getWindowDimensions().size() != inputRank) { + return emitOpError("Window dimensions must have the same number of elements " + "as the rank of the input tensor"); + } + + if (getPadding().size() != 2*inputRank) { + return emitOpError("Padding must have the same number of elements as twice " + "the rank of the input tensor"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // MaxPool2dOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TTIR/Transforms/Transforms.cpp b/lib/Dialect/TTIR/Transforms/Transforms.cpp index 084f1a90d4..0a34de9c4d 100644 --- a/lib/Dialect/TTIR/Transforms/Transforms.cpp +++ b/lib/Dialect/TTIR/Transforms/Transforms.cpp @@ -14,141 +14,4 @@ namespace mlir::tt::ttir { #define GEN_PASS_DEF_TTIRSLIDINGWINDOW2DFIXSHAPES #include "ttmlir/Dialect/TTIR/Transforms/Passes.h.inc" -//===----------------------------------------------------------------------===// -// Helper methods -//===----------------------------------------------------------------------===// - -std::vector collapseNHW(std::vector shape) { - std::vector collapsed(shape.size(), 1); - - int64_t NHW = 1; - for (uint32_t i = 0; i < shape.size() - 1; i++) { - NHW *= shape[i]; - } - collapsed[collapsed.size() - 2] = NHW; - collapsed[collapsed.size() - 1] = shape[shape.size() - 1]; - return collapsed; -} - -//===----------------------------------------------------------------------===// -// Sliding window pass -//===----------------------------------------------------------------------===// - -template -class UncollapsedSlidingWindow2dPatternRewriter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc, - Value input, ::llvm::ArrayRef shapei64, - ::mlir::ArrayAttr operandConstraints) const { - auto ty = mlir::cast(input.getType()); - auto output = - rewriter.create(loc, shapei64, ty.getElementType()); - - auto shape_attr = rewriter.getI32ArrayAttr( - {static_cast(shapei64[0]), static_cast(shapei64[1]), - static_cast(shapei64[2]), static_cast(shapei64[3])}); - return rewriter.create( - loc, output.getType(), input, output, shape_attr, operandConstraints); - } - - MaxPool2dOp createMaxPool2dOp(PatternRewriter &rewriter, MaxPool2dOp op, - Value input, int32_t input_height, - int32_t input_width, - RankedTensorType new_result_type) const { - auto output = rewriter.create( - op->getLoc(), new_result_type.getShape(), - new_result_type.getElementType()); - - auto input_height_attr = rewriter.getSI32IntegerAttr(input_height); - auto input_width_attr = rewriter.getSI32IntegerAttr(input_width); - - MaxPool2dOp new_maxpool = rewriter.create( - op.getLoc(), new_result_type, input, output, op.getKernelHeightAttr(), - op.getKernelWidthAttr(), op.getStrideHeightAttr(), - op.getStrideWidthAttr(), op.getDilationHeightAttr(), - op.getDilationWidthAttr(), op.getCeilModeAttr(), - op.getPaddingLeftAttr(), op.getPaddingRightAttr(), - op.getPaddingTopAttr(), op.getPaddingBottomAttr(), - op.getOperandConstraints(), input_height_attr, input_width_attr); - - return new_maxpool; - } - - LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final { - ::llvm::ArrayRef input_shape = - mlir::cast(op.getInput().getType()).getShape(); - - if (input_shape.size() != 4) { - return failure(); - } - - if (input_shape[0] == 1 && input_shape[1] == 1) { - return failure(); - } - - if (!llvm::isa(op)) { - return failure(); - } - - // By this point we are certain that the input tensor is not in the form (1, - // 1, N*H*W, C) And so we must insert reshapes on the input/output - - std::vector new_input_shape = collapseNHW(input_shape); - ::llvm::ArrayRef new_input_shape_array(new_input_shape); - - ReshapeOp input_reshape = - createReshapeOp(rewriter, op.getLoc(), op.getInput(), - new_input_shape_array, op.getOperandConstraints()); - - std::vector new_result_shape = - collapseNHW(op.getResult().getType().getShape().vec()); - ::llvm::ArrayRef new_result_shape_array(new_result_shape); - - RankedTensorType new_result_type = RankedTensorType::get( - new_result_shape_array, op.getResult().getType().getElementType(), - op.getResult().getType().getEncoding()); - - Operation *new_op = createMaxPool2dOp( - rewriter, mlir::cast(op), input_reshape, - static_cast(input_shape[1]), - static_cast(input_shape[2]), new_result_type); - - ReshapeOp output_reshape = createReshapeOp( - rewriter, op.getLoc(), new_op->getResult(0), - op.getResult().getType().getShape().vec(), op.getOperandConstraints()); - - rewriter.replaceOp(op, output_reshape); - return success(); - } -}; - -class TTIRSlidingWindow2dFixShapes - : public impl::TTIRSlidingWindow2dFixShapesBase< - TTIRSlidingWindow2dFixShapes> { -public: - using impl::TTIRSlidingWindow2dFixShapesBase< - TTIRSlidingWindow2dFixShapes>::TTIRSlidingWindow2dFixShapesBase; - - void runOnOperation() final { - { - RewritePatternSet patterns(&getContext()); - patterns.add>( - &getContext()); - FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) { - signalPassFailure(); - return; - } - } - } - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - } -}; - } // namespace mlir::tt::ttir diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index 7f3baaeaf7..e598d93c43 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -27,7 +27,6 @@ void createTTNNPipelineTTIRPasses( // function. Removes all private functions. pm.addPass(mlir::createInlinerPass()); - pm.addPass(mlir::tt::ttir::createTTIRSlidingWindow2dFixShapes()); pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc(systemDescOptions)); ttir::TTIRImplicitDeviceOptions implicitDeviceOptions; diff --git a/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir b/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir new file mode 100644 index 0000000000..8a7eb509d8 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir @@ -0,0 +1,19 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xbf16>, %arg1: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { + %0 = tensor.empty() : tensor<1x32x64x64xbf16> + %1 = tensor.empty() : tensor<1x32x64x64xbf16> + // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + %2, %3 = "ttir.pooling"(%arg0, %arg1, %0, %1) <{ + operandSegmentSizes = array, + pooling_method = #ttir, + window_dimensions = array, + window_strides = array, + operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) + + %4 = tensor.empty() : tensor<1x32x64x64xbf16> + %6 = "ttir.add"(%2, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> + return %6 : tensor<1x32x64x64xbf16> + } +} diff --git a/test/ttmlir/Dialect/TTNN/simple_maxpool2d.mlir b/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir similarity index 100% rename from test/ttmlir/Dialect/TTNN/simple_maxpool2d.mlir rename to test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir diff --git a/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir b/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir new file mode 100644 index 0000000000..a2d6141de2 --- /dev/null +++ b/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir @@ -0,0 +1,15 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s +#any_device = #tt.operand_constraint +module attributes {} { + func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { + %0 = tensor.empty() : tensor<1x32x64x64xbf16> + // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] + %1 = "ttir.pooling"(%arg0, %0) <{ + operandSegmentSizes = array, + pooling_method = #ttir, + window_dimensions = array, + window_strides = array, + operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> + return %1 : tensor<1x32x64x64xbf16> + } +}