diff --git a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td index 86414ff6cf..79cb1c5e00 100644 --- a/include/aie/Dialect/AIEVec/IR/AIEVecOps.td +++ b/include/aie/Dialect/AIEVec/IR/AIEVecOps.td @@ -729,7 +729,7 @@ def AIEVec_ExtElemOp: intrinsic. Extract element determined by index from vector. `$result = ext_elem($source, $index)`.}]; let assemblyFormat = "$source `,` $index attr-dict `:` type($source) `,` type($index) `,` type($result)"; - let hasVerifier = 0; + let hasVerifier = 1; } def AIEVec_NegOp: diff --git a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp index 7914c9c80e..d9cd31f7a6 100644 --- a/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp +++ b/lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp @@ -1364,6 +1364,87 @@ class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern { } }; +class ExtractElemOpConversion + : public mlir::ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Type resultType = op.getResult().getType(); + unsigned resultBitWidth = resultType.getIntOrFloatBitWidth(); + + Value src = adaptor.getSource(); + VectorType srcType = cast(src.getType()); + Type srcScalarType = srcType.getElementType(); + unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth(); + int srcLanes = getVectorLaneSize(srcType); + int srcVectorSize = srcBitWidth * srcLanes; + + if (srcVectorSize != 512) { + op.emitWarning() << "aievec.ext_elem conversion with source vector size " + << srcVectorSize << " is not supported.\n"; + return failure(); + } + + // create constant for sign + auto signCst = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); + + // create xllvm intrinsic + Value extElemOp = nullptr; + SmallVector operands( + {adaptor.getSource(), adaptor.getIndex(), signCst}); + if (resultBitWidth == 8) { + extElemOp = rewriter.create( + loc, rewriter.getI32Type(), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({64}, rewriter.getI8Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } else if (resultBitWidth == 16) { + extElemOp = rewriter.create( + loc, rewriter.getI32Type(), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({32}, rewriter.getI16Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } else if (resultBitWidth == 32) { + extElemOp = rewriter.create( + loc, rewriter.getI32Type(), + forceCastOperandsToSignature( + rewriter, loc, operands, + {VectorType::get({16}, rewriter.getI32Type()), + rewriter.getI32Type(), rewriter.getI32Type()})); + } else { + op.emitWarning() << "aievec.ext_elem conversion with result bit width " + << resultBitWidth << " is not implemented.\n"; + return failure(); + } + + // create truncation op (and bitcast op) + if (resultType.isa()) { + if (resultBitWidth < 32) { + rewriter.replaceOpWithNewOp(op, resultType, extElemOp); + } else { + rewriter.replaceOp(op, extElemOp); + } + } else { + // Float types + if (resultBitWidth == 16) { + extElemOp = rewriter.create(loc, rewriter.getI16Type(), + extElemOp); + } + rewriter.replaceOpWithNewOp(op, resultType, extElemOp); + } + + return success(); + } +}; + class FMAElemOpConversion : public mlir::ConvertOpToLLVMPattern { public: @@ -1595,6 +1676,7 @@ void populateAIEVecToLLVMConversionPatterns( FMAElemOpConversion, MatMulOpConversion, ShiftOpConversion, + ExtractElemOpConversion, FoldAIECastOps>(converter); patterns.add(converter, aie2Fp32EmulationOption); // clang-format on diff --git a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp index cd1a0074c3..b8db4437d9 100644 --- a/lib/Dialect/AIEVec/IR/AIEVecOps.cpp +++ b/lib/Dialect/AIEVec/IR/AIEVecOps.cpp @@ -1506,6 +1506,30 @@ ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) { return parsePackUnpackOp(parser, result); } +//===----------------------------------------------------------------------===// +// ExtElemOp +//===----------------------------------------------------------------------===// + +// Verify Extract Element op. +LogicalResult ExtElemOp::verify() { + // Verify the types + VectorType sourceType = llvm::dyn_cast(getSource().getType()); + + if (!sourceType) + return emitError("source requires vector type"); + + // The element type of vectors must always be the same + Type stype = sourceType.getElementType(); + Type rtype = getResult().getType(); + + if (stype != rtype) { + return emitError("the type of result must be the same as the element " + "type of source vector"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ShiftOp //===----------------------------------------------------------------------===// diff --git a/test/Conversion/AIEVecToLLVM/extract_elem.mlir b/test/Conversion/AIEVecToLLVM/extract_elem.mlir new file mode 100644 index 0000000000..8114b8422d --- /dev/null +++ b/test/Conversion/AIEVecToLLVM/extract_elem.mlir @@ -0,0 +1,86 @@ +// RUN: aie-opt %s -split-input-file -convert-aievec-to-llvm | FileCheck %s + +func.func @i8_extract_elem(%arg0 : vector<64xi8>, %index : i32) -> i8 { + %0 = aievec.ext_elem %arg0, %index : vector<64xi8>, i32, i8 + return %0 : i8 +} + +// CHECK-LABEL: @i8_extract_elem +// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8>, +// CHECK-SAME: %[[INDEX:.*]]: i32 +// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem8.I512"( +// CHECK-SAME: %[[ARG0]], %[[INDEX]], %[[CST]]) : +// CHECK-SAME: (vector<64xi8>, i32, i32) -> i32 +// CHECK-NEXT: %[[RES:.*]] = llvm.trunc %[[VEXTELEM]] : i32 to i8 +// CHECK-NEXT: return %[[RES]] : i8 + +// ----- + +func.func @i16_extract_elem(%arg0 : vector<32xi16>, %index : i32) -> i16 { + %0 = aievec.ext_elem %arg0, %index : vector<32xi16>, i32, i16 + return %0 : i16 +} + +// CHECK-LABEL: @i16_extract_elem +// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16>, +// CHECK-SAME: %[[INDEX:.*]]: i32 +// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem16.I512"( +// CHECK-SAME: %[[ARG0]], %[[INDEX]], %[[CST]]) : +// CHECK-SAME: (vector<32xi16>, i32, i32) -> i32 +// CHECK-NEXT: %[[RES:.*]] = llvm.trunc %[[VEXTELEM]] : i32 to i16 +// CHECK-NEXT: return %[[RES]] : i16 + +// ----- + +func.func @i32_extract_elem(%arg0 : vector<16xi32>, %index : i32) -> i32 { + %0 = aievec.ext_elem %arg0, %index : vector<16xi32>, i32, i32 + return %0 : i32 +} + +// CHECK-LABEL: @i32_extract_elem +// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>, +// CHECK-SAME: %[[INDEX:.*]]: i32 +// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem32.I512"( +// CHECK-SAME: %[[ARG0]], %[[INDEX]], %[[CST]]) : +// CHECK-SAME: (vector<16xi32>, i32, i32) -> i32 +// CHECK-NEXT: return %[[VEXTELEM]] : i32 + +// ----- + +func.func @bf16_extract_elem(%arg0 : vector<32xbf16>, %index : i32) -> bf16 { + %0 = aievec.ext_elem %arg0, %index : vector<32xbf16>, i32, bf16 + return %0 : bf16 +} + +// CHECK-LABEL: @bf16_extract_elem +// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16>, +// CHECK-SAME: %[[INDEX:.*]]: i32 +// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[BITCAST:.*]] = llvm.bitcast %[[ARG0]] : vector<32xbf16> to vector<32xi16> +// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem16.I512"( +// CHECK-SAME: %[[BITCAST]], %[[INDEX]], %[[CST]]) : +// CHECK-SAME: (vector<32xi16>, i32, i32) -> i32 +// CHECK-NEXT: %[[TRUNC:.*]] = llvm.trunc %[[VEXTELEM]] : i32 to i16 +// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[TRUNC]] : i16 to bf16 +// CHECK-NEXT: return %[[RES]] : bf16 + +// ----- + +func.func @f32_extract_elem(%arg0 : vector<16xf32>, %index : i32) -> f32 { + %0 = aievec.ext_elem %arg0, %index : vector<16xf32>, i32, f32 + return %0 : f32 +} + +// CHECK-LABEL: @f32_extract_elem +// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, +// CHECK-SAME: %[[INDEX:.*]]: i32 +// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: %[[BITCAST:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<16xi32> +// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem32.I512"( +// CHECK-SAME: %[[BITCAST]], %[[INDEX]], %[[CST]]) : +// CHECK-SAME: (vector<16xi32>, i32, i32) -> i32 +// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[VEXTELEM]] : i32 to f32 +// CHECK-NEXT: return %[[RES]] : f32