Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XeGPUToVC] Fix 'offset' computation for 'base address + offset' calc… #992

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/imex/Utils/VCUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef VC_UTILS_H
#define VC_UTILS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -64,6 +65,11 @@ using namespace mlir;
#define dense_vector_val(attr, vecTy) \
rewriter.create<arith::ConstantOp>(loc, DenseElementsAttr::get(vecTy, attr))

#define divi(a, b) rewriter.createOrFold<arith::DivSIOp>(loc, a, b)
#define muli(a, b) rewriter.createOrFold<arith::MulIOp>(loc, a, b)
#define addi(a, b) rewriter.createOrFold<arith::AddIOp>(loc, a, b)
#define subi(a, b) rewriter.createOrFold<arith::SubIOp>(loc, a, b)

/// This function adds necessary Func Declaration for Imported VC-intrinsics
/// functions and sets linkage attributes to those declaration
/// to support SPIRV compilation
Expand All @@ -78,4 +84,10 @@ func::CallOp createFuncCall(PatternRewriter &rewriter, Location loc,
StringRef funcName, TypeRange resultType,
ValueRange operands, bool emitCInterface);

Value getOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc,
Type addrTy, Value offset, unsigned eTyBitWidth);

Value getVecOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc,
unsigned vecSize, Type addrTy, Value offset,
unsigned eTyBitWidth);
#endif // XEGPU_VC_UTILS_H
3 changes: 1 addition & 2 deletions lib/Conversion/XeGPUToVC/LSCPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,9 +727,8 @@ auto getElemBitWidth = [](TensorDescType tdescTy) -> unsigned {
};

auto isLowPrecision = [](TensorDescType tdescTy) -> bool {
// Note: Handling for sub 8bit types is unclear so report as false
auto width = getElemBitWidth(tdescTy);
return width < 32 && width >= 8;
return width < 32 && width >= 4;
};

auto getScaled1DTdesc =
Expand Down
65 changes: 37 additions & 28 deletions lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ static Value castValueTo(Value val, Type toType, Location loc,
return val;
}

#define muli(a, b) rewriter.createOrFold<arith::MulIOp>(loc, a, b)
#define addi(a, b) rewriter.createOrFold<arith::AddIOp>(loc, a, b)
#define subi(a, b) rewriter.createOrFold<arith::SubIOp>(loc, a, b)

// Given an n-dim memref, a tensor descriptor with tile rank of 2 defines a
// 2d memory region with respect to the two inner-most dimensions. Other
// outer dimensions affect the base address of the 2d plane. For 2d, we
Expand Down Expand Up @@ -135,8 +131,10 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter,
// size
auto effectiveRank = strides.size();
int64_t ranksToAdjust = effectiveRank;
auto bytesPerElem =
op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth() / 8;
auto eTyBitWidth =
op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
auto bytesPerElem = eTyBitWidth / 8;
Value eTyBitWidthVal = index_val(eTyBitWidth);
Value bytesPerElemVal = index_val(bytesPerElem);

// We only need combine ranks that are larger than tileRank (e.g., if we the
Expand All @@ -147,15 +145,21 @@ static Value adjustBasePointer(ConversionPatternRewriter &rewriter,

auto computeBase = [&](Value base) {
for (auto i = 0; i < ranksToAdjust; i++) {
auto factor = muli(strides[i], bytesPerElemVal);
Value factor;
Value offsetVal;
if (eTyBitWidth < 8)
factor = muli(strides[i], eTyBitWidthVal);
else
factor = muli(strides[i], bytesPerElemVal);
if (offsets[i].is<Value>()) {
offsetVal = offsets[i].get<Value>();
} else {
offsetVal = index_val(
llvm::cast<IntegerAttr>(offsets[i].get<Attribute>()).getInt());
}
auto linearOffset = muli(offsetVal, factor);
if (eTyBitWidth < 8)
linearOffset = divi(linearOffset, index_val(8));
base = addi(base, linearOffset);
}

Expand All @@ -176,7 +180,7 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
auto tdescTy = op.getType();
auto scope = tdescTy.getMemorySpace();
auto rank = tdescTy.getRank();
auto elemBytes = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
auto eTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();

// SLM has to use 32-bit address, while ugm needs to use 64-bit address.
auto addrTy =
Expand Down Expand Up @@ -209,12 +213,13 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
auto payloadTy = VectorType::get(simd_lanes, addrTy);

// adjust base address to get absolute offset in unit of bytes.
// the computation is simply: base + linearOffset * elemBytes
// the computation is simply: base + linearOffset (in bytes)
Value offset =
getValueOrConstantOp(op.getMixedOffsets().back(), loc, rewriter);
offset = castValueTo(offset, addrTy, loc, rewriter);
Value factor = integer_val(elemBytes, addrTy);
auto payload = addi(base, muli(offset, factor));
Value numOffsetBytes =
getOffsetInUnitOfBytes(rewriter, loc, addrTy, offset, eTyBitWidth);
auto payload = addi(base, numOffsetBytes);

// convert the payload into vector type
payload = rewriter.create<vector::BroadcastOp>(loc, payloadTy, payload);
Expand Down Expand Up @@ -251,12 +256,15 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
if (v) {
auto value = ofr.get<Value>();
value = rewriter.create<arith::IndexCastUIOp>(loc, i32Ty, value);
if (mul > 1)
value = rewriter.create<arith::MulIOp>(loc, value, i32_val(mul));
if (mul > 8)
value =
rewriter.create<arith::MulIOp>(loc, value, i32_val(mul / 8));
else if (mul >= 4)
value = getOffsetInUnitOfBytes(rewriter, loc, i32Ty, value, mul);
return (!minus) ? value : subi(value, i32_val(minus));
} else {
int value = cast<IntegerAttr>(ofr.get<Attribute>()).getInt();
return i32_val(value * mul - minus);
return i32_val(((value * mul) / 8) - minus);
}
};

Expand All @@ -267,8 +275,9 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
// is in rows.
auto matrixShape = op.getMixedSizes();
auto size = matrixShape.size();
auto surfaceW = encodeShapeAndOffset(matrixShape[size - 1], elemBytes, 1);
auto surfaceH = encodeShapeAndOffset(matrixShape[size - 2], 1, 1);
auto surfaceW =
encodeShapeAndOffset(matrixShape[size - 1], eTyBitWidth, 1);
auto surfaceH = encodeShapeAndOffset(matrixShape[size - 2], 8, 1);

// encode the pitch, which is in bytes minus 1
auto matrixStrides = op.getMixedStrides();
Expand All @@ -279,16 +288,16 @@ class CreateNdDescPattern : public OpConversionPattern<CreateNdDescOp> {
assert(isOneOrUnknow(matrixStrides[size - 1]) &&
"Fast Changing Dimension can only have stride of 1.");
auto surfaceP =
encodeShapeAndOffset(matrixStrides[size - 2], elemBytes, 1);
encodeShapeAndOffset(matrixStrides[size - 2], eTyBitWidth, 1);

payload = rewriter.create<vector::InsertOp>(loc, surfaceW, payload, 2);
payload = rewriter.create<vector::InsertOp>(loc, surfaceH, payload, 3);
payload = rewriter.create<vector::InsertOp>(loc, surfaceP, payload, 4);

// encode the offset, they are in elements
auto offsets = op.getMixedOffsets();
auto offsetX = encodeShapeAndOffset(offsets[size - 1], 1, 0);
auto offsetY = encodeShapeAndOffset(offsets[size - 2], 1, 0);
auto offsetX = encodeShapeAndOffset(offsets[size - 1], 8, 0);
auto offsetY = encodeShapeAndOffset(offsets[size - 2], 8, 0);
payload = rewriter.create<vector::InsertOp>(loc, offsetX, payload, 5);
payload = rewriter.create<vector::InsertOp>(loc, offsetY, payload, 6);

Expand Down Expand Up @@ -337,11 +346,11 @@ class UpdateNDOffsetPattern : public OpConversionPattern<UpdateNdOffsetOp> {
}

// update offset from unit of elements to unit of bytes
auto elemBytes = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
auto factor = integer_val(elemBytes, addrTy);
auto offset = getValueOrConstantOp(offsets.back(), loc, rewriter);
offset = castValueTo(offset, addrTy, loc, rewriter);
offset = muli(offset, factor);
auto eTyBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
offset =
getOffsetInUnitOfBytes(rewriter, loc, addrTy, offset, eTyBitWidth);

// convert offset to vector type and update the payload
const int simd_lanes = 1;
Expand Down Expand Up @@ -409,10 +418,10 @@ class CreateDescPattern : public OpConversionPattern<CreateDescOp> {
auto payloadTy = vecTy(simd_lanes, addrTy);

// offset is represented in number of elements, need to scale it to bytes
auto elemBytes = elemTy.getIntOrFloatBitWidth() / 8;
auto factor = dense_vector_int_val(elemBytes, addrTy, simd_lanes);
auto eTyBitWidth = elemTy.getIntOrFloatBitWidth();
Value offsets = castValueTo(adaptor.getOffsets(), payloadTy, loc, rewriter);
offsets = muli(factor, offsets);
offsets = getVecOffsetInUnitOfBytes(rewriter, loc, simd_lanes, addrTy,
offsets, eTyBitWidth);

// create a payload with the base address broadcasted to all simd lanes
Value payload = rewriter.create<vector::BroadcastOp>(loc, payloadTy, base);
Expand Down Expand Up @@ -444,10 +453,10 @@ class UpdateOffsetOpPattern : public OpConversionPattern<UpdateOffsetOp> {
auto simd_lanes = tdescTy.getShape()[0];
auto payloadTy = VectorType::get(simd_lanes, addrTy);

auto elemBytes = elemTy.getIntOrFloatBitWidth() / 8;
Value factor = dense_vector_int_val(elemBytes, addrTy, simd_lanes);
auto eTyBitWidth = elemTy.getIntOrFloatBitWidth();
Value offsets = castValueTo(adaptor.getOffsets(), payloadTy, loc, rewriter);
offsets = muli(factor, offsets);
offsets = getVecOffsetInUnitOfBytes(rewriter, loc, simd_lanes, addrTy,
offsets, eTyBitWidth);

auto payload = addi(adaptor.getTensorDesc(), offsets);
rewriter.replaceOp(op, payload);
Expand Down
27 changes: 27 additions & 0 deletions lib/Utils/VCUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,30 @@ func::CallOp createFuncCall(PatternRewriter &rewriter, Location loc,
true /*isVectorComputeFunctionINTEL=true*/, emitCInterface);
return rewriter.create<func::CallOp>(loc, fn, resultType, operands);
}

Value getOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc,
Type addrTy, Value offset, unsigned eTyBitWidth) {
if (eTyBitWidth >= 8) {
unsigned eTyBytes = eTyBitWidth / 8;
Value factor = integer_val(eTyBytes, addrTy);
return muli(offset, factor);
} else {
Value eight = integer_val(8, addrTy);
Value bw = integer_val(eTyBitWidth, addrTy);
return divi(muli(offset, bw), eight);
}
}

Value getVecOffsetInUnitOfBytes(PatternRewriter &rewriter, Location loc,
unsigned vecSize, Type addrTy, Value offset,
unsigned eTyBitWidth) {
if (eTyBitWidth >= 8) {
unsigned eTyBytes = eTyBitWidth / 8;
Value factor = dense_vector_int_val(eTyBytes, addrTy, vecSize);
return muli(offset, factor);
} else {
Value eight = dense_vector_int_val(8, addrTy, vecSize);
Value bw = dense_vector_int_val(eTyBitWidth, addrTy, vecSize);
return divi(muli(offset, bw), eight);
}
}
Loading