Skip to content

Commit

Permalink
[XeTileToXeGPU] Enable dynamic memref in XeTileToXeGPU lowering.
Browse files Browse the repository at this point in the history
Supports static and dynamic memrefs with strides as well as integer source.
  • Loading branch information
mshahneo committed Nov 25, 2024
1 parent c3dc222 commit 9f495a6
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 67 deletions.
36 changes: 36 additions & 0 deletions include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,42 @@ llvm::SmallVector<int64_t> defaultStrides(llvm::ArrayRef<int64_t> shape);
/// capability).
bool isVectorAnyINTELType(mlir::Type type);

/// convert OpFoldResult to Value by replacing integer
/// attributes with arith::ConstantOps. It also performs
/// simple type conversions
mlir::Value getValueOrConstantOp(mlir::OpFoldResult ofr, mlir::Location loc,
mlir::PatternRewriter &rewriter,
mlir::Type type = nullptr);

// A universal get method for offsets or shapes or strides (OSS) of
// xetile::InitTileOp and xegpu::CreateNdDescOp op.
// OSS (Offsets, Shapes, Strides) information provided
// to InitTileOp & CreateNdDescOp is multifaceted. In other words oss info
// provided to InitTileOp & CreateNdDescOp in multiple ways, especially the
// shapes and strides:
// 1. For static memrefs: the shapes and strides info are inherent in the
// memref data type

// 2. For dynamic memrefs and i64/i32 source: the shapes and strides info is
// provided via the operands `sizes` and `strides` repectively, however these
// operands can also take two different types:

// 2.1 Constant type: constant attribute can be passed
// 2.2 Value type: a value type can be passed

// This function collects these info based on different scenarios and returns
// them in Value types.

// One can pass the result of getMixedOffsets(), getMixedSizes(),
// getMixedStrides() to the following utility to get them as Value types.
// Since both xetile::InitTileOp and xegpu::CreateNdDescOp ops implement the
// OffsetSizeAndStrideOpInterface, getMixedOffsets(), getMixedSizes(),
// getMixedStrides() takes care of the different scenarios mentioned above.

llvm::SmallVector<mlir::Value> getStridesOrOffsetsOrShapesInValueType(
mlir::PatternRewriter &rewriter,
::llvm::SmallVector<mlir::OpFoldResult> mixedOSS, mlir::Location loc);

} // namespace imex

#endif
73 changes: 13 additions & 60 deletions lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,26 @@
///
//===----------------------------------------------------------------------===//

#include <imex/Conversion/XeGPUToVC/XeGPUToVC.h>

#include "imex/Conversion/XeGPUToVC/XeGPUToVC.h"
#include "imex/Conversion/ArithToVC/ArithToVC.h"
#include "imex/Conversion/MathToVC/MathToVC.h"
#include "imex/Utils/VCUtils.h"
#include "imex/Utils/XeCommon.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

#include "imex/Conversion/ArithToVC/ArithToVC.h"
#include "imex/Conversion/MathToVC/MathToVC.h"
#include "imex/Utils/VCUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -71,23 +69,6 @@ static bool isOneOrUnknow(OpFoldResult ofr) {
return !val || *val == 1;
}

// convert OpFoldResult to Value by replacing integer
// attributes with arith::ConstantOps. It also performs
// simple type conversions
static Value getValueOrConstantOp(OpFoldResult ofr, Location loc,
ConversionPatternRewriter &rewriter,
Type type = nullptr) {
if (ofr.is<Value>())
return ofr.get<Value>();

auto intAttr = cast<IntegerAttr>(ofr.get<Attribute>());

if (type)
intAttr = IntegerAttr::get(type, intAttr.getInt());

return rewriter.create<arith::ConstantOp>(loc, intAttr);
}

static Value castValueTo(Value val, Type toType, Location loc,
ConversionPatternRewriter &rewriter) {

Expand Down Expand Up @@ -120,36 +101,6 @@ static Value castValueTo(Value val, Type toType, Location loc,
#define addi(a, b) rewriter.createOrFold<arith::AddIOp>(loc, a, b)
#define subi(a, b) rewriter.createOrFold<arith::SubIOp>(loc, a, b)

// A universal get method for strides of CreateNdDescOp op

// Get the effective strides of the CreateNdDescOp op
// Stride information provided to CreateNdDescOp is multifaceted
// In other words strides info provided to CreateNdDescOp in multiple ways:
// 1. For static memrefs: the strides info are inherent in the memref data type
// 2. For dynamic memrefs and i64/i32 source: the strides info is provided via
// the operand `strides`, however the strides operand can also take two
// different types:
// 2.1 Constant type: constant attribute can be passed
// 2.2 Value type: a value type can be passed
// This function collects these info based on different scenarios and returns
// them in Value types.

// We use getMixedStrides() to collect the strides info instead of handling
// aforementioned cases manually, since, it already uses
// OffsetSizeAndStrideOpInterface, getMixedStrides() already takes care of
// memref type.
SmallVector<Value> getStridesInValueType(ConversionPatternRewriter &rewriter,
CreateNdDescOp op) {
SmallVector<Value> stridesVal;
auto mixedStrides = op.getMixedStrides();
for (size_t i = 0; i < mixedStrides.size(); i++) {
auto stride =
getValueOrConstantOp(mixedStrides[i], op.getLoc(), rewriter, indexTy);
stridesVal.push_back(stride);
}
return stridesVal;
}

// Given an n-dim memref, a tensor descriptor with tile rank of 2 defines a
// 2d memory region with respect to the two inner-most dimensions. Other
// outer dimensions affect the base address of the 2d plane. For 2d, we
Expand Down Expand Up @@ -177,9 +128,11 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter,
auto tileRank = op.getTensorDesc().getType().getRank();
auto offsets = op.getMixedOffsets();

auto strides = getStridesInValueType(rewriter, op);
auto strides = getStridesOrOffsetsOrShapesInValueType(
rewriter, op.getMixedStrides(), loc);

// Calculate the effective rank of the source based on strides arrayref size
// Calculate the effective rank of the source based on strides arrayref
// size
auto effectiveRank = strides.size();
int64_t ranksToAdjust = effectiveRank;
auto bytesPerElem =
Expand Down
34 changes: 27 additions & 7 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

#include "XeTileOpConversion.h"
#include "imex/Utils/XeArch.h"
#include "imex/Utils/XeCommon.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h>
Expand All @@ -28,7 +30,7 @@
#include <mlir/IR/BuiltinTypeInterfaces.h>

namespace imex {

using namespace mlir;
using mlir::vector::CreateMaskOp;
using mlir::vector::ExtractOp;
using mlir::vector::ExtractStridedSliceOp;
Expand Down Expand Up @@ -479,14 +481,32 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
tDescOffsets.push_back(tDescOffsetX);
tDescOffsets.push_back(tDescOffsetY);

// TODO: this needs improvement, it assumes the source is static
// memeref.
// Handle memref source.
if (auto MemRefTypedSource =
mlir::cast<mlir::TypedValue<mlir::MemRefType>>(source)) {
mlir::dyn_cast<mlir::TypedValue<mlir::MemRefType>>(source)) {
// Hnadle the case where the shape is static.
if (MemRefTypedSource.getType().hasStaticShape()) {
auto createNdOp = rewriter.create<mlir::xegpu::CreateNdDescOp>(
op.getLoc(), tDescTy /*resultTy*/,
MemRefTypedSource
/*source*/,
tDescOffsets /*offsets*/);

xegpuOps[blocks[1] * i + j] = createNdOp;
} else {
// Handle the case where the shape is dynamic.
auto createNdOp = rewriter.create<mlir::xegpu::CreateNdDescOp>(
loc, tDescTy, MemRefTypedSource, tDescOffsets,
op.getMixedSizes(), op.getMixedStrides());
xegpuOps[blocks[1] * i + j] = createNdOp;
}
} else if (auto intSourceType =
mlir::dyn_cast<mlir::TypedValue<mlir::IntegerType>>(
source)) {
// Handle the case where the source is an integer.
auto createNdOp = rewriter.create<mlir::xegpu::CreateNdDescOp>(
op.getLoc(), tDescTy /*resultTy*/, MemRefTypedSource /*source*/,
tDescOffsets /*offsets*/);

loc, tDescTy, intSourceType, tDescOffsets, op.getMixedSizes(),
op.getMixedStrides());
xegpuOps[blocks[1] * i + j] = createNdOp;
} else {
return mlir::failure();
Expand Down
30 changes: 30 additions & 0 deletions lib/Utils/XeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,34 @@ bool isVectorAnyINTELType(mlir::Type type) {
spirvSupportedSizes.end());
}

// convert OpFoldResult to Value by replacing integer
// attributes with arith::ConstantOps. It also performs
// simple type conversions
mlir::Value getValueOrConstantOp(mlir::OpFoldResult ofr, mlir::Location loc,
mlir::PatternRewriter &rewriter,
mlir::Type type) {
if (ofr.is<mlir::Value>())
return ofr.get<mlir::Value>();

auto intAttr = llvm::cast<mlir::IntegerAttr>(ofr.get<mlir::Attribute>());

if (type)
intAttr = mlir::IntegerAttr::get(type, intAttr.getInt());

return rewriter.create<mlir::arith::ConstantOp>(loc, intAttr);
}

llvm::SmallVector<mlir::Value> getStridesOrOffsetsOrShapesInValueType(
mlir::PatternRewriter &rewriter,
::llvm::SmallVector<mlir::OpFoldResult> mixedOSS, mlir::Location loc) {
llvm::SmallVector<mlir::Value> valueVec;
// auto mixedStrides = op.getMixedStrides();
for (size_t i = 0; i < mixedOSS.size(); i++) {
auto oss = getValueOrConstantOp(mixedOSS[i], loc, rewriter,
rewriter.getIndexType());
valueVec.push_back(oss);
}
return valueVec;
}

} // namespace imex
35 changes: 35 additions & 0 deletions test/Conversion/XeTileToXeGPU/sg_init_tile.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \
// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s

gpu.module @test_kernel {
//CHECK: gpu.func @sg_init_tile(%[[arg0:.*]]: memref<1024x1024xf32>, %[[arg1:.*]]: memref<?x?xf32>) {
gpu.func @sg_init_tile(%a: memref<1024x1024xf32>, %b: memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index

%result = arith.constant dense<0.0>: vector<32x32xf32>

// CHECK: %[[SRC_AS_INDEX:.*]] = memref.extract_aligned_pointer_as_index {{.*}} : memref<1024x1024xf32> -> index
%src_as_index = memref.extract_aligned_pointer_as_index %a : memref<1024x1024xf32> -> index
// CHECK-NEXT: %[[SRC_AS_INT:.*]] = arith.index_cast %[[SRC_AS_INDEX]] : index to i64
%src_as_int = arith.index_cast %src_as_index : index to i64

//CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%static_memref_src = xetile.init_tile %a[0, 32] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32>

// CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[arg1]][{{.*}}], [%c1024, %c1024], [%c1024, %c1] : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%dynamic_memref_src = xetile.init_tile %b[%c0, %c32],[%c1024, %c1024], [%c1024, %c1] : memref<?x?xf32> -> !xetile.tile<32x32xf32>

// CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[SRC_AS_INT]][{{.*}}], [%c1024, %c1024], [%c1024, %c1] : i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%int_src = xetile.init_tile %src_as_int[%c0, %c32], [%c1024, %c1024], [%c1024, %c1] : i64 -> !xetile.tile<32x32xf32>

//CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
xetile.store_tile %result, %static_memref_src: vector<32x32xf32>, !xetile.tile<32x32xf32>
xetile.store_tile %result, %dynamic_memref_src: vector<32x32xf32>, !xetile.tile<32x32xf32>
xetile.store_tile %result, %int_src: vector<32x32xf32>, !xetile.tile<32x32xf32>

gpu.return
}
}
Loading

0 comments on commit 9f495a6

Please sign in to comment.