diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp index cd4153f66..2dd8b3293 100644 --- a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -501,7 +501,8 @@ class SgInitTileOpPattern : public XeOneToNConversion { }; static mlir::xegpu::CachePolicy -translateCachePolicy(imex::xetile::CachePolicyAttr val, mlir::xegpu::CachePolicy defaultVal) { +translateCachePolicy(imex::xetile::CachePolicyAttr val, + mlir::xegpu::CachePolicy defaultVal) { if (!val) return defaultVal; @@ -522,13 +523,13 @@ translateCachePolicy(imex::xetile::CachePolicyAttr val, mlir::xegpu::CachePolicy llvm_unreachable("Invalid CachePolicy value"); } -template -static auto -getCachePolicy(OpTy op, mlir::xegpu::CachePolicy defaultVal = mlir::xegpu::CachePolicy::CACHED) { +template +static auto getCachePolicy(OpTy op, mlir::xegpu::CachePolicy defaultVal = + mlir::xegpu::CachePolicy::CACHED) { auto getCachePolicyAttr = [&](imex::xetile::CachePolicyAttr val) { - return mlir::xegpu::CachePolicyAttr::get(op.getContext(), - translateCachePolicy(val, defaultVal)); + return mlir::xegpu::CachePolicyAttr::get( + op.getContext(), translateCachePolicy(val, defaultVal)); }; auto L1 = getCachePolicyAttr(op.getL1HintAttr()); @@ -689,7 +690,8 @@ struct SgStoreTileOpPattern : public XeOneToNConversion { << "values: " << values.size() << "\n"; } - auto [L1, L2, L3] = getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK); + auto [L1, L2, L3] = + getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK); for (size_t i = 0; i < tiles.size(); i++) rewriter.create(op.getLoc(), values[i], tiles[i], @@ -721,12 +723,13 @@ struct SgStoreScatterOpPattern auto maskTy = mlir::VectorType::get(innerBlk[0] * innerBlk[1], rewriter.getIntegerType(1)); auto transposeAttr = mlir::UnitAttr(); - auto [L1, L2, L3] = getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK); + auto [L1, L2, L3] = + getCachePolicy(op, mlir::xegpu::CachePolicy::WRITE_BACK); for (auto [v, t, m] : llvm::zip(values, tdescs, masks)) { m = rewriter.create(op.getLoc(), maskTy, m); v = rewriter.create(op.getLoc(), vecTy, v); - rewriter.create( - op.getLoc(), v, t, m, transposeAttr, L1, L2, L3); + rewriter.create(op.getLoc(), v, t, m, + transposeAttr, L1, L2, L3); } rewriter.eraseOp(op); return mlir::success(); diff --git a/lib/Dialect/XeTile/Transforms/Blocking.cpp b/lib/Dialect/XeTile/Transforms/Blocking.cpp index 24e9e0a3e..3cc88910b 100644 --- a/lib/Dialect/XeTile/Transforms/Blocking.cpp +++ b/lib/Dialect/XeTile/Transforms/Blocking.cpp @@ -290,7 +290,8 @@ struct LoadTileOpPattern tileTy.getElementType()); mlir::Value newOp = rewriter.create( op.getLoc(), vecTy, adaptor.getSource(), - op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); newOp = addUnpackOp(newOp, rewriter); rewriter.replaceOp(op, newOp); return mlir::success(); @@ -324,7 +325,8 @@ struct LoadGatherOpPattern auto mask = addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter); mlir::Value newOp = rewriter.create( op.getLoc(), vecTy, source, mask, - op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + op.getPadding().value_or(mlir::Attribute()), op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); newOp = addUnpackOp(newOp, rewriter); rewriter.replaceOp(op, newOp); return mlir::success(); @@ -352,7 +354,9 @@ struct StoreTileOpPattern // its inputs has not been updated yet. if (blockSize && valTy.getRank() == 2) { value = addPackOp(value, blockSize.asArrayRef(), rewriter); - rewriter.replaceOpWithNewOp(op, value, tile, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + rewriter.replaceOpWithNewOp( + op, value, tile, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); return mlir::success(); } return mlir::failure(); @@ -381,8 +385,9 @@ struct StoreScatterOpPattern value = addPackOp(value, blockSize.asArrayRef(), rewriter); auto mask = addPackOp(adaptor.getMask(), blockSize.asArrayRef(), rewriter); - rewriter.replaceOpWithNewOp(op, value, tile, - mask, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + rewriter.replaceOpWithNewOp( + op, value, tile, mask, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); return mlir::success(); } return mlir::failure();