Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Garra1980 committed Nov 22, 2024
1 parent ca096a1 commit 0a4d133
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
23 changes: 13 additions & 10 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
};

static mlir::xegpu::CachePolicy
translateCachePolicy(imex::xetile::CachePolicyAttr val, mlir::xegpu::CachePolicy defaultVal) {
translateCachePolicy(imex::xetile::CachePolicyAttr val,
mlir::xegpu::CachePolicy defaultVal) {
if (!val)
return defaultVal;

Expand All @@ -522,13 +523,13 @@ translateCachePolicy(imex::xetile::CachePolicyAttr val, mlir::xegpu::CachePolicy
llvm_unreachable("Invalid CachePolicy value");
}

template<typename OpTy>
static auto
getCachePolicy(OpTy op, mlir::xegpu::CachePolicy defaultVal = mlir::xegpu::CachePolicy::CACHED) {
template <typename OpTy>
static auto getCachePolicy(OpTy op, mlir::xegpu::CachePolicy defaultVal =
mlir::xegpu::CachePolicy::CACHED) {

auto getCachePolicyAttr = [&](imex::xetile::CachePolicyAttr val) {
return mlir::xegpu::CachePolicyAttr::get(op.getContext(),
translateCachePolicy(val, defaultVal));
return mlir::xegpu::CachePolicyAttr::get(
op.getContext(), translateCachePolicy(val, defaultVal));
};

auto L1 = getCachePolicyAttr(op.getL1HintAttr());
Expand Down Expand Up @@ -689,7 +690,8 @@ struct SgStoreTileOpPattern : public XeOneToNConversion<xetile::StoreTileOp> {
<< "values: " << values.size() << "\n";
}

auto [L1, L2, L3] = getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK);
auto [L1, L2, L3] =
getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK);

for (size_t i = 0; i < tiles.size(); i++)
rewriter.create<mlir::xegpu::StoreNdOp>(op.getLoc(), values[i], tiles[i],
Expand Down Expand Up @@ -721,12 +723,13 @@ struct SgStoreScatterOpPattern
auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1],
rewriter.getIntegerType(1));
auto transposeAttr = mlir::UnitAttr();
auto [L1, L2, L3] = getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK);
auto [L1, L2, L3] =
getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK);
for (auto [v, t, m] : llvm::zip(values, tdescs, masks)) {
m = rewriter.create<ShapeCastOp>(op.getLoc(), maskTy, m);
v = rewriter.create<ShapeCastOp>(op.getLoc(), vecTy, v);
rewriter.create<mlir::xegpu::StoreScatterOp>(
op.getLoc(), v, t, m, transposeAttr, L1, L2, L3);
rewriter.create<mlir::xegpu::StoreScatterOp>(op.getLoc(), v, t, m,
transposeAttr, L1, L2, L3);
}
rewriter.eraseOp(op);
return mlir::success();
Expand Down
15 changes: 10 additions & 5 deletions lib/Dialect/XeTile/Transforms/Blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ struct LoadTileOpPattern
tileTy.getElementType());
mlir::Value newOp = rewriter.create<xetile::LoadTileOp>(
op.getLoc(), vecTy, adaptor.getSource(),
op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
newOp = addUnpackOp(newOp, rewriter);
rewriter.replaceOp(op, newOp);
return mlir::success();
Expand Down Expand Up @@ -324,7 +325,8 @@ struct LoadGatherOpPattern
auto mask = addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter);
mlir::Value newOp = rewriter.create<xetile::LoadGatherOp>(
op.getLoc(), vecTy, source, mask,
op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
newOp = addUnpackOp(newOp, rewriter);
rewriter.replaceOp(op, newOp);
return mlir::success();
Expand Down Expand Up @@ -352,7 +354,9 @@ struct StoreTileOpPattern
// its inputs has not been updated yet.
if (blockSize && valTy.getRank() == 2) {
value = addPackOp(value, blockSize.asArrayRef(), rewriter);
rewriter.replaceOpWithNewOp<xetile::StoreTileOp>(op, value, tile, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
rewriter.replaceOpWithNewOp<xetile::StoreTileOp>(
op, value, tile, op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
return mlir::success();
}
return mlir::failure();
Expand Down Expand Up @@ -381,8 +385,9 @@ struct StoreScatterOpPattern
value = addPackOp(value, blockSize.asArrayRef(), rewriter);
auto mask =
addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter);
rewriter.replaceOpWithNewOp<xetile::StoreScatterOp>(op, value, tile,
mask, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
rewriter.replaceOpWithNewOp<xetile::StoreScatterOp>(
op, value, tile, mask, op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
return mlir::success();
}
return mlir::failure();
Expand Down

0 comments on commit 0a4d133

Please sign in to comment.