Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize TensorPrimitives.Tanh/Cosh/Sinh #93093

Merged
merged 6 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,8 @@ public static void AddMultiply(ReadOnlySpan<float> x, float y, ReadOnlySpan<floa
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Cosh(x[i]);
}
}
public static void Cosh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<CoshOperator>(x, destination);

/// <summary>Computes the cosine similarity between the two specified non-empty, equal-length tensors of single-precision floating-point numbers.</summary>
/// <param name="x">The first tensor, represented as a span.</param>
Expand Down Expand Up @@ -1012,20 +1000,8 @@ public static void Sigmoid(ReadOnlySpan<float> x, Span<float> destination)
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Sinh(x[i]);
}
}
public static void Sinh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<SinhOperator>(x, destination);

/// <summary>Computes the softmax function over the specified non-empty tensor of single-precision floating-point numbers.</summary>
/// <param name="x">The tensor, represented as a span.</param>
Expand Down Expand Up @@ -1177,20 +1153,8 @@ public static float SumOfSquares(ReadOnlySpan<float> x) =>
/// operating systems or architectures.
/// </para>
/// </remarks>
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination)
{
if (x.Length > destination.Length)
{
ThrowHelper.ThrowArgument_DestinationTooShort();
}

ValidateInputOutputSpanNonOverlapping(x, destination);

for (int i = 0; i < x.Length; i++)
{
destination[i] = MathF.Tanh(x[i]);
}
}
public static void Tanh(ReadOnlySpan<float> x, Span<float> destination) =>
InvokeSpanIntoSpan<TanhOperator>(x, destination);

/// <summary>Throws an exception if the <paramref name="input"/> and <paramref name="output"/> spans overlap and don't begin at the same memory location.</summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2992,6 +2992,183 @@ public static Vector512<float> Invoke(Vector512<float> x)
#endif
}

/// <summary>MathF.Sinh(x)</summary>
private readonly struct SinhOperator : IUnaryOperator
{
public static float Invoke(float x) => MathF.Sinh(x);
public static Vector128<float> Invoke(Vector128<float> x) => (ExpOperator.Invoke(x) - ExpOperator.Invoke(-x)) / Vector128.Create(2f);
public static Vector256<float> Invoke(Vector256<float> x) => (ExpOperator.Invoke(x) - ExpOperator.Invoke(-x)) / Vector256.Create(2f);
#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x) => (ExpOperator.Invoke(x) - ExpOperator.Invoke(-x)) / Vector512.Create(2f);
#endif
}

/// <summary>MathF.Cosh(x)</summary>
private readonly struct CoshOperator : IUnaryOperator
{
// This code is based on `vrs4_coshf` from amd/aocl-libm-ose
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// Spec:
// coshf(|x| > 89.415985107421875) = Infinity
// coshf(Infinity) = infinity
// coshf(-Infinity) = infinity
//
// cosh(x) = (exp(x) + exp(-x))/2
// cosh(-x) = +cosh(x)
//
// checks for special cases
// if ( asint(x) > infinity) return x with overflow exception and
// return x.
// if x is NaN then raise invalid FP operation exception and return x.
//
// coshf = v/2 * exp(x - log(v)) where v = 0x1.0000e8p-1

private const uint SIGN_MASK = 0x7FFFFFFF;
private const uint ARG_MAX = 0x42B2D4FC;
private const uint LOGV = 0x3f317300;
private const uint HALFV = 0x3f800074;
private const uint INVV2 = 0x3e7ffe30;

public static float Invoke(float x) => MathF.Cosh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<uint> ux = x.AsUInt32() & Vector128.Create(SIGN_MASK);
if (Vector128.GreaterThanAny(ux, Vector128.Create(ARG_MAX)))
{
return Vector128.Create(
MathF.Cosh(x.GetElement(0)),
MathF.Cosh(x.GetElement(1)),
MathF.Cosh(x.GetElement(2)),
MathF.Cosh(x.GetElement(3)));
}

Vector128<float> y = ux.AsSingle();
Vector128<float> z = ExpOperator.Invoke(y - Vector128.Create(LOGV).AsSingle());
return Vector128.Create(HALFV).AsSingle() * (z + Vector128.Create(INVV2).AsSingle() * 1f / z);
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<uint> ux = x.AsUInt32() & Vector256.Create(SIGN_MASK);
if (Vector256.GreaterThanAny(ux, Vector256.Create(ARG_MAX)))
{
return Vector256.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
}

Vector256<float> y = ux.AsSingle();
Vector256<float> z = ExpOperator.Invoke(y - Vector256.Create(LOGV).AsSingle());
return Vector256.Create(HALFV).AsSingle() * (z + Vector256.Create(INVV2).AsSingle() * 1f / z);
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<uint> ux = x.AsUInt32() & Vector512.Create(SIGN_MASK);
if (Vector512.GreaterThanAny(ux, Vector512.Create(ARG_MAX)))
{
return Vector512.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
}

Vector512<float> y = ux.AsSingle();
Vector512<float> z = ExpOperator.Invoke(y - Vector512.Create(LOGV).AsSingle());
return Vector512.Create(HALFV).AsSingle() * (z + Vector512.Create(INVV2).AsSingle() * 1f / z);
}
#endif
}

/// <summary>MathF.Tanh(x)</summary>
private readonly struct TanhOperator : IUnaryOperator
{
// This code is based on `vrs4_tanhf` from amd/aocl-libm-ose
// Copyright (C) 2008-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Licensed under the BSD 3-Clause "New" or "Revised" License
// See THIRD-PARTY-NOTICES.TXT for the full license text

// To compute vrs4_tanhf(v_f32x4_t x)
// Let y = |x|
// If 0 <= y < 0x1.154246p3
// Let z = e^(-2.0 * y) - 1 -(1)
//
// Using (1), tanhf(y) can be calculated as,
// tanhf(y) = -z / (z + 2.0)
//
// For other cases, call scalar tanhf()
//
// If x < 0, then we use the identity
// tanhf(-x) = -tanhf(x)

private const uint V4_TANHF_ARG_MAX = 0x410AA123;
private const uint V4_TANHF_SIGN_MASK = 0x7FFFFFFF;

public static float Invoke(float x) => MathF.Tanh(x);

public static Vector128<float> Invoke(Vector128<float> x)
{
Vector128<uint> ux = x.AsUInt32();
Vector128<uint> sign = ux & Vector128.Create(~V4_TANHF_SIGN_MASK);

ux &= Vector128.Create(V4_TANHF_SIGN_MASK);
if (Vector128.GreaterThanAny(ux, Vector128.Create(V4_TANHF_ARG_MAX)))
{
return Vector128.Create(
MathF.Tanh(x.GetElement(0)),
MathF.Tanh(x.GetElement(1)),
MathF.Tanh(x.GetElement(2)),
MathF.Tanh(x.GetElement(3)));
}

Vector128<float> y = ux.AsSingle();
Vector128<float> z = ExpOperator.Invoke(Vector128.Create(-2f) * y) - Vector128.Create(1f);
Vector128<uint> result = sign ^ (-z / (z + Vector128.Create(2f))).AsUInt32();

return result.AsSingle();
}

public static Vector256<float> Invoke(Vector256<float> x)
{
Vector256<uint> ux = x.AsUInt32();
Vector256<uint> sign = ux & Vector256.Create(~V4_TANHF_SIGN_MASK);

ux &= Vector256.Create(V4_TANHF_SIGN_MASK);
if (Vector256.GreaterThanAny(ux, Vector256.Create(V4_TANHF_ARG_MAX)))
{
return Vector256.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
}

Vector256<float> y = ux.AsSingle();
Vector256<float> z = ExpOperator.Invoke(Vector256.Create(-2f) * y) - Vector256.Create(1f);
Vector256<uint> result = sign ^ (-z / (z + Vector256.Create(2f))).AsUInt32();

return result.AsSingle();
}

#if NET8_0_OR_GREATER
public static Vector512<float> Invoke(Vector512<float> x)
{
Vector512<uint> ux = x.AsUInt32();
Vector512<uint> sign = ux & Vector512.Create(~V4_TANHF_SIGN_MASK);

ux &= Vector512.Create(V4_TANHF_SIGN_MASK);
if (Vector512.GreaterThanAny(ux, Vector512.Create(V4_TANHF_ARG_MAX)))
{
return Vector512.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
}

Vector512<float> y = ux.AsSingle();
Vector512<float> z = ExpOperator.Invoke(Vector512.Create(-2f) * y) - Vector512.Create(1f);
Vector512<uint> result = sign ^ (-z / (z + Vector512.Create(2f))).AsUInt32();

return result.AsSingle();
}
#endif
}

/// <summary>MathF.Log(x)</summary>
private readonly struct LogOperator : IUnaryOperator
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
}

/// <summary>MathF.Exp(x)</summary>
private readonly struct ExpOperator : IUnaryOperator
{
public bool CanVectorize => false;
Expand All @@ -1035,6 +1036,36 @@ public Vector<float> Invoke(Vector<float> x) =>
throw new NotImplementedException();
}

/// <summary>MathF.Sinh(x)</summary>
private readonly struct SinhOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Sinh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Cosh(x)</summary>
private readonly struct CoshOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Cosh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Tanh(x)</summary>
private readonly struct TanhOperator : IUnaryOperator
{
public bool CanVectorize => false;
public float Invoke(float x) => MathF.Tanh(x);
public Vector<float> Invoke(Vector<float> x) =>
// requires ShiftLeft (.NET 7+)
throw new NotImplementedException();
}

/// <summary>MathF.Log(x)</summary>
private readonly struct LogOperator : IUnaryOperator
{
Expand Down
Loading