diff --git a/include/imex/Utils/XeCommon.h b/include/imex/Utils/XeCommon.h index 797f6be2d..c24b6b525 100644 --- a/include/imex/Utils/XeCommon.h +++ b/include/imex/Utils/XeCommon.h @@ -31,6 +31,10 @@ using namespace mlir::xegpu; namespace imex { +// valid chunk sizes are 1, 2, 3, 4, 8 if simdLanes > 1. +// 16, 32, and 64 are only available if simdLanes == 1. +llvm::SmallVector getSupportedChunkSizes(int simdlanes); + using PackFuncTy = std::function( mlir::Value, mlir::Value, mlir::Location, mlir::OpBuilder &)>; diff --git a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp index d637c0cd1..eafb886de 100644 --- a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp +++ b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp @@ -115,7 +115,8 @@ static LogicalResult isValid1DBlockSetup(Type elemTy, int elems, Location &loc, if (bitWidth < 32) return rewriter.notifyMatchFailure(loc, "only 32-bit data supported."); - if (!llvm::is_contained({1, 2, 3, 4, 8, 16, 32, 64}, elems)) + auto validChunkSizes = getSupportedChunkSizes(1); + if (!llvm::is_contained(validChunkSizes, elems)) return rewriter.notifyMatchFailure( loc, "invalid number of elements. Supports 1, 2, 3, 4, 8, 16, 32, 64."); @@ -139,9 +140,11 @@ static LogicalResult isValidScatterSetup(Type elemTy, int simd_lanes, // return rewriter.notifyMatchFailure( // loc, "A valid simd lane is 16 or 32 for PVC."); - if (!llvm::is_contained({1, 2, 3, 4, 8, 16, 32, 64}, chunk_size)) + auto validChunkSizes = getSupportedChunkSizes(simd_lanes); + if (!llvm::is_contained(validChunkSizes, chunk_size)) return rewriter.notifyMatchFailure( - loc, "invalid chunk size. Supports 1, 2, 3, 4, 8, 16, 32, 64."); + loc, "invalid chunk size. Supports 1, 2, 3, 4, 8" + "(and 16, 32, 64 if simd_lanes == 1)."); auto bitWidth = elemTy.getIntOrFloatBitWidth(); diff --git a/lib/Transforms/OptimizeTranspose.cpp b/lib/Transforms/OptimizeTranspose.cpp index 6e4913e96..ac9a6d8a5 100644 --- a/lib/Transforms/OptimizeTranspose.cpp +++ b/lib/Transforms/OptimizeTranspose.cpp @@ -793,10 +793,10 @@ struct TransposeRewritePattern : public OpRewritePattern { auto numElems = bytes / 4; // number of elements each simd lane to write int chunkSize = numElems / simdLanes; - llvm::SmallVector validChunkSizes = {64, 32, 16, 8, 4, 3, 2, 1}; // the numElems has to be evenly divided by simdLanes, and the chunkSize // has to be in the validChunkSizes. + auto validChunkSizes = imex::getSupportedChunkSizes(simdLanes); if (numElems % simdLanes != 0 || !llvm::is_contained(validChunkSizes, chunkSize)) return failure(); diff --git a/lib/Utils/XeCommon.cpp b/lib/Utils/XeCommon.cpp index c262ac97a..190802a26 100644 --- a/lib/Utils/XeCommon.cpp +++ b/lib/Utils/XeCommon.cpp @@ -22,6 +22,12 @@ #include "llvm/Support/FormatVariadic.h" namespace imex { +llvm::SmallVector getSupportedChunkSizes(int simdlanes) { + if (simdlanes == 1) + return {64, 32, 16, 8, 4, 3, 2, 1}; + return {8, 4, 3, 2, 1}; +} + int getOperandIndex(mlir::Operation *op, mlir::Value operand) { for (auto [i, value] : llvm::enumerate(op->getOperands())) { if (operand == value)