Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XeTileToXeGPU] Enable dynamic memref in XeTileToXeGPU lowering. #969

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/imex/Utils/XeCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,42 @@ llvm::SmallVector<int64_t> defaultStrides(llvm::ArrayRef<int64_t> shape);
/// capability).
bool isVectorAnyINTELType(mlir::Type type);

/// convert OpFoldResult to Value by replacing integer
/// attributes with arith::ConstantOps. It also performs
/// simple type conversions
mlir::Value getValueOrConstantOp(mlir::OpFoldResult ofr, mlir::Location loc,
mlir::PatternRewriter &rewriter,
mlir::Type type = nullptr);

// A universal get method for offsets or shapes or strides (OSS) of
// xetile::InitTileOp and xegpu::CreateNdDescOp op.
// OSS (Offsets, Shapes, Strides) information provided
// to InitTileOp & CreateNdDescOp is multifaceted. In other words oss info
// provided to InitTileOp & CreateNdDescOp in multiple ways, especially the
// shapes and strides:
// 1. For static memrefs: the shapes and strides info are inherent in the
// memref data type

// 2. For dynamic memrefs and i64/i32 source: the shapes and strides info is
// provided via the operands `sizes` and `strides` repectively, however these
// operands can also take two different types:

// 2.1 Constant type: constant attribute can be passed
// 2.2 Value type: a value type can be passed

// This function collects these info based on different scenarios and returns
// them in Value types.

// One can pass the result of getMixedOffsets(), getMixedSizes(),
// getMixedStrides() to the following utility to get them as Value types.
// Since both xetile::InitTileOp and xegpu::CreateNdDescOp ops implement the
// OffsetSizeAndStrideOpInterface, getMixedOffsets(), getMixedSizes(),
// getMixedStrides() takes care of the different scenarios mentioned above.

llvm::SmallVector<mlir::Value> getStridesOrOffsetsOrShapesInValueType(
mlir::PatternRewriter &rewriter,
::llvm::SmallVector<mlir::OpFoldResult> mixedOSS, mlir::Location loc);

} // namespace imex

#endif
73 changes: 13 additions & 60 deletions lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,26 @@
///
//===----------------------------------------------------------------------===//

#include <imex/Conversion/XeGPUToVC/XeGPUToVC.h>

#include "imex/Conversion/XeGPUToVC/XeGPUToVC.h"
#include "imex/Conversion/ArithToVC/ArithToVC.h"
#include "imex/Conversion/MathToVC/MathToVC.h"
#include "imex/Utils/VCUtils.h"
#include "imex/Utils/XeCommon.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

#include "imex/Conversion/ArithToVC/ArithToVC.h"
#include "imex/Conversion/MathToVC/MathToVC.h"
#include "imex/Utils/VCUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -71,23 +69,6 @@ static bool isOneOrUnknow(OpFoldResult ofr) {
return !val || *val == 1;
}

// convert OpFoldResult to Value by replacing integer
// attributes with arith::ConstantOps. It also performs
// simple type conversions
static Value getValueOrConstantOp(OpFoldResult ofr, Location loc,
ConversionPatternRewriter &rewriter,
Type type = nullptr) {
if (ofr.is<Value>())
return ofr.get<Value>();

auto intAttr = cast<IntegerAttr>(ofr.get<Attribute>());

if (type)
intAttr = IntegerAttr::get(type, intAttr.getInt());

return rewriter.create<arith::ConstantOp>(loc, intAttr);
}

static Value castValueTo(Value val, Type toType, Location loc,
ConversionPatternRewriter &rewriter) {

Expand Down Expand Up @@ -120,36 +101,6 @@ static Value castValueTo(Value val, Type toType, Location loc,
#define addi(a, b) rewriter.createOrFold<arith::AddIOp>(loc, a, b)
#define subi(a, b) rewriter.createOrFold<arith::SubIOp>(loc, a, b)

// A universal get method for strides of CreateNdDescOp op

// Get the effective strides of the CreateNdDescOp op
// Stride information provided to CreateNdDescOp is multifaceted
// In other words strides info provided to CreateNdDescOp in multiple ways:
// 1. For static memrefs: the strides info are inherent in the memref data type
// 2. For dynamic memrefs and i64/i32 source: the strides info is provided via
// the operand `strides`, however the strides operand can also take two
// different types:
// 2.1 Constant type: constant attribute can be passed
// 2.2 Value type: a value type can be passed
// This function collects these info based on different scenarios and returns
// them in Value types.

// We use getMixedStrides() to collect the strides info instead of handling
// aforementioned cases manually, since, it already uses
// OffsetSizeAndStrideOpInterface, getMixedStrides() already takes care of
// memref type.
SmallVector<Value> getStridesInValueType(ConversionPatternRewriter &rewriter,
CreateNdDescOp op) {
SmallVector<Value> stridesVal;
auto mixedStrides = op.getMixedStrides();
for (size_t i = 0; i < mixedStrides.size(); i++) {
auto stride =
getValueOrConstantOp(mixedStrides[i], op.getLoc(), rewriter, indexTy);
stridesVal.push_back(stride);
}
return stridesVal;
}

// Given an n-dim memref, a tensor descriptor with tile rank of 2 defines a
// 2d memory region with respect to the two inner-most dimensions. Other
// outer dimensions affect the base address of the 2d plane. For 2d, we
Expand Down Expand Up @@ -177,9 +128,11 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter,
auto tileRank = op.getTensorDesc().getType().getRank();
auto offsets = op.getMixedOffsets();

auto strides = getStridesInValueType(rewriter, op);
auto strides = getStridesOrOffsetsOrShapesInValueType(
rewriter, op.getMixedStrides(), loc);

// Calculate the effective rank of the source based on strides arrayref size
// Calculate the effective rank of the source based on strides arrayref
// size
auto effectiveRank = strides.size();
int64_t ranksToAdjust = effectiveRank;
auto bytesPerElem =
Expand Down
34 changes: 27 additions & 7 deletions lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

#include "XeTileOpConversion.h"
#include "imex/Utils/XeArch.h"
#include "imex/Utils/XeCommon.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h>
Expand All @@ -28,7 +30,7 @@
#include <mlir/IR/BuiltinTypeInterfaces.h>

namespace imex {

using namespace mlir;
using mlir::vector::CreateMaskOp;
using mlir::vector::ExtractOp;
using mlir::vector::ExtractStridedSliceOp;
Expand Down Expand Up @@ -479,14 +481,32 @@ class SgInitTileOpPattern : public XeOneToNConversion<xetile::InitTileOp> {
tDescOffsets.push_back(tDescOffsetX);
tDescOffsets.push_back(tDescOffsetY);

// TODO: this needs improvement, it assumes the source is static
// memeref.
// Handle memref source.
if (auto MemRefTypedSource =
mlir::cast<mlir::TypedValue<mlir::MemRefType>>(source)) {
mlir::dyn_cast<mlir::TypedValue<mlir::MemRefType>>(source)) {
// Hnadle the case where the shape is static.
if (MemRefTypedSource.getType().hasStaticShape()) {
auto createNdOp = rewriter.create<mlir::xegpu::CreateNdDescOp>(
op.getLoc(), tDescTy /*resultTy*/,
MemRefTypedSource
/*source*/,
tDescOffsets /*offsets*/);

xegpuOps[blocks[1] * i + j] = createNdOp;
} else {
// Handle the case where the shape is dynamic.
auto createNdOp = rewriter.create<mlir::xegpu::CreateNdDescOp>(
loc, tDescTy, MemRefTypedSource, tDescOffsets,
op.getMixedSizes(), op.getMixedStrides());
xegpuOps[blocks[1] * i + j] = createNdOp;
}
} else if (auto intSourceType =
mlir::dyn_cast<mlir::TypedValue<mlir::IntegerType>>(
source)) {
// Handle the case where the source is an integer.
auto createNdOp = rewriter.create<mlir::xegpu::CreateNdDescOp>(
op.getLoc(), tDescTy /*resultTy*/, MemRefTypedSource /*source*/,
tDescOffsets /*offsets*/);

loc, tDescTy, intSourceType, tDescOffsets, op.getMixedSizes(),
op.getMixedStrides());
xegpuOps[blocks[1] * i + j] = createNdOp;
} else {
return mlir::failure();
Expand Down
30 changes: 30 additions & 0 deletions lib/Utils/XeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,34 @@ bool isVectorAnyINTELType(mlir::Type type) {
spirvSupportedSizes.end());
}

// convert OpFoldResult to Value by replacing integer
// attributes with arith::ConstantOps. It also performs
// simple type conversions
mlir::Value getValueOrConstantOp(mlir::OpFoldResult ofr, mlir::Location loc,
mlir::PatternRewriter &rewriter,
mlir::Type type) {
if (ofr.is<mlir::Value>())
return ofr.get<mlir::Value>();

auto intAttr = llvm::cast<mlir::IntegerAttr>(ofr.get<mlir::Attribute>());

if (type)
intAttr = mlir::IntegerAttr::get(type, intAttr.getInt());

return rewriter.create<mlir::arith::ConstantOp>(loc, intAttr);
}

llvm::SmallVector<mlir::Value> getStridesOrOffsetsOrShapesInValueType(
mlir::PatternRewriter &rewriter,
::llvm::SmallVector<mlir::OpFoldResult> mixedOSS, mlir::Location loc) {
llvm::SmallVector<mlir::Value> valueVec;
// auto mixedStrides = op.getMixedStrides();
for (size_t i = 0; i < mixedOSS.size(); i++) {
auto oss = getValueOrConstantOp(mixedOSS[i], loc, rewriter,
rewriter.getIndexType());
valueVec.push_back(oss);
}
return valueVec;
}

} // namespace imex
35 changes: 35 additions & 0 deletions test/Conversion/XeTileToXeGPU/sg_init_tile.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: imex-opt --split-input-file --xetile-init-duplicate --xetile-blocking \
// RUN: --cse --convert-xetile-to-xegpu --cse %s -verify-diagnostics -o -| FileCheck %s

gpu.module @test_kernel {
//CHECK: gpu.func @sg_init_tile(%[[arg0:.*]]: memref<1024x1024xf32>, %[[arg1:.*]]: memref<?x?xf32>) {
gpu.func @sg_init_tile(%a: memref<1024x1024xf32>, %b: memref<?x?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index

%result = arith.constant dense<0.0>: vector<32x32xf32>

// CHECK: %[[SRC_AS_INDEX:.*]] = memref.extract_aligned_pointer_as_index {{.*}} : memref<1024x1024xf32> -> index
%src_as_index = memref.extract_aligned_pointer_as_index %a : memref<1024x1024xf32> -> index
// CHECK-NEXT: %[[SRC_AS_INT:.*]] = arith.index_cast %[[SRC_AS_INDEX]] : index to i64
%src_as_int = arith.index_cast %src_as_index : index to i64

//CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[arg0]][{{.*}}] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%static_memref_src = xetile.init_tile %a[0, 32] : memref<1024x1024xf32> -> !xetile.tile<32x32xf32>

// CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[arg1]][{{.*}}], [%c1024, %c1024], [%c1024, %c1] : memref<?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%dynamic_memref_src = xetile.init_tile %b[%c0, %c32],[%c1024, %c1024], [%c1024, %c1] : memref<?x?xf32> -> !xetile.tile<32x32xf32>

// CHECK-COUNT-8: {{.*}} = xegpu.create_nd_tdesc %[[SRC_AS_INT]][{{.*}}], [%c1024, %c1024], [%c1024, %c1] : i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
%int_src = xetile.init_tile %src_as_int[%c0, %c32], [%c1024, %c1024], [%c1024, %c1] : i64 -> !xetile.tile<32x32xf32>

//CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global, array_length = 1 : i64, boundary_check = true>>
xetile.store_tile %result, %static_memref_src: vector<32x32xf32>, !xetile.tile<32x32xf32>
xetile.store_tile %result, %dynamic_memref_src: vector<32x32xf32>, !xetile.tile<32x32xf32>
xetile.store_tile %result, %int_src: vector<32x32xf32>, !xetile.tile<32x32xf32>

gpu.return
}
}
Loading
Loading