Skip to content

Commit

Permalink
[AArch64] Lower partial add reduction to udot or svdot (llvm#101010)
Browse files Browse the repository at this point in the history
This patch introduces lowering of the partial add reduction intrinsic to
a udot or svdot for AArch64. This also involves adding a
`shouldExpandPartialReductionIntrinsic` target hook, which AArch64 will
return false from in the cases that it can be lowered.
  • Loading branch information
SamTebbs33 authored Sep 2, 2024
1 parent df3d70b commit 44cfbef
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 25 deletions.
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,11 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);

/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
/// its operands and ReducedTY is the intrinsic's return type.
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2);

/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
SDValue expandVAArg(SDNode *Node);

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,13 @@ class TargetLoweringBase {
return true;
}

/// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
/// should be expanded using generic code in SelectionDAGBuilder.
virtual bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
return true;
}

/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
Expand Down
30 changes: 30 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <deque>
#include <limits>
#include <optional>
#include <set>
Expand Down Expand Up @@ -2439,6 +2440,35 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}

SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2) {
EVT FullTy = Op2.getValueType();

unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;

// Collect all of the subvectors
std::deque<SDValue> Subvectors = {Op1};
for (unsigned I = 0; I < ScaleFactor; I++) {
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
Subvectors.push_back(
getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
}

// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(
getNode(ISD::ADD, DL, ReducedTy, {Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
}

assert(Subvectors.size() == 1 &&
"There should only be one subvector after tree flattening");

return Subvectors[0];
}

SDValue SelectionDAG::expandVAArg(SDNode *Node) {
SDLoc dl(Node);
const TargetLowering &TLI = getTargetLoweringInfo();
Expand Down
31 changes: 6 additions & 25 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8038,34 +8038,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
SDValue OpNode = getValue(I.getOperand(1));
EVT ReducedTy = EVT::getEVT(I.getType());
EVT FullTy = OpNode.getValueType();

unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;

// Collect all of the subvectors
std::deque<SDValue> Subvectors;
Subvectors.push_back(getValue(I.getOperand(0)));
for (unsigned i = 0; i < ScaleFactor; i++) {
auto SourceIndex = DAG.getVectorIdxConstant(i * Stride, sdl);
Subvectors.push_back(DAG.getNode(ISD::EXTRACT_SUBVECTOR, sdl, ReducedTy,
{OpNode, SourceIndex}));
}

// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(DAG.getNode(ISD::ADD, sdl, ReducedTy,
{Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
visitTargetIntrinsic(I, Intrinsic);
return;
}

assert(Subvectors.size() == 1 &&
"There should only be one subvector after tree flattening");

setValue(&I, Subvectors[0]);
setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
getValue(I.getOperand(0)),
getValue(I.getOperand(1))));
return;
}
case Intrinsic::experimental_cttz_elts: {
Expand Down
70 changes: 70 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,15 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}

bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const IntrinsicInst *I) const {
if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
return true;

EVT VT = EVT::getEVT(I->getType());
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
}

bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
Expand Down Expand Up @@ -21763,6 +21772,61 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}

SDValue tryLowerPartialReductionToDot(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {

assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
getIntrinsicID(N) ==
Intrinsic::experimental_vector_partial_reduce_add &&
"Expected a partial reduction node");

if (!Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();

SDLoc DL(N);

// The narrower of the two operands. Used as the accumulator
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
if (MulOp->getOpcode() != ISD::MUL)
return SDValue();

auto ExtA = MulOp->getOperand(0);
auto ExtB = MulOp->getOperand(1);
bool IsSExt = ExtA->getOpcode() == ISD::SIGN_EXTEND;
bool IsZExt = ExtA->getOpcode() == ISD::ZERO_EXTEND;
if (ExtA->getOpcode() != ExtB->getOpcode() || (!IsSExt && !IsZExt))
return SDValue();

auto A = ExtA->getOperand(0);
auto B = ExtB->getOperand(0);
if (A.getValueType() != B.getValueType())
return SDValue();

unsigned Opcode = 0;

if (IsSExt)
Opcode = AArch64ISD::SDOT;
else if (IsZExt)
Opcode = AArch64ISD::UDOT;

assert(Opcode != 0 && "Unexpected dot product case encountered.");

EVT ReducedType = N->getValueType(0);
EVT MulSrcType = A.getValueType();

// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);

if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);

return SDValue();
}

static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand All @@ -21771,6 +21835,12 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
case Intrinsic::experimental_vector_partial_reduce_add: {
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,9 @@ class AArch64TargetLowering : public TargetLowering {

bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;

bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;

bool shouldExpandCttzElements(EVT VT) const override;

/// If a change in streaming mode is required on entry to/return from a
Expand Down
96 changes: 96 additions & 0 deletions llvm/test/CodeGen/AArch64/partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s

define <vscale x 4 x i32> @dotp(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 2 x i64> @dotp_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: udot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = zext <vscale x 8 x i16> %b to <vscale x 8 x i64>
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}

define <vscale x 4 x i32> @dotp_sext(<vscale x 4 x i32> %accc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: dotp_sext:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
%mult = mul nuw nsw <vscale x 16 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %accc, <vscale x 16 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 2 x i64> @dotp_wide_sext(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
; CHECK-LABEL: dotp_wide_sext:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sdot z0.d, z1.h, z2.h
; CHECK-NEXT: ret
entry:
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
%b.wide = sext <vscale x 8 x i16> %b to <vscale x 8 x i64>
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}

define <vscale x 4 x i32> @not_dotp(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_dotp:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z1.h, z1.h, #0xff
; CHECK-NEXT: and z2.h, z2.h, #0xff
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: uunpklo z3.s, z1.h
; CHECK-NEXT: uunpklo z4.s, z2.h
; CHECK-NEXT: uunpkhi z1.s, z1.h
; CHECK-NEXT: uunpkhi z2.s, z2.h
; CHECK-NEXT: mla z0.s, p0/m, z3.s, z4.s
; CHECK-NEXT: mla z0.s, p0/m, z1.s, z2.s
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i32>
%mult = mul nuw nsw <vscale x 8 x i32> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %mult)
ret <vscale x 4 x i32> %partial.reduce
}

define <vscale x 2 x i64> @not_dotp_wide(<vscale x 2 x i64> %acc, <vscale x 4 x i16> %a, <vscale x 4 x i16> %b) {
; CHECK-LABEL: not_dotp_wide:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: and z1.s, z1.s, #0xffff
; CHECK-NEXT: and z2.s, z2.s, #0xffff
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: uunpklo z3.d, z1.s
; CHECK-NEXT: uunpklo z4.d, z2.s
; CHECK-NEXT: uunpkhi z1.d, z1.s
; CHECK-NEXT: uunpkhi z2.d, z2.s
; CHECK-NEXT: mla z0.d, p0/m, z3.d, z4.d
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
; CHECK-NEXT: ret
entry:
%a.wide = zext <vscale x 4 x i16> %a to <vscale x 4 x i64>
%b.wide = zext <vscale x 4 x i16> %b to <vscale x 4 x i64>
%mult = mul nuw nsw <vscale x 4 x i64> %a.wide, %b.wide
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %mult)
ret <vscale x 2 x i64> %partial.reduce
}

0 comments on commit 44cfbef

Please sign in to comment.