Skip to content

Commit

Permalink
Query: Translate GroupBy-Aggregate with condition
Browse files Browse the repository at this point in the history
Resolves #18836
Resolves #11711
  • Loading branch information
smitpatel committed Aug 3, 2020
1 parent a320478 commit fb7db77
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 89 deletions.
252 changes: 184 additions & 68 deletions src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,20 @@ public virtual SqlExpression TranslateAverage([NotNull] Expression expression)
/// <returns> A SQL translation of Count over the given expression. </returns>
public virtual SqlExpression TranslateCount([CanBeNull] Expression expression = null)
{
if (expression != null)
if (expression == null)
{
// TODO: Translate Count with predicate for GroupBy
return null;
expression = _sqlExpressionFactory.Fragment("*");
}

if (!(expression is SqlExpression sqlExpression))
{
sqlExpression = TranslateInternal(expression);
}

return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
new[] { _sqlExpressionFactory.Fragment("*") },
new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(int)));
Expand All @@ -221,16 +225,20 @@ public virtual SqlExpression TranslateCount([CanBeNull] Expression expression =
/// <returns> A SQL translation of LongCount over the given expression. </returns>
public virtual SqlExpression TranslateLongCount([CanBeNull] Expression expression = null)
{
if (expression != null)
if (expression == null)
{
// TODO: Translate Count with predicate for GroupBy
return null;
expression = _sqlExpressionFactory.Fragment("*");
}

if (!(expression is SqlExpression sqlExpression))
{
sqlExpression = TranslateInternal(expression);
}

return _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
new[] { _sqlExpressionFactory.Fragment("*") },
new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));
Expand Down Expand Up @@ -448,6 +456,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
.GetMappedProjection(projectionBindingExpression.ProjectionMember)
: null;

case GroupByShaperExpression groupByShaperExpression:
return new GroupingElementExpression(groupByShaperExpression.ElementSelector);

default:
return null;
}
Expand Down Expand Up @@ -498,29 +509,124 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// GroupBy Aggregate case
if (methodCallExpression.Object == null
&& methodCallExpression.Method.DeclaringType == typeof(Enumerable)
&& methodCallExpression.Arguments.Count > 0
&& methodCallExpression.Arguments[0] is GroupByShaperExpression groupByShaperExpression)
&& methodCallExpression.Arguments.Count > 0)
{
var translatedAggregate = methodCallExpression.Method.Name switch
{
nameof(Enumerable.Average) => TranslateAverage(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Count) => TranslateCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.LongCount) => TranslateLongCount(GetPredicateOnGrouping(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Max) => TranslateMax(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Min) => TranslateMin(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)),
nameof(Enumerable.Sum) => TranslateSum(GetSelectorOnGrouping(methodCallExpression, groupByShaperExpression)),
_ => null
};

if (translatedAggregate == null)
if (Visit(methodCallExpression.Arguments[0]) is GroupingElementExpression groupingElementExpression)
{
throw new InvalidOperationException(
TranslationErrorDetails == null
? CoreStrings.TranslationFailed(methodCallExpression.Print())
: CoreStrings.TranslationFailedWithDetails(methodCallExpression.Print(), TranslationErrorDetails));
}
switch (methodCallExpression.Method.Name)
{
case nameof(Enumerable.Average):
if (methodCallExpression.Arguments.Count == 2)
{
groupingElementExpression = ApplySelector(
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

return TranslateAverage(GetExpressionForAggregation(groupingElementExpression));

case nameof(Enumerable.Count):
if (methodCallExpression.Arguments.Count == 2)
{
groupingElementExpression = ApplyPredicate(
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

return TranslateCount(GetExpressionForAggregation(groupingElementExpression, starProjection: true));

case nameof(Enumerable.LongCount):
if (methodCallExpression.Arguments.Count == 2)
{
groupingElementExpression = ApplyPredicate(
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

return TranslateLongCount(GetExpressionForAggregation(groupingElementExpression, starProjection: true));

case nameof(Enumerable.Max):
if (methodCallExpression.Arguments.Count == 2)
{
groupingElementExpression = ApplySelector(
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

return TranslateMax(GetExpressionForAggregation(groupingElementExpression));

case nameof(Enumerable.Select):
return ApplySelector(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());

case nameof(Enumerable.Min):
if (methodCallExpression.Arguments.Count == 2)
{
groupingElementExpression = ApplySelector(
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

return TranslateMin(GetExpressionForAggregation(groupingElementExpression));

case nameof(Enumerable.Sum):
if (methodCallExpression.Arguments.Count == 2)
{
groupingElementExpression = ApplySelector(
groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());
}

return TranslateSum(GetExpressionForAggregation(groupingElementExpression));

case nameof(Enumerable.Where):
return ApplyPredicate(groupingElementExpression, methodCallExpression.Arguments[1].UnwrapLambdaFromQuote());

default:
return null;
}

GroupingElementExpression ApplyPredicate(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
{
var predicate = TranslateInternal(RemapLambda(groupingElementExpression, lambdaExpression));

return groupingElementExpression.ApplyPredicate(predicate);
}

GroupingElementExpression ApplySelector(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
{
var selector = RemapLambda(groupingElementExpression, lambdaExpression);

return groupingElementExpression.ApplySelector(selector);
}

static Expression RemapLambda(GroupingElementExpression groupingElement, LambdaExpression lambdaExpression)
=> ReplacingExpressionVisitor.Replace(
lambdaExpression.Parameters[0],
groupingElement.Element,
lambdaExpression.Body);

SqlExpression GetExpressionForAggregation(GroupingElementExpression groupingElement, bool starProjection = false)
{
var selector = TranslateInternal(groupingElement.Element);

if (selector == null
&& starProjection)
{
selector = _sqlExpressionFactory.Fragment("*");
}

if (groupingElement.Predicate != null)
{
if (selector is SqlFragmentExpression)
{
selector = _sqlExpressionFactory.Constant(1);
}

return _sqlExpressionFactory.Case(
new List<CaseWhenClause>
{
new CaseWhenClause(groupingElement.Predicate, selector)
},
elseResult: null);
}

return translatedAggregate;
return selector;
}
}
}

// Subquery case
Expand Down Expand Up @@ -990,46 +1096,6 @@ private SqlExpression BindProperty(EntityReferenceExpression entityReferenceExpr
return null;
}

private static Expression GetSelectorOnGrouping(
MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
{
if (methodCallExpression.Arguments.Count == 1)
{
return groupByShaperExpression.ElementSelector;
}

if (methodCallExpression.Arguments.Count == 2)
{
var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
return ReplacingExpressionVisitor.Replace(
selectorLambda.Parameters[0],
groupByShaperExpression.ElementSelector,
selectorLambda.Body);
}

throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

private static Expression GetPredicateOnGrouping(
MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
{
if (methodCallExpression.Arguments.Count == 1)
{
return null;
}

if (methodCallExpression.Arguments.Count == 2)
{
var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
return ReplacingExpressionVisitor.Replace(
selectorLambda.Parameters[0],
groupByShaperExpression.ElementSelector,
selectorLambda.Body);
}

throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
}

private static Expression TryRemoveImplicitConvert(Expression expression)
{
if (expression is UnaryExpression unaryExpression
Expand Down Expand Up @@ -1412,6 +1478,56 @@ public Expression Convert(Type type)
}
}

private sealed class GroupingElementExpression : Expression
{
public GroupingElementExpression(Expression element)
{
Element = element;
}
public Expression Element { get; private set; }
public bool IsDistinct { get; private set; }
public SqlExpression Predicate { get; private set; }

public GroupingElementExpression ApplyDistinct()
{
IsDistinct = true;

return this;
}

public GroupingElementExpression ApplySelector(Expression expression)
{
Element = expression;

return this;
}

public GroupingElementExpression ApplyPredicate(SqlExpression expression)
{
Check.NotNull(expression, nameof(expression));

if (expression is SqlConstantExpression sqlConstant
&& sqlConstant.Value is bool boolValue
&& boolValue)
{
return this;
}

Predicate = Predicate == null
? expression
: new SqlBinaryExpression(
ExpressionType.AndAlso,
Predicate,
expression,
typeof(bool),
expression.TypeMapping);

return this;
}
public override Type Type => typeof(IEnumerable<>).MakeGenericType(Element.Type);
public override ExpressionType NodeType => ExpressionType.Extension;
}

private sealed class SqlTypeMappingVerifyingExpressionVisitor : ExpressionVisitor
{
protected override Expression VisitExtension(Expression extensionExpression)
Expand Down
4 changes: 3 additions & 1 deletion src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ public ConstantVerifyingExpressionVisitor(ITypeMappingSource typeMappingSource)
private bool ValidConstant(ConstantExpression constantExpression)
{
return constantExpression.Value == null
|| _typeMappingSource.FindMapping(constantExpression.Type) != null;
|| _typeMappingSource.FindMapping(constantExpression.Type) != null
|| constantExpression.Value is Array array
&& array.Length == 0;
}

protected override Expression VisitConstant(ConstantExpression constantExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public virtual Task GroupBy_Property_Select_Average(bool async)

[ConditionalTheory(Skip = "issue #18923")]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Property_Select_Average_with_navigation_expansion(bool async)
public virtual Task GroupBy_Property_Select_Average_with_group_enumerable_projected(bool async)
{
return AssertQueryScalar(
async,
Expand Down Expand Up @@ -2007,7 +2007,7 @@ public virtual Task Distinct_GroupBy_OrderBy_key(bool async)
assertOrder: true);
}

[ConditionalTheory(Skip = "Issue #18923")]
[ConditionalTheory(Skip = "Issue #15873")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_nested_collection_with_groupby(bool async)
{
Expand All @@ -2020,7 +2020,7 @@ public virtual Task Select_nested_collection_with_groupby(bool async)
: Array.Empty<int>()));
}

[ConditionalTheory(Skip = "Issue #18923")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Select_GroupBy_All(bool async)
{
Expand Down Expand Up @@ -2050,7 +2050,7 @@ public override bool Equals(object obj)
public override int GetHashCode() => Order.GetHashCode();
}

[ConditionalTheory(Skip = "Issue #18836")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_Where_in_aggregate(bool async)
{
Expand Down Expand Up @@ -2460,9 +2460,9 @@ public virtual Task Count_after_GroupBy_aggregate(bool async)
ss => ss.Set<Order>().GroupBy(o => o.CustomerID).Select(g => g.Sum(gg => gg.OrderID)).CountAsync(default));
}

[ConditionalTheory(Skip = "Issue #18836")]
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task LongCount_after_client_GroupBy(bool async)
public virtual Task LongCount_after_GroupBy_aggregate(bool async)
{
return AssertSingleResult(
async,
Expand Down Expand Up @@ -2662,7 +2662,7 @@ public virtual Task Complex_query_with_groupBy_in_subquery3(bool async)
}

// also 15279
[ConditionalTheory(Skip = "issue #11711")]
[ConditionalTheory(Skip = "issue #15873")]
[MemberData(nameof(IsAsyncData))]
public virtual Task Complex_query_with_groupBy_in_subquery4(bool async)
{
Expand Down
Loading

0 comments on commit fb7db77

Please sign in to comment.