From 732c13183896b0d47d8373f21f2b0b40b1c96aba Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 26 Nov 2024 21:49:48 +0000 Subject: [PATCH] Use post order traverse on blocks such that values used inside a loop, but defined outside the loop can be handled correctly. --- lib/Transforms/VnniTransformation.cpp | 128 ++++++++++-------- test/Transforms/VnniTransform/unit-tests.mlir | 32 ++++- 2 files changed, 101 insertions(+), 59 deletions(-) diff --git a/lib/Transforms/VnniTransformation.cpp b/lib/Transforms/VnniTransformation.cpp index 6f241b867..d1b8af067 100644 --- a/lib/Transforms/VnniTransformation.cpp +++ b/lib/Transforms/VnniTransformation.cpp @@ -348,6 +348,9 @@ static void applyVnniTransformOnResults(mlir::OpBuilder &builder, // the op, and whether it is safe to apply vnni transform on operands too. static void updateUnknownOp(mlir::OpBuilder &builder, mlir::Operation &op, LayoutAnalysis &analysis) { + // Ignore ops that has packed attribute, since they are inserted by the pass. + if (op.hasAttr("packed")) + return; applyVnniTransformOnResults(builder, &op, analysis); } @@ -469,82 +472,84 @@ static void updateExtractStrideSliceOp(mlir::OpBuilder &builder, } } -static void handleBranchOpInterface(mlir::OpBuilder &builder, - mlir::Block &block, - mlir::RegionBranchOpInterface branch, - mlir::TypeRange argsTypes) { - builder.setInsertionPointToStart(&block); +// handle terminal ops, e.g., scf.Yield. Update +// the types of its successor inputs if successor +// operands needs vnni format. +static void handleBranchTerminatorOpInterface( + mlir::OpBuilder &builder, + mlir::RegionBranchTerminatorOpInterface terminator, + LayoutAnalysis &analysis) { + + if (!mlir::isa(terminator->getParentOp())) + return; - mlir::Operation *op = branch.getOperation(); llvm::SmallVector successors; - llvm::SmallVector operands(op->getNumOperands(), nullptr); - branch.getEntrySuccessorRegions(operands, successors); + llvm::SmallVector operands(terminator->getNumOperands(), + nullptr); + terminator.getSuccessorRegions(operands, successors); for (mlir::RegionSuccessor &successor : successors) { - if (block.getParent() != successor.getSuccessor()) + if (!successor.isParent()) continue; - mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor); + mlir::OperandRange operands = terminator.getSuccessorOperands(successor); mlir::ValueRange inputs = successor.getSuccessorInputs(); - for (auto [arg, input] : llvm::zip(operands, inputs)) { - auto idx = mlir::cast(input).getArgNumber(); - mlir::Type dstType = argsTypes[idx]; - if (dstType == arg.getType()) { - input.setType(dstType); - continue; - } else { - auto cast = mlir::cast>(arg); - auto &&[newArg, root] = applyVnniTransform(builder, cast); - arg.replaceAllUsesExcept(newArg, root); + for (auto [arg, inp] : llvm::zip(operands, inputs)) { + if (analysis.getLayout(arg)) { + auto vecTy = mlir::cast(arg.getType()); + auto packedTy = getPackedType(vecTy); + inp.setType(packedTy); } } } +} - auto terminator = mlir::cast( - block.getTerminator()); - mlir::SmallVector operandAttributes( - terminator->getNumOperands(), nullptr); - - successors.clear(); - terminator.getSuccessorRegions(operandAttributes, successors); +// handle REgionBranchOps, e.g., scf.for. Update the +// region argument types, if the argument needs to be +// in vnni format, but the initArg is not, a vnni +// transform is applied on the initArg. +static void handleBranchOpInterface(mlir::OpBuilder &builder, + mlir::RegionBranchOpInterface branch, + LayoutAnalysis &analysis) { + mlir::Operation *op = branch.getOperation(); + llvm::SmallVector successors; + llvm::SmallVector operands(op->getNumOperands(), nullptr); + branch.getEntrySuccessorRegions(operands, successors); - for (const mlir::RegionSuccessor &successor : successors) { - if (!successor.isParent()) + for (mlir::RegionSuccessor &successor : successors) { + if (successor.isParent()) continue; + mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor); mlir::ValueRange inputs = successor.getSuccessorInputs(); - mlir::OperandRange operands = terminator.getSuccessorOperands(successor); - for (auto [operand, input] : llvm::zip(operands, inputs)) { - input.setType(operand.getType()); + + for (auto [arg, input] : llvm::zip(operands, inputs)) { + if (analysis.getLayout(input)) { + auto vecTy = mlir::cast(input.getType()); + auto packedTy = getPackedType(vecTy); + input.setType(packedTy); + if (!analysis.getLayout(arg)) { + builder.setInsertionPointAfterValue(arg); + auto cast = mlir::cast>(arg); + auto &&[newArg, root] = applyVnniTransform(builder, cast); + arg.replaceAllUsesExcept(newArg, root); + } + } } } } static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block, LayoutAnalysis &analysis) { - if (auto iface = mlir::dyn_cast_if_present( - block.getParentOp())) { - llvm::SmallVector types; - for (auto arg : block.getArguments()) { - auto argTy = arg.getType(); - if (!analysis.getLayout(arg)) { - types.push_back(argTy); - } else { - auto vecTy = mlir::cast(argTy); - auto packedTy = getPackedType(vecTy); - types.push_back(packedTy); + if (!mlir::isa(block.getParentOp())) { + builder.setInsertionPointToStart(&block); + for (auto &&arg : block.getArguments()) { + if (analysis.getLayout(arg)) { + auto cast = mlir::cast>(arg); + auto &&[newArg, root] = applyVnniTransform(builder, cast); + arg.replaceAllUsesExcept(newArg, root); } } - return handleBranchOpInterface(builder, block, iface, types); - } - - builder.setInsertionPointToStart(&block); - for (auto &&arg : block.getArguments()) { - if (analysis.getLayout(arg)) { - auto cast = mlir::cast>(arg); - auto &&[newArg, root] = applyVnniTransform(builder, cast); - arg.replaceAllUsesExcept(newArg, root); - } } } @@ -561,14 +566,21 @@ struct VnniTransformationPass final mlir::OpBuilder builder(&getContext()); llvm::SmallVector operands; - op->walk([&](mlir::Block *block) { + // process ops in post-order so that the layout info is + // used before being destroyed. + op->walk([&](mlir::Block *block) { // Iterate block ops in reverse so op is updated before it's operands. for (mlir::Operation &op : llvm::reverse(block->getOperations())) { - // Ignore shape casts as they are generated by the conversion itself. - // Ignore RegionBranchOpInterface as it handled in `updateBlockTypes`. - if (mlir::isa(op)) + if (auto terminator = + mlir::dyn_cast(op)) { + handleBranchTerminatorOpInterface(builder, terminator, analysis); continue; + } + + if (auto iface = mlir::dyn_cast(op)) { + handleBranchOpInterface(builder, iface, analysis); + continue; + } if (auto dpas = mlir::dyn_cast(op)) { updateDpasOp(builder, dpas, analysis); diff --git a/test/Transforms/VnniTransform/unit-tests.mlir b/test/Transforms/VnniTransform/unit-tests.mlir index 481ed7345..d413b65ec 100644 --- a/test/Transforms/VnniTransform/unit-tests.mlir +++ b/test/Transforms/VnniTransform/unit-tests.mlir @@ -374,4 +374,34 @@ func.func @test(%arg1 : !xegpu.tensor_desc<8x16xi16>, %arg2 : !xegpu.tensor_desc %1 = arith.bitcast %b : vector<16x16xi16> to vector<16x16xf16> %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> return %2 : vector<8x16xf32> -} +} + +// ----- + +//CHECK-LABEL: @test +// CHECK-SAME: (%[[arg0:.*]]: !xegpu.tensor_desc<8x16xf16>, %[[arg1:.*]]: !xegpu.tensor_desc<16x16xf16>, %[[arg2:.*]]: vector<16x16xf16>, %[[arg3:.*]]: i1) -> vector<8x16xf32> { +func.func @test(%arg1 : !xegpu.tensor_desc<8x16xf16>, %arg2 : !xegpu.tensor_desc<16x16xf16>, %arg3 : vector<16x16xf16>, %arg4 : i1) -> vector<8x16xf32> { + //CHECK: %[[r0:.*]] = vector.shape_cast %[[arg2]] {packed} : vector<16x16xf16> to vector<256xf16> + //CHECK: %[[r1:.*]] = vector.shuffle %[[r0]], %[[r0]] [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16> + //CHECK: %[[r2:.*]] = vector.shape_cast %[[r1]] {packed} : vector<256xf16> to vector<8x16x2xf16> + //CHECK: %[[r3:.*]] = xegpu.load_nd %[[arg0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %0 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK: %[[r4:.*]] = scf.if %[[arg3]] -> (vector<8x16xf32>) + %1 = scf.if %arg4 -> (vector<8x16xf32>) { + //CHECK: %[[r5:.*]] = xegpu.load_nd %[[arg1]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %2 = xegpu.load_nd %arg2 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + //CHECK: %[[r6:.*]] = arith.addf %[[r5]], %[[r2]] : vector<8x16x2xf16> + %3 = arith.addf %2, %arg3 : vector<16x16xf16> + //CHECK: %[[r7:.*]] = xegpu.dpas %[[r3]], %[[r6]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + //CHECK: scf.yield %[[r7]] : vector<8x16xf32> + scf.yield %4 : vector<8x16xf32> + } else { + //CHECK: %[[r5:.*]] = xegpu.dpas %[[r3]], %[[r2]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + %5 = xegpu.dpas %0, %arg3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + //CHECK: scf.yield %[[r5]] : vector<8x16xf32> + scf.yield %5 : vector<8x16xf32> + } + //CHECK: return %[[r4]] : vector<8x16xf32> + return %1 : vector<8x16xf32> +}