diff --git a/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs b/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs
index 44aae76545c..ab750c726a8 100644
--- a/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs
+++ b/src/EFCore.Relational/Query/Pipeline/RelationalShapedQueryOptimizingExpressionVisitors.cs
@@ -51,7 +51,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
var collectionId = _collectionId++;
var selectExpression = (SelectExpression)collectionShaperExpression.Projection.QueryExpression;
// Do pushdown beforehand so it updates all pending collections first
- if (selectExpression.IsDistinct || selectExpression.Limit != null || selectExpression.Offset != null)
+ if (selectExpression.IsDistinct || selectExpression.Limit != null || selectExpression.Offset != null || selectExpression.IsSetOperation)
{
selectExpression.PushdownIntoSubquery();
}
diff --git a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs
index e170d8531a0..a555efe1ede 100644
--- a/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs
+++ b/src/EFCore.Relational/Query/Pipeline/SqlExpressions/SelectExpression.cs
@@ -842,10 +842,12 @@ private SqlBinaryExpression ValidateKeyComparison(SelectExpression inner, SqlBin
return null;
}
+ // We treat a set operation as a transparent wrapper over its left operand (the ColumnExpression projection mappings
+ // found on a set operation SelectExpression are actually those of its left operand).
private bool ContainsTableReference(TableExpressionBase table)
- {
- return _tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table));
- }
+ => IsSetOperation
+ ? ((SelectExpression)Tables[0]).ContainsTableReference(table)
+ : Tables.Any(te => ReferenceEquals(te is JoinExpressionBase jeb ? jeb.Table : te, table));
public void AddInnerJoin(SelectExpression innerSelectExpression, SqlExpression joinPredicate, Type transparentIdentifierType)
{
diff --git a/src/EFCore/Properties/CoreStrings.Designer.cs b/src/EFCore/Properties/CoreStrings.Designer.cs
index a6a242d418d..52c1179dff9 100644
--- a/src/EFCore/Properties/CoreStrings.Designer.cs
+++ b/src/EFCore/Properties/CoreStrings.Designer.cs
@@ -2148,6 +2148,12 @@ public static string UnableToDiscriminate([CanBeNull] object entityType, [CanBeN
GetString("UnableToDiscriminate", nameof(entityType), nameof(discriminator)),
entityType, discriminator);
+ ///
+ /// When performing a set operation, both operands must have the same Include operations.
+ ///
+ public static string SetOperationWithDifferentIncludesInOperands
+ => GetString("SetOperationWithDifferentIncludesInOperands");
+
private static string GetString(string name, params string[] formatterNames)
{
var value = _resourceManager.GetString(name);
diff --git a/src/EFCore/Properties/CoreStrings.resx b/src/EFCore/Properties/CoreStrings.resx
index cd4066e6ad0..1e2ccbe89ab 100644
--- a/src/EFCore/Properties/CoreStrings.resx
+++ b/src/EFCore/Properties/CoreStrings.resx
@@ -1186,4 +1186,7 @@
Unable to materialize entity of type '{entityType}'. No discriminators matched '{discriminator}'.
+
+ When performing a set operation, both operands must have the same Include operations.
+
diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs
index fa633a0e870..202d3153f04 100644
--- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs
+++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpandingVisitor_MethodCall.cs
@@ -7,6 +7,7 @@
using System.Linq.Expressions;
using System.Reflection;
using System.Xml;
+using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
@@ -35,71 +36,85 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
: methodCallExpression.Update(methodCallExpression.Object, new[] { newSource, methodCallExpression.Arguments[1] });
}
- switch (methodCallExpression.Method.Name)
+ if (methodCallExpression.Method.DeclaringType == typeof(Queryable)
+ || methodCallExpression.Method.DeclaringType == typeof(QueryableExtensions)
+ || methodCallExpression.Method.DeclaringType == typeof(Enumerable)
+ || methodCallExpression.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions))
{
- case nameof(Queryable.Where):
- return ProcessWhere(methodCallExpression);
+ switch (methodCallExpression.Method.Name)
+ {
+ case nameof(Queryable.Where):
+ return ProcessWhere(methodCallExpression);
- case nameof(Queryable.Select):
- return ProcessSelect(methodCallExpression);
+ case nameof(Queryable.Select):
+ return ProcessSelect(methodCallExpression);
- case nameof(Queryable.OrderBy):
- case nameof(Queryable.OrderByDescending):
- return ProcessOrderBy(methodCallExpression);
+ case nameof(Queryable.OrderBy):
+ case nameof(Queryable.OrderByDescending):
+ return ProcessOrderBy(methodCallExpression);
- case nameof(Queryable.ThenBy):
- case nameof(Queryable.ThenByDescending):
- return ProcessThenByBy(methodCallExpression);
+ case nameof(Queryable.ThenBy):
+ case nameof(Queryable.ThenByDescending):
+ return ProcessThenByBy(methodCallExpression);
- case nameof(Queryable.Join):
- return ProcessJoin(methodCallExpression);
+ case nameof(Queryable.Join):
+ return ProcessJoin(methodCallExpression);
- case nameof(Queryable.GroupJoin):
- return ProcessGroupJoin(methodCallExpression);
+ case nameof(Queryable.GroupJoin):
+ return ProcessGroupJoin(methodCallExpression);
- case nameof(Queryable.SelectMany):
- return ProcessSelectMany(methodCallExpression);
+ case nameof(Queryable.SelectMany):
+ return ProcessSelectMany(methodCallExpression);
- case nameof(Queryable.All):
- return ProcessAll(methodCallExpression);
+ case nameof(Queryable.All):
+ return ProcessAll(methodCallExpression);
- case nameof(Queryable.Any):
- case nameof(Queryable.Count):
- case nameof(Queryable.LongCount):
- return ProcessAnyCountLongCount(methodCallExpression);
+ case nameof(Queryable.Any):
+ case nameof(Queryable.Count):
+ case nameof(Queryable.LongCount):
+ return ProcessAnyCountLongCount(methodCallExpression);
- case nameof(Queryable.Average):
- case nameof(Queryable.Sum):
- case nameof(Queryable.Min):
- case nameof(Queryable.Max):
- return ProcessAverageSumMinMax(methodCallExpression);
+ case nameof(Queryable.Average):
+ case nameof(Queryable.Sum):
+ case nameof(Queryable.Min):
+ case nameof(Queryable.Max):
+ return ProcessAverageSumMinMax(methodCallExpression);
- case nameof(Queryable.Distinct):
- return ProcessDistinct(methodCallExpression);
+ case nameof(Queryable.Distinct):
+ return ProcessDistinct(methodCallExpression);
- case nameof(Queryable.DefaultIfEmpty):
- return ProcessDefaultIfEmpty(methodCallExpression);
+ case nameof(Queryable.DefaultIfEmpty):
+ return ProcessDefaultIfEmpty(methodCallExpression);
- case nameof(Queryable.First):
- case nameof(Queryable.FirstOrDefault):
- case nameof(Queryable.Single):
- case nameof(Queryable.SingleOrDefault):
- return ProcessCardinalityReducingOperation(methodCallExpression);
+ case nameof(Queryable.First):
+ case nameof(Queryable.FirstOrDefault):
+ case nameof(Queryable.Single):
+ case nameof(Queryable.SingleOrDefault):
+ return ProcessCardinalityReducingOperation(methodCallExpression);
- case nameof(Queryable.OfType):
- return ProcessOfType(methodCallExpression);
+ case nameof(Queryable.OfType):
+ return ProcessOfType(methodCallExpression);
- case nameof(Queryable.Skip):
- case nameof(Queryable.Take):
- return ProcessSkipTake(methodCallExpression);
+ case nameof(Queryable.Skip):
+ case nameof(Queryable.Take):
+ return ProcessSkipTake(methodCallExpression);
- case "Include":
- case "ThenInclude":
- return ProcessInclude(methodCallExpression);
+ case nameof(Queryable.Union):
+ case nameof(Queryable.Concat):
+ case nameof(Queryable.Intersect):
+ case nameof(Queryable.Except):
+ return ProcessSetOperation(methodCallExpression);
- default:
- return ProcessUnknownMethod(methodCallExpression);
+ case "Include":
+ case "ThenInclude":
+ return ProcessInclude(methodCallExpression);
+
+ default:
+ return ProcessUnknownMethod(methodCallExpression);
+ }
}
+
+ return ProcessUnknownMethod(methodCallExpression);
}
private Expression ProcessUnknownMethod(MethodCallExpression methodCallExpression)
@@ -820,6 +835,51 @@ private Expression ProcessSkipTake(MethodCallExpression methodCallExpression)
return new NavigationExpansionExpression(rewritten, preProcessResult.state, methodCallExpression.Type);
}
+ private Expression ProcessSetOperation(MethodCallExpression methodCallExpression)
+ {
+ // TODO: We shouldn't terminate if both sides are identical, #16246
+
+ var source1 = VisitSourceExpression(methodCallExpression.Arguments[0]);
+ var preProcessResult1 = PreProcessTerminatingOperation(source1);
+
+ var source2 = VisitSourceExpression(methodCallExpression.Arguments[1]);
+ var preProcessResult2 = PreProcessTerminatingOperation(source2);
+
+ // Extract the includes from each side and compare to make sure they're identical.
+ // We don't allow set operations over operands with different includes.
+ var pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(skipCollectionNavigations: false);
+ pendingIncludeFindingVisitor.Visit(preProcessResult1.state.PendingSelector.Body);
+ var pendingIncludes1 = pendingIncludeFindingVisitor.PendingIncludes;
+
+ pendingIncludeFindingVisitor = new PendingIncludeFindingVisitor(skipCollectionNavigations: false);
+ pendingIncludeFindingVisitor.Visit(preProcessResult2.state.PendingSelector.Body);
+ var pendingIncludes2 = pendingIncludeFindingVisitor.PendingIncludes;
+
+ if (pendingIncludes1.Count != pendingIncludes2.Count)
+ {
+ throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands);
+ }
+
+ foreach (var (i1, i2) in pendingIncludes1.Zip(pendingIncludes2, (i1, i2) => (i1, i2)))
+ {
+ if (i1.SourceMapping.RootEntityType != i2.SourceMapping.RootEntityType
+ || i1.NavTreeNode.Navigation != i2.NavTreeNode.Navigation)
+ {
+ throw new NotSupportedException(CoreStrings.SetOperationWithDifferentIncludesInOperands);
+ }
+ }
+
+ // If the siblings are different types, one is derived from the other the set operation returns the less derived type.
+ // Find that.
+ var clrType1 = preProcessResult1.state.CurrentParameter.Type;
+ var clrType2 = preProcessResult2.state.CurrentParameter.Type;
+ var parentState = clrType1.IsAssignableFrom(clrType2) ? preProcessResult1.state : preProcessResult2.state;
+
+ var rewritten = methodCallExpression.Update(null, new[] { preProcessResult1.source, preProcessResult2.source });
+
+ return new NavigationExpansionExpression(rewritten, parentState, methodCallExpression.Type);
+ }
+
private (Expression source, NavigationExpansionExpressionState state) PreProcessTerminatingOperation(NavigationExpansionExpression source)
{
var applyOrderingsResult = ApplyPendingOrderings(source.Operand, source.State);
diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs
index 828d4ccc9c9..0648f4d476f 100644
--- a/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs
+++ b/src/EFCore/Query/NavigationExpansion/Visitors/NavigationExpansionReducingVisitor.cs
@@ -120,8 +120,8 @@ protected override Expression VisitExtension(Expression extensionExpression)
result = NavigationExpansionHelpers.AddNavigationJoin(
result.source,
result.parameter,
- pendingIncludeNode.Value,
- pendingIncludeNode.Key,
+ pendingIncludeNode.SourceMapping,
+ pendingIncludeNode.NavTreeNode,
navigationExpansionExpression.State,
new List(),
include: true);
diff --git a/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs b/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs
index ffcbde22396..963432490d5 100644
--- a/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs
+++ b/src/EFCore/Query/NavigationExpansion/Visitors/PendingIncludeFindingVisitor.cs
@@ -10,7 +10,15 @@ namespace Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Visitors
public class PendingIncludeFindingVisitor : ExpressionVisitor
{
- public virtual Dictionary PendingIncludes { get; } = new Dictionary();
+ private bool _skipCollectionNavigations;
+
+ public PendingIncludeFindingVisitor(bool skipCollectionNavigations = true)
+ {
+ _skipCollectionNavigations = skipCollectionNavigations;
+ }
+
+ public virtual List<(NavigationTreeNode NavTreeNode, SourceMapping SourceMapping)> PendingIncludes { get; } =
+ new List<(NavigationTreeNode, SourceMapping)>();
protected override Expression VisitMember(MemberExpression memberExpression)
{
@@ -81,14 +89,16 @@ protected override Expression VisitExtension(Expression extensionExpression)
private void FindPendingReferenceIncludes(NavigationTreeNode node, SourceMapping sourceMapping)
{
- if (node.Navigation != null && node.Navigation.IsCollection())
+ if (_skipCollectionNavigations && node.Navigation != null && node.Navigation.IsCollection())
{
return;
}
- if (node.Included == NavigationTreeNodeIncludeMode.ReferencePending && node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete)
+ if (node.ExpansionMode != NavigationTreeNodeExpansionMode.ReferenceComplete
+ && (node.Included == NavigationTreeNodeIncludeMode.ReferencePending
+ || !_skipCollectionNavigations && node.Included == NavigationTreeNodeIncludeMode.Collection))
{
- PendingIncludes[node] = sourceMapping;
+ PendingIncludes.Add((node, sourceMapping));
}
foreach (var child in node.Children)
diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.SetOperations.cs b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.SetOperations.cs
index 2b8a8f0cf36..2f9fd79f91f 100644
--- a/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.SetOperations.cs
+++ b/test/EFCore.Cosmos.FunctionalTests/Query/SimpleQueryCosmosTest.SetOperations.cs
@@ -33,7 +33,11 @@ public override void Union_non_entity(bool isAsync) {}
public override Task Union_with_anonymous_type_projection(bool isAsync) => Task.CompletedTask;
public override Task Select_Union_unrelated(bool isAsync) => Task.CompletedTask;
public override Task Select_Union_different_fields_in_anonymous_with_subquery(bool isAsync) => Task.CompletedTask;
+ public override Task Union_Include(bool isAsync) => Task.CompletedTask;
+ public override Task Include_Union(bool isAsync) => Task.CompletedTask;
public override Task Select_Except_reference_projection(bool isAsync) => Task.CompletedTask;
+ public override void Include_Union_only_on_one_side_throws() {}
+ public override void Include_Union_different_includes_throws() {}
public override Task SubSelect_Union(bool isAsync) => Task.CompletedTask;
public override Task Client_eval_Union_FirstOrDefault(bool isAsync) => Task.CompletedTask;
}
diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.SetOperations.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.SetOperations.cs
index 241d2c10f43..caf56aa4569 100644
--- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.SetOperations.cs
+++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.SetOperations.cs
@@ -250,6 +250,26 @@ public virtual Task Select_Union_different_fields_in_anonymous_with_subquery(boo
.Where(x => x.Foo == "Berlin"),
entryCount: 1);
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Union_Include(bool isAsync)
+ => AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Union(cs.Where(c => c.City == "London"))
+ .Include(c => c.Orders),
+ entryCount: 59);
+
+ [ConditionalTheory]
+ [MemberData(nameof(IsAsyncData))]
+ public virtual Task Include_Union(bool isAsync)
+ => AssertQuery(isAsync, cs => cs
+ .Where(c => c.City == "Berlin")
+ .Include(c => c.Orders)
+ .Union(cs
+ .Where(c => c.City == "London")
+ .Include(c => c.Orders)),
+ entryCount: 59);
+
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_Except_reference_projection(bool isAsync)
@@ -260,6 +280,45 @@ public virtual Task Select_Except_reference_projection(bool isAsync)
.Select(o => o.Customer)),
entryCount: 88);
+ [ConditionalFact]
+ public virtual void Include_Union_only_on_one_side_throws()
+ {
+ using (var ctx = CreateContext())
+ {
+ Assert.Throws(() =>
+ ctx.Customers
+ .Where(c => c.City == "Berlin")
+ .Include(c => c.Orders)
+ .Union(ctx.Customers.Where(c => c.City == "London"))
+ .ToList());
+
+ Assert.Throws(() =>
+ ctx.Customers
+ .Where(c => c.City == "Berlin")
+ .Union(ctx.Customers
+ .Where(c => c.City == "London")
+ .Include(c => c.Orders))
+ .ToList());
+ }
+ }
+
+ [ConditionalFact]
+ public virtual void Include_Union_different_includes_throws()
+ {
+ using (var ctx = CreateContext())
+ {
+ Assert.Throws(() =>
+ ctx.Customers
+ .Where(c => c.City == "Berlin")
+ .Include(c => c.Orders)
+ .Union(ctx.Customers
+ .Where(c => c.City == "London")
+ .Include(c => c.Orders)
+ .ThenInclude(o => o.OrderDetails))
+ .ToList());
+ }
+ }
+
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task SubSelect_Union(bool isAsync)
diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.SetOperations.cs b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.SetOperations.cs
index b7e9b957e81..3bdf17787a8 100644
--- a/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.SetOperations.cs
+++ b/test/EFCore.SqlServer.FunctionalTests/Query/SimpleQuerySqlServerTest.SetOperations.cs
@@ -264,6 +264,44 @@ OFFSET @__p_0 ROWS FETCH NEXT @__p_1 ROWS ONLY
ORDER BY [t0].[Foo]");
}
+ public override async Task Union_Include(bool isAsync)
+ {
+ await base.Union_Include(isAsync);
+
+ AssertSql(
+ @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
+FROM (
+ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+ UNION
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL
+) AS [t]
+LEFT JOIN [Orders] AS [o] ON [t].[CustomerID] = [o].[CustomerID]
+ORDER BY [t].[CustomerID], [o].[OrderID]");
+ }
+
+ public override async Task Include_Union(bool isAsync)
+ {
+ await base.Include_Union(isAsync);
+
+ AssertSql(
+ @"SELECT [t].[CustomerID], [t].[Address], [t].[City], [t].[CompanyName], [t].[ContactName], [t].[ContactTitle], [t].[Country], [t].[Fax], [t].[Phone], [t].[PostalCode], [t].[Region], [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
+FROM (
+ SELECT [c].[CustomerID], [c].[Address], [c].[City], [c].[CompanyName], [c].[ContactName], [c].[ContactTitle], [c].[Country], [c].[Fax], [c].[Phone], [c].[PostalCode], [c].[Region]
+ FROM [Customers] AS [c]
+ WHERE ([c].[City] = N'Berlin') AND [c].[City] IS NOT NULL
+ UNION
+ SELECT [c0].[CustomerID], [c0].[Address], [c0].[City], [c0].[CompanyName], [c0].[ContactName], [c0].[ContactTitle], [c0].[Country], [c0].[Fax], [c0].[Phone], [c0].[PostalCode], [c0].[Region]
+ FROM [Customers] AS [c0]
+ WHERE ([c0].[City] = N'London') AND [c0].[City] IS NOT NULL
+) AS [t]
+LEFT JOIN [Orders] AS [o] ON [t].[CustomerID] = [o].[CustomerID]
+ORDER BY [t].[CustomerID], [o].[OrderID]");
+ }
+
public override async Task Select_Except_reference_projection(bool isAsync)
{
await base.Select_Except_reference_projection(isAsync);