diff --git a/include/imex/Dialect/XeTile/IR/XeTileBase.td b/include/imex/Dialect/XeTile/IR/XeTileBase.td new file mode 100644 index 000000000..d2a85430f --- /dev/null +++ b/include/imex/Dialect/XeTile/IR/XeTileBase.td @@ -0,0 +1,151 @@ +//===- XeTileOps.td - XeTile dialect -------*- tablegen -*-===// +// +// Copyright 2022 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines the XeTile dialect and its types. +/// +//===----------------------------------------------------------------------===// + +#ifndef _XeTile_BASE_TD_INCLUDED_ +#define _XeTile_BASE_TD_INCLUDED_ + +include "mlir/IR/OpBase.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/ShapedOpInterfaces.td" + +// Provide a definition of the 'XeTile' dialect in the ODS framework so that we +// can define our operations. +def XeTile_Dialect : Dialect { + // The namespace of our dialect + let name = "xetile"; + + // A short one-line summary + let summary = "A dialect for enabling tile-base programming at subgroup level"; + + // A longer description + let description = [{ + XeTile provides an abstraction supporting tile-based computation to simplify the + lowering of DNN operation like matrix multiplication. XeTile dialect works at tile sizes + that are larger than the tile sizes supported by the hardware. XeTile dilaect also hides + the auto-padding requirements for out-of-bound memory accesses and, supports arbitrary + input matrix sizes. + }]; + + // The C++ namespace that the dialect class definition resides in. + let cppNamespace = "::imex::xetile"; + + let dependentDialects = [ + "::mlir::memref::MemRefDialect"]; + + // TODO: temporary disable it. + let useDefaultTypePrinterParser = true; +} + +// Base class for dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class XeTile_Op traits = []> : + Op; + +// common base class for types in XeTile dialect +class XeTile_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface], + "::imex::xetile::TileBase"> +{ + let summary = "A type representing a 2D tile"; + let description = [{ + Tile data type in XeTile dialect is used to represent a 2D memory region. + This captures the 2d shape and type of the memory region it points to. + + Syntax: + + ``` + tile-type ::= `vector` `<` vector-dim-list vector-element-type `>` + tile-element-type ::= float-type | integer-type | index-type + tile-dim-list := (static-dim-list `x`)? + static-dim-list ::= decimal-literal `x` decimal-literal + ``` + + Examples: + + ```mlir + // A tile with i32 elements + tile<3x42xi32> + + // A tile with f32 elements + tile<4x5xf32> + ``` + }]; + + let parameters = (ins ArrayRefParameter<"int64_t">:$shape, + "::mlir::Type":$elementType); + + let builders = [ + TypeBuilderWithInferredContext<(ins + "::llvm::ArrayRef":$shape, "::mlir::Type":$elementType), [{ + assert(shape.size()==2); + return $_get(elementType.getContext(), shape, elementType); + }]>, + TypeBuilderWithInferredContext<(ins + "int64_t":$dim0, "int64_t":$dim1, "::mlir::Type":$elementType), [{ + llvm::SmallVector shape{dim0, dim1}; + assert(shape.size()==2); + return $_get(elementType.getContext(), shape, elementType); + }]> + ]; + + let extraClassDeclaration = [{ + using ::mlir::ShapedType::Trait::clone; + using ::mlir::ShapedType::Trait::getElementTypeBitWidth; + using ::mlir::ShapedType::Trait::getRank; + using ::mlir::ShapedType::Trait::getNumElements; + using ::mlir::ShapedType::Trait::isDynamicDim; + using ::mlir::ShapedType::Trait::hasStaticShape; + using ::mlir::ShapedType::Trait::getNumDynamicDims; + using ::mlir::ShapedType::Trait::getDimSize; + using ::mlir::ShapedType::Trait::getDynamicDimIndex; + }]; + + let assemblyFormat = "`<` custom($shape, $elementType) `>`"; +} + +// Integer types allowd in XeTile +def XeTile_IntType : AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>; + +// Float types allowed in XeTile +def XeTile_FloatType : AnyTypeOf<[F16, F32, F64, BF16, F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2FNUZ]>; + +// Define the scalar type for XeTile +def XeTile_ScalarType : AnyTypeOf<[XeTile_IntType, XeTile_FloatType]>; + +// define a 2D memref of XeTile scalar type +def XeTile_2DMemRef : MemRefRankOf<[XeTile_ScalarType], [2]>; + +// define the source type for XeTile init_tile +def XeTile_BaseAddrType : AnyTypeOf<[XeTile_2DMemRef, I64]>; + +// define the attribute type allowed for padding values for load op +def XeTile_PaddingValueAttr : AnyAttrOf<[I32Attr, F32Attr]>; + +#endif // _XeTile_BASE_TD_INCLUDED_ diff --git a/include/imex/Dialect/XeTile/IR/XeTileOps.td b/include/imex/Dialect/XeTile/IR/XeTileOps.td index 9162a66c1..bb3ea3b75 100644 --- a/include/imex/Dialect/XeTile/IR/XeTileOps.td +++ b/include/imex/Dialect/XeTile/IR/XeTileOps.td @@ -8,184 +8,125 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file defines the basic operations for the XeTile dialect. +/// This file defines the operations for the XeTile dialect. /// //===----------------------------------------------------------------------===// - #ifndef _XeTile_OPS_TD_INCLUDED_ #define _XeTile_OPS_TD_INCLUDED_ -include "mlir/IR/OpBase.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/BuiltinTypes.td" -include "mlir/IR/BuiltinTypeInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/CastInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/CopyOpInterface.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/ShapedOpInterfaces.td" - -// Provide a definition of the 'XeTile' dialect in the ODS framework so that we -// can define our operations. -def XeTile_Dialect : Dialect { - // The namespace of our dialect - let name = "xetile"; - - // A short one-line summary - let summary = "A dialect for enabling tile-base programming at subgroup level"; - - // A longer description - let description = [{ - XeTile provides an abstraction supporting tile-based computation to simplify the - lowering of DNN operation like matrix multiplication. XeTile dialect works at tile sizes - that are larger than the tile sizes supported by the hardware. XeTile dilaect also hides - the auto-padding requirements for out-of-bound memory accesses and, supports arbitrary - input matrix sizes. - }]; - - // The C++ namespace that the dialect class definition resides in. - let cppNamespace = "::imex::xetile"; - - let dependentDialects = [ - "::mlir::memref::MemRefDialect"]; - - // TODO: temporary disable it. - let useDefaultTypePrinterParser = true; -} - -// Base class for dialect operations. This operation inherits from the base -// `Op` class in OpBase.td, and provides: -// * The parent dialect of the operation. -// * The mnemonic for the operation, or the name without the dialect prefix. -// * A list of traits for the operation. -class XeTile_Op traits = []> : - Op; - -// common base class for types in XeTile dialect -class XeTile_Type traits = [], - string baseCppClass = "::mlir::Type"> - : TypeDef { - let mnemonic = typeMnemonic; -} - -def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface], - "::imex::xetile::TileBase"> -{ - let summary = "A type representing a 2D tile"; - let description = [{ - XeTile is the XeTile dialect's representation of a 2D tile. - XeTile is a 2 dimensional block of data. - - Syntax: - - ``` - tile-type ::= `vector` `<` vector-dim-list vector-element-type `>` - tile-element-type ::= float-type | integer-type | index-type - tile-dim-list := (static-dim-list `x`)? - static-dim-list ::= decimal-literal `x` decimal-literal - ``` - - Examples: - - ```mlir - // A tile with i32 elements - tile<3x42xi32> - - // A tile with f32 elements - tile<4x5xf32> - ``` - }]; - - let parameters = (ins ArrayRefParameter<"int64_t">:$shape, - "::mlir::Type":$elementType); +include "imex/Dialect/XeTile/IR/XeTileBase.td" - let builders = [ - TypeBuilderWithInferredContext<(ins - "::llvm::ArrayRef":$shape, "::mlir::Type":$elementType), [{ - assert(shape.size()==2); - return $_get(elementType.getContext(), shape, elementType); - }]>, - TypeBuilderWithInferredContext<(ins - "int64_t":$dim0, "int64_t":$dim1, "::mlir::Type":$elementType), [{ - llvm::SmallVector shape{dim0, dim1}; - assert(shape.size()==2); - return $_get(elementType.getContext(), shape, elementType); - }]> - ]; - - let extraClassDeclaration = [{ - using ::mlir::ShapedType::Trait::clone; - using ::mlir::ShapedType::Trait::getElementTypeBitWidth; - using ::mlir::ShapedType::Trait::getRank; - using ::mlir::ShapedType::Trait::getNumElements; - using ::mlir::ShapedType::Trait::isDynamicDim; - using ::mlir::ShapedType::Trait::hasStaticShape; - using ::mlir::ShapedType::Trait::getNumDynamicDims; - using ::mlir::ShapedType::Trait::getDimSize; - using ::mlir::ShapedType::Trait::getDynamicDimIndex; - }]; - - let assemblyFormat = "`<` custom($shape, $elementType) `>`"; -} - -def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure]> { +def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]> { let summary = "Describes an XeTile with reference to a base memref"; let description = [{ - The "init_tile" operation is used to describe a reduced-size view of a 2D base - memref. This operation takes in a memref and returns an xetile. + The "init_tile" operation is used to describe a 2D region (i.e. tile) in gloabl memory. + This operation takes in a 2D memref or an address and return an xetile. If dynamic-shaped + memref or an address is used as the base, it is required to specify the shape and strides + of the memory region described by the tile. The operation takes in the following arguments: - * source: a 2D "base" memref represent a memory region. - * offsets: memref-rank number of offsets into the "base" memref at which to + * source: Source can be static/dynamic shaped memref or an address (i64) + * offsets: 2 offsets into the "source" memref or address at which to create the tile. offsets can be operands (e.g., [%c0, %c]), attributes (e.g., [2, 4]), or mix of operand and attributs (e.g., [%c0, 4] and [2, %c0]). + * dynamic_shape : 2 shape arguments specifying the size of 2 dimensions of the "source". + This is only required if a dynmaic shaped memref or an address is used as "source". + * dynamic_strides : 2 stride arguments specifying the strides of the 2D "source" memory region. + This is only required if a dynmaic shaped memref or an address is used as "source". + * dynamic_offsets : This is a subset of "offsets". offsets can contain both static and dynamic + values. "dynamic_offsets" captures the dynamic subset of the offsets. + For the follwing examples, suppose the tile shape used by the compiler is 32x64. - Example 1 (suppose the tile shape used by the compiler is 32x64): + Example 1: + Creating an xetile using a static shaped 2D memref. ```mlir %0 = memref.alloc() : memref<1024x1024xf32> - %1 = xetile.init_tile %0[256, 512] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> + %c128 = arith.constant 128 : index + %2 = xetile.init_tile %0[%c128, 512] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> ``` - Example 2 (suppose the tile shape used by the compiler is 32x64): + Example 2: + Creating an xetile using a dynamic shaped 2D memref. ```mlir - %0 = memref.alloc() : memref<1024x1024xf32> + %c64 = arith.constant 64 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %src = memref.alloc(%c1024, %c512) : memref + %1 = xetile.init_tile %src[256, %c64], [1024, 512], [512, 1] : memref -> !xetile.tile<32x64xf32> + ``` + + Example 3: + Creating an xetile using an address + + ```mlir + %src = .... : i64 + ... %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index - %1 = xetile.init_tile %0[%c128, %c256] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> + %1 = xetile.init_tile %src[%c128, %c256], [1024, 1024], [1024, 1] : i64 -> !xetile.tile<32x64xf32> ``` }]; - let arguments = (ins AnyMemRef:$base, + let arguments = (ins XeTile_BaseAddrType:$source, Variadic:$offsets, - DenseI64ArrayAttr:$static_offsets); + DenseI64ArrayAttr:$static_offsets, + Variadic:$dynamic_shape, + Variadic:$dynamic_strides + ); let results = (outs XeTile: $tile); + let builders = [ + // creating init_tile op with static memref + OpBuilder<(ins "xetile::TileType":$resultType, + "::mlir::Value":$source, + "::llvm::ArrayRef<::mlir::OpFoldResult>":$offsets, + CArg<"::llvm::ArrayRef<::mlir::NamedAttribute>", "{}">:$attrs)>, + // creating init_tile op with dynamic memref or an address + OpBuilder<(ins "xetile::TileType":$resultType, + "::mlir::Value":$source, + "::llvm::ArrayRef<::mlir::OpFoldResult>":$offsets, + "::llvm::ArrayRef<::mlir::Value>":$dynamic_shape, + "::llvm::ArrayRef<::mlir::Value>":$dynamic_strides, + CArg<"::llvm::ArrayRef<::mlir::NamedAttribute>", "{}">:$attrs + )> + ]; + let assemblyFormat = [{ - $base `` + $source `` custom($offsets, $static_offsets) - attr-dict `:` qualified(type($base)) `->` qualified(type($tile)) + (`,` `[` $dynamic_shape^ `]`)? + (`,` `[` $dynamic_strides^ `]`)? + attr-dict `:` qualified(type($source)) `->` qualified(type($tile)) }]; + let extraClassDeclaration = [{ - /// get the type of the base memref - ::mlir::MemRefType getBaseType() { return getBase().getType().cast<::mlir::MemRefType>(); } + /// get source type, could be a memref or an integer + ::mlir::Type getSourceType() {return getSource().getType();} - /// Return the element type of the base memref - ::mlir::Type getBaseElementType() { - return getBaseType().getElementType(); + /// check if the source is a memref + bool isSourceMemRef() { + return ::llvm::isa<::mlir::MemRefType>(getSourceType()); } - /// Return the shape of the base memref - ::llvm::ArrayRef getStaticBaseShape() { - return getBaseType().getShape(); + /// check if the source is an i64 (i.e. pointer) + bool isSourceInteger() { + return ::llvm::isa<::mlir::IntegerType>(getSourceType()); } + /// get the element type of the source if it is a memref + /// this method will fail if the source is not a memeref + ::mlir::Type getSourceMemrefElemType() { + assert(isSourceMemRef() && "The source is not a memref."); + return getSourceType().cast<::mlir::MemRefType>().getElementType(); + } + + + /// The result of an init_tile is always a Tile of TileType. TileType getType() { return getTile().getType().cast(); @@ -201,23 +142,38 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure]> { return getType().getShape(); } - /// Whether the given dimension size indicates a dynamic dimension. - static constexpr bool isDynamic(int64_t dValue) { - return dValue == ::mlir::ShapedType::kDynamic; + /// check if the offsets are static + bool hasStaticOffsets() { + return !::mlir::ShapedType::isDynamicShape(getStaticOffsets()); + } + + /// check if a given dim in static offsets has a static value + bool hasStaticOffsetAtDim(int dim) { + return !::mlir::ShapedType::isDynamic(getStaticOffsets()[dim]); } - /// Whether the given shape has any size that indicates a dynamic dimension. - static bool isDynamicShape(::llvm::ArrayRef dSizes) { - return ::llvm::any_of(dSizes, [](int64_t dSize) { return isDynamic(dSize); }); + /// check if the source memref has static shape info + /// this method will fail if the source is not a memref + bool sourceMemRefHasStaticShape() { + assert(isSourceMemRef() && "source is not a memref."); + return getSourceType().cast<::mlir::MemRefType>().hasStaticShape(); } - bool hasStaticOffsets() { - return !isDynamicShape(getStaticOffsets()); + /// get the static shape of the source memref + /// this method will fail if the source is not a memref or has static shape + ::llvm::ArrayRef getSourceMemrefStaticShape() { + assert(sourceMemRefHasStaticShape() && "The source memref does not have static shape."); + return getSourceType().cast<::mlir::MemRefType>().getShape(); + } + + /// check if dynamic shape arguments are present + bool hasDynamicShape() { + return getDynamicShape().size(); } - int getNumOfStaticOffsets() { - return std::accumulate(getStaticOffsets().begin(), getStaticOffsets().end(), 0, - [](int64_t a, int64_t b) { return isDynamic(b)? a: a+1;}); + /// check if dynamic stride arguments are present + bool hasDynamicStrides() { + return getDynamicStrides().size(); } }]; @@ -264,36 +220,58 @@ def XeTile_LoadTileOp : XeTile_Op<"load_tile", []> { let description = [{ "load_tile" operation loads the values of a tile into a register region with similar layout. Optionally the load operation can be performed in blocked layout as well. This is done by - specifying a block factor which describes the size (rows and cols) of the block. Blocking - does not change the order of the outer dimension. For exmaple, if a tile [m, n] is loaded + specifying an "inner_blocks" attribute which describes the size (rows and cols) of the block. + Blocking does not change the order of the outer dimension. For exmaple, if a tile [m, n] is loaded with block factor [MB, NB] the resulting register region has the layout [m/MB, n/NB, MB, NB] - "load_tile" also supports transpose. + If optional "transpose" 2-element array attribute is specified, the loaded tile will be + transposed along the specified non-zero dimension. + + If optional "padding" value is specified, out-of-bounds memory accesses will be padded with the + specified padding values. This value defaults to "0.0f". This operatio has following arguments: * source : source tile that is loaded from - * block factor : optional 2-element array arrtibute to specify the size of the inner blocks + * inner_blocks : optional 2-element array arrtibute to specify the size of the inner blocks when loaded in the blocked layout - * transpose : optional boolean attibute to specify if the output of the load will be - trasnposed or not + * transpose : optional 2-element array attibute to specify along which axis the transpose + operation must be applied to the input tile + * padding : optional string attribute to specify the padding value if out-of-bounds + memory accesses occurs Example 1: ```mlir - %4 = xetile.load_tile %src {inner_blocks = [8, 16], TRANSPOSE = true} - : tile<64x32xf32> -> vector<2x8x16x8xf32> + %4 = xetile.load_tile %src : !xetile.tile<64x32xf32> -> vector<64x32xf32> + ``` + + Example 2: + ```mlir + %4 = xetile.load_tile %src { inner_blocks = [8, 16], transpose = [1, 0], padding = 1.0 : f32} + : !xetile.tile<64x32xf32> -> vector<2x8x16x8xf32> ``` }]; let arguments = (ins XeTile: $source, OptionalAttr: $inner_blocks, - OptionalAttr: $transpose + OptionalAttr: $transpose, + OptionalAttr: $padding ); - let results = (outs Builtin_Vector: $result); + let results = (outs Builtin_Vector: $value); let assemblyFormat = [{ - $source (`inner_blocks` `=` $inner_blocks^)? ` ` (`TRANSPOSE` `=` $transpose^)? - attr-dict `:` qualified(type($source)) `->` qualified(type($result)) + $source attr-dict `:` qualified(type($source)) `->` qualified(type($value)) + }]; + + let extraClassDeclaration = [{ + ::mlir::Attribute getPaddingValue() { + if (llvm::isa(getSource().getType().getElementType())) { + auto int32Zero = ::mlir::IntegerAttr::get(mlir::IntegerType::get((*this).getContext(), 32), 0); + return getPadding().value_or(int32Zero); + } + auto float32Zero = ::mlir::IntegerAttr::get(mlir::FloatType::getF32((*this).getContext()), 0.0); + return getPadding().value_or(float32Zero); + } }]; let hasVerifier = 1; @@ -306,27 +284,26 @@ def XeTile_StoreTileOp : XeTile_Op<"store_tile", []> { If a block factor is specified, the blocked vector is stored into memory in plan layout. This operation takes the following arguments: - * tile : tile to store into - * block : vector specifying the valur to store - * block factor : optional 2-element array arrtibute to specify the size of the inner blocks + * value : vector specifying the values to store + * tile : tile representing the memory region to store into + * innner_blocks : optional 2-element array arrtibute to specify the size of the inner blocks when stored in the blocked layout Example 1: ```mlir - xetile.store_tile %dst, %value {inner_blocks = [8,16]} + xetile.store_tile %value, %dst { inner_blocks = [8,16] } : (!tile<64x32xf32>, vector<8x2x8x16xf32>) ``` }]; let arguments = (ins + Builtin_Vector: $value, XeTile: $tile, - Builtin_Vector: $block, OptionalAttr: $inner_blocks ); let assemblyFormat = [{ - $tile`,`` `$block (`inner_blocks` `=` $inner_blocks^)? attr-dict - `:` `(` qualified(type($tile)) `,` qualified(type($block)) `)` + $value`,`` `$tile attr-dict `:` `(` qualified(type($value)) `,` qualified(type($tile)) `)` }]; } @@ -362,28 +339,34 @@ def XeTile_TileMMAOp : XeTile_Op<"tile_mma", [Pure]> { let summary = "matrix multiplication in blocked layout"; let description = [{ "tile_mma" operation represents matrix multiplication on tiles. This operation - takes two input matrices (matrix A, matrix B) and an accumulator matrix (matrix C) to + takes two input matrices (matrix A, matrix B) and an optional accumulator matrix (matrix C) to perform a general matrix multiplication. C_new = A * B + C Optionally inputs A, B, C can be in blocked layout where the block factor is specificed by - an optional inner_blocks attribute. + an optional a_inner_blocks, b_inner_blocks attributes. Arguments: * a : vector representing input matrix A * b : vector representing input matrix B - * c : vector representing accumulator matrix C - * a_inner_blocks : options block factor for matrix A if it is in blocked layout - * b_inner_blocks : options block factor for matrix B if it is in blocked layout + * c : optional vector representing accumulator matrix C + * a_inner_blocks : optional inner_blocks attribute for matrix A if it is in blocked layout + * b_inner_blocks : optional inner_blocks attribute for matrix B if it is in blocked layout Example 1: + ```mlir + %c_new = xetile.tile_mma %a_vec, %b_vec + : (vector<64x32xf32>, vector<32x128xf32>) -> vector<64x128xf32> + ``` + + Example 2: ```mlir %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec : (vector<64x32xf32>, vector<32x128xf32>, vector<64x128xf32>) -> vector<64x128xf32> ``` - Example 2: + Example 3: ```mlir - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec a_inner_blocks=[8,8] b_inner_blocks=[8,16] + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec { a_inner_blocks=[8,8], b_inner_blocks=[8,16] } : (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>, vector<8x8x8x16xf32>) -> vector<8x8x8x16xf32> ``` @@ -393,7 +376,7 @@ def XeTile_TileMMAOp : XeTile_Op<"tile_mma", [Pure]> { let arguments = (ins Builtin_Vector: $a, Builtin_Vector: $b, - Builtin_Vector: $c, + Optional: $c, OptionalAttr: $a_inner_blocks, OptionalAttr: $b_inner_blocks ); @@ -401,10 +384,8 @@ def XeTile_TileMMAOp : XeTile_Op<"tile_mma", [Pure]> { let results = (outs Builtin_Vector: $output); let assemblyFormat = [{ - $a`,` ` `$b`,` ` `$c - (`a_inner_blocks` `=` $a_inner_blocks^)? - (`b_inner_blocks` `=` $b_inner_blocks^)? attr-dict - `:` `(`qualified(type($a))`,` ` `qualified(type($b))`,` ` `qualified(type($c))`)` `->` qualified(type($output)) + $a`,` ` `$b (`,` ` `$c^)? attr-dict `:` `(`qualified(type($a))`,` ` `qualified(type($b)) + (`,` ` `qualified(type($c))^)?`)` `->` qualified(type($output)) }]; let extraClassDeclaration = [{ @@ -442,9 +423,16 @@ def XeTile_UpdateTileOffsetOp : XeTile_Op<"update_tile_offset", []> { Index: $offset_y ); + let results = (outs + XeTile: $result + ); + let assemblyFormat = [{ $tile `,` ` ` $offset_x `,` ` ` $offset_y attr-dict `:` - `(` qualified(type($tile)) `,` ` ` qualified(type($offset_x)) `,` ` ` qualified(type($offset_y)) `)` + `(` qualified(type($tile)) `,` + ` ` qualified(type($offset_x)) `,` + ` ` qualified(type($offset_y)) `)` + `->` qualified(type($result)) }]; } diff --git a/lib/Dialect/XeTile/IR/XeTileOps.cpp b/lib/Dialect/XeTile/IR/XeTileOps.cpp index 50c103212..3ff256bb8 100644 --- a/lib/Dialect/XeTile/IR/XeTileOps.cpp +++ b/lib/Dialect/XeTile/IR/XeTileOps.cpp @@ -12,8 +12,10 @@ /// //===----------------------------------------------------------------------===// +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" @@ -67,35 +69,94 @@ static void printShape(mlir::AsmPrinter &printer, llvm::ArrayRef shape, } mlir::LogicalResult InitTileOp::verify() { - auto baseTy = getBaseType(); - - // base memref must either have a static shape or a strided layout - // otherwise, we can not get the shape info to create the tile - auto stridedLayout = - ::llvm::dyn_cast<::mlir::StridedLayoutAttr>(baseTy.getLayout()); - if (!baseTy.hasStaticShape() && !stridedLayout) { - return emitOpError("base memref does not have a static shape or stride " - "layout information."); + + // number of offsets must be 2 because init_tile creates 2D tiles + // dynamic_offsets is always a subset of offsets, so checking this is + // sufficient + if (getStaticOffsets().size() != 2) { + return emitOpError("number of offsets must be 2"); + } + + // if the source is a memref and has static shape, then dynamic shape and + // strides arguments must not be present + if (isSourceMemRef() && sourceMemRefHasStaticShape() && + (hasDynamicStrides() || hasDynamicShape())) { + return emitOpError("dynamic shape or strides are not allowed with a static " + "shaped memref as source"); + } + + // if the source is a memref with dynamic shape, then a 2D dynamic shape + // argument must be present + if (isSourceMemRef() && !sourceMemRefHasStaticShape() && + getDynamicShape().size() != 2) { + return emitOpError("memref with a dynamic shape is used as source but " + "dynamic shape argument missing or it is not 2D"); + } + + // if the source is a memref with dynamic shape, then a 2D dynamic strides + // argument must be present + if (isSourceMemRef() && !sourceMemRefHasStaticShape() && + getDynamicStrides().size() != 2) { + return emitOpError("memref with a dynamic shape is used as source but " + "dynamic strides argument missing or it is not 2D"); } - // offsets must be 2D. - int numOffsets = getNumOfStaticOffsets() + getOffsets().size(); - if (numOffsets != 2) { - return emitOpError("offsets of the init_tile must be 2D."); + // if the source is an address, the dynamic shape must be 2D + if (isSourceInteger() && getDynamicShape().size() != 2) { + return emitOpError("address is used as source but dynamic shape argument " + "is missing or it is not 2D"); + } + + // if the source is an address, dynamic strides must be 2D + if (isSourceInteger() && getDynamicStrides().size() != 2) { + return emitOpError("address is used as source but dynamic strides argument " + "is missing or it is not 2D"); } return mlir::success(); } +void InitTileOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, + xetile::TileType resultType, ::mlir::Value source, + ::llvm::ArrayRef<::mlir::OpFoldResult> offsets, + ::llvm::ArrayRef<::mlir::NamedAttribute> attrs) { + ::llvm::SmallVector staticOffsets; + ::llvm::SmallVector<::mlir::Value> dynamicOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + build(builder, state, resultType, source, dynamicOffsets, staticOffsets, + ::mlir::ValueRange({}), /* empty dynamic shape*/ + ::mlir::ValueRange({})); /* empty dynamic strides*/ + state.addAttributes(attrs); +} + +void InitTileOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, + xetile::TileType resultType, ::mlir::Value source, + ::llvm::ArrayRef<::mlir::OpFoldResult> offsets, + ::llvm::ArrayRef<::mlir::Value> dynamic_shape, + ::llvm::ArrayRef<::mlir::Value> dynamic_strides, + ::llvm::ArrayRef<::mlir::NamedAttribute> attrs) { + ::llvm::SmallVector staticOffsets; + ::llvm::SmallVector<::mlir::Value> dynamicOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + build(builder, state, resultType, source, dynamicOffsets, staticOffsets, + dynamic_shape, dynamic_strides); + state.addAttributes(attrs); +} + mlir::LogicalResult LoadTileOp::verify() { int64_t outputRank = - llvm::cast(getResult().getType()).getRank(); + llvm::cast(getValue().getType()).getRank(); auto innerBlocks = getInnerBlocksAttr(); + auto transpose = getTransposeAttr(); // if load_tile operation in blocked format only support 2D blocks if (innerBlocks && innerBlocks.size() != 2) { - return emitOpError("inner_blocks must be two dimensional if specified"); + return emitOpError("inner_blocks must be two dimensional"); } // if blocked load_tile load is specified output must be 4-dimensional @@ -104,6 +165,10 @@ mlir::LogicalResult LoadTileOp::verify() { "output must be 4-dimensional if inner_blocks is specified"); } + if (transpose && transpose.size() != 2) { + return emitOpError("transpose must be two dimensional"); + } + return mlir::success(); } diff --git a/test/Dialect/XeTile/IR/XeTileOps.mlir b/test/Dialect/XeTile/IR/XeTileOps.mlir index 77fb5b594..7b9067ae1 100644 --- a/test/Dialect/XeTile/IR/XeTileOps.mlir +++ b/test/Dialect/XeTile/IR/XeTileOps.mlir @@ -4,8 +4,9 @@ // Verify the generic form can be parsed. // RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s -// CHECK-LABEL: func @test_init_tile({{.*}}) { -func.func @test_init_tile(%src: memref<1024x1024xf32>) { +// init_tile with a static shaped memref +// CHECK-LABEL: func @test_init_tile_using_static_memref({{.*}}) { +func.func @test_init_tile_using_static_memref(%src: memref<1024x1024xf32>) { %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index @@ -18,6 +19,64 @@ func.func @test_init_tile(%src: memref<1024x1024xf32>) { // CHECK-SAME: memref<1024x1024xf32> -> !xetile.tile<32x64xf32> %2 = xetile.init_tile %src[%c128, %c256] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> + // CHECK: xetile.init_tile + // CHECK-SAME: memref<1024x1024xf32> -> !xetile.tile<32x64xf32> + %3 = xetile.init_tile %src[512, %c128] : memref<1024x1024xf32> -> !xetile.tile<32x64xf32> + + return +} + +// init tile with a dynmaic shaped memref +// CHECK-LABEL: func @test_init_tile_using_dynamic_memref({{.*}}) { +func.func @test_init_tile_using_dynamic_memref(%src: memref, %dim0_size : index, %dim1_size : index, + %dim0_stride : index, %dim1_stride : index ) { + + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + // CHECK: xetile.init_tile + // CHECK-SAME: memref -> !xetile.tile<32x64xf32> + %1 = xetile.init_tile %src[8, 16], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : memref -> !xetile.tile<32x64xf32> + + // CHECK: xetile.init_tile + // CHECK-SAME: memref -> !xetile.tile<32x64xf32> + %2 = xetile.init_tile %src[%c128, %c256], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : memref -> !xetile.tile<32x64xf32> + + + // CHECK: xetile.init_tile + // CHECK-SAME: memref -> !xetile.tile<32x64xf32> + %3 = xetile.init_tile %src[%c128, 64], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : memref -> !xetile.tile<32x64xf32> + + return +} + +// init tile with an addr +// CHECK-LABEL: func @test_init_tile_using_addr({{.*}}) { +func.func @test_init_tile_using_addr(%src: i64, %dim0_size : index, %dim1_size : index, + %dim0_stride : index, %dim1_stride : index ) { + + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + + // CHECK: xetile.init_tile + // CHECK-SAME: i64 -> !xetile.tile<32x64xf32> + %1 = xetile.init_tile %src[8, 16], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : i64 -> !xetile.tile<32x64xf32> + + // CHECK: xetile.init_tile + // CHECK-SAME: i64 -> !xetile.tile<32x64xf32> + %2 = xetile.init_tile %src[%c128, %c256], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : i64 -> !xetile.tile<32x64xf32> + + + // CHECK: xetile.init_tile + // CHECK-SAME: i64 -> !xetile.tile<32x64xf32> + %3 = xetile.init_tile %src[%c128, 64], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride] + : i64 -> !xetile.tile<32x64xf32> + return } @@ -42,16 +101,22 @@ func.func @test_load_tile(%src: !xetile.tile<64x32xf32>) { %1 = xetile.load_tile %src : !xetile.tile<64x32xf32> -> vector<64x32xf32> // CHECK: xetile.load_tile - // CHECK-SAME: inner_blocks = [8, 16] : !xetile.tile<64x32xf32> -> vector<8x2x8x16xf32> - %2 = xetile.load_tile %src inner_blocks = [8, 16] : !xetile.tile<64x32xf32> -> vector<8x2x8x16xf32> + // CHECK-SAME: {inner_blocks = [8, 16]} : !xetile.tile<64x32xf32> -> vector<8x2x8x16xf32> + %2 = xetile.load_tile %src { inner_blocks = [8, 16] } : !xetile.tile<64x32xf32> -> vector<8x2x8x16xf32> + + // CHECK: xetile.load_tile + // CHECK-SAME: {transpose = [1, 0]} : !xetile.tile<64x32xf32> -> vector<32x64xf32> + %3 = xetile.load_tile %src { transpose = [1, 0] } : !xetile.tile<64x32xf32> -> vector<32x64xf32> // CHECK: xetile.load_tile - // CHECK-SAME: TRANSPOSE = true : !xetile.tile<64x32xf32> -> vector<32x64xf32> - %3 = xetile.load_tile %src TRANSPOSE = true : !xetile.tile<64x32xf32> -> vector<32x64xf32> + // CHECK-SAME: {padding = 1.000000e-01 : f32} : !xetile.tile<64x32xf32> -> vector<64x32xf32> + %4 = xetile.load_tile %src { padding = 0.1 : f32 } : !xetile.tile<64x32xf32> -> vector<64x32xf32> // CHECK: xetile.load_tile - // CHECK-SAME: inner_blocks = [8, 16] TRANSPOSE = true : !xetile.tile<64x32xf32> -> vector<2x8x16x8xf32> - %4 = xetile.load_tile %src inner_blocks = [8, 16] TRANSPOSE = true : !xetile.tile<64x32xf32> -> vector<2x8x16x8xf32> + // CHECK-SAME: {inner_blocks = [8, 16], padding = 1.000000e-01 : f32, transpose = [1, 0]} : + // CHECK-SAME: !xetile.tile<64x32xf32> -> vector<2x8x16x8xf32> + %5 = xetile.load_tile %src { inner_blocks = [8, 16], transpose = [1, 0], padding = 0.1 : f32 } + : !xetile.tile<64x32xf32> -> vector<2x8x16x8xf32> return } @@ -61,12 +126,12 @@ func.func @test_store_tile(%value1 : vector<64x32xf32>, %value2 : vector<8x2x8x16xf32>, %dst: !xetile.tile<64x32xf32>) { // CHECK: xetile.store_tile - // CHECK-SAME: (!xetile.tile<64x32xf32>, vector<64x32xf32>) - xetile.store_tile %dst, %value1 : (!xetile.tile<64x32xf32>, vector<64x32xf32>) + // CHECK-SAME: (vector<64x32xf32>, !xetile.tile<64x32xf32>) + xetile.store_tile %value1, %dst : (vector<64x32xf32>, !xetile.tile<64x32xf32>) // CHECK: xetile.store_tile - // CHECK-SAME: inner_blocks = [8, 16] : (!xetile.tile<64x32xf32>, vector<8x2x8x16xf32>) - xetile.store_tile %dst, %value2 inner_blocks = [8,16] : (!xetile.tile<64x32xf32>, vector<8x2x8x16xf32>) + // CHECK-SAME: {inner_blocks = [8, 16]} : (vector<8x2x8x16xf32>, !xetile.tile<64x32xf32>) + xetile.store_tile %value2, %dst { inner_blocks = [8, 16] } : (vector<8x2x8x16xf32>, !xetile.tile<64x32xf32>) return } @@ -104,29 +169,41 @@ func.func @test_tile_mma(%a: !xetile.tile<64x32xf32>, %b: !xetile.tile<32x128xf3 // CHECK-SAME: !xetile.tile<64x128xf32> -> vector<64x128xf32> %c_vec = xetile.load_tile %c : !xetile.tile<64x128xf32> -> vector<64x128xf32> + // CHECK: xetile.tile_mma + // CHECK-SAME: (vector<64x32xf32>, vector<32x128xf32>) -> vector<64x128xf32> + %c_new = xetile.tile_mma %a_vec, %b_vec + : (vector<64x32xf32>, vector<32x128xf32>) -> vector<64x128xf32> + // CHECK: xetile.tile_mma // CHECK-SAME: (vector<64x32xf32>, vector<32x128xf32>, vector<64x128xf32>) -> vector<64x128xf32> - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec + %c_new_ = xetile.tile_mma %a_vec, %b_vec, %c_vec : (vector<64x32xf32>, vector<32x128xf32>, vector<64x128xf32>) -> vector<64x128xf32> // CHECK: xetile.load_tile - // CHECK-SAME: inner_blocks = [8, 8] : !xetile.tile<64x32xf32> -> vector<8x4x8x8xf32> - %a_vec_1 = xetile.load_tile %a inner_blocks = [8, 8] + // CHECK-SAME: {inner_blocks = [8, 8]} : !xetile.tile<64x32xf32> -> vector<8x4x8x8xf32> + %a_vec_1 = xetile.load_tile %a { inner_blocks = [8, 8] } : !xetile.tile<64x32xf32> -> vector<8x4x8x8xf32> // CHECK: xetile.load_tile - // CHECK-SAME: inner_blocks = [8, 16] : !xetile.tile<32x128xf32> -> vector<4x8x8x16xf32> - %b_vec_1 = xetile.load_tile %b inner_blocks = [8,16] + // CHECK-SAME: {inner_blocks = [8, 16]} : !xetile.tile<32x128xf32> -> vector<4x8x8x16xf32> + %b_vec_1 = xetile.load_tile %b { inner_blocks = [8, 16] } : !xetile.tile<32x128xf32> -> vector<4x8x8x16xf32> // CHECK: xetile.load_tile - // CHECK-SAME: inner_blocks = [8, 16] : !xetile.tile<64x128xf32> -> vector<8x8x8x16xf32> - %c_vec_1 = xetile.load_tile %c inner_blocks = [8,16] + // CHECK-SAME: {inner_blocks = [8, 16]} : !xetile.tile<64x128xf32> -> vector<8x8x8x16xf32> + %c_vec_1 = xetile.load_tile %c { inner_blocks = [8, 16] } : !xetile.tile<64x128xf32> -> vector<8x8x8x16xf32> // CHECK: xetile.tile_mma + // CHECK-SAME: {a_inner_blocks = [8, 8], b_inner_blocks = [8, 16]} + // CHECK-SAME: (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>) -> vector<8x8x8x16xf32> + %c_new_1 = xetile.tile_mma %a_vec_1, %b_vec_1 {a_inner_blocks = [8, 8], b_inner_blocks = [8, 16]} + : (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>) -> vector<8x8x8x16xf32> + + // CHECK: xetile.tile_mma + // CHECK-SAME: {a_inner_blocks = [8, 8], b_inner_blocks = [8, 16]} // CHECK-SAME: (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>, vector<8x8x8x16xf32>) -> vector<8x8x8x16xf32> - %c_new_1 = xetile.tile_mma %a_vec_1, %b_vec_1, %c_vec_1 a_inner_blocks=[8,8] b_inner_blocks=[8,16] + %c_new_1_ = xetile.tile_mma %a_vec_1, %b_vec_1, %c_vec_1 {a_inner_blocks = [8, 8], b_inner_blocks = [8, 16]} : (vector<8x4x8x8xf32>, vector<4x8x8x16xf32>, vector<8x8x8x16xf32>) -> vector<8x8x8x16xf32> return @@ -141,7 +218,8 @@ func.func @test_update_tile_offset(%tile: !xetile.tile<32x32xf32>) { // CHECK: xetile.update_tile_offset // CHECK-SAME: (!xetile.tile<32x32xf32>, index, index) - xetile.update_tile_offset %tile, %offset_x, %offset_y : (!xetile.tile<32x32xf32>, index, index) + xetile.update_tile_offset %tile, %offset_x, %offset_y + : (!xetile.tile<32x32xf32>, index, index) -> !xetile.tile<32x32xf32> return } diff --git a/test/Dialect/XeTile/IR/invalid.mlir b/test/Dialect/XeTile/IR/invalid.mlir index fdd706c4b..7b8cc9d75 100644 --- a/test/Dialect/XeTile/IR/invalid.mlir +++ b/test/Dialect/XeTile/IR/invalid.mlir @@ -2,60 +2,91 @@ // ----- -func.func @init_tile_with_unranked_memref(%source : memref) { - // the source memref must be ranked - // expected-error@+1 {{base memref does not have a static shape or stride layout information.}} - %1 = xetile.init_tile %source[0, 0] - : memref -> !xetile.tile<32x64xf32> - - return +func.func @init_tile_with_invalid_offsets(%source : memref<64x64xf32>, %offset : index) { + // the offsets of the init_tile must be 2D + // expected-error@+1 {{number of offsets must be 2}} + %1 = xetile.init_tile %source[%offset, %offset, %offset] + : memref<64x64xf32> -> !xetile.tile<8x8xf32> } +// ----- +func.func @init_tile_static_memref_with_invalid_dynamic_shape(%source : memref<1024x1024xf32>, + %dim0_size : index, %dim1_size : index) { + // for source memref with static shape, dynamic shape arguments should not be present + // expected-error@+1 {{dynamic shape or strides are not allowed with a static shaped memref as source}} + %1 = xetile.init_tile %source[0, 0], [%dim0_size, %dim1_size] + : memref<1024x1024xf32> -> !xetile.tile<64x64xf32> +} +// ----- +func.func @init_tile_dynamic_memref_with_invalid_dynamic_shape(%source : memref, + %dim0_size : index, %dim1_size : index, %dim0_stride : index, %dim1_stride : index) { + // for source memref with dynamic shape, dynamic shape arguments should be 2D + // expected-error@+1 {{memref with a dynamic shape is used as source but dynamic shape argument missing or it is not 2D}} + %1 = xetile.init_tile %source[0, 0], [%dim0_size], [%dim0_stride, %dim1_stride] + : memref -> !xetile.tile<64x64xf32> +} // ----- -func.func @init_tile_with_invalid_offsets(%source : memref<64x64x64xf32>, %offset : index) { - // the offsets of the init_tile must be 2D - // expected-error@+1 {{offsets of the init_tile must be 2D.}} - %1 = xetile.init_tile %source[%offset, %offset, %offset] - : memref<64x64x64xf32> -> !xetile.tile<8x8xf32> +func.func @init_tile_dynamic_memref_with_invalid_dynamic_strides(%source : memref, + %dim0_size : index, %dim1_size : index, %dim0_stride : index, %dim1_stride : index) { + // for source memref with dynamic shape, dynamic strides arguments should be 2D + // expected-error@+1 {{memref with a dynamic shape is used as source but dynamic strides argument missing or it is not 2D}} + %1 = xetile.init_tile %source[0, 0], [%dim0_size, %dim1_size], [%dim0_stride] + : memref -> !xetile.tile<64x64xf32> +} + - return +// ----- +func.func @init_tile_address_with_invalid_dynamic_shape(%source : i64, %dim0_size : index, %dim1_size : index, + %dim0_stride : index, %dim1_stride : index) { + // for source address, dynamic shape arguments should be 2D + // expected-error@+1 {{address is used as source but dynamic shape argument is missing or it is not 2D}} + %1 = xetile.init_tile %source[0, 0], [%dim0_size], [%dim0_stride, %dim1_stride] + : i64 -> !xetile.tile<64x64xf32> } +// ----- +func.func @init_tile_address_with_invalid_dynamic_strides(%source : i64, %dim0_size : index, %dim1_size : index, + %dim0_stride : index, %dim1_stride : index) { + // for source address, dynamic strides arguments should be 2D + // expected-error@+1 {{address is used as source but dynamic strides argument is missing or it is not 2D}} + %1 = xetile.init_tile %source[0, 0], [%dim0_size, %dim1_size], [%dim0_stride] + : i64 -> !xetile.tile<64x64xf32> +} // ----- func.func @load_tile_with_invalid_inner_blocks(%tile : !xetile.tile<64x64xf32>) { - - // inner_blocks must be 2D - // expected-error@+1 {{inner_blocks must be two dimensional if specified}} - %1 = xetile.load_tile %tile inner_blocks = [8,16,4] + // INNER_BLOCKS must be 2D + // expected-error@+1 {{inner_blocks must be two dimensional}} + %1 = xetile.load_tile %tile { inner_blocks = [8,16,4] } : !xetile.tile<64x64xf32> -> vector<8x4x8x16xf32> +} - return +// ----- +func.func @load_tile_with_invalid_transpose(%tile : !xetile.tile<64x32xf32>) { + // TRANSPOSE must be 2D + // expected-error@+1 {{transpose must be two dimensional}} + %1 = xetile.load_tile %tile { transpose = [1, 0 , 0] } + : !xetile.tile<64x32xf32> -> vector<32x64xf32> } // ----- func.func @load_tile_with_invalid_output_rank(%tile : !xetile.tile<64x64xf32>) { - - // if the inner_blocks is specified in tile_load output must be 4D + // if the INNER_BLOCKS is specified in tile_load output must be 4D // expected-error@+1 {{output must be 4-dimensional if inner_blocks is specified}} - %1 = xetile.load_tile %tile inner_blocks = [8,16] + %1 = xetile.load_tile %tile { inner_blocks = [8,16] } : !xetile.tile<64x64xf32> -> vector<8x4xf32> - return } // ----- func.func @tile_mma_input_rank_mismatch(%a_vec : vector<8x8x8x8xf32>, %b_vec : vector<8x8x8xf32>, %c_vec : vector<8x8x8x8xf32>) { - // the two input vectors must have the same rank // expected-error@+1 {{rank mismatch in tile mma inputs}} - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec {a_inner_blocks = [8, 8], b_inner_blocks = [8, 8]} : (vector<8x8x8x8xf32>, vector<8x8x8xf32>, vector<8x8x8x8xf32>) -> vector<8x8x8x8xf32> - - return } // ----- @@ -63,8 +94,6 @@ func.func @tile_mma_input_elem_type_mismatch(%a_vec : vector<8x8x8x8xf16>, %b_vec : vector<8x8x8x8xf32>, %c_vec : vector<8x8x8x8xf32>) { // the two input vectors must have the same element type // expected-error@+1 {{element type mismatch in tile mma inputs}} - %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec + %c_new = xetile.tile_mma %a_vec, %b_vec, %c_vec {a_inner_blocks = [8, 8], b_inner_blocks = [8, 8]} : (vector<8x8x8x8xf16>, vector<8x8x8x8xf32>, vector<8x8x8x8xf32>) -> vector<8x8x8x8xf32> - - return } diff --git a/test/Dialect/XeTile/IR/simple_gemm.mlir b/test/Dialect/XeTile/IR/simple_gemm.mlir new file mode 100644 index 000000000..5e02d7313 --- /dev/null +++ b/test/Dialect/XeTile/IR/simple_gemm.mlir @@ -0,0 +1,68 @@ +// RUN: imex-opt %s | FileCheck %s +// Verify the printed output can be parsed. +// RUN: imex-opt %s | imex-opt | FileCheck %s +// Verify the generic form can be parsed. +// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s + + +// CHECK-LABEL: func @test_gemm({{.*}}) { +func.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c1024 = arith.constant 1024 : index + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %m = arith.muli %block_id_x, %c8 : index + %n = arith.muli %block_id_y, %c16 : index + // intialize C tile and load it + // CHECK : xetile.init_tile + // CHECK-SAME : memref<1024x1024xf32> -> !xetile.tile<8x16xf32> + %c_init_tile = xetile.init_tile %C[%m, %n] : memref<1024x1024xf32> -> !xetile.tile<8x16xf32> + // CHECK : xetile.load_tile + // CHECK-SAME : !xetile.tile<8x16xf32> -> vector<8x16xf32> + %c_init_value = xetile.load_tile %c_init_tile : !xetile.tile<8x16xf32> -> vector<8x16xf32> + // initalize A and B tiles + // CHECK : xetile.init_tile + // CHECK-SAME : memref<1024x1024xf16> -> !xetile.tile<8x16xf16> + %a_init_tile = xetile.init_tile %A[%m, %c0] : memref<1024x1024xf16> -> !xetile.tile<8x16xf16> + // CHECK : xetile.init_tile + // CHECK-SAME : memref<1024x1024xf16> -> !xetile.tile<16x16xf16> + %b_init_tile = xetile.init_tile %B[%c0, %n] : memref<1024x1024xf16> -> !xetile.tile<16x16xf16> + // compute the value of C tile by iterating over tiles in k-dimension and doing dpas + %out:3 = scf.for %k = %c0 to %c1024 step %c16 + iter_args(%a_tile = %a_init_tile, %b_tile = %b_init_tile, %c_value = %c_init_value) + -> (!xetile.tile<8x16xf16>, !xetile.tile<16x16xf16>, vector<8x16xf32>) { + + // load A and B tiles + // CHECK : xetile.load_tile + // CHECK-SAME : !xetile.tile<8x16xf16> -> vector<8x16xf16> + %a_value = xetile.load_tile %a_tile : !xetile.tile<8x16xf16> -> vector<8x16xf16> + // CHECK : xetile.load_tile + // CHECK-SAME : !xetile.tile<16x16xf16> -> vector<16x16xf16> + %b_value = xetile.load_tile %b_tile : !xetile.tile<16x16xf16> -> vector<16x16xf16> + // perform dpas and accumulate + // CHECK : xetile.tile_mma + // CHECK-SAME : (vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>) -> vector<8x16xf32> + %c_new_value = xetile.tile_mma %a_value, %b_value, %c_value + : (vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32>) -> vector<8x16xf32> + // update the offsets for A and B tiles + // CHECK : xetile.update_tile_offset + // CHECK-SAME : (!xetile.tile<8x16xf16>, index, index) -> !xetile.tile<8x16xf16> + %a_next_tile = xetile.update_tile_offset %a_tile, %c0, %c16 + : (!xetile.tile<8x16xf16>, index, index) -> !xetile.tile<8x16xf16> + // CHECK : xetile.update_tile_offset + // CHECK-SAME : (!xetile.tile<16x16xf16>, index, index) -> !xetile.tile<16x16xf16> + %b_next_tile = xetile.update_tile_offset %b_tile, %c16, %c0 + : (!xetile.tile<16x16xf16>, index, index) -> !xetile.tile<16x16xf16> + // partial C tile result + scf.yield %a_next_tile, %b_next_tile, %c_new_value + : !xetile.tile<8x16xf16>, !xetile.tile<16x16xf16>, vector<8x16xf32> + } + // store the final accumulated C tile result back to memory + // CHECK : xetile.store_tile + // CHECK-SAME : (vector<8x16xf32>, !xetile.tile<8x16xf32>) + xetile.store_tile %out#2, %c_init_tile: (vector<8x16xf32>, !xetile.tile<8x16xf32>) + return +}