-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aling 'linalg-to-xegpu' pass with patched XeGPU dialect #201
Changes from 7 commits
964398e
2778459
0f25517
05aa8d6
829b9d4
3660cdc
52eb013
48914ac
2f5561c
8184f5d
be7fdf0
a94205a
4cf3457
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter, | |
Location loc, ValueRange tiles, | ||
ArrayRef<int64_t> offsets) { | ||
SmallVector<Value> updatedTiles; | ||
// convert static offsets to dynamic because of this IMEX bug: | ||
// https://github.com/intel/mlir-extensions/issues/815 | ||
std::vector<Value> dynOffsets; | ||
for (auto &x : offsets) { | ||
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, x); | ||
dynOffsets.push_back(offset); | ||
} | ||
ValueRange newOffsets{dynOffsets}; | ||
for (auto tile : tiles) { | ||
auto updatedTile = | ||
rewriter | ||
.create<xegpu::UpdateNdOffsetOp>(loc, tile.getType(), tile, | ||
/*offsets=*/ValueRange{}, offsets) | ||
.getResult(); | ||
auto updatedTile = rewriter | ||
.create<xegpu::UpdateNdOffsetOp>( | ||
loc, tile.getType(), tile, | ||
/*offsets=*/newOffsets, | ||
SmallVector<int64_t>{ShapedType::kDynamic, | ||
ShapedType::kDynamic}) | ||
.getResult(); | ||
updatedTiles.push_back(updatedTile); | ||
} | ||
|
||
|
@@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter, | |
|
||
SmallVector<Value> tiles; | ||
for (int i = 0; i < loadShape[0]; i += descTile[0]) { | ||
// convert static offsets to dynamic because of this IMEX bug: | ||
// https://github.com/intel/mlir-extensions/issues/815 | ||
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i); | ||
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) { | ||
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j); | ||
auto tile = rewriter | ||
.create<xegpu::UpdateNdOffsetOp>( | ||
loc, descType, rootTile, | ||
/*offsets=*/ValueRange{}, SmallVector<int64_t>{i, j}) | ||
/*offsets=*/ValueRange{newRowOffs, newColOffs}, | ||
SmallVector<int64_t>{ShapedType::kDynamic, | ||
ShapedType::kDynamic}) | ||
.getResult(); | ||
tiles.push_back(tile); | ||
} | ||
|
@@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, | |
|
||
VectorType vecLoadType = | ||
VectorType::get(tileType.getShape(), tileType.getElementType()); | ||
UnitAttr vnniAxisAttr = nullptr; | ||
mlir::UnitAttr packedAttr = nullptr; | ||
if (vnniConf) { | ||
vnniAxisAttr = UnitAttr::get(rewriter.getContext()); | ||
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(), | ||
*vnniConf); | ||
packedAttr = mlir::UnitAttr::get(rewriter.getContext()); | ||
} | ||
|
||
IntegerAttr transpose_bit = nullptr; | ||
SmallVector<Value> loadVec; | ||
for (auto tile : loadTiles) { | ||
|
||
auto loadOp = rewriter.create<xegpu::LoadNdOp>( | ||
loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr, | ||
loc, vecLoadType, tile, packedAttr, transpose, transpose_bit, | ||
/*l1_hint=*/hint, | ||
/*l2_hint=*/hint, /*l3_hint=*/hint); | ||
loadVec.push_back(loadOp); | ||
|
@@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, | |
|
||
// Load A sub-tiles. | ||
SmallVector<Value> loadVecA = | ||
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint); | ||
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType()); | ||
|
||
// Load B sub-tiles. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem | |
|
||
// Create output initial value load tiles. | ||
// CHECK: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]] | ||
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [0, 0] | ||
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imex doesn't support constant offsets (see intel/mlir-extensions#815) |
||
// CHECK-COUNT-7: xegpu.update_nd_offset %[[rootC]] | ||
|
||
// Load initial accumulator values. | ||
|
@@ -31,9 +31,9 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem | |
|
||
// Create input load tiles. | ||
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]] | ||
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [0, 0] | ||
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0] | ||
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]] | ||
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [0, 0] | ||
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0] | ||
// CHECK-COUNT-1: xegpu.update_nd_offset %[[rootB]] | ||
|
||
// Create DPAS computation loop over tiled reduction dimension. | ||
|
@@ -63,7 +63,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem | |
|
||
// Extract DPAS-sized chunks from larger loaded tile A. | ||
// Tile B is already in the correct shape. | ||
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16> | ||
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we do not load the |
||
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16> | ||
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16> | ||
// CHECK-COUNT-3: vector.extract_strided_slice | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
if not config.gc_use_imex: | ||
if not config.gc_enable_imex: | ||
config.unsupported = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated to the latest commit in
dev
branch