Skip to content

Commit

Permalink
[midend/lib/Conversion/ConvVectorization] fix some code style
Browse files Browse the repository at this point in the history
  • Loading branch information
FloatingcloudKnight committed Dec 27, 2024
1 parent 5266ebc commit fee59fd
Showing 1 changed file with 48 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern {
// Get Strides.
SmallVector<int64_t, 2> strides = {1, 1};
if (op->hasAttr("strides")) {
strides.clear();
for (auto value : op->getAttrOfType<mlir::DenseIntElementsAttr>("strides").getValues<int64_t>()) {
strides.push_back(value);
}
strides.clear();
for (auto value : op->getAttrOfType<mlir::DenseIntElementsAttr>("strides")
.getValues<int64_t>()) {
strides.push_back(value);
}
}
bool stride1 = strides[0] != 1;
bool stride2 = strides[1] != 1;
Expand All @@ -93,14 +94,17 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern {
// Get Dilations.
SmallVector<int64_t, 2> dilations = {1, 1};
if (op->hasAttr("dilations")) {
dilations.clear();
for (auto value : op->getAttrOfType<mlir::DenseIntElementsAttr>("dilations").getValues<int64_t>()) {
dilations.push_back(value);
}
dilations.clear();
for (auto value :
op->getAttrOfType<mlir::DenseIntElementsAttr>("dilations")
.getValues<int64_t>()) {
dilations.push_back(value);
}
}
bool dilated1 = dilations[0] != 1;
bool dilated2 = dilations[1] != 1;
Value dilHeight = rewriter.create<arith::ConstantIndexOp>(loc, dilations[0]);
Value dilHeight =
rewriter.create<arith::ConstantIndexOp>(loc, dilations[0]);
Value dilWidth = rewriter.create<arith::ConstantIndexOp>(loc, dilations[1]);

// Get i1 as the element type for mask vector.
Expand All @@ -115,7 +119,7 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern {
const Value c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
const Value c2 = rewriter.create<arith::ConstantIndexOp>(loc, 2);
const Value c3 = rewriter.create<arith::ConstantIndexOp>(loc, 3);
const Value vl_step = rewriter.create<arith::ConstantIndexOp>(loc, vecsize);
const Value vlStep = rewriter.create<arith::ConstantIndexOp>(loc, vecsize);
const Value zero =
buddy::insertZeroConstantOp(ctx, rewriter, loc, elementTy);

Expand All @@ -136,12 +140,11 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern {
Value width_o = rewriter.create<memref::DimOp>(loc, output, c2);

// Calculate the upper bound for vectorized processing
// - Subtract `vl_step` is to avoid overflow at the vectorization tail.
// - Subtract `vlStep` is to avoid overflow at the vectorization tail.
// - Add 1 to ensure the final loop runs when the workload length
// is divisible by the vector size.
Value upperBound_tmp =
rewriter.create<arith::SubIOp>(loc, channels, vl_step);
Value upperBound = rewriter.create<arith::AddIOp>(loc, upperBound_tmp, c1);
Value upperBoundTmp = rewriter.create<arith::SubIOp>(loc, channels, vlStep);
Value upperBound = rewriter.create<arith::AddIOp>(loc, upperBoundTmp, c1);

SmallVector<Value, 8> lowerBounds(4, c0);
SmallVector<Value, 8> uperBounds{batch, height_o, width_o, f_o};
Expand All @@ -150,20 +153,20 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern {
rewriter, loc, lowerBounds, uperBounds, steps,
[&](OpBuilder &builder, Location loc, ValueRange ivs) {
// Create strides variables.
Value tmp_ivs1 = ivs[1];
if(stride1){
tmp_ivs1 = builder.create<arith::MulIOp>(loc, ivs[1], strHeight);
Value tmpIvs1 = ivs[1];
if (stride1) {
tmpIvs1 = builder.create<arith::MulIOp>(loc, ivs[1], strHeight);
}
Value tmp_ivs2 = ivs[2];
if(stride2){
tmp_ivs2 = builder.create<arith::MulIOp>(loc, ivs[2], strWidth);
Value tmpIvs2 = ivs[2];
if (stride2) {
tmpIvs2 = builder.create<arith::MulIOp>(loc, ivs[2], strWidth);
}
Value tmp_result = builder.create<memref::LoadOp>(
loc, elementTy, output,
ValueRange{ivs[0], ivs[1], ivs[2], ivs[3]});
// Create vecsize mining loop.
auto iter_val = builder.create<scf::ForOp>(
loc, c0, upperBound, /*Step=*/vl_step, ValueRange{c0, tmp_result},
loc, c0, upperBound, /*Step=*/vlStep, ValueRange{c0, tmp_result},
[&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
ValueRange itrArgs) {
auto tmp0 = nestedBuilder.create<affine::AffineForOp>(
Expand All @@ -173,23 +176,27 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern {
[&](OpBuilder &builder, Location loc, Value iv0,
ValueRange itrArgs0) {
// Create dilated[0] variables.
Value tmp_ivs3 = iv0;
if(dilated1){
tmp_ivs3 = builder.create<arith::MulIOp>(loc, iv0, dilHeight);
Value tmpIvs3 = iv0;
if (dilated1) {
tmpIvs3 =
builder.create<arith::MulIOp>(loc, iv0, dilHeight);
}
Value inputHeight = builder.create<arith::AddIOp>(loc, tmp_ivs1, tmp_ivs3);
Value inputHeight =
builder.create<arith::AddIOp>(loc, tmpIvs1, tmpIvs3);
auto tmp1 = builder.create<affine::AffineForOp>(
loc, ValueRange{c0}, builder.getDimIdentityMap(),
ValueRange{width_k}, builder.getDimIdentityMap(),
/*Step=*/1, ValueRange{itrArgs0[0]},
[&](OpBuilder &builder, Location loc, Value iv1,
ValueRange itrArgs1) {
// Create dilated[1] variables.
Value tmp_ivs4 = iv1;
if(dilated2){
tmp_ivs4 = builder.create<arith::MulIOp>(loc, iv1, dilWidth);
Value tmpIvs4 = iv1;
if (dilated2) {
tmpIvs4 = builder.create<arith::MulIOp>(loc, iv1,
dilWidth);
}
Value inputWidth = builder.create<arith::AddIOp>(loc, tmp_ivs2, tmp_ivs4);
Value inputWidth = builder.create<arith::AddIOp>(
loc, tmpIvs2, tmpIvs4);
Value inputVector = builder.create<vector::LoadOp>(
loc, vectorTy, input,
ValueRange{ivs[0], inputHeight, inputWidth,
Expand Down Expand Up @@ -226,7 +233,7 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern {
nestedLoc, tmp1.getResult(0));
});
Value idx =
builder.create<arith::AddIOp>(loc, itrArgs[0], vl_step);
builder.create<arith::AddIOp>(loc, itrArgs[0], vlStep);
builder.create<scf::YieldOp>(
loc, ValueRange{idx, tmp0.getResult(0)});
});
Expand All @@ -250,25 +257,27 @@ class Conv2dNhwcFhwcVectorizationPattern : public ConversionPattern {
[&](OpBuilder &builder, Location loc, Value iv0,
ValueRange itrArgs0) {
// Create dilated[0] variables.
Value tmp_ivs3 = iv0;
if(dilated1){
tmp_ivs3 = builder.create<arith::MulIOp>(loc, iv0, dilHeight);
Value tmpIvs3 = iv0;
if (dilated1) {
tmpIvs3 =
builder.create<arith::MulIOp>(loc, iv0, dilHeight);
}
Value inputHeight =
builder.create<arith::AddIOp>(loc, tmp_ivs1, tmp_ivs3);
builder.create<arith::AddIOp>(loc, tmpIvs1, tmpIvs3);
auto tmp1 = builder.create<affine::AffineForOp>(
loc, ValueRange{c0}, builder.getDimIdentityMap(),
ValueRange{width_k}, builder.getDimIdentityMap(),
/*Step=*/1, ValueRange{itrArgs0[0]},
[&](OpBuilder &builder, Location loc, Value iv1,
ValueRange itrArgs1) {
// Create dilated[1] variables.
Value tmp_ivs4 = iv1;
if(dilated2){
tmp_ivs4 = builder.create<arith::MulIOp>(loc, iv1, dilWidth);
Value tmpIvs4 = iv1;
if (dilated2) {
tmpIvs4 = builder.create<arith::MulIOp>(loc, iv1,
dilWidth);
}
Value inputWidth =
builder.create<arith::AddIOp>(loc, tmp_ivs2, tmp_ivs4);
Value inputWidth = builder.create<arith::AddIOp>(
loc, tmpIvs2, tmpIvs4);
Value inputVec = builder.create<MaskedLoadOp>(
loc, vectorTy, input,
ValueRange{ivs[0], inputHeight, inputWidth,
Expand Down

0 comments on commit fee59fd

Please sign in to comment.