Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aling 'linalg-to-xegpu' pass with patched XeGPU dialect #201

Merged
merged 13 commits into from
Aug 7, 2024
4 changes: 2 additions & 2 deletions cmake/imex.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ if (NOT DEFINED IMEX_INCLUDES)

# TODO: Change to main https://github.com/intel/mlir-extensions when all the
# required functionality is merged.
gc_fetch_content(imex 496b240093b5e132b60c5ee69878300fe69be300 https://github.com/Menooker/mlir-extensions
SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=0
gc_fetch_content(imex d5bbd635dee500b8cff138686833bacfac5ade78 https://github.com/Menooker/mlir-extensions
Copy link
Contributor Author

@dchigarev dchigarev Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to the latest commit in dev branch

SET IMEX_CHECK_LLVM_VERSION=ON IMEX_ENABLE_L0_RUNTIME=${IMEX_ENABLE_L0_RUNTIME}
)

set(IMEX_INCLUDES
Expand Down
2 changes: 1 addition & 1 deletion include/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ if(GC_ENABLE_DNNL)
list(APPEND TABLEGEN_MACROS -DGC_HAS_ONEDNN_DIALECT)
endif()
if(GC_ENABLE_IMEX)
list(APPEND TABLEGEN_MACROS -DGC_USE_IMEX)
list(APPEND TABLEGEN_MACROS -DGC_ENABLE_IMEX)
dchigarev marked this conversation as resolved.
Show resolved Hide resolved
endif()

set(LLVM_TARGET_DEFINITIONS Passes.td)
Expand Down
4 changes: 2 additions & 2 deletions include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
}
#endif

#ifdef GC_USE_IMEX
#ifdef GC_ENABLE_IMEX
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
let summary = "Convert linalg dialect to XeGPU dialect.";
let description = [{
Expand All @@ -59,6 +59,6 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
"DPAS register block sizes MxNxK">,
];
}
#endif
#endif // GC_ENABLE_IMEX

#endif // GC_DIALECT_GC_PASSES
2 changes: 1 addition & 1 deletion lib/gc/ExecutionEngine/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else()
endif()

set(GC_PASSES GcInterface GcPasses)
if(GC_UNABLE_GPU)
if(GC_ENABLE_IMEX)
list(APPEND GC_PASSES GcGpuPasses)
endif()

Expand Down
39 changes: 28 additions & 11 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
Location loc, ValueRange tiles,
ArrayRef<int64_t> offsets) {
SmallVector<Value> updatedTiles;
// convert static offsets to dynamic because of this IMEX bug:
// https://github.com/intel/mlir-extensions/issues/815
std::vector<Value> dynOffsets;
for (auto &x : offsets) {
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, x);
dynOffsets.push_back(offset);
}
ValueRange newOffsets{dynOffsets};
for (auto tile : tiles) {
auto updatedTile =
rewriter
.create<xegpu::UpdateNdOffsetOp>(loc, tile.getType(), tile,
/*offsets=*/ValueRange{}, offsets)
.getResult();
auto updatedTile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, tile.getType(), tile,
/*offsets=*/newOffsets,
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
updatedTiles.push_back(updatedTile);
}

Expand Down Expand Up @@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,

SmallVector<Value> tiles;
for (int i = 0; i < loadShape[0]; i += descTile[0]) {
// convert static offsets to dynamic because of this IMEX bug:
// https://github.com/intel/mlir-extensions/issues/815
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
auto tile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, descType, rootTile,
/*offsets=*/ValueRange{}, SmallVector<int64_t>{i, j})
/*offsets=*/ValueRange{newRowOffs, newColOffs},
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
tiles.push_back(tile);
}
Expand Down Expand Up @@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,

VectorType vecLoadType =
VectorType::get(tileType.getShape(), tileType.getElementType());
UnitAttr vnniAxisAttr = nullptr;
mlir::UnitAttr packedAttr = nullptr;
if (vnniConf) {
vnniAxisAttr = UnitAttr::get(rewriter.getContext());
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(),
*vnniConf);
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
}

IntegerAttr transpose_bit = nullptr;
SmallVector<Value> loadVec;
for (auto tile : loadTiles) {

auto loadOp = rewriter.create<xegpu::LoadNdOp>(
loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr,
loc, vecLoadType, tile, packedAttr, transpose, transpose_bit,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
loadVec.push_back(loadOp);
Expand Down Expand Up @@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,

// Load A sub-tiles.
SmallVector<Value> loadVecA =
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vnniConfA can't be used during loading since vnniAxis=1 is now longer supported. However we still need this config to compute proper tiles for xegpu.dpas later in the code.

loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());

// Load B sub-tiles.
Expand Down
2 changes: 1 addition & 1 deletion src/gc-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ target_link_libraries(gc-opt PRIVATE

if(GC_ENABLE_IMEX)
include(imex)
target_compile_options(gc-opt PRIVATE -DGC_USE_IMEX)
target_compile_options(gc-opt PRIVATE -DGC_ENABLE_IMEX)
get_property(IMEX_INCLUDES GLOBAL PROPERTY IMEX_INCLUDES)
target_include_directories(gc-opt PRIVATE ${IMEX_INCLUDES})
target_link_libraries(gc-opt PRIVATE
Expand Down
6 changes: 3 additions & 3 deletions src/gc-opt/gc-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include "mlir/InitAllPasses.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"

#ifdef GC_USE_IMEX
#ifdef GC_ENABLE_IMEX
#include <imex/InitIMEXDialects.h>
#include <imex/InitIMEXPasses.h>
#endif
Expand All @@ -38,7 +38,7 @@ void registerCPUPipeline();
} // namespace mlir::gc

int main(int argc, char *argv[]) {
#ifdef GC_USE_IMEX
#ifdef GC_ENABLE_IMEX
imex::registerTransformsPasses();
// Conversion passes
imex::registerConvertGPUToGPUX();
Expand All @@ -59,7 +59,7 @@ int main(int argc, char *argv[]) {
registry.insert<mlir::linalgx::LinalgxDialect>();
registry.insert<mlir::microkernel::MicrokernelDialect>();
mlir::registerAllDialects(registry);
#ifdef GC_USE_IMEX
#ifdef GC_ENABLE_IMEX
registry.insert<::imex::xetile::XeTileDialect, ::imex::gpux::GPUXDialect>();
#endif
mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry);
Expand Down
3 changes: 3 additions & 0 deletions test/mlir/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ set(GC_OPT_TEST_DEPENDS

if(GC_ENABLE_IMEX)
include(imex)
if (IMEX_ENABLE_L0_RUNTIME)
list(APPEND GC_OPT_TEST_DEPENDS level-zero-runtime)
endif()
list(APPEND GC_OPT_TEST_DEPENDS GcOpenclRuntime)
endif()

Expand Down
8 changes: 4 additions & 4 deletions test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Create output initial value load tiles.
// CHECK: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]]
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [0, 0]
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imex doesn't support constant offsets (see intel/mlir-extensions#815)

// CHECK-COUNT-7: xegpu.update_nd_offset %[[rootC]]

// Load initial accumulator values.
Expand All @@ -31,9 +31,9 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Create input load tiles.
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]]
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [0, 0]
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0]
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [0, 0]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
// CHECK-COUNT-1: xegpu.update_nd_offset %[[rootB]]

// Create DPAS computation loop over tiled reduction dimension.
Expand Down Expand Up @@ -63,7 +63,7 @@ func.func @matmul(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: mem

// Extract DPAS-sized chunks from larger loaded tile A.
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x8x2xf16> to vector<512xf16>
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not load the A matrix via vnni_axis=1 anymore (see packed_attr)

// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
// CHECK-COUNT-3: vector.extract_strided_slice
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/test/gc/Transforms/GPU/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
if not config.gc_use_imex:
if not config.gc_enable_imex:
config.unsupported = True
Loading