diff --git a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp index 9e080849d..8fa5b779e 100644 --- a/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp +++ b/lib/Conversion/XeGPUToVC/XeGPUToVC.cpp @@ -661,22 +661,6 @@ class VectorShapeCastPattern : public OpConversionPattern { } }; -template -class IndexCastPattern : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(OpTy indexCastOp, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto *converter = OpConversionPattern::getTypeConverter(); - Type dstType = converter->convertType(indexCastOp.getType()); - if (!dstType) - return failure(); - rewriter.replaceOpWithNewOp(indexCastOp, dstType, adaptor.getIn()); - return success(); - } -}; - class SCFForPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -823,14 +807,6 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { target.addDynamicallyLegalDialect( [&](Operation *op) { return isLegalXeGPUSCFOp(op, typeConverter); }); - target.addDynamicallyLegalOp( - [&](Operation *op) { - if (auto vecTy = dyn_cast(op->getResult(0).getType())) { - return typeConverter.isLegal(vecTy); - } - return true; - }); - target.addIllegalOp(); // TODO: can we change it to addDynamicLegalOp? @@ -883,9 +859,6 @@ struct XeGPUToVCPass : public imex::impl::ConvertXeGPUToVCBase { patterns.add(typeConverter, patterns.getContext()); - patterns.add, - IndexCastPattern>(typeConverter, - patterns.getContext()); // Ops to llvm.genx only Patterns patterns.add(patterns.getContext()); diff --git a/lib/Transforms/VectorLinearize.cpp b/lib/Transforms/VectorLinearize.cpp index bf357b92c..1040db35e 100644 --- a/lib/Transforms/VectorLinearize.cpp +++ b/lib/Transforms/VectorLinearize.cpp @@ -34,33 +34,93 @@ namespace imex { namespace { -// rewrite arith.constant op in form of vector<1xmxindex> into 1D form -// (vector) -struct ArithConstantOpConversion final +// Cloned from upstream with isLessThanTargetBitWidth check removed. +struct ConstantOpConversion final : public mlir::OpConversionPattern { - using mlir::OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; mlir::LogicalResult matchAndRewrite(mlir::arith::ConstantOp constOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - auto value = llvm::dyn_cast(constOp.getValue()); - if (!value || value.getType().getRank() != 2) - return mlir::failure(); - auto type = value.getType(); - auto shape = type.getShape(); - auto elemTy = type.getElementType(); - if (shape[0] != 1 || !elemTy.isIndex()) + auto resType = + getTypeConverter()->convertType(constOp.getType()); + + if (resType.isScalable() && + !mlir::isa(constOp.getValue())) + return rewriter.notifyMatchFailure( + constOp, + "Cannot linearize a constant scalable vector that's not a splat"); + + if (!resType) + return rewriter.notifyMatchFailure(constOp, "can't convert return type"); + auto dstElementsAttr = + mlir::dyn_cast(constOp.getValue()); + if (!dstElementsAttr) + return rewriter.notifyMatchFailure(constOp, "unsupported attr type"); + + dstElementsAttr = dstElementsAttr.reshape(resType); + rewriter.replaceOpWithNewOp(constOp, resType, + dstElementsAttr); + return mlir::success(); + } +}; + +// Cloned from upstream with isLessThanTargetBitWidth check removed. +struct VectorizableOpConversion final + : public mlir::OpTraitConversionPattern { + using OpTraitConversionPattern::OpTraitConversionPattern; + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::FailureOr newOp = + convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); + if (failed(newOp)) return mlir::failure(); - auto newTy = mlir::VectorType::get({shape[1]}, elemTy); - value = value.reshape(newTy); - auto newOp = - rewriter.create(constOp.getLoc(), value); - auto castOp = rewriter.create(constOp.getLoc(), - type, newOp); - rewriter.replaceOp(constOp, castOp); + + rewriter.replaceOp(op, (*newOp)->getResults()); return mlir::success(); } }; +// Cloned from upstream with isLessThanTargetBitWidth check removed. +static void populateVectorLinearizeTypeConversionsAndLegality( + mlir::TypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::ConversionTarget &target) { + + typeConverter.addConversion( + [](mlir::VectorType type) -> std::optional { + if (!mlir::vector::isLinearizableVector(type)) + return type; + + return mlir::VectorType::get(type.getNumElements(), + type.getElementType(), type.isScalable()); + }); + + auto materializeCast = [](mlir::OpBuilder &builder, mlir::Type type, + mlir::ValueRange inputs, + mlir::Location loc) -> mlir::Value { + if (inputs.size() != 1 || + !mlir::isa(inputs.front().getType()) || + !mlir::isa(type)) + return nullptr; + + return builder.create(loc, type, inputs.front()); + }; + typeConverter.addArgumentMaterialization(materializeCast); + typeConverter.addSourceMaterialization(materializeCast); + typeConverter.addTargetMaterialization(materializeCast); + target.markUnknownOpDynamicallyLegal( + [=](mlir::Operation *op) -> std::optional { + if ((mlir::isa(op) || + op->hasTrait())) { + return typeConverter.isLegal(op); + } + return std::nullopt; + }); + + patterns.add( + typeConverter, patterns.getContext()); +} + struct VectorLoadOpConversion final : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; @@ -513,38 +573,19 @@ struct VectorLinearizePass final return (op && op.getAggregate().getType().getRank() == 1); }); - // borrowed from upstream with hacking for index type. Currently - // we only target vector<1xmxindex> to vector conversion. It is - // unclear whether others are valid or not; thus they are left untouched. - target.addDynamicallyLegalOp( - [&](mlir::arith::ConstantOp op) -> bool { - auto vecTy = mlir::dyn_cast(op.getType()); - if (!vecTy || vecTy.getRank() == 0) - return true; - - auto elemTy = vecTy.getElementType(); - if (elemTy.isIndex()) { - if (vecTy.getRank() == 2 && vecTy.getShape()[0] == 1) - return false; - return true; - } - return !mlir::vector::isLinearizableVector(vecTy); - }); - patterns.add(typeConverter, context); + VectorStoreOpConversion, VectorCreateMaskOpConversion>( + typeConverter, context); // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes. mlir::vector::populateVectorTransposeLoweringPatterns( patterns, mlir::vector::VectorTransformsOptions().setVectorTransposeLowering( mlir::vector::VectorTransposeLowering::Shuffle16x16)); - unsigned targetVectBitWidth = std::numeric_limits::max(); - mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( - typeConverter, patterns, target, targetVectBitWidth); + populateVectorLinearizeTypeConversionsAndLegality(typeConverter, patterns, + target); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/test/Transforms/vector-linearize.mlir b/test/Transforms/vector-linearize.mlir index 17dfe5102..dcfa5217c 100644 --- a/test/Transforms/vector-linearize.mlir +++ b/test/Transforms/vector-linearize.mlir @@ -284,3 +284,95 @@ func.func @test_vector_store_load_4x4(%buffer: memref<4x4xf32>) { vector.store %0, %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32> return } + +// ----- +// CHECK-LABEL: @test_linearize_index +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex> { +// CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[T1:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex> +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[CST]] : vector<4xindex> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : vector<4xindex> to vector<4xi32> +// CHECK: %[[T4:.*]] = arith.muli %[[T3]], %[[T0]] : vector<4xi32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : vector<4xi32> to vector<4xindex> +// CHECK: %[[T6:.*]] = vector.shape_cast %[[T5]] : vector<4xindex> to vector<2x2xindex> +// CHECK: return %[[T6]] : vector<2x2xindex> +func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>) -> vector<2x2xindex> { + %0 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xindex> +// Arith and math ops are handled in generic way, check some of them + %1 = arith.addi %arg0, %0 : vector<2x2xindex> + %2 = arith.index_cast %1 : vector<2x2xindex> to vector<2x2xi32> + %3 = arith.muli %2, %arg1 : vector<2x2xi32> + %4 = arith.index_cast %3 : vector<2x2xi32> to vector<2x2xindex> + return %4 : vector<2x2xindex> +} + +// ----- +// CHECK-LABEL: @add_kernel_f32 +// CHECK: %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> +// CHECK: %[[CST1:.*]] = arith.constant dense<[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : vector<16xindex> +// CHECK: %[[T0:.*]] = vector.splat %{{.*}} : vector<16xindex> +// CHECK: %[[T1:.*]] = arith.addi %[[T0]], %[[CST0]] : vector<16xindex> +// CHECK: %[[T2:.*]] = arith.addi %[[T0]], %[[CST1]] : vector<16xindex> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T1]] : vector<16xindex> to vector<16xi32> +// CHECK: %[[T4:.*]] = arith.index_cast %[[T2]] : vector<16xindex> to vector<16xi32> +// CHECK: %[[T5:.*]] = vector.splat %{{.*}} : vector<16xi32> +// CHECK: %[[T6:.*]] = arith.addi %[[T5]], %[[T3]] : vector<16xi32> +// CHECK: %[[T7:.*]] = arith.addi %[[T5]], %[[T4]] : vector<16xi32> +// CHECK: %[[T8:.*]] = arith.index_cast %[[T6]] : vector<16xi32> to vector<16xindex> +// CHECK: %[[T9:.*]] = arith.index_cast %[[T7]] : vector<16xi32> to vector<16xindex> +gpu.module @add_kernel_f32 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @add_kernel_f32(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense : vector<16xi1> + %c32 = arith.constant 32 : index + %c1024_i32 = arith.constant 1024 : i32 + %cst_0 = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %cst_1 = arith.constant dense<[[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : vector<1x16xindex> + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %block_dim_y = gpu.block_dim y + %0 = arith.muli %thread_id_x, %block_dim_y : index + %1 = arith.addi %0, %thread_id_y : index + %cast = memref.cast %arg0 : memref<*xf32> to memref + %cast_2 = memref.cast %arg1 : memref<*xf32> to memref + %cast_3 = memref.cast %arg2 : memref<*xf32> to memref + %2 = arith.remsi %1, %c32 : index + %3 = arith.muli %2, %c32 : index + %4 = vector.splat %3 : vector<1x16xindex> + %5 = arith.addi %4, %cst_0 : vector<1x16xindex> + %6 = arith.addi %4, %cst_1 : vector<1x16xindex> + %7 = arith.index_cast %5 : vector<1x16xindex> to vector<1x16xi32> + %8 = arith.index_cast %6 : vector<1x16xindex> to vector<1x16xi32> + %block_id_x = gpu.block_id x + %9 = arith.index_cast %block_id_x : index to i32 + %10 = arith.muli %9, %c1024_i32 : i32 + %11 = vector.splat %10 : vector<1x16xi32> + %12 = arith.addi %11, %7 : vector<1x16xi32> + %13 = arith.addi %11, %8 : vector<1x16xi32> + %14 = arith.index_cast %12 : vector<1x16xi32> to vector<1x16xindex> + %15 = arith.index_cast %13 : vector<1x16xi32> to vector<1x16xindex> + %16 = vector.shape_cast %14 : vector<1x16xindex> to vector<16xindex> + %17 = xegpu.create_tdesc %cast, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %18 = vector.shape_cast %15 : vector<1x16xindex> to vector<16xindex> + %19 = xegpu.create_tdesc %cast, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %20 = xegpu.load %17, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %21 = vector.shape_cast %20 : vector<16xf32> to vector<1x16xf32> + %22 = xegpu.load %19, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %23 = vector.shape_cast %22 : vector<16xf32> to vector<1x16xf32> + %24 = xegpu.create_tdesc %cast_2, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %25 = xegpu.create_tdesc %cast_2, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %26 = xegpu.load %24, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %27 = vector.shape_cast %26 : vector<16xf32> to vector<1x16xf32> + %28 = xegpu.load %25, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %29 = vector.shape_cast %28 : vector<16xf32> to vector<1x16xf32> + %30 = arith.addf %21, %27 : vector<1x16xf32> + %31 = arith.addf %23, %29 : vector<1x16xf32> + %32 = xegpu.create_tdesc %cast_3, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %33 = xegpu.create_tdesc %cast_3, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %34 = vector.shape_cast %30 : vector<1x16xf32> to vector<16xf32> + xegpu.store %34, %32, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %35 = vector.shape_cast %31 : vector<1x16xf32> to vector<16xf32> + xegpu.store %35, %33, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + gpu.return + } +} \ No newline at end of file