Skip to content

Commit

Permalink
[XeTileToXEGPU] Fix reduction lowering (#977)
Browse files Browse the repository at this point in the history
  • Loading branch information
Garra1980 authored Dec 5, 2024
1 parent 2d70a58 commit 1c83c2f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 78 deletions.
2 changes: 1 addition & 1 deletion lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase<XeGPUToVCPass> {
unsigned rank = type.getRank();
auto elemType = type.getElementType();

if (rank < 1 || type.getNumElements() == 1)
if (rank < 1)
return elemType;

unsigned sum = 1;
Expand Down
11 changes: 7 additions & 4 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,16 +905,19 @@ struct SgTileReductionOpPattern
sources, shape, op.getKind(), loc, elemTy, rewriter);
llvm::SmallVector<mlir::Value> newOps;
{
// intermediate is a vector of values with type of vector<shape[3]xf16>,
// each value represents a portion of the reduced value. For example,
// intermediate is a vector of values with type of vector<nxf16>
// (where n is max of min(shape[0]/2,16) and 1),
// each element is the reduced value for a row. For example,
// for vector<32x4x1x16> with reduction on dim 1 and dim 3. the
// intermediate values will be two vectors of vector<16xf16>. The values
// intermediate values will be two values of vector<16xf16>. The values
// in the first vector represents the reduction result of the first 16
// rows. Here we will extract each value and splat it to a vector<1x1xf16>
// as results to their consumers.
for (auto v : intermediates) {
auto targetTy = mlir::VectorType::get({1, 1}, elemTy);
for (auto i = 0; i < shape[3]; i++) {
auto vecTy = mlir::dyn_cast<mlir::VectorType>(v.getType());
assert(vecTy && "expect vector type");
for (auto i = 0; i < vecTy.getShape()[0]; i++) {
auto pos = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(i));
auto extractOp =
Expand Down
111 changes: 38 additions & 73 deletions test/Conversion/XeTileToXeGPU/reduction.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \
// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-canonicalization --xetile-blocking \
// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s
module {
gpu.module @test_kernel {
Expand Down Expand Up @@ -111,54 +111,6 @@ module {
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<16xf16>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf16>
//CHECK-COUNT-8: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x1xf16>, vector<1x1xf16>
//CHECK-COUNT-4: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x1xf16>, vector<2x1xf16>
//CHECK-COUNT-2: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf16>, vector<4x1xf16>
Expand Down Expand Up @@ -250,30 +202,6 @@ module {
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<8xf32>
//CHECK: {{.*}} = vector.splat %{{.*}} : vector<1x1xf32>
//CHECK-COUNT-4: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<1x1xf32>, vector<1x1xf32>
//CHECK-COUNT-2: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<2x1xf32>, vector<2x1xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<4x1xf32>, vector<4x1xf32>
Expand All @@ -283,6 +211,43 @@ module {
gpu.return
}

gpu.func @inner_reduction_small_size_1(%arg0: memref<*xf32>, %arg1: memref<*xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array<i32: 1, 32, 1>, known_grid_size = array<i32: 1, 1, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
%cst = arith.constant dense<0.000000e+00> : vector<1xf32>
%cst_0 = arith.constant dense<true> : vector<1x16xi1>
%cst_1 = arith.constant dense<true> : vector<1x1xi1>
%cst_2 = arith.constant dense<0> : vector<1x1xindex>
%cst_3 = arith.constant dense<0> : vector<1x16xindex>
%cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
%cast_4 = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
%0 = xetile.init_tile %cast, %cst_3 : memref<?xf32>, vector<1x16xindex> -> !xetile.tile<1x16xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>
%1 = xetile.load %0, %cst_0 : !xetile.tile<1x16xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>, vector<1x16xi1> -> vector<1x16xf32>
//CHECK: {{.*}} = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
//CHECK: {{.*}} = vector.shape_cast %{{.*}} : vector<1x16xf32> to vector<16xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7] : vector<16xf32>, vector<16xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf32>, vector<16xf32>
//CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<8xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3] : vector<8xf32>, vector<8xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [4, 5, 6, 7] : vector<8xf32>, vector<8xf32>
//CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<4xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0, 1] : vector<4xf32>, vector<4xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [2, 3] : vector<4xf32>, vector<4xf32>
//CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<2xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [0] : vector<2xf32>, vector<2xf32>
//CHECK: {{.*}} = vector.shuffle %{{.*}}, %{{.*}} [1] : vector<2xf32>, vector<2xf32>
//CHECK: {{.*}} = arith.addf {{.*}}, {{.*}} : vector<1xf32>
//CHECK: {{.*}} = arith.constant {{.*}} : i32
//CHECK: {{.*}} = vector.extractelement %{{.*}}[{{.*}} : i32] : vector<1xf32>
//CHECK: {{.*}} = vector.splat {{.*}} : vector<1x1xf32>
//CHECK: {{.*}} = vector.shape_cast {{.*}} : vector<1x1xf32> to vector<1xf32>
//CHECK: {{.*}} = arith.addf %{{.*}}, %{{.*}} : vector<1xf32>

%2 = vector.multi_reduction <add>, %1, %cst [1] : vector<1x16xf32> to vector<1xf32>
%3 = vector.shape_cast %2 : vector<1xf32> to vector<1x1xf32>
%4 = xetile.init_tile %cast_4, %cst_2 : memref<?xf32>, vector<1x1xindex> -> !xetile.tile<1x1xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>
xetile.store %3, %4, %cst_1 : vector<1x1xf32>, !xetile.tile<1x1xf32, #xetile.tile_attr<memory_space = 0 : i32, scattered = true>>, vector<1x1xi1>
gpu.return
}

//CHECK: gpu.func @outter_reduction(%[[arg0:.*]]: memref<128x256xf16>, %[[arg1:.*]]: memref<128x256xf16>) {
gpu.func @outter_reduction(%a: memref<128x256xf16>, %b: memref<128x256xf16>) {
//CHECK: %[[c0:.*]] = arith.constant 0 : index
Expand Down

0 comments on commit 1c83c2f

Please sign in to comment.