-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aling 'linalg-to-xegpu' pass with patched XeGPU dialect #201
Changes from all commits
964398e
2778459
0f25517
05aa8d6
829b9d4
3660cdc
52eb013
48914ac
2f5561c
8184f5d
be7fdf0
a94205a
4cf3457
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,11 +129,35 @@ template <typename T> size_t countUntil(T *ptr, T &&elem) { | |
} // namespace | ||
|
||
static cl_device_id getDevice(cl_device_type *devtype) { | ||
cl_platform_id platform; // OpenCL platform | ||
cl_device_id device; // device ID | ||
CL_SAFE_CALL(clGetPlatformIDs(1, &platform, NULL)); | ||
CL_SAFE_CALL(clGetDeviceIDs(platform, *devtype, 1, &device, NULL)); | ||
return device; | ||
Comment on lines
-132
to
-136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old logic searched for a device of the requested type only in one platform (and couldn't find any GPU on my machine). Rewritten the logic to iterate over all available platforms and return a first suitable device |
||
cl_uint numPlatforms; | ||
CL_SAFE_CALL(clGetPlatformIDs(0, nullptr, &numPlatforms)) // get num platforms | ||
|
||
std::vector<cl_platform_id> platforms(numPlatforms); | ||
CL_SAFE_CALL(clGetPlatformIDs(numPlatforms, platforms.data(), | ||
nullptr)); // get available platforms | ||
|
||
for (cl_uint i = 0; i < numPlatforms; ++i) { | ||
// Get GPU device IDs for each platform | ||
cl_uint numDevices; | ||
cl_int status = | ||
clGetDeviceIDs(platforms[i], *devtype, 0, /*devices.data()=*/nullptr, | ||
&numDevices); // get num devices with 'devtype' | ||
if (status != CL_SUCCESS) { | ||
if (status == CL_DEVICE_NOT_FOUND) { | ||
continue; // No GPU devices found on this platform | ||
} | ||
fprintf(stderr, "CL error %d @ line=%d (%s)\n", status, __LINE__, | ||
"Error getting device IDs"); | ||
abort(); | ||
} | ||
|
||
std::vector<cl_device_id> devices(numDevices); | ||
clGetDeviceIDs(platforms[i], *devtype, numDevices, devices.data(), nullptr); | ||
return devices[0]; | ||
} | ||
|
||
fprintf(stderr, "No suitable devices found."); | ||
abort(); | ||
} | ||
|
||
struct GPUCLQUEUE { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter, | |
Location loc, ValueRange tiles, | ||
ArrayRef<int64_t> offsets) { | ||
SmallVector<Value> updatedTiles; | ||
// convert static offsets to dynamic because of this IMEX bug: | ||
// https://github.com/intel/mlir-extensions/issues/815 | ||
std::vector<Value> dynOffsets; | ||
for (auto &x : offsets) { | ||
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, x); | ||
dynOffsets.push_back(offset); | ||
} | ||
ValueRange newOffsets{dynOffsets}; | ||
for (auto tile : tiles) { | ||
auto updatedTile = | ||
rewriter | ||
.create<xegpu::UpdateNdOffsetOp>(loc, tile.getType(), tile, | ||
/*offsets=*/ValueRange{}, offsets) | ||
.getResult(); | ||
auto updatedTile = rewriter | ||
.create<xegpu::UpdateNdOffsetOp>( | ||
loc, tile.getType(), tile, | ||
/*offsets=*/newOffsets, | ||
SmallVector<int64_t>{ShapedType::kDynamic, | ||
ShapedType::kDynamic}) | ||
.getResult(); | ||
updatedTiles.push_back(updatedTile); | ||
} | ||
|
||
|
@@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter, | |
|
||
SmallVector<Value> tiles; | ||
for (int i = 0; i < loadShape[0]; i += descTile[0]) { | ||
// convert static offsets to dynamic because of this IMEX bug: | ||
// https://github.com/intel/mlir-extensions/issues/815 | ||
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i); | ||
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) { | ||
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j); | ||
auto tile = rewriter | ||
.create<xegpu::UpdateNdOffsetOp>( | ||
loc, descType, rootTile, | ||
/*offsets=*/ValueRange{}, SmallVector<int64_t>{i, j}) | ||
/*offsets=*/ValueRange{newRowOffs, newColOffs}, | ||
SmallVector<int64_t>{ShapedType::kDynamic, | ||
ShapedType::kDynamic}) | ||
.getResult(); | ||
tiles.push_back(tile); | ||
} | ||
|
@@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, | |
|
||
VectorType vecLoadType = | ||
VectorType::get(tileType.getShape(), tileType.getElementType()); | ||
UnitAttr vnniAxisAttr = nullptr; | ||
mlir::UnitAttr packedAttr = nullptr; | ||
if (vnniConf) { | ||
vnniAxisAttr = UnitAttr::get(rewriter.getContext()); | ||
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(), | ||
*vnniConf); | ||
packedAttr = mlir::UnitAttr::get(rewriter.getContext()); | ||
} | ||
|
||
IntegerAttr transpose_bit = nullptr; | ||
SmallVector<Value> loadVec; | ||
for (auto tile : loadTiles) { | ||
|
||
auto loadOp = rewriter.create<xegpu::LoadNdOp>( | ||
loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr, | ||
loc, vecLoadType, tile, packedAttr, transpose, transpose_bit, | ||
/*l1_hint=*/hint, | ||
/*l2_hint=*/hint, /*l3_hint=*/hint); | ||
loadVec.push_back(loadOp); | ||
|
@@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, | |
|
||
// Load A sub-tiles. | ||
SmallVector<Value> loadVecA = | ||
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint); | ||
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType()); | ||
|
||
// Load B sub-tiles. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem | |
|
||
// Create output initial value load tiles. | ||
// CHECK: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]] | ||
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [0, 0] | ||
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imex doesn't support constant offsets (see intel/mlir-extensions#815) |
||
// CHECK-COUNT-7: xegpu.update_nd_offset %[[rootC]] | ||
|
||
// Load initial accumulator values. | ||
|
@@ -31,9 +31,9 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem | |
|
||
// Create input load tiles. | ||
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]] | ||
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [0, 0] | ||
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0] | ||
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]] | ||
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [0, 0] | ||
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0] | ||
// CHECK-COUNT-1: xegpu.update_nd_offset %[[rootB]] | ||
|
||
// Create DPAS computation loop over tiled reduction dimension. | ||
|
@@ -63,7 +63,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem | |
|
||
// Extract DPAS-sized chunks from larger loaded tile A. | ||
// Tile B is already in the correct shape. | ||
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16> | ||
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we do not load the |
||
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16> | ||
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16> | ||
// CHECK-COUNT-3: vector.extract_strided_slice | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated to the latest commit in
dev
branch