diff --git a/include/imex/Conversion/CMakeLists.txt b/include/imex/Conversion/CMakeLists.txt index 9e10ff835..27d3afb25 100644 --- a/include/imex/Conversion/CMakeLists.txt +++ b/include/imex/Conversion/CMakeLists.txt @@ -6,3 +6,4 @@ add_public_tablegen_target(IMEXConversionPassIncGen) add_mlir_doc(Passes IMEXConversionPasses ./ -gen-pass-doc) add_subdirectory(DistToStandard) +add_subdirectory(XeTileToXeGPU) diff --git a/include/imex/Conversion/Passes.h b/include/imex/Conversion/Passes.h index de067431d..a113abe10 100644 --- a/include/imex/Conversion/Passes.h +++ b/include/imex/Conversion/Passes.h @@ -21,6 +21,7 @@ #include #include #include +#include namespace imex { diff --git a/include/imex/Conversion/Passes.td b/include/imex/Conversion/Passes.td index 09c858d9a..88a0f03e3 100644 --- a/include/imex/Conversion/Passes.td +++ b/include/imex/Conversion/Passes.td @@ -330,4 +330,48 @@ def ConvertGPUXToLLVM : Pass<"convert-gpux-to-llvm", "::mlir::ModuleOp"> { } +//===----------------------------------------------------------------------===// +// XeTileToXeGPU +//===----------------------------------------------------------------------===// + +def ConvertXeTileToXeGPU: Pass<"convert-xetile-to-xegpu", "::mlir::ModuleOp"> { + let summary = "Convert from the XeTile dialect to the XeGPU dialect."; + let description = [{ + Convert XeTile dialect operations into the XeGPU dialect operations. It expects + the input code is tiled using xetile-tiling. + + #### Input invariant + + func.func @sglevel_tiled_load_tile(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<2x1x8x16xf16> + %2 = xetile.load_tile %1 : !xetile.tile<2x1x8x16xf16> -> vector<2x1x8x16xf16> + return + } + + #### Output IR + + func.func @sglevel_tiled_load_tile(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c64] {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %c8 = arith.constant 8 : index + %c64_0 = arith.constant 64 : index + %1 = xegpu.create_nd_tdesc %arg0[%c8, %c64_0] {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %2 = xegpu.load_nd %0 {mode = vc, l1_hint = uncached, l2_hint = uncached, l3_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 {mode = vc, l1_hint = uncached, l2_hint = uncached, l3_hint = uncached} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + return + } + }]; + + let constructor = "::imex::createConvertXeTileToXeGPUPass()"; + let dependentDialects = ["::imex::xegpu::XeGPUDialect", + "::imex::xetile::XeTileDialect", + "::mlir::vector::VectorDialect", + "::mlir::arith::ArithDialect", + ]; + let options = []; +} + #endif // _IMEX_CONVERSION_PASSES_TD_INCLUDED_ diff --git a/include/imex/Conversion/XeTileToXeGPU/CMakeLists.txt b/include/imex/Conversion/XeTileToXeGPU/CMakeLists.txt new file mode 100644 index 000000000..e69de29bb diff --git a/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h b/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h new file mode 100644 index 000000000..5abf7571c --- /dev/null +++ b/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h @@ -0,0 +1,45 @@ +//===- XeTileToXeGPU.h - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the XeTileToXeGPU conversion, converting the XeTile +/// dialect to the XeGPU dialect. +/// +//===----------------------------------------------------------------------===// + +#ifndef _XeTileToXeGPU_H_INCLUDED_ +#define _XeTileToXeGPU_H_INCLUDED_ + +#include +#include +#include + +#include "XeTileToXeGPUConversion.h" + +namespace mlir { +class MLIRContext; +class ModuleOp; +template class OperationPass; +class RewritePatternSet; +} // namespace mlir + +namespace imex { +class XeGPUTypeConverter; + +/// Populate the given list with patterns rewrite XeTile Ops +void populateXeTileToXeGPUConversionPatterns(XeGPUTypeConverter &converter, + mlir::RewritePatternSet &patterns); + +/// Create a pass to convert the XeTile dialect to the XeGPU dialect. +std::unique_ptr> +createConvertXeTileToXeGPUPass(); + +} // namespace imex + +#endif // _XeTileToXeGPU_H_INCLUDED_ diff --git a/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h b/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h new file mode 100644 index 000000000..708d1a68d --- /dev/null +++ b/include/imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h @@ -0,0 +1,190 @@ +//===- TypeConverter.h - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the SgXeTileToXeGPUConversion, the base class for +/// XeTileToXeGPU conversion, XeGPUTypeConverter, converting types used in +/// XeTile dialect to types used in XeGPU dialect, XeGPUOneToNPatterRewriter a +/// wrapper around ConversionPatterRewriter providng interface for supporting +/// OneToN replace. +/// +//===----------------------------------------------------------------------===// + +#ifndef _XeTileToXeGPUConversion_H_INCLUDED_ +#define _XeTileToXeGPUConversion_H_INCLUDED_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "imex/Dialect/XeGPU/IR/XeGPUOps.h" +#include "imex/Dialect/XeTile/IR/XeTileOps.h" +#include "imex/Utils/DebugUtils.h" +#include "imex/Utils/PassWrapper.h" +#include "imex/Utils/XeCommon.h" + +namespace imex { + +class XeGPUTypeConverter : public imex::XeTypeConverter { +public: + XeGPUTypeConverter(mlir::MLIRContext &context, ValueAttributeMap &map); + + std::optional + convertTileType(xetile::TileType tileTy, + llvm::SmallVectorImpl &resultTypes) override; + + std::optional + convertVectorType(mlir::VectorType vectorTy, + llvm::SmallVectorImpl &resultTypes) override; +}; + +class XeGPUOneToNPatterRewriter : public mlir::PatternRewriter, + public mlir::RewriterBase::Listener { +public: + explicit XeGPUOneToNPatterRewriter(mlir::ConversionPatternRewriter &rewriter, + XeGPUTypeConverter &converter) + : mlir::PatternRewriter(rewriter.getContext()), typeConverter(converter), + rewriter(rewriter) { + setListener(this); + } + + mlir::Block * + applySignatureConversion(mlir::Region *region, + mlir::TypeConverter::SignatureConversion &conversion, + const mlir::TypeConverter *converter = nullptr); + + template + OpTy create(mlir::Location location, Args &&...args) { + return rewriter.create(location, std::forward(args)...); + } + + mlir::FailureOr convertRegionTypes( + mlir::Region *region, const mlir::TypeConverter &converter, + mlir::TypeConverter::SignatureConversion *entryConversion = nullptr) { + return rewriter.convertRegionTypes(region, converter, entryConversion); + } + + void inlineRegionBefore(mlir::Region ®ion, mlir::Region &parent, + mlir::Region::iterator before) override { + rewriter.inlineRegionBefore(region, parent, before); + } + + void replaceOp(mlir::Operation *op, mlir::Operation *newOp) override { + assert(op && newOp && "expected non-null op"); + replaceOp(op, newOp->getResults()); + } + + void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) override; + + void eraseOp(mlir::Operation *op) override { rewriter.eraseOp(op); } + + template + void updateRootInPlace(mlir::Operation *root, CallableT &&callable) { + rewriter.updateRootInPlace(root, callable); + } + + mlir::ConversionPatternRewriter &mlirConversionPatterRewriter() { + return rewriter; + }; + +private: + XeGPUTypeConverter &typeConverter; + mlir::ConversionPatternRewriter &rewriter; +}; + +template +class SgXeTileToXeGPUConversion : public XeConversionPattern { +public: + SgXeTileToXeGPUConversion(mlir::MLIRContext *context, + XeGPUTypeConverter &typeConverter, + mlir::PatternBenefit benefit = 1) + : XeConversionPattern(typeConverter, SourceOp::getOperationName(), + benefit, context) {} + + using RangeT = llvm::ArrayRef; + using OpAdaptor = typename SourceOp::template GenericAdaptor; + + /* + * This overwrites the RewritePattern::matchAndRewrite as it is the entry + * point. It will set up the OpAdaptor such that it contains the converted + * values, and wrap the ConversionPatternRewriter with + * XeGPUOneToNPatterRewriter to provide a clean interface for users. + */ + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const final { + llvm::SmallVector convertedValues; + + // converted into convertionPatternRewriter since applyPartialConversion + // used it + auto &convertionPatternRewriter = + static_cast(rewriter); + + // One-To-One mapping provided by mlir::ConversionPatternRewriter. + // remappedValues contains new values for each operand of the operation. It + // is supposed to be a UnrealizedConversionCastOp (created by the replaceOp + // of XeGPUOneToNPatternRewriter in form of cast newvalues to oldType) for + // each operand that has One-to-N mapping. + llvm::SmallVector remappedValues; + if (mlir::failed(convertionPatternRewriter.getRemappedValues( + op->getOperands(), remappedValues))) { + return op->emitOpError("Failed to get remapped values.\n"); + // return mlir::failure(); + } + + // get the One-to-N converted types. + auto operandTys = op->getOperandTypes(); + mlir::OneToNTypeMapping operandMapping(operandTys); + if (mlir::failed( + typeConverter.computeTypeMapping(operandTys, operandMapping))) { + return op->emitOpError("Failed to compute Type mapping.\n"); + // return mlir::failure(); + } + + // retrive mapped values for each operand. If its type is not convereted + // (convertedTypes.size() == 1) we will reuse the current value. Otherwise, + // it has one-to-n mapping, and the new value should be an + // UnrealizedConversionCastOp. + for (auto [idx, value] : llvm::enumerate(remappedValues)) { + mlir::TypeRange convertedTypes = operandMapping.getConvertedTypes(idx); + if (convertedTypes.size() == 1) { + convertedValues.push_back(value); + } else if (auto castOp = + llvm::dyn_cast_or_null( + value.getDefiningOp())) { + convertedValues.push_back(castOp.getInputs()); + } else { + return op->emitError( + "[SgXeTileToXeGPUConversion::matchAndRewrite] Unexpected that " + "cannot figure out the remapped input value."); + } + } + + auto sourceOp = llvm::dyn_cast(op); + OpAdaptor adaptor(convertedValues, sourceOp); + XeGPUOneToNPatterRewriter OneToNRewriter( + convertionPatternRewriter, getTypeConverter()); + return matchAndRewrite(sourceOp, adaptor, OneToNRewriter); + } + + virtual mlir::LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + XeGPUOneToNPatterRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + } +}; + +} // namespace imex + +#endif diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td b/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td index 2d5f97da3..49422f46c 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -20,14 +20,14 @@ def XeGPU_ScatteredAttr : XeGPUAttr<"Scattered", "scattered"> { let assemblyFormat = ""; } -def XeGPU_SgMapAttr: XeGPUAttr<"SgMap", "sg_map"> { +def XeGPU_SgMapAttr: XeGPUAttr<"SubGroupMap", "sg_map"> { let parameters = (ins ArrayRefParameter<"unsigned">:$wiLayout, ArrayRefParameter<"unsigned">:$wiData, ArrayRefParameter<"unsigned">:$mmaBlockSize); // In format of #xegpu.sg_map<{mma_block_size = [2, 4], wi_layout = [2, 4], wi_data = [2, 4]}> - let assemblyFormat = "`<` custom($wiLayout, $wiData, $mmaBlockSize) `>`"; + let assemblyFormat = "`<` custom($wiLayout, $wiData, $mmaBlockSize) `>`"; let genVerifyDecl = true; @@ -52,7 +52,7 @@ def XeGPU_SgMapAttr: XeGPUAttr<"SgMap", "sg_map"> { let skipDefaultBuilders = 1; } -def XeGPU_WgMapAttr: XeGPUAttr<"WgMap", "wg_map"> { +def XeGPU_WgMapAttr: XeGPUAttr<"WorkGroupMap", "wg_map"> { let parameters = (ins ArrayRefParameter<"unsigned">:$sgLayout, ArrayRefParameter<"unsigned">:$sgData); @@ -71,7 +71,7 @@ def XeGPU_WgMapAttr: XeGPUAttr<"WgMap", "wg_map"> { let skipDefaultBuilders = 1; // In format of #xegpu.wg_map<{sg_layout = [2, 4], sg_data = [2, 4]}> - let assemblyFormat = "`<` custom($sgLayout, $sgData) `>`"; + let assemblyFormat = "`<` custom($sgLayout, $sgData) `>`"; } def XeGPU_XeMapAttr: XeGPUAttr<"XeMap", "xe_map"> { @@ -90,8 +90,8 @@ def XeGPU_XeMapAttr: XeGPUAttr<"XeMap", "xe_map"> { assert(sgLayout.size() == 2 && sgData.size() == 2 && "sgLayout and sgData should be 2D arrays.\n"); assert(wiLayout.size() == 2 && wiData.size() == 2 && "wiLayout and wiData should be 2D arrays.\n"); assert((mmaBlockSize.size() == 2 || mmaBlockSize.size() == 0) && "mmaBlockSize can be either empty or a 2D array.\n"); - auto wg = WgMapAttr::get($_ctxt, sgLayout, sgData); - auto sg = SgMapAttr::get($_ctxt, wiLayout, wiData, mmaBlockSize); + auto wg = WorkGroupMapAttr::get($_ctxt, sgLayout, sgData); + auto sg = SubGroupMapAttr::get($_ctxt, wiLayout, wiData, mmaBlockSize); return $_get($_ctxt, wg, sg); }]> ]; diff --git a/include/imex/Dialect/XeGPU/IR/XeGPUOps.td b/include/imex/Dialect/XeGPU/IR/XeGPUOps.td index ca917f9a4..c5f9c7d49 100644 --- a/include/imex/Dialect/XeGPU/IR/XeGPUOps.td +++ b/include/imex/Dialect/XeGPU/IR/XeGPUOps.td @@ -441,7 +441,7 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas"> { ); let results = (outs XeGPU_Vector2DType: $result); let assemblyFormat = [{ - $lhs `,` $rhs (`,` $acc^)? (`{` `mode` `=` $mode^ `}`)? attr-dict `:` + $lhs `,` $rhs (`,` $acc^)? (` ``{` `mode` `=` $mode^ `}`)? attr-dict `:` qualified(type($lhs)) `,` qualified(type($rhs)) (`,` qualified(type($acc))^)? `->` qualified(type($result)) }]; diff --git a/include/imex/Dialect/XeTile/IR/XeTileOps.td b/include/imex/Dialect/XeTile/IR/XeTileOps.td index e106b60ac..590bdc482 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileOps.td +++ b/include/imex/Dialect/XeTile/IR/XeTileOps.td @@ -328,7 +328,7 @@ def XeTile_PrefetchTileOp : XeTile_Op<"prefetch_tile", []> { }]; } -def XeTile_TileMMAOp : XeTile_Op<"tile_mma", [Pure]> { +def XeTile_TileMMAOp : XeTile_Op<"tile_mma", []> { let summary = "matrix multiplication in blocked layout"; let description = [{ "tile_mma" operation represents matrix multiplication on 2D or 4D vectors. This operation diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.h b/include/imex/Dialect/XeTile/Transforms/Passes.h index e7948d9b5..bd20c0cad 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.h +++ b/include/imex/Dialect/XeTile/Transforms/Passes.h @@ -28,16 +28,17 @@ class RewritePatternSet; namespace imex { +class XeTypeConverter; + //===----------------------------------------------------------------------===// /// XeTile passes. //===----------------------------------------------------------------------===// -/// Create a pass for converting XeTile Ops to XeGPU Ops -std::unique_ptr<::mlir::Pass> createXeTileToXeGPUPass(); +std::unique_ptr createXeTileTilingPass(); -/// Populate the given list with patterns that eliminate XeTile ops -void populateXeTileToXeGPUPatterns(::mlir::LLVMTypeConverter &converter, - ::mlir::RewritePatternSet &patterns); +/// +void populateXeTileTilingPatterns(imex::XeTypeConverter &converter, + mlir::RewritePatternSet &patterns); //===----------------------------------------------------------------------===// // Registration diff --git a/include/imex/Dialect/XeTile/Transforms/Passes.td b/include/imex/Dialect/XeTile/Transforms/Passes.td index 872c0db2d..176092f80 100644 --- a/include/imex/Dialect/XeTile/Transforms/Passes.td +++ b/include/imex/Dialect/XeTile/Transforms/Passes.td @@ -17,23 +17,17 @@ include "mlir/Pass/PassBase.td" -//===----------------------------------------------------------------------===// -// XeTileToXeGPU pass -//===----------------------------------------------------------------------===// +def XeTileTiling : Pass<"xetile-tiling", "::mlir::ModuleOp">{ + let summary = "transform XeTile large tiles(input) into register region block layout"; -// def XeTileToXeGPU: Pass<"xetile-xegpu", "::mlir::func::FuncOp"> { -// let summary = "Add a pass for lowering XeTile dialect ops to XeGPU"; -// let description = [{ -// -// #### Input invariant -// -// -// #### Output IR -// -// }]; -// let constructor = "::imex::createXeTileToXeGPUPass()"; -// let dependentDialects = ["::imex::xetile::XeTileDialect"]; -// let options = []; -// } + let description = [{ + This pass transforms XeTile large tiles smaller tiles with blocked layout to map to register region. + This blocked layout is represented by high dimension vectors, inner dimension matches to DPAS size + config, This lowers 2D vector to 4D vector. + }]; + + let constructor = "imex::createXeTileTilingPass()"; + let dependentDialects = ["::imex::xetile::XeTileDialect"]; +} #endif // _XeTile_PASSES_TD_INCLUDED_ diff --git a/include/imex/InitIMEXPasses.h b/include/imex/InitIMEXPasses.h index a8b92bc22..93e40d059 100644 --- a/include/imex/InitIMEXPasses.h +++ b/include/imex/InitIMEXPasses.h @@ -20,6 +20,7 @@ #include // #include #include "imex/Transforms/Passes.h" +#include #include @@ -41,6 +42,7 @@ inline void registerAllPasses() { // Dialect passes registerPTensorPasses(); + registerXeTilePasses(); // register*Passes(); // Dialect pipelines diff --git a/include/imex/Transforms/Passes.td b/include/imex/Transforms/Passes.td index ab9c5b1ab..70298f49d 100644 --- a/include/imex/Transforms/Passes.td +++ b/include/imex/Transforms/Passes.td @@ -108,5 +108,4 @@ def BF16ToGPU : Pass<"bf16-to-gpu", "::mlir::ModuleOp"> { "::mlir::arith::ArithDialect" ]; } - #endif // _IMEX_TRANSFORMS_PASSES_TD_INCLUDED_ diff --git a/include/imex/Utils/DebugUtils.h b/include/imex/Utils/DebugUtils.h new file mode 100644 index 000000000..9c98809d3 --- /dev/null +++ b/include/imex/Utils/DebugUtils.h @@ -0,0 +1,55 @@ + +#ifndef _DEBUGUTILS_H_INCLUDED_ +#define _DEBUGUTILS_H_INCLUDED_ + +#include +#include + +#include +#include + +static std::string getValueAsString(mlir::Value op, bool asOperand = false) { + std::string buf; + buf.clear(); + llvm::raw_string_ostream os(buf); + auto flags = ::mlir::OpPrintingFlags().assumeVerified(); + if (asOperand) + op.printAsOperand(os, flags); + else + op.print(os, flags); + os.flush(); + return buf; +} + +// It construct a string representation for the given array. +// It helps for printing debug information +template +static std::string makeString(T array, bool breakline = false) { + std::string buf; + buf.clear(); + llvm::raw_string_ostream os(buf); + os << "["; + for (size_t i = 1; i < array.size(); i++) { + os << array[i - 1] << ", "; + if (breakline) + os << "\n\t\t"; + } + os << array.back() << "]"; + os.flush(); + return buf; +} + +template static void dumpToFile(T val, std::string name) { + std::string buf; + buf.clear(); + + llvm::raw_string_ostream os(buf); + os << val << "\n"; + os.flush(); + + std::ofstream ofs(name, std::ofstream::out); + ofs << buf; + ofs.close(); +} + +#endif diff --git a/include/imex/Utils/XeCommon.h b/include/imex/Utils/XeCommon.h new file mode 100644 index 000000000..69c98f4a5 --- /dev/null +++ b/include/imex/Utils/XeCommon.h @@ -0,0 +1,198 @@ + +//===- XeUtils.h - XeTile/XeGPU Utility Functions --------------------*- C++ +//-*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utility functions used by XeTile/XeGPU dialects. +// +//===----------------------------------------------------------------------===// + +#ifndef _IMEX_XECOMMON_H_ +#define _IMEX_XECOMMON_H_ + +#include +#include +#include +#include + +namespace imex { + +/** + * None: Not associated with dpas + * DPASA: asscociated with dpas A operand + * DPASB: asscociated with dpas B operand + * DPASC: asscociated with dpas C operand + * DPASR: asscociated with dpas result value. + */ +enum class OperandType { None = 0, DPASA = 1, DPASB = 2, DPASC = 4, DPASR = 8 }; + +class ValueAttributeMap { +public: + ValueAttributeMap() {} + + void add(mlir::BlockArgument arg, imex::OperandType type); + void add(mlir::Operation *op, imex::OperandType type); + + imex::OperandType get(mlir::Operation *op); + imex::OperandType get(mlir::BlockArgument arg); + +private: + llvm::DenseMap operationMap; + llvm::DenseMap argumentMap; +}; + +void markDefChainValues(mlir::Value value, imex::OperandType type, + imex::ValueAttributeMap &map); +void markUseChainValues(mlir::Value value, imex::OperandType type, + imex::ValueAttributeMap &map); + +mlir::ValueRange buildUnrealizedCast(mlir::OpBuilder &builder, + mlir::TypeRange resultTypes, + mlir::ValueRange inputs); + +class XeTypeConverter : public mlir::OneToNTypeConverter { +public: + using mlir::OneToNTypeConverter::convertType; + + XeTypeConverter(mlir::MLIRContext &context, imex::ValueAttributeMap &map) + : context(context), map(map) { + addConversion([&](xetile::TileType tileTy, + llvm::SmallVectorImpl &resultTypes) + -> std::optional { + return convertTileType(tileTy, resultTypes); + }); + + addConversion([&](mlir::VectorType vectorTy, + llvm::SmallVectorImpl &resultTypes) + -> std::optional { + return convertVectorType(vectorTy, resultTypes); + }); + } + + virtual std::optional + convertTileType(xetile::TileType tileTy, + llvm::SmallVectorImpl &resultTypes) { + llvm_unreachable("Pending Implementation for convertTileType."); + } + + virtual std::optional + convertVectorType(mlir::VectorType vectorTy, + llvm::SmallVectorImpl &resultTypes) { + llvm_unreachable("Pending Implementation for convertVectorType."); + } + + imex::OperandType get(mlir::Operation *op) { return map.get(op); } + + imex::OperandType get(mlir::BlockArgument arg) { return map.get(arg); } + + bool isA(mlir::Operation *op) { + return (int(map.get(op)) & int(imex::OperandType::DPASA)); + } + + bool isA(mlir::BlockArgument arg) { + return (int(map.get(arg)) & int(imex::OperandType::DPASA)); + } + + bool isAOnly(mlir::Operation *op) { return isA(op) && !isB(op) && !isRC(op); } + + bool isAOnly(mlir::BlockArgument arg) { + return isA(arg) && !isB(arg) && !isRC(arg); + } + + bool isB(mlir::Operation *op) { + return (int(map.get(op)) & int(imex::OperandType::DPASB)); + } + + bool isB(mlir::BlockArgument arg) { + return (int(map.get(arg)) & int(imex::OperandType::DPASB)); + } + + bool isBOnly(mlir::Operation *op) { return isB(op) && !isA(op) && !isRC(op); } + + bool isBOnly(mlir::BlockArgument arg) { + return isB(arg) && !isB(arg) && !isRC(arg); + } + + bool isRC(mlir::Operation *op) { + return int(map.get(op)) & + (int(imex::OperandType::DPASC) | int(imex::OperandType::DPASR)); + } + + bool isRC(mlir::BlockArgument arg) { + return int(map.get(arg)) & + (int(imex::OperandType::DPASC) | int(imex::OperandType::DPASR)); + } + + bool isRCOnly(mlir::Operation *op) { + return isRC(op) && !isA(op) && !isB(op); + } + + bool isRCOnly(mlir::BlockArgument arg) { + return isRC(arg) && !isA(arg) && !isB(arg); + } + +private: + mlir::MLIRContext &context; + imex::ValueAttributeMap ↦ +}; + +// A simple mlir::RewritePattern wrapper with methods for accessing OperandType +class XeConversionPattern : public mlir::RewritePattern { +public: + using mlir::RewritePattern::RewritePattern; + + template + XeConversionPattern(imex::XeTypeConverter &typeConverter, Args &&...args) + : mlir::RewritePattern(std::forward(args)...), + typeConverter(typeConverter) {} + + virtual mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + }; + + imex::OperandType getOperandType(mlir::Operation *op) const { + return typeConverter.get(op); + } + + imex::OperandType getOperandType(mlir::BlockArgument arg) const { + return typeConverter.get(arg); + } + + bool isA(mlir::Operation *op) const { return typeConverter.isAOnly(op); } + + bool isA(mlir::BlockArgument arg) const { return typeConverter.isAOnly(arg); } + + bool isB(mlir::Operation *op) const { return typeConverter.isBOnly(op); } + + bool isB(mlir::BlockArgument arg) const { return typeConverter.isBOnly(arg); } + + bool isRC(mlir::Operation *op) const { return typeConverter.isRCOnly(op); } + + bool isRC(mlir::BlockArgument arg) const { + return typeConverter.isRCOnly(arg); + } + + imex::XeTypeConverter &getTypeConverter() const { return typeConverter; } + + template + std::enable_if_t::value, + ConverterTy &> + getTypeConverter() const { + return static_cast(typeConverter); + } + +protected: + imex::XeTypeConverter &typeConverter; +}; + +} // namespace imex + +#endif diff --git a/include/imex/Utils/XeUtils.h b/include/imex/Utils/XeUtils.h deleted file mode 100644 index 50ba8624b..000000000 --- a/include/imex/Utils/XeUtils.h +++ /dev/null @@ -1,48 +0,0 @@ - -//===- XeUtils.h - XeTile/XeGPU Utility Functions --------------------*- C++ -//-*-===// -// -// Copyright 2022 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header file defines utility functions used by XeTile/XeGPU dialects. -// -//===----------------------------------------------------------------------===// - -#ifndef _IMEX_XEUTILS_H_ -#define _IMEX_XEUTILS_H_ - -#include "mlir/IR/Value.h" - -static std::string getValueAsString(::mlir::Value op, bool asOperand = false) { - std::string buf; - buf.clear(); - llvm::raw_string_ostream os(buf); - auto flags = ::mlir::OpPrintingFlags().assumeVerified(); - if (asOperand) - op.printAsOperand(os, flags); - else - op.print(os, flags); - os.flush(); - return buf; -} - -template static std::string makeString(T array) { - std::string buf; - buf.clear(); - llvm::raw_string_ostream os(buf); - os << "["; - for (auto i = 1; i < array.size(); i++) - os << array[i - 1] << ", "; - if (array.size()) - os << array[array.size() - 1]; - os << "]"; - os.flush(); - return buf; -} - -#endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 7907974a2..ab837afa9 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(GPUToSPIRV) add_subdirectory(GPUToGPUX) add_subdirectory(GPUXToLLVM) add_subdirectory(XeGPUToSPIRV) +add_subdirectory(XeTileToXeGPU) diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index 7fd24904f..3fde5759e 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -70,6 +70,10 @@ class TensorDialect; namespace gpu { class GPUDialect; } // namespace gpu + +namespace vector { +class VectorDialect; +} } // namespace mlir namespace imex { @@ -85,6 +89,14 @@ namespace gpux { class GPUXDialect; } // namespace gpux +namespace xegpu { +class XeGPUDialect; +} + +namespace xetile { +class XeTileDialect; +} + #define GEN_PASS_CLASSES #include diff --git a/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp new file mode 100644 index 000000000..44fc6d914 --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp @@ -0,0 +1,87 @@ +//===- ArithOpConversion.cpp - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the ArithOpConversionPattern, used in XeTileToXeGPU +/// conversion, converting the Arith Ops. +/// +//===----------------------------------------------------------------------===// + +#include "ArithOpConversion.h" + +namespace imex { + +class SgArithConstantOpPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion< + mlir::arith::ConstantOp>::SgXeTileToXeGPUConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::arith::ConstantOp op, OpAdaptor adaptor, + XeGPUOneToNPatterRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto result = op.getResult(); + auto resultTy = result.getType(); + + if (!resultTy.isa()) + return mlir::failure(); + + auto vectorTy = resultTy.cast(); + + // We only interesting 4D vectors + if (vectorTy.getRank() != 4) + return mlir::failure(); + + auto shape = vectorTy.getShape(); + auto subVectorTy = ::mlir::VectorType::get({shape[2], shape[3]}, + vectorTy.getElementType()); + + auto valueAttr = op.getValue(); + if (!valueAttr.isa()) + return mlir::failure(); + + auto denseElementsAttr = valueAttr.cast(); + if (!denseElementsAttr.isSplat()) + return mlir::failure(); + + auto splatVal = denseElementsAttr.getSplatValue(); + + rewriter.setInsertionPoint(op); + llvm::SmallVector newOps; + for (auto i = 0; i < shape[0]; i++) { + for (auto j = 0; j < shape[1]; j++) { + auto newOp = rewriter.create( + loc, subVectorTy, + mlir::DenseElementsAttr::get(subVectorTy, splatVal)); + newOps.push_back(newOp); + } + } + + rewriter.replaceOp(op, newOps); + return mlir::success(); + } +}; + +bool isLegalArithOp(mlir::Operation *op) { + if (llvm::isa(op)) { + auto constOp = llvm::cast(op); + auto resultTy = constOp.getResult().getType(); + if (resultTy.isa() && + resultTy.cast().getRank() == 4) + return false; + } + return true; +} + +void populateArithOpConversionPatterns(imex::XeGPUTypeConverter &converter, + mlir::RewritePatternSet &patterns) { + patterns.add(patterns.getContext(), converter); +} + +} // namespace imex diff --git a/lib/Conversion/XeTileToXeGPU/ArithOpConversion.h b/lib/Conversion/XeTileToXeGPU/ArithOpConversion.h new file mode 100644 index 000000000..f3733ae78 --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/ArithOpConversion.h @@ -0,0 +1,28 @@ +//===- ArithOpConversion.h - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the ArithOpConversionPattern, used in XeTileToXeGPU +/// conversion, converting the Arith Ops. +/// +//===----------------------------------------------------------------------===// +#ifndef _ArithOpConversion_H_INCLUDED_ +#define _ArithOpConversion_H_INCLUDED_ + +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h" +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" + +namespace imex { +bool isLegalArithOp(mlir::Operation *op); + +void populateArithOpConversionPatterns(imex::XeGPUTypeConverter &converter, + mlir::RewritePatternSet &patterns); + +} // namespace imex +#endif diff --git a/lib/Conversion/XeTileToXeGPU/CMakeLists.txt b/lib/Conversion/XeTileToXeGPU/CMakeLists.txt new file mode 100644 index 000000000..ec85a4362 --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(IMEXXeTileToXeGPU + ArithOpConversion.cpp + SCFOpConversion.cpp + XeTileToXeGPU.cpp + XeTileOpConversion.cpp + XeTileToXeGPUConversion.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/imex/Conversion/XeTileToXeGPU + + DEPENDS + IMEXConversionPassIncGen + + LINK_LIBS PUBLIC + IMEXXeGPUDialect +) diff --git a/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp new file mode 100644 index 000000000..a89595b2b --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/SCFOpConversion.cpp @@ -0,0 +1,119 @@ +//===- ArithOpConversion.cpp - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the ArithOpConversionPattern, used in XeTileToXeGPU +/// conversion, converting the Arith Ops. +/// +//===----------------------------------------------------------------------===// +#include +#include + +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" + +namespace imex { + +struct SgSCFForOpBlockPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion::SgXeTileToXeGPUConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor, + imex::XeGPUOneToNPatterRewriter &rewriter) const override { + auto loc = op.getLoc(); + + llvm::SmallVector convertedArgs; + // OpAdaptor is defined with ValueRange, so it contains results after + // One-to-N mapping + for (auto values : adaptor.getInitArgs()) + convertedArgs.append(values.begin(), values.end()); + + auto argumentTys = op.getRegion().getArgumentTypes(); + mlir::OneToNTypeMapping argumentMapping(argumentTys); + // compute the type conversion (signature) for SCFFor body arguments. + // argumentMapping is essentially a TypeConverter::SignatureConversion + if (mlir::failed( + typeConverter.computeTypeMapping(argumentTys, argumentMapping))) { + op.emitOpError("Failed to compute the type mapping for arguments.\n"); + return mlir::failure(); + } + + // apply the signature convertion for SCFFor body arguments, an + // UnrealizedConversionCastOp will be inserted by typeConverted by the + // method registered in Materialization methods + if (mlir::failed(rewriter.convertRegionTypes(&op.getRegion(), typeConverter, + &argumentMapping))) { + op.emitOpError("Failed to convert region types.\n"); + return mlir::failure(); + } + + auto newOp = rewriter.create(loc, op.getLowerBound(), + op.getUpperBound(), + op.getStep(), convertedArgs); + + newOp.getBody()->erase(); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + rewriter.replaceOp(op, newOp.getResults()); + return mlir::success(); + } +}; + +struct SgSCFYieldOpPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion< + mlir::scf::YieldOp>::SgXeTileToXeGPUConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::scf::YieldOp op, OpAdaptor adaptor, + imex::XeGPUOneToNPatterRewriter &rewriter) const override { + llvm::SmallVector convertedResults; + for (auto values : adaptor.getResults()) + convertedResults.append(values.begin(), values.end()); + + auto newOp = + rewriter.create(op.getLoc(), convertedResults); + + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + +bool isLegalSCFOp(mlir::Operation *op) { + bool result = true; + if (llvm::isa(op)) { + auto forOp = llvm::cast(op); + for (auto arg : forOp.getInitArgs()) { + auto type = arg.getType(); + result &= !type.isa(); + + if (type.isa()) + result &= (type.cast().getRank() != 4); + } + } + + if (llvm::isa(op)) { + auto yieldOp = llvm::cast(op); + for (auto arg : yieldOp.getResults()) { + auto type = arg.getType(); + result &= !type.isa(); + if (type.isa()) + result &= (type.cast().getRank() != 4); + } + } + return result; +} + +void populateSCFOpConversionPatterns(imex::XeGPUTypeConverter &converter, + mlir::RewritePatternSet &patterns) { + patterns.add( + patterns.getContext(), converter); +} + +} // namespace imex diff --git a/lib/Conversion/XeTileToXeGPU/SCFOpConversion.h b/lib/Conversion/XeTileToXeGPU/SCFOpConversion.h new file mode 100644 index 000000000..07996b955 --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/SCFOpConversion.h @@ -0,0 +1,28 @@ +//===- ArithOpConversion.h - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the ArithOpConversionPattern, used in XeTileToXeGPU +/// conversion, converting the Arith Ops. +/// +//===----------------------------------------------------------------------===// +#ifndef _SCFOpConversion_H_INCLUDED_ +#define _SCFOpConversion_H_INCLUDED_ + +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h" +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" + +namespace imex { +bool isLegalSCFOp(mlir::Operation *op); + +void populateSCFOpConversionPatterns(imex::XeGPUTypeConverter &converter, + mlir::RewritePatternSet &patterns); + +} // namespace imex +#endif diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp new file mode 100644 index 000000000..2e5fc75bf --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp @@ -0,0 +1,340 @@ +//===- XeTileOpConversion.h - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements ConversionPatterns for XeTileOps, used in XeTileToXeGPU +/// conversion, converting the XeTile dialect to the XeGPU dialect. +/// +//===----------------------------------------------------------------------===// + +#include + +#include "ArithOpConversion.h" +#include "SCFOpConversion.h" +#include "XeTileOpConversion.h" + +namespace imex { + +// Sg-level XeTile::init_tile -> XeGPU::init_tile +class SgInitTileOpPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion< + xetile::InitTileOp>::SgXeTileToXeGPUConversion; + + mlir::LogicalResult + matchAndRewrite(xetile::InitTileOp op, OpAdaptor adaptor, + XeGPUOneToNPatterRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto source = op.getSource(); + auto resultTile = op.getResult(); + auto resTileType = resultTile.getType(); + auto resTileShape = resTileType.getShape(); + auto indexType = rewriter.getIndexType(); + + llvm::SmallVector offsets; + auto staticOffsets = op.getStaticOffsets(); + auto dynamicOffsets = op.getOffsets(); + for (int i = 0, j = 0; i != staticOffsets.size(); i++) { + if (mlir::ShapedType::isDynamic(staticOffsets[i])) { + offsets.push_back(dynamicOffsets[j++]); + } else { + offsets.push_back(rewriter.create( + op.getLoc(), rewriter.getIndexAttr(staticOffsets[i]))); + } + } + + auto offsetsX = offsets[0]; + auto offsetsY = offsets[1]; + + if (resTileType.getRank() != 4) + return mlir::failure(); + + auto createIndexConstant = [&](mlir::Type type, int64_t value) { + auto attr = rewriter.getIndexAttr(value); + return rewriter.create(loc, type, attr); + }; + + auto tDescTy = xegpu::TensorDescType::get( + {resTileShape[2], resTileShape[3]}, resTileType.getElementType(), + imex::xegpu::MemoryScope::GLOBAL /*memory scope*/); + + rewriter.setInsertionPoint(op); + llvm::SmallVector xegpuOps; + for (int i = 0; i < resTileShape[0]; i++) { + for (int j = 0; j < resTileShape[1]; j++) { + auto subOffX = createIndexConstant(indexType, (resTileShape[2] * i)); + auto subOffY = createIndexConstant(indexType, (resTileShape[3] * j)); + auto tDescOffsetX = + rewriter.createOrFold(loc, subOffX, offsetsX); + auto tDescOffsetY = + rewriter.createOrFold(loc, subOffY, offsetsY); + mlir::SmallVector tDescOffsets{tDescOffsetX, + tDescOffsetY}; + + constexpr int64_t kDynamic = std::numeric_limits::min(); + + // TODO: this needs improvement, it assumes the source is static + // memeref. + auto createNdOp = rewriter.create( + op.getLoc(), tDescTy /*resultTy*/, source /*source*/, + tDescOffsets /*offsets*/, true /*boboundary_check*/, + imex::xegpu::Mode::VC /*mode*/); + + xegpuOps.push_back(createNdOp); + } + } + + rewriter.replaceOp(op, xegpuOps); + return mlir::success(); + } +}; + +// Sg-level XeTile::prefetch_tile -> XeGPU::prefetch_2d +struct SgPrefetchTileOpPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion< + xetile::PrefetchTileOp>::SgXeTileToXeGPUConversion; + + ::mlir::LogicalResult + matchAndRewrite(xetile::PrefetchTileOp op, OpAdaptor adaptor, + XeGPUOneToNPatterRewriter &rewriter) const override { + auto tileTy = op.getTile().getType(); + auto tiles = adaptor.getTile(); + if (tileTy.getRank() != 4) + return mlir::failure(); + auto shape = tileTy.getShape(); + + if (shape[0] * shape[1] != tiles.size()) { + op.emitOpError("Failed to lower LoadTileOp because shape[0] * shape[1] " + "!= sources.size()."); + return mlir::failure(); + } + + auto elementTy = tileTy.getElementType(); + auto subVectorTy = mlir::VectorType::get({shape[2], shape[3]}, elementTy); + + auto L1 = xegpu::CacheReadHintAttr::get(op.getContext(), + xegpu::CacheReadHint::CACHED); + auto L2 = xegpu::CacheReadHintAttr::get(op.getContext(), + xegpu::CacheReadHint::CACHED); + auto L3 = xegpu::CacheReadHintAttr::get(op.getContext(), + xegpu::CacheReadHint::CACHED); + + for (int i = 0; i < shape[0]; i++) { + for (int j = 0; j < shape[1]; j++) { + auto tile = tiles[i * shape[1] + j]; + rewriter.create( + op.getLoc(), subVectorTy, tile, mlir::IntegerAttr(), + mlir::DenseI64ArrayAttr(), L1, L2, L3, imex::xegpu::Mode::VC); + } + } + + rewriter.eraseOp(op); + + return mlir::success(); + } +}; + +// Sg-level XeTile::load_tile -> XeGPU::load_2d +struct SgLoadTileOpPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion< + xetile::LoadTileOp>::SgXeTileToXeGPUConversion; + + mlir::LogicalResult + matchAndRewrite(xetile::LoadTileOp op, OpAdaptor adaptor, + XeGPUOneToNPatterRewriter &rewriter) const override { + auto resultTy = op.getValue().getType(); + auto tileTy = op.getSource().getType(); + + if (resultTy.getRank() != 4 || tileTy.getRank() != 4) + return mlir::failure(); + + auto shape = resultTy.getShape(); + auto sources = adaptor.getSource(); + + if (shape[0] * shape[1] != sources.size()) { + op.emitOpError("Failed to lower LoadTileOp because shape[0] * shape[1] " + "!= sources.size()."); + return mlir::failure(); + } + + auto elementTy = resultTy.getElementType(); + + // TODO: move these two into architecture abstracture in future. + const int SIMD_WIDTH_IN_BITS = 32; + int vnniFactor = SIMD_WIDTH_IN_BITS / elementTy.getIntOrFloatBitWidth(); + + int vnniAxis = 1; + mlir::IntegerAttr vnniAxisAttr; + auto transposeAttr = op.getTransposeAttr(); + auto L1 = xegpu::CacheReadHintAttr::get(op.getContext(), + xegpu::CacheReadHint::UNCACHED); + auto L2 = xegpu::CacheReadHintAttr::get(op.getContext(), + xegpu::CacheReadHint::UNCACHED); + auto L3 = xegpu::CacheReadHintAttr::get(op.getContext(), + xegpu::CacheReadHint::UNCACHED); + + llvm::SmallVector newShape = {shape[2], shape[3]}; + // needs vnni transform; + if (vnniFactor > 1 && (isA(op) || isB(op))) { + if (isB(op)) + vnniAxis = 0; + newShape[vnniAxis] /= vnniFactor; + newShape.push_back(vnniFactor); + vnniAxisAttr = + rewriter.getIntegerAttr(rewriter.getIntegerType(32), vnniAxis); + } + + auto subVectorTy = + ::mlir::VectorType::get(newShape, resultTy.getElementType()); + + rewriter.setInsertionPoint(op); + + llvm::SmallVector<::mlir::Value> xegpuOps; + for (int i = 0; i < shape[0]; i++) { + for (int j = 0; j < shape[1]; j++) { + auto tile = sources[i * shape[1] + j]; + auto ldOp = rewriter.create( + op.getLoc(), subVectorTy, tile, vnniAxisAttr, transposeAttr, L1, L2, + L3, imex::xegpu::Mode::VC); + xegpuOps.push_back(ldOp); + } + } + + rewriter.replaceOp(op, xegpuOps); + return mlir::success(); + } +}; + +// Sg-level XeTile::store_tile -> XeGPU::store_2d +struct SgStoreTileOpPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion< + xetile::StoreTileOp>::SgXeTileToXeGPUConversion; + + ::mlir::LogicalResult + matchAndRewrite(xetile::StoreTileOp op, OpAdaptor adaptor, + XeGPUOneToNPatterRewriter &rewriter) const override { + auto tiles = adaptor.getTile(); + auto values = adaptor.getValue(); + + if (tiles.size() != values.size()) { + op.emitOpError() << "Failed to lower the StoreOp, because tile and block " + "size doesn't match." + << "tiles: " << tiles.size() << ", " + << "values: " << values.size() << "\n"; + return mlir::failure(); + } + + auto context = op.getContext(); + auto L1 = xegpu::CacheWriteHintAttr::get(context, + xegpu::CacheWriteHint::UNCACHED); + auto L2 = xegpu::CacheWriteHintAttr::get(context, + xegpu::CacheWriteHint::UNCACHED); + auto L3 = xegpu::CacheWriteHintAttr::get(context, + xegpu::CacheWriteHint::UNCACHED); + for (size_t i = 0; i < tiles.size(); i++) + rewriter.create(op.getLoc(), tiles[i], values[i], L1, + L2, L3, imex::xegpu::Mode::VC); + + rewriter.eraseOp(op); + return ::mlir::success(); + } +}; + +// Sg-level XeTile::tile_mma-> XeGPU::dpas +struct SgTileMMAOpPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion::SgXeTileToXeGPUConversion; + + ::mlir::LogicalResult + matchAndRewrite(xetile::TileMMAOp op, OpAdaptor adaptor, + XeGPUOneToNPatterRewriter &rewriter) const override { + + auto aShape = op.getAType().getShape(); + auto bShape = op.getBType().getShape(); + + if (aShape.size() != 4 || bShape.size() != 4) { + op.emitOpError() << "Operand A and B for mma should be 4d.\n"; + return mlir::failure(); + } + + if (aShape[3] != bShape[2] || aShape[1] != bShape[0]) { + op.emitOpError() << "A and B size doesn't match. A should be m x k, and " + "B should be k x n"; + return mlir::failure(); + } + + uint64_t M = aShape[0]; + uint64_t K = aShape[1]; + uint64_t N = bShape[1]; + + auto loc = op.getLoc(); + auto AValues = adaptor.getA(); + auto BValues = adaptor.getB(); + auto CValues = adaptor.getC(); + + auto elemTy = op.getOutput().getType().getElementType(); + auto subCTy = mlir::VectorType::get({aShape[2], bShape[3]}, elemTy); + + mlir::SmallVector xegpuOps; + for (uint64_t i = 0; i < M; i++) { + for (uint64_t j = 0; j < N; j++) { + mlir::Value tmpC; + if (op.getC()) + tmpC = CValues[i * N + j]; // init with acc + for (uint64_t k = 0; k < K; k++) { + auto aVec = AValues[i * K + k]; + auto bVec = BValues[k * N + j]; + tmpC = rewriter.create( + loc, subCTy /*result*/, aVec /*lhs*/, bVec /*rhs*/, tmpC /*acc*/, + imex::xegpu::Mode::VC); + } + xegpuOps.push_back(tmpC); + } + } + rewriter.replaceOp(op, xegpuOps); + return mlir::success(); + } +}; + +struct SgUpdateTileOffsetOpPattern + : public SgXeTileToXeGPUConversion { + using SgXeTileToXeGPUConversion< + xetile::UpdateTileOffsetOp>::SgXeTileToXeGPUConversion; + + mlir::LogicalResult + matchAndRewrite(xetile::UpdateTileOffsetOp op, OpAdaptor adaptor, + XeGPUOneToNPatterRewriter &rewriter) const override { + auto offsetX = op.getOffsetX(); + auto offsetY = op.getOffsetY(); + auto tiles = adaptor.getTile(); + + llvm::SmallVector xegpuOps; + for (auto tile : tiles) { + auto xegpuTile = rewriter.create( + op.getLoc(), tile.getType(), tile, mlir::ValueRange{offsetX, offsetY}, + imex::xegpu::Mode::VC); + xegpuOps.push_back(xegpuTile); + } + rewriter.replaceOp(op, xegpuOps); + return mlir::success(); + } +}; + +void populateXeTileOpConversionPatterns(imex::XeGPUTypeConverter &converter, + mlir::RewritePatternSet &patterns) { + patterns.insert(patterns.getContext(), + converter); +} + +} // namespace imex diff --git a/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.h b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.h new file mode 100644 index 000000000..9e46f1a7f --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/XeTileOpConversion.h @@ -0,0 +1,27 @@ +//===- XeTileOpConversion.h - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines ConversionPatterns for XeTileOps, used in XeTileToXeGPU +/// conversion, converting the XeTile dialect to the XeGPU dialect. +/// +//===----------------------------------------------------------------------===// +#ifndef _XeTileOpConversion_H_INCLUDED_ +#define _XeTileOpConversion_H_INCLUDED_ + +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h" + +namespace imex { + +void populateXeTileOpConversionPatterns(imex::XeGPUTypeConverter &converter, + mlir::RewritePatternSet &patterns); + +} // namespace imex + +#endif diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp new file mode 100644 index 000000000..a4ca36c6d --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPU.cpp @@ -0,0 +1,109 @@ +//===- XeTileToXeGPU.cpp - XeTileToXeGPU conversion -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the XeTileToXeGPU conversion, converting the XeTile +/// dialect to the XeGPU dialect. +/// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#include "../PassDetail.h" +#include "ArithOpConversion.h" +#include "SCFOpConversion.h" +#include "XeTileOpConversion.h" + +namespace imex { + +class XeTileConversionTarget : public mlir::ConversionTarget { +public: + explicit XeTileConversionTarget(mlir::MLIRContext &context) + : mlir::ConversionTarget(context) { + addIllegalOp(); + + addLegalOp(); + + addLegalDialect(); + + addDynamicallyLegalDialect( + [&](mlir::Operation *op) { return isLegalArithOp(op); }); + + addDynamicallyLegalDialect( + [&](mlir::Operation *op) { return isLegalSCFOp(op); }); + } +}; + +// Full Pass +struct ConvertXeTileToXeGPUPass // convert XeTile to XeGPU + : public ::imex::ConvertXeTileToXeGPUBase { + ConvertXeTileToXeGPUPass() = default; + + void runOnOperation() override { + mlir::ModuleOp mod = getOperation(); + mlir::MLIRContext &context = getContext(); + + // skip functions with XeTile.TileType inputs and outputs + bool hasTileTyInFuncTy = false; + mod.walk([&](mlir::func::FuncOp op) { + auto funcTy = op.getFunctionType(); + hasTileTyInFuncTy |= std::any_of( + funcTy.getInputs().begin(), funcTy.getInputs().end(), + [](mlir::Type ty) { return llvm::isa(ty); }); + hasTileTyInFuncTy |= std::any_of( + funcTy.getResults().begin(), funcTy.getInputs().end(), + [](mlir::Type ty) { return llvm::isa(ty); }); + }); + + if (hasTileTyInFuncTy) { + mod.emitOpError( + "Currently FunctionType with xetile.TileType is not supported."); + return signalPassFailure(); + } + + imex::ValueAttributeMap map; + mod.walk([&](imex::xetile::TileMMAOp op) { + markDefChainValues(op.getA(), OperandType::DPASA, map); + markDefChainValues(op.getB(), OperandType::DPASB, map); + markDefChainValues(op.getC(), OperandType::DPASC, map); + markUseChainValues(op.getOutput(), OperandType::DPASR, map); + }); + + XeGPUTypeConverter typeConverter(context, map); + XeTileConversionTarget target(context); + mlir::RewritePatternSet patterns(&context); + + populateXeTileToXeGPUConversionPatterns(typeConverter, patterns); + + if (mlir::failed( + mlir::applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + } +}; + +/// Populate the given list with patterns that convert XeTile to XeGPU +void populateXeTileToXeGPUConversionPatterns( + imex::XeGPUTypeConverter &converter, mlir::RewritePatternSet &patterns) { + populateSCFOpConversionPatterns(converter, patterns); + populateArithOpConversionPatterns(converter, patterns); + populateXeTileOpConversionPatterns(converter, patterns); +} + +/// Create a pass that convert XeTile to XeGPU +std::unique_ptr<::mlir::OperationPass<::mlir::ModuleOp>> +createConvertXeTileToXeGPUPass() { + return std::make_unique(); +} + +} // namespace imex diff --git a/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp new file mode 100644 index 000000000..3e4817948 --- /dev/null +++ b/lib/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.cpp @@ -0,0 +1,165 @@ +//===- XeTileToXeGPUConversion.cpp - XeTileToXeGPU conversion -------*- C++ +//-*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the SgXeTileToXeGPUConversion, the base class for +/// XeTileToXeGPU conversion, XeGPUTypeConverter, converting types used in +/// XeTile dialect to types used in XeGPU dialect, XeGPUOneToNPatterRewriter a +/// wrapper around ConversionPatterRewriter providng interface for supporting +/// OneToN replace. +/// +//===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include "../PassDetail.h" + +namespace imex { + +static bool isIdentityConversion(mlir::Type originalType, + mlir::TypeRange convertedTypes) { + return convertedTypes.size() == 1 && convertedTypes[0] == originalType; +} + +static llvm::SmallVector +buildUnrealizedBackwardsCasts(mlir::ValueRange convertedValues, + const mlir::OneToNTypeMapping &typeConversion, + mlir::RewriterBase &rewriter) { + + // assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); + + // Create unrealized cast op for each converted result of the op. + llvm::SmallVector recastValues; + mlir::TypeRange originalTypes = typeConversion.getOriginalTypes(); + recastValues.reserve(originalTypes.size()); + auto convertedValueIt = convertedValues.begin(); + for (auto [idx, originalType] : llvm::enumerate(originalTypes)) { + mlir::TypeRange convertedTypes = typeConversion.getConvertedTypes(idx); + size_t numConvertedValues = convertedTypes.size(); + if (isIdentityConversion(originalType, convertedTypes)) { + // Identity conversion: take result as is. + recastValues.push_back(*convertedValueIt); + } else { + // Non-identity conversion: cast back to source type. + mlir::ValueRange recastValue = buildUnrealizedCast( + rewriter, originalType, + mlir::ValueRange{convertedValueIt, + convertedValueIt + numConvertedValues}); + assert(recastValue.size() == 1); + recastValues.push_back(recastValue.front()); + } + convertedValueIt += numConvertedValues; + } + + return recastValues; +} + +XeGPUTypeConverter::XeGPUTypeConverter(mlir::MLIRContext &context, + imex::ValueAttributeMap &map) + : XeTypeConverter(context, map) { + addConversion( + [&](mlir::IndexType type) -> std::optional { return type; }); + + addConversion( + [&](mlir::MemRefType type) -> std::optional { return type; }); + + addArgumentMaterialization( + [&](mlir::OpBuilder &builder, mlir::Type resultType, + mlir::ValueRange inputs, + mlir::Location loc) -> std::optional { + return builder + .create(loc, resultType, inputs) + .getResult(0); + }); + + addSourceMaterialization( + [&](mlir::OpBuilder &builder, mlir::Type resultType, + mlir::ValueRange inputs, + mlir::Location loc) -> std::optional { + return builder + .create(loc, resultType, inputs) + .getResult(0); + }); +} + +std::optional XeGPUTypeConverter::convertTileType( + xetile::TileType tileTy, llvm::SmallVectorImpl &resultTypes) { + if (tileTy.getRank() == 2) { + resultTypes.push_back(tileTy); + return mlir::success(); + } else if (tileTy.getRank() == 4) { + auto shape = tileTy.getShape(); + auto tdescTy = xegpu::TensorDescType::get({shape[2], shape[3]}, + tileTy.getElementType()); + auto numElements = shape[0] * shape[1]; + resultTypes.assign(numElements, tdescTy); + return mlir::success(); + } + return std::nullopt; +} + +std::optional XeGPUTypeConverter::convertVectorType( + mlir::VectorType vectorTy, llvm::SmallVectorImpl &resultTypes) { + if (vectorTy.getRank() == 4) { + auto shape = vectorTy.getShape(); + auto vecTy = + mlir::VectorType::get({shape[2], shape[3]}, vectorTy.getElementType()); + auto numElements = shape[0] * shape[1]; + resultTypes.assign(numElements, vecTy); + return mlir::success(); + } else if (vectorTy.getRank() == 2) { + resultTypes.push_back(vectorTy); + return mlir::success(); + } + return std::nullopt; +} + +mlir::Block *XeGPUOneToNPatterRewriter::applySignatureConversion( + mlir::Region *region, mlir::TypeConverter::SignatureConversion &conversion, + const mlir::TypeConverter *converter) { + return rewriter.applySignatureConversion(region, conversion, converter); +} + +void XeGPUOneToNPatterRewriter::replaceOp(mlir::Operation *op, + mlir::ValueRange newValues) { + // It is one-to-one mapping, let the ConvertionPatternRewriter handle it + // directly. + if (newValues.size() == op->getNumResults()) { + rewriter.replaceOp(op, newValues); + } else { // it is one-to-N mapping, so create unrealizedCasts to make it as + // one-to-one mapping + llvm::SmallVector recastValues; + auto resultTys = op->getResultTypes(); + mlir::OneToNTypeMapping resultMapping(resultTys); + if (mlir::succeeded( + typeConverter.computeTypeMapping(resultTys, resultMapping))) { + auto castValues = + buildUnrealizedBackwardsCasts(newValues, resultMapping, rewriter); + rewriter.replaceOp(op, castValues); + } else { + llvm_unreachable("It is an unexpected failure of failing to convert the " + "result types."); + } + } +} + +} // namespace imex diff --git a/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 9a5b686c2..4453f1607 100644 --- a/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -27,7 +27,7 @@ #include -#include "imex/Utils/XeUtils.h" +#include "imex/Utils/DebugUtils.h" namespace imex { namespace xegpu { @@ -166,10 +166,10 @@ static void printArrayElement(mlir::AsmPrinter &printer, } static mlir::LogicalResult -parseSgMapAttrElements(mlir::AsmParser &parser, - llvm::SmallVector &layout, - llvm::SmallVector &data, - llvm::SmallVector &mmaBlockSize) { +parseSubGroupMapAttrElements(mlir::AsmParser &parser, + llvm::SmallVector &layout, + llvm::SmallVector &data, + llvm::SmallVector &mmaBlockSize) { auto parseElt = [&]() -> mlir::LogicalResult { return mlir::AsmParser::KeywordSwitch(parser) .Case("mma_block_size", @@ -185,8 +185,9 @@ parseSgMapAttrElements(mlir::AsmParser &parser, return parseArrayList(parser, data, true); }) .Default([&](llvm::StringRef keyword, llvm::SMLoc) { - parser.emitError(parser.getCurrentLocation(), - "SgMapAttr Parser meet an unexpected keywoard: ") + parser.emitError( + parser.getCurrentLocation(), + "SubGroupMapAttr Parser meet an unexpected keywoard: ") << keyword << "\n"; return mlir::failure(); }); @@ -202,10 +203,9 @@ parseSgMapAttrElements(mlir::AsmParser &parser, return mlir::success(); } -static void printSgMapAttrElements(mlir::AsmPrinter &printer, - llvm::ArrayRef layout, - llvm::ArrayRef data, - llvm::ArrayRef mmaBlockSize) { +static void printSubGroupMapAttrElements( + mlir::AsmPrinter &printer, llvm::ArrayRef layout, + llvm::ArrayRef data, llvm::ArrayRef mmaBlockSize) { printer << "{"; if (mmaBlockSize.size()) { printArrayElement(printer, "mma_block_size", mmaBlockSize); @@ -218,9 +218,9 @@ static void printSgMapAttrElements(mlir::AsmPrinter &printer, } static mlir::LogicalResult -parseWgMapAttrElements(mlir::AsmParser &parser, - llvm::SmallVector &layout, - llvm::SmallVector &data) { +parseWorkGroupMapAttrElements(mlir::AsmParser &parser, + llvm::SmallVector &layout, + llvm::SmallVector &data) { auto parseElt = [&]() -> mlir::LogicalResult { return mlir::AsmParser::KeywordSwitch(parser) .Case("sg_layout", @@ -232,8 +232,9 @@ parseWgMapAttrElements(mlir::AsmParser &parser, return parseArrayList(parser, data, true); }) .Default([&](llvm::StringRef keyword, llvm::SMLoc) { - parser.emitError(parser.getCurrentLocation(), - "WgMapAttr Parser meet an unexpected keywoard: ") + parser.emitError( + parser.getCurrentLocation(), + "WorkGroupMapAttr Parser meet an unexpected keywoard: ") << keyword << "\n"; return mlir::failure(); }); @@ -248,9 +249,9 @@ parseWgMapAttrElements(mlir::AsmParser &parser, return mlir::success(); } -static void printWgMapAttrElements(mlir::AsmPrinter &printer, - llvm::ArrayRef layout, - llvm::ArrayRef data) { +static void printWorkGroupMapAttrElements(mlir::AsmPrinter &printer, + llvm::ArrayRef layout, + llvm::ArrayRef data) { printer << "{"; printArrayElement(printer, "sg_layout", layout); printer << "," << ' '; @@ -258,28 +259,27 @@ static void printWgMapAttrElements(mlir::AsmPrinter &printer, printer << "}"; } -mlir::LogicalResult -SgMapAttr::verify(llvm::function_ref emitError, - llvm::ArrayRef layout, - llvm::ArrayRef data, - llvm::ArrayRef mmaBlockSize) { +mlir::LogicalResult SubGroupMapAttr::verify( + llvm::function_ref emitError, + llvm::ArrayRef layout, llvm::ArrayRef data, + llvm::ArrayRef mmaBlockSize) { if (mmaBlockSize.size() != 2 && mmaBlockSize.size() != 0) { emitError() - << "Failed to parse SgMapAttr: mma_block_size should be a " + << "Failed to parse SubGroupMapAttr: mma_block_size should be a " "`llvm::ArrayRef` with size 2 or empty. But it got " << mmaBlockSize.size() << ".\n"; return mlir::failure(); } if (layout.size() != 2) { - emitError() << "Failed to parse SgMapAttr: missing wi_layout which " + emitError() << "Failed to parse SubGroupMapAttr: missing wi_layout which " "is to be a `llvm::ArrayRef` with size 2.\n"; return mlir::failure(); } if (data.size() != 2) { - emitError() << "Failed to parse SgMapAttr: missing wi_data which is " + emitError() << "Failed to parse SubGroupMapAttr: missing wi_data which is " "to be a `llvm::ArrayRef` with size 2.\n"; return mlir::failure(); } @@ -287,18 +287,17 @@ SgMapAttr::verify(llvm::function_ref emitError, return mlir::success(); } -mlir::LogicalResult -WgMapAttr::verify(llvm::function_ref emitError, - llvm::ArrayRef layout, - llvm::ArrayRef data) { +mlir::LogicalResult WorkGroupMapAttr::verify( + llvm::function_ref emitError, + llvm::ArrayRef layout, llvm::ArrayRef data) { if (layout.size() != 2) { - emitError() << "Failed to parse WgMapAttr: missing sg_layout which " + emitError() << "Failed to parse WorkGroupMapAttr: missing sg_layout which " "is to be a `llvm::ArrayRef` with size 2.\n"; return mlir::failure(); } if (data.size() != 2) { - emitError() << "Failed to parse WgMapAttr: missing sg_data which is " + emitError() << "Failed to parse WorkGroupMapAttr: missing sg_data which is " "to be a `llvm::ArrayRef` with size 2.\n"; return mlir::failure(); } @@ -306,8 +305,8 @@ WgMapAttr::verify(llvm::function_ref emitError, } mlir::Attribute XeMapAttr::parse(mlir::AsmParser &parser, mlir::Type type) { - imex::xegpu::WgMapAttr wg; - imex::xegpu::SgMapAttr sg; + imex::xegpu::WorkGroupMapAttr wg; + imex::xegpu::SubGroupMapAttr sg; // Parse literal '<' if (parser.parseLess()) return {}; @@ -322,10 +321,10 @@ mlir::Attribute XeMapAttr::parse(mlir::AsmParser &parser, mlir::Type type) { llvm::SmallVector mmaBlockSize; llvm::SmallVector wiLayout; llvm::SmallVector wiData; - if (mlir::failed(parseSgMapAttrElements( - parser, mmaBlockSize, wiLayout, wiData))) + if (mlir::failed(parseSubGroupMapAttrElements( + parser, wiLayout, wiData, mmaBlockSize))) return mlir::failure(); - sg = imex::xegpu::SgMapAttr::get( + sg = imex::xegpu::SubGroupMapAttr::get( parser.getContext(), wiLayout, wiData, mmaBlockSize); return mlir::success(!!sg); }) @@ -335,11 +334,11 @@ mlir::Attribute XeMapAttr::parse(mlir::AsmParser &parser, mlir::Type type) { return mlir::failure(); llvm::SmallVector sgLayout; llvm::SmallVector sgData; - if (mlir::failed( - parseWgMapAttrElements(parser, sgLayout, sgData))) + if (mlir::failed(parseWorkGroupMapAttrElements( + parser, sgLayout, sgData))) return mlir::failure(); - wg = imex::xegpu::WgMapAttr::get(parser.getContext(), - sgLayout, sgData); + wg = imex::xegpu::WorkGroupMapAttr::get(parser.getContext(), + sgLayout, sgData); return mlir::success(!!wg); }) .Default([&](llvm::StringRef keyword, llvm::SMLoc) { @@ -370,7 +369,8 @@ void XeMapAttr::print(mlir::AsmPrinter &printer) const { printer << "<"; if (getWg()) { printer << "wg = "; - printWgMapAttrElements(printer, getWg().getSgLayout(), getWg().getSgData()); + printWorkGroupMapAttrElements(printer, getWg().getSgLayout(), + getWg().getSgData()); printSep = true; } @@ -378,8 +378,9 @@ void XeMapAttr::print(mlir::AsmPrinter &printer) const { if (printSep) printer << ", "; printer << "sg = "; - printSgMapAttrElements(printer, getSg().getMmaBlockSize(), - getSg().getWiLayout(), getSg().getWiData()); + printSubGroupMapAttrElements(printer, getSg().getWiLayout(), + getSg().getWiData(), + getSg().getMmaBlockSize()); } printer << ">"; diff --git a/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 4aea1a89f..f9ef3a0d8 100644 --- a/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -27,7 +27,7 @@ #include #include -#include "imex/Utils/XeUtils.h" +#include "imex/Utils/DebugUtils.h" #define DEBUG_TYPE "xegpu" @@ -71,8 +71,8 @@ static void transpose(llvm::ArrayRef trans, }; static bool isMappingAttr(mlir::Attribute attr) { - return attr && (llvm::isa(attr) || - llvm::isa(attr) || + return attr && (llvm::isa(attr) || + llvm::isa(attr) || llvm::isa(attr)); } @@ -614,8 +614,8 @@ mlir::LogicalResult LoadNDOp::verify() { auto valueShape = valueTy.getShape().vec(); if (mode == imex::xegpu::Mode::SIMT) { - imex::xegpu::WgMapAttr wgMap; - imex::xegpu::SgMapAttr sgMap; + imex::xegpu::WorkGroupMapAttr wgMap; + imex::xegpu::SubGroupMapAttr sgMap; auto encoding = tdescTy.getEncoding(); if (!isMappingAttr(encoding)) { @@ -627,8 +627,8 @@ mlir::LogicalResult LoadNDOp::verify() { wgMap = xeMapAttr.getWg(); sgMap = xeMapAttr.getSg(); } else { - wgMap = llvm::dyn_cast(encoding); - sgMap = llvm::dyn_cast(encoding); + wgMap = llvm::dyn_cast(encoding); + sgMap = llvm::dyn_cast(encoding); } if (wgMap) { @@ -636,12 +636,13 @@ mlir::LogicalResult LoadNDOp::verify() { auto sgLayout = wgMap.getSgLayout(); for (size_t i = 0; i < sgData.size(); i++) { if (tdescShape[i] % sgLayout[i] != 0 || - tdescShape[i] % sgData[i] != 0 || tdescShape[i] % sgData[i] != 0) - return emitOpError( - "Invalid WgMapAttr. It should meet the following conditions: " - "tdescShape[i] % sgLayout[i] == 0 && " - "tdescShape[i] % sgData[i] == 0 && " - "tdescShape[i] % sgData[i] == 0"); + tdescShape[i] % sgData[i] != 0 || + tdescShape[i] % (sgLayout[i] * sgData[i]) != 0) + return emitOpError("Invalid WorkGroupMapAttr. It should meet the " + "following conditions: " + "tdescShape[i] % sgLayout[i] == 0 && " + "tdescShape[i] % sgData[i] == 0 && " + "tdescShape[i] % (sgLayout[i] *sgData[i]) == 0"); tdescShape[i] /= sgLayout[i]; } } @@ -654,22 +655,22 @@ mlir::LogicalResult LoadNDOp::verify() { if (tdescShape[i] % blockSize[i] != 0 || blockSize[i] % wiLayout[i] != 0 || blockSize[i] % wiData[i] != 0 || blockSize[i] % (wiLayout[i] * wiData[i]) != 0) { - return emitOpError( - "Invalid SgMapAttr. It should meet the following conditions: " - "tdescShape[i] % blockSize[i] == 0 && " - "blockSize[i] % wiLayout[i] == 0 && " - "blockSize[i] % wiData[i] == 0 && " - "blockSize[i] % (wiLayout[i] * wiData[i]) == 0 "); + return emitOpError("Invalid SubGroupMapAttr. It should meet the " + "following conditions: " + "tdescShape[i] % blockSize[i] == 0 && " + "blockSize[i] % wiLayout[i] == 0 && " + "blockSize[i] % wiData[i] == 0 && " + "blockSize[i] % (wiLayout[i] * wiData[i]) == 0 "); } } for (size_t i = 0; i < wiLayout.size(); i++) { if (tdescShape[i] % wiData[i] != 0 || tdescShape[i] % (wiLayout[i] * wiData[i]) != 0) { - return emitOpError( - "Invalid SgMapAttr. It should meet the following conditions: " - "tdescShape[i] % wiData[i] == 0 && " - "tdescShape[i] % (wiLayout[i] * wiData[i]) == 0 "); + return emitOpError("Invalid SubGroupMapAttr. It should meet the " + "following conditions: " + "tdescShape[i] % wiData[i] == 0 && " + "tdescShape[i] % (wiLayout[i] * wiData[i]) == 0 "); } tdescShape[i] /= wiLayout[i]; } @@ -838,25 +839,29 @@ mlir::LogicalResult StoreNDOp::verify() { "SIMT mode operators.\n"); } - imex::xegpu::WgMapAttr wgMap; - imex::xegpu::SgMapAttr sgMap; + imex::xegpu::WorkGroupMapAttr wgMap; + imex::xegpu::SubGroupMapAttr sgMap; std::vector shape = dstTy.getShape().vec(); if (auto xeMapAttr = llvm::dyn_cast(encoding)) { wgMap = xeMapAttr.getWg(); sgMap = xeMapAttr.getSg(); } else { - wgMap = llvm::dyn_cast(encoding); - sgMap = llvm::dyn_cast(encoding); + wgMap = llvm::dyn_cast(encoding); + sgMap = llvm::dyn_cast(encoding); } if (wgMap) { auto sgData = wgMap.getSgData(); auto sgLayout = wgMap.getSgLayout(); for (size_t i = 0; i < sgData.size(); i++) { - assert(shape[i] % sgLayout[i] == 0); - assert(shape[i] % sgData[i] == 0); - assert(shape[i] % (sgLayout[i] * sgData[i]) == 0); + if (shape[i] % sgLayout[i] != 0 || shape[i] % sgData[i] != 0 || + shape[i] % (sgLayout[i] * sgData[i]) != 0) + return emitOpError("Invalid WorkGroupMapAttr. It should meet the " + "following conditions: " + "tdescShape[i] % sgLayout[i] == 0 && " + "tdescShape[i] % sgData[i] == 0 && " + "tdescShape[i] % (sgLayout[i] *sgData[i]) == 0"); shape[i] /= sgLayout[i]; } } @@ -867,24 +872,24 @@ mlir::LogicalResult StoreNDOp::verify() { auto wiData = sgMap.getWiData(); for (size_t i = 0; i < shape.size(); i++) { if (blockSize[i] % (wiLayout[i] * wiData[i]) != 0 || - blockSize[i] % wiLayout[i] != 0 || blockSize[i] % wiData[i] == 0 || - shape[i] % blockSize[i] == 0) { - return emitOpError( - "Invalid SgMapAttr. It should meet the following conditions: " - "tdescShape[i] % blockSize[i] == 0 && " - "blockSize[i] % wiLayout[i] == 0 && " - "blockSize[i] % wiData[i] == 0 && " - "blockSize[i] % (wiLayout[i] * wiData[i]) == 0 "); + blockSize[i] % wiLayout[i] != 0 || blockSize[i] % wiData[i] != 0 || + shape[i] % blockSize[i] != 0) { + return emitOpError("Invalid SubGroupMapAttr. It should meet the " + "following conditions: " + "tdescShape[i] % blockSize[i] == 0 && " + "blockSize[i] % wiLayout[i] == 0 && " + "blockSize[i] % wiData[i] == 0 && " + "blockSize[i] % (wiLayout[i] * wiData[i]) == 0 "); } } for (size_t i = 0; i < wiLayout.size(); i++) { if (shape[i] % wiData[i] != 0 || shape[i] % (wiLayout[i] * wiData[i]) != 0) { - return emitOpError( - "Invalid SgMapAttr. It should meet the following conditions: " - "tdescShape[i] % wiData[i] == 0 && " - "tdescShape[i] % (wiLayout[i] * wiData[i]) == 0 "); + return emitOpError("Invalid SubGroupMapAttr. It should meet the " + "following conditions: " + "tdescShape[i] % wiData[i] == 0 && " + "tdescShape[i] % (wiLayout[i] * wiData[i]) == 0 "); } shape[i] /= wiLayout[i]; } @@ -983,6 +988,10 @@ mlir::LogicalResult DpasOp::verify() { "lhs and rhs rank does not match for dpas op, or their rank is not 3."); } + if (lhsRank < 3) { + return emitOpError("dpas op requires 3d vector. Rank is not 3"); + } + return mlir::success(); } diff --git a/lib/Dialect/XeTile/CMakeLists.txt b/lib/Dialect/XeTile/CMakeLists.txt index 218c20c88..9f57627c3 100644 --- a/lib/Dialect/XeTile/CMakeLists.txt +++ b/lib/Dialect/XeTile/CMakeLists.txt @@ -1,2 +1,2 @@ add_subdirectory(IR) -#add_subdirectory(Transforms) +add_subdirectory(Transforms) diff --git a/lib/Dialect/XeTile/Transforms/CMakeLists.txt b/lib/Dialect/XeTile/Transforms/CMakeLists.txt index 59e45befd..31df10494 100644 --- a/lib/Dialect/XeTile/Transforms/CMakeLists.txt +++ b/lib/Dialect/XeTile/Transforms/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(IMEXXeTileTransforms - # FIXME.cpp + XeTileTiling.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/imex/Dialect/XeTile diff --git a/lib/Dialect/XeTile/Transforms/XeTileTiling.cpp b/lib/Dialect/XeTile/Transforms/XeTileTiling.cpp new file mode 100644 index 000000000..54bb2371f --- /dev/null +++ b/lib/Dialect/XeTile/Transforms/XeTileTiling.cpp @@ -0,0 +1,375 @@ +//===- LowerXeTileToBlockLayout.cppp - LowerXeTileToBlockLayout Pass -------*- +// C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains lowering transformation for XeTile large tiles into +/// smaller tiles with blocked layout that maps to register region. +/// This blocked layout is represented by high dimension vectors, inner +/// dimension matches to DPAS size config. +/// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h" +#include "imex/Dialect/XeTile/Transforms/Passes.h" +#include "imex/Utils/DebugUtils.h" + +#include "PassDetail.h" +#include "XeTileTiling.h" + +using namespace mlir; +using namespace imex; +namespace imex { +#define GEN_PASS_DEF_XETILETILING +#include "imex/Dialect/XeTile/Transforms/Passes.h.inc" +} // namespace imex + +namespace imex { + +// DPAS block size as per HW config - TODO, define platform specific sizes +#define M_SIZE 8 +#define K_SIZE 16 +#define N_SIZE 16 + +struct ArithConstantOpPattern + : public XeTileConversion { + using XeTileConversion::XeTileConversion; + + mlir::LogicalResult + matchAndRewrite(mlir::arith::ConstantOp op, OpAdaptor adaptor, + OpPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto result = op.getResult(); + auto resultTy = result.getType(); + + if (!resultTy.isa()) + return mlir::failure(); + + auto vectorTy = resultTy.cast(); + + // We only interesting 2D vectors, and the one used as C + if (vectorTy.getRank() != 2) + return mlir::failure(); + + auto valueAttr = op.getValue(); + if (!valueAttr.isa()) + return mlir::failure(); + + auto denseElementsAttr = valueAttr.cast(); + if (!denseElementsAttr.isSplat()) + return mlir::failure(); + + auto splatVal = denseElementsAttr.getSplatValue(); + + auto shape = vectorTy.getShape(); + + // TODO: a place holder for getting inner_blocks + llvm::SmallVector inner_blocks = {M_SIZE, N_SIZE}; + + // set blockSizes default to inner_block size + llvm::SmallVector blockSizes(inner_blocks); + if (isA(op)) { + blockSizes = {M_SIZE, K_SIZE}; + } else if (isB(op)) { + blockSizes = {K_SIZE, N_SIZE}; + } else if (isRC(op)) { + blockSizes = {M_SIZE, N_SIZE}; + } + + auto vecTy = ::mlir::VectorType::get({shape[0] / blockSizes[0], + shape[1] / blockSizes[1], + blockSizes[0], blockSizes[1]}, + vectorTy.getElementType()); + + auto newOp = rewriter.create( + loc, vecTy, mlir::DenseElementsAttr::get(vecTy, splatVal)); + + // rewriter.replaceOp(op, newOp); + rewriter.replaceOpWithIf(op, newOp->getResults(), [&](mlir::OpOperand &op) { + auto *owner = op.getOwner(); + + // the direct user is an xetile operator + if (llvm::isa(owner->getDialect())) + return true; + + // the direct user is an scf::ForOp, but the corresponding argument + // is used by an xetile operator + if (auto forOp = llvm::dyn_cast(owner)) { + auto arg = forOp.getRegionIterArgForOpOperand(op); + + auto haveXeTileUsers = std::any_of( + arg.user_begin(), arg.user_end(), [&](mlir::Operation *op) { + return llvm::isa(op->getDialect()); + }); + + if (auto yieldOp = llvm::dyn_cast( + forOp.getRegion().front().getTerminator())) { + auto idx = forOp.getResultForOpOperand(op).getResultNumber(); + auto definingOp = yieldOp.getResults()[idx].getDefiningOp(); + haveXeTileUsers |= + llvm::isa(definingOp->getDialect()); + } + + return haveXeTileUsers; + } + + return false; + }); + + return mlir::success(); + } +}; + +struct SCFForOpPattern : public XeTileConversion { + using XeTileConversion::XeTileConversion; + + ::mlir::LogicalResult + matchAndRewrite(mlir::scf::ForOp op, OpAdaptor adaptor, + OpPatternRewriter &rewriter) const override { + auto newOp = rewriter.create( + op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), adaptor.getInitArgs()); + mlir::Block *block = op.getBody(); + mlir::Block *newBlock = newOp.getBody(); + rewriter.mergeBlocks(block, newBlock, newBlock->getArguments()); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + +struct InitTileOpPattern : public XeTileConversion { + using XeTileConversion::XeTileConversion; + + ::mlir::LogicalResult + matchAndRewrite(xetile::InitTileOp op, OpAdaptor adaptor, + OpPatternRewriter &rewriter) const override { + auto tileTy = op.getTile().getType(); + if (tileTy.getRank() != 2) { + op.emitWarning( + "Skipped InitTileOp because the result tile is not rank 2.\n"); + return mlir::failure(); + } + + auto shape = tileTy.getShape(); + + // TODO: a place holder for getting inner_blocks + llvm::SmallVector inner_blocks = {M_SIZE, N_SIZE}; + + // set blockSizes default to inner_block size + llvm::SmallVector blockSizes(inner_blocks); + if (isA(op)) { + blockSizes = {M_SIZE, K_SIZE}; + } else if (isB(op)) { + blockSizes = {K_SIZE, N_SIZE}; + } else if (isRC(op)) { + blockSizes = {M_SIZE, N_SIZE}; + } + + auto newTileTy = imex::xetile::TileType::get({shape[0] / blockSizes[0], + shape[1] / blockSizes[1], + blockSizes[0], blockSizes[1]}, + tileTy.getElementType()); + + auto newOp = rewriter.create<::imex::xetile::InitTileOp>( + op.getLoc(), newTileTy, op.getSource(), op.getOffsets(), + op.getStaticOffsetsAttr(), op.getDynamicShape(), + op.getDynamicStrides()); + + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + +struct LoadTileOpPattern : public XeTileConversion { + using XeTileConversion::XeTileConversion; + + ::mlir::LogicalResult + matchAndRewrite(xetile::LoadTileOp op, OpAdaptor adaptor, + OpPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resultTy = op.getResult().getType(); + + if (resultTy.getRank() != 2) { + op.emitWarning("skipped because the result is not 2D."); + return mlir::failure(); + } + + auto shape = resultTy.getShape(); + + // TODO: a place holder for getting inner_blocks + llvm::SmallVector inner_blocks = {M_SIZE, N_SIZE}; + + // set blockSizes default to inner_block size + llvm::SmallVector blockSizes(inner_blocks); + + if (isA(op)) { + blockSizes = {M_SIZE, K_SIZE}; + } else if (isB(op)) { + blockSizes = {K_SIZE, N_SIZE}; + } else if (isRC(op)) { + blockSizes = {M_SIZE, N_SIZE}; + } + + auto vecTy = ::mlir::VectorType::get({shape[0] / blockSizes[0], + shape[1] / blockSizes[1], + blockSizes[0], blockSizes[1]}, + resultTy.getElementType()); + + auto newOp = rewriter.create<::imex::xetile::LoadTileOp>( + loc, vecTy, adaptor.getSource(), op.getTransposeAttr(), + op.getPaddingAttr()); + + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + +struct StoreTileOpPattern : public XeTileConversion { + using XeTileConversion::XeTileConversion; + + ::mlir::LogicalResult + matchAndRewrite(xetile::StoreTileOp op, OpAdaptor adaptor, + OpPatternRewriter &rewriter) const override { + + auto newOp = rewriter.create<::imex::xetile::StoreTileOp>( + op.getLoc(), adaptor.getValue(), adaptor.getTile()); + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + +struct TileMMAOpPattern : public XeTileConversion { + using XeTileConversion::XeTileConversion; + + ::mlir::LogicalResult + matchAndRewrite(xetile::TileMMAOp op, OpAdaptor adaptor, + OpPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resultTy = op.getOutput().getType(); + + if (resultTy.getRank() != 2) { + op.emitWarning("skipped because the result is not 2D."); + return mlir::failure(); + } + + auto shape = resultTy.getShape(); + auto vecTy = ::mlir::VectorType::get( + {shape[0] / M_SIZE, shape[1] / N_SIZE, M_SIZE, N_SIZE}, + resultTy.getElementType()); + + auto newOp = rewriter.create( + loc, vecTy, adaptor.getA(), adaptor.getB(), adaptor.getC()); + + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + +struct UpdateTileOffsetOpPattern + : public XeTileConversion { + using XeTileConversion::XeTileConversion; + + ::mlir::LogicalResult + matchAndRewrite(xetile::UpdateTileOffsetOp op, OpAdaptor adaptor, + OpPatternRewriter &rewriter) const override { + + auto newOp = rewriter.create<::imex::xetile::UpdateTileOffsetOp>( + op.getLoc(), adaptor.getTile().getType(), adaptor.getTile(), + adaptor.getOffsetX(), adaptor.getOffsetY()); + + rewriter.replaceOp(op, newOp); + return mlir::success(); + } +}; + +void populateXeTileTilingPatterns(imex::XeTypeConverter &converter, + mlir::RewritePatternSet &patterns) { + + patterns.insert(patterns.getContext(), converter); +} + +// Lowers XeTile to blocked layout with high-dim vector +struct XeTileTilingPass + : public imex::impl::XeTileTilingBase { + + XeTileTilingPass() = default; + +public: + void runOnOperation() override { + mlir::MLIRContext &context = getContext(); + auto mod = this->getOperation(); + + // skip functions with XeTile.TileType inputs and outputs + bool hasTileTyInFuncTy = false; + mod.walk([&](mlir::func::FuncOp op) { + auto funcTy = op.getFunctionType(); + hasTileTyInFuncTy |= std::any_of( + funcTy.getInputs().begin(), funcTy.getInputs().end(), + [](mlir::Type ty) { return llvm::isa(ty); }); + hasTileTyInFuncTy |= std::any_of( + funcTy.getResults().begin(), funcTy.getInputs().end(), + [](mlir::Type ty) { return llvm::isa(ty); }); + }); + + if (hasTileTyInFuncTy) { + mod.emitOpError( + "Currently FunctionType with xetile.TileType is not supported."); + return signalPassFailure(); + } + + imex::ValueAttributeMap map; + mod.walk([&](imex::xetile::TileMMAOp op) { + markDefChainValues(op.getA(), OperandType::DPASA, map); + markDefChainValues(op.getB(), OperandType::DPASB, map); + if (bool(op.getC())) + markDefChainValues(op.getC(), OperandType::DPASC, map); + markUseChainValues(op.getOutput(), OperandType::DPASR, map); + }); + + mlir::RewritePatternSet patterns(&context); + XeTypeConverter typeConverter(context, map); + + populateXeTileTilingPatterns(typeConverter, patterns); + + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.maxIterations = 2; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + if (failed( + applyPatternsAndFoldGreedily(mod, std::move(patterns), config))) { + return signalPassFailure(); + } + } +}; + +/// Create a pass +std::unique_ptr<::mlir::Pass> createXeTileTilingPass() { + return std::make_unique(); +} +} // namespace imex diff --git a/lib/Dialect/XeTile/Transforms/XeTileTiling.h b/lib/Dialect/XeTile/Transforms/XeTileTiling.h new file mode 100644 index 000000000..41412ea8d --- /dev/null +++ b/lib/Dialect/XeTile/Transforms/XeTileTiling.h @@ -0,0 +1,64 @@ +//===- XeTileTranformBase.h - -------*- C++ -*-===// +//===- XeTileTranformBase.h - -------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// +/////===----------------------------------------------------------------------===// +#ifndef _XeTileTranformBase_H_INCLUDED_ +#define _XeTileTranformBase_H_INCLUDED_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h" +#include "imex/Dialect/XeTile/IR/XeTileOps.h" +#include "imex/Utils/DebugUtils.h" +#include "imex/Utils/PassWrapper.h" +#include "imex/Utils/XeCommon.h" + +#include "PassDetail.h" + +namespace imex { + +template +class XeTileConversion : public imex::XeConversionPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + using OpPatternRewriter = typename mlir::PatternRewriter; + + XeTileConversion(mlir::MLIRContext *context, XeTypeConverter &typeConverter, + mlir::PatternBenefit benefit = 1) + : XeConversionPattern(typeConverter, SourceOp::getOperationName(), + benefit, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override final { + auto sourceOp = llvm::cast(op); + OpAdaptor adaptor(op->getOperands(), sourceOp); + return matchAndRewrite(sourceOp, adaptor, rewriter); + } + + virtual mlir::LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + OpPatternRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + } +}; + +} // namespace imex + +#endif diff --git a/lib/Utils/CMakeLists.txt b/lib/Utils/CMakeLists.txt index e97fa5f48..3a317f99c 100644 --- a/lib/Utils/CMakeLists.txt +++ b/lib/Utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_library(IMEXUtil FuncUtils.cpp TypeConversion.cpp + XeCommon.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/imex/Utils diff --git a/lib/Utils/XeCommon.cpp b/lib/Utils/XeCommon.cpp new file mode 100644 index 000000000..4c47381ea --- /dev/null +++ b/lib/Utils/XeCommon.cpp @@ -0,0 +1,314 @@ +//===- XeCommon.cpp - --------------*- C++ -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements XeTypeConverter, ValueAttributeMap and some other +/// routines used by Xe related dialects. +/// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include "imex/Dialect/XeGPU/IR/XeGPUOps.h" +#include "imex/Dialect/XeTile/IR/XeTileOps.h" +#include "imex/Utils/DebugUtils.h" +#include "imex/Utils/XeCommon.h" + +namespace imex { + +void ValueAttributeMap::add(mlir::BlockArgument arg, imex::OperandType type) { + if (argumentMap.count(arg) == 0) + argumentMap[arg] = int(type); + else + argumentMap[arg] |= int(type); +} + +void ValueAttributeMap::add(mlir::Operation *op, imex::OperandType type) { + if (operationMap.count(op) == 0) + operationMap[op] = int(type); + else + operationMap[op] |= int(type); +} + +imex::OperandType ValueAttributeMap::get(mlir::Operation *op) { + if (operationMap.count(op) == 0) + return OperandType::None; + return OperandType(operationMap[op]); +} + +imex::OperandType ValueAttributeMap::get(mlir::BlockArgument arg) { + if (argumentMap.count(arg) == 0) + return OperandType::None; + return OperandType(argumentMap[arg]); +} + +static bool isConvertibleOp(mlir::Operation *op) { + if (llvm::isa(op) || + llvm::isa(op)) { + return true; + } + return false; +} + +static int getOperandIndex(mlir::Operation *op, mlir::Value operand) { + for (auto [i, value] : llvm::enumerate(op->getOperands())) { + if (operand == value) + return i; + } + return -1; +}; + +static mlir::Value getOperandForArg(mlir::scf::ForOp &forOp, + mlir::Value &value) { + auto arg = llvm::dyn_cast_or_null(value); + if (arg && arg.getArgNumber() >= forOp.getNumInductionVars()) { + auto &iterOperand = forOp.getOpOperandForRegionIterArg(arg); + auto numCtrlOperands = forOp.getNumControlOperands(); + auto operandIdx = iterOperand.getOperandNumber(); + return forOp.getInitArgs()[operandIdx - numCtrlOperands]; + } + return mlir::Value(); +}; + +static mlir::BlockArgument getArgForOperand(mlir::scf::ForOp &forOp, + mlir::Value operand) { + auto idx = getOperandIndex(forOp, operand); + auto numControls = forOp.getNumControlOperands(); + assert(idx >= numControls); + return forOp.getRegionIterArg(idx - numControls); +}; + +static mlir::Operation *getDefineOrParentOp(mlir::Value value) { + if (llvm::isa(value)) + return value.getDefiningOp(); + if (auto arg = llvm::dyn_cast_or_null(value)) + return arg.getOwner()->getParentOp(); + return NULL; +}; + +enum class ChainType { DefChain, UseChain }; + +// It traverse operators and arguments in the chain, +// ops will carry on operators in the chain +// arg will carry on arguments in the chain +// skippedOps carry the operators to be skipped during the traverse +template +void traverseDefUseChain(mlir::Value value, + llvm::SmallVector &ops, + llvm::SmallVector &args, + llvm::SmallVector skippedOps = {}) { + + llvm::SmallVector queue; + + auto isSkipped = [&](mlir::Operation *op) { + return std::find(skippedOps.begin(), skippedOps.end(), op) != + skippedOps.end(); + }; + + auto visitDef = [&](mlir::Value value) { + if (auto arg = llvm::dyn_cast_or_null(value)) + args.push_back(arg); + + auto *op = getDefineOrParentOp(value); + if (op == nullptr || isSkipped(op)) + return; + + // we don't track scf.for since it is composited op + if (!llvm::isa(op)) + ops.push_back(op); + + if (isConvertibleOp(op)) { + queue.append(op->operand_begin(), op->operand_end()); + } else if (auto forOp = llvm::dyn_cast_or_null(op)) { + auto opr = getOperandForArg(forOp, value); + if (bool(opr)) + queue.push_back(opr); + } else if (llvm::isa(value) && + !llvm::isa(op) && + !llvm::isa(op)) { + op->emitError("\nUnexpected operator of an BlockArgument.\n"); + llvm_unreachable("Unexpected case for when handling a BlockArgument.\n"); + } + + return; + }; + + auto visitUsers = [&](mlir::Value value) { + if (!bool(value)) + return; + for (mlir::Operation *user : value.getUsers()) { + if (isSkipped(user)) + continue; + ops.push_back(user); + // YieldOp indicats results of a SCF ForOp, IfOp is currently not handled. + if (llvm::isa(user)) { + auto *parentOp = user->getParentOp(); + auto idx = getOperandIndex(user, value); + if (llvm::isa(parentOp) && idx >= 0) { + auto opResult = parentOp->getResult(idx); + queue.push_back(opResult); + } else { + llvm_unreachable( + "Meet an unexpected/unprocessed op in preOrderVisist.\n"); + } + } else if (auto forOp = llvm::dyn_cast_or_null(user)) { + auto arg = getArgForOperand(forOp, value); + args.push_back(arg); + queue.push_back(arg); + } else if (isConvertibleOp(user)) { + queue.append(user->result_begin(), user->result_end()); + } + } + }; + + if (bool(value)) + queue.push_back(value); + + while (queue.size()) { + auto value = queue.pop_back_val(); + if (!bool(value)) + continue; + + if (chain == ChainType::DefChain) { + visitDef(value); + continue; + } + + if (chain == ChainType::UseChain) { + visitUsers(value); + continue; + } + } +} + +static bool isInterestingTarget(mlir::Operation *op) { + auto constantOp = llvm::dyn_cast_or_null(op); + return llvm::isa(op->getDialect()) || + (constantOp && + llvm::isa(constantOp.getResult().getType())); +} + +static bool isInterestingTarget(mlir::BlockArgument arg) { + auto ty = arg.getType(); + return llvm::isa(ty) || + llvm::isa(ty); +} + +/* + * markDefChainValues iterates over the values in the def-chain of the given + * value, and mark/record them with the OperandType type in ValueAttributeMap. + */ +void markDefChainValues(mlir::Value value, imex::OperandType type, + imex::ValueAttributeMap &map) { + + auto mark = [&](llvm::SmallVector &array) { + for (auto v : array) { + if (isInterestingTarget(v)) + map.add(v, type); + } + }; + + llvm::SmallVector postOrderOps; + llvm::SmallVector postOrderArgs; + traverseDefUseChain(value, postOrderOps, postOrderArgs); + + // mark the interested ops with the type in map. + // here we only interested in ops from xetile dialect. + // and the arith.consantOp for initializing C of mma. + mark(postOrderOps); + + // mark the interested arguments we only interested in + // args with TileType and VectorType + mark(postOrderArgs); + + llvm::SmallVector TopDownVisits; + for (auto op : postOrderOps) { + if (isConvertibleOp(op)) + TopDownVisits.append(op->result_begin(), op->result_end()); + } + + for (auto arg : postOrderArgs) { + if (isInterestingTarget(arg)) + TopDownVisits.push_back(arg); + } + + llvm::SmallVector preOrderOps; + llvm::SmallVector preOrderArgs; + + for (auto v : TopDownVisits) { + // If the value just has one user, it should have been visited in + // postOrderVisit. This is how it is added to the vector. + if (llvm::hasSingleElement(v.getUsers())) + continue; + // get Operators and arguments directly and indirectly using value v + // but skip those already visited in postOrderVisit + traverseDefUseChain(v, preOrderOps, preOrderArgs, + postOrderOps); + } + + mark(preOrderOps); + mark(preOrderArgs); +} + +/* + * markUseChainValues iterates over the values in use-chain of the given value, + * and mark/record them with the OperandType type in ValueAttributeMap. + */ +void markUseChainValues(mlir::Value value, imex::OperandType type, + imex::ValueAttributeMap &map) { + + auto mark = [&](llvm::SmallVector &array) { + for (auto v : array) { + if (isInterestingTarget(v)) + map.add(v, type); + } + }; + + llvm::SmallVector preOrderOps; + llvm::SmallVector preOrderArgs; + traverseDefUseChain(value, preOrderOps, preOrderArgs); + + mark(preOrderOps); + mark(preOrderArgs); + + llvm::SmallVector BottomUpVisits; + + // We don't do postOrderVisit on arguments since they only have one defining + // op which should be visited in preOrderVisit. + for (auto op : preOrderOps) { + if (isConvertibleOp(op)) + BottomUpVisits.append(op->operand_begin(), op->operand_end()); + } + + llvm::SmallVector postOrderOps; + llvm::SmallVector postOrderArgs; + + for (auto v : BottomUpVisits) { + traverseDefUseChain(v, postOrderOps, postOrderArgs, + preOrderOps); + } + + mark(postOrderOps); + mark(postOrderArgs); +} + +mlir::ValueRange buildUnrealizedCast(mlir::OpBuilder &builder, + mlir::TypeRange resultTypes, + mlir::ValueRange inputs) { + mlir::Location loc = builder.getUnknownLoc(); + if (!inputs.empty()) + loc = inputs.front().getLoc(); + auto castOp = builder.create( + loc, resultTypes, inputs); + return castOp->getResults(); +} + +} // namespace imex diff --git a/test/Conversion/XeTileToXeGPU/sg_level_gemm_1k_1k_1k_f16_f32.mlir b/test/Conversion/XeTileToXeGPU/sg_level_gemm_1k_1k_1k_f16_f32.mlir new file mode 100644 index 000000000..47400e5cd --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_gemm_1k_1k_1k_f16_f32.mlir @@ -0,0 +1,718 @@ +// RUN: imex-opt --xetile-tiling --convert-xetile-to-xegpu --remove-dead-values %s | FileCheck %s + +// CHECK-LABEL: func @test_gemm({{.*}}) { +func.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // %c8 = arith.constant 8 : index + // %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c64 : index + %n = arith.muli %block_id_y, %c64 : index + // intialize C tile and load it + //CHECK: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi %2, %c16_14 : index + //CHECK-NEXT: arith.addi %3, %c16_15 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<64x64xf32> + //CHECK: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<64x64xf32> -> vector<64x64xf32> + // initalize A and B tiles + // CHECK: arith.constant 0 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 40 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 40 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 40 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 40 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 56 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 56 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 56 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + // CHECK-NEXT: arith.constant 56 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<64x64xf16> + // CHECK: arith.constant 0 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.addi {{.*}} : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<64x64xf16> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + // CHECK: scf.for + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + %out:3 = scf.for %k = %c0 to %c1024 step %c64 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<64x64xf16>, !xetile.tile<64x64xf16>, vector<64x64xf32>) { + + // load A and B tiles + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %a_value = xetile.load_tile %a_tile : !xetile.tile<64x64xf16> -> vector<64x64xf16> + // CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + // CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<64x64xf16> -> vector<64x64xf16> + // perform dpas and accumulate + // CHECK: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + // CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value : vector<64x64xf16>, vector<64x64xf16>, vector<64x64xf32> -> vector<64x64xf32> + // update the offsets for A and B tiles + // CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c64] + : !xetile.tile<64x64xf16>, index, index -> !xetile.tile<64x64xf16> + + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c64, %c0] + : !xetile.tile<64x64xf16>, index, index -> !xetile.tile<64x64xf16> + // partial C tile result + // CHECK: scf.yield + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + // CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + // CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<64x64xf16>, !xetile.tile<64x64xf16>, vector<64x64xf32> + } + // store the final accumulated C tile result back to memory + //CHECK: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xetile.store_tile %out#2, %c_init_tile: vector<64x64xf32>, !xetile.tile<64x64xf32> + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_load_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_level_load_tile.mlir new file mode 100644 index 000000000..c333b564e --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_load_tile.mlir @@ -0,0 +1,18 @@ +// RUN: imex-opt --split-input-file --xetile-tiling --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +func.func @sglevel_tiled_load_tile(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<16x16xf16> + + //CHECK: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %2 = xetile.load_tile %1 : !xetile.tile<16x16xf16> -> vector<16x16xf16> + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_scf_for.mlir b/test/Conversion/XeTileToXeGPU/sg_level_scf_for.mlir new file mode 100644 index 000000000..4a09b538f --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_scf_for.mlir @@ -0,0 +1,33 @@ +// RUN: imex-opt --split-input-file --xetile-tiling --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +// CHECK: sglevel +func.func @sglevel_tiled_gemm(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + //CHECK: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<16x16xf16> + %2 = arith.constant dense<0.0> : vector<16x16xf16> + //CHECK: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, vector<8x16xf16>, vector<8x16xf16> + %nexta, %res = scf.for %k= %c0 to %c1024 step %c64 iter_args(%subA = %1, %subB = %2) -> (!xetile.tile<16x16xf16>, vector<16x16xf16>) { + //CHECK: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xetile.load_tile %subA : !xetile.tile<16x16xf16> -> vector<16x16xf16> + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %5 = xetile.update_tile_offset %subA, [%c0, %c64]: !xetile.tile<16x16xf16>, index, index -> !xetile.tile<16x16xf16> + //CHECK: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, vector<8x16xf16>, vector<8x16xf16> + scf.yield %5, %3: !xetile.tile<16x16xf16>, vector<16x16xf16> + } + + //CHECK: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %5 = xetile.init_tile %b[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<16x16xf16> + //CHECK: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + //CHECK: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xetile.store_tile %res, %5: vector<16x16xf16>, !xetile.tile<16x16xf16> + + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_store.mlir b/test/Conversion/XeTileToXeGPU/sg_level_store.mlir new file mode 100644 index 000000000..833140e3e --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_store.mlir @@ -0,0 +1,49 @@ +// RUN: imex-opt --split-input-file --xetile-tiling --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +func.func @sglevel_tiled_store(%a: memref<1024x1024xf32>) { + // CHECK: arith.constant dense<0.000000e+00> : vector<8x16xf32> + // CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + // CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + // CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + // CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + // CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + // CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + // CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + %result = arith.constant dense<0.0>: vector<32x32xf32> + + // CHECK: arith.constant 0 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xetile.init_tile %a[0, 32] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32> + + // CHECK: xegpu.store_nd {{.*}} {mode = vc, {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xetile.store_tile %result, %1: vector<32x32xf32>, !xetile.tile<32x32xf32> + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_tile_mma.mlir b/test/Conversion/XeTileToXeGPU/sg_level_tile_mma.mlir new file mode 100644 index 000000000..e654e7b4f --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_tile_mma.mlir @@ -0,0 +1,112 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +func.func @sglevel_tiled_gemm(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 80 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 80 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 80 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 80 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<4x2x8x16xf16> + + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %2 = xetile.load_tile %1 : !xetile.tile<4x2x8x16xf16> -> vector<4x2x8x16xf16> + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 80 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 80 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 80 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 80 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %3 = xetile.init_tile %b[%c64, %c0] : memref<1024x1024xf16> -> !xetile.tile<2x4x16x16xf16> + + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %4 = xetile.load_tile %3 : !xetile.tile<2x4x16x16xf16> -> vector<2x4x16x16xf16> + + //CHECK: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %6 = xetile.tile_mma %2, %4: vector<4x2x8x16xf16>, vector<2x4x16x16xf16> -> vector<4x4x8x16xf32> + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_tiled_gemm.mlir b/test/Conversion/XeTileToXeGPU/sg_level_tiled_gemm.mlir new file mode 100644 index 000000000..56b2334d3 --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_tiled_gemm.mlir @@ -0,0 +1,798 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s + +func.func @sglevel_tiled_gemm(%a: memref<1024x1024xf16>, + %b: memref<1024x1024xf16>, + %c: memref<1024x1024xf32>) { + %sg_x = gpu.thread_id x + %sg_y = gpu.thread_id y + + //CHECK: arith.constant 0 : index + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + + scf.for %i = %c0 to %c1024 step %c128 { + scf.for %j = %c0 to %c1024 step %c128 { + %tile_0_dim_0 = arith.constant 128 : index + %tile_0_dim_1 = arith.constant 64 : index + %dl_0_dim_0 = arith.constant 2 : index + %dl_0_dim_1 = arith.constant 1 : index + %tile_0_block_size_dim_0 = arith.divsi %tile_0_dim_0, %dl_0_dim_0 : index + %tile_0_block_size_dim_1 = arith.divsi %tile_0_dim_1, %dl_0_dim_1 : index + %x0 = arith.remsi %sg_x, %dl_0_dim_0 : index + %y0 = arith.remsi %sg_y, %dl_0_dim_1 : index + %tmp_offset_0_dim_0 = arith.muli %x0, %tile_0_block_size_dim_0 : index + %tmp_offset_0_dim_1 = arith.muli %y0, %tile_0_block_size_dim_1 : index + %offset_0_dim_0 = arith.addi %i, %tmp_offset_0_dim_0 : index + %offset_0_dim_1 = arith.addi %c0, %tmp_offset_0_dim_1: index + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xetile.init_tile %a[%offset_0_dim_0, %offset_0_dim_1] : memref<1024x1024xf16> -> !xetile.tile<8x4x8x16xf16> + + %tile_1_dim_0 = arith.constant 64 : index + %tile_1_dim_1 = arith.constant 128 : index + %dl_1_dim_0 = arith.constant 1 : index + %dl_1_dim_1 = arith.constant 2 : index + %tile_1_block_size_dim_0 = arith.divsi %tile_1_dim_0, %dl_1_dim_0 : index + %tile_1_block_size_dim_1 = arith.divsi %tile_1_dim_1, %dl_1_dim_1 : index + %x1 = arith.remsi %sg_x, %dl_1_dim_0 : index + %y1 = arith.remsi %sg_y, %dl_1_dim_1 : index + %tmp_offset_1_dim_0 = arith.muli %x1, %tile_1_block_size_dim_0 : index + %tmp_offset_1_dim_1 = arith.muli %y1, %tile_1_block_size_dim_1 : index + %offset_1_dim_0 = arith.addi %c0, %tmp_offset_1_dim_0 : index + %offset_1_dim_1 = arith.addi %j, %tmp_offset_1_dim_1: index + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: arith.addi {{.*}}: index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xetile.init_tile %b[%offset_1_dim_0, %offset_1_dim_1] : memref<1024x1024xf16> -> !xetile.tile<4x4x16x16xf16> + + //CHECK: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + %3 = arith.constant dense<0.0> : vector<8x4x8x16xf32> + %tmp0, %tmp1, %result = scf.for %k= %c0 to %c1024 step %c64 iter_args(%subA = %1, %subB = %2, %subC = %3) -> (!xetile.tile<8x4x8x16xf16>, !xetile.tile<4x4x16x16xf16>, vector<8x4x8x16xf32>) { + + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %4 = xetile.load_tile %subA : !xetile.tile<8x4x8x16xf16> -> vector<8x4x8x16xf16> + + + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %5 = xetile.load_tile %subB : !xetile.tile<4x4x16x16xf16> -> vector<4x4x16x16xf16> + + //CHECK: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %6 = xetile.tile_mma %4, %5, %subC: vector<8x4x8x16xf16>, vector<4x4x16x16xf16>, vector<8x4x8x16xf32> -> vector<8x4x8x16xf32> + + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %7 = xetile.update_tile_offset %subA, [%c0, %c64] : !xetile.tile<8x4x8x16xf16>, index, index -> !xetile.tile<8x4x8x16xf16> // simply update the type since relative offsets are used + + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %8 = xetile.update_tile_offset %subB, [%c64, %c0] : !xetile.tile<4x4x16x16xf16>, index, index -> !xetile.tile<4x4x16x16xf16> // simply update the type since relative offsets are used + + //CHECK: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + scf.yield %7, %8, %6: !xetile.tile<8x4x8x16xf16>, !xetile.tile<4x4x16x16xf16>, vector<8x4x8x16xf32> // simply update the type + } + + %tile_2_dim_0 = arith.constant 128 : index + %tile_2_dim_1 = arith.constant 128 : index + %dl_2_dim_0 = arith.constant 2 : index + %dl_2_dim_1 = arith.constant 2 : index + %tile_2_block_size_dim_0 = arith.divsi %tile_2_dim_0, %dl_2_dim_0 : index + %tile_2_block_size_dim_1 = arith.divsi %tile_2_dim_1, %dl_2_dim_1 : index + %x2 = arith.remsi %sg_x, %dl_2_dim_0 : index + %y2 = arith.remsi %sg_y, %dl_2_dim_1 : index + + %tmp_offset_2_dim_0 = arith.muli %x2, %tile_2_block_size_dim_0 : index + %tmp_offset_2_dim_1 = arith.muli %y2, %tile_2_block_size_dim_1 : index + %offset_3_dim_0 = arith.addi %i, %tmp_offset_2_dim_0 : index + %offset_3_dim_1 = arith.addi %j, %tmp_offset_2_dim_1: index + + + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc %arg2[%261, %262] {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + + %9 = xetile.init_tile %c[%offset_3_dim_0, %offset_3_dim_1] : memref<1024x1024xf32> -> !xetile.tile<8x4x8x16xf32> + + //CHECK: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xetile.store_tile %result, %9: vector<8x4x8x16xf32>, !xetile.tile<8x4x8x16xf32> + scf.yield + } + scf.yield + } + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_tiled_load_tile.mlir b/test/Conversion/XeTileToXeGPU/sg_level_tiled_load_tile.mlir new file mode 100644 index 000000000..8a7a9fbd0 --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_tiled_load_tile.mlir @@ -0,0 +1,16 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +func.func @sglevel_tiled_load_tile(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<2x1x8x16xf16> + //CHECK: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %2 = xetile.load_tile %1 : !xetile.tile<2x1x8x16xf16> -> vector<2x1x8x16xf16> + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_tiled_scf_for.mlir b/test/Conversion/XeTileToXeGPU/sg_level_tiled_scf_for.mlir new file mode 100644 index 000000000..5a4cb1d86 --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_tiled_scf_for.mlir @@ -0,0 +1,43 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +// CHECK: sglevel +func.func @sglevel_tiled_gemm(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>) { + //CHECK: arith.constant 0 : index + %c0 = arith.constant 0 : index + //CHECK: arith.constant 64 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + //CHECK: arith.constant 0 : index + //CHECK: arith.constant 64 : index + //CHECK: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK: arith.constant 8 : index + //CHECK: arith.constant 64 : index + //CHECK: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xetile.init_tile %a[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<2x1x8x16xf16> + //CHECK: arith.constant dense<0.000000e+00> : vector<8x16xf16> + //CHECK: arith.constant dense<0.000000e+00> : vector<8x16xf16> + %2 = arith.constant dense<0.0> : vector<2x1x8x16xf16> + + //CHECK: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, vector<8x16xf16>, vector<8x16xf16> + %nexta, %res = scf.for %k= %c0 to %c1024 step %c64 iter_args(%subA = %1, %subB = %2) -> (!xetile.tile<2x1x8x16xf16>, vector<2x1x8x16xf16>) { + //CHECK: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK: xegpu.load_nd {{.*}} {mode = vc, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xetile.load_tile %subA : !xetile.tile<2x1x8x16xf16> -> vector<2x1x8x16xf16> + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %5 = xetile.update_tile_offset %subA, [%c0, %c64]: !xetile.tile<2x1x8x16xf16>, index, index -> !xetile.tile<2x1x8x16xf16> + //CHECK: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, vector<8x16xf16>, vector<8x16xf16> + scf.yield %5, %3: !xetile.tile<2x1x8x16xf16>, vector<2x1x8x16xf16> + } + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 64 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %5 = xetile.init_tile %b[%c0, %c64] : memref<1024x1024xf16> -> !xetile.tile<2x1x8x16xf16> + //CHECK: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + xetile.store_tile %res, %5: vector<2x1x8x16xf16>, !xetile.tile<2x1x8x16xf16> + + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_tiled_simple.mlir b/test/Conversion/XeTileToXeGPU/sg_level_tiled_simple.mlir new file mode 100644 index 000000000..cb8d84377 --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_tiled_simple.mlir @@ -0,0 +1,106 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +func.func @sglevel_tiled_gemm(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>, %c: memref<1024x1024xf32>) { + //CHECK: arith.constant 0 : index + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + %tile_0_dim_0 = arith.constant 128 : index + %tile_0_dim_1 = arith.constant 64 : index + %dl_0_dim_0 = arith.constant 2 : index + %dl_0_dim_1 = arith.constant 1 : index + %tile_0_block_size_dim_0 = arith.divsi %tile_0_dim_0, %dl_0_dim_0 : index + %tile_0_block_size_dim_1 = arith.divsi %tile_0_dim_1, %dl_0_dim_1 : index + %x0 = arith.remsi %c0, %dl_0_dim_0 : index + %y0 = arith.remsi %c0, %dl_0_dim_1 : index + %tmp_offset_0_dim_0 = arith.muli %x0, %tile_0_block_size_dim_0 : index + %tmp_offset_0_dim_1 = arith.muli %y0, %tile_0_block_size_dim_1 : index + %offset_0_dim_0 = arith.addi %c64, %tmp_offset_0_dim_0 : index + %offset_0_dim_1 = arith.addi %c0, %tmp_offset_0_dim_1: index + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi %6, %c0_1 : index + //CHECK-NEXT: arith.addi %7, %c0_2 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi %6, %c8 : index + //CHECK-NEXT: arith.addi %7, %c0_3 : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xetile.init_tile %a[%offset_0_dim_0, %offset_0_dim_1] : memref<1024x1024xf16> -> !xetile.tile<2x1x8x16xf16> + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xetile.init_tile %b[%offset_0_dim_0, %offset_0_dim_1] : memref<1024x1024xf16> -> !xetile.tile<1x2x16x16xf16> + + //CHECK: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + %3 = arith.constant dense<0.0> : vector<2x2x8x16xf32> + + //CHECK: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + %nexta, %nextb, %res = scf.for %k= %c0 to %c1024 step %c64 + iter_args(%subA = %1, %subB = %2, %subC = %3) -> (!xetile.tile<2x1x8x16xf16>, !xetile.tile<1x2x16x16xf16>, vector<2x2x8x16xf32>) { + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %4 = xetile.load_tile %subA : !xetile.tile<2x1x8x16xf16> -> vector<2x1x8x16xf16> + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %5 = xetile.load_tile %subB : !xetile.tile<1x2x16x16xf16> -> vector<1x2x16x16xf16> + //CHECK: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %6 = xetile.tile_mma %4, %5, %subC: vector<2x1x8x16xf16>, vector<1x2x16x16xf16>, vector<2x2x8x16xf32> -> vector<2x2x8x16xf32> + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %7 = xetile.update_tile_offset %subA, [%c0, %c64] : !xetile.tile<2x1x8x16xf16>, index, index -> !xetile.tile<2x1x8x16xf16> + //CHECK: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: xegpu.update_nd_offset {{.*}} {mode = vc} : !xegpu.tensor_desc<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %8 = xetile.update_tile_offset %subB, [%c64, %c0] : !xetile.tile<1x2x16x16xf16>, index, index -> !xetile.tile<1x2x16x16xf16> + //CHECK: !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<8x16xf16>, + //CHECK-SAME: !xegpu.tensor_desc<16x16xf16>, !xegpu.tensor_desc<16x16xf16>, + //CHECK-SAME: vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + scf.yield %7, %8, %6: !xetile.tile<2x1x8x16xf16>, !xetile.tile<1x2x16x16xf16>, vector<2x2x8x16xf32> + } + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %5 = xetile.init_tile %c[%offset_0_dim_0, %offset_0_dim_1] : memref<1024x1024xf32> -> !xetile.tile<2x2x8x16xf32> + //CHECK: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xetile.store_tile %res, %5: vector<2x2x8x16xf32>, !xetile.tile<2x2x8x16xf32> + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_tiled_store.mlir b/test/Conversion/XeTileToXeGPU/sg_level_tiled_store.mlir new file mode 100644 index 000000000..7a47e7216 --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_tiled_store.mlir @@ -0,0 +1,170 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +func.func @sglevel_tiled_store(%a: memref<1024x1024xf32>) { + // CHECK: arith.constant 0 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 64 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 0 : index + // CHECK-NEXT: arith.constant 80 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 64 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 8 : index + // CHECK-NEXT: arith.constant 80 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 64 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 16 : index + // CHECK-NEXT: arith.constant 80 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 64 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 24 : index + // CHECK-NEXT: arith.constant 80 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 64 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: arith.constant 80 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 40 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 40 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 40 : index + // CHECK-NEXT: arith.constant 64 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 40 : index + // CHECK-NEXT: arith.constant 80 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 64 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: arith.constant 80 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 56 : index + // CHECK-NEXT: arith.constant 32 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 56 : index + // CHECK-NEXT: arith.constant 48 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 56 : index + // CHECK-NEXT: arith.constant 64 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: arith.constant 56 : index + // CHECK-NEXT: arith.constant 80 : index + // CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32> + %1 = xetile.init_tile %a[0, 32] : memref<1024x1024xf32> -> !xetile.tile<8x4x8x16xf32> + + //CHECK: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + %result = arith.constant dense<0.0>: vector<8x4x8x16xf32> + + + //CHECK: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + //CHECK-NEXT: xegpu.store_nd {{.*}} {mode = vc, {{.*}}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xetile.store_tile %result, %1: vector<8x4x8x16xf32>, !xetile.tile<8x4x8x16xf32> + return +} diff --git a/test/Conversion/XeTileToXeGPU/sg_level_tiled_tile_mma.mlir b/test/Conversion/XeTileToXeGPU/sg_level_tiled_tile_mma.mlir new file mode 100644 index 000000000..9a0e4cc76 --- /dev/null +++ b/test/Conversion/XeTileToXeGPU/sg_level_tiled_tile_mma.mlir @@ -0,0 +1,496 @@ +// RUN: imex-opt --split-input-file --convert-xetile-to-xegpu --remove-dead-values %s -verify-diagnostics -o -| FileCheck %s +func.func @sglevel_tiled_gemm(%a: memref<1024x1024xf16>, %b: memref<1024x1024xf16>) { + //CHECK: arith.constant 0 : index + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + + %tile_0_dim_0 = arith.constant 128 : index + %tile_0_dim_1 = arith.constant 64 : index + %dl_0_dim_0 = arith.constant 2 : index + %dl_0_dim_1 = arith.constant 1 : index + %tile_0_block_size_dim_0 = arith.divsi %tile_0_dim_0, %dl_0_dim_0 : index + %tile_0_block_size_dim_1 = arith.divsi %tile_0_dim_1, %dl_0_dim_1 : index + %x0 = arith.remsi %c0, %dl_0_dim_0 : index + %y0 = arith.remsi %c0, %dl_0_dim_1 : index + %tmp_offset_0_dim_0 = arith.muli %x0, %tile_0_block_size_dim_0 : index + %tmp_offset_0_dim_1 = arith.muli %y0, %tile_0_block_size_dim_1 : index + %offset_0_dim_0 = arith.addi %c64, %tmp_offset_0_dim_0 : index + %offset_0_dim_1 = arith.addi %c0, %tmp_offset_0_dim_1: index + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 8 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 24 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 40 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK-NEXT: arith.constant 56 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xetile.init_tile %a[%offset_0_dim_0, %offset_0_dim_1] : memref<1024x1024xf16> -> !xetile.tile<8x4x8x16xf16> + + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 1, {{.*}}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16> + %2 = xetile.load_tile %1 : !xetile.tile<8x4x8x16xf16> -> vector<8x4x8x16xf16> + + %tile_1_dim_0 = arith.constant 64 : index + %tile_1_dim_1 = arith.constant 128 : index + %dl_1_dim_0 = arith.constant 1 : index + %dl_1_dim_1 = arith.constant 2 : index + %tile_1_block_size_dim_0 = arith.divsi %tile_1_dim_0, %dl_1_dim_0 : index + %tile_1_block_size_dim_1 = arith.divsi %tile_1_dim_1, %dl_1_dim_1 : index + %x1 = arith.remsi %c0, %dl_1_dim_0 : index + %y1 = arith.remsi %c0, %dl_1_dim_1 : index + %tmp_offset_1_dim_0 = arith.muli %x1, %tile_1_block_size_dim_0 : index + %tmp_offset_1_dim_1 = arith.muli %y1, %tile_1_block_size_dim_1 : index + %offset_1_dim_0 = arith.addi %c0, %tmp_offset_1_dim_0 : index + %offset_1_dim_1 = arith.addi %c0, %tmp_offset_1_dim_1: index + + //CHECK: arith.constant 0 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 0 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 16 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 32 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.constant 48 : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: arith.addi {{.*}} : index + //CHECK-NEXT: xegpu.create_nd_tdesc {{.*}} {mode = vc, boundary_check = true} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16> + %3 = xetile.init_tile %b[%offset_1_dim_0, %offset_1_dim_1] : memref<1024x1024xf16> -> !xetile.tile<4x4x16x16xf16> + + //CHECK: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + //CHECK-NEXT: xegpu.load_nd {{.*}} {mode = vc, vnni_axis = 0, {{.*}}} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16> + %4 = xetile.load_tile %3 : !xetile.tile<4x4x16x16xf16> -> vector<4x4x16x16xf16> + + + //CHECK: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + //CHECK-NEXT: arith.constant dense<0.000000e+00> : vector<8x16xf32> + %subC = arith.constant dense<0.0> : vector<8x4x8x16xf32> + + + //CHECK: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + //CHECK-NEXT: xegpu.dpas {{.*}} {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> + %6 = xetile.tile_mma %2, %4, %subC: vector<8x4x8x16xf16>, vector<4x4x16x16xf16>, vector<8x4x8x16xf32> -> vector<8x4x8x16xf32> + + + return +} diff --git a/test/Dialect/XeTile/Transforms/sg_gemm_1k_1k_1k_f16_f32.mlir b/test/Dialect/XeTile/Transforms/sg_gemm_1k_1k_1k_f16_f32.mlir new file mode 100644 index 000000000..231c7471b --- /dev/null +++ b/test/Dialect/XeTile/Transforms/sg_gemm_1k_1k_1k_f16_f32.mlir @@ -0,0 +1,65 @@ +// RUN: imex-opt --xetile-tiling %s | FileCheck %s + +// CHECK-LABEL: func @test_gemm({{.*}}) { +func.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // %c8 = arith.constant 8 : index + // %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c64 : index + %n = arith.muli %block_id_y, %c64 : index + // intialize C tile and load it + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf32> -> !xetile.tile<8x4x8x16xf32> + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<64x64xf32> + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<8x4x8x16xf32> -> vector<8x4x8x16xf32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<64x64xf32> -> vector<64x64xf32> + // initalize A and B tiles + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<8x4x8x16xf16> + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<64x64xf16> + // CHECK : xetile.init_tile + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<4x4x16x16xf16> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<64x64xf16> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + // CHECK: (!xetile.tile<8x4x8x16xf16>, !xetile.tile<4x4x16x16xf16>, vector<8x4x8x16xf32>) + %out:3 = scf.for %k = %c0 to %c1024 step %c64 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<64x64xf16>, !xetile.tile<64x64xf16>, vector<64x64xf32>) { + + // load A and B tiles + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<8x4x8x16xf16> -> vector<8x4x8x16xf16> + %a_value = xetile.load_tile %a_tile : !xetile.tile<64x64xf16> -> vector<64x64xf16> + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<4x4x16x16xf16> -> vector<4x4x16x16xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<64x64xf16> -> vector<64x64xf16> + // perform dpas and accumulate + // CHECK: xetile.tile_mma + // CHECK-SAME: vector<8x4x8x16xf16>, vector<4x4x16x16xf16>, vector<8x4x8x16xf32> -> vector<8x4x8x16xf32> + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value : vector<64x64xf16>, vector<64x64xf16>, vector<64x64xf32> -> vector<64x64xf32> + // update the offsets for A and B tiles + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<8x4x8x16xf16>, index, index -> !xetile.tile<8x4x8x16xf16> + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c64] + : !xetile.tile<64x64xf16>, index, index -> !xetile.tile<64x64xf16> + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<4x4x16x16xf16>, index, index -> !xetile.tile<4x4x16x16xf16> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c64, %c0] + : !xetile.tile<64x64xf16>, index, index -> !xetile.tile<64x64xf16> + // partial C tile result + // CHECK: !xetile.tile<8x4x8x16xf16>, !xetile.tile<4x4x16x16xf16>, vector<8x4x8x16xf32> + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<64x64xf16>, !xetile.tile<64x64xf16>, vector<64x64xf32> + } + // store the final accumulated C tile result back to memory + // CHECK: vector<8x4x8x16xf32>, !xetile.tile<8x4x8x16xf32> + xetile.store_tile %out#2, %c_init_tile: vector<64x64xf32>, !xetile.tile<64x64xf32> + return +} diff --git a/test/Dialect/XeTile/Transforms/sg_gemm_1k_1k_1k_i8_i32.mlir b/test/Dialect/XeTile/Transforms/sg_gemm_1k_1k_1k_i8_i32.mlir new file mode 100644 index 000000000..0b035e3f4 --- /dev/null +++ b/test/Dialect/XeTile/Transforms/sg_gemm_1k_1k_1k_i8_i32.mlir @@ -0,0 +1,68 @@ +// RUN: imex-opt --xetile-tiling %s | FileCheck %s + + +// CHECK-LABEL: func @test_gemm({{.*}}) { +func.func @test_gemm(%A: memref<1024x1024xi8>, %B: memref<1024x1024xi8>, %C: memref<1024x1024xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // %c8 = arith.constant 8 : index + // %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c64 : index + %n = arith.muli %block_id_y, %c64 : index + // intialize C tile and load it + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xi32> -> !xetile.tile<8x4x8x16xi32> + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xi32> -> !xetile.tile<64x64xi32> + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<8x4x8x16xi32> -> vector<8x4x8x16xi32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<64x64xi32> -> vector<64x64xi32> + // initalize A and B tiles + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xi8> -> !xetile.tile<8x4x8x16xi8> + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xi8> -> !xetile.tile<64x64xi8> + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xi8> -> !xetile.tile<4x4x16x16xi8> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xi8> -> !xetile.tile<64x64xi8> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + // CHECK: scf.for + // CHECK-SAME: !xetile.tile<8x4x8x16xi8>, !xetile.tile<4x4x16x16xi8>, vector<8x4x8x16xi32> + %out:3 = scf.for %k = %c0 to %c1024 step %c64 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<64x64xi8>, !xetile.tile<64x64xi8>, vector<64x64xi32>) { + + // load A and B tiles + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<8x4x8x16xi8> -> vector<8x4x8x16xi8> + %a_value = xetile.load_tile %a_tile : !xetile.tile<64x64xi8> -> vector<64x64xi8> + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<4x4x16x16xi8> -> vector<4x4x16x16xi8> + %b_value = xetile.load_tile %b_tile : !xetile.tile<64x64xi8> -> vector<64x64xi8> + // perform dpas and accumulate + // CHECK: xetile.tile_mma + // CHECK-SAME: vector<8x4x8x16xi8>, vector<4x4x16x16xi8>, vector<8x4x8x16xi32> -> vector<8x4x8x16xi32> + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value : vector<64x64xi8>, vector<64x64xi8>, vector<64x64xi32> -> vector<64x64xi32> + // update the offsets for A and B tiles + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<8x4x8x16xi8>, index, index -> !xetile.tile<8x4x8x16xi8> + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c64] + : !xetile.tile<64x64xi8>, index, index -> !xetile.tile<64x64xi8> + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<4x4x16x16xi8>, index, index -> !xetile.tile<4x4x16x16xi8> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c64, %c0] + : !xetile.tile<64x64xi8>, index, index -> !xetile.tile<64x64xi8> + // partial C tile result + // CHECK: scf.yield + // CHECK-SAME: !xetile.tile<8x4x8x16xi8>, !xetile.tile<4x4x16x16xi8>, vector<8x4x8x16xi32> + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<64x64xi8>, !xetile.tile<64x64xi8>, vector<64x64xi32> + } + // store the final accumulated C tile result back to memory + // CHECK: xetile.store_tile + // CHECK-SAME: vector<8x4x8x16xi32>, !xetile.tile<8x4x8x16xi32> + xetile.store_tile %out#2, %c_init_tile {innner_blocks = [8, 16]}: vector<64x64xi32>, !xetile.tile<64x64xi32> + return +} diff --git a/test/Dialect/XeTile/Transforms/sg_gemm_4k_4k_4k_f16_f32.mlir b/test/Dialect/XeTile/Transforms/sg_gemm_4k_4k_4k_f16_f32.mlir new file mode 100644 index 000000000..4f8aaf1ac --- /dev/null +++ b/test/Dialect/XeTile/Transforms/sg_gemm_4k_4k_4k_f16_f32.mlir @@ -0,0 +1,65 @@ +// RUN: imex-opt --xetile-tiling %s | FileCheck %s + +// CHECK-LABEL: func @test_gemm({{.*}}) { +func.func @test_gemm(%A: memref<4096x4096xf16>, %B: memref<4096x4096xf16>, %C: memref<4096x4096xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c4096 = arith.constant 4096 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c64 : index + %n = arith.muli %block_id_y, %c128 : index + // intialize C tile and load it + // CHECK: xetile.init_tile + // CHECK-SAME: memref<4096x4096xf32> -> !xetile.tile<8x8x8x16xf32> + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xf32> -> !xetile.tile<64x128xf32> + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<8x8x8x16xf32> -> vector<8x8x8x16xf32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<64x128xf32> -> vector<64x128xf32> + // initalize A and B tiles + // CHECK: xetile.init_tile + // CHECK-SAME: memref<4096x4096xf16> -> !xetile.tile<8x8x8x16xf16> + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xf16> -> !xetile.tile<64x128xf16> + // CHECK : xetile.init_tile %arg1[%c0, %3] : memref<4096x4096xf16> -> !xetile.tile<8x8x16x16xf16> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xf16> -> !xetile.tile<128x128xf16> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + // CHECK: !xetile.tile<8x8x8x16xf16>, !xetile.tile<8x8x16x16xf16>, vector<8x8x8x16xf32> + %out:3 = scf.for %k = %c0 to %c4096 step %c128 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<64x128xf16>, !xetile.tile<128x128xf16>, vector<64x128xf32>) { + + // load A and B tiles + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<8x8x8x16xf16> -> vector<8x8x8x16xf16> + %a_value = xetile.load_tile %a_tile : !xetile.tile<64x128xf16> -> vector<64x128xf16> + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<8x8x16x16xf16> -> vector<8x8x16x16xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<128x128xf16> -> vector<128x128xf16> + // perform dpas and accumulate + // CHECK: xetile.tile_mma + // CHECK-SAME : vector<8x8x8x16xf16>, vector<8x8x16x16xf16>, vector<8x8x8x16xf32> -> vector<8x8x8x16xf32> + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<64x128xf16>, vector<128x128xf16>, vector<64x128xf32> -> vector<64x128xf32> + // update the offsets for A and B tiles + // CHECK: xetile.update_tile_offset + // CHECK-SAME : !xetile.tile<8x8x8x16xf16>, index, index -> !xetile.tile<8x8x8x16xf16> + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c128] + : !xetile.tile<64x128xf16>, index, index -> !xetile.tile<64x128xf16> + // CHECK: xetile.update_tile_offset + // CHECK-SAME : !xetile.tile<8x8x16x16xf16>, index, index -> !xetile.tile<8x8x16x16xf16> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c128, %c0] + : !xetile.tile<128x128xf16>, index, index -> !xetile.tile<128x128xf16> + // partial C tile result + // CHECK: !xetile.tile<8x8x8x16xf16>, !xetile.tile<8x8x16x16xf16>, vector<8x8x8x16xf32> + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<64x128xf16>, !xetile.tile<128x128xf16>, vector<64x128xf32> + } + // store the final accumulated C tile result back to memory + // CHECK: xetile.store_tile + // CHECK-SAME: vector<8x8x8x16xf32>, !xetile.tile<8x8x8x16xf32> + xetile.store_tile %out#2, %c_init_tile : vector<64x128xf32>, !xetile.tile<64x128xf32> + return +} diff --git a/test/Dialect/XeTile/Transforms/sg_gemm_4k_4k_4k_i8_i32.mlir b/test/Dialect/XeTile/Transforms/sg_gemm_4k_4k_4k_i8_i32.mlir new file mode 100644 index 000000000..ef3c5dc1a --- /dev/null +++ b/test/Dialect/XeTile/Transforms/sg_gemm_4k_4k_4k_i8_i32.mlir @@ -0,0 +1,67 @@ +// RUN: imex-opt --xetile-tiling %s | FileCheck %s + + +// CHECK-LABEL: func @test_gemm({{.*}}) { +func.func @test_gemm(%A: memref<4096x4096xi8>, %B: memref<4096x4096xi8>, %C: memref<4096x4096xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // %c8 = arith.constant 8 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c128 : index + %n = arith.muli %block_id_y, %c256 : index + // intialize C tile and load it + // CHECK: xetile.init_tile + // CHECK-SAME: memref<4096x4096xi32> -> !xetile.tile<16x16x8x16xi32> + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<4096x4096xi32> -> !xetile.tile<128x256xi32> + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<16x16x8x16xi32> -> vector<16x16x8x16xi32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<128x256xi32> -> vector<128x256xi32> + // initalize A and B tiles + // CHECK: xetile.init_tile + // CHECK-SAME: memref<4096x4096xi8> -> !xetile.tile<16x16x8x16xi8> + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<4096x4096xi8> -> !xetile.tile<128x256xi8> + // CHECK: xetile.init_tile + // CHECK-SAME: memref<4096x4096xi8> -> !xetile.tile<16x16x16x16xi8> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<4096x4096xi8> -> !xetile.tile<256x256xi8> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + // CHECK: (!xetile.tile<16x16x8x16xi8>, !xetile.tile<16x16x16x16xi8>, vector<16x16x8x16xi32>) + %out:3 = scf.for %k = %c0 to %c4096 step %c256 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<128x256xi8>, !xetile.tile<256x256xi8>, vector<128x256xi32>) { + + // load A and B tiles + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<16x16x8x16xi8> -> vector<16x16x8x16xi8> + %a_value = xetile.load_tile %a_tile : !xetile.tile<128x256xi8> -> vector<128x256xi8> + // CHECK: xetile.load_tile + // CHECK-SAME: !xetile.tile<16x16x16x16xi8> -> vector<16x16x16x16xi8> + %b_value = xetile.load_tile %b_tile : !xetile.tile<256x256xi8> -> vector<256x256xi8> + // perform dpas and accumulate + // CHECK: xetile.tile_mma + // CHECK-SAME: vector<16x16x8x16xi8>, vector<16x16x16x16xi8>, vector<16x16x8x16xi32> -> vector<16x16x8x16xi32> + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : vector<128x256xi8>, vector<256x256xi8>, vector<128x256xi32> -> vector<128x256xi32> + // update the offsets for A and B tiles + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<16x16x8x16xi8>, index, index -> !xetile.tile<16x16x8x16xi8> + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c256] + : !xetile.tile<128x256xi8>, index, index -> !xetile.tile<128x256xi8> + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<16x16x16x16xi8>, index, index -> !xetile.tile<16x16x16x16xi8> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c256, %c0] + : !xetile.tile<256x256xi8>, index, index -> !xetile.tile<256x256xi8> + // partial C tile result + // CHECK: !xetile.tile<16x16x8x16xi8>, !xetile.tile<16x16x16x16xi8>, vector<16x16x8x16xi32> + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<128x256xi8>, !xetile.tile<256x256xi8>, vector<128x256xi32> + } + // store the final accumulated C tile result back to memory + // CHECK: xetile.store_tile + // CHECK-SAME: vector<16x16x8x16xi32>, !xetile.tile<16x16x8x16xi32> + xetile.store_tile %out#2, %c_init_tile : vector<128x256xi32>, !xetile.tile<128x256xi32> + return +}