From 2d70a589cce17d772ed692cb41c4457a8c7b5a38 Mon Sep 17 00:00:00 2001 From: Charitha Saumya <136391709+charithaintc@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:23:44 -0800 Subject: [PATCH] [XeGPU][Transforms] Bug fix in OptimizeTranspose. (#976) --- lib/Transforms/OptimizeTranspose.cpp | 2 +- test/Transforms/xegpu-optimize-transpose.mlir | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/lib/Transforms/OptimizeTranspose.cpp b/lib/Transforms/OptimizeTranspose.cpp index 3bc0052a2..6e4913e96 100644 --- a/lib/Transforms/OptimizeTranspose.cpp +++ b/lib/Transforms/OptimizeTranspose.cpp @@ -535,7 +535,7 @@ struct ScfForOpPattern : public OpConversionPattern { // If initArg is a UnrealizedConversionCastOp, new iter args need to be // updated. And we keep track of the old -> new iter arg indice mapping. if (auto unrealizedConversionCast = - llvm::dyn_cast( + llvm::dyn_cast_if_present( initArg.getDefiningOp())) { auto sources = unrealizedConversionCast.getOperands(); forOpOutputTypeMapping[index] = diff --git a/test/Transforms/xegpu-optimize-transpose.mlir b/test/Transforms/xegpu-optimize-transpose.mlir index 2e55a7ea0..2dbb381ee 100644 --- a/test/Transforms/xegpu-optimize-transpose.mlir +++ b/test/Transforms/xegpu-optimize-transpose.mlir @@ -118,6 +118,55 @@ func.func @test_scf_for_array_len(%arg0 : memref<64x64xf16>, %arg1 : vector<8x16 return %result#0 : vector<8x16xf32> } +// ----- +// CHECK-LABEL: @test_nested_scf_for_array_len( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<64x64xf16>, %[[ARG1:[a-zA-Z0-9]+]]: vector<8x16xf16>, %[[ARG2:[a-zA-Z0-9]+]]: memref<64x64xf32>) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK: scf.for {{.*}} { +// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %[[C0]]] : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%{{.*}}, %[[C16]]] : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T3:.*]]:3 = scf.for {{.*}} iter_args(%[[ARG5:.*]] = %[[T2]], %[[ARG6:.*]] = %[[T0]], %[[ARG7:.*]] = %[[T1]]) -> (vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>) { +// CHECK-DAG: %[[T5:.*]] = xegpu.load_nd %[[ARG6]] <{transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<8x16x2xf16> +// CHECK-DAG: %[[T7:.*]] = xegpu.load_nd %[[ARG7]] <{transpose = array, transpose_bit_width = 32 : i32}> : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<8x16x2xf16> +// CHECK-DAG: %[[T9:.*]] = xegpu.update_nd_offset %[[ARG6]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK-DAG: %[[T10:.*]] = xegpu.update_nd_offset %[[ARG7]], [{{.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: scf.yield %{{.*}}, %[[T9]], %[[T10]] : vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> +// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<64x64xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> +// CHECK: xegpu.store_nd %[[T3]]#0, %{{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> +func.func @test_nested_scf_for_array_len(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>, %arg2: memref<64x64xf32>) { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + scf.for %arg3 = %c0 to %c64 step %c8 { + %0 = xegpu.create_nd_tdesc %arg0[%arg3, %c0] : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + %cst = arith.constant dense<0.000000e+00> : vector<128xf32> + %1 = vector.shape_cast %cst : vector<128xf32> to vector<8x16xf32> + %2:2 = scf.for %arg4 = %c0 to %c64 step %c32 iter_args(%arg5 = %1, %arg6 = %0) -> (vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr>) { + %4 = xegpu.load_nd %arg6 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<2x16x16xf16> + %5 = vector.extract %4[0] : vector<16x16xf16> from vector<2x16x16xf16> + %6 = vector.extract %4[1] : vector<16x16xf16> from vector<2x16x16xf16> + %7 = vector.transpose %5, [1, 0] : vector<16x16xf16> to vector<16x16xf16> + %8 = vector.shape_cast %7 {packed} : vector<16x16xf16> to vector<256xf16> + %9 = vector.shuffle %8, %8 [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16> + %10 = vector.shape_cast %9 {packed} : vector<256xf16> to vector<8x16x2xf16> + %11 = xegpu.dpas %arg1, %10, %1 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %12 = vector.transpose %6, [1, 0] : vector<16x16xf16> to vector<16x16xf16> + %13 = vector.shape_cast %12 {packed} : vector<16x16xf16> to vector<256xf16> + %14 = vector.shuffle %13, %13 [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16> + %15 = vector.shape_cast %14 {packed} : vector<256xf16> to vector<8x16x2xf16> + %16 = xegpu.dpas %arg1, %15, %11 : vector<8x16xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %17 = xegpu.update_nd_offset %arg6, [%c32, %c0] : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + scf.yield %16, %17 : vector<8x16xf32>, !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + } + %3 = xegpu.create_nd_tdesc %arg2[%arg3, %c0] : memref<64x64xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + xegpu.store_nd %2#0, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + } + return +} + // ----- // CHECK-LABEL: func.func @test_scf_for_preop( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<64x64xf16>, %[[ARG1:[a-zA-Z0-9]+]]: vector<8x16xf16>) -> vector<8x16xf32> {