From 44cfbef1b3cb0dd33886cc27441930008a245963 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 2 Sep 2024 14:06:14 +0100 Subject: [PATCH] [AArch64] Lower partial add reduction to udot or svdot (#101010) This patch introduces lowering of the partial add reduction intrinsic to a udot or svdot for AArch64. This also involves adding a `shouldExpandPartialReductionIntrinsic` target hook, which AArch64 will return false from in the cases that it can be lowered. --- llvm/include/llvm/CodeGen/SelectionDAG.h | 5 + llvm/include/llvm/CodeGen/TargetLowering.h | 7 ++ .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 30 ++++++ .../SelectionDAG/SelectionDAGBuilder.cpp | 31 ++---- .../Target/AArch64/AArch64ISelLowering.cpp | 70 ++++++++++++++ llvm/lib/Target/AArch64/AArch64ISelLowering.h | 3 + .../AArch64/partial-reduce-dot-product.ll | 96 +++++++++++++++++++ 7 files changed, 217 insertions(+), 25 deletions(-) create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 1514d92b36b3c2..7ee8ca18c2c1de 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1594,6 +1594,11 @@ class SelectionDAG { /// the target's desired shift amount type. SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op); + /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are + /// its operands and ReducedTY is the intrinsic's return type. + SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1, + SDValue Op2); + /// Expand the specified \c ISD::VAARG node as the Legalize pass would. SDValue expandVAArg(SDNode *Node); diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index eda38cd8a564d6..e17d68d2690c86 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -453,6 +453,13 @@ class TargetLoweringBase { return true; } + /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic + /// should be expanded using generic code in SelectionDAGBuilder. + virtual bool + shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const { + return true; + } + /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded /// using generic code in SelectionDAGBuilder. virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 7f57b6db40ef49..aa468fa9ebb4c3 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -74,6 +74,7 @@ #include #include #include +#include #include #include #include @@ -2439,6 +2440,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) { return getZExtOrTrunc(Op, SDLoc(Op), ShTy); } +SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1, + SDValue Op2) { + EVT FullTy = Op2.getValueType(); + + unsigned Stride = ReducedTy.getVectorMinNumElements(); + unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride; + + // Collect all of the subvectors + std::deque Subvectors = {Op1}; + for (unsigned I = 0; I < ScaleFactor; I++) { + auto SourceIndex = getVectorIdxConstant(I * Stride, DL); + Subvectors.push_back( + getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex})); + } + + // Flatten the subvector tree + while (Subvectors.size() > 1) { + Subvectors.push_back( + getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]})); + Subvectors.pop_front(); + Subvectors.pop_front(); + } + + assert(Subvectors.size() == 1 && + "There should only be one subvector after tree flattening"); + + return Subvectors[0]; +} + SDValue SelectionDAG::expandVAArg(SDNode *Node) { SDLoc dl(Node); const TargetLowering &TLI = getTargetLoweringInfo(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 4b326ba76f97f2..382a555aa656f2 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8038,34 +8038,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, return; } case Intrinsic::experimental_vector_partial_reduce_add: { - SDValue OpNode = getValue(I.getOperand(1)); - EVT ReducedTy = EVT::getEVT(I.getType()); - EVT FullTy = OpNode.getValueType(); - unsigned Stride = ReducedTy.getVectorMinNumElements(); - unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride; - - // Collect all of the subvectors - std::deque Subvectors; - Subvectors.push_back(getValue(I.getOperand(0))); - for (unsigned i = 0; i < ScaleFactor; i++) { - auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl); - Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy, - {OpNode, SourceIndex})); - } - - // Flatten the subvector tree - while (Subvectors.size() > 1) { - Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy, - {Subvectors[0], Subvectors[1]})); - Subvectors.pop_front(); - Subvectors.pop_front(); + if (!TLI.shouldExpandPartialReductionIntrinsic(cast(&I))) { + visitTargetIntrinsic(I, Intrinsic); + return; } - assert(Subvectors.size() == 1 && - "There should only be one subvector after tree flattening"); - - setValue(&I, Subvectors[0]); + setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()), + getValue(I.getOperand(0)), + getValue(I.getOperand(1)))); return; } case Intrinsic::experimental_cttz_elts: { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 11aca69db0a148..1735ff5cd69748 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1988,6 +1988,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT, return false; } +bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic( + const IntrinsicInst *I) const { + if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add) + return true; + + EVT VT = EVT::getEVT(I->getType()); + return VT != MVT::nxv4i32 && VT != MVT::nxv2i64; +} + bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const { if (!Subtarget->isSVEorStreamingSVEAvailable()) return true; @@ -21763,6 +21772,61 @@ static SDValue tryCombineWhileLo(SDNode *N, return SDValue(N, 0); } +SDValue tryLowerPartialReductionToDot(SDNode *N, + const AArch64Subtarget *Subtarget, + SelectionDAG &DAG) { + + assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && + getIntrinsicID(N) == + Intrinsic::experimental_vector_partial_reduce_add && + "Expected a partial reduction node"); + + if (!Subtarget->isSVEorStreamingSVEAvailable()) + return SDValue(); + + SDLoc DL(N); + + // The narrower of the two operands. Used as the accumulator + auto NarrowOp = N->getOperand(1); + auto MulOp = N->getOperand(2); + if (MulOp->getOpcode() != ISD::MUL) + return SDValue(); + + auto ExtA = MulOp->getOperand(0); + auto ExtB = MulOp->getOperand(1); + bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND; + bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND; + if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt)) + return SDValue(); + + auto A = ExtA->getOperand(0); + auto B = ExtB->getOperand(0); + if (A.getValueType() != B.getValueType()) + return SDValue(); + + unsigned Opcode = 0; + + if (IsSExt) + Opcode = AArch64ISD::SDOT; + else if (IsZExt) + Opcode = AArch64ISD::UDOT; + + assert(Opcode != 0 && "Unexpected dot product case encountered."); + + EVT ReducedType = N->getValueType(0); + EVT MulSrcType = A.getValueType(); + + // Dot products operate on chunks of four elements so there must be four times + // as many elements in the wide type + if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) + return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B); + + if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) + return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B); + + return SDValue(); +} + static SDValue performIntrinsicCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -21771,6 +21835,12 @@ static SDValue performIntrinsicCombine(SDNode *N, switch (IID) { default: break; + case Intrinsic::experimental_vector_partial_reduce_add: { + if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG)) + return Dot; + return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0), + N->getOperand(1), N->getOperand(2)); + } case Intrinsic::aarch64_neon_vcvtfxs2fp: case Intrinsic::aarch64_neon_vcvtfxu2fp: return tryCombineFixedPointConvert(N, DCI, DAG); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 39d5df0de0eec7..f9d45b02d30e30 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -998,6 +998,9 @@ class AArch64TargetLowering : public TargetLowering { bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override; + bool + shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override; + bool shouldExpandCttzElements(EVT VT) const override; /// If a change in streaming mode is required on entry to/return from a diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll new file mode 100644 index 00000000000000..b1354ab210f727 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll @@ -0,0 +1,96 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s + +define @dotp( %acc, %a, %b) { +; CHECK-LABEL: dotp: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: udot z0.s, z1.b, z2.b +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) + ret %partial.reduce +} + +define @dotp_wide( %acc, %a, %b) { +; CHECK-LABEL: dotp_wide: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: udot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) + ret %partial.reduce +} + +define @dotp_sext( %accc, %a, %b) { +; CHECK-LABEL: dotp_sext: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %accc, %mult) + ret %partial.reduce +} + +define @dotp_wide_sext( %acc, %a, %b) { +; CHECK-LABEL: dotp_wide_sext: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sdot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) + ret %partial.reduce +} + +define @not_dotp( %acc, %a, %b) { +; CHECK-LABEL: not_dotp: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: and z1.h, z1.h, #0xff +; CHECK-NEXT: and z2.h, z2.h, #0xff +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: uunpklo z3.s, z1.h +; CHECK-NEXT: uunpklo z4.s, z2.h +; CHECK-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s +; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %mult) + ret %partial.reduce +} + +define @not_dotp_wide( %acc, %a, %b) { +; CHECK-LABEL: not_dotp_wide: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: and z1.s, z1.s, #0xffff +; CHECK-NEXT: and z2.s, z2.s, #0xffff +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: uunpklo z3.d, z1.s +; CHECK-NEXT: uunpklo z4.d, z2.s +; CHECK-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEXT: uunpkhi z2.d, z2.s +; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64( %acc, %mult) + ret %partial.reduce +}