diff --git a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp index 61fb8efc4..01e2a053d 100644 --- a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp +++ b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp @@ -645,7 +645,6 @@ class VectorShapeCastPattern : public OpConversionPattern { matchAndRewrite(ShapeCastOp shapeCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *converter = getTypeConverter(); - Type dstType = converter->convertType(shapeCastOp.getType()); if (!dstType) @@ -661,6 +660,22 @@ class VectorShapeCastPattern : public OpConversionPattern { } }; +template +class IndexCastPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpTy indexCastOp, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = OpConversionPattern::getTypeConverter(); + Type dstType = converter->convertType(indexCastOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(indexCastOp, dstType, adaptor.getIn()); + return success(); + } +}; + class SCFForPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -875,6 +890,14 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { target.addDynamicallyLegalDialect( [&](Operation *op) { return isLegalXeGPUSCFOp(op, typeConverter); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + if (auto vecTy = dyn_cast(op->getResult(0).getType())) { + return typeConverter.isLegal(vecTy); + } + return true; + }); + target.addDynamicallyLegalOp([&](arith::MaximumFOp op) { if (auto vecTy = dyn_cast(op.getType())) { if (vecTy.getRank() != 1) @@ -921,16 +944,6 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { unsigned rank = type.getRank(); auto elemType = type.getElementType(); - if (llvm::isa(elemType)) - elemType = IntegerType::get(&getContext(), 64); - - auto scalarType = llvm::dyn_cast_or_null(elemType); - if (!scalarType && !elemType.isBF16()) { - llvm::dbgs() << type - << " illegal: cannot convert non-scalar element type\n"; - return nullptr; - } - if (rank < 1 || type.getNumElements() == 1) return elemType; @@ -951,6 +964,9 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { patterns.add(typeConverter, patterns.getContext()); + patterns.add, + IndexCastPattern>(typeConverter, + patterns.getContext()); // Ops to llvm.genx only Patterns patterns.add, DpasPattern, diff --git a/test/Conversion/XeGPUToVC/nd-ops.mlir b/test/Conversion/XeGPUToVC/nd-ops.mlir new file mode 100644 index 000000000..010fb18b0 --- /dev/null +++ b/test/Conversion/XeGPUToVC/nd-ops.mlir @@ -0,0 +1,36 @@ +// Tests ops on nd vectors that should be linearized. + +// RUN: imex-opt -convert-xegpu-to-vc -split-input-file %s | FileCheck %s --check-prefixes=CHECK +module @gemm attributes {gpu.container_module} { + gpu.module @test_kernel { + + // CHECK-LABEL: gpu.func @test_index_cast + // CHECK: %[[c1024:.*]] = arith.constant 1024 : i32 + // CHECK: %[[bid:.*]] = gpu.block_id x + // CHECK: %[[cst:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + // CHECK: %[[r0:.*]] = arith.index_cast %[[bid]] : index to i32 + // CHECK: %[[r1:.*]] = arith.muli %[[r0]], %[[c1024]] : i32 + // CHECK: %[[r2:.*]] = vector.splat %[[r1]] : vector<16xi32> + // CHECK: %[[r3:.*]] = arith.addi %[[r2]], %[[r2]] : vector<16xi32> + // CHECK: %[[r4:.*]] = arith.addi %[[r2]], %[[cst]] : vector<16xi32> + // CHECK: %[[r5:.*]] = arith.index_cast %[[r3]] : vector<16xi32> to vector<16xindex> + // CHECK: %[[r6:.*]] = arith.index_cast %[[r4]] : vector<16xi32> to vector<16xindex> + // CHECK-NEXT: gpu.return + gpu.func @test_index_cast() kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>}{ + %c1024_i32 = arith.constant 1024 : i32 + %block_id_x = gpu.block_id x + %cst_0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %23 = arith.index_cast %block_id_x : index to i32 + %24 = arith.muli %23, %c1024_i32 : i32 + %25 = vector.splat %24 : vector<16xi32> + %26 = arith.addi %25, %25 : vector<16xi32> + %27 = vector.shape_cast %26 : vector<16xi32> to vector<1x16xi32> + %28 = arith.addi %25, %cst_0 : vector<16xi32> + %29 = vector.shape_cast %28 : vector<16xi32> to vector<1x16xi32> + %30 = arith.index_cast %27 : vector<1x16xi32> to vector<1x16xindex> + %31 = arith.index_cast %29 : vector<1x16xi32> to vector<1x16xindex> + + gpu.return + } + } +}