From 4e8cf6f634fea5e912c49dbb2a8063dacfbecc61 Mon Sep 17 00:00:00 2001 From: "Gusthinna Waduge, Charitha Saumya" Date: Wed, 1 Nov 2023 12:30:36 -0700 Subject: [PATCH] Add attributes required for work group level XeTile programming. Following changes are introduced. - WG and SG levels layout attributes for XeTile - Custom printers/parsers for XeTile ops (needed by XeTileToXeGPU lowering) - Additional verifications for tile_mma that were missing in initial version - Updated test cases --- include/imex/Dialect/XeTile/IR/CMakeLists.txt | 5 + include/imex/Dialect/XeTile/IR/XeTileAttrs.td | 96 ++++ include/imex/Dialect/XeTile/IR/XeTileBase.td | 151 ------ .../imex/Dialect/XeTile/IR/XeTileDialect.td | 60 +++ include/imex/Dialect/XeTile/IR/XeTileOps.h | 12 +- include/imex/Dialect/XeTile/IR/XeTileOps.td | 231 ++++---- include/imex/Dialect/XeTile/IR/XeTileTypes.td | 126 +++++ lib/Dialect/XeTile/IR/CMakeLists.txt | 2 + lib/Dialect/XeTile/IR/XeTileDialect.cpp | 188 +++++++ lib/Dialect/XeTile/IR/XeTileOps.cpp | 506 ++++++++++++++---- test/Dialect/XeTile/IR/XeTileOps.mlir | 225 -------- test/Dialect/XeTile/IR/invalid.mlir | 76 ++- test/Dialect/XeTile/IR/ops.mlir | 277 ++++++++++ test/Dialect/XeTile/IR/simple_gemm.mlir | 110 ++-- 14 files changed, 1428 insertions(+), 637 deletions(-) create mode 100644 include/imex/Dialect/XeTile/IR/XeTileAttrs.td delete mode 100644 include/imex/Dialect/XeTile/IR/XeTileBase.td create mode 100644 include/imex/Dialect/XeTile/IR/XeTileDialect.td create mode 100644 include/imex/Dialect/XeTile/IR/XeTileTypes.td create mode 100644 lib/Dialect/XeTile/IR/XeTileDialect.cpp delete mode 100644 test/Dialect/XeTile/IR/XeTileOps.mlir create mode 100644 test/Dialect/XeTile/IR/ops.mlir diff --git a/include/imex/Dialect/XeTile/IR/CMakeLists.txt b/include/imex/Dialect/XeTile/IR/CMakeLists.txt index c14516249..c202d0eb4 100644 --- a/include/imex/Dialect/XeTile/IR/CMakeLists.txt +++ b/include/imex/Dialect/XeTile/IR/CMakeLists.txt @@ -1,2 +1,7 @@ add_mlir_dialect(XeTileOps xetile) add_mlir_doc(XeTileOps XeTileDialect Dialects/ -gen-dialect-doc -dialect=xetile) + +set(LLVM_TARGET_DEFINITIONS XeTileOps.td) +mlir_tablegen(XeTileOpsAttrs.h.inc -gen-attrdef-decls) +mlir_tablegen(XeTileOpsAttrs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(XeTileOpsAttrsIncGen) diff --git a/include/imex/Dialect/XeTile/IR/XeTileAttrs.td b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td new file mode 100644 index 000000000..3bb53217f --- /dev/null +++ b/include/imex/Dialect/XeTile/IR/XeTileAttrs.td @@ -0,0 +1,96 @@ +//===------------ XeTileAttr.td - XeTile dialect -------*- tablegen -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines custom attributes used by XeTile dialect. +/// +//===----------------------------------------------------------------------===// + +#ifndef _XETILE_ATTR_DEF_TD_INCLUDED_ +#define _XETILE_ATTR_DEF_TD_INCLUDED_ + +include "mlir/IR/AttrTypeBase.td" +include "imex/Dialect/XeTile/IR/XeTileDialect.td" + +class XeTile_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +def XeTile_SubGroupMapAttr : XeTile_Attr<"SubGroupMap", "sg_map"> { + let parameters = (ins + OptionalParameter<"mlir::DenseI32ArrayAttr">:$mma_block_size, + "mlir::DenseI32ArrayAttr":$wi_layout, + "mlir::DenseI32ArrayAttr":$wi_data + ); + let assemblyFormat = "`<` struct(params) `>`"; + let genVerifyDecl = true; + let builders = [ + AttrBuilder<(ins "llvm::ArrayRef":$mma_block_size, + "llvm::ArrayRef":$wi_layout, + "llvm::ArrayRef":$wi_data), + [{ + return $_get($_ctxt, mlir::DenseI32ArrayAttr::get($_ctxt, mma_block_size), + mlir::DenseI32ArrayAttr::get($_ctxt, wi_layout), + mlir::DenseI32ArrayAttr::get($_ctxt, wi_data)); + }]>, + AttrBuilder<(ins "llvm::ArrayRef":$wi_layout, + "llvm::ArrayRef":$wi_data), + [{ + return $_get($_ctxt, mlir::DenseI32ArrayAttr(), mlir::DenseI32ArrayAttr::get($_ctxt, wi_layout), + mlir::DenseI32ArrayAttr::get($_ctxt, wi_data)); + }]> + ]; +} + +def XeTile_WorkGroupMapAttr : XeTile_Attr<"WorkGroupMap", "wg_map"> { + let parameters = (ins + "mlir::DenseI32ArrayAttr":$sg_layout, + "mlir::DenseI32ArrayAttr":$sg_data + ); + let assemblyFormat = "`<` struct(params) `>`"; + let genVerifyDecl = true; + let builders = [ + AttrBuilder<(ins "llvm::ArrayRef":$sg_layout, + "llvm::ArrayRef":$sg_data), + [{ + return $_get($_ctxt, mlir::DenseI32ArrayAttr::get($_ctxt, sg_layout), + mlir::DenseI32ArrayAttr::get($_ctxt, sg_data)); + }]> + ]; +} + +def XeTile_XeMapAttr : XeTile_Attr<"XeMap", "xe_map"> { + let parameters = (ins + XeTile_WorkGroupMapAttr:$wg, + XeTile_SubGroupMapAttr:$sg + ); + let assemblyFormat = "`<` struct(params) `>`"; + let builders = [ + AttrBuilder<(ins "llvm::ArrayRef":$mma_block_size, + "llvm::ArrayRef":$wi_layout, + "llvm::ArrayRef":$wi_data, + "llvm::ArrayRef":$sg_layout, + "llvm::ArrayRef":$sg_data), + [{ + return $_get($_ctxt, WorkGroupMapAttr::get($_ctxt, sg_layout, sg_data), + SubGroupMapAttr::get($_ctxt, mma_block_size, wi_layout, wi_data)) ; + }]>, + AttrBuilder<(ins "llvm::ArrayRef":$wi_layout, + "llvm::ArrayRef":$wi_data, + "llvm::ArrayRef":$sg_layout, + "llvm::ArrayRef":$sg_data), + [{ + return $_get($_ctxt, WorkGroupMapAttr::get($_ctxt, sg_layout, sg_data), + SubGroupMapAttr::get($_ctxt, wi_layout, wi_data)) ; + }]> + ]; +} + +#endif // _XETILE_ATTR_DEF_TD_INCLUDED_ diff --git a/include/imex/Dialect/XeTile/IR/XeTileBase.td b/include/imex/Dialect/XeTile/IR/XeTileBase.td deleted file mode 100644 index d2a85430f..000000000 --- a/include/imex/Dialect/XeTile/IR/XeTileBase.td +++ /dev/null @@ -1,151 +0,0 @@ -//===- XeTileOps.td - XeTile dialect -------*- tablegen -*-===// -// -// Copyright 2022 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines the XeTile dialect and its types. -/// -//===----------------------------------------------------------------------===// - -#ifndef _XeTile_BASE_TD_INCLUDED_ -#define _XeTile_BASE_TD_INCLUDED_ - -include "mlir/IR/OpBase.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/BuiltinTypes.td" -include "mlir/IR/BuiltinTypeInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/CastInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/CopyOpInterface.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/ShapedOpInterfaces.td" - -// Provide a definition of the 'XeTile' dialect in the ODS framework so that we -// can define our operations. -def XeTile_Dialect : Dialect { - // The namespace of our dialect - let name = "xetile"; - - // A short one-line summary - let summary = "A dialect for enabling tile-base programming at subgroup level"; - - // A longer description - let description = [{ - XeTile provides an abstraction supporting tile-based computation to simplify the - lowering of DNN operation like matrix multiplication. XeTile dialect works at tile sizes - that are larger than the tile sizes supported by the hardware. XeTile dilaect also hides - the auto-padding requirements for out-of-bound memory accesses and, supports arbitrary - input matrix sizes. - }]; - - // The C++ namespace that the dialect class definition resides in. - let cppNamespace = "::imex::xetile"; - - let dependentDialects = [ - "::mlir::memref::MemRefDialect"]; - - // TODO: temporary disable it. - let useDefaultTypePrinterParser = true; -} - -// Base class for dialect operations. This operation inherits from the base -// `Op` class in OpBase.td, and provides: -// * The parent dialect of the operation. -// * The mnemonic for the operation, or the name without the dialect prefix. -// * A list of traits for the operation. -class XeTile_Op traits = []> : - Op; - -// common base class for types in XeTile dialect -class XeTile_Type traits = [], - string baseCppClass = "::mlir::Type"> - : TypeDef { - let mnemonic = typeMnemonic; -} - -def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface], - "::imex::xetile::TileBase"> -{ - let summary = "A type representing a 2D tile"; - let description = [{ - Tile data type in XeTile dialect is used to represent a 2D memory region. - This captures the 2d shape and type of the memory region it points to. - - Syntax: - - ``` - tile-type ::= `vector` `<` vector-dim-list vector-element-type `>` - tile-element-type ::= float-type | integer-type | index-type - tile-dim-list := (static-dim-list `x`)? - static-dim-list ::= decimal-literal `x` decimal-literal - ``` - - Examples: - - ```mlir - // A tile with i32 elements - tile<3x42xi32> - - // A tile with f32 elements - tile<4x5xf32> - ``` - }]; - - let parameters = (ins ArrayRefParameter<"int64_t">:$shape, - "::mlir::Type":$elementType); - - let builders = [ - TypeBuilderWithInferredContext<(ins - "::llvm::ArrayRef":$shape, "::mlir::Type":$elementType), [{ - assert(shape.size()==2); - return $_get(elementType.getContext(), shape, elementType); - }]>, - TypeBuilderWithInferredContext<(ins - "int64_t":$dim0, "int64_t":$dim1, "::mlir::Type":$elementType), [{ - llvm::SmallVector shape{dim0, dim1}; - assert(shape.size()==2); - return $_get(elementType.getContext(), shape, elementType); - }]> - ]; - - let extraClassDeclaration = [{ - using ::mlir::ShapedType::Trait::clone; - using ::mlir::ShapedType::Trait::getElementTypeBitWidth; - using ::mlir::ShapedType::Trait::getRank; - using ::mlir::ShapedType::Trait::getNumElements; - using ::mlir::ShapedType::Trait::isDynamicDim; - using ::mlir::ShapedType::Trait::hasStaticShape; - using ::mlir::ShapedType::Trait::getNumDynamicDims; - using ::mlir::ShapedType::Trait::getDimSize; - using ::mlir::ShapedType::Trait::getDynamicDimIndex; - }]; - - let assemblyFormat = "`<` custom($shape, $elementType) `>`"; -} - -// Integer types allowd in XeTile -def XeTile_IntType : AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>; - -// Float types allowed in XeTile -def XeTile_FloatType : AnyTypeOf<[F16, F32, F64, BF16, F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2FNUZ]>; - -// Define the scalar type for XeTile -def XeTile_ScalarType : AnyTypeOf<[XeTile_IntType, XeTile_FloatType]>; - -// define a 2D memref of XeTile scalar type -def XeTile_2DMemRef : MemRefRankOf<[XeTile_ScalarType], [2]>; - -// define the source type for XeTile init_tile -def XeTile_BaseAddrType : AnyTypeOf<[XeTile_2DMemRef, I64]>; - -// define the attribute type allowed for padding values for load op -def XeTile_PaddingValueAttr : AnyAttrOf<[I32Attr, F32Attr]>; - -#endif // _XeTile_BASE_TD_INCLUDED_ diff --git a/include/imex/Dialect/XeTile/IR/XeTileDialect.td b/include/imex/Dialect/XeTile/IR/XeTileDialect.td new file mode 100644 index 000000000..a2511c1a1 --- /dev/null +++ b/include/imex/Dialect/XeTile/IR/XeTileDialect.td @@ -0,0 +1,60 @@ +//===--------------- XeTileOps.td - XeTile dialect -------*- tablegen -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the XeTile dialect. +/// +//===----------------------------------------------------------------------===// + +#ifndef _XETILE_BASE_TD_INCLUDED_ +#define _XETILE_BASE_TD_INCLUDED_ + +include "mlir/IR/OpBase.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/ShapedOpInterfaces.td" + +// Provide a definition of the 'XeTile' dialect in the ODS framework so that we +// can define our operations. +def XeTile_Dialect : Dialect { + // The namespace of our dialect + let name = "xetile"; + + // A short one-line summary + let summary = "A dialect for enabling tile-base programming at subgroup level"; + + // A longer description + let description = [{ + XeTile provides an abstraction supporting tile-based computation to simplify the + lowering of DNN operation like matrix multiplication. XeTile dialect works at tile sizes + that are larger than the tile sizes supported by the hardware. XeTile dilaect also hides + the auto-padding requirements for out-of-bound memory accesses and, supports arbitrary + input matrix sizes. + }]; + + // The C++ namespace that the dialect class definition resides in. + let cppNamespace = "::imex::xetile"; + + let dependentDialects = [ + "::mlir::memref::MemRefDialect"]; + + // TODO: temporary disable it. + let useDefaultTypePrinterParser = true; + let useDefaultAttributePrinterParser = true; +} + + +#endif // _XETILE_BASE_TD_INCLUDED_ diff --git a/include/imex/Dialect/XeTile/IR/XeTileOps.h b/include/imex/Dialect/XeTile/IR/XeTileOps.h index 6211a5750..02903d123 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileOps.h +++ b/include/imex/Dialect/XeTile/IR/XeTileOps.h @@ -1,4 +1,4 @@ -//===- XeTileOps.h - XeTile dialect -------*- C++ -*-===// +//===---------------------- XeTileOps.h - XeTile dialect -------*- C++ -*-===// // // Copyright 2022 Intel Corporation // Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. @@ -12,8 +12,8 @@ /// //===----------------------------------------------------------------------===// -#ifndef _XeTile_OPS_H_INCLUDED_ -#define _XeTile_OPS_H_INCLUDED_ +#ifndef _XETILE_OPS_H_INCLUDED_ +#define _XETILE_OPS_H_INCLUDED_ #include #include @@ -54,16 +54,18 @@ class TileBase : public mlir::Type, public mlir::ShapedType::Trait { /// Clone this type with the given shape and element type. If the /// provided shape is `None`, the current shape of the type is used. TileBase cloneWith(std::optional> shape, - mlir::Type elementType) const; + mlir::Type elementType, mlir::Attribute encoding) const; }; } // namespace xetile } // namespace imex #include +#define GET_ATTRDEF_CLASSES +#include #define GET_TYPEDEF_CLASSES #include #define GET_OP_CLASSES #include -#endif // _XeTile_OPS_H_INCLUDED_ +#endif // _XETILE_OPS_H_INCLUDED_ diff --git a/include/imex/Dialect/XeTile/IR/XeTileOps.td b/include/imex/Dialect/XeTile/IR/XeTileOps.td index bb3ea3b75..87c2d47b4 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileOps.td +++ b/include/imex/Dialect/XeTile/IR/XeTileOps.td @@ -1,4 +1,4 @@ -//===- XeTileOps.td - XeTile dialect -------*- tablegen -*-===// +//===---------------- XeTileOps.td - XeTile dialect -------*- tablegen -*-===// // // Copyright 2022 Intel Corporation // Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. @@ -11,10 +11,20 @@ /// This file defines the operations for the XeTile dialect. /// //===----------------------------------------------------------------------===// -#ifndef _XeTile_OPS_TD_INCLUDED_ -#define _XeTile_OPS_TD_INCLUDED_ +#ifndef _XETILE_OPS_TD_INCLUDED_ +#define _XETILE_OPS_TD_INCLUDED_ -include "imex/Dialect/XeTile/IR/XeTileBase.td" +include "imex/Dialect/XeTile/IR/XeTileDialect.td" +include "imex/Dialect/XeTile/IR/XeTileTypes.td" +include "imex/Dialect/XeTile/IR/XeTileAttrs.td" + +// Base class for dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class XeTile_Op traits = []> : + Op; def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]> { let summary = "Describes an XeTile with reference to a base memref"; @@ -24,17 +34,24 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]> memref or an address is used as the base, it is required to specify the shape and strides of the memory region described by the tile. + Optionally, the tile can be described in blocked layout as well. This is done by specifying + an "inner_blocks" attribute which describes the size (rows and cols) of the block. This attribute + is used by later lowering passes to detremine the 2D block load/store sizes. + The operation takes in the following arguments: * source: Source can be static/dynamic shaped memref or an address (i64) * offsets: 2 offsets into the "source" memref or address at which to create the tile. offsets can be operands (e.g., [%c0, %c]), attributes - (e.g., [2, 4]), or mix of operand and attributs (e.g., [%c0, 4] and [2, %c0]). + (e.g., [2, 4]), or mix of operand and attributes (e.g., [%c0, 4] and [2, %c0]). + * dynamic_offsets : This is a subset of "offsets". offsets can contain both static and dynamic + values. "dynamic_offsets" captures the dynamic subset of the offsets. * dynamic_shape : 2 shape arguments specifying the size of 2 dimensions of the "source". This is only required if a dynmaic shaped memref or an address is used as "source". + dynamic_shapes needs to be operands i.e. dynamic SSA values (e.g., [%c128, %c128]). * dynamic_strides : 2 stride arguments specifying the strides of the 2D "source" memory region. This is only required if a dynmaic shaped memref or an address is used as "source". - * dynamic_offsets : This is a subset of "offsets". offsets can contain both static and dynamic - values. "dynamic_offsets" captures the dynamic subset of the offsets. + dynamic_strides needs to be operands i.e. dynamic SSA values (e.g., [%c128, %c1]). + * inner_blocks : Optional 2 element integer array describing [rows, cols] of the blocked layout. For the follwing examples, suppose the tile shape used by the compiler is 32x64. @@ -47,15 +64,25 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]> %2 = xetile.init_tile %0[%c128, 512] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> ``` + Example 2: + Creating an xetile using a static shaped 2D memref and inner blocks attribute. + + ```mlir + %0 = memref.alloc() : memref<1024x1024xf32> + %c128 = arith.constant 128 : index + %2 = xetile.init_tile %0[%c128, 512] { inner_blocks = [8, 16]} : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> + ``` + Example 2: Creating an xetile using a dynamic shaped 2D memref. ```mlir + %c1 = arith.constant 1 : index %c64 = arith.constant 64 : index %c512 = arith.constant 512 : index %c1024 = arith.constant 1024 : index %src = memref.alloc(%c1024, %c512) : memref - %1 = xetile.init_tile %src[256, %c64], [1024, 512], [512, 1] : memref -> !xetile.tile<32x64xf32> + %1 = xetile.init_tile %src[256, %c64], [%c1024, %c1024], [%c1024, %c1] : memref -> !xetile.tile<32x64xf32> ``` Example 3: @@ -64,9 +91,11 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]> ```mlir %src = .... : i64 ... + %c1 = arith.constant 1 : index %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index - %1 = xetile.init_tile %src[%c128, %c256], [1024, 1024], [1024, 1] : i64 -> !xetile.tile<32x64xf32> + %c1024 = arith.constant 1024 : index + %1 = xetile.init_tile %src[%c128, %c256], [%c1024, %c1024], [%c1024, %c1] : i64 -> !xetile.tile<32x64xf32> ``` }]; @@ -83,87 +112,83 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]> let builders = [ // creating init_tile op with static memref OpBuilder<(ins "xetile::TileType":$resultType, - "::mlir::Value":$source, - "::llvm::ArrayRef<::mlir::OpFoldResult>":$offsets, - CArg<"::llvm::ArrayRef<::mlir::NamedAttribute>", "{}">:$attrs)>, + "mlir::Value":$source, + "llvm::ArrayRef":$offsets)>, // creating init_tile op with dynamic memref or an address OpBuilder<(ins "xetile::TileType":$resultType, - "::mlir::Value":$source, - "::llvm::ArrayRef<::mlir::OpFoldResult>":$offsets, - "::llvm::ArrayRef<::mlir::Value>":$dynamic_shape, - "::llvm::ArrayRef<::mlir::Value>":$dynamic_strides, - CArg<"::llvm::ArrayRef<::mlir::NamedAttribute>", "{}">:$attrs - )> + "mlir::Value":$source, + "llvm::ArrayRef":$offsets, + "llvm::ArrayRef":$dynamic_shape, + "llvm::ArrayRef":$dynamic_strides)> ]; - let assemblyFormat = [{ - $source `` - custom($offsets, $static_offsets) - (`,` `[` $dynamic_shape^ `]`)? - (`,` `[` $dynamic_strides^ `]`)? - attr-dict `:` qualified(type($source)) `->` qualified(type($tile)) - }]; + // let assemblyFormat = [{ + // $source `` + // custom($offsets, $static_offsets) + // (`,` `[` $dynamic_shape^ `]`)? + // (`,` `[` $dynamic_strides^ `]`)? + // attr-dict `:` qualified(type($source)) `->` qualified(type($tile)) + // }]; + let hasCustomAssemblyFormat = true; let extraClassDeclaration = [{ /// get source type, could be a memref or an integer - ::mlir::Type getSourceType() {return getSource().getType();} + mlir::Type getSourceType() {return getSource().getType();} /// check if the source is a memref bool isSourceMemRef() { - return ::llvm::isa<::mlir::MemRefType>(getSourceType()); + return llvm::isa(getSourceType()); } /// check if the source is an i64 (i.e. pointer) bool isSourceInteger() { - return ::llvm::isa<::mlir::IntegerType>(getSourceType()); + return llvm::isa(getSourceType()); } /// get the element type of the source if it is a memref /// this method will fail if the source is not a memeref - ::mlir::Type getSourceMemrefElemType() { + mlir::Type getSourceMemrefElemType() { assert(isSourceMemRef() && "The source is not a memref."); - return getSourceType().cast<::mlir::MemRefType>().getElementType(); + return getSourceType().cast().getElementType(); } - - /// The result of an init_tile is always a Tile of TileType. TileType getType() { return getTile().getType().cast(); } /// Return the element type of the tile - ::mlir::Type getElementType() { + mlir::Type getElementType() { return getType().getElementType(); } /// Return the shape of the tile - ::llvm::ArrayRef getShape() { + llvm::ArrayRef getShape() { return getType().getShape(); } /// check if the offsets are static bool hasStaticOffsets() { - return !::mlir::ShapedType::isDynamicShape(getStaticOffsets()); + return !mlir::ShapedType::isDynamicShape(getStaticOffsets()); } /// check if a given dim in static offsets has a static value bool hasStaticOffsetAtDim(int dim) { - return !::mlir::ShapedType::isDynamic(getStaticOffsets()[dim]); + return !mlir::ShapedType::isDynamic(getStaticOffsets()[dim]); } /// check if the source memref has static shape info /// this method will fail if the source is not a memref bool sourceMemRefHasStaticShape() { assert(isSourceMemRef() && "source is not a memref."); - return getSourceType().cast<::mlir::MemRefType>().hasStaticShape(); + return getSourceType().cast().hasStaticShape(); } /// get the static shape of the source memref /// this method will fail if the source is not a memref or has static shape - ::llvm::ArrayRef getSourceMemrefStaticShape() { + llvm::ArrayRef getSourceMemrefStaticShape() { assert(sourceMemRefHasStaticShape() && "The source memref does not have static shape."); - return getSourceType().cast<::mlir::MemRefType>().getShape(); + return getSourceType().cast().getShape(); } /// check if dynamic shape arguments are present @@ -218,11 +243,8 @@ def XeTile_InitCoopTileOp : XeTile_Op<"init_coop_tile", [Pure]> { def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { let summary = "Loads a tile into a register region"; let description = [{ - "load_tile" operation loads the values of a tile into a register region with similar layout. - Optionally the load operation can be performed in blocked layout as well. This is done by - specifying an "inner_blocks" attribute which describes the size (rows and cols) of the block. - Blocking does not change the order of the outer dimension. For exmaple, if a tile [m, n] is loaded - with block factor [MB, NB] the resulting register region has the layout [m/MB, n/NB, MB, NB] + "load_tile" operation loads the values of a tile into a register region with 2D or 4D layout. + 4D layout is used when the tile is in blocked layout. If optional "transpose" 2-element array attribute is specified, the loaded tile will be transposed along the specified non-zero dimension. @@ -232,44 +254,48 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { This operatio has following arguments: * source : source tile that is loaded from - * inner_blocks : optional 2-element array arrtibute to specify the size of the inner blocks - when loaded in the blocked layout * transpose : optional 2-element array attibute to specify along which axis the transpose operation must be applied to the input tile * padding : optional string attribute to specify the padding value if out-of-bounds memory accesses occurs - Example 1: + Example 1: loading into a 2D regsiter region ```mlir %4 = xetile.load_tile %src : !xetile.tile<64x32xf32> -> vector<64x32xf32> ``` - Example 2: + Example 2: loading with tranpose and padding attributes enabled. ```mlir - %4 = xetile.load_tile %src { inner_blocks = [8, 16], transpose = [1, 0], padding = 1.0 : f32} - : !xetile.tile<64x32xf32> -> vector<2x8x16x8xf32> + %4 = xetile.load_tile %src { transpose = [1, 0], padding = 1.0 : f32} + : !xetile.tile<64x32xf32> -> vector<32x64xf32> + ``` + + Example 3: loading into a 4D register region. + ```mlir + %4 = xetile.load_tile %src : !xetile.tile<64x32xf32> -> vector<8x2x8x16xf32> ``` }]; let arguments = (ins XeTile: $source, - OptionalAttr: $inner_blocks, - OptionalAttr: $transpose, + OptionalAttr: $transpose, OptionalAttr: $padding ); - let results = (outs Builtin_Vector: $value); + let results = (outs XeTile_2DOr4DVector: $value); - let assemblyFormat = [{ - $source attr-dict `:` qualified(type($source)) `->` qualified(type($value)) - }]; + // let assemblyFormat = [{ + // $source attr-dict `:` qualified(type($source)) `->` qualified(type($value)) + // }]; + let hasCustomAssemblyFormat = true; let extraClassDeclaration = [{ - ::mlir::Attribute getPaddingValue() { + // padding value defaults to zero in the appropriate type if its not specified + mlir::Attribute getPaddingValueOrDefault() { if (llvm::isa(getSource().getType().getElementType())) { - auto int32Zero = ::mlir::IntegerAttr::get(mlir::IntegerType::get((*this).getContext(), 32), 0); + auto int32Zero = mlir::IntegerAttr::get(mlir::IntegerType::get((*this).getContext(), 32), 0); return getPadding().value_or(int32Zero); } - auto float32Zero = ::mlir::IntegerAttr::get(mlir::FloatType::getF32((*this).getContext()), 0.0); + auto float32Zero = mlir::FloatAttr::get(mlir::FloatType::getF32((*this).getContext()), 0.0); return getPadding().value_or(float32Zero); } }]; @@ -280,31 +306,34 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { def XeTile_StoreTileOp : XeTile_Op<"store_tile", []> { let summary = "stores a register region into memory"; let description = [{ - "store_tile" operation can be used to store a register region into memory in plain layout. - If a block factor is specified, the blocked vector is stored into memory in plan layout. + "store_tile" operation can be used to store a register region into a 2D memory region + decribed by a tile. The register region can be in 2D or 4D. 4D register region is used + when the stored value is in blocked layout. This operation takes the following arguments: * value : vector specifying the values to store - * tile : tile representing the memory region to store into - * innner_blocks : optional 2-element array arrtibute to specify the size of the inner blocks - when stored in the blocked layout + * tile : tile representing the 2D memory region to store into - Example 1: + Example 1: storing a 2D register region + ```mlir + xetile.store_tile %value, %dst : !tile<64x32xf32>, vector<64x32xf32> + ``` + + Example 1: storing a 4D register region ```mlir - xetile.store_tile %value, %dst { inner_blocks = [8,16] } - : (!tile<64x32xf32>, vector<8x2x8x16xf32>) + xetile.store_tile %value, %dst : !tile<64x32xf32>, vector<8x2x8x16xf32> ``` }]; let arguments = (ins - Builtin_Vector: $value, - XeTile: $tile, - OptionalAttr: $inner_blocks + XeTile_2DOr4DVector: $value, + XeTile: $tile ); let assemblyFormat = [{ - $value`,`` `$tile attr-dict `:` `(` qualified(type($value)) `,` qualified(type($tile)) `)` + $value`,`` `$tile attr-dict `:` qualified(type($value)) `,` qualified(type($tile)) }]; + // let hasCustomAssemblyFormat = true; } def XeTile_PrefetchTileOp : XeTile_Op<"prefetch_tile", []> { @@ -338,62 +367,58 @@ def XeTile_PrefetchTileOp : XeTile_Op<"prefetch_tile", []> { def XeTile_TileMMAOp : XeTile_Op<"tile_mma", [Pure]> { let summary = "matrix multiplication in blocked layout"; let description = [{ - "tile_mma" operation represents matrix multiplication on tiles. This operation - takes two input matrices (matrix A, matrix B) and an optional accumulator matrix (matrix C) to + "tile_mma" operation represents matrix multiplication on 2D or 4D vectors. This operation + takes two input vectors (matrix A, matrix B) and an optional accumulator vector (matrix C) to perform a general matrix multiplication. C_new = A * B + C - Optionally inputs A, B, C can be in blocked layout where the block factor is specificed by - an optional a_inner_blocks, b_inner_blocks attributes. + When vectors A, B and, C are specified in 4D if they are in blocked layout i.e. loaded from + memory in blocked layout. Arguments: * a : vector representing input matrix A * b : vector representing input matrix B * c : optional vector representing accumulator matrix C - * a_inner_blocks : optional inner_blocks attribute for matrix A if it is in blocked layout - * b_inner_blocks : optional inner_blocks attribute for matrix B if it is in blocked layout - Example 1: + Example 1: tile_mma on 2D vectors of A and B ```mlir %c_new = xetile.tile_mma %a_vec, %b_vec - : (vector<64x32xf32>, vector<32x128xf32>) -> vector<64x128xf32> + : vector<64x32xf32>, vector<32x128xf32> -> vector<64x128xf32> ``` - Example 2: + Example 2: tile_mma on 2D vectors of A, B and, C ```mlir %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec - : (vector<64x32xf32>, vector<32x128xf32>, vector<64x128xf32>) -> vector<64x128xf32> + : vector<64x32xf32>, vector<32x128xf32>, vector<64x128xf32> -> vector<64x128xf32> ``` - Example 3: + Example 3: tile_mma on 4D vectors of A, B and, C ```mlir - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec { a_inner_blocks=[8,8], b_inner_blocks=[8,16] } - : (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>, vector<8x8x8x16xf32>) -> vector<8x8x8x16xf32> + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec + : vector<8x4x8x8xf32>, vector<4x8x8x16xf32>, vector<8x8x8x16xf32> -> vector<8x8x8x16xf32> ``` }]; let arguments = (ins - Builtin_Vector: $a, - Builtin_Vector: $b, - Optional: $c, - OptionalAttr: $a_inner_blocks, - OptionalAttr: $b_inner_blocks + XeTile_2DOr4DVector: $a, + XeTile_2DOr4DVector: $b, + Optional: $c ); - let results = (outs Builtin_Vector: $output); + let results = (outs XeTile_2DOr4DVector: $output); - let assemblyFormat = [{ - $a`,` ` `$b (`,` ` `$c^)? attr-dict `:` `(`qualified(type($a))`,` ` `qualified(type($b)) - (`,` ` `qualified(type($c))^)?`)` `->` qualified(type($output)) - }]; + // let assemblyFormat = [{ + // $a`,` ` `$b (`,` ` `$c^)? attr-dict `:` `(`qualified(type($a))`,` ` `qualified(type($b)) + // (`,` ` `qualified(type($c))^)?`)` `->` qualified(type($output)) + // }]; + + let hasCustomAssemblyFormat = true; let extraClassDeclaration = [{ mlir::VectorType getAType() { return llvm::cast(getA().getType()); } mlir::VectorType getBType() { return llvm::cast(getB().getType()); } - - mlir::VectorType getCType() { return llvm::cast(getC().getType()); } }]; let hasVerifier = 1; @@ -412,8 +437,8 @@ def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", []> { Example 1: ```mlir - xetile.update_tile_offset %tile, %offset_x, %offset_y - : (tile<32x32xf32>, index, index) + xetile.update_tile_offset %tile, [%offset_x, %offset_y] + : tile<32x32xf32>, index, index ``` }]; @@ -428,13 +453,13 @@ def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", []> { ); let assemblyFormat = [{ - $tile `,` ` ` $offset_x `,` ` ` $offset_y attr-dict `:` - `(` qualified(type($tile)) `,` - ` ` qualified(type($offset_x)) `,` - ` ` qualified(type($offset_y)) `)` + $tile `,` ` ` `[` $offset_x `,` ` ` $offset_y `]` ` ` attr-dict `:` + qualified(type($tile)) `,` + qualified(type($offset_x)) `,` + qualified(type($offset_y)) ` ` `->` qualified(type($result)) }]; } -#endif // _XeTile_OPS_TD_INCLUDED_ +#endif // _XETILE_OPS_TD_INCLUDED_ diff --git a/include/imex/Dialect/XeTile/IR/XeTileTypes.td b/include/imex/Dialect/XeTile/IR/XeTileTypes.td new file mode 100644 index 000000000..09e19671e --- /dev/null +++ b/include/imex/Dialect/XeTile/IR/XeTileTypes.td @@ -0,0 +1,126 @@ +//===---------------- XeTileTypes.td - XeTile dialect -------*- tablegen -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the custom types used by XeTile dialect. +/// +//===----------------------------------------------------------------------===// +#ifndef _XETILE_TYPES_TD_INCLUDED_ +#define _XETILE_TYPES_TD_INCLUDED_ + + +include "imex/Dialect/XeTile/IR/XeTileDialect.td" + +// common base class for types in XeTile dialect +class XeTile_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface], + "::mlir::TensorType"> +{ + let summary = "A type representing a N-D tile"; + let description = [{ + Tile data type in XeTile dialect is used to represent an N-D memory region. This captures the + 2d shape and type of the memory region it points to. Optional encoding attribute can be + attached to the tile type to carry extra information such as data layout information. Optional + inner_blocks attribute can be used to specify the 2D tiling layout within the larger tile. + + Syntax: + + ``` + tile-type ::= `tile` `<` tile-dim-list element-type (`,` `inner_blocks` = `[` int-array `]` )? (`,` encoding )? `>` + element-type ::= float-type | integer-type + tile-dim-list := (decimal-literal `x`)* + int-array ::= int-array-attribute + encoding ::= attribute-value + ``` + + Examples: + + ```mlir + // 2D tile with i32 elements + tile<3x42xi32> + + // 4D tile with f32 elements + tile<4x5x6x7xf32> + + // 2D tile with i16 elements and encoding + tile<64x64xi16, #encoding> + + // 2D tile with i16 elements and inner_blocks + tile<64x64xi16, inner_blocks = [8,16]> + + // 2D tile with i16 elements and inner_blocks and encoding + tile<64x64xi16, inner_blocks = [8,16], #encoding> + ``` + }]; + + let parameters = (ins ArrayRefParameter<"int64_t">:$shape, + "::mlir::Type":$elementType, + OptionalArrayRefParameter<"int64_t">:$inner_blocks, + OptionalParameter<"::mlir::Attribute">:$encoding); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "::llvm::ArrayRef":$shape, + "::mlir::Type":$elementType, + CArg<"llvm::ArrayRef", "{}">:$inner_blocks, + CArg<"::mlir::Attribute", "{}">:$encoding + ), [{ + return $_get(elementType.getContext(), shape, elementType, inner_blocks, encoding); + }]> + ]; + + let extraClassDeclaration = [{ + using TensorType::clone; + using mlir::ShapedType::Trait::getElementTypeBitWidth; + using mlir::ShapedType::Trait::getRank; + using mlir::ShapedType::Trait::getNumElements; + using mlir::ShapedType::Trait::isDynamicDim; + using mlir::ShapedType::Trait::hasStaticShape; + using mlir::ShapedType::Trait::getNumDynamicDims; + using mlir::ShapedType::Trait::getDimSize; + using mlir::ShapedType::Trait::getDynamicDimIndex; + + TileType clone(::mlir::Type elementType) { + return ::llvm::cast(cloneWith(getShape(), elementType)); + } + }]; + + let assemblyFormat = "`<` custom($shape, $elementType, $inner_blocks, $encoding) `>`"; + let genVerifyDecl = true; +} + +// Integer types allowd in XeTile +def XeTile_IntType : AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>; + +// Float types allowed in XeTile +def XeTile_FloatType : AnyTypeOf<[F16, F32, F64, BF16, TF32]>; + +// Define the scalar type for XeTile +def XeTile_ScalarType : AnyTypeOf<[XeTile_IntType, XeTile_FloatType]>; + +// define a 1D or 2D memref of XeTile scalar type +def XeTile_DynamicMemref : MemRefRankOf<[XeTile_ScalarType], [1, 2]>; + +// define the source type for XeTile init_tile +def XeTile_BaseAddrType : AnyTypeOf<[XeTile_DynamicMemref, UI64, UI32, I64, I32]>; + +// define the value type for XeTile load_tile and store_tile op +def XeTile_2DOr4DVector: VectorOfRankAndType<[2, 4], [XeTile_ScalarType]>; + +// define the attribute type allowed for padding values for load op +def XeTile_PaddingValueAttr : AnyAttrOf<[I32Attr, F32Attr]>; + + + +#endif // _XETILE_TYPES_TD_INCLUDED_ diff --git a/lib/Dialect/XeTile/IR/CMakeLists.txt b/lib/Dialect/XeTile/IR/CMakeLists.txt index 7bfbbb0b2..b735ec3c1 100644 --- a/lib/Dialect/XeTile/IR/CMakeLists.txt +++ b/lib/Dialect/XeTile/IR/CMakeLists.txt @@ -1,11 +1,13 @@ add_mlir_dialect_library(IMEXXeTileDialect XeTileOps.cpp + XeTileDialect.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/imex/Dialect/XeTile DEPENDS MLIRXeTileOpsIncGen + XeTileOpsAttrsIncGen LINK_LIBS PUBLIC MLIRIR diff --git a/lib/Dialect/XeTile/IR/XeTileDialect.cpp b/lib/Dialect/XeTile/IR/XeTileDialect.cpp new file mode 100644 index 000000000..d11bb0568 --- /dev/null +++ b/lib/Dialect/XeTile/IR/XeTileDialect.cpp @@ -0,0 +1,188 @@ +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include + +namespace imex { +namespace xetile { + +// bool TileBase::hasRank() const { return true; } + +// llvm::ArrayRef TileBase::getShape() const { +// return cast().getShape(); +// } + +// TileBase TileBase::cloneWith(std::optional> shape, +// Type elementType, mlir::Attribute encoding) +// const { +// return TileType::get(shape.value_or(getShape()), elementType, encoding); +// } + +static mlir::LogicalResult +parseXeTileType(mlir::AsmParser &parser, llvm::SmallVector &shape, + mlir::Type &type, llvm::SmallVector &inner_blocks, + mlir::Attribute &xe_map) { + mlir::Builder odsBuilder(parser.getContext()); + llvm::SMLoc odsLoc = parser.getCurrentLocation(); + (void)odsLoc; + mlir::FailureOr<::imex::xetile::XeMapAttr> _result_xe_map; + mlir::FailureOr> _result_inner_blocks; + llvm::SmallVector dimensions; + if (parser.parseDimensionList(dimensions)) + return mlir::failure(); + + mlir::Type t; + if (parser.parseType(t)) + return mlir::failure(); + + shape = std::move(dimensions); + type = std::move(t); + + bool shouldParseXeMap = true; + + if (mlir::succeeded(parser.parseOptionalComma())) { + // try to parse 'inner_blocks' + if (mlir::succeeded(parser.parseOptionalKeyword("inner_blocks"))) { + if (parser.parseEqual()) + return mlir::failure(); + if (parser.parseLSquare()) + return mlir::failure(); + // Parse variable 'inner_blocks' + _result_inner_blocks = + mlir::FieldParser>::parse(parser); + if (mlir::failed(_result_inner_blocks)) { + parser.emitError(parser.getCurrentLocation(), + "failed to parse XeTile parameter 'inner_blocks' " + "which is to be a `llvm::ArrayRef`"); + return mlir::failure(); + } + if (parser.parseRSquare()) + return mlir::failure(); + + for (auto v : *_result_inner_blocks) + inner_blocks.push_back(v); + + // check if there an additional comma, if so we need to parse XeMap + if (mlir::failed(parser.parseOptionalComma())) { + shouldParseXeMap = false; + } + } + + // Parse variable 'xe_map' + if (shouldParseXeMap) { + _result_xe_map = + mlir::FieldParser<::imex::xetile::XeMapAttr>::parse(parser); + if (mlir::failed(_result_xe_map)) { + parser.emitError( + parser.getCurrentLocation(), + "failed to parse XeTile encoding parameter which is to " + "be a `::imex::xetile::XeMapAttr`"); + return mlir::failure(); + } + xe_map = std::move(_result_xe_map->dyn_cast()); + } + } + + return mlir::success(); +} + +static void printXeTileType(mlir::AsmPrinter &printer, + llvm::ArrayRef shape, mlir::Type type, + llvm::ArrayRef inner_blocks, + mlir::Attribute xe_map) { + for (int64_t dim : shape) { + if (mlir::ShapedType::isDynamic(dim)) + printer << '?'; + else + printer << dim; + printer << 'x'; + } + printer << type; + + if (inner_blocks.size()) { + printer << ", "; + printer << "inner_blocks"; + printer << ' ' << "="; + printer << ' ' << "["; + printer.printStrippedAttrOrType(inner_blocks); + printer << "]"; + } + + if (xe_map) { + printer << ", "; + printer.printStrippedAttrOrType(xe_map); + } +} + +mlir::LogicalResult +TileType::verify(llvm::function_ref emitError, + llvm::ArrayRef shape, mlir::Type elementType, + llvm::ArrayRef inner_blocks, + mlir::Attribute encoding) { + if (inner_blocks.size() > 0 && inner_blocks.size() != 2) + emitError() << "expect integer array of size 2 for inner_blocks"; + + if (encoding) { + auto xeMap = llvm::dyn_cast(encoding); + if (!xeMap) + emitError() << "expect xetile::XeMapAttr for encoding"; + } + + return mlir::success(); +} + +mlir::LogicalResult SubGroupMapAttr::verify( + llvm::function_ref emitError, + mlir::DenseI32ArrayAttr mma_block_size, mlir::DenseI32ArrayAttr wi_layout, + mlir::DenseI32ArrayAttr wi_data) { + if (mma_block_size && mma_block_size.size() != 2) + emitError() << "expect integer array of size 2 for mma_block_size"; + if (wi_layout.size() != 2) + emitError() << "expect integer array of size 2 for wi_layout"; + if (wi_data.size() != 2) + emitError() << "expect integer array of size 2 for wi_data"; + return mlir::success(); +} + +mlir::LogicalResult WorkGroupMapAttr::verify( + llvm::function_ref emitError, + mlir::DenseI32ArrayAttr sg_layout, mlir::DenseI32ArrayAttr sg_data) { + if (sg_layout.size() != 2) + emitError() << "expect integer array of size 2 for sg_layout"; + if (sg_data.size() != 2) + emitError() << "expect integer array of size 2 for sg_data"; + return mlir::success(); +} + +void XeTileDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include + >(); + addOperations< +#define GET_OP_LIST +#include + >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include + >(); +} + +} // namespace xetile +} // namespace imex + +#include +#define GET_ATTRDEF_CLASSES +#include +#define GET_TYPEDEF_CLASSES +#include diff --git a/lib/Dialect/XeTile/IR/XeTileOps.cpp b/lib/Dialect/XeTile/IR/XeTileOps.cpp index 3ff256bb8..93fedc342 100644 --- a/lib/Dialect/XeTile/IR/XeTileOps.cpp +++ b/lib/Dialect/XeTile/IR/XeTileOps.cpp @@ -15,10 +15,15 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" +#include #include #include #include @@ -29,142 +34,252 @@ namespace imex { namespace xetile { -bool TileBase::hasRank() const { return true; } +template +static mlir::ParseResult parseAttributeHelper(mlir::OpAsmParser &parser, + mlir::OperationState &result, + llvm::StringRef attrKeyword) { + AttrType attr; + mlir::Type ty; + + if (std::is_same::value) { + ty = mlir::Type{}; + } else if (std::is_same::value) { + ty = mlir::Type{}; + } else { + assert(0 && "Unreachable.\n"); + } -llvm::ArrayRef TileBase::getShape() const { - return cast().getShape(); -} + if (parser.parseCustomAttributeWithFallback(attr, ty)) + return mlir::failure(); -TileBase TileBase::cloneWith(std::optional> shape, - Type elementType) const { - return TileType::get(shape.value_or(getShape()), elementType); + if (attr) + result.addAttribute(attrKeyword, attr); + return mlir::success(); } -static mlir::LogicalResult parseShape(mlir::AsmParser &parser, - llvm::SmallVector &shape, - mlir::Type &type) { - llvm::SmallVector dimensions; - if (parser.parseDimensionList(dimensions)) +static mlir::ParseResult +parseOptionalAttrDict(mlir::OpAsmParser &parser, mlir::OperationState &result, + llvm::ArrayRef allowedKeys) { + + // try to parse the left brace + if (mlir::failed(parser.parseOptionalLBrace())) + return mlir::success(); + + auto parseElt = [&]() -> mlir::ParseResult { + auto loc = parser.getCurrentLocation(); + llvm::StringRef nameId; + if (parser.parseOptionalKeyword(&nameId, allowedKeys)) + return parser.emitError(loc, "invalid") + << "attribute keyword: " << nameId << ".\n"; + + if (parser.parseEqual()) + return mlir::failure(); + + if (nameId == "transpose") + return parseAttributeHelper(parser, result, + nameId); + if (nameId == "padding") { + return parseAttributeHelper(parser, result, nameId); + } + + assert(0 && "Unreachable!"); + }; + + if (parser.parseCommaSeparatedList(parseElt)) return mlir::failure(); - mlir::Type t; - if (parser.parseType(t)) + if (parser.parseRBrace()) return mlir::failure(); - shape = std::move(dimensions); - type = std::move(t); return mlir::success(); } -static void printShape(mlir::AsmPrinter &printer, llvm::ArrayRef shape, - mlir::Type type) { - for (int64_t dim : shape) { - if (mlir::ShapedType::isDynamic(dim)) - printer << '?'; - else - printer << dim; - printer << 'x'; - } - printer << type; -} - mlir::LogicalResult InitTileOp::verify() { // number of offsets must be 2 because init_tile creates 2D tiles // dynamic_offsets is always a subset of offsets, so checking this is // sufficient - if (getStaticOffsets().size() != 2) { + if (getStaticOffsets().size() != 2) return emitOpError("number of offsets must be 2"); - } // if the source is a memref and has static shape, then dynamic shape and // strides arguments must not be present if (isSourceMemRef() && sourceMemRefHasStaticShape() && - (hasDynamicStrides() || hasDynamicShape())) { + (hasDynamicStrides() || hasDynamicShape())) return emitOpError("dynamic shape or strides are not allowed with a static " "shaped memref as source"); - } // if the source is a memref with dynamic shape, then a 2D dynamic shape // argument must be present if (isSourceMemRef() && !sourceMemRefHasStaticShape() && - getDynamicShape().size() != 2) { + getDynamicShape().size() != 2) return emitOpError("memref with a dynamic shape is used as source but " "dynamic shape argument missing or it is not 2D"); - } // if the source is a memref with dynamic shape, then a 2D dynamic strides // argument must be present if (isSourceMemRef() && !sourceMemRefHasStaticShape() && - getDynamicStrides().size() != 2) { + getDynamicStrides().size() != 2) return emitOpError("memref with a dynamic shape is used as source but " "dynamic strides argument missing or it is not 2D"); - } // if the source is an address, the dynamic shape must be 2D - if (isSourceInteger() && getDynamicShape().size() != 2) { + if (isSourceInteger() && getDynamicShape().size() != 2) return emitOpError("address is used as source but dynamic shape argument " "is missing or it is not 2D"); - } // if the source is an address, dynamic strides must be 2D - if (isSourceInteger() && getDynamicStrides().size() != 2) { + if (isSourceInteger() && getDynamicStrides().size() != 2) return emitOpError("address is used as source but dynamic strides argument " "is missing or it is not 2D"); - } return mlir::success(); } -void InitTileOp::build(::mlir::OpBuilder &builder, - ::mlir::OperationState &state, - xetile::TileType resultType, ::mlir::Value source, - ::llvm::ArrayRef<::mlir::OpFoldResult> offsets, - ::llvm::ArrayRef<::mlir::NamedAttribute> attrs) { - ::llvm::SmallVector staticOffsets; - ::llvm::SmallVector<::mlir::Value> dynamicOffsets; +void InitTileOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + xetile::TileType resultType, mlir::Value source, + llvm::ArrayRef offsets) { + llvm::SmallVector staticOffsets; + llvm::SmallVector dynamicOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); build(builder, state, resultType, source, dynamicOffsets, staticOffsets, - ::mlir::ValueRange({}), /* empty dynamic shape*/ - ::mlir::ValueRange({})); /* empty dynamic strides*/ - state.addAttributes(attrs); + mlir::ValueRange({}), /* empty dynamic shape*/ + mlir::ValueRange({}) /* empty dynamic strides*/ + ); } -void InitTileOp::build(::mlir::OpBuilder &builder, - ::mlir::OperationState &state, - xetile::TileType resultType, ::mlir::Value source, - ::llvm::ArrayRef<::mlir::OpFoldResult> offsets, - ::llvm::ArrayRef<::mlir::Value> dynamic_shape, - ::llvm::ArrayRef<::mlir::Value> dynamic_strides, - ::llvm::ArrayRef<::mlir::NamedAttribute> attrs) { - ::llvm::SmallVector staticOffsets; - ::llvm::SmallVector<::mlir::Value> dynamicOffsets; +void InitTileOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + xetile::TileType resultType, mlir::Value source, + llvm::ArrayRef offsets, + llvm::ArrayRef dynamic_shape, + llvm::ArrayRef dynamic_strides) { + llvm::SmallVector staticOffsets; + llvm::SmallVector dynamicOffsets; dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); build(builder, state, resultType, source, dynamicOffsets, staticOffsets, dynamic_shape, dynamic_strides); - state.addAttributes(attrs); } -mlir::LogicalResult LoadTileOp::verify() { - int64_t outputRank = - llvm::cast(getValue().getType()).getRank(); +mlir::ParseResult InitTileOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + mlir::OpAsmParser::UnresolvedOperand sourceRawOperands[1]; + llvm::ArrayRef sourceOperands( + sourceRawOperands); + llvm::SMLoc sourceOperandsLoc; + (void)sourceOperandsLoc; + llvm::SmallVector offsetsOperands; + llvm::SMLoc offsetsOperandsLoc; + (void)offsetsOperandsLoc; + mlir::DenseI64ArrayAttr static_offsetsAttr; + llvm::SmallVector + dynamic_shapeOperands; + llvm::SMLoc dynamic_shapeOperandsLoc; + (void)dynamic_shapeOperandsLoc; + llvm::SmallVector + dynamic_stridesOperands; + llvm::SMLoc dynamic_stridesOperandsLoc; + (void)dynamic_stridesOperandsLoc; + mlir::Type sourceRawTypes[1]; + llvm::ArrayRef sourceTypes(sourceRawTypes); + mlir::Type tileRawTypes[1]; + llvm::ArrayRef tileTypes(tileRawTypes); + + sourceOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(sourceRawOperands[0])) + return mlir::failure(); + { + offsetsOperandsLoc = parser.getCurrentLocation(); + auto odsResult = + parseDynamicIndexList(parser, offsetsOperands, static_offsetsAttr); + if (odsResult) + return mlir::failure(); + result.getOrAddProperties().static_offsets = + static_offsetsAttr; + } + if (mlir::succeeded(parser.parseOptionalComma())) { + if (parser.parseLSquare()) + return mlir::failure(); + + dynamic_shapeOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(dynamic_shapeOperands)) + return mlir::failure(); + if (parser.parseRSquare()) + return mlir::failure(); + } + if (mlir::succeeded(parser.parseOptionalComma())) { + if (parser.parseLSquare()) + return mlir::failure(); + + dynamic_stridesOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(dynamic_stridesOperands)) + return mlir::failure(); + if (parser.parseRSquare()) + return mlir::failure(); + } - auto innerBlocks = getInnerBlocksAttr(); - auto transpose = getTransposeAttr(); + if (parser.parseColon()) + return mlir::failure(); - // if load_tile operation in blocked format only support 2D blocks - if (innerBlocks && innerBlocks.size() != 2) { - return emitOpError("inner_blocks must be two dimensional"); - } + if (parser.parseType(sourceRawTypes[0])) + return mlir::failure(); + if (parser.parseArrow()) + return mlir::failure(); - // if blocked load_tile load is specified output must be 4-dimensional - if (innerBlocks && outputRank != 4) { - return emitOpError( - "output must be 4-dimensional if inner_blocks is specified"); + if (parser.parseType(tileRawTypes[0])) + return mlir::failure(); + llvm::copy(llvm::ArrayRef( + {1, static_cast(offsetsOperands.size()), + static_cast(dynamic_shapeOperands.size()), + static_cast(dynamic_stridesOperands.size())}), + result.getOrAddProperties() + .operandSegmentSizes.begin()); + mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType(); + result.addTypes(tileTypes); + if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc, + result.operands)) + return mlir::failure(); + if (parser.resolveOperands(offsetsOperands, odsBuildableType0, + offsetsOperandsLoc, result.operands)) + return mlir::failure(); + if (parser.resolveOperands(dynamic_shapeOperands, odsBuildableType0, + dynamic_shapeOperandsLoc, result.operands)) + return mlir::failure(); + if (parser.resolveOperands(dynamic_stridesOperands, odsBuildableType0, + dynamic_stridesOperandsLoc, result.operands)) + return mlir::failure(); + return mlir::success(); +} + +void InitTileOp::print(mlir::OpAsmPrinter &printer) { + printer << ' '; + printer << getSource(); + printDynamicIndexList(printer, *this, getOffsets(), getStaticOffsetsAttr()); + if (!getDynamicShape().empty()) { + printer << ","; + printer << ' ' << "["; + printer << getDynamicShape(); + printer << "]"; + } + if (!getDynamicStrides().empty()) { + printer << ","; + printer << ' ' << "["; + printer << getDynamicStrides(); + printer << "]"; } + printer << ' ' << ":"; + printer << ' '; + printer << getSource().getType(); + printer << ' ' << "->"; + printer << ' '; + printer << getTile().getType(); +} + +mlir::LogicalResult LoadTileOp::verify() { + auto transpose = getTransposeAttr(); + if (transpose && transpose.size() != 2) { return emitOpError("transpose must be two dimensional"); } @@ -172,40 +287,251 @@ mlir::LogicalResult LoadTileOp::verify() { return mlir::success(); } +mlir::ParseResult LoadTileOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + + mlir::OpAsmParser::UnresolvedOperand sourceTile; + llvm::ArrayRef sourceOperands( + sourceTile); + llvm::SMLoc sourceTileOperandLoc = parser.getCurrentLocation(); + if (parser.parseOperand(sourceTile)) + return mlir::failure(); + + // try to parse the optional dictionary attributes + if (parseOptionalAttrDict(parser, result, {"transpose", "padding"})) + return mlir::failure(); + + if (parser.parseColon()) + return mlir::failure(); + + mlir::Type sourceType; + llvm::ArrayRef sourceTypes(sourceType); + if (parser.parseType(sourceType)) + return mlir::failure(); + + if (parser.parseArrow()) + return mlir::failure(); + + mlir::Type valueType; + llvm::ArrayRef outputValueTypes(valueType); + if (parser.parseType(valueType)) + return mlir::failure(); + + result.addTypes(outputValueTypes); + if (parser.resolveOperands(sourceOperands, sourceTypes, sourceTileOperandLoc, + result.operands)) + return mlir::failure(); + return mlir::success(); +} + +static void printPaddingValue(mlir::Attribute paddingValue, + mlir::OpAsmPrinter &printer) { + if (auto floatVal = llvm::dyn_cast(paddingValue)) { + printer << floatVal.getValue() << " : " << floatVal.getType(); + } else if (auto intVal = llvm::dyn_cast(paddingValue)) { + printer << intVal.getValue() << " : " << intVal.getType(); + } +} + +void LoadTileOp::print(mlir::OpAsmPrinter &printer) { + printer << ' '; + printer << getSource(); + bool printSep = false; + + printer << " { "; + if ((*this)->getAttrs().size()) { + if (getTransposeAttr()) { + if (printSep) + printer << ", "; + printer << "transpose = "; + getTransposeAttr().print(printer); + printSep = true; + } + } + if (printSep) + printer << ", "; + printer << "padding = "; + printPaddingValue(getPaddingValueOrDefault(), printer); + printSep = true; + + printer << " } "; + + printer << " : "; + printer << getSource().getType(); + printer << " -> "; + printer << getValue().getType(); +} + +mlir::ParseResult TileMMAOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + + mlir::OpAsmParser::UnresolvedOperand aRawOperands[1]; + llvm::ArrayRef aOperands(aRawOperands); + llvm::SMLoc aOperandsLoc; + mlir::OpAsmParser::UnresolvedOperand bRawOperands[1]; + llvm::ArrayRef bOperands(bRawOperands); + llvm::SMLoc bOperandsLoc; + llvm::SmallVector cOperands; + llvm::SMLoc cOperandsLoc; + + mlir::Type aRawTypes[1]; + llvm::ArrayRef aTypes(aRawTypes); + mlir::Type bRawTypes[1]; + llvm::ArrayRef bTypes(bRawTypes); + llvm::SmallVector cTypes; + mlir::Type outputRawTypes[1]; + llvm::ArrayRef outputTypes(outputRawTypes); + + aOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(aRawOperands[0])) + return mlir::failure(); + + if (parser.parseComma()) + return mlir::failure(); + + bOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperand(bRawOperands[0])) + return mlir::failure(); + + // try to parse optional C vector + if (mlir::succeeded(parser.parseOptionalComma())) { + cOperandsLoc = parser.getCurrentLocation(); + mlir::OpAsmParser::UnresolvedOperand operand; + mlir::OptionalParseResult parseResult = + parser.parseOptionalOperand(operand); + + if (parseResult.has_value()) { + if (failed(*parseResult)) + return mlir::failure(); + cOperands.push_back(operand); + } + } + + if (parser.parseColon()) + return mlir::failure(); + + if (parser.parseType(aRawTypes[0])) + return mlir::failure(); + + if (parser.parseComma()) + return mlir::failure(); + + if (parser.parseType(bRawTypes[0])) + return mlir::failure(); + + if (mlir::succeeded(parser.parseOptionalComma())) { + mlir::Type optionalType; + mlir::OptionalParseResult parseResult = + parser.parseOptionalType(optionalType); + + if (parseResult.has_value()) { + if (failed(*parseResult)) + return mlir::failure(); + cTypes.push_back(optionalType); + } + } + + if (parser.parseArrow()) + return mlir::failure(); + + if (parser.parseType(outputRawTypes[0])) + return mlir::failure(); + + result.addTypes(outputTypes); + + if (parser.resolveOperands(aOperands, aTypes, aOperandsLoc, result.operands)) + return mlir::failure(); + + if (parser.resolveOperands(bOperands, bTypes, bOperandsLoc, result.operands)) + return mlir::failure(); + + if (parser.resolveOperands(cOperands, cTypes, cOperandsLoc, result.operands)) + return mlir::failure(); + + return mlir::success(); +} + +void TileMMAOp::print(mlir::OpAsmPrinter &printer) { + printer << ' '; + printer << getA(); + printer << ", "; + printer << getB(); + + if (getC()) { + printer << ", "; + printer << getC(); + } + printer << " : "; + printer << getA().getType() << ", "; + printer << getB().getType(); + if (getC()) { + printer << ", "; + printer << getC().getType(); + } + printer << " -> "; + printer << getOutput().getType(); +} + mlir::LogicalResult TileMMAOp::verify() { int64_t aRank = getAType().getRank(); int64_t bRank = getBType().getRank(); mlir::Type aElemType = getAType().getElementType(); mlir::Type bElemType = getBType().getElementType(); + mlir::Type outElemType = getOutput().getType().getElementType(); - // the two vector inputs to tile mma must have same rank + auto aShape = getAType().getShape(); + auto bShape = getBType().getShape(); + auto outShape = getOutput().getType().getShape(); + + // two vectors must have the same rank if (aRank != bRank) - return emitOpError("rank mismatch in tile mma inputs"); + return emitOpError("A and B inputs must have the same rank."); // the two vector inputs must have the same element type if (aElemType != bElemType) - return emitOpError("element type mismatch in tile mma inputs"); + return emitOpError("A and B inputs must have the same type."); + + if (getC() && + (llvm::cast(getC().getType()).getElementType() != + outElemType)) + return emitOpError("C and output vector must have the same type."); + + auto check4DMmaShapes = [](llvm::ArrayRef &A, + llvm::ArrayRef &B, + llvm::ArrayRef &Out) -> bool { + return A[1] == B[0] && A[3] == B[2] && Out[0] == A[0] && Out[1] == B[1] && + Out[2] == A[2] && Out[3] == B[3]; + }; + + auto check2DMmaShapes = [](llvm::ArrayRef &A, + llvm::ArrayRef &B, + llvm::ArrayRef &Out) -> bool { + return A[1] == B[0] && Out[0] == A[0] && Out[1] == B[1]; + }; + + // check mma shapes for 4D case + if (aRank == 4 && !check4DMmaShapes(aShape, bShape, outShape)) + return emitOpError("incompatible A, B and output sizes for 4D tile mma op. " + "4D tile mma should have the shape (m x k x Bm x Bk) x " + "(k x n x Bk x Bn) = (m x n x Bm x Bn)."); + + // check mma shape for 2D case + if (aRank == 2 && !check2DMmaShapes(aShape, bShape, outShape)) + return emitOpError( + "incompatible A, B and output sizes for 2D tile mma op. " + "2D tile mma should have the shape (m x k) x (k x n) = (m x n)."); - return mlir::success(); -} + // optional input C must have the same shape as output + if (getC() && + llvm::cast(getC().getType()).getShape() != outShape) + return emitOpError("input C must have the same shape as output."); -void XeTileDialect::initialize() { - addTypes< -#define GET_TYPEDEF_LIST -#include - >(); - addOperations< -#define GET_OP_LIST -#include - >(); + return mlir::success(); } } // namespace xetile } // namespace imex -#include -#define GET_TYPEDEF_CLASSES -#include #define GET_OP_CLASSES #include diff --git a/test/Dialect/XeTile/IR/XeTileOps.mlir b/test/Dialect/XeTile/IR/XeTileOps.mlir deleted file mode 100644 index 7b9067ae1..000000000 --- a/test/Dialect/XeTile/IR/XeTileOps.mlir +++ /dev/null @@ -1,225 +0,0 @@ -// RUN: imex-opt %s | FileCheck %s -// Verify the printed output can be parsed. -// RUN: imex-opt %s | imex-opt | FileCheck %s -// Verify the generic form can be parsed. -// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s - -// init_tile with a static shaped memref -// CHECK-LABEL: func @test_init_tile_using_static_memref({{.*}}) { -func.func @test_init_tile_using_static_memref(%src: memref<1024x1024xf32>) { - - %c128 = arith.constant 128 : index - %c256 = arith.constant 256 : index - - // CHECK: xetile.init_tile - // CHECK-SAME: memref<1024x1024xf32> -> !xetile.tile<32x64xf32> - %1 = xetile.init_tile %src[8, 16] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> - - // CHECK: xetile.init_tile - // CHECK-SAME: memref<1024x1024xf32> -> !xetile.tile<32x64xf32> - %2 = xetile.init_tile %src[%c128, %c256] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> - - // CHECK: xetile.init_tile - // CHECK-SAME: memref<1024x1024xf32> -> !xetile.tile<32x64xf32> - %3 = xetile.init_tile %src[512, %c128] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> - - return -} - -// init tile with a dynmaic shaped memref -// CHECK-LABEL: func @test_init_tile_using_dynamic_memref({{.*}}) { -func.func @test_init_tile_using_dynamic_memref(%src: memref, %dim0_size : index, %dim1_size : index, - %dim0_stride : index, %dim1_stride : index ) { - - %c128 = arith.constant 128 : index - %c256 = arith.constant 256 : index - - // CHECK: xetile.init_tile - // CHECK-SAME: memref -> !xetile.tile<32x64xf32> - %1 = xetile.init_tile %src[8, 16], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] - : memref -> !xetile.tile<32x64xf32> - - // CHECK: xetile.init_tile - // CHECK-SAME: memref -> !xetile.tile<32x64xf32> - %2 = xetile.init_tile %src[%c128, %c256], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] - : memref -> !xetile.tile<32x64xf32> - - - // CHECK: xetile.init_tile - // CHECK-SAME: memref -> !xetile.tile<32x64xf32> - %3 = xetile.init_tile %src[%c128, 64], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] - : memref -> !xetile.tile<32x64xf32> - - return -} - -// init tile with an addr -// CHECK-LABEL: func @test_init_tile_using_addr({{.*}}) { -func.func @test_init_tile_using_addr(%src: i64, %dim0_size : index, %dim1_size : index, - %dim0_stride : index, %dim1_stride : index ) { - - %c128 = arith.constant 128 : index - %c256 = arith.constant 256 : index - - // CHECK: xetile.init_tile - // CHECK-SAME: i64 -> !xetile.tile<32x64xf32> - %1 = xetile.init_tile %src[8, 16], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] - : i64 -> !xetile.tile<32x64xf32> - - // CHECK: xetile.init_tile - // CHECK-SAME: i64 -> !xetile.tile<32x64xf32> - %2 = xetile.init_tile %src[%c128, %c256], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] - : i64 -> !xetile.tile<32x64xf32> - - - // CHECK: xetile.init_tile - // CHECK-SAME: i64 -> !xetile.tile<32x64xf32> - %3 = xetile.init_tile %src[%c128, 64], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] - : i64 -> !xetile.tile<32x64xf32> - - return -} - -// CHECK-LABEL: func @test_init_coop_tile({{.*}}) { -func.func @test_init_coop_tile(%src: !xetile.tile<64x64xf32>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 64 : index - - // CHECK: xetile.init_coop_tile - // CHECK-SAME: !xetile.tile<64x64xf32>, index, index -> !xetile.tile<8x8xf32> - %1 = xetile.init_coop_tile %src, %c0, %c1 - : !xetile.tile<64x64xf32>, index, index -> !xetile.tile<8x8xf32> - - return -} - - -// CHECK-LABEL: func @test_load_tile({{.*}}) { -func.func @test_load_tile(%src: !xetile.tile<64x32xf32>) { - // CHECK: xetile.load_tile - // CHECK-SAME: !xetile.tile<64x32xf32> -> vector<64x32xf32> - %1 = xetile.load_tile %src : !xetile.tile<64x32xf32> -> vector<64x32xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: {inner_blocks = [8, 16]} : !xetile.tile<64x32xf32> -> vector<8x2x8x16xf32> - %2 = xetile.load_tile %src { inner_blocks = [8, 16] } : !xetile.tile<64x32xf32> -> vector<8x2x8x16xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: {transpose = [1, 0]} : !xetile.tile<64x32xf32> -> vector<32x64xf32> - %3 = xetile.load_tile %src { transpose = [1, 0] } : !xetile.tile<64x32xf32> -> vector<32x64xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: {padding = 1.000000e-01 : f32} : !xetile.tile<64x32xf32> -> vector<64x32xf32> - %4 = xetile.load_tile %src { padding = 0.1 : f32 } : !xetile.tile<64x32xf32> -> vector<64x32xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: {inner_blocks = [8, 16], padding = 1.000000e-01 : f32, transpose = [1, 0]} : - // CHECK-SAME: !xetile.tile<64x32xf32> -> vector<2x8x16x8xf32> - %5 = xetile.load_tile %src { inner_blocks = [8, 16], transpose = [1, 0], padding = 0.1 : f32 } - : !xetile.tile<64x32xf32> -> vector<2x8x16x8xf32> - - return -} - -// CHECK-LABEL: func @test_store_tile({{.*}}) { -func.func @test_store_tile(%value1 : vector<64x32xf32>, - %value2 : vector<8x2x8x16xf32>, %dst: !xetile.tile<64x32xf32>) { - - // CHECK: xetile.store_tile - // CHECK-SAME: (vector<64x32xf32>, !xetile.tile<64x32xf32>) - xetile.store_tile %value1, %dst : (vector<64x32xf32>, !xetile.tile<64x32xf32>) - - // CHECK: xetile.store_tile - // CHECK-SAME: {inner_blocks = [8, 16]} : (vector<8x2x8x16xf32>, !xetile.tile<64x32xf32>) - xetile.store_tile %value2, %dst { inner_blocks = [8, 16] } : (vector<8x2x8x16xf32>, !xetile.tile<64x32xf32>) - - return -} - -// CHECK-LABEL: func @test_coop_prefetch_tile({{.*}}) { -func.func @test_coop_prefetch_tile(%src: !xetile.tile<64x64xf32>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 64 : index - - // CHECK: xetile.init_coop_tile - // CHECK-SAME: !xetile.tile<64x64xf32>, index, index -> !xetile.tile<8x8xf32> - %1 = xetile.init_coop_tile %src, %c0, %c1 - : !xetile.tile<64x64xf32>, index, index -> !xetile.tile<8x8xf32> - - // CHECK: xetile.prefetch_tile - // CHECK-SAME: (!xetile.tile<8x8xf32>) - xetile.prefetch_tile %1 : (!xetile.tile<8x8xf32>) - - return -} - - -// CHECK-LABEL: func @test_tile_mma({{.*}}) { -func.func @test_tile_mma(%a: !xetile.tile<64x32xf32>, %b: !xetile.tile<32x128xf32>, %c : !xetile.tile<64x128xf32>) { - - // CHECK: xetile.load_tile - // CHECK-SAME: !xetile.tile<64x32xf32> -> vector<64x32xf32> - %a_vec = xetile.load_tile %a : !xetile.tile<64x32xf32> -> vector<64x32xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: !xetile.tile<32x128xf32> -> vector<32x128xf32> - %b_vec = xetile.load_tile %b : !xetile.tile<32x128xf32> -> vector<32x128xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: !xetile.tile<64x128xf32> -> vector<64x128xf32> - %c_vec = xetile.load_tile %c : !xetile.tile<64x128xf32> -> vector<64x128xf32> - - // CHECK: xetile.tile_mma - // CHECK-SAME: (vector<64x32xf32>, vector<32x128xf32>) -> vector<64x128xf32> - %c_new = xetile.tile_mma %a_vec, %b_vec - : (vector<64x32xf32>, vector<32x128xf32>) -> vector<64x128xf32> - - // CHECK: xetile.tile_mma - // CHECK-SAME: (vector<64x32xf32>, vector<32x128xf32>, vector<64x128xf32>) -> vector<64x128xf32> - %c_new_ = xetile.tile_mma %a_vec, %b_vec, %c_vec - : (vector<64x32xf32>, vector<32x128xf32>, vector<64x128xf32>) -> vector<64x128xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: {inner_blocks = [8, 8]} : !xetile.tile<64x32xf32> -> vector<8x4x8x8xf32> - %a_vec_1 = xetile.load_tile %a { inner_blocks = [8, 8] } - : !xetile.tile<64x32xf32> -> vector<8x4x8x8xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: {inner_blocks = [8, 16]} : !xetile.tile<32x128xf32> -> vector<4x8x8x16xf32> - %b_vec_1 = xetile.load_tile %b { inner_blocks = [8, 16] } - : !xetile.tile<32x128xf32> -> vector<4x8x8x16xf32> - - // CHECK: xetile.load_tile - // CHECK-SAME: {inner_blocks = [8, 16]} : !xetile.tile<64x128xf32> -> vector<8x8x8x16xf32> - %c_vec_1 = xetile.load_tile %c { inner_blocks = [8, 16] } - : !xetile.tile<64x128xf32> -> vector<8x8x8x16xf32> - - // CHECK: xetile.tile_mma - // CHECK-SAME: {a_inner_blocks = [8, 8], b_inner_blocks = [8, 16]} - // CHECK-SAME: (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>) -> vector<8x8x8x16xf32> - %c_new_1 = xetile.tile_mma %a_vec_1, %b_vec_1 {a_inner_blocks = [8, 8], b_inner_blocks = [8, 16]} - : (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>) -> vector<8x8x8x16xf32> - - // CHECK: xetile.tile_mma - // CHECK-SAME: {a_inner_blocks = [8, 8], b_inner_blocks = [8, 16]} - // CHECK-SAME: (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>, vector<8x8x8x16xf32>) -> vector<8x8x8x16xf32> - %c_new_1_ = xetile.tile_mma %a_vec_1, %b_vec_1, %c_vec_1 {a_inner_blocks = [8, 8], b_inner_blocks = [8, 16]} - : (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>, vector<8x8x8x16xf32>) -> vector<8x8x8x16xf32> - - return -} - - -// CHECK-LABEL: func @test_update_tile_offset({{.*}}) { -func.func @test_update_tile_offset(%tile: !xetile.tile<32x32xf32>) { - - %offset_x = arith.constant 0 : index - %offset_y = arith.constant 96 : index - - // CHECK: xetile.update_tile_offset - // CHECK-SAME: (!xetile.tile<32x32xf32>, index, index) - xetile.update_tile_offset %tile, %offset_x, %offset_y - : (!xetile.tile<32x32xf32>, index, index) -> !xetile.tile<32x32xf32> - - return -} diff --git a/test/Dialect/XeTile/IR/invalid.mlir b/test/Dialect/XeTile/IR/invalid.mlir index 7b8cc9d75..bd1e877a5 100644 --- a/test/Dialect/XeTile/IR/invalid.mlir +++ b/test/Dialect/XeTile/IR/invalid.mlir @@ -55,14 +55,6 @@ func.func @init_tile_address_with_invalid_dynamic_strides(%source : i64, %dim0_s : i64 -> !xetile.tile<64x64xf32> } -// ----- -func.func @load_tile_with_invalid_inner_blocks(%tile : !xetile.tile<64x64xf32>) { - // INNER_BLOCKS must be 2D - // expected-error@+1 {{inner_blocks must be two dimensional}} - %1 = xetile.load_tile %tile { inner_blocks = [8,16,4] } - : !xetile.tile<64x64xf32> -> vector<8x4x8x16xf32> -} - // ----- func.func @load_tile_with_invalid_transpose(%tile : !xetile.tile<64x32xf32>) { // TRANSPOSE must be 2D @@ -71,29 +63,67 @@ func.func @load_tile_with_invalid_transpose(%tile : !xetile.tile<64x32xf32>) { : !xetile.tile<64x32xf32> -> vector<32x64xf32> } + // ----- -func.func @load_tile_with_invalid_output_rank(%tile : !xetile.tile<64x64xf32>) { - // if the INNER_BLOCKS is specified in tile_load output must be 4D - // expected-error@+1 {{output must be 4-dimensional if inner_blocks is specified}} - %1 = xetile.load_tile %tile { inner_blocks = [8,16] } - : !xetile.tile<64x64xf32> -> vector<8x4xf32> +func.func @tile_mma_incompatible_ranks(%a_vec : vector<8x8x8x8xf32>, + %b_vec : vector<8x8xf32>, %c_vec : vector<8x8x8x8xf32>) { + // the two input vectors must have the same rank + // expected-error@+1 {{A and B inputs must have the same rank.}} + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<8x8x8x8xf32>, vector<8x8xf32>, vector<8x8x8x8xf32> + -> vector<8x8x8x8xf32> +} +// ----- +func.func @tile_mma_input_elem_type_mismatch(%a_vec : vector<8x8xf32>, + %b_vec : vector<8x8xf16>, %c_vec : vector<8x8xf32>) { + // the two input vectors must have the same rank + // expected-error@+1 {{A and B inputs must have the same type.}} + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<8x8xf32>, vector<8x8xf16>, vector<8x8xf32> -> vector<8x8xf32> } // ----- -func.func @tile_mma_input_rank_mismatch(%a_vec : vector<8x8x8x8xf32>, - %b_vec : vector<8x8x8xf32>, %c_vec : vector<8x8x8x8xf32>) { +func.func @tile_mma_output_elem_type_mismatch(%a_vec : vector<8x8xf32>, + %b_vec : vector<8x8xf32>, %c_vec : vector<8x8xf16>) { // the two input vectors must have the same rank - // expected-error@+1 {{rank mismatch in tile mma inputs}} - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec {a_inner_blocks = [8, 8], b_inner_blocks = [8, 8]} - : (vector<8x8x8x8xf32>, vector<8x8x8xf32>, vector<8x8x8x8xf32>) -> vector<8x8x8x8xf32> + // expected-error@+1 {{C and output vector must have the same type.}} + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<8x8xf32>, vector<8x8xf32>, vector<8x8xf16> -> vector<8x8xf32> +} + +// ----- +func.func @tile_mma_incompatible_mma_shapes_4d(%a_vec : vector<8x16x8x32xf16>, + %b_vec : vector<16x8x8x8xf16>, %c_vec : vector<8x8x8x8xf32>) { + // the two input vectors must have the same element type + // expected-error@+1 {{incompatible A, B and output sizes for 4D tile mma op. 4D tile mma should have the shape (m x k x Bm x Bk) x (k x n x Bk x Bn) = (m x n x Bm x Bn).}} + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec + : vector<8x16x8x32xf16>, vector<16x8x8x8xf16>, vector<8x8x8x8xf32> -> vector<8x8x8x8xf32> } // ----- -func.func @tile_mma_input_elem_type_mismatch(%a_vec : vector<8x8x8x8xf16>, - %b_vec : vector<8x8x8x8xf32>, %c_vec : vector<8x8x8x8xf32>) { +func.func @tile_mma_incompatible_mma_shapes_2d(%a_vec : vector<8x16xf16>, + %b_vec : vector<8x8xf16>, %c_vec : vector<8x8xf32>) { // the two input vectors must have the same element type - // expected-error@+1 {{element type mismatch in tile mma inputs}} - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec {a_inner_blocks = [8, 8], b_inner_blocks = [8, 8]} - : (vector<8x8x8x8xf16>, vector<8x8x8x8xf32>, vector<8x8x8x8xf32>) -> vector<8x8x8x8xf32> + // expected-error@+1 {{incompatible A, B and output sizes for 2D tile mma op. 2D tile mma should have the shape (m x k) x (k x n) = (m x n).}} + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<8x16xf16>, vector<8x8xf16>, vector<8x8xf32> -> vector<8x8xf32> } + +// ----- +func.func @tile_mma_input_c_shape_mismatch(%a_vec : vector<8x16xf16>, + %b_vec : vector<16x8xf16>, %c_vec : vector<16x8xf32>) { + // the two input vectors must have the same element type + // expected-error@+1 {{input C must have the same shape as output.}} + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : vector<8x16xf16>, vector<16x8xf16>, vector<16x8xf32> -> vector<8x8xf32> +} + +// ----- +// expected-error@+1 {{expect integer array of size 2 for mma_block_size}} +#sg_map_1 = #xetile.sg_map +// expected-error@+1 {{expect integer array of size 2 for wi_layout}} +#sg_map_2 = #xetile.sg_map +// expected-error@+1 {{expect integer array of size 2 for wi_data}} +#sg_map_3 = #xetile.sg_map +// expected-error@+1 {{expect integer array of size 2 for sg_layout}} +#wg_map_1 = #xetile.wg_map +// expected-error@+1 {{expect integer array of size 2 for sg_data}} +#wg_map_2 = #xetile.wg_map +// expected-error@+1 {{expect integer array of size 2 for inner_blocks}} +#tile1 = !xetile.tile<64x64xf16, inner_blocks = [8, 16, 8]> diff --git a/test/Dialect/XeTile/IR/ops.mlir b/test/Dialect/XeTile/IR/ops.mlir new file mode 100644 index 000000000..cef7228e2 --- /dev/null +++ b/test/Dialect/XeTile/IR/ops.mlir @@ -0,0 +1,277 @@ +// RUN: imex-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: imex-opt %s | imex-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s + +#sg_map = #xetile.sg_map +#wg_map = #xetile.wg_map +#xe_map = #xetile.xe_map + +// init_tile with a static shaped memref +// CHECK-LABEL: func @test_init_tile_using_static_memref({{.*}}) { +func.func @test_init_tile_using_static_memref(%src: memref<1024x1024xf16>) { + + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<32x64xf16> + %1 = xetile.init_tile %src[8, 16] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> + + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<32x64xf16> + %2 = xetile.init_tile %src[%c128, %c256] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> + + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<32x64xf16> + %3 = xetile.init_tile %src[512, %c128] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16> + + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<32x64xf16, inner_blocks = [32, 16]> + %4 = xetile.init_tile %src[512, %c128] : memref<1024x1024xf16> -> !xetile.tile<32x64xf16, inner_blocks = [32, 16]> + + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #xetile.xe_map, + // CHECK-SAME: sg = >> + %5 = xetile.init_tile %src[0, 0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, #xe_map> + + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<128x128xf16, inner_blocks = [32, 16], + // CHECK-SAME: #xetile.xe_map, + // CHECK-SAME: sg = >> + %6 = xetile.init_tile %src[0, 0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, inner_blocks = [32, 16], #xe_map> + + return +} + +// init tile with a dynmaic shaped memref +// CHECK-LABEL: func @test_init_tile_using_dynamic_memref({{.*}}) { +func.func @test_init_tile_using_dynamic_memref(%src: memref, %dim0_size : index, %dim1_size : index, + %dim0_stride : index, %dim1_stride : index ) { + + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + // CHECK: xetile.init_tile + // CHECK-SAME: memref -> !xetile.tile<32x64xf16> + %1 = xetile.init_tile %src[8, 16], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : memref -> !xetile.tile<32x64xf16> + + // CHECK: xetile.init_tile + // CHECK-SAME: memref -> !xetile.tile<32x64xf16> + %2 = xetile.init_tile %src[%c128, %c256], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : memref -> !xetile.tile<32x64xf16> + + + // CHECK: xetile.init_tile + // CHECK-SAME: memref -> !xetile.tile<32x64xf16> + %3 = xetile.init_tile %src[%c128, 64], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : memref -> !xetile.tile<32x64xf16> + + // CHECK: xetile.init_tile + // CHECK-SAME: memref -> !xetile.tile<128x128xf16, #xetile.xe_map, + // CHECK-SAME: sg = >> + %4 = xetile.init_tile %src[0, 0], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : memref -> !xetile.tile<128x128xf16, #xe_map> + + return +} + +// init tile with an addr +// CHECK-LABEL: func @test_init_tile_using_addr({{.*}}) { +func.func @test_init_tile_using_addr(%src: i64, %dim0_size : index, %dim1_size : index, + %dim0_stride : index, %dim1_stride : index ) { + + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + // CHECK: xetile.init_tile + // CHECK-SAME: i64 -> !xetile.tile<32x64xf16> + %1 = xetile.init_tile %src[8, 16], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : i64 -> !xetile.tile<32x64xf16> + + // CHECK: xetile.init_tile + // CHECK-SAME: i64 -> !xetile.tile<32x64xf16> + %2 = xetile.init_tile %src[%c128, %c256], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : i64 -> !xetile.tile<32x64xf16> + + + // CHECK: xetile.init_tile + // CHECK-SAME: i64 -> !xetile.tile<32x64xf16> + %3 = xetile.init_tile %src[%c128, 64], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : i64 -> !xetile.tile<32x64xf16> + + // CHECK: xetile.init_tile %arg0[0, 0], [%arg1, %arg2], [%arg3, %arg4] + // CHECK-SAME: i64 -> !xetile.tile<128x128xf16, #xetile.xe_map, + // CHECK-SAME: sg = >> + %4 = xetile.init_tile %src[0, 0], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : i64 -> !xetile.tile<128x128xf16, #xe_map> + + return +} + +// CHECK-LABEL: func @test_init_coop_tile({{.*}}) { +func.func @test_init_coop_tile(%src: !xetile.tile<64x64xf16>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 64 : index + + // CHECK: xetile.init_coop_tile + // CHECK-SAME: !xetile.tile<64x64xf16>, index, index -> !xetile.tile<8x8xf16> + %1 = xetile.init_coop_tile %src, %c0, %c1 + : !xetile.tile<64x64xf16>, index, index -> !xetile.tile<8x8xf16> + + return +} + + +// CHECK-LABEL: func @test_load_tile({{.*}}) { +func.func @test_load_tile(%src: !xetile.tile<64x32xf16>, %src1 : !xetile.tile<128x128xf16, #xe_map>) { + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } : !xetile.tile<64x32xf16> -> vector<64x32xf16> + %1 = xetile.load_tile %src : !xetile.tile<64x32xf16> -> vector<64x32xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } + // CHECK-SAME: !xetile.tile<64x32xf16> -> vector<8x2x8x16xf16> + %2 = xetile.load_tile %src : !xetile.tile<64x32xf16> -> vector<8x2x8x16xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { transpose = [1, 0], padding = 0.000000e+00 : f32 } : !xetile.tile<64x32xf16> -> vector<32x64xf16> + %3 = xetile.load_tile %src { transpose = [1, 0] } : !xetile.tile<64x32xf16> -> vector<32x64xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 1.000000e-01 : f32 } : !xetile.tile<64x32xf16> -> vector<64x32xf16> + %4 = xetile.load_tile %src { padding = 0.1 : f32 } : !xetile.tile<64x32xf16> -> vector<64x32xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { transpose = [1, 0], padding = 1.000000e-01 : f32 } + // CHECK-SAME: !xetile.tile<64x32xf16> -> vector<2x8x16x8xf16> + %5 = xetile.load_tile %src { transpose = [1, 0], padding = 0.1 : f32 } + : !xetile.tile<64x32xf16> -> vector<2x8x16x8xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { transpose = [1, 0], padding = 1.000000e-01 : f32 } + // CHECK-SAME: !xetile.tile<128x128xf16, #xetile.xe_map, + // CHECK-SAME: sg = >> -> vector<2x8x16x8xf16> + %6 = xetile.load_tile %src1 { transpose = [1, 0], padding = 0.1 : f32 } + : !xetile.tile<128x128xf16, #xe_map> -> vector<2x8x16x8xf16> + + return +} + +// CHECK-LABEL: func @test_store_tile({{.*}}) { +func.func @test_store_tile(%value1 : vector<64x32xf16>, + %value2 : vector<8x2x8x16xf16>, %value3 : vector<16x8x8x16xf16>, %dst: !xetile.tile<64x32xf16>, %dst1: !xetile.tile<128x128xf16, #xe_map>) { + + // CHECK: xetile.store_tile + // CHECK-SAME: vector<64x32xf16>, !xetile.tile<64x32xf16> + xetile.store_tile %value1, %dst : vector<64x32xf16>, !xetile.tile<64x32xf16> + + // CHECK: xetile.store_tile + // CHECK-SAME: vector<8x2x8x16xf16>, !xetile.tile<64x32xf16> + xetile.store_tile %value2, %dst : vector<8x2x8x16xf16>, !xetile.tile<64x32xf16> + + // CHECK: xetile.store_tile + // CHECK-SAME: vector<16x8x8x16xf16>, !xetile.tile<128x128xf16, #xetile.xe_map, + // CHECK-SAME: sg = >> + xetile.store_tile %value3, %dst1 : vector<16x8x8x16xf16>, !xetile.tile<128x128xf16, #xe_map> + + return +} + +// CHECK-LABEL: func @test_coop_prefetch_tile({{.*}}) { +func.func @test_coop_prefetch_tile(%src: !xetile.tile<64x64xf16>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 64 : index + + // CHECK: xetile.init_coop_tile + // CHECK-SAME: !xetile.tile<64x64xf16>, index, index -> !xetile.tile<8x8xf16> + %1 = xetile.init_coop_tile %src, %c0, %c1 + : !xetile.tile<64x64xf16>, index, index -> !xetile.tile<8x8xf16> + + // CHECK: xetile.prefetch_tile + // CHECK-SAME: (!xetile.tile<8x8xf16>) + xetile.prefetch_tile %1 : (!xetile.tile<8x8xf16>) + + return +} + + +// CHECK-LABEL: func @test_tile_mma({{.*}}) { +func.func @test_tile_mma(%a: !xetile.tile<64x32xf16>, %b: !xetile.tile<32x128xf16>, %c : !xetile.tile<64x128xf16>) { + + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } : !xetile.tile<64x32xf16> -> vector<64x32xf16> + %a_vec = xetile.load_tile %a : !xetile.tile<64x32xf16> -> vector<64x32xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } : !xetile.tile<32x128xf16> -> vector<32x128xf16> + %b_vec = xetile.load_tile %b : !xetile.tile<32x128xf16> -> vector<32x128xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } : !xetile.tile<64x128xf16> -> vector<64x128xf16> + %c_vec = xetile.load_tile %c : !xetile.tile<64x128xf16> -> vector<64x128xf16> + + // CHECK: xetile.tile_mma + // CHECK-SAME: vector<64x32xf16>, vector<32x128xf16> -> vector<64x128xf16> + %c_new = xetile.tile_mma %a_vec, %b_vec + : vector<64x32xf16>, vector<32x128xf16> -> vector<64x128xf16> + + // CHECK: xetile.tile_mma + // CHECK-SAME: vector<64x32xf16>, vector<32x128xf16>, vector<64x128xf16> -> vector<64x128xf16> + %c_new_ = xetile.tile_mma %a_vec, %b_vec, %c_vec + : vector<64x32xf16>, vector<32x128xf16>, vector<64x128xf16> -> vector<64x128xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } + // CHECK-SAME: !xetile.tile<64x32xf16> -> vector<8x4x8x8xf16> + %a_vec_1 = xetile.load_tile %a : !xetile.tile<64x32xf16> -> vector<8x4x8x8xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } + // CHECK-SAME: !xetile.tile<32x128xf16> -> vector<4x8x8x16xf16> + %b_vec_1 = xetile.load_tile %b : !xetile.tile<32x128xf16> -> vector<4x8x8x16xf16> + + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } + // CHECK-SAME: !xetile.tile<64x128xf16> -> vector<8x8x8x16xf16> + %c_vec_1 = xetile.load_tile %c : !xetile.tile<64x128xf16> -> vector<8x8x8x16xf16> + + // CHECK: xetile.tile_mma + // CHECK-SAME: vector<8x4x8x8xf16>, vector<4x8x8x16xf16> -> vector<8x8x8x16xf16> + %c_new_1 = xetile.tile_mma %a_vec_1, %b_vec_1 + : vector<8x4x8x8xf16>, vector<4x8x8x16xf16> -> vector<8x8x8x16xf16> + + // CHECK: xetile.tile_mma + // CHECK-SAME: vector<8x4x8x8xf16>, vector<4x8x8x16xf16>, vector<8x8x8x16xf16> -> vector<8x8x8x16xf16> + %c_new_1_ = xetile.tile_mma %a_vec_1, %b_vec_1, %c_vec_1 + : vector<8x4x8x8xf16>, vector<4x8x8x16xf16>, vector<8x8x8x16xf16> -> vector<8x8x8x16xf16> + + return +} + + +// CHECK-LABEL: func @test_update_tile_offset({{.*}}) { +func.func @test_update_tile_offset(%tile: !xetile.tile<32x32xf16>, %tile1 : !xetile.tile<128x128xf16, #xe_map>) { + + %offset_x = arith.constant 0 : index + %offset_y = arith.constant 96 : index + %c128 = arith.constant 128 : index + %c0 = arith.constant 0 : index + + // CHECK: xetile.update_tile_offset + // CHECK-SAME: : !xetile.tile<32x32xf16>, index, index -> !xetile.tile<32x32xf16> + %1 = xetile.update_tile_offset %tile, [%offset_x, %offset_y] + : !xetile.tile<32x32xf16>, index, index -> !xetile.tile<32x32xf16> + + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<128x128xf16, #xetile.xe_map, + // CHECK-SAME: sg = >>, index, index + // CHECK-SAME: -> !xetile.tile<128x128xf16, #xetile.xe_map, + // CHECK-SAME: sg = >> + %2 = xetile.update_tile_offset %tile1, [%c128, %c0] + : !xetile.tile<128x128xf16, #xe_map>, index, index -> !xetile.tile<128x128xf16, #xe_map> + + return +} diff --git a/test/Dialect/XeTile/IR/simple_gemm.mlir b/test/Dialect/XeTile/IR/simple_gemm.mlir index 5e02d7313..0799fb90e 100644 --- a/test/Dialect/XeTile/IR/simple_gemm.mlir +++ b/test/Dialect/XeTile/IR/simple_gemm.mlir @@ -4,65 +4,95 @@ // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s +#sg_map_a = #xetile.sg_map +#wg_map_a = #xetile.wg_map +#xe_map_a = #xetile.xe_map + +#sg_map_b = #xetile.sg_map +#wg_map_b = #xetile.wg_map +#xe_map_b = #xetile.xe_map + +#sg_map_c = #xetile.sg_map +#wg_map_c = #xetile.wg_map +#xe_map_c = #xetile.xe_map // CHECK-LABEL: func @test_gemm({{.*}}) { func.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - %c16 = arith.constant 16 : index + // %c8 = arith.constant 8 : index + // %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index %c1024 = arith.constant 1024 : index %block_id_x = gpu.block_id x %block_id_y = gpu.block_id y - %m = arith.muli %block_id_x, %c8 : index - %n = arith.muli %block_id_y, %c16 : index + %m = arith.muli %block_id_x, %c128 : index + %n = arith.muli %block_id_y, %c128 : index // intialize C tile and load it - // CHECK : xetile.init_tile - // CHECK-SAME : memref<1024x1024xf32> -> !xetile.tile<8x16xf32> - %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<8x16xf32> - // CHECK : xetile.load_tile - // CHECK-SAME : !xetile.tile<8x16xf32> -> vector<8x16xf32> - %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<8x16xf32> -> vector<8x16xf32> + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf32> -> !xetile.tile<128x128xf32, inner_blocks = [8, 16] + // CHECK-SAME: #xetile.xe_map, + // CHECK-SAME: sg = >> + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<128x128xf32, inner_blocks = [8, 16], #xe_map_c> + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } : !xetile.tile<128x128xf32, inner_blocks = [8, 16], + // CHECK-SAME: #xetile.xe_map, + // CHECK-SAME: sg = >> -> vector<128x128xf32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<128x128xf32, inner_blocks = [8, 16], #xe_map_c> -> vector<128x128xf32> // initalize A and B tiles - // CHECK : xetile.init_tile - // CHECK-SAME : memref<1024x1024xf16> -> !xetile.tile<8x16xf16> - %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<8x16xf16> - // CHECK : xetile.init_tile - // CHECK-SAME : memref<1024x1024xf16> -> !xetile.tile<16x16xf16> - %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<16x16xf16> + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<128x128xf16, + // CHECK-SAME: inner_blocks = [8, 16], #xetile.xe_map, + // CHECK-SAME: sg = >> + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, inner_blocks = [8, 16], #xe_map_a> + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf16> -> !xetile.tile<128x128xf16, inner_blocks = [16, 16] + // CHECK-SAME: #xetile.xe_map, + // CHECK-SAME: sg = >> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<128x128xf16, inner_blocks = [16, 16], #xe_map_b> // compute the value of C tile by iterating over tiles in k-dimension and doing dpas - %out:3 = scf.for %k = %c0 to %c1024 step %c16 + %out:3 = scf.for %k = %c0 to %c1024 step %c128 iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) - -> (!xetile.tile<8x16xf16>, !xetile.tile<16x16xf16>, vector<8x16xf32>) { + -> (!xetile.tile<128x128xf16, inner_blocks = [8, 16], #xe_map_a>, !xetile.tile<128x128xf16, inner_blocks = [16, 16], #xe_map_b>, vector<128x128xf32>) { // load A and B tiles - // CHECK : xetile.load_tile - // CHECK-SAME : !xetile.tile<8x16xf16> -> vector<8x16xf16> - %a_value = xetile.load_tile %a_tile : !xetile.tile<8x16xf16> -> vector<8x16xf16> - // CHECK : xetile.load_tile - // CHECK-SAME : !xetile.tile<16x16xf16> -> vector<16x16xf16> - %b_value = xetile.load_tile %b_tile : !xetile.tile<16x16xf16> -> vector<16x16xf16> + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } : !xetile.tile<128x128xf16, inner_blocks = [8, 16], + // CHECK-SAME: #xetile.xe_map, + // CHECK-SAME: sg = >> -> vector<128x128xf16> + %a_value = xetile.load_tile %a_tile : !xetile.tile<128x128xf16, inner_blocks = [8, 16], #xe_map_a> -> vector<128x128xf16> + // CHECK: xetile.load_tile + // CHECK-SAME: { padding = 0.000000e+00 : f32 } : !xetile.tile<128x128xf16, inner_blocks = [16, 16], + // CHECK-SAME: #xetile.xe_map, + // CHECK-SAME: sg = >> -> vector<128x128xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<128x128xf16, inner_blocks = [16, 16], #xe_map_b> -> vector<128x128xf16> // perform dpas and accumulate - // CHECK : xetile.tile_mma - // CHECK-SAME : (vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>) -> vector<8x16xf32> - %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value - : (vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>) -> vector<8x16xf32> + // CHECK: xetile.tile_mma + // CHECK-SAME: vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32> // update the offsets for A and B tiles - // CHECK : xetile.update_tile_offset - // CHECK-SAME : (!xetile.tile<8x16xf16>, index, index) -> !xetile.tile<8x16xf16> - %a_next_tile = xetile.update_tile_offset %a_tile, %c0, %c16 - : (!xetile.tile<8x16xf16>, index, index) -> !xetile.tile<8x16xf16> - // CHECK : xetile.update_tile_offset - // CHECK-SAME : (!xetile.tile<16x16xf16>, index, index) -> !xetile.tile<16x16xf16> - %b_next_tile = xetile.update_tile_offset %b_tile, %c16, %c0 - : (!xetile.tile<16x16xf16>, index, index) -> !xetile.tile<16x16xf16> + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<128x128xf16, inner_blocks = [8, 16], #xetile.xe_map, + // CHECK-SAME: sg = >>, index, index + // CHECK-SAME: -> !xetile.tile<128x128xf16, inner_blocks = [8, 16], #xetile.xe_map, + // CHECK-SAME: sg = >> + %a_next_tile = xetile.update_tile_offset %a_tile, [%c0, %c128] + : !xetile.tile<128x128xf16, inner_blocks = [8, 16], #xe_map_a>, index, index -> !xetile.tile<128x128xf16, inner_blocks = [8, 16], #xe_map_a> + // CHECK: xetile.update_tile_offset + // CHECK-SAME: !xetile.tile<128x128xf16, inner_blocks = [16, 16], #xetile.xe_map, + // CHECK-SAME: sg = >>, index, index + // CHECK-SAME: -> !xetile.tile<128x128xf16, inner_blocks = [16, 16], #xetile.xe_map, + // CHECK-SAME: sg = >> + %b_next_tile = xetile.update_tile_offset %b_tile, [%c128, %c0] + : !xetile.tile<128x128xf16, inner_blocks = [16, 16], #xe_map_b>, index, index -> !xetile.tile<128x128xf16, inner_blocks = [16, 16], #xe_map_b> // partial C tile result scf.yield %a_next_tile, %b_next_tile, %c_new_value - : !xetile.tile<8x16xf16>, !xetile.tile<16x16xf16>, vector<8x16xf32> + : !xetile.tile<128x128xf16, inner_blocks = [8, 16], #xe_map_a>, !xetile.tile<128x128xf16, inner_blocks = [16, 16], #xe_map_b>, vector<128x128xf32> } // store the final accumulated C tile result back to memory - // CHECK : xetile.store_tile - // CHECK-SAME : (vector<8x16xf32>, !xetile.tile<8x16xf32>) - xetile.store_tile %out#2, %c_init_tile: (vector<8x16xf32>, !xetile.tile<8x16xf32>) + // CHECK: xetile.store_tile + // CHECK-SAME: vector<128x128xf32>, !xetile.tile<128x128xf32, inner_blocks = [8, 16], #xetile.xe_map, + // CHECK-SAME: sg = >> + xetile.store_tile %out#2, %c_init_tile : vector<128x128xf32>, !xetile.tile<128x128xf32, inner_blocks = [8, 16], #xe_map_c> return }