Skip to content

Commit

Permalink
[aievec] Add missing fixes for bump (#1456)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain authored May 3, 2024
1 parent 7b22ed2 commit 5419d5f
Showing 1 changed file with 99 additions and 36 deletions.
135 changes: 99 additions & 36 deletions lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2179,6 +2179,33 @@ struct LowerTruncOpPattern : OpConversionPattern<SrcOpTy> {
using LowerTruncFOpPattern = LowerTruncOpPattern<arith::TruncFOp>;
using LowerTruncIOpPattern = LowerTruncOpPattern<arith::TruncIOp>;

// If `op` is the last operation in the sequence:
// %0 = unrealized_conversion_cast <%IN> : <native type>, !emitc.opaque_type
// %1 = emitc.call_opaque <funcName>, %0...
// %2 = unrealized_conversion_cast %1 : !emitc.opaque_type, <native type>
// return the value <%IN>.
static std::optional<Value>
getUnOpaquedOperandOfEmitCOpaqueCallOp(Operation *op, StringRef funcName) {
auto uccOp = dyn_cast<UnrealizedConversionCastOp>(op);
if (!uccOp)
return {};

auto inVal = uccOp.getInputs()[0];
if (!isa<emitc::OpaqueType>(inVal.getType()))
return {};

auto callOp = inVal.getDefiningOp<emitc::CallOpaqueOp>();
if (callOp.getCallee() != funcName)
return {};

auto callOperandsUccOp =
callOp.getOperands()[0].getDefiningOp<UnrealizedConversionCastOp>();
if (!callOperandsUccOp)
return {};

return callOperandsUccOp.getInputs()[0];
}

// Check there is an operation chain like-
//
// %cst_0 = arith.constant dense<1.000000e+00> : vector<16xbf16>
Expand Down Expand Up @@ -2253,14 +2280,21 @@ static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) {
if (!addLvalOp || !addRvalOp)
return false;

if (!((isa<math::ExpOp>(addLvalOp) && isa<arith::ConstantOp>(addRvalOp)) ||
(isa<math::ExpOp>(addRvalOp) && isa<arith::ConstantOp>(addLvalOp)) ||
(isa<emitc::CallOpaqueOp>(addLvalOp) &&
cast<emitc::CallOpaqueOp>(addLvalOp).getCallee() == "getExpBf16" &&
isa<arith::ConstantOp>(addRvalOp)) ||
(isa<emitc::CallOpaqueOp>(addRvalOp) &&
cast<emitc::CallOpaqueOp>(addRvalOp).getCallee() == "getExpBf16" &&
isa<arith::ConstantOp>(addLvalOp))))
auto addLvalExpOp = dyn_cast<math::ExpOp>(addLvalOp);
auto addRvalExpOp = dyn_cast<math::ExpOp>(addRvalOp);
auto addLvalExpOpIn =
getUnOpaquedOperandOfEmitCOpaqueCallOp(addLvalOp, "getExpBf16")
.value_or(nullptr);
auto addRvalExpOpIn =
getUnOpaquedOperandOfEmitCOpaqueCallOp(addRvalOp, "getExpBf16")
.value_or(nullptr);
if (!addLvalExpOpIn && addLvalExpOp)
addLvalExpOpIn = addLvalExpOp.getOperand();
if (!addRvalExpOpIn && addRvalExpOp)
addRvalExpOpIn = addRvalExpOp.getOperand();

if (!((addLvalExpOpIn && isa<arith::ConstantOp>(addRvalOp)) ||
(addRvalExpOpIn && isa<arith::ConstantOp>(addLvalOp))))
return false;

constOp = isa<arith::ConstantOp>(addLvalOp)
Expand All @@ -2273,24 +2307,11 @@ static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) {
if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
return false;

auto expOp = isa<math::ExpOp>(addLvalOp)
? cast<math::ExpOp>(addLvalOp)
: (isa<emitc::CallOpaqueOp>(addLvalOp)
? cast<emitc::CallOpaqueOp>(addLvalOp)
: (isa<math::ExpOp>(addRvalOp)
? cast<math::ExpOp>(addRvalOp)
: cast<emitc::CallOpaqueOp>(addRvalOp)));

auto expOperand =
isa<math::ExpOp>(expOp)
? cast<math::ExpOp>(expOp).getOperand()
: *(cast<emitc::CallOpaqueOp>(expOp).getOperands().begin());
negOp = dyn_cast<arith::NegFOp>(expOperand.getDefiningOp());

if (!negOp)
return false;
auto expOperand = addLvalExpOpIn ? addLvalExpOpIn : addRvalExpOpIn;

return true;
negOp = expOperand.getDefiningOp<arith::NegFOp>();

return negOp != nullptr;
}

// Convert the operation chain like-
Expand Down Expand Up @@ -2336,10 +2357,24 @@ struct ComputeSigmoidOpPattern : OpConversionPattern<arith::DivFOp> {
rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);

rewriter.setInsertionPoint(divfOp);
SmallVector<Value> sigmoidOperands = {negOp.getOperand()};
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
divfOp, TypeRange{adaptor.getLhs().getType()}, "getSigmoidBf16",
nullptr, nullptr, sigmoidOperands);
Type vecOpaqueTy;
if (laneSize == 16)
vecOpaqueTy =
emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
else
vecOpaqueTy =
emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
auto opaquedOperand =
rewriter
.create<UnrealizedConversionCastOp>(divfOp.getLoc(), vecOpaqueTy,
negOp.getOperand())
.getResult(0);
SmallVector<Value> sigmoidOperands = {opaquedOperand};
auto callOp = rewriter.create<emitc::CallOpaqueOp>(
divfOp.getLoc(), TypeRange{vecOpaqueTy}, "getSigmoidBf16", nullptr,
nullptr, sigmoidOperands);
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
divfOp, TypeRange{adaptor.getLhs().getType()}, callOp.getResults());

return success();
}
Expand Down Expand Up @@ -2372,10 +2407,24 @@ struct ComputeCeilOpPattern : OpConversionPattern<math::CeilOp> {
rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);

rewriter.setInsertionPoint(ceilOp);
SmallVector<Value> ceilOperands = {adaptor.getOperand()};
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
ceilOp, TypeRange{ceilOp.getResult().getType()}, "getCeilBf16", nullptr,
Type vecOpaqueTy;
if (laneSize == 16)
vecOpaqueTy =
emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
else
vecOpaqueTy =
emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
auto opaquedOperand =
rewriter
.create<UnrealizedConversionCastOp>(ceilOp.getLoc(), vecOpaqueTy,
adaptor.getOperand())
.getResult(0);
SmallVector<Value> ceilOperands = {opaquedOperand};
auto callOp = rewriter.create<emitc::CallOpaqueOp>(
ceilOp.getLoc(), TypeRange{vecOpaqueTy}, "getCeilBf16", nullptr,
nullptr, ceilOperands);
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
ceilOp, TypeRange{ceilOp.getResult().getType()}, callOp.getResults());

return success();
}
Expand Down Expand Up @@ -2408,10 +2457,24 @@ struct ComputeFloorOpPattern : OpConversionPattern<math::FloorOp> {
rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);

rewriter.setInsertionPoint(floorOp);
SmallVector<Value> floorOperands = {adaptor.getOperand()};
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
floorOp, TypeRange{floorOp.getResult().getType()}, "getFloorBf16",
nullptr, nullptr, floorOperands);
Type vecOpaqueTy;
if (laneSize == 16)
vecOpaqueTy =
emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
else
vecOpaqueTy =
emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
auto opaquedOperand =
rewriter
.create<UnrealizedConversionCastOp>(floorOp.getLoc(), vecOpaqueTy,
adaptor.getOperand())
.getResult(0);
SmallVector<Value> floorOperands = {opaquedOperand};
auto callOp = rewriter.create<emitc::CallOpaqueOp>(
floorOp.getLoc(), TypeRange{vecOpaqueTy}, "getFloorBf16", nullptr,
nullptr, floorOperands);
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
floorOp, TypeRange{floorOp.getResult().getType()}, callOp.getResults());

return success();
}
Expand Down

0 comments on commit 5419d5f

Please sign in to comment.