Skip to content

Commit

Permalink
Create new general pooling op and decomposition pattern that converts to
Browse files Browse the repository at this point in the history
maxpool2d
  • Loading branch information
LPanosTT committed Nov 4, 2024
1 parent 6988418 commit bb05d0c
Show file tree
Hide file tree
Showing 18 changed files with 616 additions and 336 deletions.
6 changes: 6 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ mlir_tablegen(TTIROpsAttrs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(TTIROpsAttrsIncGen)
add_dependencies(mlir-headers TTIROpsAttrsIncGen)

set(LLVM_TARGET_DEFINITIONS TTIROpsEnums.td)
mlir_tablegen(TTIROpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(TTIROpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTTIROpsEnumsIncGen)
add_dependencies(mlir-headers MLIRTTIROpsEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS TTIROpsInterfaces.td)
mlir_tablegen(TTIROpsInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(TTIROpsInterfaces.cpp.inc -gen-op-interface-defs)
Expand Down
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "TTIROpsInterfaces.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.h.inc"

Expand Down
27 changes: 27 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,33 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
}];
}

def TTIR_PoolingOp : TTIR_DPSOp<"pooling", [AttrSizedOperandSegments]> {
let summary = "General pooling op";
let description = [{
General pooling op
}];

let arguments = (ins
Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
TTIR_PoolingMethodAttr:$pooling_method,
DenseI64ArrayAttr:$window_dimensions,

// Default stride of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$window_strides,
// Default dilation of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$base_dilations,
// Default dilation of 1 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size(), 1)">:$window_dilations,
// Default padding of 0 over every dimension
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "SmallVector<int64_t>(getWindowDimensions().size() * 2, 0)">:$padding,
TT_OperandConstraintArrayAttr:$operand_constraints
);

let results = (outs Variadic<AnyRankedTensor>);

let hasVerifier = 1;
}

def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
Expand Down
4 changes: 4 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

include "mlir/IR/AttrTypeBase.td"
include "ttmlir/Dialect/TTIR/IR/TTIRBase.td"
include "mlir/IR/EnumAttr.td"
include "ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td"

def TTIR_PoolingMethodAttr : EnumAttr<TTIR_Dialect, TTIR_PoolingMethod, "pooling_method">;

def TTIR_ConvolutionLayoutAttr : AttrDef<TTIR_Dialect, "ConvolutionLayout", [], "::mlir::Attribute"> {
let mnemonic = "convolution_layout";
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsEnums.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TTIR_ENUMS_TD
#define TTMLIR_TTIR_ENUMS_TD

include "mlir/IR/EnumAttr.td"

def TTIR_AveragePoolingMethod : I32EnumAttrCase<"Average", 0>;
def TTIR_MaxPoolingMethod : I32EnumAttrCase<"Max", 1>;

def TTIR_PoolingMethod : I32EnumAttr<"PoolingMethod", "TTIR PoolingMethod", [
TTIR_AveragePoolingMethod,
TTIR_MaxPoolingMethod
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt::ttir";
}

#endif
7 changes: 0 additions & 7 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,6 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
];
}

def TTIRSlidingWindow2dFixShapes: Pass<"ttir-sliding-window-2d-fix-shapes", "::mlir::ModuleOp"> {
let summary = "Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)";
let description = [{
Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)
}];
}

def TTIRSplitCompoundLayout: Pass<"ttir-split-compound-layout", "::mlir::ModuleOp"> {
let summary = "Split compound layouts.";
let description = [{
Expand Down
210 changes: 113 additions & 97 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <limits>
#include <vector>

#include "mlir/Dialect/Traits.h"
Expand Down Expand Up @@ -473,11 +474,74 @@ class StableHLOToTTIRReduceWindowOpConversionPattern
rewriter.eraseOp(op);
}

bool isMaxPool2d(mlir::stablehlo::ReduceWindowOp &srcOp) const {
bool isMaxPool(mlir::stablehlo::ReduceWindowOp &srcOp) const {
if (srcOp.getBody().getBlocks().size() != 1) {
return false;
}

// Find constant input(s)
Operation *init_value;
for (uint64_t i = 0; i < srcOp.getInitValues().size(); i++) {
init_value = srcOp.getInitValues()[i].getDefiningOp();
auto name = init_value->getName().getStringRef().str();
(void)name;
while (init_value->getOpOperands().size() == 1) {
init_value = init_value->getOpOperand(0).get().getDefiningOp();
}
if (!isa<stablehlo::ConstantOp>(init_value)) {
return false;
}

stablehlo::ConstantOp init_value_op =
mlir::cast<stablehlo::ConstantOp>(init_value);

if (init_value_op.getValueAttr().size() != 1) {
return false;
}

// Constant operand must be -inf if this is to be a max pool
// since bfloat16 is not a type we acually have I must compare the raw
// bits
if (init_value_op.getResult().getType().getElementType().isBF16()) {
// Collect the values into a vector
std::vector<mlir::Attribute> values;
for (int64_t i = 0; i < init_value_op.getValueAttr().size(); ++i) {
values.push_back(
init_value_op.getValueAttr().getValues<mlir::Attribute>()[i]);
}

auto denseValues = ::mlir::DenseElementsAttr::get(
init_value_op.getValueAttr().getShapedType(), values);
uint16_t bfloat_bits =
static_cast<uint16_t>(*denseValues.getRawData().data());
if (bfloat_bits != 0xff80) { // This is -inf in bfloat16
return false;
}
} else if (init_value_op.getValue().getType().isF32()) {
if (*init_value_op.getValue().value_begin<float>() !=
-std::numeric_limits<float>::infinity()) {
return false;
}
} else if (init_value_op.getValue().getType().isF64()) {
if (*init_value_op.getValue().value_begin<double>() !=
-std::numeric_limits<double>::infinity()) {
return false;
}
} else if (init_value_op.getValue().getType().isInteger(32)) {
if (*init_value_op.getValue().value_begin<int32_t>() !=
std::numeric_limits<int32_t>::min()) {
return false;
}
} else if (init_value_op.getValue().getType().isInteger(64)) {
if (*init_value_op.getValue().value_begin<int64_t>() !=
std::numeric_limits<int64_t>::min()) {
return false;
}
} else {
return false;
}
}

Block &block = *srcOp.getBody().getBlocks().begin();
uint32_t op_idx = 0;
for (Operation &op : block) {
Expand All @@ -501,105 +565,57 @@ class StableHLOToTTIRReduceWindowOpConversionPattern
mlir::stablehlo::ReduceWindowOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (isMaxPool2d(srcOp)) {
RankedTensorType outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult(0).getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

RankedTensorType outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult(0).getType()));
ValueRange inputs = adaptor.getInputs()[0];
ValueRange outputs = {outputTensor};

auto window_dimensions = adaptor.getWindowDimensionsAttr();
auto window_strides = adaptor.getWindowStridesAttr();
auto base_dilations = adaptor.getBaseDilationsAttr();
auto window_dilations = adaptor.getWindowDilationsAttr();
auto padding_ = adaptor.getPaddingAttr();

// Generate defaults if they dont exist
window_strides = window_strides
? window_strides
: rewriter.getDenseI64ArrayAttr(SmallVector<int64_t>(
window_dimensions.size(), 1));
base_dilations = base_dilations
? base_dilations
: rewriter.getDenseI64ArrayAttr(SmallVector<int64_t>(
window_dimensions.size(), 1));
window_dilations =
window_dilations ? window_dilations
: rewriter.getDenseI64ArrayAttr(SmallVector<int64_t>(
window_dimensions.size(), 1));
auto padding =
padding_ ? rewriter.getDenseI64ArrayAttr(
SmallVector<int64_t>(padding_.getValues<int64_t>()))
: rewriter.getDenseI64ArrayAttr(
SmallVector<int64_t>(window_dimensions.size() * 2, 1));

auto operand_constraints = rewriter.getArrayAttr(SmallVector<Attribute>(
adaptor.getOperands().size(), rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile)));

mlir::tt::ttir::PoolingMethod pooling_method;
if (isMaxPool(srcOp)) {
pooling_method = mlir::tt::ttir::PoolingMethod::Max;
} else {
return failure();
}

tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
rewriter.replaceOpWithNewOp<ttir::PoolingOp>(
srcOp, outputType, inputs, outputs,
pooling_method, window_dimensions, window_strides,
base_dilations, window_dilations, padding, operand_constraints);

// The generalized ReduceWindow allows for kernel_size, strides, dilation,
// and padding to act on all 4 input dimensions. Since we only support
// channel-last pooling, we select the middle two values for H and W.
// And fail if the others are not 1 (or 0 in the case of padding).
std::vector<int64_t> window_dimensions = adaptor.getWindowDimensions();
if (window_dimensions[0] != 1 || window_dimensions[3] != 1) {
return failure();
}
IntegerAttr kernel_height_attr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(window_dimensions[1]));
IntegerAttr kernel_width_attr = rewriter.getSI32IntegerAttr(
static_cast<int32_t>(window_dimensions[2]));

std::vector<int64_t> strides =
adaptor.getWindowStrides()
.value_or(ArrayRef<int64_t>({1, 1, 1, 1}))
.vec();

if (strides[0] != 1 || strides[3] != 1) {
return failure();
}
IntegerAttr stride_height_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(strides[1]));
IntegerAttr stride_width_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(strides[2]));

std::vector<int64_t> dilation =
adaptor.getBaseDilations()
.value_or(ArrayRef<int64_t>({1, 1, 1, 1}))
.vec();

if (dilation[0] != 1 || dilation[3] != 1) {
return failure();
}
IntegerAttr dilation_height_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(dilation[1]));
IntegerAttr dilation_width_attr =
rewriter.getSI32IntegerAttr(static_cast<int32_t>(dilation[2]));

// Padding here is in the form ((., .), (top, bottom), (left, right), (.,
// .)) one for each of (N, H, W, C). Since we only support maxpool2d, the
// first and last padding tuples must be zero to be valid. This list is
// flattened so we can use a single iterator to get the values.
std::vector<int32_t> padding = {0, 0, 0, 0};
if (adaptor.getPadding().has_value()) {
uint32_t pad_idx = 0;
for (auto iter = adaptor.getPadding()->value_begin<int64_t>();
iter < adaptor.getPadding()->value_end<int64_t>(); iter++) {

// TTIR requires left, right, top, bottom
if (pad_idx == 2) {
padding[2] = *iter;
} else if (pad_idx == 3) {
padding[3] = *iter;
} else if (pad_idx == 4) {
padding[0] = *iter;
} else if (pad_idx == 5) {
padding[1] = *iter;
} else if (*iter != 0) {
// Padding on the channel or batch is > 1. TTIR/TTNN does not
// support this.
return failure();
}
pad_idx++;
}
}
::llvm::ArrayRef<int64_t> input_shape =
mlir::cast<mlir::RankedTensorType>(adaptor.getInputs()[0].getType())
.getShape();

// Dead ttir.constant sticks around and fails verification. Removing it
// like so since its behind another op
recursiveErase(rewriter, adaptor.getInitValues()[0].getDefiningOp());
rewriter.replaceOpWithNewOp<mlir::tt::ttir::MaxPool2dOp>(
srcOp, outputType, srcOp.getInputs()[0], outputTensor,
kernel_height_attr, kernel_width_attr, stride_height_attr,
stride_width_attr, dilation_height_attr, dilation_width_attr,
rewriter.getBoolAttr(false), rewriter.getSI32IntegerAttr(padding[0]),
rewriter.getSI32IntegerAttr(padding[1]),
rewriter.getSI32IntegerAttr(padding[2]),
rewriter.getSI32IntegerAttr(padding[3]),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))),
rewriter.getSI32IntegerAttr(input_shape[1]),
rewriter.getSI32IntegerAttr(input_shape[2]));

return success();
}
return failure();
return success();

}
};

Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
set(CMAKE_BUILD_TYPE Debug)
add_mlir_library(TTMLIRTTIRToTTIRDecomposition
TTIRToTTIRDecomposition.cpp
TTIRToTTIRDecompositionPass.cpp
Expand Down
Loading

0 comments on commit bb05d0c

Please sign in to comment.