Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use post order traverse on blocks such that values used inside a loop, #971

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 70 additions & 58 deletions lib/Transforms/VnniTransformation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ static void applyVnniTransformOnResults(mlir::OpBuilder &builder,
// the op, and whether it is safe to apply vnni transform on operands too.
static void updateUnknownOp(mlir::OpBuilder &builder, mlir::Operation &op,
LayoutAnalysis &analysis) {
// Ignore ops that has packed attribute, since they are inserted by the pass.
if (op.hasAttr("packed"))
return;
applyVnniTransformOnResults(builder, &op, analysis);
}

Expand Down Expand Up @@ -469,82 +472,84 @@ static void updateExtractStrideSliceOp(mlir::OpBuilder &builder,
}
}

static void handleBranchOpInterface(mlir::OpBuilder &builder,
mlir::Block &block,
mlir::RegionBranchOpInterface branch,
mlir::TypeRange argsTypes) {
builder.setInsertionPointToStart(&block);
// handle terminal ops, e.g., scf.Yield. Update
// the types of its successor inputs if successor
// operands needs vnni format.
static void handleBranchTerminatorOpInterface(
mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
LayoutAnalysis &analysis) {

if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
return;

mlir::Operation *op = branch.getOperation();
llvm::SmallVector<mlir::RegionSuccessor> successors;
llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
branch.getEntrySuccessorRegions(operands, successors);
llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
nullptr);
terminator.getSuccessorRegions(operands, successors);

for (mlir::RegionSuccessor &successor : successors) {
if (block.getParent() != successor.getSuccessor())
if (!successor.isParent())
continue;

mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
mlir::ValueRange inputs = successor.getSuccessorInputs();
for (auto [arg, input] : llvm::zip(operands, inputs)) {
auto idx = mlir::cast<mlir::BlockArgument>(input).getArgNumber();
mlir::Type dstType = argsTypes[idx];
if (dstType == arg.getType()) {
input.setType(dstType);
continue;
} else {
auto cast = mlir::cast<mlir::TypedValue<mlir::VectorType>>(arg);
auto &&[newArg, root] = applyVnniTransform(builder, cast);
arg.replaceAllUsesExcept(newArg, root);
for (auto [arg, inp] : llvm::zip(operands, inputs)) {
if (analysis.getLayout(arg)) {
auto vecTy = mlir::cast<mlir::VectorType>(arg.getType());
auto packedTy = getPackedType(vecTy);
inp.setType(packedTy);
}
}
}
}

auto terminator = mlir::cast<mlir::RegionBranchTerminatorOpInterface>(
block.getTerminator());
mlir::SmallVector<mlir::Attribute> operandAttributes(
terminator->getNumOperands(), nullptr);

successors.clear();
terminator.getSuccessorRegions(operandAttributes, successors);
// handle REgionBranchOps, e.g., scf.for. Update the
// region argument types, if the argument needs to be
// in vnni format, but the initArg is not, a vnni
// transform is applied on the initArg.
static void handleBranchOpInterface(mlir::OpBuilder &builder,
mlir::RegionBranchOpInterface branch,
LayoutAnalysis &analysis) {
mlir::Operation *op = branch.getOperation();
llvm::SmallVector<mlir::RegionSuccessor> successors;
llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
branch.getEntrySuccessorRegions(operands, successors);

for (const mlir::RegionSuccessor &successor : successors) {
if (!successor.isParent())
for (mlir::RegionSuccessor &successor : successors) {
if (successor.isParent())
continue;

mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
mlir::ValueRange inputs = successor.getSuccessorInputs();
mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
for (auto [operand, input] : llvm::zip(operands, inputs)) {
input.setType(operand.getType());

for (auto [arg, input] : llvm::zip(operands, inputs)) {
if (analysis.getLayout(input)) {
auto vecTy = mlir::cast<mlir::VectorType>(input.getType());
auto packedTy = getPackedType(vecTy);
input.setType(packedTy);
if (!analysis.getLayout(arg)) {
builder.setInsertionPointAfterValue(arg);
auto cast = mlir::cast<mlir::TypedValue<mlir::VectorType>>(arg);
auto &&[newArg, root] = applyVnniTransform(builder, cast);
arg.replaceAllUsesExcept(newArg, root);
}
}
}
}
}

static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
LayoutAnalysis &analysis) {
if (auto iface = mlir::dyn_cast_if_present<mlir::RegionBranchOpInterface>(
block.getParentOp())) {
llvm::SmallVector<mlir::Type> types;
for (auto arg : block.getArguments()) {
auto argTy = arg.getType();
if (!analysis.getLayout(arg)) {
types.push_back(argTy);
} else {
auto vecTy = mlir::cast<mlir::VectorType>(argTy);
auto packedTy = getPackedType(vecTy);
types.push_back(packedTy);
if (!mlir::isa<mlir::RegionBranchOpInterface>(block.getParentOp())) {
builder.setInsertionPointToStart(&block);
for (auto &&arg : block.getArguments()) {
if (analysis.getLayout(arg)) {
auto cast = mlir::cast<mlir::TypedValue<mlir::VectorType>>(arg);
auto &&[newArg, root] = applyVnniTransform(builder, cast);
arg.replaceAllUsesExcept(newArg, root);
}
}
return handleBranchOpInterface(builder, block, iface, types);
}

builder.setInsertionPointToStart(&block);
for (auto &&arg : block.getArguments()) {
if (analysis.getLayout(arg)) {
auto cast = mlir::cast<mlir::TypedValue<mlir::VectorType>>(arg);
auto &&[newArg, root] = applyVnniTransform(builder, cast);
arg.replaceAllUsesExcept(newArg, root);
}
}
}

Expand All @@ -561,14 +566,21 @@ struct VnniTransformationPass final

mlir::OpBuilder builder(&getContext());
llvm::SmallVector<mlir::Type> operands;
op->walk<mlir::WalkOrder::PreOrder>([&](mlir::Block *block) {
// process ops in post-order so that the layout info is
// used before being destroyed.
op->walk([&](mlir::Block *block) {
// Iterate block ops in reverse so op is updated before it's operands.
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
// Ignore shape casts as they are generated by the conversion itself.
// Ignore RegionBranchOpInterface as it handled in `updateBlockTypes`.
if (mlir::isa<mlir::vector::ShapeCastOp, mlir::RegionBranchOpInterface,
mlir::RegionBranchTerminatorOpInterface>(op))
if (auto terminator =
mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
handleBranchTerminatorOpInterface(builder, terminator, analysis);
continue;
}

if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
handleBranchOpInterface(builder, iface, analysis);
continue;
}

if (auto dpas = mlir::dyn_cast<mlir::xegpu::DpasOp>(op)) {
updateDpasOp(builder, dpas, analysis);
Expand Down
32 changes: 31 additions & 1 deletion test/Transforms/VnniTransform/unit-tests.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,34 @@ func.func @test(%arg1 : !xegpu.tensor_desc<8x16xi16>, %arg2 : !xegpu.tensor_desc
%1 = arith.bitcast %b : vector<16x16xi16> to vector<16x16xf16>
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
return %2 : vector<8x16xf32>
}
}

// -----

//CHECK-LABEL: @test
// CHECK-SAME: (%[[arg0:.*]]: !xegpu.tensor_desc<8x16xf16>, %[[arg1:.*]]: !xegpu.tensor_desc<16x16xf16>, %[[arg2:.*]]: vector<16x16xf16>, %[[arg3:.*]]: i1) -> vector<8x16xf32> {
func.func @test(%arg1 : !xegpu.tensor_desc<8x16xf16>, %arg2 : !xegpu.tensor_desc<16x16xf16>, %arg3 : vector<16x16xf16>, %arg4 : i1) -> vector<8x16xf32> {
//CHECK: %[[r0:.*]] = vector.shape_cast %[[arg2]] {packed} : vector<16x16xf16> to vector<256xf16>
//CHECK: %[[r1:.*]] = vector.shuffle %[[r0]], %[[r0]] [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16>
//CHECK: %[[r2:.*]] = vector.shape_cast %[[r1]] {packed} : vector<256xf16> to vector<8x16x2xf16>
//CHECK: %[[r3:.*]] = xegpu.load_nd %[[arg0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%0 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
//CHECK: %[[r4:.*]] = scf.if %[[arg3]] -> (vector<8x16xf32>)
%1 = scf.if %arg4 -> (vector<8x16xf32>) {
//CHECK: %[[r5:.*]] = xegpu.load_nd %[[arg1]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
%2 = xegpu.load_nd %arg2 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
//CHECK: %[[r6:.*]] = arith.addf %[[r5]], %[[r2]] : vector<8x16x2xf16>
%3 = arith.addf %2, %arg3 : vector<16x16xf16>
//CHECK: %[[r7:.*]] = xegpu.dpas %[[r3]], %[[r6]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
%4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
//CHECK: scf.yield %[[r7]] : vector<8x16xf32>
scf.yield %4 : vector<8x16xf32>
} else {
//CHECK: %[[r5:.*]] = xegpu.dpas %[[r3]], %[[r2]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
%5 = xegpu.dpas %0, %arg3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
//CHECK: scf.yield %[[r5]] : vector<8x16xf32>
scf.yield %5 : vector<8x16xf32>
}
//CHECK: return %[[r4]] : vector<8x16xf32>
return %1 : vector<8x16xf32>
}
Loading