Skip to content

Commit

Permalink
[XeGPU][Transforms] Bug fix in OptimizeTranspose. (#976)
Browse files Browse the repository at this point in the history
  • Loading branch information
charithaintc authored Dec 4, 2024
1 parent e93172a commit 2d70a58
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lib/Transforms/OptimizeTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ struct ScfForOpPattern : public OpConversionPattern<scf::ForOp> {
// If initArg is a UnrealizedConversionCastOp, new iter args need to be
// updated. And we keep track of the old -> new iter arg indice mapping.
if (auto unrealizedConversionCast =
llvm::dyn_cast<UnrealizedConversionCastOp>(
llvm::dyn_cast_if_present<UnrealizedConversionCastOp>(
initArg.getDefiningOp())) {
auto sources = unrealizedConversionCast.getOperands();
forOpOutputTypeMapping[index] =
Expand Down
49 changes: 49 additions & 0 deletions test/Transforms/xegpu-optimize-transpose.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,55 @@ func.func @test_scf_for_array_len(%arg0 : memref<64x64xf16>, %arg1 : vector<8x16
return %result#0 : vector<8x16xf32>
}

// -----
// CHECK-LABEL: @test_nested_scf_for_array_len(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<64x64xf16>, %[[ARG1:[a-zA-Z0-9]+]]: vector<8x16xf16>, %[[ARG2:[a-zA-Z0-9]+]]: memref<64x64xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK: scf.for {{.*}} {
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %[[C0]]] : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %[[C16]]] : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T3:.*]]:3 = scf.for {{.*}} iter_args(%[[ARG5:.*]] = %[[T2]], %[[ARG6:.*]] = %[[T0]], %[[ARG7:.*]] = %[[T1]]) -> (vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>) {
// CHECK-DAG: %[[T5:.*]] = xegpu.load_nd %[[ARG6]] <{transpose = array<i64: 1, 0>, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<8x16x2xf16>
// CHECK-DAG: %[[T7:.*]] = xegpu.load_nd %[[ARG7]] <{transpose = array<i64: 1, 0>, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>> -> vector<8x16x2xf16>
// CHECK-DAG: %[[T9:.*]] = xegpu.update_nd_offset %[[ARG6]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK-DAG: %[[T10:.*]] = xegpu.update_nd_offset %[[ARG7]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: scf.yield %{{.*}}, %[[T9]], %[[T10]] : vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<64x64xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<array_length = 1 : i64>>
// CHECK: xegpu.store_nd %[[T3]]#0, %{{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<array_length = 1 : i64>>
func.func @test_nested_scf_for_array_len(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>, %arg2: memref<64x64xf32>) {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
scf.for %arg3 = %c0 to %c64 step %c8 {
%0 = xegpu.create_nd_tdesc %arg0[%arg3, %c0] : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
%cst = arith.constant dense<0.000000e+00> : vector<128xf32>
%1 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32>
%2:2 = scf.for %arg4 = %c0 to %c64 step %c32 iter_args(%arg5 = %1, %arg6 = %0) -> (vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>) {
%4 = xegpu.load_nd %arg6 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x16x16xf16>
%5 = vector.extract %4[0] : vector<16x16xf16> from vector<2x16x16xf16>
%6 = vector.extract %4[1] : vector<16x16xf16> from vector<2x16x16xf16>
%7 = vector.transpose %5, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
%8 = vector.shape_cast %7 {packed} : vector<16x16xf16> to vector<256xf16>
%9 = vector.shuffle %8, %8 [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16>
%10 = vector.shape_cast %9 {packed} : vector<256xf16> to vector<8x16x2xf16>
%11 = xegpu.dpas %arg1, %10, %1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32>
%12 = vector.transpose %6, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
%13 = vector.shape_cast %12 {packed} : vector<16x16xf16> to vector<256xf16>
%14 = vector.shuffle %13, %13 [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16>
%15 = vector.shape_cast %14 {packed} : vector<256xf16> to vector<8x16x2xf16>
%16 = xegpu.dpas %arg1, %15, %11 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32>
%17 = xegpu.update_nd_offset %arg6, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
scf.yield %16, %17 : vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
}
%3 = xegpu.create_nd_tdesc %arg2[%arg3, %c0] : memref<64x64xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<array_length = 1 : i64>>
xegpu.store_nd %2#0, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<array_length = 1 : i64>>
}
return
}

// -----
// CHECK-LABEL: func.func @test_scf_for_preop(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<64x64xf16>, %[[ARG1:[a-zA-Z0-9]+]]: vector<8x16xf16>) -> vector<8x16xf32> {
Expand Down

0 comments on commit 2d70a58

Please sign in to comment.