Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XeGPU][Transforms] Improve OptimizeTranspose pass to handle array_length > 1 case for B transpose. #965

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/imex/Utils/XeArch.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace imex {
struct Range {
int min;
int max;
bool contains(int val) { return val >= min && val <= max; }
};

// DPAS m x n x k
Expand Down
82 changes: 8 additions & 74 deletions include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ mlir::TypedValue<mlir::VectorType> stack(mlir::Value vecUp, mlir::Value vecDown,
bool isSupportedModule(mlir::gpu::GPUModuleOp mod);

int getOperandIndex(mlir::Operation *op, mlir::Value operand);

// Obtain the index of the result in the operation. If the result is not found,
// return -1.
int getResultIndex(mlir::Operation *op, mlir::Value result);

mlir::BlockArgument getArgForOperand(mlir::scf::ForOp &op, mlir::Value operand);

mlir::ValueRange buildUnrealizedCast(mlir::OpBuilder &builder,
Expand Down Expand Up @@ -90,20 +95,15 @@ class TileUsageAnalysis {
op->walk<mlir::WalkOrder::PreOrder>([&](imex::xetile::LoadTileOp op) {
Usage[op] = (uint)UsageType::None;
llvm::SmallVector<mlir::Value> q({op});
bool transposeBeforeDPAS = false;
while (q.size()) {
auto curr = q.pop_back_val();
for (mlir::Operation *user : curr.getUsers()) {
if (auto mma = llvm::dyn_cast_if_present<xetile::TileMMAOp>(user)) {
auto idx = getOperandIndex(mma, curr);
if (idx == 0)
Usage[op] |= transposeBeforeDPAS
? (uint)UsageType::TRANSPOSE_DPAS_A
: (uint)UsageType::DPAS_A;
Usage[op] |= (uint)UsageType::DPAS_A;
else if (idx == 1)
Usage[op] |= transposeBeforeDPAS
? (uint)UsageType::TRANSPOSE_DPAS_B
: (uint)UsageType::DPAS_B;
Usage[op] |= (uint)UsageType::DPAS_B;
else if (idx == 2)
Usage[op] |= (uint)UsageType::DPAS_C;
else
Expand All @@ -115,14 +115,6 @@ class TileUsageAnalysis {
} else if (auto pack =
llvm::dyn_cast_if_present<xetile::TilePackOp>(user)) {
q.push_back(pack);
} else if (auto transpose =
llvm::dyn_cast_if_present<xetile::TransposeOp>(user)) {
// Transpose op is found in between LoadTileOp and TileMMAOp. This
// info is needed for downstream optimizations.
transposeBeforeDPAS = true;
q.push_back(transpose);
} else if (mlir::OpTrait::hasElementwiseMappableTraits(user)) {
q.push_back(user->getResult(0));
}
}
}
Expand All @@ -143,20 +135,6 @@ class TileUsageAnalysis {
return false;
}

bool isForTransposeDPASA(imex::xetile::LoadTileOp op) {
if (Usage.count(op)) {
return Usage[op] & UsageType::TRANSPOSE_DPAS_A;
}
return false;
}

bool isForTransposeDPASB(imex::xetile::LoadTileOp op) {
if (Usage.count(op)) {
return Usage[op] & UsageType::TRANSPOSE_DPAS_B;
}
return false;
}

bool isForDPASC(imex::xetile::LoadTileOp op) {
if (Usage.count(op)) {
return Usage[op] & UsageType::DPAS_C;
Expand Down Expand Up @@ -224,9 +202,7 @@ class TileUsageAnalysis {
DPAS_A = 8,
DPAS_B = 16,
DPAS_C = 32,
OTHER = 64,
TRANSPOSE_DPAS_A = 128,
TRANSPOSE_DPAS_B = 256
OTHER = 64
};

llvm::DenseMap<mlir::Operation *, uint> Usage;
Expand Down Expand Up @@ -526,18 +502,6 @@ class XeConversionPattern : public mlir::RewritePattern {
return llvm::cast<TileUsageAnalysis>(analysis).isForDPASB(op);
}

template <typename = typename std::enable_if<
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
bool isForTransposeDPASA(imex::xetile::LoadTileOp op) const {
return llvm::cast<TileUsageAnalysis>(analysis).isForTransposeDPASA(op);
}

template <typename = typename std::enable_if<
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
bool isForTransposeDPASB(imex::xetile::LoadTileOp op) const {
return llvm::cast<TileUsageAnalysis>(analysis).isForTransposeDPASB(op);
}

template <typename = typename std::enable_if<
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
bool isForDPASC(imex::xetile::LoadTileOp op) const {
Expand Down Expand Up @@ -573,36 +537,6 @@ class XeConversionPattern : public mlir::RewritePattern {
bool isForLoadAndStore(imex::xetile::InitTileOp op) const {
return llvm::cast<TileUsageAnalysis>(analysis).isForLoadAndStore(op);
}

template <typename = typename std::enable_if<
std::is_same_v<AnalysisT, TileUsageAnalysis>>>
bool isForLoadTransposeDPASB(imex::xetile::InitTileOp op) const {
if (!isForLoad(op))
return false;
// Walk the InitTileOp and collect all loadOps
llvm::SmallVector<mlir::Operation *> loadOps;
op->walk<mlir::WalkOrder::PreOrder>([&](imex::xetile::InitTileOp op) {
llvm::SmallVector<mlir::Value> q({op});
while (q.size()) {
auto curr = q.pop_back_val();
for (mlir::Operation *user : curr.getUsers()) {
if (llvm::isa<imex::xetile::LoadTileOp>(user)) {
loadOps.push_back(user);
} else if (auto forOp =
llvm::dyn_cast_if_present<mlir::scf::ForOp>(user)) {
auto arg = getArgForOperand(forOp, curr);
q.push_back(arg);
}
}
}
});
// If more than one loadOp, return false. TODO : Handle this case
if (loadOps.size() > 1)
return false;
auto loadOp = llvm::dyn_cast<imex::xetile::LoadTileOp>(loadOps[0]);
// Finally check if the loadOp is propagated to transpose op and DPAS B
return isForTransposeDPASB(loadOp);
}
};

/// Clone `shape` with the last two elements swapped.
Expand Down
5 changes: 0 additions & 5 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,6 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
? getBlockArrayLength(op, elemTy, innerBlk[0],
innerBlk[1], shape[1])
: 1;
// If this tile is used in load -> transpose -> DPASB chain, optimize
// transpose optimization requires array_length to be 1.
if (isForLoadTransposeDPASB(op))
array_length = 1;

auto width = array_length * innerBlk[1];

llvm::SmallVector<int64_t, 2> blocks(
Expand Down
66 changes: 43 additions & 23 deletions lib/Transforms/HoistTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,13 @@ struct HoistTransposeBeforeExtractStridedSliceOpPattern
transposeOp.getVector().getDefiningOp());
if (!extractOp)
return mlir::failure();
// Source of extract op must be a load op.
auto loadOp = llvm::dyn_cast<mlir::xegpu::LoadNdOp>(
extractOp.getVector().getDefiningOp());
if (!loadOp)
auto sourceOfExtract = extractOp.getVector().getDefiningOp();
if (!sourceOfExtract)
return mlir::failure();
// Check if the load is already transposed by previous application of this
// Check if the source is already transposed by previous application of this
// pattern.
mlir::vector::TransposeOp transposedLoad = nullptr;
for (auto user : loadOp->getUsers()) {
for (auto user : sourceOfExtract->getUsers()) {
if (auto transposeUser =
llvm::dyn_cast<mlir::vector::TransposeOp>(user)) {
transposedLoad = transposeUser;
Expand All @@ -98,7 +96,7 @@ struct HoistTransposeBeforeExtractStridedSliceOpPattern
// If not found, create a new transpose op.
if (!transposedLoad)
transposedLoad = rewriter.create<mlir::vector::TransposeOp>(
transposeOp.getLoc(), loadOp.getResult(),
transposeOp.getLoc(), sourceOfExtract->getResult(0),
llvm::ArrayRef<int64_t>({1, 0}));
// Extract the required slice from the transposed load and replace the
// original transpose op with it.
Expand All @@ -124,25 +122,47 @@ struct HoistTransposePass final
mlir::Operation *op = getOperation();
llvm::SmallDenseSet<mlir::vector::TransposeOp> transposeOps;

// Visit ExtractStridedSliceOp and check if it is followed by a TransposeOp.
auto visitExtractStridedSliceOp =
[&](mlir::vector::ExtractStridedSliceOp extractStridedSliceOp)
-> mlir::vector::TransposeOp {
// If extract op has more than one user, skip.
if (!extractStridedSliceOp->hasOneUse())
return nullptr;
// If the user is not a transpose op, skip.
auto transposeOp = llvm::dyn_cast_if_present<mlir::vector::TransposeOp>(
*extractStridedSliceOp->user_begin());
if (!(transposeOp &&
transposeOp.getPermutation() == llvm::ArrayRef<int64_t>({1, 0})))
return nullptr;
return transposeOp;
};

op->walk([&](mlir::xegpu::LoadNdOp loadOp) -> mlir::WalkResult {
// Check all users of the load op are ExtractStridedSliceOp followed by a
// TransposeOp.
// Check all users of the load op are,
// 1. ExtractStridedSliceOp -> TransposeOp chain
// 2. ExtractOp -> ExtractStridedSliceOp -> TransposeOp chain
for (auto user : loadOp->getUsers()) {
auto extractOp =
llvm::dyn_cast_if_present<mlir::vector::ExtractStridedSliceOp>(
user);
if (!extractOp)
return mlir::WalkResult::skip();
// If extract op has more than one user, skip.
if (!extractOp->hasOneUse())
return mlir::WalkResult::skip();
// If the user is not a transpose op, skip.
auto transposeOp = llvm::dyn_cast_if_present<mlir::vector::TransposeOp>(
*extractOp->user_begin());
if (!(transposeOp &&
transposeOp.getPermutation() == llvm::ArrayRef<int64_t>({1, 0})))
if (auto extractStridedSliceOp =
llvm::dyn_cast_if_present<mlir::vector::ExtractStridedSliceOp>(
user)) {
auto found = visitExtractStridedSliceOp(extractStridedSliceOp);
if (found)
transposeOps.insert(found);
} else if (auto extractOp =
llvm::dyn_cast_if_present<mlir::vector::ExtractOp>(
user)) {
for (auto extractUser : extractOp->getUsers()) {
if (auto extractStridedSliceOp = llvm::dyn_cast_if_present<
mlir::vector::ExtractStridedSliceOp>(extractUser)) {
auto found = visitExtractStridedSliceOp(extractStridedSliceOp);
if (found)
transposeOps.insert(found);
}
}
} else {
return mlir::WalkResult::skip();
transposeOps.insert(transposeOp);
}
}
return mlir::WalkResult::advance();
});
Expand Down
Loading
Loading