From 3b9094dc85e24fd0c954984854821fade65c91a9 Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Sun, 1 Dec 2024 14:25:56 +0000 Subject: [PATCH 1/3] add initial conversions for xevm block ops --- include/gc/Conversion/Passes.h | 2 +- include/gc/Dialect/LLVMIR/XeVMOps.td | 55 +++ lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 278 ++++++++++++- lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp | 31 +- .../LLVMIR/XeVM/XeVMToLLVMIRTranslation.cpp | 49 +++ src/gc-opt/gc-opt.cpp | 3 + .../gc/Conversion/XeVMToLLVM/blockload2d.mlir | 381 ++++++++++++++++++ .../XeVMToLLVM/blockprefetch2d.mlir | 200 +++++++++ .../Conversion/XeVMToLLVM/blockstore2d.mlir | 42 ++ 9 files changed, 1033 insertions(+), 8 deletions(-) create mode 100644 test/mlir/test/gc/Conversion/XeVMToLLVM/blockload2d.mlir create mode 100644 test/mlir/test/gc/Conversion/XeVMToLLVM/blockprefetch2d.mlir create mode 100644 test/mlir/test/gc/Conversion/XeVMToLLVM/blockstore2d.mlir diff --git a/include/gc/Conversion/Passes.h b/include/gc/Conversion/Passes.h index d589d7d9..a2e1b4a2 100644 --- a/include/gc/Conversion/Passes.h +++ b/include/gc/Conversion/Passes.h @@ -9,7 +9,7 @@ #ifndef GC_CONVERSION_PASSES_H #define GC_CONVERSION_PASSES_H -#include "gc/Conversion/XeVMToLLVM.h" +#include "gc/Conversion/XeVMToLLVM/XeVMToLLVM.h" namespace mlir { diff --git a/include/gc/Dialect/LLVMIR/XeVMOps.td b/include/gc/Dialect/LLVMIR/XeVMOps.td index efc33728..db9ff9cf 100644 --- a/include/gc/Dialect/LLVMIR/XeVMOps.td +++ b/include/gc/Dialect/LLVMIR/XeVMOps.td @@ -19,6 +19,15 @@ def XeVM_Dialect : Dialect { let name = "xevm"; let cppNamespace = "::mlir::xevm"; let dependentDialects = ["LLVM::LLVMDialect"]; + + let extraClassDeclaration = [{ + /// Get the name for the attribute used to specify cache control + /// decorations. + static constexpr ::llvm::StringRef getCacheControlsAttrName() { + return ::llvm::StringLiteral("xevm.DecorationCacheControlINTEL"); + } + }]; + let useDefaultAttributePrinterParser = 1; } @@ -161,6 +170,52 @@ def XeVM_BlockStore2dOp : XeVM_Op<"blockstore2d">, let hasVerifier = 1; } +def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">, + Arguments<(ins + Arg:$ptr, + I32:$base_width, + I32:$base_height, + I32:$base_pitch, + I32:$x, + I32:$y, + I32Attr:$elem_size_in_bits, + I32Attr:$tile_width, + I32Attr:$tile_height, + I32Attr:$v_blocks, + DefaultValuedAttr:$l1_cache_control, + DefaultValuedAttr:$l3_cache_control + )> { + + let summary = "2D block prefetch"; + + let description = [{ + The `xevm.blockprefetch2d` operation prefetches a two dimensional tile + from a larger matrix residing in memory. The parameters are: + $ptr - the base address of the matrix containing the tile to prefetch + $base_width, $base_height, $base_pitch - the shape of the matrix + $x, $y, $tile_width, $tile_height - the starting offsets and shape of tile to prefetch + $elem_size_in_bits - the size in bits of the matrix element + - 32 for f32, bf32 + - 16 for f16, int16, bf16 + - 8 for int8, int4, int2 + $v_blocks - number of tiles to prefetch + $cache_control - an enumerator that sets the L1 and L3 cache behaviour + + Notes: + - coordinate is provided in elements, while width and pitch are provided in bytes. + }]; + + let assemblyFormat = [{ + operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,` + `tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `l1_cache_control` `=` $l1_cache_control `,` + `l3_cache_control` `=` $l3_cache_control `}` + attr-dict `:` `(` type(operands) `)` + }]; + + let hasVerifier = 1; +} + + def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> { let description = [{ GPU target attribute for controlling compilation of targets. All diff --git a/lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 0a0ab302..8b374d06 100644 --- a/lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -11,9 +11,18 @@ #include "gc/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/Support/FormatVariadic.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "xevm-to-llvm" @@ -26,6 +35,231 @@ using namespace mlir; using namespace xevm; namespace { +struct LLVMFuncAttributeOptions { + bool isConvergent = false; + bool isNoUnwind = false; + bool isWillReturn = false; + LLVM::MemoryEffectsAttr memEffectsAttr{}; +}; +// static constexpr LLVMFuncAttributeOptions convergentAttrs = { +// true, false, false, {}}; +// static constexpr LLVMFuncAttributeOptions noUnwindAttrs = { +// false, true, false, {}}; +static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = { + false, true, true, {}}; +// static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = +// { +// true, true, true, {}}; + +std::string getTypeMangling(Type ty, bool isUnsigned = false) { + return TypeSwitch(ty) + .Case([isUnsigned](VectorType ty) -> std::string { + return "Dv" + std::to_string(ty.getNumElements()) + "_" + + getTypeMangling(ty.getElementType(), isUnsigned); + }) + .Case([](Float16Type) -> std::string { return "Dh"; }) + .Case([](Float32Type) -> std::string { return "f"; }) + .Case([](Float64Type) -> std::string { return "d"; }) + .Case([isUnsigned](IntegerType ty) -> std::string { + switch (ty.getWidth()) { + case 8: + return isUnsigned ? "h" : "c"; + case 16: + return isUnsigned ? "t" : "s"; + case 32: + return isUnsigned ? "j" : "i"; + case 64: + return isUnsigned ? "m" : "l"; + default: + llvm_unreachable("unhandled integer type"); + } + }); +} + +template +static std::optional +getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op, + const bool isLoad) { + if ((op.getL1CacheControlAttr() == + xevm::L1StoreCacheControlAttr::get( + rewriter.getContext(), xevm::L1StoreCacheControl::DEFAULT) && + op.getL3CacheControlAttr() == + xevm::L3StoreCacheControlAttr::get( + rewriter.getContext(), xevm::L3StoreCacheControl::DEFAULT)) || + + (op.getL1CacheControlAttr() == + xevm::L1LoadCacheControlAttr::get( + rewriter.getContext(), xevm::L1LoadCacheControl::DEFAULT) && + op.getL3CacheControlAttr() == + xevm::L3LoadCacheControlAttr::get( + rewriter.getContext(), xevm::L3LoadCacheControl::DEFAULT))) { + return {}; + } + constexpr int32_t decorationCacheControlArity{4}; + constexpr int32_t loadCacheControlKey{6442}; + constexpr int32_t storeCacheControlKey{6443}; + constexpr int32_t l1Level{0}; + constexpr int32_t l3Level{1}; + const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; + SmallVector decorationsL1{ + controlKey, l1Level, static_cast(op.getL1CacheControl()), 0}; + SmallVector decorationsL3{ + controlKey, l3Level, static_cast(op.getL3CacheControl()), 0}; + auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1); + auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3); + + SmallVector combinedAttrs = {arrayAttrL1, arrayAttrL3}; + return rewriter.getArrayAttr(combinedAttrs); +} + +static LLVM::CallOp createDeviceFunctionCall( + ConversionPatternRewriter &rewriter, StringRef funcName, Type retType, + ArrayRef argTypes, ArrayRef args, + mlir::ArrayRef> paramAttrs, + LLVMFuncAttributeOptions funcAttributeOptions) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + MLIRContext *ctx = rewriter.getContext(); + Location loc = UnknownLoc::get(ctx); + + LLVM::LLVMFuncOp funcOp = + LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, retType); + funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); + funcOp.setConvergent(funcAttributeOptions.isConvergent); + funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind); + funcOp.setWillReturn(funcAttributeOptions.isWillReturn); + + if (funcAttributeOptions.memEffectsAttr) + funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr); + + for (auto [idx, attrName] : paramAttrs) + funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr()); + + // if (!passthroughAttrs.getFnAttributes().empty()) + // funcOp->setAttrs(passthroughAttrs.getFnAttributes().getDictionary(ctx)); + + auto callOp = rewriter.create(loc, funcOp, args); + callOp->setAttrs(funcOp->getAttrs()); + + return callOp; +} + +template +class LoadStorePrefetchNdToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isLoad = std::is_same_v; + constexpr bool isStore = std::is_same_v; + constexpr bool isPrefetch = std::is_same_v; + auto loc = op.getLoc(); + VectorType vecType; + if constexpr (isLoad) { + vecType = op.getRes().getType(); + } else if constexpr (isStore) { + vecType = op.getStoredVal().getType(); + } + + auto i32Type = rewriter.getI32Type(); + bool vnni = false; + bool transpose = false; + if constexpr (isLoad) { + vnni = op.getVnniTransform(); + transpose = op.getTranspose(); + } + + Value byteCoord = + rewriter.create(loc, VectorType::get(2, i32Type)); + Value zero = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(0)); + Value one = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(1)); + byteCoord = rewriter.create( + loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); + byteCoord = rewriter.create( + loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); + SmallVector args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(), + op.getBasePitch(), byteCoord}; + SmallVector retTypes; + Value spvLoadDstPtr; + std::string funcName, bitWidthId; + SmallVector, 4> paramAttrs; + if constexpr (isPrefetch) { // Prefetch + funcName = "intel_sub_group_2d_block_prefetch"; + paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())}; + } else { + auto vecElemType = vecType.getElementType(); + auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth(); + Value numElems = rewriter.create( + loc, i32Type, vecType.getNumElements()); + auto dstOrSrcPtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType, + numElems); + args.push_back(dstOrSrcPtr); + if constexpr (isLoad) { // Load + funcName = "intel_sub_group_2d_block_read"; + bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true); + if (vnni) + funcName += "_transform"; + else if (transpose) + funcName += "_transpose"; + spvLoadDstPtr = dstOrSrcPtr; + retTypes.push_back(vecType); + paramAttrs = { + std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()), + }; + } else { // Store + funcName = "intel_sub_group_2d_block_write"; + bitWidthId = (vecElemBitWidth == 32) + ? "j" + : ((vecElemBitWidth == 16) ? "t" : "h"); + rewriter.create(loc, op.getStoredVal(), dstOrSrcPtr); + paramAttrs = { + std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), + std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()), + }; + } + } + + // !X = !{i32 %decoration_kind%, i32 %level%, i32 %control%, i32 %operand of + // the instruction to decorate%} + funcName = + llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(), + op.getTileHeight(), op.getTileWidth(), op.getVBlocks()) + .str(); + funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(), + funcName, isPrefetch ? "" : "P", bitWidthId) + .str(); + SmallVector argTypes; + for (auto arg : args) { + argTypes.push_back(arg.getType()); + } + LLVM::CallOp call = createDeviceFunctionCall( + rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), + argTypes, args, paramAttrs, noUnwindWillReturnAttrs); + if (std::optional optCacheControls = + getCacheControlMetadata(rewriter, op, isLoad || isPrefetch)) { + call->setAttr(xevm::XeVMDialect::getCacheControlsAttrName(), + *optCacheControls); + } + if constexpr (isLoad) + rewriter.replaceOp( + op, rewriter.create(loc, vecType, spvLoadDstPtr)); + else + rewriter.eraseOp(op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + struct ConvertXeVMToLLVMPass : public impl::ConvertXeVMToLLVMPassBase { using Base::Base; @@ -37,19 +271,51 @@ struct ConvertXeVMToLLVMPass void runOnOperation() override { ConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); - RewritePatternSet pattern(&getContext()); - mlir::populateXeVMToLLVMConversionPatterns(pattern); - if (failed( - applyPartialConversion(getOperation(), target, std::move(pattern)))) + target.addIllegalDialect(); + RewritePatternSet patterns(&getContext()); + mlir::populateXeVMToLLVMConversionPatterns(patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; } // namespace +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + void mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) { - /*TODO*/ + patterns.add, + LoadStorePrefetchNdToOCLPattern, + LoadStorePrefetchNdToOCLPattern>( + patterns.getContext()); } +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert XeVM to LLVM. +struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const final { + populateXeVMToLLVMConversionPatterns(patterns); + } +}; +} // namespace + void mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { - /*TODO*/ + registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) { + dialect->addInterfaces(); + }); } diff --git a/lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp b/lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp index f7ca5a63..370f9df5 100644 --- a/lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp +++ b/lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -24,7 +24,8 @@ namespace { constexpr uint32_t subgroupSize = 16; template LogicalResult verifyMatrixInput(Op op) { - static_assert(llvm::is_one_of::value, + static_assert(llvm::is_one_of::value, "Unexpected template parameter"); std::optional width = getConstantIntValue(op.getBaseWidth()); @@ -279,6 +280,34 @@ LogicalResult BlockStore2dOp::verify() { return success(); } +LogicalResult BlockPrefetch2dOp::verify() { + if (verifyMatrixInput(*this).failed()) + return failure(); + + uint32_t tileWidth = getTileWidth(); + switch (getElemSizeInBits()) { + case 8: + if (tileWidth != 16 && tileWidth != 32) + return emitOpError("tile_width for 8 bit elements should be equal to " + "16 or 32"); + break; + case 16: + if (tileWidth != 16) + return emitOpError("tile_width for 16 bit elements should be equal " + "to 16"); + break; + case 32: + if (tileWidth != 8 && tileWidth != 16) + return emitOpError( + "tile_width for 32 bit elements should be equal to 8 or 16"); + break; + default: + llvm_unreachable("unexpected element size"); + } + + return success(); +} + LogicalResult XeVMTargetAttr::verify(function_ref emitError, int O, StringRef triple, StringRef chip) { diff --git a/lib/gc/Target/LLVMIR/XeVM/XeVMToLLVMIRTranslation.cpp b/lib/gc/Target/LLVMIR/XeVM/XeVMToLLVMIRTranslation.cpp index 228299c7..f818458c 100644 --- a/lib/gc/Target/LLVMIR/XeVM/XeVMToLLVMIRTranslation.cpp +++ b/lib/gc/Target/LLVMIR/XeVM/XeVMToLLVMIRTranslation.cpp @@ -17,6 +17,11 @@ #include "mlir/IR/Operation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" + #include "llvm/IR/ConstantRange.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/raw_ostream.h" @@ -46,6 +51,19 @@ class XeVMDialectLLVMIRTranslationInterface amendOperation(Operation *op, ArrayRef instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { + StringRef attrName = attribute.getName().getValue(); + if (attrName == xevm::XeVMDialect::getCacheControlsAttrName()) { + auto cacheControlsArray = dyn_cast(attribute.getValue()); + if (cacheControlsArray.size() != 2) { + return op->emitOpError( + "Expected both L1 and L3 cache control attributes!"); + } + if (instructions.size() != 1) { + return op->emitOpError("Expecting a single instruction"); + } + return handleDecorationCacheControl(op, instructions.front(), + cacheControlsArray.getValue()); + } auto func = dyn_cast(op); if (!func) return failure(); @@ -53,6 +71,37 @@ class XeVMDialectLLVMIRTranslationInterface return success(); } + +private: + template + static llvm::Metadata *getConstantIntMD(llvm::Type *type, IntTy val) { + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(type, val)); + } + + static LogicalResult handleDecorationCacheControl(Operation *op, + llvm::Instruction *inst, + ArrayRef attrs) { + SmallVector decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx); + llvm::transform(attrs, std::back_inserter(decorations), + [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * { + auto valuesArray = dyn_cast(attr).getValue(); + std::array metadata; + llvm::transform( + valuesArray, metadata.begin(), + [i32Ty](Attribute valueAttr) { + return getConstantIntMD( + i32Ty, cast(valueAttr).getValue()); + }); + return llvm::MDNode::get(ctx, metadata); + }); + constexpr llvm::StringLiteral decorationCacheControlMDName = + "spirv.DecorationCacheControlINTEL"; + inst->setMetadata(decorationCacheControlMDName, + llvm::MDNode::get(ctx, decorations)); + return success(); + } }; } // namespace diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 2d2a0288..147729ff 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -24,6 +24,7 @@ #ifdef GC_HAS_ONEDNN_DIALECT #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #endif +#include "gc/Conversion/Passes.h" #include "gc/Transforms/Microkernel/MicrokernelPasses.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" @@ -56,6 +57,7 @@ int main(int argc, char *argv[]) { mlir::registerAllPasses(); mlir::gc::registerCPUPipeline(); mlir::gc::registerGraphCompilerPasses(); + mlir::registerGCConversionPasses(); mlir::cpuruntime::registerCPURuntimePasses(); mlir::microkernel::registerMicrokernelPasses(); @@ -72,6 +74,7 @@ int main(int argc, char *argv[]) { registry.insert<::imex::xetile::XeTileDialect, ::imex::gpux::GPUXDialect>(); #endif mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); + mlir::registerConvertXeVMToLLVMInterface(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Graph Compiler modular optimizer driver\n", registry)); } diff --git a/test/mlir/test/gc/Conversion/XeVMToLLVM/blockload2d.mlir b/test/mlir/test/gc/Conversion/XeVMToLLVM/blockload2d.mlir new file mode 100644 index 00000000..6b875b47 --- /dev/null +++ b/test/mlir/test/gc/Conversion/XeVMToLLVM/blockload2d.mlir @@ -0,0 +1,381 @@ +// RUN: gc-opt -convert-xevm-to-llvm -split-input-file %s | FileCheck %s + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { + // CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK-NEXT: [[COORD0:%.*]] = llvm.insertelement %arg4, [[UNDEF]][[[ZERO]] : i32] : vector<2xi32> + // CHECK-NEXT: [[COORD1:%.*]] = llvm.insertelement %arg5, [[COORD0]][[[ONE]] : i32] : vector<2xi32> + // CHECK: [[EIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-NEXT: [[DEST:%.*]] = llvm.alloca [[EIGHT]] x i16 : (i32) -> !llvm.ptr + // CHECK-NEXT: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, [[COORD1]], [[DEST]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi16> + + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_8b_16r32x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_8b_32r32x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=32, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r16x1cPU3AS1viiiDv2_iPh(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi8> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi8> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x1cPU3AS1viiiDv2_iPh(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi8> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi8> + llvm.return +} + +// ----- + +// COM: This case come from the 06 tutorial of FP8 flash attention. +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r16x4cPU3AS1viiiDv2_iPh(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi8> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=4, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi8> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_16r16x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r32x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_32b_8r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<4xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<4xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_8r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_16r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_32b_16r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_32r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=32, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_32b_8r2x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<1xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=2, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<1xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_8b_16r32x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_8b_32r32x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<64xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=32, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<64xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_16r16x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_read_16b_32r16x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<64xi16> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<64xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_32b_8r8x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_16r8x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_32r8x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=32, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transform_8b_32r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=1, transpose=false, vnni_transform=true, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transform_8b_32r16x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=2, transpose=false, vnni_transform=true, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transform_8b_32r16x4cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=4, transpose=false, vnni_transform=true, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=true, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=1, transpose=false, vnni_transform=true, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=2, transpose=false, vnni_transform=true, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x2cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi32> + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=2, transpose=false, vnni_transform=true, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d_(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=1, transpose=true, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6442 : i32, 0 : i32, 1 : i32, 0 : i32{{\]}}, {{\[}}6442 : i32, 1 : i32, 1 : i32, 0 : i32{{\]\]}} + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=L1UC, l3_cache_control=L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6442 : i32, 0 : i32, 1 : i32, 0 : i32{{\]}}, {{\[}}6442 : i32, 1 : i32, 2 : i32, 0 : i32{{\]\]}} + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=L1UC, l3_cache_control=L3C} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6442 : i32, 0 : i32, 2 : i32, 0 : i32{{\]}}, {{\[}}6442 : i32, 1 : i32, 1 : i32, 0 : i32{{\]\]}} + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=L1C, l3_cache_control=L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6442 : i32, 0 : i32, 2 : i32, 0 : i32{{\]}}, {{\[}}6442 : i32, 1 : i32, 2 : i32, 0 : i32{{\]\]}} + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=L1C, l3_cache_control=L3C} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6442 : i32, 0 : i32, 3 : i32, 0 : i32{{\]}}, {{\[}}6442 : i32, 1 : i32, 1 : i32, 0 : i32{{\]\]}} + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=L1S, l3_cache_control=L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6442 : i32, 0 : i32, 3 : i32, 0 : i32{{\]}}, {{\[}}6442 : i32, 1 : i32, 2 : i32, 0 : i32{{\]\]}} + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=L1S, l3_cache_control=L3C} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6442 : i32, 0 : i32, 4 : i32, 0 : i32{{\]}}, {{\[}}6442 : i32, 1 : i32, 2 : i32, 0 : i32{{\]\]}} + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=L1IAR, l3_cache_control=L3C} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} + +// ----- + +llvm.func @xevm.blockload2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockload2d( + // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt( + // CHECK-NOT: xevm.DecorationCacheControlINTEL + %0 = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16> + llvm.return +} diff --git a/test/mlir/test/gc/Conversion/XeVMToLLVM/blockprefetch2d.mlir b/test/mlir/test/gc/Conversion/XeVMToLLVM/blockprefetch2d.mlir new file mode 100644 index 00000000..02439986 --- /dev/null +++ b/test/mlir/test/gc/Conversion/XeVMToLLVM/blockprefetch2d.mlir @@ -0,0 +1,200 @@ +// RUN: gc-opt -convert-xevm-to-llvm -split-input-file %s | FileCheck %s + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.func @xevm.blockprefetch2d(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) { + // CHECK-DAG: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: [[COORD0:%.*]] = llvm.insertelement %arg4, [[UNDEF]][[[ZERO]] : i32] : vector<2xi32> + // CHECK-NEXT: [[COORD1:%.*]] = llvm.insertelement %arg5, [[COORD0]][[[ONE]] : i32] : vector<2xi32> + // CHECK-NEXT: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, [[COORD1]]) + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6442 : i32, 0 : i32, 1 : i32, 0 : i32{{\]}}, {{\[}}6442 : i32, 1 : i32, 1 : i32, 0 : i32{{\]\]}} + // CHECK: (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, l1_cache_control=L1UC, l3_cache_control=L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_8b_16r32x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_8b_32r32x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=32, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z46intel_sub_group_2d_block_prefetch_16b_16r16x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z46intel_sub_group_2d_block_prefetch_16b_32r16x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_32b_8r8x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_32b_8r16x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_32b_16r8x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z46intel_sub_group_2d_block_prefetch_32b_16r16x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=16, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_32b_32r8x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=32, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z46intel_sub_group_2d_block_prefetch_32b_32r16x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=32, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_8b_8r32x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_8b_16r32x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_8b_32r32x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=32, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_8r16x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z46intel_sub_group_2d_block_prefetch_16b_16r16x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z46intel_sub_group_2d_block_prefetch_16b_32r16x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=32, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z44intel_sub_group_2d_block_prefetch_32b_8r8x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_32b_16r8x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=16, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_32b_32r8x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=32, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_8b_32r16x1cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_8b_32r16x2cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=2, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} + +// ----- + +llvm.func @xevm.blockprefetch2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) { + // CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_8b_32r16x4cPU3AS1viiiDv2_i(%arg0, %arg1, %arg2, %arg3, {{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> () + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=32, v_blocks=4, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + llvm.return +} diff --git a/test/mlir/test/gc/Conversion/XeVMToLLVM/blockstore2d.mlir b/test/mlir/test/gc/Conversion/XeVMToLLVM/blockstore2d.mlir new file mode 100644 index 00000000..31ddb9a1 --- /dev/null +++ b/test/mlir/test/gc/Conversion/XeVMToLLVM/blockstore2d.mlir @@ -0,0 +1,42 @@ +// RUN: gc-opt -convert-xevm-to-llvm -split-input-file %s | FileCheck %s + +llvm.func @xevm.blockstore2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) { + // CHECK: llvm.func @xevm.blockstore2d(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: vector<8xi8>) { + // CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: [[UNDEF:%.*]] = llvm.mlir.undef : vector<2xi32> + // CHECK-NEXT: [[COORD0:%.*]] = llvm.insertelement %arg4, [[UNDEF]][[[ZERO]] : i32] : vector<2xi32> + // CHECK-NEXT: [[COORD1:%.*]] = llvm.insertelement %arg5, [[COORD0]][[[ONE]] : i32] : vector<2xi32> + // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-NEXT: [[STOREVALPTR:%.*]] = llvm.alloca [[C8]] x i8 : (i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %arg6, [[STOREVALPTR]] : vector<8xi8>, !llvm.ptr + // CHECK-NEXT: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_write_8b_8r16x1cPU3AS1viiiDv2_iPh(%arg0, %arg1, %arg2, %arg3, [[COORD1]], [[STOREVALPTR]]) + // CHECK: xevm.DecorationCacheControlINTEL = {{\[\[}}6443 : i32, 0 : i32, 1 : i32, 0 : i32{{\]}}, {{\[}}6443 : i32, 1 : i32, 1 : i32, 0 : i32{{\]\]}} + // CHECK: : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + xevm.blockstore2d %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=L1UC, l3_cache_control=L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi8>) + llvm.return +} + +// ----- + +llvm.func @xevm.blockstore2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) { + // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_write_8b_8r32x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + xevm.blockstore2d %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) + llvm.return +} + +// ----- + +llvm.func @xevm.blockstore2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_16b_8r16x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + xevm.blockstore2d %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>) + llvm.return +} + +// ----- + +llvm.func @xevm.blockstore2d(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi32>) { + // CHECK: llvm.call spir_funccc @_Z42intel_sub_group_2d_block_write_32b_8r16x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + xevm.blockstore2d %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + llvm.return +} From de647428a2c25a4048906c1b08576c48ba2c4490 Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Tue, 3 Dec 2024 11:49:53 +0000 Subject: [PATCH 2/3] integration test with xevm --- lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 2 - lib/gc/ExecutionEngine/Driver/CMakeLists.txt | 1 + lib/gc/ExecutionEngine/Driver/Driver.cpp | 5 +++ lib/gc/Transforms/GPU/Pipeline.cpp | 4 +- test/mlir/test/gc/gpu-runner/xevm.mlir | 45 ++++++++++++++++++++ 5 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 test/mlir/test/gc/gpu-runner/xevm.mlir diff --git a/lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 8b374d06..68bbfe29 100644 --- a/lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -226,8 +226,6 @@ class LoadStorePrefetchNdToOCLPattern : public OpConversionPattern { } } - // !X = !{i32 %decoration_kind%, i32 %level%, i32 %control%, i32 %operand of - // the instruction to decorate%} funcName = llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(), op.getTileHeight(), op.getTileWidth(), op.getVBlocks()) diff --git a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt index d6020e8b..00d458e1 100644 --- a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt @@ -40,4 +40,5 @@ gc_add_mlir_library(GcJitWrapper ${conversion_libs} ${GC_PASSES} GcAnalysis + MLIRXeVMToLLVMIRTranslation ) diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index 16da521d..bb718042 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -11,6 +11,8 @@ #ifdef GC_HAS_ONEDNN_DIALECT #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #endif +#include "gc/Conversion/Passes.h" +#include "gc/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" @@ -26,11 +28,13 @@ namespace gc { static DialectRegistry initDialects() { mlir::registerAllPasses(); mlir::gc::registerGraphCompilerPasses(); + mlir::registerGCConversionPasses(); mlir::cpuruntime::registerCPURuntimePasses(); mlir::DialectRegistry registry; registry.insert(); mlir::registerAllDialects(registry); mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); + mlir::registerConvertXeVMToLLVMInterface(registry); #ifdef GC_HAS_ONEDNN_DIALECT registry.insert(); #endif @@ -38,6 +42,7 @@ static DialectRegistry initDialects() { llvm::InitializeNativeTargetAsmPrinter(); llvm::InitializeNativeTargetAsmParser(); mlir::registerAllToLLVMIRTranslations(registry); + mlir::registerXeVMDialectTranslation(registry); return registry; } diff --git a/lib/gc/Transforms/GPU/Pipeline.cpp b/lib/gc/Transforms/GPU/Pipeline.cpp index 130b25a1..591cafe1 100644 --- a/lib/gc/Transforms/GPU/Pipeline.cpp +++ b/lib/gc/Transforms/GPU/Pipeline.cpp @@ -8,8 +8,8 @@ #include +#include "gc/Conversion/Passes.h" #include "gc/Transforms/Passes.h" - #include "imex/Conversion/Passes.h" #include "imex/Transforms/Passes.h" @@ -110,7 +110,7 @@ void populateGPUPipeline(OpPassManager &pm, pm.addPass(createArithToLLVMConversionPass()); pm.addPass(createConvertFuncToLLVMPass()); pm.addPass(createConvertMathToLLVMPass()); - + pm.addPass(createConvertXeVMToLLVMPass()); if (pipelineOpts.useGpuRuntime) { pm.addPass(createGpuToGpuOcl({pipelineOpts.callFinish})); } else { diff --git a/test/mlir/test/gc/gpu-runner/xevm.mlir b/test/mlir/test/gc/gpu-runner/xevm.mlir new file mode 100644 index 00000000..9053b8f8 --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/xevm.mlir @@ -0,0 +1,45 @@ +// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils %s | FileCheck %s + +module{ + +func.func @load_store(%src: memref<8x16xf32>, %dst: memref<8x16xf32>) -> memref<8x16xf32> { + %constant = arith.constant 1.23 : f32 + %c0 = arith.constant 0 : index + memref.store %constant, %dst[%c0, %c0] : memref<8x16xf32> + + %0 = memref.extract_aligned_pointer_as_index %src : memref<8x16xf32> -> index + %1 = arith.index_cast %0 : index to i64 + %ptr_generic = llvm.inttoptr %1 : i64 to !llvm.ptr + %ptr = llvm.addrspacecast %ptr_generic : !llvm.ptr to !llvm.ptr<1> + + + %base_width = arith.constant 16 : i32 + %base_height = arith.constant 16 : i32 + %base_pitch = arith.constant 16 : i32 + %x = arith.constant 0 : i32 + %y = arith.constant 0 : i32 + + %loaded = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + + %dst_ptr_as_idx = memref.extract_aligned_pointer_as_index %dst : memref<8x16xf32> -> index + %dst_ptr_as_i64 = arith.index_cast %dst_ptr_as_idx : index to i64 + %dst_ptr_generic = llvm.inttoptr %dst_ptr_as_i64 : i64 to !llvm.ptr + %dst_ptr = llvm.addrspacecast %dst_ptr_generic : !llvm.ptr to !llvm.ptr<1> + + xevm.blockstore2d %dst_ptr, %base_width, %base_height, %base_pitch, %x, %y, %loaded {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + + return %dst : memref<8x16xf32> +} + +func.func @main() { + %src = memref.alloc() : memref<8x16xf32> + %dst = memref.alloc() : memref<8x16xf32> + %gpu_res = call @load_store(%src, %dst) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32> + %cast = memref.cast %gpu_res : memref<8x16xf32> to memref<*xf32> + call @printMemrefF32(%cast) : (memref<*xf32>) -> () + return +} + +func.func private @printMemrefF32(%ptr : memref<*xf32>) + +} From d9df05767cfc48f8a3f8fd2217534a8b89e7a9c4 Mon Sep 17 00:00:00 2001 From: Artem Kroviakov Date: Fri, 13 Dec 2024 09:12:45 +0000 Subject: [PATCH 3/3] lower via upstream path --- include/gc/Conversion/XeVMToLLVM/XeVMToLLVM.h | 2 +- include/gc/ExecutionEngine/Driver/Driver.h | 2 +- include/gc/Transforms/Passes.td | 34 +++++++ lib/gc/ExecutionEngine/Driver/CMakeLists.txt | 2 + lib/gc/ExecutionEngine/Driver/Driver.cpp | 12 ++- lib/gc/Target/LLVM/CMakeLists.txt | 3 + lib/gc/Transforms/CMakeLists.txt | 3 +- lib/gc/Transforms/GPU/CMakeLists.txt | 2 + lib/gc/Transforms/GPU/XeVMAttachTarget.cpp | 82 +++++++++++++++++ src/gc-cpu-runner/CMakeLists.txt | 2 +- src/gc-cpu-runner/gc-cpu-runner.cpp | 9 +- src/gc-opt/CMakeLists.txt | 8 ++ src/gc-opt/gc-opt.cpp | 11 +++ test/mlir/test/gc/gpu-runner/xevm.mlir | 91 ++++++++++--------- 14 files changed, 207 insertions(+), 56 deletions(-) create mode 100644 lib/gc/Transforms/GPU/XeVMAttachTarget.cpp diff --git a/include/gc/Conversion/XeVMToLLVM/XeVMToLLVM.h b/include/gc/Conversion/XeVMToLLVM/XeVMToLLVM.h index 5081d216..b693d151 100644 --- a/include/gc/Conversion/XeVMToLLVM/XeVMToLLVM.h +++ b/include/gc/Conversion/XeVMToLLVM/XeVMToLLVM.h @@ -17,7 +17,7 @@ class RewritePatternSet; class Pass; #define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS -#include "mlir/Conversion/Passes.h.inc" +#include "gc/Conversion/Passes.h.inc" void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns); diff --git a/include/gc/ExecutionEngine/Driver/Driver.h b/include/gc/ExecutionEngine/Driver/Driver.h index ee8630b5..7ae185f4 100644 --- a/include/gc/ExecutionEngine/Driver/Driver.h +++ b/include/gc/ExecutionEngine/Driver/Driver.h @@ -18,7 +18,7 @@ namespace mlir { class DialectRegistry; namespace gc { -const DialectRegistry &initCompilerAndGetDialects(); +DialectRegistry &initCompilerAndGetDialects(); // the pointers to XXXMemRefType using GeneralMemrefPtr = void *; diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 1f30b877..3acbfe35 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -261,4 +261,38 @@ def LowerToTileVector : Pass<"lower-to-tile-vector", "func::FuncOp"> { ]; } +def GpuXeVMAttachTarget: Pass<"xevm-attach-target", ""> { + let summary = "Attaches a XeVM target attribute to a GPU Module."; + let description = [{ + This pass searches for all GPU Modules in the immediate regions and attaches + a XeVM target if the module matches the name specified by the `module` argument. + + Example: + ``` + // File: in.mlir: + gpu.module @xevm_module_1 {...} + gpu.module @xevm_module_2 {...} + gpu.module @xevm_module_1 {...} + // mlir-opt --xevm-attach-target="module=xevm.* chip=pvc" in.mlir + gpu.module @xevm_module_1 {...} + gpu.module @xevm_module_2 {...} + gpu.module @xevm_module_1 [#xevm.target] {...} + ``` + }]; + let options = [ + Option<"moduleMatcher", "module", "std::string", + /*default=*/ [{""}], + "Regex used to identify the modules to attach the target to.">, + Option<"triple", "triple", "std::string", + /*default=*/ "\"spirv64-unknown-unknown\"", + "Target triple.">, + Option<"chip", "chip", "std::string", + /*default=*/"\"pvc\"", + "Target chip.">, + Option<"optLevel", "O", "unsigned", + /*default=*/"2", + "Optimization level."> + ]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt index 00d458e1..cc8cac60 100644 --- a/lib/gc/ExecutionEngine/Driver/CMakeLists.txt +++ b/lib/gc/ExecutionEngine/Driver/CMakeLists.txt @@ -25,6 +25,7 @@ else() MLIRToLLVMIRTranslationRegistration ) endif() +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) set(GC_PASSES GcInterface GcPasses) if(GC_ENABLE_IMEX) @@ -38,6 +39,7 @@ gc_add_mlir_library(GcJitWrapper ${MLIR_LINK_COMPONENTS} ${dialect_libs} ${conversion_libs} + ${extension_libs} ${GC_PASSES} GcAnalysis MLIRXeVMToLLVMIRTranslation diff --git a/lib/gc/ExecutionEngine/Driver/Driver.cpp b/lib/gc/ExecutionEngine/Driver/Driver.cpp index bb718042..7c10f324 100644 --- a/lib/gc/ExecutionEngine/Driver/Driver.cpp +++ b/lib/gc/ExecutionEngine/Driver/Driver.cpp @@ -12,9 +12,11 @@ #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #endif #include "gc/Conversion/Passes.h" +#include "gc/Target/LLVM/XeVM/Target.h" #include "gc/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/All.h" @@ -34,19 +36,23 @@ static DialectRegistry initDialects() { registry.insert(); mlir::registerAllDialects(registry); mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); + mlir::registerAllExtensions(registry); + // Adds missing `LLVMTranslationDialectInterface` registration for dialect for + // gpu.module op + mlir::registerAllToLLVMIRTranslations(registry); mlir::registerConvertXeVMToLLVMInterface(registry); + mlir::registerXeVMDialectTranslation(registry); + mlir::xevm::registerXeVMTargetInterfaceExternalModels(registry); #ifdef GC_HAS_ONEDNN_DIALECT registry.insert(); #endif llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); llvm::InitializeNativeTargetAsmParser(); - mlir::registerAllToLLVMIRTranslations(registry); - mlir::registerXeVMDialectTranslation(registry); return registry; } -const DialectRegistry &initCompilerAndGetDialects() { +DialectRegistry &initCompilerAndGetDialects() { static DialectRegistry reg = initDialects(); return reg; } diff --git a/lib/gc/Target/LLVM/CMakeLists.txt b/lib/gc/Target/LLVM/CMakeLists.txt index 81d4d563..6e619e3a 100644 --- a/lib/gc/Target/LLVM/CMakeLists.txt +++ b/lib/gc/Target/LLVM/CMakeLists.txt @@ -7,6 +7,9 @@ gc_add_mlir_dialect_library(MLIRXeVMTarget ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR ${PROJECT_SOURCE_DIR}/include/gc/Dialect/LLVMIR + LINK_COMPONENTS + SPIRVCodeGen + LINK_LIBS PUBLIC MLIRIR MLIRExecutionEngineUtils diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 111d0c75..7d1007b0 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -30,7 +30,8 @@ gc_add_mlir_library(GcPasses DEPENDS GraphCompilerPassIncGen - + GCConversionPassIncGen + LINK_LIBS PUBLIC ${mlir_dialect_libs} ${mlir_conversion_libs} diff --git a/lib/gc/Transforms/GPU/CMakeLists.txt b/lib/gc/Transforms/GPU/CMakeLists.txt index f4b286b9..f7d10b3d 100644 --- a/lib/gc/Transforms/GPU/CMakeLists.txt +++ b/lib/gc/Transforms/GPU/CMakeLists.txt @@ -17,6 +17,7 @@ gc_add_mlir_library(GcGpuPasses GpuToGpuOcl.cpp LinalgToXeGPU.cpp Pipeline.cpp + XeVMAttachTarget.cpp DEPENDS GraphCompilerPassIncGen @@ -31,6 +32,7 @@ gc_add_mlir_library(GcGpuPasses MLIRMathToSPIRV MLIRControlFlowToSPIRV MLIRMemRefTransforms + MLIRXeVMToLLVMIRTranslation GcInterface GcUtilsIR ${IMEX_LIBS} diff --git a/lib/gc/Transforms/GPU/XeVMAttachTarget.cpp b/lib/gc/Transforms/GPU/XeVMAttachTarget.cpp new file mode 100644 index 00000000..6218120e --- /dev/null +++ b/lib/gc/Transforms/GPU/XeVMAttachTarget.cpp @@ -0,0 +1,82 @@ +//===- XeVMAttachTarget.cpp - Attach an XeVM target -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the `GpuXeVMAttachTarget` pass, attaching `#xevm.target` +// attributes to GPU modules. +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/LLVMIR/XeVMDialect.h" + +#include "gc/Target/LLVM/XeVM/Target.h" +#include "gc/Transforms/Passes.h" + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/Regex.h" +#include + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_GPUXEVMATTACHTARGET +#include "gc/Transforms/Passes.h.inc" +} // namespace gc +} // namespace mlir + +using namespace mlir::xevm; +using namespace mlir; + +namespace { +struct XeVMAttachTarget + : public gc::impl::GpuXeVMAttachTargetBase { + using Base::Base; + + DictionaryAttr getFlags(OpBuilder &builder) const; + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; +} // namespace + +DictionaryAttr XeVMAttachTarget::getFlags(OpBuilder &builder) const { + UnitAttr unitAttr = builder.getUnitAttr(); + SmallVector flags; + auto addFlag = [&](StringRef flag) { + flags.push_back(builder.getNamedAttr(flag, unitAttr)); + }; + if (!flags.empty()) + return builder.getDictionaryAttr(flags); + return nullptr; +} + +void XeVMAttachTarget::runOnOperation() { + OpBuilder builder(&getContext()); + auto target = builder.getAttr(optLevel, triple, chip); + llvm::Regex matcher(moduleMatcher); + for (Region ®ion : getOperation()->getRegions()) + for (Block &block : region.getBlocks()) + for (auto module : block.getOps()) { + // Check if the name of the module matches. + if (!moduleMatcher.empty() && !matcher.match(module.getName())) + continue; + // Create the target array. + SmallVector targets; + if (std::optional attrs = module.getTargets()) + targets.append(attrs->getValue().begin(), attrs->getValue().end()); + targets.push_back(target); + // Remove any duplicate targets. + targets.erase(llvm::unique(targets), targets.end()); + // Update the target attribute array. + module.setTargetsAttr(builder.getArrayAttr(targets)); + } +} diff --git a/src/gc-cpu-runner/CMakeLists.txt b/src/gc-cpu-runner/CMakeLists.txt index a0037d6b..a148cc7d 100644 --- a/src/gc-cpu-runner/CMakeLists.txt +++ b/src/gc-cpu-runner/CMakeLists.txt @@ -32,5 +32,5 @@ if(GC_DEV_LINK_LLVM_DYLIB) endif() gc_add_mlir_tool(gc-cpu-runner gc-cpu-runner.cpp) -target_link_libraries(gc-cpu-runner PRIVATE GcCpuRuntime) +target_link_libraries(gc-cpu-runner PRIVATE GcJitWrapper GcCpuRuntime) mlir_check_all_link_libraries(gc-cpu-runner) diff --git a/src/gc-cpu-runner/gc-cpu-runner.cpp b/src/gc-cpu-runner/gc-cpu-runner.cpp index 3e154343..2b69e427 100644 --- a/src/gc-cpu-runner/gc-cpu-runner.cpp +++ b/src/gc-cpu-runner/gc-cpu-runner.cpp @@ -17,6 +17,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include "gc/ExecutionEngine/Driver/Driver.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/ExecutionEngine/JitRunner.h" @@ -34,13 +36,8 @@ int main(int argc, char **argv) { // keeps GCCPURuntime linked gc_runtime_keep_alive = 0; llvm::InitLLVM y(argc, argv); - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - llvm::InitializeNativeTargetAsmParser(); - - mlir::DialectRegistry registry; + mlir::DialectRegistry ®istry{mlir::gc::initCompilerAndGetDialects()}; registry.insert(); - mlir::registerAllToLLVMIRTranslations(registry); return mlir::JitRunnerMain(argc, argv, registry); } \ No newline at end of file diff --git a/src/gc-opt/CMakeLists.txt b/src/gc-opt/CMakeLists.txt index d6a9b656..fb2af4de 100644 --- a/src/gc-opt/CMakeLists.txt +++ b/src/gc-opt/CMakeLists.txt @@ -29,6 +29,11 @@ if(GC_DEV_LINK_LLVM_DYLIB) else() set(MLIR_LINK_COMPONENTS MLIROptLib + MLIRBuiltinToLLVMIRTranslation + MLIRExecutionEngine + MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRToLLVMIRTranslationRegistration ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) @@ -38,9 +43,12 @@ add_llvm_executable(gc-opt gc-opt.cpp) llvm_update_compile_flags(gc-opt) mlir_check_all_link_libraries(gc-opt) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + target_link_libraries(gc-opt PUBLIC GcInterface) target_link_libraries(gc-opt PRIVATE ${dialect_libs} + ${extension_libs} ${conversion_libs} ${MLIR_LINK_COMPONENTS} GcPasses diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 147729ff..047c611c 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -25,9 +25,14 @@ #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #endif #include "gc/Conversion/Passes.h" +#include "gc/Target/LLVM/XeVM/Target.h" +#include "gc/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" + #include "gc/Transforms/Microkernel/MicrokernelPasses.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" @@ -74,7 +79,13 @@ int main(int argc, char *argv[]) { registry.insert<::imex::xetile::XeTileDialect, ::imex::gpux::GPUXDialect>(); #endif mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); + mlir::registerAllExtensions(registry); + // Adds missing `LLVMTranslationDialectInterface` registration for dialect for + // gpu.module op + mlir::registerAllToLLVMIRTranslations(registry); mlir::registerConvertXeVMToLLVMInterface(registry); + mlir::registerXeVMDialectTranslation(registry); + mlir::xevm::registerXeVMTargetInterfaceExternalModels(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Graph Compiler modular optimizer driver\n", registry)); } diff --git a/test/mlir/test/gc/gpu-runner/xevm.mlir b/test/mlir/test/gc/gpu-runner/xevm.mlir index 9053b8f8..86780e77 100644 --- a/test/mlir/test/gc/gpu-runner/xevm.mlir +++ b/test/mlir/test/gc/gpu-runner/xevm.mlir @@ -1,45 +1,50 @@ // RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils %s | FileCheck %s - -module{ - -func.func @load_store(%src: memref<8x16xf32>, %dst: memref<8x16xf32>) -> memref<8x16xf32> { - %constant = arith.constant 1.23 : f32 - %c0 = arith.constant 0 : index - memref.store %constant, %dst[%c0, %c0] : memref<8x16xf32> - - %0 = memref.extract_aligned_pointer_as_index %src : memref<8x16xf32> -> index - %1 = arith.index_cast %0 : index to i64 - %ptr_generic = llvm.inttoptr %1 : i64 to !llvm.ptr - %ptr = llvm.addrspacecast %ptr_generic : !llvm.ptr to !llvm.ptr<1> - - - %base_width = arith.constant 16 : i32 - %base_height = arith.constant 16 : i32 - %base_pitch = arith.constant 16 : i32 - %x = arith.constant 0 : i32 - %y = arith.constant 0 : i32 - - %loaded = xevm.blockload2d %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> - - %dst_ptr_as_idx = memref.extract_aligned_pointer_as_index %dst : memref<8x16xf32> -> index - %dst_ptr_as_i64 = arith.index_cast %dst_ptr_as_idx : index to i64 - %dst_ptr_generic = llvm.inttoptr %dst_ptr_as_i64 : i64 to !llvm.ptr - %dst_ptr = llvm.addrspacecast %dst_ptr_generic : !llvm.ptr to !llvm.ptr<1> - - xevm.blockstore2d %dst_ptr, %base_width, %base_height, %base_pitch, %x, %y, %loaded {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, l1_cache_control=Default, l3_cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) - - return %dst : memref<8x16xf32> -} - -func.func @main() { - %src = memref.alloc() : memref<8x16xf32> - %dst = memref.alloc() : memref<8x16xf32> - %gpu_res = call @load_store(%src, %dst) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32> - %cast = memref.cast %gpu_res : memref<8x16xf32> to memref<*xf32> - call @printMemrefF32(%cast) : (memref<*xf32>) -> () - return -} - -func.func private @printMemrefF32(%ptr : memref<*xf32>) - + +module @gemm attributes {gpu.container_module} { + + gpu.module @kernels { + gpu.func @load_store(%src: memref<8x16xf32>, %dst: memref<8x16xf32>) kernel { + %constant = arith.constant 1.23 : f32 + %c0 = arith.constant 0 : index + memref.store %constant, %dst[%c0, %c0] : memref<8x16xf32> + gpu.return + } + } + gpu.module @kernel { + gpu.func @store_constant(%ptr: !llvm.ptr) kernel { + %const_val = arith.constant 42.0 : f32 + llvm.store %const_val, %ptr : f32, !llvm.ptr + gpu.return + } + } + + + func.func @test(%src : memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} { + %token0 = gpu.wait async + %c1 = arith.constant 1 : index + %c16 = arith.constant 1 : index + %memref_0 = gpu.alloc [%token0] host_shared () : memref<8x16xf32> + memref.copy %src, %memref_0 : memref<8x16xf32> to memref<8x16xf32> + %0 = memref.extract_aligned_pointer_as_index %memref_0 : memref<8x16xf32> -> index + %1 = arith.index_cast %0 : index to i64 + %2 = llvm.inttoptr %1 : i64 to !llvm.ptr + %token1 = gpu.wait async + %5 = gpu.launch_func async [%token1] @kernel::@store_constant blocks in (%c1, %c1, %c1) threads in (%c1, %c16, %c1) args(%2 : !llvm.ptr) + gpu.wait [%5] + // gpu.dealloc %memref_0 : memref<8x16xf32> + return %memref_0 : memref<8x16xf32> + } + + func.func @main() attributes {llvm.emit_c_interface} { + %A = memref.alloc() : memref<8x16xf32> + %B = call @test(%A) : (memref<8x16xf32>) -> memref<8x16xf32> + %B_cast = memref.cast %B : memref<8x16xf32> to memref<*xf32> + %A_cast = memref.cast %A : memref<8x16xf32> to memref<*xf32> + call @printMemrefF32(%A_cast) : (memref<*xf32>) -> () + call @printMemrefF32(%B_cast) : (memref<*xf32>) -> () + + memref.dealloc %A : memref<8x16xf32> + return + } + func.func private @printMemrefF32(%ptr : memref<*xf32>) }