Skip to content

Commit

Permalink
Aling 'linalg-to-xegpu' pass with patched XeGPU dialect (#201)
Browse files Browse the repository at this point in the history
This PR updates linalg-to-xegpu pass to make it compatible with xegpu-to-vc-func pass from IMEX.

The PR also adds a simple e2e test for linalg->xegpu->gpu exe pipeline.

---------

Signed-off-by: dchigarev <[email protected]>
  • Loading branch information
dchigarev authored Aug 7, 2024
1 parent a58150d commit dd1a80d
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 23 deletions.
2 changes: 1 addition & 1 deletion cmake/imex.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if (NOT DEFINED IMEX_INCLUDES)

# TODO: Change to main https://github.com/intel/mlir-extensions when all the
# required functionality is merged.
gc_fetch_content(imex 496b240093b5e132b60c5ee69878300fe69be300 https://github.com/Menooker/mlir-extensions
gc_fetch_content(imex d5bbd635dee500b8cff138686833bacfac5ade78 https://github.com/Menooker/mlir-extensions
SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=0
)

Expand Down
2 changes: 1 addition & 1 deletion include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
"DPAS register block sizes MxNxK">,
];
}
#endif
#endif // GC_USE_IMEX

def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
"func::FuncOp"> {
Expand Down
2 changes: 1 addition & 1 deletion lib/gc/ExecutionEngine/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else()
endif()

set(GC_PASSES GcInterface GcPasses)
if(GC_UNABLE_GPU)
if(GC_ENABLE_IMEX)
list(APPEND GC_PASSES GcGpuPasses)
endif()

Expand Down
34 changes: 29 additions & 5 deletions lib/gc/ExecutionEngine/OpenCLRuntime/OpenCLRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,35 @@ template <typename T> size_t countUntil(T *ptr, T &&elem) {
} // namespace

static cl_device_id getDevice(cl_device_type *devtype) {
cl_platform_id platform; // OpenCL platform
cl_device_id device; // device ID
CL_SAFE_CALL(clGetPlatformIDs(1, &platform, NULL));
CL_SAFE_CALL(clGetDeviceIDs(platform, *devtype, 1, &device, NULL));
return device;
cl_uint numPlatforms;
CL_SAFE_CALL(clGetPlatformIDs(0, nullptr, &numPlatforms)) // get num platforms

std::vector<cl_platform_id> platforms(numPlatforms);
CL_SAFE_CALL(clGetPlatformIDs(numPlatforms, platforms.data(),
nullptr)); // get available platforms

for (cl_uint i = 0; i < numPlatforms; ++i) {
// Get GPU device IDs for each platform
cl_uint numDevices;
cl_int status =
clGetDeviceIDs(platforms[i], *devtype, 0, /*devices.data()=*/nullptr,
&numDevices); // get num devices with 'devtype'
if (status != CL_SUCCESS) {
if (status == CL_DEVICE_NOT_FOUND) {
continue; // No GPU devices found on this platform
}
fprintf(stderr, "CL error %d @ line=%d (%s)\n", status, __LINE__,
"Error getting device IDs");
abort();
}

std::vector<cl_device_id> devices(numDevices);
clGetDeviceIDs(platforms[i], *devtype, numDevices, devices.data(), nullptr);
return devices[0];
}

fprintf(stderr, "No suitable devices found.");
abort();
}

struct GPUCLQUEUE {
Expand Down
39 changes: 28 additions & 11 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
Location loc, ValueRange tiles,
ArrayRef<int64_t> offsets) {
SmallVector<Value> updatedTiles;
// convert static offsets to dynamic because of this IMEX bug:
// https://github.com/intel/mlir-extensions/issues/815
std::vector<Value> dynOffsets;
for (auto &x : offsets) {
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, x);
dynOffsets.push_back(offset);
}
ValueRange newOffsets{dynOffsets};
for (auto tile : tiles) {
auto updatedTile =
rewriter
.create<xegpu::UpdateNdOffsetOp>(loc, tile.getType(), tile,
/*offsets=*/ValueRange{}, offsets)
.getResult();
auto updatedTile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, tile.getType(), tile,
/*offsets=*/newOffsets,
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
updatedTiles.push_back(updatedTile);
}

Expand Down Expand Up @@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,

SmallVector<Value> tiles;
for (int i = 0; i < loadShape[0]; i += descTile[0]) {
// convert static offsets to dynamic because of this IMEX bug:
// https://github.com/intel/mlir-extensions/issues/815
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
auto tile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, descType, rootTile,
/*offsets=*/ValueRange{}, SmallVector<int64_t>{i, j})
/*offsets=*/ValueRange{newRowOffs, newColOffs},
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
tiles.push_back(tile);
}
Expand Down Expand Up @@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,

VectorType vecLoadType =
VectorType::get(tileType.getShape(), tileType.getElementType());
UnitAttr vnniAxisAttr = nullptr;
mlir::UnitAttr packedAttr = nullptr;
if (vnniConf) {
vnniAxisAttr = UnitAttr::get(rewriter.getContext());
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(),
*vnniConf);
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
}

IntegerAttr transpose_bit = nullptr;
SmallVector<Value> loadVec;
for (auto tile : loadTiles) {

auto loadOp = rewriter.create<xegpu::LoadNdOp>(
loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr,
loc, vecLoadType, tile, packedAttr, transpose, transpose_bit,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
loadVec.push_back(loadOp);
Expand Down Expand Up @@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,

// Load A sub-tiles.
SmallVector<Value> loadVecA =
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA);
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());

// Load B sub-tiles.
Expand Down
8 changes: 4 additions & 4 deletions test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Create output initial value load tiles.
// CHECK: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]]
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [0, 0]
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0]
// CHECK-COUNT-7: xegpu.update_nd_offset %[[rootC]]

// Load initial accumulator values.
Expand All @@ -31,9 +31,9 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Create input load tiles.
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]]
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [0, 0]
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0]
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [0, 0]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
// CHECK-COUNT-1: xegpu.update_nd_offset %[[rootB]]

// Create DPAS computation loop over tiled reduction dimension.
Expand Down Expand Up @@ -63,7 +63,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Extract DPAS-sized chunks from larger loaded tile A.
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16>
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK-COUNT-3: vector.extract_strided_slice
Expand Down
Loading

0 comments on commit dd1a80d

Please sign in to comment.