Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release/9.0-staging] Fix LINQ handling of iterator.Take(...).Last(...) #112714

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,12 @@ public override Iterator<TSource> Take(int count)
{
if (_source is Iterator<TSource> iterator &&
iterator.GetCount(onlyIfCheap: true) is int count &&
count >= _minIndexInclusive)
count > _minIndexInclusive)
{
return !HasLimit ?
// If there's no upper bound, or if there are fewer items in the list
// than the upper bound allows, just return the last element of the list.
// Otherwise, get the element at the upper bound.
return (uint)count <= (uint)_maxIndexInclusive ?
iterator.TryGetLast(out found) :
iterator.TryGetElementAt(_maxIndexInclusive, out found);
}
Expand Down
6 changes: 3 additions & 3 deletions src/libraries/System.Linq/tests/AggregateByTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ public class AggregateByTests : EnumerableTests
[Fact]
public void Empty()
{
Assert.All(IdentityTransforms<int>(), transform =>
Assert.All(CreateSources<int>([]), source =>
{
Assert.Equal(Enumerable.Empty<KeyValuePair<int, int>>(), transform(Enumerable.Empty<int>()).AggregateBy(i => i, i => i, (a, i) => a + i));
Assert.Equal(Enumerable.Empty<KeyValuePair<int, int>>(), transform(Enumerable.Empty<int>()).AggregateBy(i => i, 0, (a, i) => a + i));
Assert.Equal([], source.AggregateBy(i => i, i => i, (a, i) => a + i));
Assert.Equal([], source.AggregateBy(i => i, 0, (a, i) => a + i));
});
}

Expand Down
20 changes: 5 additions & 15 deletions src/libraries/System.Linq/tests/ChunkTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ public void ChunkSourceLazily()
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2, -12345})]
public void ChunkSourceRepeatCalls(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

Assert.Equal(source.Chunk(3), source.Chunk(3));
});
}
Expand All @@ -54,10 +52,8 @@ public void ChunkSourceRepeatCalls(int[] array)
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2, -12345})]
public void ChunkSourceEvenly(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
chunks.MoveNext();
Assert.Equal(new[] { 9999, 0, 888 }, chunks.Current);
Expand All @@ -73,10 +69,8 @@ public void ChunkSourceEvenly(int[] array)
[InlineData(new[] {9999, 0, 888, -1, 66, -777, 1, 2})]
public void ChunkSourceUnevenly(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
chunks.MoveNext();
Assert.Equal(new[] { 9999, 0, 888 }, chunks.Current);
Expand All @@ -92,10 +86,8 @@ public void ChunkSourceUnevenly(int[] array)
[InlineData(new[] {9999, 0})]
public void ChunkSourceSmallerThanMaxSize(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
chunks.MoveNext();
Assert.Equal(new[] { 9999, 0 }, chunks.Current);
Expand All @@ -107,10 +99,8 @@ public void ChunkSourceSmallerThanMaxSize(int[] array)
[InlineData(new int[0])]
public void EmptySourceYieldsNoChunks(int[] array)
{
Assert.All(IdentityTransforms<int>(), t =>
Assert.All(CreateSources(array), source =>
{
IEnumerable<int> source = t(array);

using IEnumerator<int[]> chunks = source.Chunk(3).GetEnumerator();
Assert.False(chunks.MoveNext());
});
Expand Down
29 changes: 6 additions & 23 deletions src/libraries/System.Linq/tests/ConcatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ private static void SameResultsWithQueryAndRepeatCallsWorker<T>(IEnumerable<T> f
first = from item in first select item;
second = from item in second select item;

VerifyEqualsWorker(first.Concat(second), first.Concat(second));
VerifyEqualsWorker(second.Concat(first), second.Concat(first));
Assert.Equal(first.Concat(second), first.Concat(second));
Assert.Equal(second.Concat(first), second.Concat(first));
}

[Theory]
Expand All @@ -41,8 +41,8 @@ private static void SameResultsWithQueryAndRepeatCallsWorker<T>(IEnumerable<T> f
[InlineData(new int[] { 2, 3, 5, 9 }, new int[] { 8, 10 }, new int[] { 2, 3, 5, 9, 8, 10 })] // Neither side is empty
public void PossiblyEmptyInputs(IEnumerable<int> first, IEnumerable<int> second, IEnumerable<int> expected)
{
VerifyEqualsWorker(expected, first.Concat(second));
VerifyEqualsWorker(expected.Skip(first.Count()).Concat(expected.Take(first.Count())), second.Concat(first)); // Swap the inputs around
Assert.Equal(expected, first.Concat(second));
Assert.Equal(expected.Skip(first.Count()).Concat(expected.Take(first.Count())), second.Concat(first)); // Swap the inputs around
}

[Fact]
Expand Down Expand Up @@ -80,7 +80,7 @@ public void SecondNull()
public void VerifyEquals(IEnumerable<int> expected, IEnumerable<int> actual)
{
// workaround: xUnit type inference doesn't work if the input type is not T (like IEnumerable<T>)
VerifyEqualsWorker(expected, actual);
Assert.Equal(expected, actual);
}

[Theory]
Expand Down Expand Up @@ -133,23 +133,6 @@ public void First_Last_ElementAt(IEnumerable<int> _, IEnumerable<int> actual)
}
}

private static void VerifyEqualsWorker<T>(IEnumerable<T> expected, IEnumerable<T> actual)
{
// Returns a list of functions that, when applied to enumerable, should return
// another one that has equivalent contents.
var identityTransforms = IdentityTransforms<T>();

// We run the transforms N^2 times, by testing all transforms
// of expected against all transforms of actual.
foreach (var outTransform in identityTransforms)
{
foreach (var inTransform in identityTransforms)
{
Assert.Equal(outTransform(expected), inTransform(actual));
}
}
}

public static IEnumerable<object[]> ArraySourcesData() => GenerateSourcesData(outerTransform: e => e.ToArray());

public static IEnumerable<object[]> SelectArraySourcesData() => GenerateSourcesData(outerTransform: e => e.Select(i => i).ToArray());
Expand Down Expand Up @@ -292,7 +275,7 @@ public void ManyConcats(IEnumerable<IEnumerable<int>> sources)
}

Assert.Equal(sources.Sum(s => s.Count()), concatee.Count());
VerifyEqualsWorker(sources.SelectMany(s => s), concatee);
Assert.Equal(sources.SelectMany(s => s), concatee);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/libraries/System.Linq/tests/CountTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ public void RunOnce<T>(int count, IEnumerable<T> enumerable)

private static IEnumerable<object[]> EnumerateCollectionTypesAndCounts<T>(int count, IEnumerable<T> enumerable)
{
foreach (var transform in IdentityTransforms<T>())
foreach (IEnumerable<T> source in CreateSources(enumerable))
{
yield return new object[] { count, transform(enumerable) };
yield return [count, source];
}
}

Expand Down
91 changes: 75 additions & 16 deletions src/libraries/System.Linq/tests/EnumerableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using Xunit;
using Xunit.Sdk;

namespace System.Linq.Tests
{
Expand Down Expand Up @@ -243,6 +244,7 @@ protected static IEnumerable<T> FlipIsCollection<T>(IEnumerable<T> source)
{
return source is ICollection<T> ? ForceNotCollection(source) : new List<T>(source);
}

protected static T[] Repeat<T>(Func<int, T> factory, int count)
{
T[] results = new T[count];
Expand Down Expand Up @@ -316,26 +318,83 @@ protected static IEnumerable<IEnumerable<T>> CreateSources<T>(IEnumerable<T> sou
}
}

protected static List<Func<IEnumerable<T>, IEnumerable<T>>> IdentityTransforms<T>()
protected static IEnumerable<Func<IEnumerable<T>, IEnumerable<T>>> IdentityTransforms<T>()
{
// All of these transforms should take an enumerable and produce
// another enumerable with the same contents.
return new List<Func<IEnumerable<T>, IEnumerable<T>>>
// Various collection types all representing the same source.
List<Func<IEnumerable<T>, IEnumerable<T>>> sources =
[
e => e, // original
e => e.ToArray(), // T[]
e => e.ToList(), // List<T>
e => new ReadOnlyCollection<T>(e.ToArray()), // IList<T> that's not List<T>/T[]
e => new TestCollection<T>(e.ToArray()), // ICollection<T> that's not IList<T>
e => new TestReadOnlyCollection<T>(e.ToArray()), // IReadOnlyCollection<T> that's not ICollection<T>
e => ForceNotCollection(e), // IEnumerable<T> with no other interfaces
];
if (typeof(T) == typeof(char))
{
e => e,
e => e.ToArray(),
e => e.ToList(),
e => e.ToList().Take(int.MaxValue),
sources.Add(e => (IEnumerable<T>)(object)string.Concat((IEnumerable<char>)(object)e)); // string
}

// Various transforms that all yield the same elements as the source.
List<Func<IEnumerable<T>, IEnumerable<T>>> transforms =
[
// Append
e =>
{
T[] values = e.ToArray();
return values.Length == 0 ? [] : values[0..^1].Append(values[^1]);
},

// Concat
e => e.Concat(ForceNotCollection<T>([])),
e => ForceNotCollection<T>([]).Concat(e),

// Prepend
e =>
{
T[] values = e.ToArray();
return values.Length == 0 ? [] : values[1..].Prepend(values[0]);
},

// Reverse
e => e.Reverse().Reverse(),

// Select
e => e.Select(i => i),
e => e.Select(i => i).Take(int.MaxValue),
e => e.Select(i => i).Where(i => true),

// SelectMany
e => e.SelectMany<T, T>(i => [i]),

// Take
e => e.Take(int.MaxValue),
e => e.TakeLast(int.MaxValue),
e => e.TakeWhile(i => true),

// Skip
e => e.SkipWhile(i => false),

// Where
e => e.Where(i => true),
e => e.Concat(Array.Empty<T>()),
e => e.Concat(ForceNotCollection(Array.Empty<T>())),
e => ForceNotCollection(e),
e => ForceNotCollection(e).Skip(0),
e => new ReadOnlyCollection<T>(e.ToArray()),
};
];

foreach (Func<IEnumerable<T>, IEnumerable<T>> source in sources)
{
// Yield the source itself.
yield return source;

foreach (Func<IEnumerable<T>, IEnumerable<T>> transform in transforms)
{
// Yield a single transform on the source
yield return e => transform(source(e));

foreach (Func<IEnumerable<T>, IEnumerable<T>> transform2 in transforms)
{
// Yield a second transform on the first transform on the source.
yield return e => transform2(transform(source(e)));
}
}
}
}

protected sealed class DelegateIterator<TSource> : IEnumerable<TSource>, IEnumerator<TSource>
Expand Down
34 changes: 13 additions & 21 deletions src/libraries/System.Linq/tests/SelectManyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -372,31 +372,23 @@ public void ForcedToEnumeratorDoesntEnumerateIndexedResultSel()
Assert.False(en is not null && en.MoveNext());
}

[Theory]
[MemberData(nameof(ParameterizedTestsData))]
public void ParameterizedTests(IEnumerable<int> source, Func<int, IEnumerable<int>> selector)
[Fact]
public void ParameterizedTests()
{
Assert.All(CreateSources(source), source =>
for (int i = 1; i <= 20; i++)
{
var expected = source.Select(i => selector(i)).Aggregate((l, r) => l.Concat(r));
var actual = source.SelectMany(selector);
Assert.All(CreateSources(Enumerable.Range(1, i)), source =>
{
Func<int, IEnumerable<int>> selector = n => Enumerable.Range(i, n);

Assert.Equal(expected, actual);
Assert.Equal(expected.Count(), actual.Count()); // SelectMany may employ an optimized Count implementation.
Assert.Equal(expected.ToArray(), actual.ToArray());
Assert.Equal(expected.ToList(), actual.ToList());
});
}
var expected = source.Select(i => selector(i)).Aggregate((l, r) => l.Concat(r)).ToArray();
var actual = source.SelectMany(selector);

public static IEnumerable<object[]> ParameterizedTestsData()
{
foreach (Func<IEnumerable<int>, IEnumerable<int>> transform in IdentityTransforms<int>())
{
for (int i = 1; i <= 20; i++)
{
Func<int, IEnumerable<int>> selector = n => transform(Enumerable.Range(i, n));
yield return new object[] { Enumerable.Range(1, i), selector };
}
Assert.Equal(expected, actual);
Assert.Equal(expected.Length, actual.Count()); // SelectMany may employ an optimized Count implementation.
Assert.Equal(expected, actual.ToArray());
Assert.Equal(expected, actual.ToList());
});
}
}

Expand Down
Loading
Loading