Skip to content

Commit

Permalink
reopen the ConvertConv2DToImg2Col Pattern & Add another constraint
Browse files Browse the repository at this point in the history
Change-Id: I5882508d983d2f8f1c9fa5b1423d41d04c2d22de
  • Loading branch information
daozhuo.feng committed Apr 18, 2024
1 parent 8323eea commit 187dec5
Showing 1 changed file with 149 additions and 147 deletions.
296 changes: 149 additions & 147 deletions lib/Dialect/Top/Transforms/ProcessorOptimize/OptimizeBM1684X.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,8 @@ class ConvertMatMul2Attention : public OpRewritePattern<top::MatMulOp> {
}
} else if (module::isBM1684X()) {
if (len / n > 2048 * 320 * 4 ||
(len_weight0 * 2 + len_weight1 + len_weight2) / head > 1024 * 128 * 4) {
(len_weight0 * 2 + len_weight1 + len_weight2) / head >
1024 * 128 * 4) {
return failure();
}
}
Expand Down Expand Up @@ -725,145 +726,147 @@ class WhereBroadcastToTile : public OpRewritePattern<top::WhereOp> {
}
};


/**
* Based on comprehensive testing with the full model-zoo,
* the benefits brought by this pattern are minimal, yet it results in a significant performance degradation in models such as clip.
* Therefore, this pattern will be temporarily commented out.
*/
// class ConvertConv2DToImg2Col final : public OpRewritePattern<top::ConvOp> {
// using OpRewritePattern::OpRewritePattern;
// LogicalResult matchAndRewrite(top::ConvOp convOp,
// PatternRewriter &rewriter) const override {
// Value input = convOp.getInput();
// Value filter = convOp.getFilter();
// Value bias = convOp.getBias();
// Value output = convOp.getOutput();
// auto inputType = llvm::cast<ShapedType>(input.getType());
// auto filterType = llvm::cast<ShapedType>(filter.getType());
// auto outputType = llvm::cast<ShapedType>(output.getType());
// bool with_bias = !module::isNone(bias);
// auto strides = module::getI64Array(convOp.getStrides());
// // note: current support Conv2D
// if (!filterType.hasStaticShape() || !inputType.hasStaticShape() ||
// module::getShape(output).size() != 4) {
// return failure();
// }


// auto hasAllOneValues = [&](mlir::ArrayAttr attr) -> bool {
// return llvm::all_of(
// attr.getAsRange<IntegerAttr>(),
// [](IntegerAttr element) { return element.getInt() == 1; });
// };
// if (convOp.getDilations().has_value() &&
// !hasAllOneValues(convOp.getDilations().value()))
// return failure();

// auto filterShape = filterType.getShape();
// auto outputShape = outputType.getShape();

// const int n = outputShape[0];
// const int oc = outputShape[1];
// const int oh = outputShape[2];
// const int ow = outputShape[3];
// const int ic = filterShape[1];
// const int kh = filterShape[2];
// const int kw = filterShape[3];
// if (!(ic <= 3 && kh >= 16 && kw >= 16 && strides->at(0) == kh &&
// strides->at(1) == kw)) {
// return failure();
// }
// int id = 0;
// auto loc_name = module::getName(convOp.getOperation()).str();
// // 1. Input->Reshape+permute+Reshape(reorder the input)
// SmallVector<int64_t> colTensorShape = {n, ic, oh, kh, ow, kw};
// auto reshapeOp = rewriter.create<top::ReshapeOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get(colTensorShape, inputType.getElementType()),
// ValueRange{input});
// std::vector<int64_t> order = {0, 2, 3, 1, 4, 5};
// std::vector<NamedAttribute> attrs;
// attrs.emplace_back(
// rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr(order)));

// auto perMuteOp_0 = rewriter.create<top::PermuteOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get({n, oh, kh, ic, ow, kw},
// inputType.getElementType()),
// ValueRange{reshapeOp}, attrs);
// order = {0, 1, 4, 3, 2, 5};
// attrs.clear();
// attrs.emplace_back(
// rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr(order)));
// auto perMuteOp = rewriter.create<top::PermuteOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get({n, oh, ow, ic, kh, kw},
// inputType.getElementType()),
// ValueRange{perMuteOp_0}, attrs);

// auto reshapeOp_2 = rewriter.create<top::ReshapeOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get({n, oh * ow, ic * kh * kw},
// inputType.getElementType()),
// ValueRange{perMuteOp});
// // 2. filter->reshape
// auto reshapeOp_3 = rewriter.create<top::ReshapeOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get({oc, ic * kh * kw}, filterType.getElementType()),
// ValueRange{filter});
// std::vector<Value> operands;
// operands.emplace_back(reshapeOp_2);
// operands.emplace_back(reshapeOp_3);
// // 3. bias->reshape
// if (with_bias) {
// auto reshapeOp_4 = rewriter.create<top::ReshapeOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get(
// {1, 1, oc},
// llvm::cast<ShapedType>(bias.getType()).getElementType()),
// ValueRange{bias});
// operands.emplace_back(reshapeOp_4);
// } else {
// operands.emplace_back(bias);
// }

// attrs.clear();
// attrs.emplace_back(
// rewriter.getNamedAttr("right_transpose", rewriter.getBoolAttr(true)));
// attrs.emplace_back(
// rewriter.getNamedAttr("output_transpose", rewriter.getBoolAttr(false)));
// // 4. matmul
// auto matmulOp = rewriter.create<top::MatMulOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()),
// operands, attrs);
// attrs.clear();
// attrs.emplace_back(
// rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr({0, 2, 1})));
// // 5. permute
// auto perMuteOp_2 = rewriter.create<top::PermuteOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()),
// ValueRange{matmulOp}, attrs);
// // 6. reshape the output
// auto reshapeOp_5 = rewriter.create<top::ReshapeOp>(
// NameLoc::get(
// rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
// RankedTensorType::get({n, oc, oh, ow}, outputType.getElementType()),
// ValueRange{perMuteOp_2});
// rewriter.replaceOp(convOp, ArrayRef<Value>{reshapeOp_5});
// return success();
// }
// };
* the benefits brought by this pattern are minimal, yet it results in a
* significant performance degradation in models such as clip. Therefore, this
* pattern will be temporarily commented out.
*/
class ConvertConv2DToImg2Col final : public OpRewritePattern<top::ConvOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(top::ConvOp convOp,
PatternRewriter &rewriter) const override {
Value input = convOp.getInput();
Value filter = convOp.getFilter();
Value bias = convOp.getBias();
Value output = convOp.getOutput();
auto inputType = llvm::cast<ShapedType>(input.getType());
auto filterType = llvm::cast<ShapedType>(filter.getType());
auto outputType = llvm::cast<ShapedType>(output.getType());
bool with_bias = !module::isNone(bias);
auto strides = module::getI64Array(convOp.getStrides());
// note: current support Conv2D
if (!filterType.hasStaticShape() || !inputType.hasStaticShape() ||
module::getShape(output).size() != 4) {
return failure();
}

auto hasAllOneValues = [&](mlir::ArrayAttr attr) -> bool {
return llvm::all_of(
attr.getAsRange<IntegerAttr>(),
[](IntegerAttr element) { return element.getInt() == 1; });
};
if (convOp.getDilations().has_value() &&
!hasAllOneValues(convOp.getDilations().value()))
return failure();

auto filterShape = filterType.getShape();
auto outputShape = outputType.getShape();

const int n = outputShape[0];
const int oc = outputShape[1];
const int oh = outputShape[2];
const int ow = outputShape[3];
const int ic = filterShape[1];
const int kh = filterShape[2];
const int kw = filterShape[3];
// When kh >= 29 and kw >= 29, the last dimension of the reordered kernel becomes quite large.
// Using it as the right matrix in matrix multiplication, particularly when performing a transpose on the right matrix,
// can lead to performance degradation.This adjustment is primarily made concerning the CLIP model.Further improvements will be considered later
if (!(ic <= 3 && kh >= 16 && kh < 29 && kw >= 16 && kw < 29 &&
strides->at(0) == kh && strides->at(1) == kw)) {
return failure();
}
int id = 0;
auto loc_name = module::getName(convOp.getOperation()).str();
// 1. Input->Reshape+permute+Reshape(reorder the input)
SmallVector<int64_t> colTensorShape = {n, ic, oh, kh, ow, kw};
auto reshapeOp = rewriter.create<top::ReshapeOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get(colTensorShape, inputType.getElementType()),
ValueRange{input});
std::vector<int64_t> order = {0, 2, 3, 1, 4, 5};
std::vector<NamedAttribute> attrs;
attrs.emplace_back(
rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr(order)));

auto perMuteOp_0 = rewriter.create<top::PermuteOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get({n, oh, kh, ic, ow, kw},
inputType.getElementType()),
ValueRange{reshapeOp}, attrs);
order = {0, 1, 4, 3, 2, 5};
attrs.clear();
attrs.emplace_back(
rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr(order)));
auto perMuteOp = rewriter.create<top::PermuteOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get({n, oh, ow, ic, kh, kw},
inputType.getElementType()),
ValueRange{perMuteOp_0}, attrs);

auto reshapeOp_2 = rewriter.create<top::ReshapeOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get({n, oh * ow, ic * kh * kw},
inputType.getElementType()),
ValueRange{perMuteOp});
// 2. filter->reshape
auto reshapeOp_3 = rewriter.create<top::ReshapeOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get({oc, ic * kh * kw}, filterType.getElementType()),
ValueRange{filter});
std::vector<Value> operands;
operands.emplace_back(reshapeOp_2);
operands.emplace_back(reshapeOp_3);
// 3. bias->reshape
if (with_bias) {
auto reshapeOp_4 = rewriter.create<top::ReshapeOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get(
{1, 1, oc},
llvm::cast<ShapedType>(bias.getType()).getElementType()),
ValueRange{bias});
operands.emplace_back(reshapeOp_4);
} else {
operands.emplace_back(bias);
}

attrs.clear();
attrs.emplace_back(
rewriter.getNamedAttr("right_transpose", rewriter.getBoolAttr(true)));
attrs.emplace_back(
rewriter.getNamedAttr("output_transpose", rewriter.getBoolAttr(false)));
// 4. matmul
auto matmulOp = rewriter.create<top::MatMulOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()),
operands, attrs);
attrs.clear();
attrs.emplace_back(
rewriter.getNamedAttr("order", rewriter.getI64ArrayAttr({0, 2, 1})));
// 5. permute
auto perMuteOp_2 = rewriter.create<top::PermuteOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()),
ValueRange{matmulOp}, attrs);
// 6. reshape the output
auto reshapeOp_5 = rewriter.create<top::ReshapeOp>(
NameLoc::get(
rewriter.getStringAttr(loc_name + "_" + std::to_string(id++))),
RankedTensorType::get({n, oc, oh, ow}, outputType.getElementType()),
ValueRange{perMuteOp_2});
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapeOp_5});
return success();
}
};

/* for to reduce the data move, split the matmul
to multiple matmul if match below pattern:
Expand Down Expand Up @@ -1030,9 +1033,9 @@ class SplitMatMulPattern : public OpRewritePattern<top::MatMulOp> {
}
std::vector<int64_t> new_right_shape(right_shape);
new_right_shape[new_right_shape.size() - 1] = slice_width[idx];
auto new_filter =
top::WeightOp::create_float(op, "_filter_" + std::to_string(id),
*new_filter_f32, new_right_shape, storage_type);
auto new_filter = top::WeightOp::create_float(
op, "_filter_" + std::to_string(id), *new_filter_f32, new_right_shape,
storage_type);
operands.emplace_back(new_filter);

if (with_bias) {
Expand Down Expand Up @@ -1104,12 +1107,11 @@ namespace top {
using namespace bm1684x;
void populateOptimizeBM1684XPatterns(RewritePatternSet *patterns) {
patterns->add<MergeScale2Conv>(patterns->getContext(), /*PatternBenefit*/ 9);
patterns
->add<ConvertGLMTilePermute, ConvertMatMulWithRightTranspose,
ConvertMatMul2Attention, ReshapeReorderPattern,
ConvertMultiInputAdd, WhereBroadcastToTile,
SplitMatMulPattern, ConvertScaleOp, ConcatToSwapDimInner>(
patterns->getContext(), 8);
patterns->add<ConvertGLMTilePermute, ConvertMatMulWithRightTranspose,
ConvertMatMul2Attention, ReshapeReorderPattern,
ConvertMultiInputAdd, WhereBroadcastToTile,
ConvertConv2DToImg2Col, SplitMatMulPattern, ConvertScaleOp,
ConcatToSwapDimInner>(patterns->getContext(), 8);
}
} // namespace top
} // namespace tpu_mlir

0 comments on commit 187dec5

Please sign in to comment.