Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize TensorPrimitives.Tanh/Cosh/Sinh #93093

Merged
merged 6 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,8 @@ public static void AddMultiply(ReadOnlySpan<float> x, float y, ReadOnlySpan<floa
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Cosh(x[i]);
}
}
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<CoshOperator>(x, destination);

/// <summary>Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
Expand Down Expand Up @@ -1012,20 +1000,8 @@ public static void Sigmoid(ReadOnlySpan<float> x, Span<float> destination)
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Sinh(x[i]);
}
}
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<SinhOperator>(x, destination);

/// <summary>Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand Down Expand Up @@ -1177,20 +1153,8 @@ public static float SumOfSquares(ReadOnlySpan<float> x) =>
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Tanh(x[i]);
}
}
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<TanhOperator>(x, destination);

/// <summary>Throws an exception if the <paramref name="input"/> and <paramref name="output"/> spans overlap and don't begin at the same memory location.</summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;
using System.Security.Cryptography;

namespace System.Numerics.Tensors
{
Expand Down Expand Up @@ -147,15 +148,15 @@ public static void ConvertToHalf(ReadOnlySpan<float> source, Span<Half> destinat
// so we convert the VectorXx<float> to a VectorXx<uint>, and the caller then uses this twice, narrows the combination
// into a VectorXx<ushort>, and then saves that out to the destination `ref Half` reinterpreted as `ref ushort`.

#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
const uint MinExp = 0x3880_0000u; // Minimum exponent for rounding
const uint Exponent126 = 0x3f00_0000u; // Exponent displacement #1
const uint SingleBiasedExponentMask = 0x7F80_0000; // float.BiasedExponentMask; // Exponent mask
const uint Exponent13 = 0x0680_0000u; // Exponent displacement #2
const float MaxHalfValueBelowInfinity = 65520.0f; // Maximum value that is not Infinity in Half
const uint ExponentMask = 0x7C00; // Mask for exponent bits in Half
const uint SingleSignMask = 0x8000_0000u; // float.SignMask; // Mask for sign bit in float
#pragma warning restore IDE0059
#pragma warning restore IDE0059

static Vector128<uint> SingleToHalfAsWidenedUInt32_Vector128(Vector128<float> value)
{
Expand Down Expand Up @@ -462,13 +463,13 @@ public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<float> destin
// The VectorXx<uint> is created by reading a vector of Halfs as a VectorXx<short> then widened to two VectorXx<int>s and cast to VectorXx<uint>s.
// We loop handling one input vector at a time, producing two output float vectors.

#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single
const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single
const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half
const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half
#pragma warning restore IDE0059
#pragma warning restore IDE0059

static Vector128<float> HalfAsWidenedUInt32ToSingle_Vector128(Vector128<uint> value)
{
Expand Down Expand Up @@ -2992,6 +2993,163 @@ public static Vector512<float> Invoke(Vector512<float> x)
#endif
}

/// <summary>MathF.Cosh(x)</summary>
private readonly struct CoshOperator : IUnaryOperator
{
// This code is based on `vrs4_coshf` from amd/aocl-libm-ose
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// Spec:
// coshf(|x| > 89.415985107421875) = Infinity
// coshf(Infinity) = infinity
// coshf(-Infinity) = infinity
//
// cosh(x) = (exp(x) + exp(-x))/2
// cosh(-x) = +cosh(x)
//
// checks for special cases
// if ( asint(x) > infinity) return x with overflow exception and
// return x.
// if x is NaN then raise invalid FP operation exception and return x.
//
// coshf = v/2 * exp(x - log(v)) where v = 0x1.0000e8p-1

private const uint SIGN_MASK = 0x7FFFFFFF;
private const uint LOGV = 0x3f317300;
private const uint HALFV = 0x3f800074;
private const uint INVV2 = 0x3e7ffe30;

public static float Invoke(float x) => MathF.Cosh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<float> y = (x.AsUInt32() & Vector128.Create(SIGN_MASK)).AsSingle();
Vector128<float> z = ExpOperator.Invoke(y - Vector128.Create(LOGV).AsSingle());
return Vector128.Create(HALFV).AsSingle() * (z + (Vector128.Create(INVV2).AsSingle() / z));
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<float> y = (x.AsUInt32() & Vector256.Create(SIGN_MASK)).AsSingle();
Vector256<float> z = ExpOperator.Invoke(y - Vector256.Create(LOGV).AsSingle());
return Vector256.Create(HALFV).AsSingle() * (z + (Vector256.Create(INVV2).AsSingle() / z));
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<float> y = (x.AsUInt32() & Vector512.Create(SIGN_MASK)).AsSingle();
Vector512<float> z = ExpOperator.Invoke(y - Vector512.Create(LOGV).AsSingle());
return Vector512.Create(HALFV).AsSingle() * (z + (Vector512.Create(INVV2).AsSingle() / z));
}
#endif
}

/// <summary>MathF.Sinh(x)</summary>
private readonly struct SinhOperator : IUnaryOperator
{
// Same as cosh, but with `z -` rather than `z +`, and with the sign
// flipped on the result based on the sign of the input.

private const uint SIGN_MASK = 0x7FFFFFFF;
private const uint LOGV = 0x3f317300;
private const uint HALFV = 0x3f800074;
private const uint INVV2 = 0x3e7ffe30;

public static float Invoke(float x) => MathF.Sinh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<uint> ux = x.AsUInt32();
Vector128<uint> sign = ux & Vector128.Create(~SIGN_MASK);
Vector128<float> y = (ux & Vector128.Create(SIGN_MASK)).AsSingle();
Vector128<float> z = ExpOperator.Invoke(y - Vector128.Create(LOGV).AsSingle());
Vector128<float> result = Vector128.Create(HALFV).AsSingle() * (z - (Vector128.Create(INVV2).AsSingle() / z));
return (sign ^ result.AsUInt32()).AsSingle();
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<uint> ux = x.AsUInt32();
Vector256<uint> sign = ux & Vector256.Create(~SIGN_MASK);
Vector256<float> y = (ux & Vector256.Create(SIGN_MASK)).AsSingle();
Vector256<float> z = ExpOperator.Invoke(y - Vector256.Create(LOGV).AsSingle());
Vector256<float> result = Vector256.Create(HALFV).AsSingle() * (z - (Vector256.Create(INVV2).AsSingle() / z));
return (sign ^ result.AsUInt32()).AsSingle();
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<uint> ux = x.AsUInt32();
Vector512<uint> sign = ux & Vector512.Create(~SIGN_MASK);
Vector512<float> y = (ux & Vector512.Create(SIGN_MASK)).AsSingle();
Vector512<float> z = ExpOperator.Invoke(y - Vector512.Create(LOGV).AsSingle());
Vector512<float> result = Vector512.Create(HALFV).AsSingle() * (z - (Vector512.Create(INVV2).AsSingle() / z));
return (sign ^ result.AsUInt32()).AsSingle();
}
#endif
}

/// <summary>MathF.Tanh(x)</summary>
private readonly struct TanhOperator : IUnaryOperator
{
// This code is based on `vrs4_tanhf` from amd/aocl-libm-ose
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// To compute vrs4_tanhf(v_f32x4_t x)
// Let y = |x|
// If 0 <= y < 0x1.154246p3
// Let z = e^(-2.0 * y) - 1 -(1)
//
// Using (1), tanhf(y) can be calculated as,
// tanhf(y) = -z / (z + 2.0)
//
// For other cases, call scalar tanhf()
//
// If x < 0, then we use the identity
// tanhf(-x) = -tanhf(x)

private const uint V4_TANHF_SIGN_MASK = 0x7FFFFFFF;

public static float Invoke(float x) => MathF.Tanh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<uint> ux = x.AsUInt32();
Vector128<uint> sign = ux & Vector128.Create(~V4_TANHF_SIGN_MASK);
Vector128<float> y = (ux & Vector128.Create(V4_TANHF_SIGN_MASK)).AsSingle();
Vector128<float> z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f);
return (sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32()).AsSingle();
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<uint> ux = x.AsUInt32();
Vector256<uint> sign = ux & Vector256.Create(~V4_TANHF_SIGN_MASK);
Vector256<float> y = (ux & Vector256.Create(V4_TANHF_SIGN_MASK)).AsSingle();
Vector256<float> z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f);
return (sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32()).AsSingle();
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<uint> ux = x.AsUInt32();
Vector512<uint> sign = ux & Vector512.Create(~V4_TANHF_SIGN_MASK);
Vector512<float> y = (ux & Vector512.Create(V4_TANHF_SIGN_MASK)).AsSingle();
Vector512<float> z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f);
return (sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32()).AsSingle();
}
#endif
}

/// <summary>MathF.Log(x)</summary>
private readonly struct LogOperator : IUnaryOperator
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
}

/// <summary>MathF.Exp(x)</summary>
private readonly struct ExpOperator : IUnaryOperator
{
public bool CanVectorize => false;
Expand All @@ -1035,6 +1036,36 @@ public Vector<float> Invoke(Vector<float> x) =>
throw new NotImplementedException();
}

/// <summary>MathF.Sinh(x)</summary>
private readonly struct SinhOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Sinh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Cosh(x)</summary>
private readonly struct CoshOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Cosh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Tanh(x)</summary>
private readonly struct TanhOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Tanh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Log(x)</summary>
private readonly struct LogOperator : IUnaryOperator
{
Expand Down
Loading