Skip to content

Commit

Permalink
Refactor conversion of aievec.mul_elem to support combined precision (#…
Browse files Browse the repository at this point in the history
…643)

* Refactor AIE-ML acc datatype emission
* Refactor arith.muli/mulf to aievec.mul_elem conversion pattern to make it extensible and clean
  - Reorganize the existing case-by-case patterns and decouple the pattern that requires two inputs to be the same type
  - Make it a cleaner pattern considering lhs/rhs/out datatype
  - Verified that all the dut.cc are identical before/after the refactor
* Add convertValueToTargetTypeAieML() which can be helpful for handling the vector lane mismatch issue later on.
* Add CPP emission for aievec.unpack op
* Add VectorToAIEVec lit tests to cover the lowering patterns
* Add new combined precision tosa tests for element-wise multiply:
  - i8xi16_mul_elem_v32 (out=i32, lane=32) (cycle count=144, PM=272), PASS
  - i8xi16_mul_elem_v16 (out=i32, lane=16) (cycle count=792, PM=368), XFAIL
    - No intent to work on this at the moment, but keep a record there
  - i16xi32_mul_elem (out=i32, lane=16) (cycle count=408, PM=384), PASS
  - i8xi32_mul_elem (out=i32, lane=16) (cycle count=728, PM=368), PASS
  • Loading branch information
jamestcl-amd authored Sep 19, 2023
1 parent bb7653e commit 679c3ce
Show file tree
Hide file tree
Showing 20 changed files with 697 additions and 177 deletions.
292 changes: 170 additions & 122 deletions lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,34 +87,78 @@ extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
return {};
}

// Create MulElemOp for i8 and bf16 types in aie-ml. The corresponding intrinsic
// is mul_elem_16_2, which indicates that we need to concatenate zero vectors
// for both mul operands before creating MulElemOp.
static aievec::MulElemOp createMulElemAieML(ConversionPatternRewriter &rewriter,
Value lval, Value rval,
VectorType srcType,
unsigned bitWidth, Location loc) {
Type accType = getVectorOpDestType(srcType, /*AIEML =*/true);
VectorType vecType =
createVectorType(512 / bitWidth, srcType.getElementType());

arith::ConstantOp zeroConstOp = nullptr;
zeroConstOp = rewriter.create<arith::ConstantOp>(
loc, srcType.getElementType(),
rewriter.getZeroAttr(srcType.getElementType()));
auto broadcastZeroOp = rewriter.create<aievec::BroadcastScalarOp>(
loc, vecType, zeroConstOp->getResult(0));
auto extOp = rewriter.create<aievec::ExtOp>(loc, srcType,
broadcastZeroOp.getResult(), 0);

SmallVector<Value> lSources = {lval, extOp->getResult(0)};
SmallVector<Value> rSources = {rval, extOp->getResult(0)};
auto lConcatOp = rewriter.create<aievec::ConcatOp>(loc, vecType, lSources);
auto rConcatOp = rewriter.create<aievec::ConcatOp>(loc, vecType, rSources);

auto mulElemOp = rewriter.create<aievec::MulElemOp>(
loc, accType, lConcatOp->getResult(0), rConcatOp->getResult(0));
return mulElemOp;
// Convert a input value to a target vector type. This function can insert
// multiple aievec ops depending on the combination of input and output vector
// types.
static std::optional<Value>
convertValueToTargetTypeAieML(ConversionPatternRewriter &rewriter, Location loc,
Value inputVal, VectorType tgtType) {
VectorType srcType = cast<VectorType>(inputVal.getType());
auto srcElemType = srcType.getElementType();
unsigned srcBitWidth = srcElemType.getIntOrFloatBitWidth();
unsigned srcLaneSize = getVectorLaneSize(srcType);

auto tgtElemType = tgtType.getElementType();
unsigned tgtBitWidth = tgtElemType.getIntOrFloatBitWidth();
unsigned tgtLaneSize = getVectorLaneSize(tgtType);

if (srcType == tgtType)
return inputVal;

if ((srcElemType == tgtElemType) && (srcLaneSize != tgtLaneSize)) {
// TODO: relax the condition below?
if ((srcLaneSize == 16 && tgtLaneSize == 32 &&
isa<FloatType>(srcElemType)) ||
(srcLaneSize == 32 && tgtLaneSize == 64 &&
isa<IntegerType>(srcElemType))) {
auto zeroConstOp = rewriter.create<arith::ConstantOp>(
loc, srcType.getElementType(),
rewriter.getZeroAttr(srcType.getElementType()));
auto broadcastZeroOp = rewriter.create<aievec::BroadcastScalarOp>(
loc, tgtType, zeroConstOp->getResult(0));
auto extOp = rewriter.create<aievec::ExtOp>(
loc, srcType, broadcastZeroOp->getResult(0), 0);

SmallVector<Value> inputSources = {inputVal, extOp->getResult(0)};
aievec::ConcatOp concatOp =
rewriter.create<aievec::ConcatOp>(loc, tgtType, inputSources);

return concatOp.getResult();
}
} else if ((srcElemType != tgtElemType) && (srcLaneSize == tgtLaneSize) &&
isa<IntegerType>(srcElemType) && isa<IntegerType>(tgtElemType)) {
if (srcBitWidth == 16 && tgtBitWidth == 32 && srcLaneSize == 16) {
// Case 1: vector<16xi16> to vector<16xi32> conversion by aievec.ups +
// aievec.cast
auto accType = getVectorOpDestType(srcType, /*AIEML =*/true);
auto upsOp = rewriter.create<aievec::UPSOp>(loc, accType, inputVal);
auto castOp = rewriter.create<aievec::CastOp>(
loc, tgtType, upsOp.getResult(), /*isResAcc*/ false);
return castOp.getResult();
} else if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) {
// Case 2: vector<16xi8> to vector<16xi32> conversion by aievec.concat +
// aievec.ups + aievec.cast + aievec.ext
// FIXME: Should use undef_xxx() for the second input of concat
auto concatOutType = createVectorType(32, srcElemType);
auto concatOp = rewriter.create<aievec::ConcatOp>(
loc, concatOutType, SmallVector<Value>({inputVal, inputVal}));
auto accType = getVectorOpDestType(concatOutType, /*AIEML =*/true);
auto upsOp =
rewriter.create<aievec::UPSOp>(loc, accType, concatOp.getResult());
auto castType = createVectorType(32, tgtElemType);
auto castOp = rewriter.create<aievec::CastOp>(
loc, castType, upsOp.getResult(), /*isResAcc*/ false);
auto extOp =
rewriter.create<aievec::ExtOp>(loc, tgtType, castOp.getResult(), 0);
return extOp.getResult();
} else if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) {
// Case 3: vector<32xi8> to vector<32xi16> conversion by aievec.unpack
auto unpackOp = rewriter.create<aievec::UnpackOp>(loc, tgtType, inputVal);
return unpackOp.getResult();
}
}

return std::nullopt;
}

// Return the list of attributes that configure an `aievec.select` op to
Expand Down Expand Up @@ -546,8 +590,8 @@ struct ConvertMulFToAIEVecMulElemOpPattern
if (!resultType)
return failure();

// FIXME: Verify it is not a part of FMA
auto isAddOp = [&](Operation *op) { return isa<arith::AddFOp>(op); };
// Verify it is not a part of FMA
if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
return failure();

Expand All @@ -560,42 +604,68 @@ struct ConvertMulFToAIEVecMulElemOpPattern
if (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32))
return failure();

aievec::MulElemOp mulElemOp = nullptr;

if (resultElWidth == 16) {
mulElemOp =
createMulElemAieML(rewriter, adaptor.getLhs(), adaptor.getRhs(),
resultType, resultElWidth, mulOp.getLoc());
rewriter.replaceOpWithNewOp<aievec::SRSOp>(
mulOp, resultType, mulElemOp.getResult(), shiftParam);
// Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs
auto lval = adaptor.getLhs();
auto rval = adaptor.getRhs();
if (auto lvalExtOp = lval.getDefiningOp<arith::ExtFOp>()) {
lval = lvalExtOp->getOperand(0);
}
if (auto rvalExtOp = rval.getDefiningOp<arith::ExtFOp>()) {
rval = rvalExtOp->getOperand(0);
}
VectorType lSrcType = cast<VectorType>(lval.getType());
VectorType rSrcType = cast<VectorType>(rval.getType());
unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
Type accType = getVectorOpDestType(lSrcType, /*AIEML =*/true);
if (rBitWidth > lBitWidth) {
accType = getVectorOpDestType(rSrcType, /*AIEML =*/true);
}
// Only support the same lhs/rhs type at the moment
if (lSrcType != rSrcType) {
return failure();
}
// Only support two bfloat16 inputs at the moment
if (lBitWidth != 16 || rBitWidth != 16) {
return failure();
}
// float type
else {
auto lhs = dyn_cast<arith::ExtFOp>(adaptor.getLhs().getDefiningOp());
auto rhs = dyn_cast<arith::ExtFOp>(adaptor.getRhs().getDefiningOp());

if (!lhs || !rhs)
return failure();

auto lval = lhs->getOperand(0);
auto rval = rhs->getOperand(0);

VectorType lSrcType = cast<VectorType>(lval.getType());
VectorType rSrcType = cast<VectorType>(rval.getType());
// Prepare lhr/rhs for the aievec.mul_elem op
VectorType targetInputType =
createVectorType(512 / lBitWidth, lSrcType.getElementType());
if (rBitWidth > lBitWidth) {
targetInputType =
createVectorType(512 / rBitWidth, rSrcType.getElementType());
}
auto lValConverted = convertValueToTargetTypeAieML(rewriter, mulOp.getLoc(),
lval, targetInputType);
auto rValConverted = convertValueToTargetTypeAieML(rewriter, mulOp.getLoc(),
rval, targetInputType);
if (!lValConverted || !rValConverted)
return failure();

unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
// Create an aievec.mul_elem op
aievec::MulElemOp mulElemOp = rewriter.create<aievec::MulElemOp>(
mulOp.getLoc(), accType, *lValConverted, *rValConverted);

if (lBitWidth != 16 || rBitWidth != 16)
return failure();
// Create an aievec.cast or an aievec.srs op
auto mulElemResultType = mulElemOp.getType();
auto mulElemResultElWidth =
mulElemResultType.getElementType().getIntOrFloatBitWidth();

mulElemOp = createMulElemAieML(rewriter, lval, rval, lSrcType, lBitWidth,
mulOp.getLoc());
if (mulElemResultElWidth == resultElWidth) {
rewriter.replaceOpWithNewOp<aievec::CastOp>(
mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
} else if (mulElemResultElWidth > resultElWidth) {
rewriter.replaceOpWithNewOp<aievec::SRSOp>(
mulOp, resultType, mulElemOp.getResult(), shiftParam);
} else {
return failure();
}

return success();
}

unsigned shiftParam;
};

Expand All @@ -617,8 +687,8 @@ struct ConvertMulIToAIEVecMulElemOpPattern
if (!resultType)
return failure();

// FIXME: Verify it is not a part of MAC
auto isAddOp = [&](Operation *op) { return isa<arith::AddIOp>(op); };
// Verify it is not a part of MAC
if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
return failure();

Expand All @@ -631,79 +701,57 @@ struct ConvertMulIToAIEVecMulElemOpPattern
((laneSize != 16 && laneSize != 32) || resultElWidth != 32))
return failure();

// Deal with the case with sext op for i8 and i16:
// Case 1:
// Transfer -
// %1 = arith.extsi %a : vector<32xi8> to vector<32xi32>
// %2 = arith.extsi %b : vector<32xi8> to vector<32xi32>
// %3 = arith.muli %1, %2 : vector<32xi32>
// to -
// aievec.mul_elem(%a, %b) : vector<64xi8>, vector<64xi8>, vector<32xi32>
//
// Case 2:
// Transfer -
// %1 = arith.extsi %a : vector<32xi16> to vector<32xi32>
// %2 = arith.extsi %b : vector<32xi16> to vector<32xi32>
// %3 = arith.muli %1, %2 : vector<32xi32>
// to -
// aievec.mul_elem(%a, %b) : vector<32xi16>, vector<32xi16>, vector<32xi32>
if (laneSize == 32 && (resultElWidth == 32 || resultElWidth == 8)) {
if (resultElWidth == 32) {
auto lhs = dyn_cast<arith::ExtSIOp>(adaptor.getLhs().getDefiningOp());
auto rhs = dyn_cast<arith::ExtSIOp>(adaptor.getRhs().getDefiningOp());

if (!lhs || !rhs)
return failure();

auto lval = lhs->getOperand(0);
auto rval = rhs->getOperand(0);

VectorType lSrcType = cast<VectorType>(lval.getType());
VectorType rSrcType = cast<VectorType>(rval.getType());
// Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs
auto lval = adaptor.getLhs();
auto rval = adaptor.getRhs();
if (auto lvalExtOp = lval.getDefiningOp<arith::ExtSIOp>()) {
lval = lvalExtOp->getOperand(0);
}
if (auto rvalExtOp = rval.getDefiningOp<arith::ExtSIOp>()) {
rval = rvalExtOp->getOperand(0);
}
VectorType lSrcType = cast<VectorType>(lval.getType());
VectorType rSrcType = cast<VectorType>(rval.getType());
unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
Type accType = getVectorOpDestType(lSrcType, /*AIEML =*/true);
if (rBitWidth > lBitWidth) {
accType = getVectorOpDestType(rSrcType, /*AIEML =*/true);
}

unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
// Prepare lhr/rhs for the aievec.mul_elem op
VectorType targetInputType =
createVectorType(512 / lBitWidth, lSrcType.getElementType());
if (rBitWidth > lBitWidth) {
targetInputType =
createVectorType(512 / rBitWidth, rSrcType.getElementType());
}
auto lValConverted = convertValueToTargetTypeAieML(rewriter, mulOp.getLoc(),
lval, targetInputType);
auto rValConverted = convertValueToTargetTypeAieML(rewriter, mulOp.getLoc(),
rval, targetInputType);
if (!lValConverted || !rValConverted)
return failure();

if ((lBitWidth != 8 || rBitWidth != 8) &&
(lBitWidth != 16 || rBitWidth != 16))
return failure();
// Create an aievec.mul_elem op
aievec::MulElemOp mulElemOp = rewriter.create<aievec::MulElemOp>(
mulOp.getLoc(), accType, *lValConverted, *rValConverted);

aievec::MulElemOp mulElemOp = nullptr;
if (lBitWidth == 8) {
mulElemOp = createMulElemAieML(rewriter, lval, rval, lSrcType,
lBitWidth, mulOp.getLoc());
} else {
Type accType = getVectorOpDestType(lSrcType, /*AIEML =*/true);
mulElemOp = rewriter.create<aievec::MulElemOp>(mulOp.getLoc(),
accType, lval, rval);
}
rewriter.replaceOpWithNewOp<aievec::CastOp>(
mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
// Case 3:
// Transfer -
// %1 = arith muli %a, %b : vector<32xi8>
// to -
// aievec.mul_elem(%a, %b) : vector<64xi8>, vector<64xi8>,
// vector<32xi32>
} else {
auto lval = adaptor.getLhs();
auto rval = adaptor.getRhs();
VectorType srcType = cast<VectorType>(lval.getType());
unsigned bitWidth = srcType.getElementType().getIntOrFloatBitWidth();
auto mulElemOp = createMulElemAieML(rewriter, lval, rval, srcType,
bitWidth, mulOp.getLoc());
rewriter.replaceOpWithNewOp<aievec::SRSOp>(
mulOp, srcType, mulElemOp.getResult(), shiftParam);
}
} else {
Type accType = getVectorOpDestType(cast<VectorType>(mulOp.getType()),
/*AIEML =*/true);
// Create an aievec.cast or an aievec.srs op
auto mulElemResultType = mulElemOp.getType();
auto mulElemResultElWidth =
mulElemResultType.getElementType().getIntOrFloatBitWidth();

auto mulElemOp = rewriter.create<aievec::MulElemOp>(
mulOp.getLoc(), accType, adaptor.getLhs(), adaptor.getRhs());
if (mulElemResultElWidth == resultElWidth) {
rewriter.replaceOpWithNewOp<aievec::CastOp>(
mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
} else if (mulElemResultElWidth > resultElWidth) {
rewriter.replaceOpWithNewOp<aievec::SRSOp>(
mulOp, resultType, mulElemOp.getResult(), shiftParam);
} else {
return failure();
}

return success();
}

Expand Down
Loading

0 comments on commit 679c3ce

Please sign in to comment.