diff --git a/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt new file mode 100644 index 00000000000000..a8f2d0192cfec9 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/ReferenceAssemblyExclusions.txt @@ -0,0 +1,2 @@ +M:System.Numerics.Tensors.TensorPrimitives.ConvertToHalf(System.ReadOnlySpan{System.Single},System.Span{System.Half}) +M:System.Numerics.Tensors.TensorPrimitives.ConvertToSingle(System.ReadOnlySpan{System.Half},System.Span{System.Single}) \ No newline at end of file diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index b0449170860b45..b87938ff24586b 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -10,6 +10,7 @@ Once this package has shipped a stable version, the following line should be removed in order to re-enable validation. --> true + ReferenceAssemblyExclusions.txt diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs index 6bda8b2c900e16..69562bee25124d 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs @@ -316,39 +316,12 @@ public static void Exp(ReadOnlySpan x, Span destination) => /// public static int IndexOfMax(ReadOnlySpan x) { - int result = -1; - - if (!x.IsEmpty) + if (x.IsEmpty) { - result = 0; - float max = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - float current = x[i]; - - if (current != max) - { - if (float.IsNaN(current)) - { - return i; - } - - if (max < current) - { - result = i; - max = current; - } - } - else if (IsNegative(max) && !IsNegative(current)) - { - result = i; - max = current; - } - } + return -1; } - return result; + return IndexOfMinMaxCore(x); } /// Searches for the index of the single-precision floating-point number with the largest magnitude in the specified tensor. @@ -367,43 +340,12 @@ public static int IndexOfMax(ReadOnlySpan x) /// public static int IndexOfMaxMagnitude(ReadOnlySpan x) { - int result = -1; - - if (!x.IsEmpty) + if (x.IsEmpty) { - result = 0; - float max = float.NegativeInfinity; - float maxMag = float.NegativeInfinity; - - for (int i = 0; i < x.Length; i++) - { - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != maxMag) - { - if (float.IsNaN(currentMag)) - { - return i; - } - - if (maxMag < currentMag) - { - result = i; - max = current; - maxMag = currentMag; - } - } - else if (IsNegative(max) && !IsNegative(current)) - { - result = i; - max = current; - maxMag = currentMag; - } - } + return -1; } - return result; + return IndexOfMinMaxCore(x); } /// Searches for the index of the smallest single-precision floating-point number in the specified tensor. @@ -421,39 +363,12 @@ public static int IndexOfMaxMagnitude(ReadOnlySpan x) /// public static int IndexOfMin(ReadOnlySpan x) { - int result = -1; - - if (!x.IsEmpty) + if (x.IsEmpty) { - result = 0; - float min = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - float current = x[i]; - - if (current != min) - { - if (float.IsNaN(current)) - { - return i; - } - - if (current < min) - { - result = i; - min = current; - } - } - else if (IsNegative(current) && !IsNegative(min)) - { - result = i; - min = current; - } - } + return -1; } - return result; + return IndexOfMinMaxCore(x); } /// Searches for the index of the single-precision floating-point number with the smallest magnitude in the specified tensor. @@ -472,43 +387,12 @@ public static int IndexOfMin(ReadOnlySpan x) /// public static int IndexOfMinMagnitude(ReadOnlySpan x) { - int result = -1; - - if (!x.IsEmpty) + if (x.IsEmpty) { - result = 0; - float min = float.PositiveInfinity; - float minMag = float.PositiveInfinity; - - for (int i = 0; i < x.Length; i++) - { - float current = x[i]; - float currentMag = Math.Abs(current); - - if (currentMag != minMag) - { - if (float.IsNaN(currentMag)) - { - return i; - } - - if (currentMag < minMag) - { - result = i; - min = current; - minMag = currentMag; - } - } - else if (IsNegative(current) && !IsNegative(min)) - { - result = i; - min = current; - minMag = currentMag; - } - } + return -1; } - return result; + return IndexOfMinMaxCore(x); } /// Computes the element-wise natural (base e) logarithm of single-precision floating-point numbers in the specified tensor. diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index d77cd743a0713f..e82f25c889c522 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -7,7 +7,6 @@ using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.Arm; using System.Runtime.Intrinsics.X86; -using System.Security.Cryptography; namespace System.Numerics.Tensors { @@ -1062,7 +1061,6 @@ private static float MinMaxCore(ReadOnlySpan x) if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) { ref float xRef = ref MemoryMarshal.GetReference(x); - // Load the first vector as the initial set of results, and bail immediately // to scalar handling if it contains any NaNs (which don't compare equally to themselves). Vector512 result = Vector512.LoadUnsafe(ref xRef, 0), current; @@ -1227,6 +1225,229 @@ private static float MinMaxCore(ReadOnlySpan x) } } + private static int IndexOfMinMaxCore(ReadOnlySpan x) where TIndexOfMinMax : struct, IIndexOfOperator + { + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the index of the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated && x.Length >= Vector512.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector512 resultIndex = Vector512.Create(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + Vector512 curIndex = resultIndex; + Vector512 increment = Vector512.Create(Vector512.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector512 result = Vector512.LoadUnsafe(ref xRef); + Vector512 current; + + Vector512 nanMask = ~Vector512.Equals(result, result); + if (nanMask != Vector512.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector512.Count; + int i = Vector512.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector512.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector512.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector512.LoadUnsafe(ref xRef, (uint)(x.Length - Vector512.Count)); + curIndex += Vector512.Create(x.Length - i); + + nanMask = ~Vector512.Equals(current, current); + if (nanMask != Vector512.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } +#endif + + if (Vector256.IsHardwareAccelerated && x.Length >= Vector256.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector256 resultIndex = Vector256.Create(0, 1, 2, 3, 4, 5, 6, 7); + Vector256 curIndex = resultIndex; + Vector256 increment = Vector256.Create(Vector256.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector256 result = Vector256.LoadUnsafe(ref xRef); + Vector256 current; + + Vector256 nanMask = ~Vector256.Equals(result, result); + if (nanMask != Vector256.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector256.Count; + int i = Vector256.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector256.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector256.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + current = Vector256.LoadUnsafe(ref xRef, (uint)(x.Length - Vector256.Count)); + curIndex += Vector256.Create(x.Length - i); + + nanMask = ~Vector256.Equals(current, current); + if (nanMask != Vector256.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } + + if (Vector128.IsHardwareAccelerated && x.Length >= Vector128.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + Vector128 resultIndex = Vector128.Create(0, 1, 2, 3); + Vector128 curIndex = resultIndex; + Vector128 increment = Vector128.Create(Vector128.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector128 result = Vector128.LoadUnsafe(ref xRef); + Vector128 current; + + Vector128 nanMask = ~Vector128.Equals(result, result); + if (nanMask != Vector128.Zero) + { + return IndexOfFirstMatch(nanMask); + } + + int oneVectorFromEnd = x.Length - Vector128.Count; + int i = Vector128.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = Vector128.LoadUnsafe(ref xRef, (uint)i); + curIndex += increment; + + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return i + IndexOfFirstMatch(nanMask); + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + + i += Vector128.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + curIndex += Vector128.Create(x.Length - i); + + current = Vector128.LoadUnsafe(ref xRef, (uint)(x.Length - Vector128.Count)); + + nanMask = ~Vector128.Equals(current, current); + if (nanMask != Vector128.Zero) + { + return curIndex[IndexOfFirstMatch(nanMask)]; + } + + TIndexOfMinMax.Invoke(ref result, current, ref resultIndex, curIndex); + } + + // Aggregate the lanes in the vector to create the final scalar result. + return TIndexOfMinMax.Invoke(result, resultIndex); + } + + // Scalar path used when either vectorization is not supported or the input is too small to vectorize. + float curResult = x[0]; + int curIn = 0; + if (float.IsNaN(curResult)) + { + return curIn; + } + + for (int i = 1; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return i; + } + + curIn = TIndexOfMinMax.Invoke(ref curResult, current, curIn, i); + } + + return curIn; + } + + private static int IndexOfFirstMatch(Vector128 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } + + private static int IndexOfFirstMatch(Vector256 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } + +#if NET8_0_OR_GREATER + private static int IndexOfFirstMatch(Vector512 mask) + { + return BitOperations.TrailingZeroCount(mask.ExtractMostSignificantBits()); + } +#endif + /// Performs an element-wise operation on and writes the results to . /// Specifies the operation to perform on each element loaded from . private static void InvokeSpanIntoSpan( @@ -7629,6 +7850,26 @@ private static Vector512 IsNegative(Vector512 vector) => Vector512.LessThan(vector.AsInt32(), Vector512.Zero).AsSingle(); #endif + /// Gets whether the specified is positive. + private static bool IsPositive(float f) => float.IsPositive(f); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 IsPositive(Vector128 vector) => + Vector128.GreaterThan(vector.AsInt32(), Vector128.AllBitsSet).AsSingle(); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 IsPositive(Vector256 vector) => + Vector256.GreaterThan(vector.AsInt32(), Vector256.AllBitsSet).AsSingle(); + +#if NET8_0_OR_GREATER + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 IsPositive(Vector512 vector) => + Vector512.GreaterThan(vector.AsInt32(), Vector512.AllBitsSet).AsSingle(); +#endif + /// Finds and returns the first NaN value in . /// The vector must have already been validated to contain a NaN. private static float GetFirstNaN(Vector128 vector) @@ -7637,6 +7878,14 @@ private static float GetFirstNaN(Vector128 vector) return vector.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits())); } + /// Finds and returns the first NaN index value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector128 vector, Vector128 index) + { + Debug.Assert(!Vector128.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector128.Equals(vector, vector)).ExtractMostSignificantBits())); + } + /// Finds and returns the first NaN value in . /// The vector must have already been validated to contain a NaN. private static float GetFirstNaN(Vector256 vector) @@ -7645,6 +7894,14 @@ private static float GetFirstNaN(Vector256 vector) return vector.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits())); } + /// Finds and returns the first NaN index value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector256 vector, Vector256 index) + { + Debug.Assert(!Vector256.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector256.Equals(vector, vector)).ExtractMostSignificantBits())); + } + #if NET8_0_OR_GREATER /// Finds and returns the first NaN value in . /// The vector must have already been validated to contain a NaN. @@ -7653,6 +7910,14 @@ private static float GetFirstNaN(Vector512 vector) Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN"); return vector.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits())); } + + /// Finds and returns the first NaN value in . + /// The vector must have already been validated to contain a NaN. + private static int GetFirstNaNIndex(Vector512 vector, Vector512 index) + { + Debug.Assert(!Vector512.EqualsAll(vector, vector), "Expected vector to contain a NaN"); + return index.GetElement(BitOperations.TrailingZeroCount((~Vector512.Equals(vector, vector)).ExtractMostSignificantBits())); + } #endif /// Gets the base 2 logarithm of . @@ -7824,6 +8089,514 @@ public static Vector512 Invoke(Vector512 x, Vector512 y) => #endif } + private interface IIndexOfOperator + { + static abstract int Invoke(ref float result, float current, int resultIndex, int curIndex); + static abstract int Invoke(Vector128 result, Vector128 resultIndex); + static abstract void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex); + static abstract int Invoke(Vector256 result, Vector256 resultIndex); + static abstract void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex); +#if NET8_0_OR_GREATER + static abstract int Invoke(Vector512 result, Vector512 resultIndex); + static abstract void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex); +#endif + } + + /// Returns the index of MathF.Max(x, y) + private readonly struct IndexOfMaxOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 maxIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex) + { + Vector128 greaterThanMask = Vector128.GreaterThan(max, current); + + Vector128 equalMask = Vector128.Equals(max, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 maxIndex) + { + // Max the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = maxIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex) + { + Vector256 greaterThanMask = Vector256.GreaterThan(max, current); + + Vector256 equalMask = Vector256.Equals(max, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex) + { + Vector512 greaterThanMask = Vector512.GreaterThan(max, current); + + Vector512 equalMask = Vector512.Equals(max, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (current > result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 maxIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(maxIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref maxIndex, tmpIndex); + return maxIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 max, Vector128 current, ref Vector128 maxIndex, Vector128 curIndex) + { + Vector128 maxMag = Vector128.Abs(max), currentMag = Vector128.Abs(current); + + Vector128 greaterThanMask = Vector128.GreaterThan(maxMag, currentMag); + + Vector128 equalMask = Vector128.Equals(max, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanMask = Vector128.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 maxIndex) + { + // Max the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = maxIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, maxIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 max, Vector256 current, ref Vector256 maxIndex, Vector256 curIndex) + { + Vector256 maxMag = Vector256.Abs(max), currentMag = Vector256.Abs(current); + + Vector256 greaterThanMask = Vector256.GreaterThan(maxMag, currentMag); + + Vector256 equalMask = Vector256.Equals(max, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanMask = Vector256.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 max, Vector512 current, ref Vector512 maxIndex, Vector512 curIndex) + { + Vector512 maxMag = Vector512.Abs(max), currentMag = Vector512.Abs(current); + Vector512 greaterThanMask = Vector512.GreaterThan(maxMag, currentMag); + + Vector512 equalMask = Vector512.Equals(max, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanMask = Vector512.LessThan(maxIndex, curIndex); + + greaterThanMask |= (negativeMask & equalMask) | (~IsNegative(max) & equalMask & lessThanMask.AsSingle()); + } + + max = ElementWiseSelect(greaterThanMask, max, current); + + maxIndex = ElementWiseSelect(greaterThanMask.AsInt32(), maxIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMaxAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + + if (curMaxAbs == currentAbs) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs > curMaxAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + /// Returns the index of MathF.Min(x, y) + private readonly struct IndexOfMinOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 resultIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return resultIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex) + { + Vector128 lessThanMask = Vector128.LessThan(result, current); + + Vector128 equalMask = Vector128.Equals(result, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 resultIndex) + { + // Min the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex) + { + Vector256 lessThanMask = Vector256.LessThan(result, current); + + Vector256 equalMask = Vector256.Equals(result, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex) + { + Vector512 lessThanMask = Vector512.LessThan(result, current); + + Vector512 equalMask = Vector512.Equals(result, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (current < result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector128 result, Vector128 resultIndex) + { + Vector128 tmpResult = Vector128.Shuffle(result, Vector128.Create(2, 3, 0, 1)); + Vector128 tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(2, 3, 0, 1)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + + tmpResult = Vector128.Shuffle(result, Vector128.Create(1, 0, 3, 2)); + tmpIndex = Vector128.Shuffle(resultIndex, Vector128.Create(1, 0, 3, 2)); + + Invoke(ref result, tmpResult, ref resultIndex, tmpIndex); + return resultIndex.ToScalar(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector128 result, Vector128 current, ref Vector128 resultIndex, Vector128 curIndex) + { + Vector128 minMag = Vector128.Abs(result), currentMag = Vector128.Abs(current); + + Vector128 lessThanMask = Vector128.LessThan(minMag, currentMag); + + Vector128 equalMask = Vector128.Equals(result, current); + if (equalMask.AsInt32() != Vector128.Zero) + { + Vector128 negativeMask = IsNegative(current); + Vector128 lessThanIndexMask = Vector128.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector256 result, Vector256 resultIndex) + { + // Min the upper/lower halves of the Vector256 + Vector128 resultLower = result.GetLower(); + Vector128 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector256 result, Vector256 current, ref Vector256 resultIndex, Vector256 curIndex) + { + Vector256 minMag = Vector256.Abs(result), currentMag = Vector256.Abs(current); + + Vector256 lessThanMask = Vector256.LessThan(minMag, currentMag); + + Vector256 equalMask = Vector256.Equals(result, current); + if (equalMask.AsInt32() != Vector256.Zero) + { + Vector256 negativeMask = IsNegative(current); + Vector256 lessThanIndexMask = Vector256.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(Vector512 result, Vector512 resultIndex) + { + // Min the upper/lower halves of the Vector512 + Vector256 resultLower = result.GetLower(); + Vector256 indexLower = resultIndex.GetLower(); + + Invoke(ref resultLower, result.GetUpper(), ref indexLower, resultIndex.GetUpper()); + return Invoke(resultLower, indexLower); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static void Invoke(ref Vector512 result, Vector512 current, ref Vector512 resultIndex, Vector512 curIndex) + { + Vector512 minMag = Vector512.Abs(result), currentMag = Vector512.Abs(current); + + Vector512 lessThanMask = Vector512.LessThan(minMag, currentMag); + + Vector512 equalMask = Vector512.Equals(result, current); + if (equalMask.AsInt32() != Vector512.Zero) + { + Vector512 negativeMask = IsNegative(current); + Vector512 lessThanIndexMask = Vector512.LessThan(resultIndex, curIndex); + + lessThanMask |= (~negativeMask & equalMask) | (IsNegative(result) & equalMask & lessThanIndexMask.AsSingle()); + } + + result = ElementWiseSelect(lessThanMask, result, current); + + resultIndex = ElementWiseSelect(lessThanMask.AsInt32(), resultIndex, curIndex); + } +#endif + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMinAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + if (curMinAbs == currentAbs) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs < curMinAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + /// MathF.Max(x, y) private readonly struct MaxPropagateNaNOperator : IBinaryOperator { @@ -9231,6 +10004,62 @@ public static Vector512 Invoke(Vector512 x) #endif } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) + { + if (Sse41.IsSupported) + return Sse41.BlendVariable(left, right, ~mask); + + return Vector128.ConditionalSelect(mask, left, right); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector128 ElementWiseSelect(Vector128 mask, Vector128 left, Vector128 right) + { + if (Sse41.IsSupported) + return Sse41.BlendVariable(left, right, ~mask); + + return Vector128.ConditionalSelect(mask, left, right); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) + { + if (Avx2.IsSupported) + return Avx2.BlendVariable(left, right, ~mask); + + return Vector256.ConditionalSelect(mask, left, right); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector256 ElementWiseSelect(Vector256 mask, Vector256 left, Vector256 right) + { + if (Avx2.IsSupported) + return Avx2.BlendVariable(left, right, ~mask); + + return Vector256.ConditionalSelect(mask, left, right); + } + +#if NET8_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) + { + if (Avx512F.IsSupported) + return Avx512F.BlendVariable(left, right, ~mask); + + return Vector512.ConditionalSelect(mask, left, right); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector512 ElementWiseSelect(Vector512 mask, Vector512 left, Vector512 right) + { + if (Avx512F.IsSupported) + return Avx512F.BlendVariable(left, right, ~mask); + + return Vector512.ConditionalSelect(mask, left, right); + } +#endif + /// 1f / (1f + MathF.Exp(-x)) private readonly struct SigmoidOperator : IUnaryOperator { diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index 5e6e9ac6252e3c..4dc8ffcf56d82c 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -311,6 +311,95 @@ private static float MinMaxCore(ReadOnlySpan x, TMinMaxO return result; } + private static readonly int[] s_0through7 = [0, 1, 2, 3, 4, 5, 6, 7]; + + private static int IndexOfMinMaxCore(ReadOnlySpan x, TIndexOfMinMaxOperator op = default) + where TIndexOfMinMaxOperator : struct, IIndexOfOperator + { + // This matches the IEEE 754:2019 `maximum`/`minimum` functions. + // It propagates NaN inputs back to the caller and + // otherwise returns the index of the greater of the inputs. + // It treats +0 as greater than -0 as per the specification. + + int result; + int i = 0; + + if (Vector.IsHardwareAccelerated && Vector.Count <= 8 && x.Length >= Vector.Count) + { + ref float xRef = ref MemoryMarshal.GetReference(x); + + Vector resultIndex = new Vector(s_0through7); + Vector curIndex = resultIndex; + Vector increment = new Vector(Vector.Count); + + // Load the first vector as the initial set of results, and bail immediately + // to scalar handling if it contains any NaNs (which don't compare equally to themselves). + Vector resultVector = AsVector(ref xRef, 0), current; + if (Vector.EqualsAll(resultVector, resultVector)) + { + int oneVectorFromEnd = x.Length - Vector.Count; + i = Vector.Count; + + // Aggregate additional vectors into the result as long as there's at least one full vector left to process. + while (i <= oneVectorFromEnd) + { + // Load the next vector, and early exit on NaN. + current = AsVector(ref xRef, i); + curIndex = Vector.Add(curIndex, increment); + + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + op.Invoke(ref resultVector, current, ref resultIndex, curIndex); + i += Vector.Count; + } + + // If any elements remain, handle them in one final vector. + if (i != x.Length) + { + curIndex = Vector.Add(curIndex, new Vector(x.Length - i)); + + current = AsVector(ref xRef, x.Length - Vector.Count); + if (!Vector.EqualsAll(current, current)) + { + goto Scalar; + } + + op.Invoke(ref resultVector, current, ref resultIndex, curIndex); + } + + result = op.Invoke(resultVector, resultIndex); + + return result; + } + } + + // Scalar path used when either vectorization is not supported, the input is too small to vectorize, + // or a NaN is encountered. + Scalar: + float curResult = x[i]; + int curIn = i; + if (float.IsNaN(curResult)) + { + return curIn; + } + + for (; i < x.Length; i++) + { + float current = x[i]; + if (float.IsNaN(current)) + { + return i; + } + + curIn = op.Invoke(ref curResult, current, curIn, i); + } + + return curIn; + } + /// Performs an element-wise operation on and writes the results to . /// Specifies the operation to perform on each element loaded from . private static unsafe void InvokeSpanIntoSpan( @@ -2280,6 +2369,20 @@ private static ref Vector AsVector(ref float start, nuint offset) => ref Unsafe.As>( ref Unsafe.Add(ref start, (nint)(offset))); + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref int start, int offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, offset)); + + /// Gets whether the specified is positive. + private static bool IsPositive(float f) => !IsNegative(f); + + /// Gets whether each specified is positive. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Vector IsPositive(Vector vector) => + ((Vector)Vector.GreaterThan(((Vector)vector), Vector.Zero)); + /// Gets whether the specified is negative. private static unsafe bool IsNegative(float f) => *(int*)&f < 0; @@ -2349,6 +2452,283 @@ public Vector Invoke(Vector x, Vector y) public Vector Invoke(Vector x, Vector y) => x / y; } + private interface IIndexOfOperator + { + int Invoke(ref float result, float current, int resultIndex, int curIndex); + int Invoke(Vector result, Vector resultIndex); + void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex); + } + + /// Returns the index of MathF.Max(x, y) + private readonly struct IndexOfMaxOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) + { + float curMax = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (result[i] == curMax && IsNegative(curMax) && !IsNegative(result[i])) + { + curMax = result[i]; + curIn = resultIndex[i]; + } + else if (result[i] > curMax) + { + curMax = result[i]; + curIn = resultIndex[i]; + } + } + + return curIn; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) + { + Vector lessThanMask = Vector.GreaterThan(result, current); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)~negativeMask & equalMask) | ((Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); + + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (current > result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + } + + private readonly struct IndexOfMaxMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMaxAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + + if (curMaxAbs == currentAbs) + { + if (IsNegative(result) && !IsNegative(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs > curMaxAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector maxIndex) + { + float curMax = result[0]; + int curIn = maxIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (MathF.Abs(result[i]) == MathF.Abs(curMax) && IsNegative(curMax) && !IsNegative(result[i])) + { + curMax = result[i]; + curIn = maxIndex[i]; + } + else if (MathF.Abs(result[i]) > MathF.Abs(curMax)) + { + curMax = result[i]; + curIn = maxIndex[i]; + } + } + + return curIn; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) + { + Vector maxMag = Vector.Abs(result), currentMag = Vector.Abs(current); + + Vector lessThanMask = Vector.GreaterThan(maxMag, currentMag); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)~negativeMask & equalMask) | ((Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); + + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } + } + + private readonly struct IndexOfMinOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + if (result == current) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (current < result) + { + result = current; + return curIndex; + } + + return resultIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) + { + float curMin = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (result[i] == curMin && IsPositive(curMin) && !IsPositive(result[i])) + { + curMin = result[i]; + curIn = resultIndex[i]; + } + else if (result[i] < curMin) + { + curMin = result[i]; + curIn = resultIndex[i]; + } + } + + return curIn; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) + { + Vector lessThanMask = Vector.LessThan(result, current); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)negativeMask & equalMask) | (~(Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); + + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } + } + + private readonly struct IndexOfMinMagnitudeOperator : IIndexOfOperator + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(ref float result, float current, int resultIndex, int curIndex) + { + float curMinAbs = MathF.Abs(result); + float currentAbs = MathF.Abs(current); + if (curMinAbs == currentAbs) + { + if (IsPositive(result) && !IsPositive(current)) + { + result = current; + return curIndex; + } + } + else if (currentAbs < curMinAbs) + { + result = current; + return curIndex; + } + + return resultIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int Invoke(Vector result, Vector resultIndex) + { + float curMin = result[0]; + int curIn = resultIndex[0]; + for (int i = 1; i < Vector.Count; i++) + { + if (MathF.Abs(result[i]) == MathF.Abs(curMin) && IsPositive(curMin) && !IsPositive(result[i])) + { + curMin = result[i]; + curIn = resultIndex[i]; + } + else if (MathF.Abs(result[i]) < MathF.Abs(curMin)) + { + curMin = result[i]; + curIn = resultIndex[i]; + } + } + + return curIn; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Invoke(ref Vector result, Vector current, ref Vector resultIndex, Vector curIndex) + { + Vector minMag = Vector.Abs(result), currentMag = Vector.Abs(current); + + Vector lessThanMask = Vector.LessThan(minMag, currentMag); + + Vector equalMask = Vector.Equals(result, current); + + if (equalMask != Vector.Zero) + { + Vector negativeMask = IsNegative(current); + Vector lessThanIndexMask = Vector.LessThan(resultIndex, curIndex); + + lessThanMask |= ((Vector)negativeMask & equalMask) | (~(Vector)IsNegative(result) & equalMask & lessThanIndexMask); + } + + result = Vector.ConditionalSelect(lessThanMask, result, current); + + resultIndex = Vector.ConditionalSelect(lessThanMask, resultIndex, curIndex); + } + } + /// MathF.Max(x, y) (but without guaranteed NaN propagation) private readonly struct MaxOperator : IBinaryOperator {