Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic FromSql() support in new pipeline #15752

Merged
merged 1 commit into from
May 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions src/EFCore.Relational/Extensions/RelationalQueryableExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ namespace Microsoft.EntityFrameworkCore
/// </summary>
public static class RelationalQueryableExtensions
{
internal static readonly MethodInfo FromSqlMethodInfo
private static readonly MethodInfo _fromSqlOnQueryableMethodInfo
= typeof(RelationalQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(FromSqlRaw))
.GetTypeInfo().GetDeclaredMethods(nameof(FromSqlOnQueryable))
.Single();

/// <summary>
Expand Down Expand Up @@ -54,7 +54,10 @@ internal static readonly MethodInfo FromSqlMethodInfo
/// <returns> An <see cref="IQueryable{T}" /> representing the raw SQL query. </returns>
[StringFormatMethod("sql")]
[Obsolete(
"For returning objects from SQL queries using plain strings, use FromSqlRaw instead. For returning objects from SQL queries using interpolated string syntax to create parameters, use FromSqlInterpolated instead.")]
"For returning objects from SQL queries using plain strings, use FromSqlRaw instead. " +
"For returning objects from SQL queries using interpolated string syntax to create parameters, use FromSqlInterpolated instead. " +
"Call either new method directly on the DbSet at the root of the query.",
error: true)]
public static IQueryable<TEntity> FromSql<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotParameterized] RawSqlString sql,
Expand All @@ -68,7 +71,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
return source.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
Expression.Constant(sql.Format),
Expression.Constant(parameters)));
Expand Down Expand Up @@ -96,7 +99,10 @@ public static IQueryable<TEntity> FromSql<TEntity>(
/// <param name="sql"> The interpolated string representing a SQL query. </param>
/// <returns> An <see cref="IQueryable{T}" /> representing the interpolated string SQL query. </returns>
[Obsolete(
"For returning objects from SQL queries using plain strings, use FromSqlRaw instead. For returning objects from SQL queries using interpolated string syntax to create parameters, use FromSqlInterpolated instead.")]
"For returning objects from SQL queries using plain strings, use FromSqlRaw instead. " +
"For returning objects from SQL queries using interpolated string syntax to create parameters, use FromSqlInterpolated instead. " +
"Call either new method directly on the DbSet at the root of the query.",
error: true)]
public static IQueryable<TEntity> FromSql<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] [NotParameterized] FormattableString sql)
Expand All @@ -109,7 +115,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
return source.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
Expression.Constant(sql.Format),
Expression.Constant(sql.GetArguments())));
Expand Down Expand Up @@ -145,7 +151,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
/// <returns> An <see cref="IQueryable{T}" /> representing the raw SQL query. </returns>
[StringFormatMethod("sql")]
public static IQueryable<TEntity> FromSqlRaw<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] this DbSet<TEntity> source,
[NotParameterized] string sql,
[NotNull] params object[] parameters)
where TEntity : class
Expand All @@ -154,11 +160,12 @@ public static IQueryable<TEntity> FromSqlRaw<TEntity>(
Check.NotEmpty(sql, nameof(sql));
Check.NotNull(parameters, nameof(parameters));

return source.Provider.CreateQuery<TEntity>(
var queryableSource = (IQueryable)source;
return queryableSource.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
queryableSource.Expression,
Expression.Constant(sql),
Expression.Constant(parameters)));
}
Expand All @@ -185,21 +192,29 @@ public static IQueryable<TEntity> FromSqlRaw<TEntity>(
/// <param name="sql"> The interpolated string representing a SQL query with parameters. </param>
/// <returns> An <see cref="IQueryable{T}" /> representing the interpolated string SQL query. </returns>
public static IQueryable<TEntity> FromSqlInterpolated<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotNull] this DbSet<TEntity> source,
[NotNull] [NotParameterized] FormattableString sql)
where TEntity : class
{
Check.NotNull(source, nameof(source));
Check.NotNull(sql, nameof(sql));
Check.NotEmpty(sql.Format, nameof(source));

return source.Provider.CreateQuery<TEntity>(
var queryableSource = (IQueryable)source;
return queryableSource.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
FromSqlMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
queryableSource.Expression,
Expression.Constant(sql.Format),
Expression.Constant(sql.GetArguments())));
}

internal static IQueryable<TEntity> FromSqlOnQueryable<TEntity>(
[NotNull] this IQueryable<TEntity> source,
[NotParameterized] string sql,
[NotNull] params object[] parameters)
where TEntity : class
=> throw new NotSupportedException();
}
}
16 changes: 16 additions & 0 deletions src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,22 @@ protected override Expression VisitTable(TableExpression tableExpression)
return tableExpression;
}

protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression)
{
_relationalCommandBuilder.AppendLine("(");

using (_relationalCommandBuilder.Indent())
{
_relationalCommandBuilder.AppendLines(fromSqlExpression.Sql);
// TODO: Generate parameters
}

_relationalCommandBuilder.Append(") AS ")
.Append(_sqlGenerationHelper.DelimitIdentifier(fromSqlExpression.Alias));

return fromSqlExpression;
}

protected override Expression VisitSqlBinary(SqlBinaryExpression sqlBinaryExpression)
{
if (sqlBinaryExpression.OperatorType == ExpressionType.Coalesce)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,42 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query.Pipeline;
using Microsoft.EntityFrameworkCore.Internal;

namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline
{
public class RelationalEntityQueryableExpressionVisitor2 : EntityQueryableExpressionVisitor2
{
private IModel _model;
private readonly IModel _model;

public RelationalEntityQueryableExpressionVisitor2(IModel model)
{
_model = model;
}

protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType)
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
return new RelationalShapedQueryExpression(_model.FindEntityType(elementType));
if (methodCallExpression.Method.DeclaringType == typeof(RelationalQueryableExtensions)
&& methodCallExpression.Method.Name == nameof(RelationalQueryableExtensions.FromSqlOnQueryable))
{
// TODO: Implement parameters
var sql = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value;
var queryable = (IQueryable)((ConstantExpression)methodCallExpression.Arguments[0]).Value;
return CreateShapedQueryExpression(queryable.ElementType, sql);
}

return base.VisitMethodCall(methodCallExpression);
}

protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType)
=> new RelationalShapedQueryExpression(_model.FindEntityType(elementType));

protected virtual ShapedQueryExpression CreateShapedQueryExpression(Type elementType, string sql)
=> new RelationalShapedQueryExpression(_model.FindEntityType(elementType), sql);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,17 @@ public RelationalShapedQueryExpression(IEntityType entityType)
typeof(ValueBuffer)),
false);
}

public RelationalShapedQueryExpression(IEntityType entityType, string sql)
{
QueryExpression = new SelectExpression(entityType, sql);
ShaperExpression = new EntityShaperExpression(
entityType,
new ProjectionBindingExpression(
QueryExpression,
new ProjectionMember(),
typeof(ValueBuffer)),
false);
}
}
}
4 changes: 4 additions & 0 deletions src/EFCore.Relational/Query/Pipeline/SqlExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ protected override Expression VisitExtension(Expression extensionExpression)
case ExistsExpression existsExpression:
return VisitExists(existsExpression);

case FromSqlExpression fromSqlExpression:
return VisitFromSql(fromSqlExpression);

case InExpression inExpression:
return VisitIn(inExpression);

Expand Down Expand Up @@ -76,6 +79,7 @@ protected override Expression VisitExtension(Expression extensionExpression)
protected abstract Expression VisitExists(ExistsExpression existsExpression);
protected abstract Expression VisitIn(InExpression inExpression);
protected abstract Expression VisitCrossJoin(CrossJoinExpression crossJoinExpression);
protected abstract Expression VisitFromSql(FromSqlExpression fromSqlExpression);
protected abstract Expression VisitInnerJoin(InnerJoinExpression innerJoinExpression);
protected abstract Expression VisitLeftJoin(LeftJoinExpression leftJoinExpression);
protected abstract Expression VisitProjection(ProjectionExpression projectionExpression);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System.Linq.Expressions;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Relational.Query.Pipeline.SqlExpressions
{
public class FromSqlExpression : TableExpressionBase
{
#region Fields & Constructors
public FromSqlExpression(
[NotNull] string sql,
[NotNull] string alias)
: base(alias)
{
Sql = sql;
}
#endregion

#region Public Properties

/// <summary>
/// Gets the SQL.
/// </summary>
/// <value>
/// The SQL.
/// </value>
public string Sql { get; }

#endregion

#region Expression-based methods

protected override Expression VisitChildren(ExpressionVisitor visitor)
=> this;

public override void Print(ExpressionPrinter expressionPrinter)
=> expressionPrinter.StringBuilder.Append(Sql);

#endregion

#region Equality & HashCode

public override bool Equals(object obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is FromSqlExpression fromSqlExpression
&& Equals(fromSqlExpression));

private bool Equals(FromSqlExpression fromSqlExpression)
=> base.Equals(fromSqlExpression)
&& string.Equals(Sql, fromSqlExpression.Sql);

public override int GetHashCode()
{
unchecked
{
var hashCode = base.GetHashCode();
hashCode = (hashCode * 397) ^ Sql.GetHashCode();

return hashCode;
}
}

#endregion
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ public SelectExpression(IEntityType entityType)
_projectionMapping[new ProjectionMember()] = new EntityProjectionExpression(entityType, tableExpression, false);
}

public SelectExpression(IEntityType entityType, string sql)
: base("")
{
var fromSqlExpression = new FromSqlExpression(
sql,
entityType.GetTableName().ToLower().Substring(0, 1));

_tables.Add(fromSqlExpression);

_projectionMapping[new ProjectionMember()] = new EntityProjectionExpression(entityType, fromSqlExpression, false);
}

public SqlExpression BindProperty(Expression projectionExpression, IProperty property)
{
var member = (projectionExpression as ProjectionBindingExpression).ProjectionMember;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ public class FromSqlExpressionNode : ResultOperatorExpressionNodeBase
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static readonly IReadOnlyCollection<MethodInfo> SupportedMethods = new[] { RelationalQueryableExtensions.FromSqlMethodInfo };
public static readonly IReadOnlyCollection<MethodInfo> SupportedMethods = new List<MethodInfo>
{
// RelationalQueryableExtensions.FromSqlMethodInfo
};

private readonly string _sql;
private readonly Expression _arguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ protected override Expression VisitExists(ExistsExpression existsExpression)
return ApplyConversion(existsExpression.Update(subquery), condition: true);
}

protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression)
=> fromSqlExpression;

protected override Expression VisitIn(InExpression inExpression)
{
var parentSearchCondition = _isSearchCondition;
Expand Down
Loading