From a637f3ec054e20992f44e096a5d53c7179baef24 Mon Sep 17 00:00:00 2001 From: Mahesha S Date: Mon, 23 Dec 2024 18:23:43 +0000 Subject: [PATCH] [XeGPUToVC] Fix 'offset' computation for 'base address + offset' calculation Current implementation of computing 'offset' fails for sub-byte types. This patch generalizes the implementation of 'offset' computation so that it works even for sub-byte types. --- include/imex/Utils/VCUtils.h | 12 +++++ lib/Conversion/XeGPUToVC/LSCPatterns.cpp | 3 +- lib/Conversion/XeGPUToVC/XeGPUToVC.cpp | 65 ++++++++++++++---------- lib/Utils/VCUtils.cpp | 27 ++++++++++ 4 files changed, 77 insertions(+), 30 deletions(-) diff --git a/include/imex/Utils/VCUtils.h b/include/imex/Utils/VCUtils.h index 20ac41811..c78df991c 100644 --- a/include/imex/Utils/VCUtils.h +++ b/include/imex/Utils/VCUtils.h @@ -15,6 +15,7 @@ #ifndef VC_UTILS_H #define VC_UTILS_H +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -64,6 +65,11 @@ using namespace mlir; #define dense_vector_val(attr, vecTy) \ rewriter.create(loc, DenseElementsAttr::get(vecTy, attr)) +#define divi(a, b) rewriter.createOrFold(loc, a, b) +#define muli(a, b) rewriter.createOrFold(loc, a, b) +#define addi(a, b) rewriter.createOrFold(loc, a, b) +#define subi(a, b) rewriter.createOrFold(loc, a, b) + /// This function adds necessary Func Declaration for Imported VC-intrinsics /// functions and sets linkage attributes to those declaration /// to support SPIRV compilation @@ -78,4 +84,10 @@ func::CallOp createFuncCall(PatternRewriter &rewriter, Location loc, StringRef funcName, TypeRange resultType, ValueRange operands, bool emitCInterface); +Value getOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc, + Type addrTy, Value offset, unsigned eTyBitWidth); + +Value getVecOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc, + unsigned vecSize, Type addrTy, Value offset, + unsigned eTyBitWidth); #endif // XEGPU_VC_UTILS_H diff --git a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp index eafb886de..79b155ac1 100644 --- a/lib/Conversion/XeGPUToVC/LSCPatterns.cpp +++ b/lib/Conversion/XeGPUToVC/LSCPatterns.cpp @@ -727,9 +727,8 @@ auto getElemBitWidth = [](TensorDescType tdescTy) -> unsigned { }; auto isLowPrecision = [](TensorDescType tdescTy) -> bool { - // Note: Handling for sub 8bit types is unclear so report as false auto width = getElemBitWidth(tdescTy); - return width < 32 && width >= 8; + return width < 32 && width >= 4; }; auto getScaled1DTdesc = diff --git a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp index 7ce8422cc..ab3e6daae 100644 --- a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp +++ b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp @@ -97,10 +97,6 @@ static Value castValueTo(Value val, Type toType, Location loc, return val; } -#define muli(a, b) rewriter.createOrFold(loc, a, b) -#define addi(a, b) rewriter.createOrFold(loc, a, b) -#define subi(a, b) rewriter.createOrFold(loc, a, b) - // Given an n-dim memref, a tensor descriptor with tile rank of 2 defines a // 2d memory region with respect to the two inner-most dimensions. Other // outer dimensions affect the base address of the 2d plane. For 2d, we @@ -135,8 +131,10 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter, // size auto effectiveRank = strides.size(); int64_t ranksToAdjust = effectiveRank; - auto bytesPerElem = - op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth() / 8; + auto eTyBitWidth = + op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth(); + auto bytesPerElem = eTyBitWidth / 8; + Value eTyBitWidthVal = index_val(eTyBitWidth); Value bytesPerElemVal = index_val(bytesPerElem); // We only need combine ranks that are larger than tileRank (e.g., if we the @@ -147,8 +145,12 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter, auto computeBase = [&](Value base) { for (auto i = 0; i < ranksToAdjust; i++) { - auto factor = muli(strides[i], bytesPerElemVal); + Value factor; Value offsetVal; + if (eTyBitWidth < 8) + factor = muli(strides[i], eTyBitWidthVal); + else + factor = muli(strides[i], bytesPerElemVal); if (offsets[i].is()) { offsetVal = offsets[i].get(); } else { @@ -156,6 +158,8 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter, llvm::cast(offsets[i].get()).getInt()); } auto linearOffset = muli(offsetVal, factor); + if (eTyBitWidth < 8) + linearOffset = divi(linearOffset, index_val(8)); base = addi(base, linearOffset); } @@ -176,7 +180,7 @@ class CreateNdDescPattern : public OpConversionPattern { auto tdescTy = op.getType(); auto scope = tdescTy.getMemorySpace(); auto rank = tdescTy.getRank(); - auto elemBytes = tdescTy.getElementType().getIntOrFloatBitWidth() / 8; + auto eTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); // SLM has to use 32-bit address, while ugm needs to use 64-bit address. auto addrTy = @@ -209,12 +213,13 @@ class CreateNdDescPattern : public OpConversionPattern { auto payloadTy = VectorType::get(simd_lanes, addrTy); // adjust base address to get absolute offset in unit of bytes. - // the computation is simply: base + linearOffset * elemBytes + // the computation is simply: base + linearOffset (in bytes) Value offset = getValueOrConstantOp(op.getMixedOffsets().back(), loc, rewriter); offset = castValueTo(offset, addrTy, loc, rewriter); - Value factor = integer_val(elemBytes, addrTy); - auto payload = addi(base, muli(offset, factor)); + Value numOffsetBytes = + getOffsetInUnitOfBytes(rewriter, loc, addrTy, offset, eTyBitWidth); + auto payload = addi(base, numOffsetBytes); // convert the payload into vector type payload = rewriter.create(loc, payloadTy, payload); @@ -251,12 +256,15 @@ class CreateNdDescPattern : public OpConversionPattern { if (v) { auto value = ofr.get(); value = rewriter.create(loc, i32Ty, value); - if (mul > 1) - value = rewriter.create(loc, value, i32_val(mul)); + if (mul > 8) + value = + rewriter.create(loc, value, i32_val(mul / 8)); + else if (mul >= 4) + value = getOffsetInUnitOfBytes(rewriter, loc, i32Ty, value, mul); return (!minus) ? value : subi(value, i32_val(minus)); } else { int value = cast(ofr.get()).getInt(); - return i32_val(value * mul - minus); + return i32_val(((value * mul) / 8) - minus); } }; @@ -267,8 +275,9 @@ class CreateNdDescPattern : public OpConversionPattern { // is in rows. auto matrixShape = op.getMixedSizes(); auto size = matrixShape.size(); - auto surfaceW = encodeShapeAndOffset(matrixShape[size - 1], elemBytes, 1); - auto surfaceH = encodeShapeAndOffset(matrixShape[size - 2], 1, 1); + auto surfaceW = + encodeShapeAndOffset(matrixShape[size - 1], eTyBitWidth, 1); + auto surfaceH = encodeShapeAndOffset(matrixShape[size - 2], 8, 1); // encode the pitch, which is in bytes minus 1 auto matrixStrides = op.getMixedStrides(); @@ -279,7 +288,7 @@ class CreateNdDescPattern : public OpConversionPattern { assert(isOneOrUnknow(matrixStrides[size - 1]) && "Fast Changing Dimension can only have stride of 1."); auto surfaceP = - encodeShapeAndOffset(matrixStrides[size - 2], elemBytes, 1); + encodeShapeAndOffset(matrixStrides[size - 2], eTyBitWidth, 1); payload = rewriter.create(loc, surfaceW, payload, 2); payload = rewriter.create(loc, surfaceH, payload, 3); @@ -287,8 +296,8 @@ class CreateNdDescPattern : public OpConversionPattern { // encode the offset, they are in elements auto offsets = op.getMixedOffsets(); - auto offsetX = encodeShapeAndOffset(offsets[size - 1], 1, 0); - auto offsetY = encodeShapeAndOffset(offsets[size - 2], 1, 0); + auto offsetX = encodeShapeAndOffset(offsets[size - 1], 8, 0); + auto offsetY = encodeShapeAndOffset(offsets[size - 2], 8, 0); payload = rewriter.create(loc, offsetX, payload, 5); payload = rewriter.create(loc, offsetY, payload, 6); @@ -337,11 +346,11 @@ class UpdateNDOffsetPattern : public OpConversionPattern { } // update offset from unit of elements to unit of bytes - auto elemBytes = tdescTy.getElementType().getIntOrFloatBitWidth() / 8; - auto factor = integer_val(elemBytes, addrTy); auto offset = getValueOrConstantOp(offsets.back(), loc, rewriter); offset = castValueTo(offset, addrTy, loc, rewriter); - offset = muli(offset, factor); + auto eTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); + offset = + getOffsetInUnitOfBytes(rewriter, loc, addrTy, offset, eTyBitWidth); // convert offset to vector type and update the payload const int simd_lanes = 1; @@ -409,10 +418,10 @@ class CreateDescPattern : public OpConversionPattern { auto payloadTy = vecTy(simd_lanes, addrTy); // offset is represented in number of elements, need to scale it to bytes - auto elemBytes = elemTy.getIntOrFloatBitWidth() / 8; - auto factor = dense_vector_int_val(elemBytes, addrTy, simd_lanes); + auto eTyBitWidth = elemTy.getIntOrFloatBitWidth(); Value offsets = castValueTo(adaptor.getOffsets(), payloadTy, loc, rewriter); - offsets = muli(factor, offsets); + offsets = getVecOffsetInUnitOfBytes(rewriter, loc, simd_lanes, addrTy, + offsets, eTyBitWidth); // create a payload with the base address broadcasted to all simd lanes Value payload = rewriter.create(loc, payloadTy, base); @@ -444,10 +453,10 @@ class UpdateOffsetOpPattern : public OpConversionPattern { auto simd_lanes = tdescTy.getShape()[0]; auto payloadTy = VectorType::get(simd_lanes, addrTy); - auto elemBytes = elemTy.getIntOrFloatBitWidth() / 8; - Value factor = dense_vector_int_val(elemBytes, addrTy, simd_lanes); + auto eTyBitWidth = elemTy.getIntOrFloatBitWidth(); Value offsets = castValueTo(adaptor.getOffsets(), payloadTy, loc, rewriter); - offsets = muli(factor, offsets); + offsets = getVecOffsetInUnitOfBytes(rewriter, loc, simd_lanes, addrTy, + offsets, eTyBitWidth); auto payload = addi(adaptor.getTensorDesc(), offsets); rewriter.replaceOp(op, payload); diff --git a/lib/Utils/VCUtils.cpp b/lib/Utils/VCUtils.cpp index f2c4f2c2f..b5cedc2b6 100644 --- a/lib/Utils/VCUtils.cpp +++ b/lib/Utils/VCUtils.cpp @@ -77,3 +77,30 @@ func::CallOp createFuncCall(PatternRewriter &rewriter, Location loc, true /*isVectorComputeFunctionINTEL=true*/, emitCInterface); return rewriter.create(loc, fn, resultType, operands); } + +Value getOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc, + Type addrTy, Value offset, unsigned eTyBitWidth) { + if (eTyBitWidth >= 8) { + unsigned eTyBytes = eTyBitWidth / 8; + Value factor = integer_val(eTyBytes, addrTy); + return muli(offset, factor); + } else { + Value eight = integer_val(8, addrTy); + Value bw = integer_val(eTyBitWidth, addrTy); + return divi(muli(offset, bw), eight); + } +} + +Value getVecOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc, + unsigned vecSize, Type addrTy, Value offset, + unsigned eTyBitWidth) { + if (eTyBitWidth >= 8) { + unsigned eTyBytes = eTyBitWidth / 8; + Value factor = dense_vector_int_val(eTyBytes, addrTy, vecSize); + return muli(offset, factor); + } else { + Value eight = dense_vector_int_val(8, addrTy, vecSize); + Value bw = dense_vector_int_val(eTyBitWidth, addrTy, vecSize); + return divi(muli(offset, bw), eight); + } +}