Skip to content

Commit

Permalink
Entity equality support for non-anonymous DTOs
Browse files Browse the repository at this point in the history
Also various modifications to support MemberMemberBinding better.

Fixes #16789
  • Loading branch information
roji committed Jul 27, 2019
1 parent 83030cb commit 5011d20
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -232,37 +232,43 @@ protected override Expression VisitNew(NewExpression newExpression)

protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression)
{
var newExpression = (NewExpression)Visit(memberInitExpression.NewExpression);
var newExpression = VisitAndConvert(memberInitExpression.NewExpression, nameof(VisitMemberInit));
if (newExpression == null)
{
return null;
}

var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count];
var newBindings = new MemberBinding[memberInitExpression.Bindings.Count];
for (var i = 0; i < newBindings.Length; i++)
{
var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i];
if (_clientEval)
newBindings[i] = VisitMemberBinding(memberInitExpression.Bindings[i]);
if (newBindings[i] == null)
{
newBindings[i] = memberAssignment.Update(Visit(memberAssignment.Expression));
return null;
}
else
{
var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member);
_projectionMembers.Push(projectionMember);
}

var visitedExpression = Visit(memberAssignment.Expression);
if (visitedExpression == null)
{
return null;
}
return memberInitExpression.Update(newExpression, newBindings);
}

newBindings[i] = memberAssignment.Update(visitedExpression);
_projectionMembers.Pop();
}
protected override MemberAssignment VisitMemberAssignment(MemberAssignment memberAssignment)
{
if (_clientEval)
{
return memberAssignment.Update(Visit(memberAssignment.Expression));
}

return memberInitExpression.Update(newExpression, newBindings);
var projectionMember = _projectionMembers.Peek().Append(memberAssignment.Member);
_projectionMembers.Push(projectionMember);

var visitedExpression = Visit(memberAssignment.Expression);
if (visitedExpression == null)
{
return null;
}

_projectionMembers.Pop();
return memberAssignment.Update(visitedExpression);
}

// TODO: Debugging
Expand Down
142 changes: 116 additions & 26 deletions src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -13,6 +14,7 @@
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion.Internal;
using Index = System.Index;

namespace Microsoft.EntityFrameworkCore.Query.Internal
{
Expand Down Expand Up @@ -56,13 +58,83 @@ protected override Expression VisitNew(NewExpression newExpression)
var visitedArgs = Visit(newExpression.Arguments);
var visitedExpression = newExpression.Update(visitedArgs.Select(Unwrap));

return (newExpression.Members?.Count ?? 0) == 0
// NewExpression.Members is populated for anonymous types, mapping constructor arguments to the properties
// which receive their values. If not populated, a non-anonymous type is being constructed, and we have no idea where
// its constructor arguments will end up.
if (newExpression.Members == null)
{
return visitedExpression;
}

var entityReferenceInfo = visitedArgs
.Select((a, i) => (Arg: a, Index: i))
.Where(ai => ai.Arg is EntityReferenceExpression)
.ToDictionary(
ai => visitedExpression.Members[ai.Index].Name,
ai => EntityOrDtoType.FromEntityReferenceExpression((EntityReferenceExpression)ai.Arg));

return entityReferenceInfo.Count == 0
? (Expression)visitedExpression
: new EntityReferenceExpression(visitedExpression, visitedExpression.Members
.Select((m, i) => (Member: m, Index: i))
.ToDictionary(
mi => mi.Member.Name,
mi => visitedArgs[mi.Index]));
: new EntityReferenceExpression(visitedExpression, entityReferenceInfo);
}

protected override Expression VisitMemberInit(MemberInitExpression memberInitExpression)
{
var visitedNew = Visit(memberInitExpression.NewExpression);
var (visitedBindings, entityReferenceInfo) = VisitMemberBindings(memberInitExpression.Bindings);
var visitedMemberInit = memberInitExpression.Update((NewExpression)Unwrap(visitedNew), visitedBindings);

return entityReferenceInfo == null
? (Expression)visitedMemberInit
: new EntityReferenceExpression(visitedMemberInit, entityReferenceInfo);

// Visits member bindings, unwrapping expressions and surfacing entity reference information via the dictionary
(IEnumerable<MemberBinding>, Dictionary<string, EntityOrDtoType>) VisitMemberBindings(ReadOnlyCollection<MemberBinding> bindings)
{
var newBindings = new MemberBinding[bindings.Count];
Dictionary<string, EntityOrDtoType> bindingEntityReferenceInfo = null;

for (var i = 0; i < bindings.Count; i++)
{
switch (bindings[i])
{
case MemberAssignment assignment:
var visitedAssignment = VisitMemberAssignment(assignment);
if (visitedAssignment.Expression is EntityReferenceExpression ere)
{
AddEntityReferenceInfo(assignment.Member.Name, EntityOrDtoType.FromEntityReferenceExpression(ere));
}
newBindings[i] = assignment.Update(Unwrap(visitedAssignment.Expression));
continue;

case MemberMemberBinding memberMember:
var (visitedSubBindings, subEntityReferenceInformation) = VisitMemberBindings(memberMember.Bindings);
if (subEntityReferenceInformation != null)
{
AddEntityReferenceInfo(memberMember.Member.Name, EntityOrDtoType.FromDtoType(subEntityReferenceInformation));
}
newBindings[i] = memberMember.Update(visitedSubBindings);
continue;

case MemberListBinding memberList:
throw new NotImplementedException();

default:
throw new InvalidOperationException("Unhandled member binding type: " + bindings[i].BindingType);
}
}

return (newBindings, bindingEntityReferenceInfo);

void AddEntityReferenceInfo(string memberName, EntityOrDtoType entityOrDtoType)
{
if (bindingEntityReferenceInfo == null)
{
bindingEntityReferenceInfo = new Dictionary<string, EntityOrDtoType>();
}
bindingEntityReferenceInfo[memberName] = entityOrDtoType;
}
}
}

protected override Expression VisitMember(MemberExpression memberExpression)
Expand Down Expand Up @@ -238,7 +310,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
methodCallExpression.Update(null, newArguments),
newSourceWrapper.EntityType,
lastNavigation: null,
newSourceWrapper.AnonymousType,
newSourceWrapper.DtoType,
subqueryTraversed: true);
}
}
Expand Down Expand Up @@ -557,10 +629,10 @@ protected virtual Expression RewriteEquality(bool equality, Expression left, Exp
var leftTypeWrapper = left as EntityReferenceExpression;
var rightTypeWrapper = right as EntityReferenceExpression;

// If one of the sides is an anonymous object, or both sides are unknown, abort
// If one of the sides is a DTO, or both sides are unknown, abort
if (leftTypeWrapper == null && rightTypeWrapper == null
|| leftTypeWrapper?.IsAnonymousType == true
|| rightTypeWrapper?.IsAnonymousType == true)
|| leftTypeWrapper?.IsDtoType == true
|| rightTypeWrapper?.IsDtoType == true)
{
return null;
}
Expand Down Expand Up @@ -786,6 +858,25 @@ protected static Expression Unwrap(Expression expression)
_ => expression
};

protected struct EntityOrDtoType
{
public static EntityOrDtoType FromEntityReferenceExpression(EntityReferenceExpression ere)
=> new EntityOrDtoType
{
EntityType = ere.IsEntityType ? ere.EntityType : null,
DtoType = ere.IsDtoType ? ere.DtoType : null
};

public static EntityOrDtoType FromDtoType(Dictionary<string, EntityOrDtoType> dtoType)
=> new EntityOrDtoType { DtoType = dtoType };

public bool IsEntityType => EntityType != null;
public bool IsDto => DtoType != null;

public IEntityType EntityType;
public Dictionary<string, EntityOrDtoType> DtoType;
}

protected class EntityReferenceExpression : Expression
{
public sealed override ExpressionType NodeType => ExpressionType.Extension;
Expand All @@ -808,17 +899,17 @@ protected class EntityReferenceExpression : Expression
private readonly INavigation _lastNavigation;

[CanBeNull]
public Dictionary<string, Expression> AnonymousType { get; }
public Dictionary<string, EntityOrDtoType> DtoType { get; }

public bool SubqueryTraversed { get; }

public bool IsAnonymousType => AnonymousType != null;
public bool IsDtoType => DtoType != null;
public bool IsEntityType => EntityType != null;

public EntityReferenceExpression(Expression underlying, Dictionary<string, Expression> anonymousType)
public EntityReferenceExpression(Expression underlying, Dictionary<string, EntityOrDtoType> dtoType)
{
Underlying = underlying;
AnonymousType = anonymousType;
DtoType = dtoType;
}

public EntityReferenceExpression(Expression underlying, IEntityType entityType)
Expand All @@ -838,13 +929,13 @@ public EntityReferenceExpression(
Expression underlying,
IEntityType entityType,
INavigation lastNavigation,
Dictionary<string, Expression> anonymousType,
Dictionary<string, EntityOrDtoType> dtoType,
bool subqueryTraversed)
{
Underlying = underlying;
EntityType = entityType;
_lastNavigation = lastNavigation;
AnonymousType = anonymousType;
DtoType = dtoType;
SubqueryTraversed = subqueryTraversed;
}

Expand All @@ -866,14 +957,13 @@ public virtual Expression TraverseProperty(string propertyName, Expression desti
: destinationExpression;
}

if (IsAnonymousType)
if (IsDtoType)
{
if (AnonymousType.TryGetValue(propertyName, out var expression)
&& expression is EntityReferenceExpression wrapper)
if (DtoType.TryGetValue(propertyName, out var entityOrDto))
{
return wrapper.IsEntityType
? new EntityReferenceExpression(destinationExpression, wrapper.EntityType)
: new EntityReferenceExpression(destinationExpression, wrapper.AnonymousType);
return entityOrDto.IsEntityType
? new EntityReferenceExpression(destinationExpression, entityOrDto.EntityType)
: new EntityReferenceExpression(destinationExpression, entityOrDto.DtoType);
}

return destinationExpression;
Expand All @@ -883,7 +973,7 @@ public virtual Expression TraverseProperty(string propertyName, Expression desti
}

public EntityReferenceExpression Update(Expression newUnderlying)
=> new EntityReferenceExpression(newUnderlying, EntityType, _lastNavigation, AnonymousType, SubqueryTraversed);
=> new EntityReferenceExpression(newUnderlying, EntityType, _lastNavigation, DtoType, SubqueryTraversed);

protected override Expression VisitChildren(ExpressionVisitor visitor)
=> Update(visitor.Visit(Underlying));
Expand All @@ -896,9 +986,9 @@ public virtual void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.StringBuilder.Append($".EntityType({EntityType})");
}
else if (IsAnonymousType)
else if (IsDtoType)
{
expressionPrinter.StringBuilder.Append(".AnonymousObject");
expressionPrinter.StringBuilder.Append(".DTO");
}

if (SubqueryTraversed)
Expand All @@ -907,7 +997,7 @@ public virtual void Print(ExpressionPrinter expressionPrinter)
}
}

public override string ToString() => $"{Underlying}[{(IsEntityType ? EntityType.ShortName() : "AnonymousObject")}{(SubqueryTraversed ? ", Subquery" : "")}]";
public override string ToString() => $"{Underlying}[{(IsEntityType ? EntityType.ShortName() : "DTO")}{(SubqueryTraversed ? ", Subquery" : "")}]";
}
}
}
55 changes: 31 additions & 24 deletions src/EFCore/Query/Internal/ExpressionEqualityComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,30 +213,9 @@ public virtual int GetHashCode(Expression obj)

hash.Add(memberInitExpression.NewExpression, this);

for (var i = 0; i < memberInitExpression.Bindings.Count; i++)
foreach (var binding in memberInitExpression.Bindings)
{
var memberBinding = memberInitExpression.Bindings[i];

hash.Add(memberBinding.Member);
hash.Add(memberBinding.BindingType);

switch (memberBinding.BindingType)
{
case MemberBindingType.Assignment:
var memberAssignment = (MemberAssignment)memberBinding;
hash.Add(memberAssignment.Expression, this);
break;
case MemberBindingType.ListBinding:
var memberListBinding = (MemberListBinding)memberBinding;
for (var j = 0; j < memberListBinding.Initializers.Count; j++)
{
AddListToHash(ref hash, memberListBinding.Initializers[j].Arguments);
}

break;
default:
throw new NotImplementedException();
}
ProcessMemberBinding(ref hash, binding);
}

break;
Expand Down Expand Up @@ -303,6 +282,34 @@ public virtual int GetHashCode(Expression obj)

return hash.ToHashCode();
}

void ProcessMemberBinding(ref HashCode hash, MemberBinding memberBinding)
{
hash.Add(memberBinding.Member);
hash.Add(memberBinding.BindingType);

switch (memberBinding)
{
case MemberAssignment assignment:
hash.Add(assignment.Expression, this);
break;
case MemberMemberBinding memberMember:
foreach (var subBinding in memberMember.Bindings)
{
ProcessMemberBinding(ref hash, subBinding);
}
break;
case MemberListBinding memberList:
for (var j = 0; j < memberList.Initializers.Count; j++)
{
AddListToHash(ref hash, memberList.Initializers[j].Arguments);
}

break;
default:
throw new InvalidOperationException("Unhandled member binding type: " + memberBinding.BindingType);
}
}
}

private void AddListToHash<T>(ref HashCode hash, IReadOnlyList<T> expressions)
Expand Down Expand Up @@ -711,7 +718,7 @@ private bool CompareBinding(MemberBinding a, MemberBinding b)
case MemberBindingType.MemberBinding:
return CompareMemberMemberBinding((MemberMemberBinding)a, (MemberMemberBinding)b);
default:
throw new NotImplementedException();
throw new InvalidOperationException("Unhandled member binding type: " + a.BindingType);
}
}

Expand Down
Loading

0 comments on commit 5011d20

Please sign in to comment.