Skip to content

Commit

Permalink
Add pattern for IndexCastOp in XeGPUToVC pass. (#963)
Browse files Browse the repository at this point in the history
  • Loading branch information
nmostafa authored Nov 16, 2024
1 parent d8a52ff commit aa75f70
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 11 deletions.
38 changes: 27 additions & 11 deletions lib/Conversion/XeGPUToVC/XeGPUToVC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,6 @@ class VectorShapeCastPattern : public OpConversionPattern<ShapeCastOp> {
matchAndRewrite(ShapeCastOp shapeCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *converter = getTypeConverter();

Type dstType = converter->convertType(shapeCastOp.getType());

if (!dstType)
Expand All @@ -661,6 +660,22 @@ class VectorShapeCastPattern : public OpConversionPattern<ShapeCastOp> {
}
};

template <typename OpTy>
class IndexCastPattern : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpTy indexCastOp, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *converter = OpConversionPattern<OpTy>::getTypeConverter();
Type dstType = converter->convertType(indexCastOp.getType());
if (!dstType)
return failure();
rewriter.replaceOpWithNewOp<OpTy>(indexCastOp, dstType, adaptor.getIn());
return success();
}
};

class SCFForPattern : public OpConversionPattern<ForOp> {
public:
using OpConversionPattern<ForOp>::OpConversionPattern;
Expand Down Expand Up @@ -875,6 +890,14 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase<XeGPUToVCPass> {
target.addDynamicallyLegalDialect<scf::SCFDialect>(
[&](Operation *op) { return isLegalXeGPUSCFOp(op, typeConverter); });

target.addDynamicallyLegalOp<arith::IndexCastOp, arith::IndexCastUIOp>(
[&](Operation *op) {
if (auto vecTy = dyn_cast<VectorType>(op->getResult(0).getType())) {
return typeConverter.isLegal(vecTy);
}
return true;
});

target.addDynamicallyLegalOp<arith::MaximumFOp>([&](arith::MaximumFOp op) {
if (auto vecTy = dyn_cast<VectorType>(op.getType())) {
if (vecTy.getRank() != 1)
Expand Down Expand Up @@ -921,16 +944,6 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase<XeGPUToVCPass> {
unsigned rank = type.getRank();
auto elemType = type.getElementType();

if (llvm::isa<IndexType>(elemType))
elemType = IntegerType::get(&getContext(), 64);

auto scalarType = llvm::dyn_cast_or_null<spirv::ScalarType>(elemType);
if (!scalarType && !elemType.isBF16()) {
llvm::dbgs() << type
<< " illegal: cannot convert non-scalar element type\n";
return nullptr;
}

if (rank < 1 || type.getNumElements() == 1)
return elemType;

Expand All @@ -951,6 +964,9 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase<XeGPUToVCPass> {
patterns.add<VectorShapeCastPattern, SCFForPattern>(typeConverter,
patterns.getContext());

patterns.add<IndexCastPattern<arith::IndexCastOp>,
IndexCastPattern<arith::IndexCastUIOp>>(typeConverter,
patterns.getContext());
// Ops to llvm.genx only Patterns
patterns.add<NbarrierWaitPattern, CompilerHintPattern,
ElementwiseToVCPattern<arith::MaximumFOp>, DpasPattern,
Expand Down
36 changes: 36 additions & 0 deletions test/Conversion/XeGPUToVC/nd-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Tests ops on nd vectors that should be linearized.

// RUN: imex-opt -convert-xegpu-to-vc -split-input-file %s | FileCheck %s --check-prefixes=CHECK
module @gemm attributes {gpu.container_module} {
gpu.module @test_kernel {

// CHECK-LABEL: gpu.func @test_index_cast
// CHECK: %[[c1024:.*]] = arith.constant 1024 : i32
// CHECK: %[[bid:.*]] = gpu.block_id x
// CHECK: %[[cst:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
// CHECK: %[[r0:.*]] = arith.index_cast %[[bid]] : index to i32
// CHECK: %[[r1:.*]] = arith.muli %[[r0]], %[[c1024]] : i32
// CHECK: %[[r2:.*]] = vector.splat %[[r1]] : vector<16xi32>
// CHECK: %[[r3:.*]] = arith.addi %[[r2]], %[[r2]] : vector<16xi32>
// CHECK: %[[r4:.*]] = arith.addi %[[r2]], %[[cst]] : vector<16xi32>
// CHECK: %[[r5:.*]] = arith.index_cast %[[r3]] : vector<16xi32> to vector<16xindex>
// CHECK: %[[r6:.*]] = arith.index_cast %[[r4]] : vector<16xi32> to vector<16xindex>
// CHECK-NEXT: gpu.return
gpu.func @test_index_cast() kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>}{
%c1024_i32 = arith.constant 1024 : i32
%block_id_x = gpu.block_id x
%cst_0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
%23 = arith.index_cast %block_id_x : index to i32
%24 = arith.muli %23, %c1024_i32 : i32
%25 = vector.splat %24 : vector<16xi32>
%26 = arith.addi %25, %25 : vector<16xi32>
%27 = vector.shape_cast %26 : vector<16xi32> to vector<1x16xi32>
%28 = arith.addi %25, %cst_0 : vector<16xi32>
%29 = vector.shape_cast %28 : vector<16xi32> to vector<1x16xi32>
%30 = arith.index_cast %27 : vector<1x16xi32> to vector<1x16xindex>
%31 = arith.index_cast %29 : vector<1x16xi32> to vector<1x16xindex>

gpu.return
}
}
}

0 comments on commit aa75f70

Please sign in to comment.