Skip to content

Commit

Permalink
Vectorize TensorPrimitives.Exp (dotnet#93018)
Browse files Browse the repository at this point in the history
* Vectorize TensorPrimitives.Exp

* Update src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs
  • Loading branch information
tannergooding authored Oct 4, 2023
1 parent b4bb155 commit bdd7b7a
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,20 +322,8 @@ public static float Dot(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Exp(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Exp(x[i]);
}
}
public static void Exp(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<ExpOperator>(x, destination);

/// <summary>Searches for the index of the largest single-precision floating-point number in the specified tensor.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,286 @@ public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y)
#endif
}

private readonly struct ExpOperator : IUnaryOperator
{
// This code is based on `vrs4_expf` from amd/aocl-libm-ose
// Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// Implementation Notes:
// 1. Argument Reduction:
// e^x = 2^(x/ln2) --- (1)
//
// Let x/ln(2) = z --- (2)
//
// Let z = n + r , where n is an integer --- (3)
// |r| <= 1/2
//
// From (1), (2) and (3),
// e^x = 2^z
// = 2^(N+r)
// = (2^N)*(2^r) --- (4)
//
// 2. Polynomial Evaluation
// From (4),
// r = z - N
// 2^r = C1 + C2*r + C3*r^2 + C4*r^3 + C5 *r^4 + C6*r^5
//
// 4. Reconstruction
// Thus,
// e^x = (2^N) * (2^r)

private const uint V_ARG_MAX = 0x42AE0000;
private const uint V_MASK = 0x7FFFFFFF;

private const float V_EXPF_MIN = -103.97208f;
private const float V_EXPF_MAX = 88.72284f;

private const double V_EXPF_HUGE = 6755399441055744;
private const double V_TBL_LN2 = 1.4426950408889634;

private const double C1 = 1.0000000754895704;
private const double C2 = 0.6931472254087585;
private const double C3 = 0.2402210737432219;
private const double C4 = 0.05550297297702539;
private const double C5 = 0.009676036358193323;
private const double C6 = 0.001341000536524434;

public static float Invoke(float x) => MathF.Exp(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
// Convert x to double precision
(Vector128<double> xl, Vector128<double> xu) = Vector128.Widen(x);

// x * (64.0 / ln(2))
Vector128<double> v_tbl_ln2 = Vector128.Create(V_TBL_LN2);

Vector128<double> zl = xl * v_tbl_ln2;
Vector128<double> zu = xu * v_tbl_ln2;

Vector128<double> v_expf_huge = Vector128.Create(V_EXPF_HUGE);

Vector128<double> dnl = zl + v_expf_huge;
Vector128<double> dnu = zu + v_expf_huge;

// n = int (z)
Vector128<ulong> nl = dnl.AsUInt64();
Vector128<ulong> nu = dnu.AsUInt64();

// dn = double(n)
dnl -= v_expf_huge;
dnu -= v_expf_huge;

// r = z - dn
Vector128<double> c1 = Vector128.Create(C1);
Vector128<double> c2 = Vector128.Create(C2);
Vector128<double> c3 = Vector128.Create(C3);
Vector128<double> c4 = Vector128.Create(C4);
Vector128<double> c5 = Vector128.Create(C5);
Vector128<double> c6 = Vector128.Create(C6);

Vector128<double> rl = zl - dnl;

Vector128<double> rl2 = rl * rl;
Vector128<double> rl4 = rl2 * rl2;

Vector128<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));


Vector128<double> ru = zu - dnu;

Vector128<double> ru2 = ru * ru;
Vector128<double> ru4 = ru2 * ru2;

Vector128<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));

// result = (float)[poly + (n << 52)]
Vector128<float> ret = Vector128.Narrow(
(polyl.AsUInt64() + Vector128.ShiftLeft(nl, 52)).AsDouble(),
(polyu.AsUInt64() + Vector128.ShiftLeft(nu, 52)).AsDouble()
);

// Check if -103 < |x| < 88
if (Vector128.GreaterThanAny(x.AsUInt32() & Vector128.Create(V_MASK), Vector128.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector128<float> infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX));

ret = Vector128.ConditionalSelect(
infinityMask,
Vector128.Create(float.PositiveInfinity),
ret
);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN)));
}

return ret;
}

public static Vector256<float> Invoke(Vector256<float> x)
{
// Convert x to double precision
(Vector256<double> xl, Vector256<double> xu) = Vector256.Widen(x);

// x * (64.0 / ln(2))
Vector256<double> v_tbl_ln2 = Vector256.Create(V_TBL_LN2);

Vector256<double> zl = xl * v_tbl_ln2;
Vector256<double> zu = xu * v_tbl_ln2;

Vector256<double> v_expf_huge = Vector256.Create(V_EXPF_HUGE);

Vector256<double> dnl = zl + v_expf_huge;
Vector256<double> dnu = zu + v_expf_huge;

// n = int (z)
Vector256<ulong> nl = dnl.AsUInt64();
Vector256<ulong> nu = dnu.AsUInt64();

// dn = double(n)
dnl -= v_expf_huge;
dnu -= v_expf_huge;

// r = z - dn
Vector256<double> c1 = Vector256.Create(C1);
Vector256<double> c2 = Vector256.Create(C2);
Vector256<double> c3 = Vector256.Create(C3);
Vector256<double> c4 = Vector256.Create(C4);
Vector256<double> c5 = Vector256.Create(C5);
Vector256<double> c6 = Vector256.Create(C6);

Vector256<double> rl = zl - dnl;

Vector256<double> rl2 = rl * rl;
Vector256<double> rl4 = rl2 * rl2;

Vector256<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));


Vector256<double> ru = zu - dnu;

Vector256<double> ru2 = ru * ru;
Vector256<double> ru4 = ru2 * ru2;

Vector256<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));

// result = (float)[poly + (n << 52)]
Vector256<float> ret = Vector256.Narrow(
(polyl.AsUInt64() + Vector256.ShiftLeft(nl, 52)).AsDouble(),
(polyu.AsUInt64() + Vector256.ShiftLeft(nu, 52)).AsDouble()
);

// Check if -103 < |x| < 88
if (Vector256.GreaterThanAny(x.AsUInt32() & Vector256.Create(V_MASK), Vector256.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector256<float> infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX));

ret = Vector256.ConditionalSelect(
infinityMask,
Vector256.Create(float.PositiveInfinity),
ret
);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN)));
}

return ret;
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
// Convert x to double precision
(Vector512<double> xl, Vector512<double> xu) = Vector512.Widen(x);

// x * (64.0 / ln(2))
Vector512<double> v_tbl_ln2 = Vector512.Create(V_TBL_LN2);

Vector512<double> zl = xl * v_tbl_ln2;
Vector512<double> zu = xu * v_tbl_ln2;

Vector512<double> v_expf_huge = Vector512.Create(V_EXPF_HUGE);

Vector512<double> dnl = zl + v_expf_huge;
Vector512<double> dnu = zu + v_expf_huge;

// n = int (z)
Vector512<ulong> nl = dnl.AsUInt64();
Vector512<ulong> nu = dnu.AsUInt64();

// dn = double(n)
dnl -= v_expf_huge;
dnu -= v_expf_huge;

// r = z - dn
Vector512<double> c1 = Vector512.Create(C1);
Vector512<double> c2 = Vector512.Create(C2);
Vector512<double> c3 = Vector512.Create(C3);
Vector512<double> c4 = Vector512.Create(C4);
Vector512<double> c5 = Vector512.Create(C5);
Vector512<double> c6 = Vector512.Create(C6);

Vector512<double> rl = zl - dnl;

Vector512<double> rl2 = rl * rl;
Vector512<double> rl4 = rl2 * rl2;

Vector512<double> polyl = (c4 * rl + c3) * rl2
+ ((c6 * rl + c5) * rl4
+ (c2 * rl + c1));


Vector512<double> ru = zu - dnu;

Vector512<double> ru2 = ru * ru;
Vector512<double> ru4 = ru2 * ru2;

Vector512<double> polyu = (c4 * ru + c3) * ru2
+ ((c6 * ru + c5) * ru4
+ (c2 * ru + c1));

// result = (float)[poly + (n << 52)]
Vector512<float> ret = Vector512.Narrow(
(polyl.AsUInt64() + Vector512.ShiftLeft(nl, 52)).AsDouble(),
(polyu.AsUInt64() + Vector512.ShiftLeft(nu, 52)).AsDouble()
);

// Check if -103 < |x| < 88
if (Vector512.GreaterThanAny(x.AsUInt32() & Vector512.Create(V_MASK), Vector512.Create(V_ARG_MAX)))
{
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
Vector512<float> infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX));

ret = Vector512.ConditionalSelect(
infinityMask,
Vector512.Create(float.PositiveInfinity),
ret
);

// (x < V_EXPF_MIN) ? 0 : x
ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN)));
}

return ret;
}
#endif
}

private readonly struct LogOperator : IUnaryOperator
{
// This code is based on `vrs4_logf` from amd/aocl-libm-ose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,19 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
}

private readonly struct ExpOperator : IUnaryOperator
{
public bool CanVectorize => false;

public float Invoke(float x) => MathF.Exp(x);

public Vector<float> Invoke(Vector<float> x)
{
// Vectorizing requires shift left support, which is .NET 7 or later
throw new NotImplementedException();
}
}

private readonly struct LogOperator : IUnaryOperator
{
public bool CanVectorize => false;
Expand Down

0 comments on commit bdd7b7a

Please sign in to comment.