Skip to content

Commit

Permalink
Vectorize TensorPrimitives.Tanh/Cosh/Sinh (dotnet#93093)
Browse files Browse the repository at this point in the history
* Vectorize TensorPrimitives.Tanh/Cosh/Sinh

Tanh and Cosh are based on AOCL-LibM.

AOCL-LibM doesn't appear to have a sinh implementation, so this Sinh is just based on the sinh formula based on exp(x).

I also augmented the tests further, including:
- Added more tests for sinh/cosh/tanh
- Add an equality routine that supports comparing larger values with a tolerance
- Tightened the tolerance for most functions
- Changed some tests to be theories to be consistent with style elsewhere in the tests
- Fixed some use of Math to be MathF

* Remove unnecessary special-handling path from cosh

* Remove unnecessary special-handling path from tanh

* Redo sinh based on cosh

* Address PR feedback
  • Loading branch information
stephentoub authored and michaelgsharp committed Oct 20, 2023
1 parent 6c63ae7 commit bc4d0cd
Show file tree
Hide file tree
Showing 4 changed files with 480 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,8 @@ public static void AddMultiply(ReadOnlySpan<float> x, float y, ReadOnlySpan<floa
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Cosh(x[i]);
}
}
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<CoshOperator>(x, destination);

/// <summary>Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
Expand Down Expand Up @@ -1012,20 +1000,8 @@ public static void Sigmoid(ReadOnlySpan<float> x, Span<float> destination)
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Sinh(x[i]);
}
}
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<SinhOperator>(x, destination);

/// <summary>Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand Down Expand Up @@ -1177,20 +1153,8 @@ public static float SumOfSquares(ReadOnlySpan<float> x) =>
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Tanh(x[i]);
}
}
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<TanhOperator>(x, destination);

/// <summary>Throws an exception if the <paramref name="input"/> and <paramref name="output"/> spans overlap and don't begin at the same memory location.</summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;
using System.Security.Cryptography;

namespace System.Numerics.Tensors
{
Expand Down Expand Up @@ -147,15 +148,15 @@ public static void ConvertToHalf(ReadOnlySpan<float> source, Span<Half> destinat
// so we convert the VectorXx<float> to a VectorXx<uint>, and the caller then uses this twice, narrows the combination
// into a VectorXx<ushort>, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`.

#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding
const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1
const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask
const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2
const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half
const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half
const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float
#pragma warning restore IDE0059
#pragma warning restore IDE0059

static Vector128<uint> SingleToHalfAsWidenedUInt32_Vector128(Vector128<float> value)
{
Expand Down Expand Up @@ -462,13 +463,13 @@ public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<float> destin
// The VectorXx<uint> is created by reading a vector of Halfs as a VectorXx<short> then widened to two VectorXx<int>s and cast to VectorXx<uint>s.
// We loop handling one input vector at a time, producing two output float vectors.

#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single
const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single
const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half
const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half
#pragma warning restore IDE0059
#pragma warning restore IDE0059

static Vector128<float> HalfAsWidenedUInt32ToSingle_Vector128(Vector128<uint> value)
{
Expand Down Expand Up @@ -2992,6 +2993,156 @@ public static Vector512<float> Invoke(Vector512<float> x)
#endif
}

/// <summary>MathF.Cosh(x)</summary>
private readonly struct CoshOperator : IUnaryOperator
{
// This code is based on `vrs4_coshf` from amd/aocl-libm-ose
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// Spec:
// coshf(|x| > 89.415985107421875) = Infinity
// coshf(Infinity) = infinity
// coshf(-Infinity) = infinity
//
// cosh(x) = (exp(x) + exp(-x))/2
// cosh(-x) = +cosh(x)
//
// checks for special cases
// if ( asint(x) > infinity) return x with overflow exception and
// return x.
// if x is NaN then raise invalid FP operation exception and return x.
//
// coshf = v/2 * exp(x - log(v)) where v = 0x1.0000e8p-1

private const float LOGV = 0.693161f;
private const float HALFV = 1.0000138f;
private const float INVV2 = 0.24999309f;

public static float Invoke(float x) => MathF.Cosh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<float> y = Vector128.Abs(x);
Vector128<float> z = ExpOperator.Invoke(y - Vector128.Create(LOGV));
return Vector128.Create(HALFV) * (z + (Vector128.Create(INVV2) / z));
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<float> y = Vector256.Abs(x);
Vector256<float> z = ExpOperator.Invoke(y - Vector256.Create(LOGV));
return Vector256.Create(HALFV) * (z + (Vector256.Create(INVV2) / z));
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<float> y = Vector512.Abs(x);
Vector512<float> z = ExpOperator.Invoke(y - Vector512.Create(LOGV));
return Vector512.Create(HALFV) * (z + (Vector512.Create(INVV2) / z));
}
#endif
}

/// <summary>MathF.Sinh(x)</summary>
private readonly struct SinhOperator : IUnaryOperator
{
// Same as cosh, but with `z -` rather than `z +`, and with the sign
// flipped on the result based on the sign of the input.

private const uint SIGN_MASK = 0x7FFFFFFF;
private const float LOGV = 0.693161f;
private const float HALFV = 1.0000138f;
private const float INVV2 = 0.24999309f;

public static float Invoke(float x) => MathF.Sinh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<float> y = Vector128.Abs(x);
Vector128<float> z = ExpOperator.Invoke(y - Vector128.Create(LOGV));
Vector128<float> result = Vector128.Create(HALFV) * (z - (Vector128.Create(INVV2) / z));
Vector128<uint> sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK);
return (sign ^ result.AsUInt32()).AsSingle();
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<float> y = Vector256.Abs(x);
Vector256<float> z = ExpOperator.Invoke(y - Vector256.Create(LOGV));
Vector256<float> result = Vector256.Create(HALFV) * (z - (Vector256.Create(INVV2) / z));
Vector256<uint> sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK);
return (sign ^ result.AsUInt32()).AsSingle();
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<float> y = Vector512.Abs(x);
Vector512<float> z = ExpOperator.Invoke(y - Vector512.Create(LOGV));
Vector512<float> result = Vector512.Create(HALFV) * (z - (Vector512.Create(INVV2) / z));
Vector512<uint> sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK);
return (sign ^ result.AsUInt32()).AsSingle();
}
#endif
}

/// <summary>MathF.Tanh(x)</summary>
private readonly struct TanhOperator : IUnaryOperator
{
// This code is based on `vrs4_tanhf` from amd/aocl-libm-ose
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// To compute vrs4_tanhf(v_f32x4_t x)
// Let y = |x|
// If 0 <= y < 0x1.154246p3
// Let z = e^(-2.0 * y) - 1 -(1)
//
// Using (1), tanhf(y) can be calculated as,
// tanhf(y) = -z / (z + 2.0)
//
// For other cases, call scalar tanhf()
//
// If x < 0, then we use the identity
// tanhf(-x) = -tanhf(x)

private const uint SIGN_MASK = 0x7FFFFFFF;

public static float Invoke(float x) => MathF.Tanh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<float> y = Vector128.Abs(x);
Vector128<float> z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f);
Vector128<uint> sign = x.AsUInt32() & Vector128.Create(~SIGN_MASK);
return (sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32()).AsSingle();
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<float> y = Vector256.Abs(x);
Vector256<float> z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f);
Vector256<uint> sign = x.AsUInt32() & Vector256.Create(~SIGN_MASK);
return (sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32()).AsSingle();
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<float> y = Vector512.Abs(x);
Vector512<float> z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f);
Vector512<uint> sign = x.AsUInt32() & Vector512.Create(~SIGN_MASK);
return (sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32()).AsSingle();
}
#endif
}

/// <summary>MathF.Log(x)</summary>
private readonly struct LogOperator : IUnaryOperator
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
}

/// <summary>MathF.Exp(x)</summary>
private readonly struct ExpOperator : IUnaryOperator
{
public bool CanVectorize => false;
Expand All @@ -1035,6 +1036,36 @@ public Vector<float> Invoke(Vector<float> x) =>
throw new NotImplementedException();
}

/// <summary>MathF.Sinh(x)</summary>
private readonly struct SinhOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Sinh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Cosh(x)</summary>
private readonly struct CoshOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Cosh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Tanh(x)</summary>
private readonly struct TanhOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Tanh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Log(x)</summary>
private readonly struct LogOperator : IUnaryOperator
{
Expand Down
Loading

0 comments on commit bc4d0cd

Please sign in to comment.