From 7f1f3e681b0c7a3ba947ff6da676211d6fcab5c3 Mon Sep 17 00:00:00 2001 From: James Lin Date: Tue, 7 May 2024 11:08:16 -0500 Subject: [PATCH] [aievec] to-llvm flow for aievec.broadcast_scalar op (#1446) * This PR add the support for aievec.broadcast_scalar op going through the to-llvm flow. * Add aievec-to-llvm conversion pattern/tests for the aievec.broadcast_scalar op. * Add op verifier for the aievec.broadcast_scalar op. * Add target external llvm translation tests. --- include/aie/Dialect/AIEVec/IR/AIEVecOps.td | 2 +- .../aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td | 16 ++++++++-- lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp | 22 ++++++++++--- lib/Dialect/AIEVec/IR/AIEVecOps.cpp | 10 ++++-- .../AIEVecToLLVM/broadcast_scalar.mlir | 31 ++++++++++++++++++- test/Target/LLVMIR/aievec.mlir | 16 ++++++++++ 6 files changed, 85 insertions(+), 12 deletions(-) diff --git a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td index 0350376c04..86414ff6cf 100644 --- a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td +++ b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td @@ -352,7 +352,7 @@ def AIEVec_BroadcastScalarOp: AIEVec_Op<"broadcast_scalar", [ Pure ]>, - Arguments<(ins AnyType:$source)>, + Arguments<(ins AnyTypeOf<[BF16, F32, I32, I16, I8]>:$source)>, Results<(outs AnyVector:$result)> { let summary = "AIE-ML broadcast scalar"; let description = [{ diff --git a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td index 5dbab4c89f..b288539aac 100644 --- a/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td +++ b/include/aie/Dialect/XLLVM/IR/XLLVMAIE2IntrOps.td @@ -134,17 +134,27 @@ def Vector16AccFloatToV16BF16IntrOp : def VectorBroadcast8I512IntrOp : AIEVec2_IntrOp<"vbroadcast8.I512", [TypeIs<"res", VectorOfLengthAndType<[64], [I8]>>]>, - Arguments<(ins I32:$value)>; + Arguments<(ins I32:$src)>; + +def VectorBroadcast16I512IntrOp : + AIEVec2_IntrOp<"vbroadcast16.I512", + [TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>, + Arguments<(ins I32:$src)>; def VectorBroadcast32I512IntrOp : AIEVec2_IntrOp<"vbroadcast32.I512", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, - Arguments<(ins I32:$value)>; + Arguments<(ins I32:$src)>; def VectorBroadcast16BF512IntrOp : AIEVec2_IntrOp<"vbroadcast16.bf512", [TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>, - Arguments<(ins BF16:$value)>; + Arguments<(ins BF16:$src)>; + +def VectorBroadcastfloatI512IntrOp : + AIEVec2_IntrOp<"vbroadcastfloat.I512", + [TypeIs<"res", VectorOfLengthAndType<[16], [F32]>>]>, + Arguments<(ins F32:$src)>; // ----- EXT ----- diff --git a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp index 2d4c4ae32a..7914c9c80e 100644 --- a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp +++ b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp @@ -1249,6 +1249,15 @@ class BroadcastScalarOpConversion VectorType resultType = cast(result.getType()); Type resultScaTy = resultType.getElementType(); unsigned resultBitWidth = resultScaTy.getIntOrFloatBitWidth(); + int resultLanes = getVectorLaneSize(resultType); + int resultVectorSize = resultBitWidth * resultLanes; + + if (resultVectorSize != 512) { + op.emitWarning() + << "aievec.broadcast_scalar conversion with result vector size " + << resultVectorSize << " is not implemented.\n"; + return failure(); + } // Integer types if (llvm::isa(resultScaTy)) { @@ -1257,16 +1266,15 @@ class BroadcastScalarOpConversion unsigned srcBitWidth = srcType.getIntOrFloatBitWidth(); if (srcBitWidth < 32) { - src = rewriter.create(loc, rewriter.getI32Type(), - adaptor.getSource()); - } else if (srcBitWidth > 32) { - src = rewriter.create(loc, rewriter.getI32Type(), - adaptor.getSource()); + src = rewriter.create(loc, rewriter.getI32Type(), src); } if (resultBitWidth == 8) { rewriter.replaceOpWithNewOp( op, VectorType::get({64}, rewriter.getI8Type()), src); + } else if (resultBitWidth == 16) { + rewriter.replaceOpWithNewOp( + op, VectorType::get({32}, rewriter.getI16Type()), src); } else if (resultBitWidth == 32) { rewriter.replaceOpWithNewOp( op, VectorType::get({16}, rewriter.getI32Type()), src); @@ -1282,6 +1290,10 @@ class BroadcastScalarOpConversion rewriter.replaceOpWithNewOp( op, VectorType::get({32}, rewriter.getBF16Type()), adaptor.getSource()); + } else if (resultBitWidth == 32) { + rewriter.replaceOpWithNewOp( + op, VectorType::get({16}, rewriter.getF32Type()), + adaptor.getSource()); } else { op.emitWarning() << "aievec.broadcast_scalar conversion with result bitwidth " diff --git a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp index 4799207b8f..cd1a0074c3 100644 --- a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp +++ b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp @@ -541,8 +541,14 @@ LogicalResult BroadcastScalarOp::verify() { if (!resultType) return emitError("requires vector type"); - if (!sourceType) - return emitError("requires source type"); + if (!sourceType.isa()) + return emitError("requires source type to be integer or float"); + + Type resultElemType = resultType.getElementType(); + if (sourceType != resultElemType) { + return emitError("the element type of result vector must be the same as " + "the source type"); + } return success(); } diff --git a/test/Conversion/AIEVecToLLVM/broadcast_scalar.mlir b/test/Conversion/AIEVecToLLVM/broadcast_scalar.mlir index 63902da0bf..7dee049265 100644 --- a/test/Conversion/AIEVecToLLVM/broadcast_scalar.mlir +++ b/test/Conversion/AIEVecToLLVM/broadcast_scalar.mlir @@ -15,6 +15,21 @@ func.func @i8_broadcast_scalar(%arg0 : i8) -> vector<64xi8> { // ----- +func.func @i16_broadcast_scalar(%arg0 : i16) -> vector<32xi16> { + %0 = aievec.broadcast_scalar %arg0 : i16, vector<32xi16> + return %0 : vector<32xi16> +} + +// CHECK-LABEL: @i16_broadcast_scalar +// CHECK-SAME: %[[ARG0:.*]]: i16 +// CHECK: %[[VAL:.*]] = llvm.sext %[[ARG0]] : i16 to i32 +// CHECK-NEXT: %[[VBROADCAST:.*]] = "xllvm.intr.aie2.vbroadcast16.I512"( +// CHECK-SAME: %[[VAL]]) : +// CHECK-SAME: (i32) -> vector<32xi16> +// CHECK-NEXT: return %[[VBROADCAST]] : vector<32xi16> + +// ----- + func.func @i32_broadcast_scalar(%arg0 : i32) -> vector<16xi32> { %0 = aievec.broadcast_scalar %arg0 : i32, vector<16xi32> return %0 : vector<16xi32> @@ -39,4 +54,18 @@ func.func @bf16_broadcast_scalar(%arg0 : bf16) -> vector<32xbf16> { // CHECK: %[[VBROADCAST:.*]] = "xllvm.intr.aie2.vbroadcast16.bf512"( // CHECK-SAME: %[[ARG0]]) : // CHECK-SAME: (bf16) -> vector<32xbf16> -// CHECK-NEXT: return %[[VBROADCAST]] : vector<32xbf16> \ No newline at end of file +// CHECK-NEXT: return %[[VBROADCAST]] : vector<32xbf16> + +// ----- + +func.func @f32_broadcast_scalar(%arg0 : f32) -> vector<16xf32> { + %0 = aievec.broadcast_scalar %arg0 : f32, vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: @f32_broadcast_scalar +// CHECK-SAME: %[[ARG0:.*]]: f32 +// CHECK: %[[VBROADCAST:.*]] = "xllvm.intr.aie2.vbroadcastfloat.I512"( +// CHECK-SAME: %[[ARG0]]) : +// CHECK-SAME: (f32) -> vector<16xf32> +// CHECK-NEXT: return %[[VBROADCAST]] : vector<16xf32> diff --git a/test/Target/LLVMIR/aievec.mlir b/test/Target/LLVMIR/aievec.mlir index 013f31ce93..2e40dee5c1 100644 --- a/test/Target/LLVMIR/aievec.mlir +++ b/test/Target/LLVMIR/aievec.mlir @@ -152,6 +152,14 @@ llvm.func @vbroadcast8_i512(%val : i32) -> vector<64xi8> { llvm.return %0 : vector<64xi8> } +// CHECK-LABEL: define <32 x i16> @vbroadcast16_i512 +llvm.func @vbroadcast16_i512(%val : i32) -> vector<32xi16> { + // CHECK: call <32 x i16> @llvm.aie2.vbroadcast16.I512( + // CHECK-SAME: i32 %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.vbroadcast16.I512"(%val) : (i32) -> vector<32xi16> + llvm.return %0 : vector<32xi16> +} + // CHECK-LABEL: define <16 x i32> @vbroadcast32_i512 llvm.func @vbroadcast32_i512(%val : i32) -> vector<16xi32> { // CHECK: call <16 x i32> @llvm.aie2.vbroadcast32.I512( @@ -168,6 +176,14 @@ llvm.func @vbroadcast16_bf512(%val : bf16) -> vector<32xbf16> { llvm.return %0 : vector<32xbf16> } +// CHECK-LABEL: define <16 x float> @vbroadcastfloat_i512 +llvm.func @vbroadcastfloat_i512(%val : f32) -> vector<16xf32> { + // CHECK: call <16 x float> @llvm.aie2.vbroadcastfloat.I512( + // CHECK-SAME: float %{{[0-9]+}}) + %0 = "xllvm.intr.aie2.vbroadcastfloat.I512"(%val) : (f32) -> vector<16xf32> + llvm.return %0 : vector<16xf32> +} + // ----- EXT ----- // CHECK-LABEL: define <8 x i32> @ext_i256_i512