From ecb3517dc49e85dea18543482e8e744c5d69e4e7 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 3 Dec 2024 21:08:34 +0000 Subject: [PATCH 1/2] Add support for xetile.atomic_rmw op in init-duplicate pass --- include/imex/Utils/XeCommon.h | 56 ++++++++++++++++++- .../XeTile/Transforms/InitDuplicate.cpp | 7 ++- .../XeTile/Transforms/init_duplicate.mlir | 31 ++++++++++ 3 files changed, 91 insertions(+), 3 deletions(-) create mode 100644 test/Dialect/XeTile/Transforms/init_duplicate.mlir diff --git a/include/imex/Utils/XeCommon.h b/include/imex/Utils/XeCommon.h index 8298126dc..af4e94f33 100644 --- a/include/imex/Utils/XeCommon.h +++ b/include/imex/Utils/XeCommon.h @@ -81,6 +81,8 @@ class TileUsageAnalysis { Usage[op] |= (uint)UsageType::PREFETCH; } else if (llvm::isa(user)) { Usage[op] |= (uint)UsageType::STORE; + } if (llvm::isa(user)) { + Usage[op] |= (uint)UsageType::ATOMICRMW; } else if (llvm::isa(user)) { Usage[op] |= (uint)UsageType::OTHER; } else if (auto forOp = @@ -162,6 +164,17 @@ class TileUsageAnalysis { return false; } + bool isForAtomicRMW(imex::xetile::InitTileOp op) { + if (Usage.count(op)) { + bool load = Usage[op] & UsageType::LOAD; + bool store = Usage[op] & UsageType::STORE; + bool prefetch = Usage[op] & UsageType::PREFETCH; + bool atomic_rmw = Usage[op] & UsageType::ATOMICRMW; + return !load && !store && !prefetch && atomic_rmw; + } + return false; + } + // bool isForLoadAndPrefetch(imex::xetile::InitTileOp op) { if (Usage.count(op)) { @@ -193,6 +206,28 @@ class TileUsageAnalysis { return false; } + bool isForLoadAndAtomicRMW(imex::xetile::InitTileOp op) { + if (Usage.count(op)) { + bool load = Usage[op] & UsageType::LOAD; + bool store = Usage[op] & UsageType::STORE; + bool prefetch = Usage[op] & UsageType::PREFETCH; + bool atomic_rmw = Usage[op] & UsageType::ATOMICRMW; + return load && !store && !prefetch && atomic_rmw; + } + return false; + } + + bool isForAtomicRMWAndStore(imex::xetile::InitTileOp op) { + if (Usage.count(op)) { + bool load = Usage[op] & UsageType::LOAD; + bool store = Usage[op] & UsageType::STORE; + bool prefetch = Usage[op] & UsageType::PREFETCH; + bool atomic_rmw = Usage[op] & UsageType::ATOMICRMW; + return !load && store && !prefetch && atomic_rmw; + } + return false; + } + private: enum UsageType { None = 0, @@ -202,7 +237,8 @@ class TileUsageAnalysis { DPAS_A = 8, DPAS_B = 16, DPAS_C = 32, - OTHER = 64 + ATOMICRMW = 64, + OTHER = 128 }; llvm::DenseMap Usage; @@ -526,6 +562,12 @@ class XeConversionPattern : public mlir::RewritePattern { return llvm::cast(analysis).isForPrefetch(op); } + template >> + bool isForAtomicRMW(imex::xetile::InitTileOp op) const { + return llvm::cast(analysis).isForAtomicRMW(op); + } + template >> bool isForLoadAndPrefetch(imex::xetile::InitTileOp op) const { @@ -537,6 +579,18 @@ class XeConversionPattern : public mlir::RewritePattern { bool isForLoadAndStore(imex::xetile::InitTileOp op) const { return llvm::cast(analysis).isForLoadAndStore(op); } + + template >> + bool isForLoadAndAtomicRMW(imex::xetile::InitTileOp op) const { + return llvm::cast(analysis).isForLoadAndAtomicRMW(op); + } + + template >> + bool isForAtomicRMWAndStore(imex::xetile::InitTileOp op) const { + return llvm::cast(analysis).isForAtomicRMWAndStore(op); + } }; /// Clone `shape` with the last two elements swapped. diff --git a/lib/Dialect/XeTile/Transforms/InitDuplicate.cpp b/lib/Dialect/XeTile/Transforms/InitDuplicate.cpp index 1614ef598..e74caa1fb 100644 --- a/lib/Dialect/XeTile/Transforms/InitDuplicate.cpp +++ b/lib/Dialect/XeTile/Transforms/InitDuplicate.cpp @@ -56,11 +56,14 @@ class XeTileInitDuplicatePass op->walk([&](imex::xetile::InitTileOp op) { mlir::OpBuilder rewriter(op); if (usageAnalysis.isForLoadAndStore(op) || - usageAnalysis.isForLoadAndPrefetch(op)) { + usageAnalysis.isForLoadAndPrefetch(op) || + usageAnalysis.isForLoadAndAtomicRMW(op) || + usageAnalysis.isForAtomicRMWAndStore(op)) { mlir::Operation *cloneOp = rewriter.clone(*op); for (auto user : op->getUsers()) { if (llvm::isa(user) || - llvm::dyn_cast(user)) { + llvm::dyn_cast(user) || + llvm::dyn_cast(user)) { auto *targetOp = llvm::dyn_cast_if_present(user); targetOp->replaceUsesOfWith(op->getResults()[0], cloneOp->getResults()[0]); diff --git a/test/Dialect/XeTile/Transforms/init_duplicate.mlir b/test/Dialect/XeTile/Transforms/init_duplicate.mlir new file mode 100644 index 000000000..7297e1ad7 --- /dev/null +++ b/test/Dialect/XeTile/Transforms/init_duplicate.mlir @@ -0,0 +1,31 @@ +// RUN: imex-opt --split-input-file --xetile-init-duplicate %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test_kernel { + //CHECK: gpu.func @init_duplicate(%[[value:.*]]: vector<32x64xf32>, %[[arg0:.*]]: memref<256x256xf32>) + gpu.func @init_duplicate(%value: vector<32x64xf32>, %arg0: memref<256x256xf32>) { + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[INITTILE_0:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32> + // CHECK: %[[INITTILE_1:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32> + // CHECK: %[[ATOMICRMW:.*]] = xetile.atomic_rmw addf %[[value]], %[[INITTILE_1]] : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32> + // CHECK: xetile.store_tile %[[ATOMICRMW]], %[[INITTILE_0]] : vector<32x64xf32>, !xetile.tile<32x64xf32> + %c0 = arith.constant 0 : index + %tile = xetile.init_tile %arg0[%c0, %c0] : memref<256x256xf32> -> !xetile.tile<32x64xf32> + %rmw = xetile.atomic_rmw addf %value, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32> + xetile.store_tile %rmw, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32> + gpu.return + } + + //CHECK: gpu.func @init_duplicate_1(%[[value:.*]]: vector<32x64xf32>, %[[arg0:.*]]: memref<256x256xf32>) + gpu.func @init_duplicate_1(%value: vector<32x64xf32>, %arg0: memref<256x256xf32>) { + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[INITTILE_0:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32> + // CHECK: %[[INITTILE_1:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32> + // CHECK: %[[LOADTILE:.*]] = xetile.load_tile %[[INITTILE_1]] : !xetile.tile<32x64xf32> -> vector<32x64xf32> + // CHECK: %[[ATOMICRMW:.*]] = xetile.atomic_rmw addf %[[value]], %[[INITTILE_0]] : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32> + %c0 = arith.constant 0 : index + %tile = xetile.init_tile %arg0[%c0, %c0] : memref<256x256xf32> -> !xetile.tile<32x64xf32> + %load = xetile.load_tile %tile : !xetile.tile<32x64xf32> -> vector<32x64xf32> + %rmw = xetile.atomic_rmw addf %value, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32> + gpu.return +} +} From b4e8c519540a6eee2afae6c49777bbdb86b72326 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Tue, 3 Dec 2024 21:15:05 +0000 Subject: [PATCH 2/2] Fix pre-commit --- include/imex/Utils/XeCommon.h | 4 ++-- test/Dialect/XeTile/Transforms/init_duplicate.mlir | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/imex/Utils/XeCommon.h b/include/imex/Utils/XeCommon.h index af4e94f33..77fdfdcd2 100644 --- a/include/imex/Utils/XeCommon.h +++ b/include/imex/Utils/XeCommon.h @@ -81,7 +81,7 @@ class TileUsageAnalysis { Usage[op] |= (uint)UsageType::PREFETCH; } else if (llvm::isa(user)) { Usage[op] |= (uint)UsageType::STORE; - } if (llvm::isa(user)) { + } else if (llvm::isa(user)) { Usage[op] |= (uint)UsageType::ATOMICRMW; } else if (llvm::isa(user)) { Usage[op] |= (uint)UsageType::OTHER; @@ -238,7 +238,7 @@ class TileUsageAnalysis { DPAS_B = 16, DPAS_C = 32, ATOMICRMW = 64, - OTHER = 128 + OTHER = 128 }; llvm::DenseMap Usage; diff --git a/test/Dialect/XeTile/Transforms/init_duplicate.mlir b/test/Dialect/XeTile/Transforms/init_duplicate.mlir index 7297e1ad7..d6b941c7a 100644 --- a/test/Dialect/XeTile/Transforms/init_duplicate.mlir +++ b/test/Dialect/XeTile/Transforms/init_duplicate.mlir @@ -8,11 +8,11 @@ gpu.module @test_kernel { // CHECK: %[[INITTILE_1:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32> // CHECK: %[[ATOMICRMW:.*]] = xetile.atomic_rmw addf %[[value]], %[[INITTILE_1]] : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32> // CHECK: xetile.store_tile %[[ATOMICRMW]], %[[INITTILE_0]] : vector<32x64xf32>, !xetile.tile<32x64xf32> - %c0 = arith.constant 0 : index + %c0 = arith.constant 0 : index %tile = xetile.init_tile %arg0[%c0, %c0] : memref<256x256xf32> -> !xetile.tile<32x64xf32> %rmw = xetile.atomic_rmw addf %value, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32> xetile.store_tile %rmw, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32> - gpu.return + gpu.return } //CHECK: gpu.func @init_duplicate_1(%[[value:.*]]: vector<32x64xf32>, %[[arg0:.*]]: memref<256x256xf32>) @@ -22,10 +22,10 @@ gpu.module @test_kernel { // CHECK: %[[INITTILE_1:.*]] = xetile.init_tile %[[arg0]][%[[c0]], %[[c0]]] : memref<256x256xf32> -> !xetile.tile<32x64xf32> // CHECK: %[[LOADTILE:.*]] = xetile.load_tile %[[INITTILE_1]] : !xetile.tile<32x64xf32> -> vector<32x64xf32> // CHECK: %[[ATOMICRMW:.*]] = xetile.atomic_rmw addf %[[value]], %[[INITTILE_0]] : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32> - %c0 = arith.constant 0 : index + %c0 = arith.constant 0 : index %tile = xetile.init_tile %arg0[%c0, %c0] : memref<256x256xf32> -> !xetile.tile<32x64xf32> %load = xetile.load_tile %tile : !xetile.tile<32x64xf32> -> vector<32x64xf32> %rmw = xetile.atomic_rmw addf %value, %tile : vector<32x64xf32>, !xetile.tile<32x64xf32> -> vector<32x64xf32> - gpu.return + gpu.return } }