diff --git a/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt
new file mode 100644
index 00000000000000..a8f2d0192cfec9
--- /dev/null
+++ b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt
@@ -0,0 +1,2 @@
+M:System.Numerics.Tensors.TensorPrimitives.ConvertToHalf(System.ReadOnlySpan{System.Single},System.Span{System.Half})
+M:System.Numerics.Tensors.TensorPrimitives.ConvertToSingle(System.ReadOnlySpan{System.Half},System.Span{System.Single})
\ No newline at end of file
diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj
index b0449170860b45..b87938ff24586b 100644
--- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj
+++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj
@@ -10,6 +10,7 @@
Once this package has shipped a stable version, the following line
should be removed in order to re-enable validation. -->
true
+ ReferenceAssemblyExclusions.txt
diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs
index 6bda8b2c900e16..69562bee25124d 100644
--- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs
+++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs
@@ -316,39 +316,12 @@ public static void Exp(ReadOnlySpan x, Span destination) =>
///
public static int IndexOfMax(ReadOnlySpan x)
{
- int result = -1;
-
- if (!x.IsEmpty)
+ if (x.IsEmpty)
{
- result = 0;
- float max = float.NegativeInfinity;
-
- for (int i = 0; i < x.Length; i++)
- {
- float current = x[i];
-
- if (current != max)
- {
- if (float.IsNaN(current))
- {
- return i;
- }
-
- if (max < current)
- {
- result = i;
- max = current;
- }
- }
- else if (IsNegative(max) && !IsNegative(current))
- {
- result = i;
- max = current;
- }
- }
+ return -1;
}
- return result;
+ return IndexOfMinMaxCore(x);
}
/// Searches for the index of the single-precision floating-point number with the largest magnitude in the specified tensor.
@@ -367,43 +340,12 @@ public static int IndexOfMax(ReadOnlySpan x)
///
public static int IndexOfMaxMagnitude(ReadOnlySpan x)
{
- int result = -1;
-
- if (!x.IsEmpty)
+ if (x.IsEmpty)
{
- result = 0;
- float max = float.NegativeInfinity;
- float maxMag = float.NegativeInfinity;
-
- for (int i = 0; i < x.Length; i++)
- {
- float current = x[i];
- float currentMag = Math.Abs(current);
-
- if (currentMag != maxMag)
- {
- if (float.IsNaN(currentMag))
- {
- return i;
- }
-
- if (maxMag < currentMag)
- {
- result = i;
- max = current;
- maxMag = currentMag;
- }
- }
- else if (IsNegative(max) && !IsNegative(current))
- {
- result = i;
- max = current;
- maxMag = currentMag;
- }
- }
+ return -1;
}
- return result;
+ return IndexOfMinMaxCore(x);
}
/// Searches for the index of the smallest single-precision floating-point number in the specified tensor.
@@ -421,39 +363,12 @@ public static int IndexOfMaxMagnitude(ReadOnlySpan x)
///
public static int IndexOfMin(ReadOnlySpan x)
{
- int result = -1;
-
- if (!x.IsEmpty)
+ if (x.IsEmpty)
{
- result = 0;
- float min = float.PositiveInfinity;
-
- for (int i = 0; i < x.Length; i++)
- {
- float current = x[i];
-
- if (current != min)
- {
- if (float.IsNaN(current))
- {
- return i;
- }
-
- if (current < min)
- {
- result = i;
- min = current;
- }
- }
- else if (IsNegative(current) && !IsNegative(min))
- {
- result = i;
- min = current;
- }
- }
+ return -1;
}
- return result;
+ return IndexOfMinMaxCore(x);
}
/// Searches for the index of the single-precision floating-point number with the smallest magnitude in the specified tensor.
@@ -472,43 +387,12 @@ public static int IndexOfMin(ReadOnlySpan x)
///
public static int IndexOfMinMagnitude(ReadOnlySpan x)
{
- int result = -1;
-
- if (!x.IsEmpty)
+ if (x.IsEmpty)
{
- result = 0;
- float min = float.PositiveInfinity;
- float minMag = float.PositiveInfinity;
-
- for (int i = 0; i < x.Length; i++)
- {
- float current = x[i];
- float currentMag = Math.Abs(current);
-
- if (currentMag != minMag)
- {
- if (float.IsNaN(currentMag))
- {
- return i;
- }
-
- if (currentMag < minMag)
- {
- result = i;
- min = current;
- minMag = currentMag;
- }
- }
- else if (IsNegative(current) && !IsNegative(min))
- {
- result = i;
- min = current;
- minMag = currentMag;
- }
- }
+ return -1;
}
- return result;
+ return IndexOfMinMaxCore(x);
}
/// Computes the element-wise natural (base e) logarithm of single-precision floating-point numbers in the specified tensor.
diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs
index d77cd743a0713f..e82f25c889c522 100644
--- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs
+++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs
@@ -7,7 +7,6 @@
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;
-using System.Security.Cryptography;
namespace System.Numerics.Tensors
{
@@ -1062,7 +1061,6 @@ private static float MinMaxCore(ReadOnlySpan x)
if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count)
{
ref float xRef = ref MemoryMarshal.GetReference(x);
-
// Load the first vector as the initial set of results, and bail immediately
// to scalar handling if it contains any NaNs (which don't compare equally to themselves).
Vector512 result = Vector512.LoadUnsafe(ref xRef, 0), current;
@@ -1227,6 +1225,229 @@ private static float MinMaxCore(ReadOnlySpan x)
}
}
+ private static int IndexOfMinMaxCore(ReadOnlySpan x) where TIndexOfMinMax : struct, IIndexOfOperator
+ {
+ // This matches the IEEE 754:2019 `maximum`/`minimum` functions.
+ // It propagates NaN inputs back to the caller and
+ // otherwise returns the index of the greater of the inputs.
+ // It treats +0 as greater than -0 as per the specification.
+
+#if NET8_0_OR_GREATER
+ if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count)
+ {
+ ref float xRef = ref MemoryMarshal.GetReference(x);
+ Vector512 resultIndex = Vector512.Create(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
+ Vector512 curIndex = resultIndex;
+ Vector512 increment = Vector512.Create(Vector512.Count);
+
+ // Load the first vector as the initial set of results, and bail immediately
+ // to scalar handling if it contains any NaNs (which don't compare equally to themselves).
+ Vector512 result = Vector512.LoadUnsafe(ref xRef);
+ Vector512 current;
+
+ Vector512 nanMask = ~Vector512.Equals(result, result);
+ if (nanMask != Vector512.Zero)
+ {
+ return IndexOfFirstMatch(nanMask);
+ }
+
+ int oneVectorFromEnd = x.Length - Vector512.Count;
+ int i = Vector512.Count;
+
+ // Aggregate additional vectors into the result as long as there's at least one full vector left to process.
+ while (i <= oneVectorFromEnd)
+ {
+ // Load the next vector, and early exit on NaN.
+ current = Vector512.LoadUnsafe(ref xRef, (uint)i);
+ curIndex += increment;
+
+ nanMask = ~Vector512.Equals(current, current);
+ if (nanMask != Vector512.Zero)
+ {
+ return i + IndexOfFirstMatch(nanMask);
+ }
+
+ TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex);
+
+ i += Vector512.Count;
+ }
+
+ // If any elements remain, handle them in one final vector.
+ if (i != x.Length)
+ {
+ current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count));
+ curIndex += Vector512.Create(x.Length - i);
+
+ nanMask = ~Vector512.Equals(current, current);
+ if (nanMask != Vector512.Zero)
+ {
+ return curIndex[IndexOfFirstMatch(nanMask)];
+ }
+
+ TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex);
+ }
+
+ // Aggregate the lanes in the vector to create the final scalar result.
+ return TIndexOfMinMax.Invoke(result, resultIndex);
+ }
+#endif
+
+ if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count)
+ {
+ ref float xRef = ref MemoryMarshal.GetReference(x);
+ Vector256 resultIndex = Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7);
+ Vector256 curIndex = resultIndex;
+ Vector256 increment = Vector256.Create(Vector256.Count);
+
+ // Load the first vector as the initial set of results, and bail immediately
+ // to scalar handling if it contains any NaNs (which don't compare equally to themselves).
+ Vector256 result = Vector256.LoadUnsafe(ref xRef);
+ Vector256 current;
+
+ Vector256 nanMask = ~Vector256.Equals(result, result);
+ if (nanMask != Vector256.Zero)
+ {
+ return IndexOfFirstMatch(nanMask);
+ }
+
+ int oneVectorFromEnd = x.Length - Vector256.Count;
+ int i = Vector256.Count;
+
+ // Aggregate additional vectors into the result as long as there's at least one full vector left to process.
+ while (i <= oneVectorFromEnd)
+ {
+ // Load the next vector, and early exit on NaN.
+ current = Vector256.LoadUnsafe(ref xRef, (uint)i);
+ curIndex += increment;
+
+ nanMask = ~Vector256.Equals(current, current);
+ if (nanMask != Vector256.Zero)
+ {
+ return i + IndexOfFirstMatch(nanMask);
+ }
+
+ TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex);
+
+ i += Vector256.Count;
+ }
+
+ // If any elements remain, handle them in one final vector.
+ if (i != x.Length)
+ {
+ current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count));
+ curIndex += Vector256.Create(x.Length - i);
+
+ nanMask = ~Vector256.Equals(current, current);
+ if (nanMask != Vector256.Zero)
+ {
+ return curIndex[IndexOfFirstMatch(nanMask)];
+ }
+
+ TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex);
+ }
+
+ // Aggregate the lanes in the vector to create the final scalar result.
+ return TIndexOfMinMax.Invoke(result, resultIndex);
+ }
+
+ if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count)
+ {
+ ref float xRef = ref MemoryMarshal.GetReference(x);
+ Vector128 resultIndex = Vector128.Create(0, 1, 2, 3);
+ Vector128 curIndex = resultIndex;
+ Vector128 increment = Vector128.Create(Vector128.Count);
+
+ // Load the first vector as the initial set of results, and bail immediately
+ // to scalar handling if it contains any NaNs (which don't compare equally to themselves).
+ Vector128 result = Vector128.LoadUnsafe(ref xRef);
+ Vector128 current;
+
+ Vector128 nanMask = ~Vector128.Equals(result, result);
+ if (nanMask != Vector128.Zero)
+ {
+ return IndexOfFirstMatch(nanMask);
+ }
+
+ int oneVectorFromEnd = x.Length - Vector128.Count;
+ int i = Vector128.Count;
+
+ // Aggregate additional vectors into the result as long as there's at least one full vector left to process.
+ while (i <= oneVectorFromEnd)
+ {
+ // Load the next vector, and early exit on NaN.
+ current = Vector128.LoadUnsafe(ref xRef, (uint)i);
+ curIndex += increment;
+
+ nanMask = ~Vector128.Equals(current, current);
+ if (nanMask != Vector128.Zero)
+ {
+ return i + IndexOfFirstMatch(nanMask);
+ }
+
+ TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex);
+
+ i += Vector128.Count;
+ }
+
+ // If any elements remain, handle them in one final vector.
+ if (i != x.Length)
+ {
+ curIndex += Vector128.Create(x.Length - i);
+
+ current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count));
+
+ nanMask = ~Vector128.Equals(current, current);
+ if (nanMask != Vector128.Zero)
+ {
+ return curIndex[IndexOfFirstMatch(nanMask)];
+ }
+
+ TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex);
+ }
+
+ // Aggregate the lanes in the vector to create the final scalar result.
+ return TIndexOfMinMax.Invoke(result, resultIndex);
+ }
+
+ // Scalar path used when either vectorization is not supported or the input is too small to vectorize.
+ float curResult = x[0];
+ int curIn = 0;
+ if (float.IsNaN(curResult))
+ {
+ return curIn;
+ }
+
+ for (int i = 1; i < x.Length; i++)
+ {
+ float current = x[i];
+ if (float.IsNaN(current))
+ {
+ return i;
+ }
+
+ curIn = TIndexOfMinMax.Invoke(ref curResult, current, curIn, i);
+ }
+
+ return curIn;
+ }
+
+ private static int IndexOfFirstMatch(Vector128 mask)
+ {
+ return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());
+ }
+
+ private static int IndexOfFirstMatch(Vector256 mask)
+ {
+ return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());
+ }
+
+#if NET8_0_OR_GREATER
+ private static int IndexOfFirstMatch(Vector512 mask)
+ {
+ return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits());
+ }
+#endif
+
/// Performs an element-wise operation on and writes the results to .
/// Specifies the operation to perform on each element loaded from .
private static void InvokeSpanIntoSpan(
@@ -7629,6 +7850,26 @@ private static Vector512 IsNegative(Vector512 vector) =>
Vector512.LessThan(vector.AsInt32(), Vector512.Zero).AsSingle();
#endif
+ /// Gets whether the specified is positive.
+ private static bool IsPositive(float f) => float.IsPositive(f);
+
+ /// Gets whether each specified is positive.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static Vector128 IsPositive(Vector128 vector) =>
+ Vector128.GreaterThan(vector.AsInt32(), Vector128.AllBitsSet).AsSingle();
+
+ /// Gets whether each specified is positive.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static Vector256 IsPositive(Vector256 vector) =>
+ Vector256.GreaterThan(vector.AsInt32(), Vector256.AllBitsSet).AsSingle();
+
+#if NET8_0_OR_GREATER
+ /// Gets whether each specified is positive.
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static Vector512 IsPositive(Vector512 vector) =>
+ Vector512.GreaterThan(vector.AsInt32(), Vector512.AllBitsSet).AsSingle();
+#endif
+
/// Finds and returns the first NaN value in .
/// The vector must have already been validated to contain a NaN.
private static float GetFirstNaN(Vector128 vector)
@@ -7637,6 +7878,14 @@ private static float GetFirstNaN(Vector128 vector)
return vector.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits()));
}
+ /// Finds and returns the first NaN index value in .
+ /// The vector must have already been validated to contain a NaN.
+ private static int GetFirstNaNIndex(Vector128 vector, Vector128 index)
+ {
+ Debug.Assert(!Vector128.EqualsAll(vector, vector), "Expected vector to contain a NaN");
+ return index.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits()));
+ }
+
/// Finds and returns the first NaN value in .
/// The vector must have already been validated to contain a NaN.
private static float GetFirstNaN(Vector256 vector)
@@ -7645,6 +7894,14 @@ private static float GetFirstNaN(Vector256 vector)
return vector.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits()));
}
+ /// Finds and returns the first NaN index value in .
+ /// The vector must have already been validated to contain a NaN.
+ private static int GetFirstNaNIndex(Vector256 vector, Vector256 index)
+ {
+ Debug.Assert(!Vector256.EqualsAll(vector, vector), "Expected vector to contain a NaN");
+ return index.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits()));
+ }
+
#if NET8_0_OR_GREATER
/// Finds and returns the first NaN value in .
/// The vector must have already been validated to contain a NaN.
@@ -7653,6 +7910,14 @@ private static float GetFirstNaN(Vector512 vector)
Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN");
return vector.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits()));
}
+
+ /// Finds and returns the first NaN value in .
+ /// The vector must have already been validated to contain a NaN.
+ private static int GetFirstNaNIndex(Vector512 vector, Vector512 index)
+ {
+ Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN");
+ return index.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits()));
+ }
#endif
/// Gets the base 2 logarithm of .
@@ -7824,6 +8089,514 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) =>
#endif
}
+ private interface IIndexOfOperator
+ {
+ static abstract int Invoke(ref float result, float current, int resultIndex, int curIndex);
+ static abstract int Invoke(Vector128 result, Vector128 resultIndex);
+ static abstract void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex);
+ static abstract int Invoke(Vector256 result, Vector256 resultIndex);
+ static abstract void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex);
+#if NET8_0_OR_GREATER
+ static abstract int Invoke(Vector512 result, Vector512 resultIndex);
+ static abstract void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex);
+#endif
+ }
+
+ /// Returns the index of MathF.Max(x, y)
+ private readonly struct IndexOfMaxOperator : IIndexOfOperator
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector128 result, Vector128 maxIndex)
+ {
+ Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1));
+ Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1));
+
+ Invoke(ref result, tmpResult, ref maxIndex, tmpIndex);
+
+ tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2));
+ tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2));
+
+ Invoke(ref result, tmpResult, ref maxIndex, tmpIndex);
+ return maxIndex.ToScalar();
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex)
+ {
+ Vector128 greaterThanMask = Vector128.GreaterThan(max, current);
+
+ Vector128 equalMask = Vector128.Equals(max, current);
+ if (equalMask.AsInt32() != Vector128.Zero)
+ {
+ Vector128 negativeMask = IsNegative(current);
+ Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex);
+
+ greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle());
+ }
+
+ max = ElementWiseSelect(greaterThanMask, max, current);
+
+ maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector256 result, Vector256 maxIndex)
+ {
+ // Max the upper/lower halves of the Vector256
+ Vector128 resultLower = result.GetLower();
+ Vector128 indexLower = maxIndex.GetLower();
+
+ Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper());
+ return Invoke(resultLower, indexLower);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex)
+ {
+ Vector256 greaterThanMask = Vector256.GreaterThan(max, current);
+
+ Vector256 equalMask = Vector256.Equals(max, current);
+ if (equalMask.AsInt32() != Vector256.Zero)
+ {
+ Vector256 negativeMask = IsNegative(current);
+ Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex);
+
+ greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle());
+ }
+
+ max = ElementWiseSelect(greaterThanMask, max, current);
+
+ maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex);
+ }
+
+#if NET8_0_OR_GREATER
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector512 result, Vector512 resultIndex)
+ {
+ // Min the upper/lower halves of the Vector512
+ Vector256 resultLower = result.GetLower();
+ Vector256 indexLower = resultIndex.GetLower();
+
+ Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper());
+ return Invoke(resultLower, indexLower);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex)
+ {
+ Vector512 greaterThanMask = Vector512.GreaterThan(max, current);
+
+ Vector512 equalMask = Vector512.Equals(max, current);
+ if (equalMask.AsInt32() != Vector512.Zero)
+ {
+ Vector512 negativeMask = IsNegative(current);
+ Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex);
+
+ greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle());
+ }
+
+ max = ElementWiseSelect(greaterThanMask, max, current);
+
+ maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex);
+ }
+#endif
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(ref float result, float current, int resultIndex, int curIndex)
+ {
+ if (result == current)
+ {
+ if (IsNegative(result) && !IsNegative(current))
+ {
+ result = current;
+ return curIndex;
+ }
+ }
+ else if (current > result)
+ {
+ result = current;
+ return curIndex;
+ }
+
+ return resultIndex;
+ }
+ }
+
+ private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector128 result, Vector128 maxIndex)
+ {
+ Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1));
+ Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1));
+
+ Invoke(ref result, tmpResult, ref maxIndex, tmpIndex);
+
+ tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2));
+ tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2));
+
+ Invoke(ref result, tmpResult, ref maxIndex, tmpIndex);
+ return maxIndex.ToScalar();
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex)
+ {
+ Vector128 maxMag = Vector128.Abs(max), currentMag = Vector128.Abs(current);
+
+ Vector128 greaterThanMask = Vector128.GreaterThan(maxMag, currentMag);
+
+ Vector128 equalMask = Vector128.Equals(max, current);
+ if (equalMask.AsInt32() != Vector128.Zero)
+ {
+ Vector128 negativeMask = IsNegative(current);
+ Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex);
+
+ greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle());
+ }
+
+ max = ElementWiseSelect(greaterThanMask, max, current);
+
+ maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector256 result, Vector256 maxIndex)
+ {
+ // Max the upper/lower halves of the Vector256
+ Vector128 resultLower = result.GetLower();
+ Vector128 indexLower = maxIndex.GetLower();
+
+ Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper());
+ return Invoke(resultLower, indexLower);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex)
+ {
+ Vector256 maxMag = Vector256.Abs(max), currentMag = Vector256.Abs(current);
+
+ Vector256 greaterThanMask = Vector256.GreaterThan(maxMag, currentMag);
+
+ Vector256 equalMask = Vector256.Equals(max, current);
+ if (equalMask.AsInt32() != Vector256.Zero)
+ {
+ Vector256 negativeMask = IsNegative(current);
+ Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex);
+
+ greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle());
+ }
+
+ max = ElementWiseSelect(greaterThanMask, max, current);
+
+ maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex);
+ }
+
+#if NET8_0_OR_GREATER
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector512 result, Vector512 resultIndex)
+ {
+ // Min the upper/lower halves of the Vector512
+ Vector256 resultLower = result.GetLower();
+ Vector256 indexLower = resultIndex.GetLower();
+
+ Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper());
+ return Invoke(resultLower, indexLower);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex)
+ {
+ Vector512 maxMag = Vector512.Abs(max), currentMag = Vector512.Abs(current);
+ Vector512 greaterThanMask = Vector512.GreaterThan(maxMag, currentMag);
+
+ Vector512 equalMask = Vector512.Equals(max, current);
+ if (equalMask.AsInt32() != Vector512.Zero)
+ {
+ Vector512 negativeMask = IsNegative(current);
+ Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex);
+
+ greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle());
+ }
+
+ max = ElementWiseSelect(greaterThanMask, max, current);
+
+ maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex);
+ }
+#endif
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(ref float result, float current, int resultIndex, int curIndex)
+ {
+ float curMaxAbs = MathF.Abs(result);
+ float currentAbs = MathF.Abs(current);
+
+ if (curMaxAbs == currentAbs)
+ {
+ if (IsNegative(result) && !IsNegative(current))
+ {
+ result = current;
+ return curIndex;
+ }
+ }
+ else if (currentAbs > curMaxAbs)
+ {
+ result = current;
+ return curIndex;
+ }
+
+ return resultIndex;
+ }
+ }
+
+ /// Returns the index of MathF.Min(x, y)
+ private readonly struct IndexOfMinOperator : IIndexOfOperator
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector128 result, Vector128 resultIndex)
+ {
+ Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1));
+ Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1));
+
+ Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
+
+ tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2));
+ tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2));
+
+ Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
+ return resultIndex.ToScalar();
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex)
+ {
+ Vector128 lessThanMask = Vector128.LessThan(result, current);
+
+ Vector128 equalMask = Vector128.Equals(result, current);
+ if (equalMask.AsInt32() != Vector128.Zero)
+ {
+ Vector128 negativeMask = IsNegative(current);
+ Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex);
+
+ lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle());
+ }
+
+ result = ElementWiseSelect(lessThanMask, result, current);
+
+ resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex);
+ }
+
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector256 result, Vector256 resultIndex)
+ {
+ // Min the upper/lower halves of the Vector256
+ Vector128 resultLower = result.GetLower();
+ Vector128 indexLower = resultIndex.GetLower();
+
+ Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper());
+ return Invoke(resultLower, indexLower);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex)
+ {
+ Vector256 lessThanMask = Vector256.LessThan(result, current);
+
+ Vector256 equalMask = Vector256.Equals(result, current);
+ if (equalMask.AsInt32() != Vector256.Zero)
+ {
+ Vector256 negativeMask = IsNegative(current);
+ Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex);
+
+ lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle());
+ }
+
+ result = ElementWiseSelect(lessThanMask, result, current);
+
+ resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex);
+ }
+
+#if NET8_0_OR_GREATER
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector512 result, Vector512 resultIndex)
+ {
+ // Min the upper/lower halves of the Vector512
+ Vector256 resultLower = result.GetLower();
+ Vector256 indexLower = resultIndex.GetLower();
+
+ Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper());
+ return Invoke(resultLower, indexLower);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex)
+ {
+ Vector512 lessThanMask = Vector512.LessThan(result, current);
+
+ Vector512 equalMask = Vector512.Equals(result, current);
+ if (equalMask.AsInt32() != Vector512.Zero)
+ {
+ Vector512 negativeMask = IsNegative(current);
+ Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex);
+
+ lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle());
+ }
+
+ result = ElementWiseSelect(lessThanMask, result, current);
+
+ resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex);
+ }
+#endif
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(ref float result, float current, int resultIndex, int curIndex)
+ {
+ if (result == current)
+ {
+ if (IsPositive(result) && !IsPositive(current))
+ {
+ result = current;
+ return curIndex;
+ }
+ }
+ else if (current < result)
+ {
+ result = current;
+ return curIndex;
+ }
+
+ return resultIndex;
+ }
+ }
+
+ private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector128 result, Vector128 resultIndex)
+ {
+ Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1));
+ Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1));
+
+ Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
+
+ tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2));
+ tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2));
+
+ Invoke(ref result, tmpResult, ref resultIndex, tmpIndex);
+ return resultIndex.ToScalar();
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex)
+ {
+ Vector128 minMag = Vector128.Abs(result), currentMag = Vector128.Abs(current);
+
+ Vector128 lessThanMask = Vector128.LessThan(minMag, currentMag);
+
+ Vector128 equalMask = Vector128.Equals(result, current);
+ if (equalMask.AsInt32() != Vector128.Zero)
+ {
+ Vector128 negativeMask = IsNegative(current);
+ Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex);
+
+ lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle());
+ }
+
+ result = ElementWiseSelect(lessThanMask, result, current);
+
+ resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static int Invoke(Vector256 result, Vector256 resultIndex)
+ {
+ // Min the upper/lower halves of the Vector256
+ Vector128 resultLower = result.GetLower();
+ Vector128 indexLower = resultIndex.GetLower();
+
+ Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper());
+ return Invoke(resultLower, indexLower);
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public static void Invoke(ref Vector256 result, Vector256