diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index 816250589f8..4a4ea51deba 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -181,6 +181,12 @@ private static readonly MethodInfo _getItemMethodInfo = typeof(JObject).GetTypeInfo().GetRuntimeProperties() .Single(pi => pi.Name == "Item" && pi.GetIndexParameters()[0].ParameterType == typeof(string)) .GetMethod; + private static readonly PropertyInfo _jTokenTypePropertyInfo + = typeof(JToken).GetTypeInfo().GetRuntimeProperties() + .Single(mi => mi.Name == nameof(JToken.Type)); + private static readonly MethodInfo _jTokenToObjectMethodInfo + = typeof(JToken).GetTypeInfo().GetRuntimeMethods() + .Single(mi => mi.Name == nameof(JToken.ToObject) && mi.GetParameters().Length == 0); private static readonly MethodInfo _toObjectMethodInfo = typeof(CosmosProjectionBindingRemovingExpressionVisitor).GetTypeInfo().GetRuntimeMethods() .Single(mi => mi.Name == nameof(SafeToObject)); @@ -199,8 +205,8 @@ private readonly IDictionary _materializationCo = new Dictionary(); private readonly IDictionary _projectionBindings = new Dictionary(); - private readonly IDictionary _ownerMappings - = new Dictionary(); + private readonly IDictionary _ownerMappings + = new Dictionary(); private (IEntityType EntityType, ParameterExpression JObjectVariable) _ownerInfo; private ParameterExpression _ordinalParameter; @@ -329,14 +335,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp (ParameterExpression)((MethodCallExpression)methodCallExpression.Arguments[0]).Object]; } - var readExpression = CreateGetValueExpression(innerExpression, property); - if (readExpression.Type.IsValueType - && methodCallExpression.Type == typeof(object)) - { - readExpression = Expression.Convert(readExpression, typeof(object)); - } - - return readExpression; + return CreateGetValueExpression(innerExpression, property, methodCallExpression.Type); } return base.VisitMethodCall(methodCallExpression); @@ -381,6 +380,10 @@ protected override Expression VisitExtension(Expression extensionExpression) { _ownerMappings[objectArrayProjection.InnerProjection.AccessExpression] = _ownerInfo; } + else + { + _ownerMappings[objectArrayProjection.InnerProjection.AccessExpression] = (objectArrayProjection.Navigation.DeclaringEntityType, objectArrayProjection.AccessExpression); + } var previousOrdinalParameter = _ordinalParameter; _ordinalParameter = ordinalParameter; @@ -676,7 +679,8 @@ private static Expression CreateReadJTokenExpression(Expression jObjectExpressio private Expression CreateGetValueExpression( Expression jObjectExpression, - IProperty property) + IProperty property, + Type clrType) { if (property.Name == StoreKeyConvention.JObjectPropertyName) { @@ -697,7 +701,13 @@ private Expression CreateGetValueExpression( && !property.IsForeignKey() && property.ClrType == typeof(int)) { - return _ordinalParameter; + Expression readExpression = _ordinalParameter; + if (readExpression.Type != clrType) + { + readExpression = Expression.Convert(readExpression, clrType); + } + + return readExpression; } var principalProperty = property.FindFirstPrincipal(); @@ -708,7 +718,7 @@ private Expression CreateGetValueExpression( { Debug.Assert(principalProperty.DeclaringEntityType.IsAssignableFrom(ownerInfo.EntityType)); - ownerJObjectExpression = ownerInfo.JObjectVariable; + ownerJObjectExpression = ownerInfo.JObjectExpression; } else if (jObjectExpression is RootReferenceExpression rootReferenceExpression) { @@ -721,15 +731,15 @@ private Expression CreateGetValueExpression( if (ownerJObjectExpression != null) { - return CreateGetValueExpression(ownerJObjectExpression, principalProperty); + return CreateGetValueExpression(ownerJObjectExpression, principalProperty, clrType); } } } - return Expression.Default(property.ClrType); + return Expression.Default(clrType); } - return CreateGetValueExpression(jObjectExpression, storeName, property.ClrType, property.GetTypeMapping()); + return CreateGetValueExpression(jObjectExpression, storeName, clrType, property.GetTypeMapping()); } private Expression CreateGetValueExpression( @@ -762,21 +772,39 @@ private Expression CreateGetValueExpression( var converter = typeMapping?.Converter; if (converter != null) { - valueExpression = ConvertJTokenToType(jTokenExpression, converter.ProviderClrType); + var jTokenParameter = Expression.Parameter(typeof(JToken)); - valueExpression = ReplacingExpressionVisitor.Replace( - converter.ConvertFromProviderExpression.Parameters.Single(), - valueExpression, - converter.ConvertFromProviderExpression.Body); + var body + = ReplacingExpressionVisitor.Replace( + converter.ConvertFromProviderExpression.Parameters.Single(), + Expression.Call( + jTokenParameter, + _jTokenToObjectMethodInfo.MakeGenericMethod(converter.ProviderClrType)), + converter.ConvertFromProviderExpression.Body); - if (valueExpression.Type != clrType) + if (body.Type != clrType) { - valueExpression = Expression.Convert(valueExpression, clrType); + body = Expression.Convert(body, clrType); } + + body = Expression.Condition( + Expression.OrElse( + Expression.Equal(jTokenParameter, Expression.Default(typeof(JToken))), + Expression.Equal(Expression.MakeMemberAccess(jTokenParameter, _jTokenTypePropertyInfo), + Expression.Constant(JTokenType.Null))), + Expression.Default(clrType), + body); + + valueExpression = Expression.Invoke(Expression.Lambda(body, jTokenParameter), jTokenExpression); } else { - valueExpression = ConvertJTokenToType(jTokenExpression, clrType); + valueExpression = ConvertJTokenToType(jTokenExpression, typeMapping?.ClrType.MakeNullable() ?? clrType); + + if (valueExpression.Type != clrType) + { + valueExpression = Expression.Convert(valueExpression, clrType); + } } return valueExpression; diff --git a/src/EFCore.Cosmos/Storage/Internal/CosmosTypeMappingSource.cs b/src/EFCore.Cosmos/Storage/Internal/CosmosTypeMappingSource.cs index a4bef19cd91..5273880476b 100644 --- a/src/EFCore.Cosmos/Storage/Internal/CosmosTypeMappingSource.cs +++ b/src/EFCore.Cosmos/Storage/Internal/CosmosTypeMappingSource.cs @@ -52,7 +52,8 @@ protected override CoreTypeMapping FindMapping(in TypeMappingInfo mappingInfo) return mapping; } - if (clrType.IsValueType + if ((clrType.IsValueType + && !clrType.IsEnum) || clrType == typeof(string)) { return new CosmosTypeMapping(clrType); diff --git a/src/EFCore.Cosmos/Update/Internal/DocumentSource.cs b/src/EFCore.Cosmos/Update/Internal/DocumentSource.cs index 8d49ae081e5..3d89dcd58a0 100644 --- a/src/EFCore.Cosmos/Update/Internal/DocumentSource.cs +++ b/src/EFCore.Cosmos/Update/Internal/DocumentSource.cs @@ -211,12 +211,9 @@ public virtual JObject UpdateDocument(JObject document, IUpdateEntry entry) } public virtual JObject GetCurrentDocument(IUpdateEntry entry) - { - var document = _jObjectProperty != null + => _jObjectProperty != null ? (JObject)(entry.SharedIdentityEntry ?? entry).GetCurrentValue(_jObjectProperty) : null; - return document; - } private static JToken ConvertPropertyValue(IProperty property, object value) { diff --git a/src/EFCore/Query/Internal/QueryCompiler.cs b/src/EFCore/Query/Internal/QueryCompiler.cs index 8aa4ed5a4e6..5f3123908b2 100644 --- a/src/EFCore/Query/Internal/QueryCompiler.cs +++ b/src/EFCore/Query/Internal/QueryCompiler.cs @@ -103,9 +103,7 @@ public virtual Func CompileQueryCore( Expression query, IModel model, bool async) - { - return database.CompileQuery(query, async); - } + => database.CompileQuery(query, async); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to diff --git a/src/EFCore/Storage/Database.cs b/src/EFCore/Storage/Database.cs index 753397d9bb9..d2e95de4fa7 100644 --- a/src/EFCore/Storage/Database.cs +++ b/src/EFCore/Storage/Database.cs @@ -68,10 +68,8 @@ public abstract Task SaveChangesAsync( CancellationToken cancellationToken = default); public virtual Func CompileQuery(Expression query, bool async) - { - return Dependencies.QueryCompilationContextFactory + => Dependencies.QueryCompilationContextFactory .Create(async) .CreateQueryExecutor(query); - } } } diff --git a/src/EFCore/Storage/ValueConversion/ValueConverterSelector.cs b/src/EFCore/Storage/ValueConversion/ValueConverterSelector.cs index d16df064ee7..b69b3006a24 100644 --- a/src/EFCore/Storage/ValueConversion/ValueConverterSelector.cs +++ b/src/EFCore/Storage/ValueConversion/ValueConverterSelector.cs @@ -69,176 +69,173 @@ public virtual IEnumerable Select( { Check.NotNull(modelClrType, nameof(modelClrType)); - var underlyingModelType = modelClrType.UnwrapNullableType(); - var underlyingProviderType = providerClrType?.UnwrapNullableType(); - - if (underlyingModelType.IsEnum) + if (modelClrType.IsEnum) { foreach (var converterInfo in FindNumericConventions( - underlyingModelType, - underlyingProviderType, + modelClrType, + providerClrType, typeof(EnumToNumberConverter<,>), EnumToStringOrBytes)) { yield return converterInfo; } } - else if (underlyingModelType == typeof(bool)) + else if (modelClrType == typeof(bool)) { foreach (var converterInfo in FindNumericConventions( typeof(bool), - underlyingProviderType, + providerClrType, typeof(BoolToZeroOneConverter<>), null)) { yield return converterInfo; } - if (underlyingProviderType == null - || underlyingProviderType == typeof(string)) + if (providerClrType == null + || providerClrType == typeof(string)) { yield return BoolToStringConverter.DefaultInfo; } - if (underlyingProviderType == null - || underlyingProviderType == typeof(byte[])) + if (providerClrType == null + || providerClrType == typeof(byte[])) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(byte[])), + (modelClrType, typeof(byte[])), k => new ValueConverterInfo( - underlyingModelType, + modelClrType, typeof(byte[]), info => new BoolToZeroOneConverter().ComposeWith( NumberToBytesConverter.DefaultInfo.Create()), new ConverterMappingHints(size: 1))); } } - else if (underlyingModelType == typeof(char)) + else if (modelClrType == typeof(char)) { - foreach (var valueConverterInfo in ForChar(typeof(char), underlyingProviderType)) + foreach (var valueConverterInfo in ForChar(typeof(char), providerClrType)) { yield return valueConverterInfo; } } - else if (underlyingModelType == typeof(Guid)) + else if (modelClrType == typeof(Guid)) { - if (underlyingProviderType == null - || underlyingProviderType == typeof(byte[])) + if (providerClrType == null + || providerClrType == typeof(byte[])) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(byte[])), + (modelClrType, typeof(byte[])), k => GuidToBytesConverter.DefaultInfo); } - if (underlyingProviderType == null - || underlyingProviderType == typeof(string)) + if (providerClrType == null + || providerClrType == typeof(string)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(string)), + (modelClrType, typeof(string)), k => GuidToStringConverter.DefaultInfo); } } - else if (underlyingModelType == typeof(byte[])) + else if (modelClrType == typeof(byte[])) { - if (underlyingProviderType == null - || underlyingProviderType == typeof(string)) + if (providerClrType == null + || providerClrType == typeof(string)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(string)), + (modelClrType, typeof(string)), k => BytesToStringConverter.DefaultInfo); } } - else if (underlyingModelType == typeof(Uri)) + else if (modelClrType == typeof(Uri)) { - if (underlyingProviderType == null - || underlyingProviderType == typeof(string)) + if (providerClrType == null + || providerClrType == typeof(string)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(string)), + (modelClrType, typeof(string)), k => UriToStringConverter.DefaultInfo); } } - else if (underlyingModelType == typeof(string)) + else if (modelClrType == typeof(string)) { - if (underlyingProviderType == null - || underlyingProviderType == typeof(byte[])) + if (providerClrType == null + || providerClrType == typeof(byte[])) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(byte[])), + (modelClrType, typeof(byte[])), k => StringToBytesConverter.DefaultInfo); } - else if (underlyingProviderType.IsEnum) + else if (providerClrType.IsEnum) { yield return _converters.GetOrAdd( - (typeof(string), underlyingProviderType), + (typeof(string), providerClrType), k => (ValueConverterInfo)typeof(StringToEnumConverter<>) .MakeGenericType(k.ProviderClrType) .GetAnyProperty("DefaultInfo") .GetValue(null)); } - else if (_numerics.Contains(underlyingProviderType)) + else if (_numerics.Contains(providerClrType)) { foreach (var converterInfo in FindNumericConventions( typeof(string), - underlyingProviderType, + providerClrType, typeof(StringToNumberConverter<>), null)) { yield return converterInfo; } } - else if (underlyingProviderType == typeof(DateTime)) + else if (providerClrType == typeof(DateTime)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(DateTime)), + (modelClrType, typeof(DateTime)), k => StringToDateTimeConverter.DefaultInfo); } - else if (underlyingProviderType == typeof(DateTimeOffset)) + else if (providerClrType == typeof(DateTimeOffset)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(DateTimeOffset)), + (modelClrType, typeof(DateTimeOffset)), k => StringToDateTimeOffsetConverter.DefaultInfo); } - else if (underlyingProviderType == typeof(TimeSpan)) + else if (providerClrType == typeof(TimeSpan)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(TimeSpan)), + (modelClrType, typeof(TimeSpan)), k => StringToTimeSpanConverter.DefaultInfo); } - else if (underlyingProviderType == typeof(Guid)) + else if (providerClrType == typeof(Guid)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(Guid)), + (modelClrType, typeof(Guid)), k => StringToGuidConverter.DefaultInfo); } - else if (underlyingProviderType == typeof(bool)) + else if (providerClrType == typeof(bool)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(bool)), + (modelClrType, typeof(bool)), k => StringToBoolConverter.DefaultInfo); } - else if (underlyingProviderType == typeof(char)) + else if (providerClrType == typeof(char)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(char)), + (modelClrType, typeof(char)), k => StringToCharConverter.DefaultInfo); } - else if (underlyingProviderType == typeof(Uri)) + else if (providerClrType == typeof(Uri)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(Uri)), + (modelClrType, typeof(Uri)), k => StringToUriConverter.DefaultInfo); } } - else if (underlyingModelType == typeof(DateTime) - || underlyingModelType == typeof(DateTimeOffset) - || underlyingModelType == typeof(TimeSpan)) + else if (modelClrType == typeof(DateTime) + || modelClrType == typeof(DateTimeOffset) + || modelClrType == typeof(TimeSpan)) { - if (underlyingProviderType == null - || underlyingProviderType == typeof(string)) + if (providerClrType == null + || providerClrType == typeof(string)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(string)), + (modelClrType, typeof(string)), k => k.ModelClrType == typeof(DateTime) ? DateTimeToStringConverter.DefaultInfo : k.ModelClrType == typeof(DateTimeOffset) @@ -246,11 +243,11 @@ public virtual IEnumerable Select( : TimeSpanToStringConverter.DefaultInfo); } - if (underlyingProviderType == null - || underlyingProviderType == typeof(long)) + if (providerClrType == null + || providerClrType == typeof(long)) { yield return _converters.GetOrAdd( - (underlyingModelType, typeof(long)), + (modelClrType, typeof(long)), k => k.ModelClrType == typeof(DateTime) ? DateTimeToBinaryConverter.DefaultInfo : k.ModelClrType == typeof(DateTimeOffset) @@ -258,17 +255,17 @@ public virtual IEnumerable Select( : TimeSpanToTicksConverter.DefaultInfo); } - if (underlyingProviderType == null - || underlyingProviderType == typeof(byte[])) + if (providerClrType == null + || providerClrType == typeof(byte[])) { - yield return underlyingModelType == typeof(DateTimeOffset) + yield return modelClrType == typeof(DateTimeOffset) ? _converters.GetOrAdd( - (underlyingModelType, typeof(byte[])), + (modelClrType, typeof(byte[])), k => DateTimeOffsetToBytesConverter.DefaultInfo) : _converters.GetOrAdd( - (underlyingModelType, typeof(byte[])), + (modelClrType, typeof(byte[])), k => new ValueConverterInfo( - underlyingModelType, + modelClrType, typeof(byte[]), i => (i.ModelClrType == typeof(DateTime) ? DateTimeToBinaryConverter.DefaultInfo.Create() @@ -278,15 +275,15 @@ public virtual IEnumerable Select( NumberToBytesConverter.DefaultInfo.MappingHints)); } } - else if (_numerics.Contains(underlyingModelType) - && (underlyingProviderType == null - || underlyingProviderType == typeof(byte[]) - || underlyingProviderType == typeof(string) - || _numerics.Contains(underlyingProviderType))) + else if (_numerics.Contains(modelClrType) + && (providerClrType == null + || providerClrType == typeof(byte[]) + || providerClrType == typeof(string) + || _numerics.Contains(providerClrType))) { foreach (var converterInfo in FindNumericConventions( - underlyingModelType, - underlyingProviderType, + modelClrType, + providerClrType, typeof(CastingConverter<,>), NumberToStringOrBytes)) { diff --git a/test/EFCore.Cosmos.FunctionalTests/BuiltInDataTypesCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/BuiltInDataTypesCosmosTest.cs index 9b892d28034..3527c62c34c 100644 --- a/test/EFCore.Cosmos.FunctionalTests/BuiltInDataTypesCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/BuiltInDataTypesCosmosTest.cs @@ -2,29 +2,80 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.Cosmos.TestUtilities; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.TestUtilities; +using Xunit; namespace Microsoft.EntityFrameworkCore.Cosmos { - // TODO: Issue #12086 - internal class BuiltInDataTypesCosmosTest : BuiltInDataTypesTestBase + public class BuiltInDataTypesCosmosTest : BuiltInDataTypesTestBase { public BuiltInDataTypesCosmosTest(BuiltInDataTypesCosmosFixture fixture) : base(fixture) { } + [ConditionalTheory(Skip = "Issue #16919")] + public override Task Can_filter_projection_with_inline_enum_variable(bool async) + { + return base.Can_filter_projection_with_inline_enum_variable(async); + } + + [ConditionalTheory(Skip = "Issue #16919")] + public override Task Can_filter_projection_with_captured_enum_variable(bool async) + { + return base.Can_filter_projection_with_captured_enum_variable(async); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_data_type() + { + base.Can_query_using_any_data_type(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_data_type_nullable_shadow() + { + base.Can_query_using_any_data_type_nullable_shadow(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_data_type_shadow() + { + base.Can_query_using_any_data_type_shadow(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_nullable_data_type() + { + base.Can_query_using_any_nullable_data_type(); + } + + [ConditionalFact(Skip = "Issue #16919")] public override void Can_query_using_any_nullable_data_type_as_literal() { - // TODO: Requires ReLinq to be removed + base.Can_query_using_any_nullable_data_type_as_literal(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_with_null_parameters_using_any_nullable_data_type() + { + base.Can_query_with_null_parameters_using_any_nullable_data_type(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_insert_and_read_back_with_string_key() + { + base.Can_insert_and_read_back_with_string_key(); } + [ConditionalFact(Skip = "Issue #16920")] public override void Can_insert_and_read_back_with_binary_key() { - // TODO: For this to work Join needs to be translated or compiled as a Join with custom equality comparer + base.Can_insert_and_read_back_with_binary_key(); } public override void Can_perform_query_with_max_length() @@ -38,6 +89,8 @@ public class BuiltInDataTypesCosmosFixture : BuiltInDataTypesFixtureBase public override bool StrictEquality => true; + public override int IntegerPrecision => 53; + public override bool SupportsAnsi => false; public override bool SupportsUnicodeToAnsiConversion => false; diff --git a/test/EFCore.Cosmos.FunctionalTests/CustomConvertersCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/CustomConvertersCosmosTest.cs index 9fbe90f5106..210e606d699 100644 --- a/test/EFCore.Cosmos.FunctionalTests/CustomConvertersCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/CustomConvertersCosmosTest.cs @@ -2,15 +2,16 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.Cosmos.TestUtilities; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Internal; using Microsoft.EntityFrameworkCore.TestUtilities; +using Xunit; namespace Microsoft.EntityFrameworkCore.Cosmos { - // TODO: Issue #12086 - internal class CustomConvertersCosmosTest : CustomConvertersTestBase + public class CustomConvertersCosmosTest : CustomConvertersTestBase { public CustomConvertersCosmosTest(CustomConvertersCosmosFixture fixture) : base(fixture) @@ -22,13 +23,88 @@ public override void Can_perform_query_with_max_length() // Over the 2Mb document limit } - // TODO: For these to work Join needs to be translated or compiled as a Join with custom equality comparer + [ConditionalTheory(Skip = "Issue #16919")] + public override Task Can_filter_projection_with_inline_enum_variable(bool async) + { + return base.Can_filter_projection_with_inline_enum_variable(async); + } + + [ConditionalTheory(Skip = "Issue #16919")] + public override Task Can_filter_projection_with_captured_enum_variable(bool async) + { + return base.Can_filter_projection_with_captured_enum_variable(async); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_data_type() + { + base.Can_query_using_any_data_type(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_data_type_nullable_shadow() + { + base.Can_query_using_any_data_type_nullable_shadow(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_data_type_shadow() + { + base.Can_query_using_any_data_type_shadow(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_nullable_data_type() + { + base.Can_query_using_any_nullable_data_type(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_using_any_nullable_data_type_as_literal() + { + base.Can_query_using_any_nullable_data_type_as_literal(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_with_null_parameters_using_any_nullable_data_type() + { + base.Can_query_with_null_parameters_using_any_nullable_data_type(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_insert_and_read_back_with_string_key() + { + base.Can_insert_and_read_back_with_string_key(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_and_update_with_conversion_for_custom_struct() + { + base.Can_query_and_update_with_conversion_for_custom_struct(); + } + + [ConditionalFact(Skip = "Issue #16919")] + public override void Can_query_and_update_with_conversion_for_custom_type() + { + base.Can_query_and_update_with_conversion_for_custom_type(); + } + + [ConditionalFact(Skip = "Issue #16920")] + public override void Can_query_and_update_with_nullable_converter_on_primary_key() + { + base.Can_query_and_update_with_nullable_converter_on_primary_key(); + } + + [ConditionalFact(Skip = "Issue #16920")] public override void Can_insert_and_read_back_with_binary_key() { + base.Can_insert_and_read_back_with_binary_key(); } + [ConditionalFact(Skip = "Issue #16920")] public override void Can_insert_and_read_back_with_case_insensitive_string_key() { + base.Can_insert_and_read_back_with_case_insensitive_string_key(); } public class CustomConvertersCosmosFixture : CustomConvertersFixtureBase @@ -37,6 +113,8 @@ public class CustomConvertersCosmosFixture : CustomConvertersFixtureBase public override bool StrictEquality => true; + public override int IntegerPrecision => 53; + public override bool SupportsAnsi => false; public override bool SupportsUnicodeToAnsiConversion => false; diff --git a/test/EFCore.Cosmos.FunctionalTests/EmbeddedDocumentsTest.cs b/test/EFCore.Cosmos.FunctionalTests/EmbeddedDocumentsTest.cs index 8015e807934..0fa4dc4f2b8 100644 --- a/test/EFCore.Cosmos.FunctionalTests/EmbeddedDocumentsTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/EmbeddedDocumentsTest.cs @@ -192,7 +192,22 @@ public virtual async Task Can_query_just_nested_collection() { using (var context = CreateContext()) { - context.Add(new Person { Id = 3, Addresses = new[] { new Address { Street = "First", City = "City" }, new Address { Street = "Second", City = "City" } } }); + context.Add( + new Person + { + Id = 3, + Addresses = new[] + { + new Address + { + Street = "First", City = "City" + }, + new Address + { + Street = "Second", City = "City" + } + } + }); context.SaveChanges(); } diff --git a/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs b/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs index ad9d710e04f..66c8b062a41 100644 --- a/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs +++ b/test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs @@ -6,6 +6,7 @@ using System.ComponentModel.DataAnnotations.Schema; using System.Linq; using System.Linq.Expressions; +using System.Reflection; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Metadata; @@ -249,9 +250,12 @@ private void QueryBuiltInDataTypesTest(EntityEntry source) set.Where(e => e.Id == 11 && EF.Property(e, nameof(BuiltInDataTypes.TestInt32)) == param2).ToList().Single()); var param3 = -1234567890123456789L; - Assert.Same( - entity, - set.Where(e => e.Id == 11 && EF.Property(e, nameof(BuiltInDataTypes.TestInt64)) == param3).ToList().Single()); + if (Fixture.IntegerPrecision == 64) + { + Assert.Same( + entity, + set.Where(e => e.Id == 11 && EF.Property(e, nameof(BuiltInDataTypes.TestInt64)) == param3).ToList().Single()); + } double? param4 = -1.23456789; if (Fixture.StrictEquality) @@ -879,9 +883,6 @@ private void QueryBuiltInNullableDataTypesTest(EntityEntry sou } } - private static Type UnwrapNullableType(Type type) - => type == null ? null : Nullable.GetUnderlyingType(type) ?? type; - protected virtual EntityEntry AddTestBuiltInNullableDataTypes(DbSet set) where TEntity : BuiltInNullableDataTypesBase, new() { @@ -1550,15 +1551,86 @@ public virtual void Can_insert_and_read_back_with_null_string_foreign_key() } } - // ReSharper disable once ParameterOnlyUsedForPreconditionCheck.Local - private static void AssertEqualIfMapped(IEntityType entityType, T expected, Expression> actual) + private void AssertEqualIfMapped(IEntityType entityType, T expected, Expression> actualExpression) + { + if (entityType.FindProperty(((MemberExpression)actualExpression.Body).Member.Name) != null) + { + var actual = actualExpression.Compile()(); + var type = UnwrapNullableEnumType(typeof(T)); + if (IsSignedInteger(type)) + { + Assert.True(Equal(Convert.ToInt64(expected), Convert.ToInt64(actual)), $"Expected:\t{expected}\r\nActual:\t{actual}"); + } + else if (IsUnsignedInteger(type)) + { + Assert.True(Equal(Convert.ToUInt64(expected), Convert.ToUInt64(actual)), $"Expected:\t{expected}\r\nActual:\t{actual}"); + } + else + { + Assert.Equal(expected, actual); + } + } + } + + private bool Equal(long left, long right) + { + if (left >= 0 + && right >= 0) + { + return Equal((ulong)left, (ulong)right); + } + + if (left < 0 + && right < 0) + { + return Equal((ulong)-left, (ulong)-right); + } + + return false; + } + + private bool Equal(ulong left, ulong right) { - if (entityType.FindProperty(((MemberExpression)actual.Body).Member.Name) != null) + if (Fixture.IntegerPrecision < 64) { - Assert.Equal(expected, actual.Compile()()); + var largestPrecise = 1ul << Fixture.IntegerPrecision; + while (left > largestPrecise) + { + left >>= 1; + right >>= 1; + } } + + return left == right; } + private static Type UnwrapNullableType(Type type) + => type == null ? null : Nullable.GetUnderlyingType(type) ?? type; + + public static Type UnwrapNullableEnumType(Type type) + { + var underlyingNonNullableType = UnwrapNullableType(type); + if (!underlyingNonNullableType.GetTypeInfo().IsEnum) + { + return underlyingNonNullableType; + } + + return Enum.GetUnderlyingType(underlyingNonNullableType); + } + + private static bool IsSignedInteger(Type type) + => type == typeof(int) + || type == typeof(long) + || type == typeof(short) + || type == typeof(sbyte); + + private static bool IsUnsignedInteger(Type type) + => type == typeof(byte) + || type == typeof(uint) + || type == typeof(ulong) + || type == typeof(ushort) + || type == typeof(char); + [ConditionalFact] public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values_set_to_null() { @@ -1652,7 +1724,7 @@ public virtual void Can_insert_and_read_back_all_nullable_data_types_with_values { var dt = context.Set().Where(ndt => ndt.Id == 101).ToList().Single(); - var entityType = context.Model.FindEntityType(typeof(BuiltInDataTypes)); + var entityType = context.Model.FindEntityType(typeof(BuiltInNullableDataTypes)); AssertEqualIfMapped(entityType, "TestString", () => dt.TestString); AssertEqualIfMapped(entityType, new byte[] { 10, 9, 8, 7, 6 }, () => dt.TestByteArray); AssertEqualIfMapped(entityType, (short)-1234, () => dt.TestNullableInt16); @@ -1729,32 +1801,33 @@ public virtual void Can_insert_and_read_back_object_backed_data_types() { var dt = context.Set().Where(ndt => ndt.Id == 101).ToList().Single(); - Assert.Equal("TestString", dt.String); - Assert.Equal(new byte[] { 10, 9, 8, 7, 6 }, dt.Bytes); - Assert.Equal((short)-1234, dt.Int16); - Assert.Equal(-123456789, dt.Int32); - Assert.Equal(-1234567890123456789L, dt.Int64); - Assert.Equal(-1.23456789, dt.Double); - Assert.Equal(-1234567890.01M, dt.Decimal); - Assert.Equal(DateTime.Parse("01/01/2000 12:34:56"), dt.DateTime); - Assert.Equal(new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), dt.DateTimeOffset); - Assert.Equal(new TimeSpan(0, 10, 9, 8, 7), dt.TimeSpan); - Assert.Equal(-1.234F, dt.Single); - Assert.Equal(false, dt.Boolean); - Assert.Equal((byte)255, dt.Byte); - Assert.Equal(Enum64.SomeValue, dt.Enum64); - Assert.Equal(Enum32.SomeValue, dt.Enum32); - Assert.Equal(Enum16.SomeValue, dt.Enum16); - Assert.Equal(Enum8.SomeValue, dt.Enum8); - Assert.Equal((ushort)1234, dt.UnsignedInt16); - Assert.Equal(1234565789U, dt.UnsignedInt32); - Assert.Equal(1234567890123456789UL, dt.UnsignedInt64); - Assert.Equal('a', dt.Character); - Assert.Equal((sbyte)-128, dt.SignedByte); - Assert.Equal(EnumU64.SomeValue, dt.EnumU64); - Assert.Equal(EnumU32.SomeValue, dt.EnumU32); - Assert.Equal(EnumU16.SomeValue, dt.EnumU16); - Assert.Equal(EnumS8.SomeValue, dt.EnumS8); + var entityType = context.Model.FindEntityType(typeof(ObjectBackedDataTypes)); + AssertEqualIfMapped(entityType, "TestString", () => dt.String); + AssertEqualIfMapped(entityType, new byte[] { 10, 9, 8, 7, 6 }, () => dt.Bytes); + AssertEqualIfMapped(entityType, (short)-1234, () => dt.Int16); + AssertEqualIfMapped(entityType, -123456789, () => dt.Int32); + AssertEqualIfMapped(entityType, -1234567890123456789L, () => dt.Int64); + AssertEqualIfMapped(entityType, -1.23456789, () => dt.Double); + AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.Decimal); + AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56"), () => dt.DateTime); + AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), () => dt.DateTimeOffset); + AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TimeSpan); + AssertEqualIfMapped(entityType, -1.234F, () => dt.Single); + AssertEqualIfMapped(entityType, false, () => dt.Boolean); + AssertEqualIfMapped(entityType, (byte)255, () => dt.Byte); + AssertEqualIfMapped(entityType, Enum64.SomeValue, () => dt.Enum64); + AssertEqualIfMapped(entityType, Enum32.SomeValue, () => dt.Enum32); + AssertEqualIfMapped(entityType, Enum16.SomeValue, () => dt.Enum16); + AssertEqualIfMapped(entityType, Enum8.SomeValue, () => dt.Enum8); + AssertEqualIfMapped(entityType, (ushort)1234, () => dt.UnsignedInt16); + AssertEqualIfMapped(entityType, 1234565789U, () => dt.UnsignedInt32); + AssertEqualIfMapped(entityType, 1234567890123456789UL, () => dt.UnsignedInt64); + AssertEqualIfMapped(entityType, 'a', () => dt.Character); + AssertEqualIfMapped(entityType, (sbyte)-128, () => dt.SignedByte); + AssertEqualIfMapped(entityType, EnumU64.SomeValue, () => dt.EnumU64); + AssertEqualIfMapped(entityType, EnumU32.SomeValue, () => dt.EnumU32); + AssertEqualIfMapped(entityType, EnumU16.SomeValue, () => dt.EnumU16); + AssertEqualIfMapped(entityType, EnumS8.SomeValue, () => dt.EnumS8); } } @@ -1801,30 +1874,31 @@ public virtual void Can_insert_and_read_back_nullable_backed_data_types() { var dt = context.Set().Where(ndt => ndt.Id == 101).ToList().Single(); - Assert.Equal((short)-1234, dt.Int16); - Assert.Equal(-123456789, dt.Int32); - Assert.Equal(-1234567890123456789L, dt.Int64); - Assert.Equal(-1.23456789, dt.Double); - Assert.Equal(-1234567890.01M, dt.Decimal); - Assert.Equal(DateTime.Parse("01/01/2000 12:34:56"), dt.DateTime); - Assert.Equal(new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), dt.DateTimeOffset); - Assert.Equal(new TimeSpan(0, 10, 9, 8, 7), dt.TimeSpan); - Assert.Equal(-1.234F, dt.Single); - Assert.Equal(false, dt.Boolean); - Assert.Equal((byte)255, dt.Byte); - Assert.Equal(Enum64.SomeValue, dt.Enum64); - Assert.Equal(Enum32.SomeValue, dt.Enum32); - Assert.Equal(Enum16.SomeValue, dt.Enum16); - Assert.Equal(Enum8.SomeValue, dt.Enum8); - Assert.Equal((ushort)1234, dt.UnsignedInt16); - Assert.Equal(1234565789U, dt.UnsignedInt32); - Assert.Equal(1234567890123456789UL, dt.UnsignedInt64); - Assert.Equal('a', dt.Character); - Assert.Equal((sbyte)-128, dt.SignedByte); - Assert.Equal(EnumU64.SomeValue, dt.EnumU64); - Assert.Equal(EnumU32.SomeValue, dt.EnumU32); - Assert.Equal(EnumU16.SomeValue, dt.EnumU16); - Assert.Equal(EnumS8.SomeValue, dt.EnumS8); + var entityType = context.Model.FindEntityType(typeof(NullableBackedDataTypes)); + AssertEqualIfMapped(entityType, (short)-1234, () => dt.Int16); + AssertEqualIfMapped(entityType, -123456789, () => dt.Int32); + AssertEqualIfMapped(entityType, -1234567890123456789L, () => dt.Int64); + AssertEqualIfMapped(entityType, -1.23456789, () => dt.Double); + AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.Decimal); + AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56"), () => dt.DateTime); + AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), () => dt.DateTimeOffset); + AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TimeSpan); + AssertEqualIfMapped(entityType, -1.234F, () => dt.Single); + AssertEqualIfMapped(entityType, false, () => dt.Boolean); + AssertEqualIfMapped(entityType, (byte)255, () => dt.Byte); + AssertEqualIfMapped(entityType, Enum64.SomeValue, () => dt.Enum64); + AssertEqualIfMapped(entityType, Enum32.SomeValue, () => dt.Enum32); + AssertEqualIfMapped(entityType, Enum16.SomeValue, () => dt.Enum16); + AssertEqualIfMapped(entityType, Enum8.SomeValue, () => dt.Enum8); + AssertEqualIfMapped(entityType, (ushort)1234, () => dt.UnsignedInt16); + AssertEqualIfMapped(entityType, 1234565789U, () => dt.UnsignedInt32); + AssertEqualIfMapped(entityType, 1234567890123456789UL, () => dt.UnsignedInt64); + AssertEqualIfMapped(entityType, 'a', () => dt.Character); + AssertEqualIfMapped(entityType, (sbyte)-128, () => dt.SignedByte); + AssertEqualIfMapped(entityType, EnumU64.SomeValue, () => dt.EnumU64); + AssertEqualIfMapped(entityType, EnumU32.SomeValue, () => dt.EnumU32); + AssertEqualIfMapped(entityType, EnumU16.SomeValue, () => dt.EnumU16); + AssertEqualIfMapped(entityType, EnumS8.SomeValue, () => dt.EnumS8); } } @@ -1871,30 +1945,32 @@ public virtual void Can_insert_and_read_back_non_nullable_backed_data_types() { var dt = context.Set().Where(ndt => ndt.Id == 101).ToList().Single(); - Assert.Equal((short)-1234, dt.Int16); - Assert.Equal(-123456789, dt.Int32); - Assert.Equal(-1234567890123456789L, dt.Int64); - Assert.Equal(-1.23456789, dt.Double); - Assert.Equal(-1234567890.01M, dt.Decimal); - Assert.Equal(DateTime.Parse("01/01/2000 12:34:56"), dt.DateTime); - Assert.Equal(new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), dt.DateTimeOffset); - Assert.Equal(new TimeSpan(0, 10, 9, 8, 7), dt.TimeSpan); - Assert.Equal(-1.234F, dt.Single); - Assert.Equal(false, dt.Boolean); - Assert.Equal((byte)255, dt.Byte); - Assert.Equal(Enum64.SomeValue, dt.Enum64); - Assert.Equal(Enum32.SomeValue, dt.Enum32); - Assert.Equal(Enum16.SomeValue, dt.Enum16); - Assert.Equal(Enum8.SomeValue, dt.Enum8); - Assert.Equal((ushort)1234, dt.UnsignedInt16); - Assert.Equal(1234565789U, dt.UnsignedInt32); - Assert.Equal(1234567890123456789UL, dt.UnsignedInt64); - Assert.Equal('a', dt.Character); - Assert.Equal((sbyte)-128, dt.SignedByte); - Assert.Equal(EnumU64.SomeValue, dt.EnumU64); - Assert.Equal(EnumU32.SomeValue, dt.EnumU32); - Assert.Equal(EnumU16.SomeValue, dt.EnumU16); - Assert.Equal(EnumS8.SomeValue, dt.EnumS8); + var entityType = context.Model.FindEntityType(typeof(NonNullableBackedDataTypes)); + AssertEqualIfMapped(entityType, (short)-1234, () => dt.Int16); + AssertEqualIfMapped(entityType, -123456789, () => dt.Int32); + AssertEqualIfMapped(entityType, -1234567890123456789L, () => dt.Int64); + AssertEqualIfMapped(entityType, -1234567890123456789L, () => dt.Int64); + AssertEqualIfMapped(entityType, -1.23456789, () => dt.Double); + AssertEqualIfMapped(entityType, -1234567890.01M, () => dt.Decimal); + AssertEqualIfMapped(entityType, DateTime.Parse("01/01/2000 12:34:56"), () => dt.DateTime); + AssertEqualIfMapped(entityType, new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)), () => dt.DateTimeOffset); + AssertEqualIfMapped(entityType, new TimeSpan(0, 10, 9, 8, 7), () => dt.TimeSpan); + AssertEqualIfMapped(entityType, -1.234F, () => dt.Single); + AssertEqualIfMapped(entityType, false, () => dt.Boolean); + AssertEqualIfMapped(entityType, (byte)255, () => dt.Byte); + AssertEqualIfMapped(entityType, Enum64.SomeValue, () => dt.Enum64); + AssertEqualIfMapped(entityType, Enum32.SomeValue, () => dt.Enum32); + AssertEqualIfMapped(entityType, Enum16.SomeValue, () => dt.Enum16); + AssertEqualIfMapped(entityType, Enum8.SomeValue, () => dt.Enum8); + AssertEqualIfMapped(entityType, (ushort)1234, () => dt.UnsignedInt16); + AssertEqualIfMapped(entityType, 1234565789U, () => dt.UnsignedInt32); + AssertEqualIfMapped(entityType, 1234567890123456789UL, () => dt.UnsignedInt64); + AssertEqualIfMapped(entityType, 'a', () => dt.Character); + AssertEqualIfMapped(entityType, (sbyte)-128, () => dt.SignedByte); + AssertEqualIfMapped(entityType, EnumU64.SomeValue, () => dt.EnumU64); + AssertEqualIfMapped(entityType, EnumU32.SomeValue, () => dt.EnumU32); + AssertEqualIfMapped(entityType, EnumU16.SomeValue, () => dt.EnumU16); + AssertEqualIfMapped(entityType, EnumS8.SomeValue, () => dt.EnumS8); } } @@ -2142,6 +2218,8 @@ protected static void MakeRequired(ModelBuilder modelBuilder) public abstract bool StrictEquality { get; } + public virtual int IntegerPrecision => 19; + public abstract bool SupportsAnsi { get; } public abstract bool SupportsUnicodeToAnsiConversion { get; }