Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enumerable.TryGetNonEnumeratedCount (Implements #27183) #48239

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,6 @@ namespace System.Collections.Generic
/// </summary>
internal static partial class EnumerableHelpers
{
/// <summary>
/// Tries to get the count of the enumerable cheaply.
/// </summary>
/// <typeparam name="T">The element type of the source enumerable.</typeparam>
/// <param name="source">The enumerable to count.</param>
/// <param name="count">The count of the enumerable, if it could be obtained cheaply.</param>
/// <returns><c>true</c> if the enumerable could be counted cheaply; otherwise, <c>false</c>.</returns>
internal static bool TryGetCount<T>(IEnumerable<T> source, out int count)
{
Debug.Assert(source != null);

if (source is ICollection<T> collection)
{
count = collection.Count;
return true;
}

if (source is IIListProvider<T> provider)
{
return (count = provider.GetCount(onlyIfCheap: true)) >= 0;
}

count = -1;
return false;
}

/// <summary>
/// Copies items from an enumerable to an array.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public void Reserve(int count)
public bool ReserveOrAdd(IEnumerable<T> items)
{
int itemCount;
if (EnumerableHelpers.TryGetCount(items, out itemCount))
if (System.Linq.Enumerable.TryGetNonEnumeratedCount(items, out itemCount))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: using System.Linq;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid bringing all the enumerable methods into scope since this is a System.Collections namespace.

{
if (itemCount > 0)
{
Expand Down
19 changes: 10 additions & 9 deletions src/libraries/System.Linq.Queryable/tests/Queryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,17 @@ public static void MatchSequencePattern()
typeof(Enumerable),
typeof(Queryable),
new [] {
"ToLookup",
"ToDictionary",
"ToArray",
"AsEnumerable",
"ToList",
nameof(Enumerable.ToLookup),
nameof(Enumerable.ToDictionary),
nameof(Enumerable.ToArray),
nameof(Enumerable.AsEnumerable),
nameof(Enumerable.ToList),
nameof(Enumerable.Append),
nameof(Enumerable.Prepend),
nameof(Enumerable.ToHashSet),
nameof(Enumerable.TryGetNonEnumeratedCount),
"Fold",
"LeftJoin",
"Append",
"Prepend",
"ToHashSet"
}
);

Expand All @@ -140,7 +141,7 @@ public static void MatchSequencePattern()
typeof(Queryable),
typeof(Enumerable),
new [] {
"AsQueryable"
nameof(Queryable.AsQueryable)
}
);

Expand Down
2 changes: 1 addition & 1 deletion src/libraries/System.Linq/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
<PropertyGroup>
<StrongNameKeyId>Microsoft</StrongNameKeyId>
</PropertyGroup>
</Project>
</Project>
1 change: 1 addition & 0 deletions src/libraries/System.Linq/ref/System.Linq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ public static System.Collections.Generic.IEnumerable<
public static System.Linq.ILookup<TKey, TSource> ToLookup<TSource, TKey>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
public static System.Linq.ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Func<TSource, TElement> elementSelector) { throw null; }
public static System.Linq.ILookup<TKey, TElement> ToLookup<TSource, TKey, TElement>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Func<TSource, TElement> elementSelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer) { throw null; }
public static bool TryGetNonEnumeratedCount<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, out int count) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Union<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Union<TSource>(this System.Collections.Generic.IEnumerable<TSource> first, System.Collections.Generic.IEnumerable<TSource> second, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> Where<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
Expand Down
4 changes: 2 additions & 2 deletions src/libraries/System.Linq/src/System/Linq/Concat.SpeedOpt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ private sealed partial class Concat2Iterator<TSource> : ConcatIterator<TSource>
public override int GetCount(bool onlyIfCheap)
{
int firstCount, secondCount;
if (!EnumerableHelpers.TryGetCount(_first, out firstCount))
if (!_first.TryGetNonEnumeratedCount(out firstCount))
{
if (onlyIfCheap)
{
Expand All @@ -23,7 +23,7 @@ public override int GetCount(bool onlyIfCheap)
firstCount = _first.Count();
}

if (!EnumerableHelpers.TryGetCount(_second, out secondCount))
if (!_second.TryGetNonEnumeratedCount(out secondCount))
{
if (onlyIfCheap)
{
Expand Down
53 changes: 53 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/Count.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,59 @@ public static int Count<TSource>(this IEnumerable<TSource> source, Func<TSource,
return count;
}

/// <summary>
/// Attempts to determine the number of elements in a sequence without forcing an enumeration.
/// </summary>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <param name="source">A sequence that contains elements to be counted.</param>
/// <param name="count">
/// When this method returns, contains the count of <paramref name="source" /> if successful,
/// or zero if the method failed to determine the count.</param>
/// <returns>
/// <see langword="true" /> if the count of <paramref name="source"/> can be determined without enumeration;
/// otherwise, <see langword="false" />.
/// </returns>
/// <remarks>
/// The method performs a series of type tests, identifying common subtypes whose
/// count can be determined without enumerating; this includes <see cref="ICollection{T}"/>,
/// <see cref="ICollection"/> as well as internal types used in the LINQ implementation.
///
/// The method is typically a constant-time operation, but ultimately this depends on the complexity
/// characteristics of the underlying collection implementation.
/// </remarks>
public static bool TryGetNonEnumeratedCount<TSource>(this IEnumerable<TSource> source, out int count)
Copy link
Member

@alrz alrz Feb 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be generic? if it does, I think a non-generic overload would also make sense.
(roslyn is considering using this method as part of list pattern lowering, we probably don't want to skip it for IEnumerable see https://github.com/dotnet/csharplang/blob/master/meetings/2021/LDM-2021-02-03.md).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do that, however it becomes more difficult to check for generic interfaces, which would require some form of reflection.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which would require some form of reflection.

That's definitely a no-go because the point of using this is performance. Perhaps those generic interfaces need a non-generic base with Count prop. IMO TryGetNonEnumeratedCount shouldn't care about TSource.

{
if (source == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
}

if (source is ICollection<TSource> collectionoft)
{
count = collectionoft.Count;
return true;
}

if (source is IIListProvider<TSource> listProv)
{
int c = listProv.GetCount(onlyIfCheap: true);
if (c >= 0)
{
count = c;
return true;
}
}

if (source is ICollection collection)
{
count = collection.Count;
return true;
}

count = 0;
return false;
}

public static long LongCount<TSource>(this IEnumerable<TSource> source)
{
if (source == null)
Expand Down
3 changes: 2 additions & 1 deletion src/libraries/System.Linq/tests/ConsistencyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ private static IEnumerable<string> GetExcludedMethods()
nameof(Enumerable.ToArray),
nameof(Enumerable.AsEnumerable),
nameof(Enumerable.ToList),
nameof(Enumerable.ToHashSet),
nameof(Enumerable.TryGetNonEnumeratedCount),
"Fold",
"LeftJoin",
"ToHashSet"
};

return result;
Expand Down
84 changes: 84 additions & 0 deletions src/libraries/System.Linq/tests/CountTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,89 @@ public void NullPredicate_ThrowsArgumentNullException()
Func<int, bool> predicate = null;
AssertExtensions.Throws<ArgumentNullException>("predicate", () => Enumerable.Range(0, 3).Count(predicate));
}

[Fact]
public void NonEnumeratingCount_NullSource_ThrowsArgumentNullException()
{
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IEnumerable<int>)null).TryGetNonEnumeratedCount(out _));
}

[Theory]
[MemberData(nameof(NonEnumeratingCount_SupportedEnumerables))]
public void NonEnumeratingCount_SupportedEnumerables_ShouldReturnExpectedCount<T>(int expectedCount, IEnumerable<T> source)
{
Assert.True(source.TryGetNonEnumeratedCount(out int actualCount));
Assert.Equal(expectedCount, actualCount);
}

[Theory]
[MemberData(nameof(NonEnumeratingCount_UnsupportedEnumerables))]
public void NonEnumeratingCount_UnsupportedEnumerables_ShouldReturnFalse<T>(IEnumerable<T> source)
{
Assert.False(source.TryGetNonEnumeratedCount(out int actualCount));
Assert.Equal(0, actualCount);
}

[Fact]
public void NonEnumeratingCount_ShouldNotEnumerateSource()
{
bool isEnumerated = false;
Assert.False(Source().TryGetNonEnumeratedCount(out int count));
Assert.Equal(0, count);
Assert.False(isEnumerated);

IEnumerable<int> Source()
{
isEnumerated = true;
yield return 42;
}
}

public static IEnumerable<object[]> NonEnumeratingCount_SupportedEnumerables()
{
yield return WrapArgs(4, new int[]{ 1, 2, 3, 4 });
yield return WrapArgs(4, new List<int>(new int[] { 1, 2, 3, 4 }));
yield return WrapArgs(4, new Stack<int>(new int[] { 1, 2, 3, 4 }));

yield return WrapArgs(0, Enumerable.Empty<string>());

if (PlatformDetection.IsSpeedOptimized)
{
yield return WrapArgs(100, Enumerable.Range(1, 100));
yield return WrapArgs(80, Enumerable.Repeat(1, 80));
yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1));
yield return WrapArgs(4, new int[] { 1, 2, 3, 4 }.Select(x => x + 1));
yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1));
yield return WrapArgs(7, Enumerable.Range(1, 20).ToLookup(x => x % 7));
yield return WrapArgs(20, Enumerable.Range(1, 20).Reverse());
yield return WrapArgs(20, Enumerable.Range(1, 20).OrderBy(x => -x));
yield return WrapArgs(20, Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10)));
}

static object[] WrapArgs<T>(int expectedCount, IEnumerable<T> source) => new object[] { expectedCount, source };
}

public static IEnumerable<object[]> NonEnumeratingCount_UnsupportedEnumerables()
{
yield return WrapArgs(Enumerable.Range(1, 100).Where(x => x % 2 == 0));
yield return WrapArgs(Enumerable.Range(1, 100).GroupBy(x => x % 2 == 0));
yield return WrapArgs(new Stack<int>(new int[] { 1, 2, 3, 4 }).Select(x => x + 1));
yield return WrapArgs(Enumerable.Range(1, 100).Distinct());

if (!PlatformDetection.IsSpeedOptimized)
{
yield return WrapArgs(Enumerable.Range(1, 100));
yield return WrapArgs(Enumerable.Repeat(1, 80));
yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1));
yield return WrapArgs(new int[] { 1, 2, 3, 4 }.Select(x => x + 1));
yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1));
yield return WrapArgs(Enumerable.Range(1, 20).ToLookup(x => x % 7));
yield return WrapArgs(Enumerable.Range(1, 20).Reverse());
yield return WrapArgs(Enumerable.Range(1, 20).OrderBy(x => -x));
yield return WrapArgs(Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10)));
}

static object[] WrapArgs<T>(IEnumerable<T> source) => new object[] { source };
}
}
}