Skip to content

Commit

Permalink
[TransposeOptimize][XeGPUToVC] Fix constraints on chunkSize (#979)
Browse files Browse the repository at this point in the history
  • Loading branch information
chencha3 authored Dec 11, 2024
1 parent 227a0f7 commit d40c423
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
4 changes: 4 additions & 0 deletions include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ using namespace mlir::xegpu;

namespace imex {

// valid chunk sizes are 1, 2, 3, 4, 8 if simdLanes > 1.
// 16, 32, and 64 are only available if simdLanes == 1.
llvm::SmallVector<int> getSupportedChunkSizes(int simdlanes);

using PackFuncTy = std::function<mlir::TypedValue<mlir::VectorType>(
mlir::Value, mlir::Value, mlir::Location, mlir::OpBuilder &)>;

Expand Down
9 changes: 6 additions & 3 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ static LogicalResult isValid1DBlockSetup(Type elemTy, int elems, Location &loc,
if (bitWidth < 32)
return rewriter.notifyMatchFailure(loc, "only 32-bit data supported.");

if (!llvm::is_contained({1, 2, 3, 4, 8, 16, 32, 64}, elems))
auto validChunkSizes = getSupportedChunkSizes(1);
if (!llvm::is_contained(validChunkSizes, elems))
return rewriter.notifyMatchFailure(
loc, "invalid number of elements. Supports 1, 2, 3, 4, 8, 16, 32, 64.");

Expand All @@ -139,9 +140,11 @@ static LogicalResult isValidScatterSetup(Type elemTy, int simd_lanes,
// return rewriter.notifyMatchFailure(
// loc, "A valid simd lane is 16 or 32 for PVC.");

if (!llvm::is_contained({1, 2, 3, 4, 8, 16, 32, 64}, chunk_size))
auto validChunkSizes = getSupportedChunkSizes(simd_lanes);
if (!llvm::is_contained(validChunkSizes, chunk_size))
return rewriter.notifyMatchFailure(
loc, "invalid chunk size. Supports 1, 2, 3, 4, 8, 16, 32, 64.");
loc, "invalid chunk size. Supports 1, 2, 3, 4, 8"
"(and 16, 32, 64 if simd_lanes == 1).");

auto bitWidth = elemTy.getIntOrFloatBitWidth();

Expand Down
2 changes: 1 addition & 1 deletion lib/Transforms/OptimizeTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,10 @@ struct TransposeRewritePattern : public OpRewritePattern<vector::TransposeOp> {
auto numElems = bytes / 4;
// number of elements each simd lane to write
int chunkSize = numElems / simdLanes;
llvm::SmallVector<int> validChunkSizes = {64, 32, 16, 8, 4, 3, 2, 1};

// the numElems has to be evenly divided by simdLanes, and the chunkSize
// has to be in the validChunkSizes.
auto validChunkSizes = imex::getSupportedChunkSizes(simdLanes);
if (numElems % simdLanes != 0 ||
!llvm::is_contained(validChunkSizes, chunkSize))
return failure();
Expand Down
6 changes: 6 additions & 0 deletions lib/Utils/XeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
#include "llvm/Support/FormatVariadic.h"

namespace imex {
llvm::SmallVector<int> getSupportedChunkSizes(int simdlanes) {
if (simdlanes == 1)
return {64, 32, 16, 8, 4, 3, 2, 1};
return {8, 4, 3, 2, 1};
}

int getOperandIndex(mlir::Operation *op, mlir::Value operand) {
for (auto [i, value] : llvm::enumerate(op->getOperands())) {
if (operand == value)
Expand Down

0 comments on commit d40c423

Please sign in to comment.