Skip to content

Commit

Permalink
[aievec] to-llvm flow for aievec.ext_elem op (#1442)
Browse files Browse the repository at this point in the history
* This PR add the support for aievec.ext_elem op going through the to-llvm flow.
* Add aievec-to-llvm conversion pattern for the aievec.ext_elem op.
* Add aievec-to-llvm conversion tests for the aievec.ext_elem op.
* Add op verifier for the aievec.ext_elem op.
  • Loading branch information
jamestcl-amd authored May 7, 2024
1 parent 7f1f3e6 commit 7cb4396
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 1 deletion.
2 changes: 1 addition & 1 deletion include/aie/Dialect/AIEVec/IR/AIEVecOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def AIEVec_ExtElemOp:
intrinsic. Extract element determined by index from vector.
`$result = ext_elem($source, $index)`.}];
let assemblyFormat = "$source `,` $index attr-dict `:` type($source) `,` type($index) `,` type($result)";
let hasVerifier = 0;
let hasVerifier = 1;
}

def AIEVec_NegOp:
Expand Down
82 changes: 82 additions & 0 deletions lib/Conversion/AIEVecToLLVM/AIEVecToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,87 @@ class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
}
};

class ExtractElemOpConversion
: public mlir::ConvertOpToLLVMPattern<aievec::ExtElemOp> {
public:
using ConvertOpToLLVMPattern<aievec::ExtElemOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(aievec::ExtElemOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Type resultType = op.getResult().getType();
unsigned resultBitWidth = resultType.getIntOrFloatBitWidth();

Value src = adaptor.getSource();
VectorType srcType = cast<VectorType>(src.getType());
Type srcScalarType = srcType.getElementType();
unsigned srcBitWidth = srcScalarType.getIntOrFloatBitWidth();
int srcLanes = getVectorLaneSize(srcType);
int srcVectorSize = srcBitWidth * srcLanes;

if (srcVectorSize != 512) {
op.emitWarning() << "aievec.ext_elem conversion with source vector size "
<< srcVectorSize << " is not supported.\n";
return failure();
}

// create constant for sign
auto signCst = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));

// create xllvm intrinsic
Value extElemOp = nullptr;
SmallVector<Value> operands(
{adaptor.getSource(), adaptor.getIndex(), signCst});
if (resultBitWidth == 8) {
extElemOp = rewriter.create<xllvm::VectorExtractElem8I512IntrOp>(
loc, rewriter.getI32Type(),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({64}, rewriter.getI8Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else if (resultBitWidth == 16) {
extElemOp = rewriter.create<xllvm::VectorExtractElem16I512IntrOp>(
loc, rewriter.getI32Type(),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({32}, rewriter.getI16Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else if (resultBitWidth == 32) {
extElemOp = rewriter.create<xllvm::VectorExtractElem32I512IntrOp>(
loc, rewriter.getI32Type(),
forceCastOperandsToSignature(
rewriter, loc, operands,
{VectorType::get({16}, rewriter.getI32Type()),
rewriter.getI32Type(), rewriter.getI32Type()}));
} else {
op.emitWarning() << "aievec.ext_elem conversion with result bit width "
<< resultBitWidth << " is not implemented.\n";
return failure();
}

// create truncation op (and bitcast op)
if (resultType.isa<IntegerType>()) {
if (resultBitWidth < 32) {
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, resultType, extElemOp);
} else {
rewriter.replaceOp(op, extElemOp);
}
} else {
// Float types
if (resultBitWidth == 16) {
extElemOp = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI16Type(),
extElemOp);
}
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, resultType, extElemOp);
}

return success();
}
};

class FMAElemOpConversion
: public mlir::ConvertOpToLLVMPattern<aievec::FMAElemOp> {
public:
Expand Down Expand Up @@ -1595,6 +1676,7 @@ void populateAIEVecToLLVMConversionPatterns(
FMAElemOpConversion,
MatMulOpConversion,
ShiftOpConversion,
ExtractElemOpConversion,
FoldAIECastOps>(converter);
patterns.add<MulElemOpConversion>(converter, aie2Fp32EmulationOption);
// clang-format on
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/AIEVec/IR/AIEVecOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,30 @@ ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
return parsePackUnpackOp(parser, result);
}

//===----------------------------------------------------------------------===//
// ExtElemOp
//===----------------------------------------------------------------------===//

// Verify Extract Element op.
LogicalResult ExtElemOp::verify() {
// Verify the types
VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());

if (!sourceType)
return emitError("source requires vector type");

// The element type of vectors must always be the same
Type stype = sourceType.getElementType();
Type rtype = getResult().getType();

if (stype != rtype) {
return emitError("the type of result must be the same as the element "
"type of source vector");
}

return success();
}

//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
Expand Down
86 changes: 86 additions & 0 deletions test/Conversion/AIEVecToLLVM/extract_elem.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: aie-opt %s -split-input-file -convert-aievec-to-llvm | FileCheck %s

func.func @i8_extract_elem(%arg0 : vector<64xi8>, %index : i32) -> i8 {
%0 = aievec.ext_elem %arg0, %index : vector<64xi8>, i32, i8
return %0 : i8
}

// CHECK-LABEL: @i8_extract_elem
// CHECK-SAME: %[[ARG0:.*]]: vector<64xi8>,
// CHECK-SAME: %[[INDEX:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem8.I512"(
// CHECK-SAME: %[[ARG0]], %[[INDEX]], %[[CST]]) :
// CHECK-SAME: (vector<64xi8>, i32, i32) -> i32
// CHECK-NEXT: %[[RES:.*]] = llvm.trunc %[[VEXTELEM]] : i32 to i8
// CHECK-NEXT: return %[[RES]] : i8

// -----

func.func @i16_extract_elem(%arg0 : vector<32xi16>, %index : i32) -> i16 {
%0 = aievec.ext_elem %arg0, %index : vector<32xi16>, i32, i16
return %0 : i16
}

// CHECK-LABEL: @i16_extract_elem
// CHECK-SAME: %[[ARG0:.*]]: vector<32xi16>,
// CHECK-SAME: %[[INDEX:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem16.I512"(
// CHECK-SAME: %[[ARG0]], %[[INDEX]], %[[CST]]) :
// CHECK-SAME: (vector<32xi16>, i32, i32) -> i32
// CHECK-NEXT: %[[RES:.*]] = llvm.trunc %[[VEXTELEM]] : i32 to i16
// CHECK-NEXT: return %[[RES]] : i16

// -----

func.func @i32_extract_elem(%arg0 : vector<16xi32>, %index : i32) -> i32 {
%0 = aievec.ext_elem %arg0, %index : vector<16xi32>, i32, i32
return %0 : i32
}

// CHECK-LABEL: @i32_extract_elem
// CHECK-SAME: %[[ARG0:.*]]: vector<16xi32>,
// CHECK-SAME: %[[INDEX:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem32.I512"(
// CHECK-SAME: %[[ARG0]], %[[INDEX]], %[[CST]]) :
// CHECK-SAME: (vector<16xi32>, i32, i32) -> i32
// CHECK-NEXT: return %[[VEXTELEM]] : i32

// -----

func.func @bf16_extract_elem(%arg0 : vector<32xbf16>, %index : i32) -> bf16 {
%0 = aievec.ext_elem %arg0, %index : vector<32xbf16>, i32, bf16
return %0 : bf16
}

// CHECK-LABEL: @bf16_extract_elem
// CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16>,
// CHECK-SAME: %[[INDEX:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[BITCAST:.*]] = llvm.bitcast %[[ARG0]] : vector<32xbf16> to vector<32xi16>
// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem16.I512"(
// CHECK-SAME: %[[BITCAST]], %[[INDEX]], %[[CST]]) :
// CHECK-SAME: (vector<32xi16>, i32, i32) -> i32
// CHECK-NEXT: %[[TRUNC:.*]] = llvm.trunc %[[VEXTELEM]] : i32 to i16
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[TRUNC]] : i16 to bf16
// CHECK-NEXT: return %[[RES]] : bf16

// -----

func.func @f32_extract_elem(%arg0 : vector<16xf32>, %index : i32) -> f32 {
%0 = aievec.ext_elem %arg0, %index : vector<16xf32>, i32, f32
return %0 : f32
}

// CHECK-LABEL: @f32_extract_elem
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>,
// CHECK-SAME: %[[INDEX:.*]]: i32
// CHECK: %[[CST:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-NEXT: %[[BITCAST:.*]] = llvm.bitcast %[[ARG0]] : vector<16xf32> to vector<16xi32>
// CHECK-NEXT: %[[VEXTELEM:.*]] = "xllvm.intr.aie2.vextract.elem32.I512"(
// CHECK-SAME: %[[BITCAST]], %[[INDEX]], %[[CST]]) :
// CHECK-SAME: (vector<16xi32>, i32, i32) -> i32
// CHECK-NEXT: %[[RES:.*]] = llvm.bitcast %[[VEXTELEM]] : i32 to f32
// CHECK-NEXT: return %[[RES]] : f32

0 comments on commit 7cb4396

Please sign in to comment.